Files
ollama/x/imagegen/models/qwen_image/pipeline_test.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

67 lines
1.5 KiB
Go

//go:build mlx
package qwen_image
import (
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}