Files
ollama/x/imagegen/models/qwen_image/scheduler_test.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

136 lines
4.1 KiB
Go

//go:build mlx
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}