mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
120 lines
3.7 KiB
Go
120 lines
3.7 KiB
Go
//go:build mlx
|
|
|
|
package qwen_image
|
|
|
|
import (
|
|
"math"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// TestTransformerConfig tests configuration invariants.
|
|
func TestTransformerConfig(t *testing.T) {
|
|
cfg := defaultTransformerConfig()
|
|
|
|
// Property: hidden_dim = n_heads * head_dim
|
|
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
|
|
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
|
|
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
|
|
}
|
|
|
|
// Property: axes_dims_rope sums to head_dim
|
|
var ropeSum int32
|
|
for _, d := range cfg.AxesDimsRope {
|
|
ropeSum += d
|
|
}
|
|
if ropeSum != cfg.HeadDim {
|
|
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
|
|
}
|
|
|
|
// Property: in_channels = out_channels * patch_size^2
|
|
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
|
|
if cfg.InChannels != expectedIn {
|
|
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
|
|
}
|
|
}
|
|
|
|
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
|
|
func TestTransformerRoPE(t *testing.T) {
|
|
cfg := defaultTransformerConfig()
|
|
|
|
// Test with small image dimensions
|
|
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
|
|
txtLen := int32(5)
|
|
|
|
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
|
|
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
|
|
// Verify shapes: [seq_len, head_dim]
|
|
imgSeqLen := imgH * imgW
|
|
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
|
|
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
|
|
}
|
|
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
|
|
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
|
|
}
|
|
|
|
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
|
|
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
|
|
}
|
|
|
|
// Verify values are finite
|
|
imgData := ropeCache.ImgFreqs.Data()
|
|
for i := 0; i < min(100, len(imgData)); i++ {
|
|
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
|
|
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestTransformerForward tests full forward pass (integration test).
|
|
// Skips if model weights are not available.
|
|
func TestTransformerForward(t *testing.T) {
|
|
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
|
|
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
|
t.Skip("Skipping: model weights not found at " + weightsPath)
|
|
}
|
|
|
|
transformer := &Transformer{}
|
|
if err := transformer.Load(weightsPath); err != nil {
|
|
t.Fatalf("Failed to load transformer: %v", err)
|
|
}
|
|
mlx.Keep(mlx.Collect(transformer)...)
|
|
cfg := transformer.Config
|
|
|
|
// Small test inputs
|
|
batchSize := int32(1)
|
|
imgH, imgW := int32(4), int32(4)
|
|
imgSeqLen := imgH * imgW
|
|
txtSeqLen := int32(5)
|
|
|
|
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
|
|
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
|
|
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
|
|
|
|
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
|
|
|
|
// Forward pass
|
|
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
|
mlx.Eval(out)
|
|
|
|
// Verify output shape: [batch, img_seq_len, in_channels]
|
|
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
|
|
gotShape := out.Shape()
|
|
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
|
|
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
|
|
}
|
|
|
|
// Verify output is finite
|
|
outData := out.Data()
|
|
for i := 0; i < min(100, len(outData)); i++ {
|
|
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
|
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
|
break
|
|
}
|
|
}
|
|
}
|