Files
ollama/x/imagegen/nn/nn_test.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

357 lines
8.2 KiB
Go

//go:build mlx
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
func TestLinearNoBias(t *testing.T) {
// Weight: [out=2, in=3] -> transposed at forward time
weight := mlx.NewArrayFloat32([]float32{
1, 2, 3, // row 0
4, 5, 6, // row 1
}, []int32{2, 3})
mlx.Eval(weight)
linear := NewLinear(weight, nil)
// Input: [1, 3]
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15]
data := out.Data()
if len(data) != 2 || data[0] != 6 || data[1] != 15 {
t.Errorf("expected [6, 15], got %v", data)
}
}
// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly.
func TestLinearWithBias(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{
1, 2, 3,
4, 5, 6,
}, []int32{2, 3})
bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2})
mlx.Eval(weight, bias)
linear := NewLinear(weight, bias)
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Expected: [6, 15] + [10, 20] = [16, 35]
data := out.Data()
if len(data) != 2 || data[0] != 16 || data[1] != 35 {
t.Errorf("expected [16, 35], got %v", data)
}
}
// TestLinearBatched verifies Linear works with batched input.
func TestLinearBatched(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{
1, 0,
0, 1,
}, []int32{2, 2}) // Identity
mlx.Eval(weight)
linear := NewLinear(weight, nil)
// Batch of 3 inputs
x := mlx.NewArrayFloat32([]float32{
1, 2,
3, 4,
5, 6,
}, []int32{3, 2})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Identity should return same values
data := out.Data()
expected := []float32{1, 2, 3, 4, 5, 6}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRMSNorm verifies RMSNorm computation.
func TestRMSNorm(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4})
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
// Input with known RMS
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
mlx.Eval(x)
out := norm.Forward(x, 0) // eps=0 uses stored Eps
mlx.Eval(out)
// RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1]
data := out.Data()
for i, v := range data {
if math.Abs(float64(v-1.0)) > 1e-4 {
t.Errorf("at %d: expected ~1.0, got %f", i, v)
}
}
}
// TestRMSNormWithScale verifies RMSNorm applies weight scaling.
func TestRMSNormWithScale(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4})
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
mlx.Eval(x)
out := norm.Forward(x, 0) // eps=0 uses stored Eps
mlx.Eval(out)
// Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2]
data := out.Data()
for i, v := range data {
if math.Abs(float64(v-2.0)) > 1e-4 {
t.Errorf("at %d: expected ~2.0, got %f", i, v)
}
}
}
// TestEmbedding verifies embedding lookup.
func TestEmbedding(t *testing.T) {
// Embedding table: 4 tokens, dim 3
weight := mlx.NewArrayFloat32([]float32{
0, 0, 0, // token 0
1, 1, 1, // token 1
2, 2, 2, // token 2
3, 3, 3, // token 3
}, []int32{4, 3})
mlx.Eval(weight)
emb := NewEmbedding(weight)
// Look up tokens [1, 3, 0]
indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3})
mlx.Eval(indices)
out := emb.Forward(indices)
mlx.Eval(out)
data := out.Data()
expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRepeatKV verifies K/V repetition for GQA.
func TestRepeatKV(t *testing.T) {
// [B=1, num_kv_heads=2, S=2, head_dim=2]
x := mlx.NewArrayFloat32([]float32{
// head 0
1, 2, // pos 0
3, 4, // pos 1
// head 1
5, 6, // pos 0
7, 8, // pos 1
}, []int32{1, 2, 2, 2})
mlx.Eval(x)
// Repeat factor 2: 2 kv heads -> 4 heads
out := RepeatKV(x, 2)
mlx.Eval(out)
shape := out.Shape()
if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 {
t.Errorf("expected shape [1,4,2,2], got %v", shape)
}
data := out.Data()
// After repeat: head0, head0, head1, head1
expected := []float32{
1, 2, 3, 4, // head 0 (original)
1, 2, 3, 4, // head 0 (repeat)
5, 6, 7, 8, // head 1 (original)
5, 6, 7, 8, // head 1 (repeat)
}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged.
func TestRepeatKVNoOp(t *testing.T) {
x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2})
mlx.Eval(x)
out := RepeatKV(x, 1)
// Should return same pointer
if out != x {
t.Error("RepeatKV with factor 1 should return input unchanged")
}
}
// TestApplyCausalMask verifies causal masking.
func TestApplyCausalMask(t *testing.T) {
// [B=1, heads=1, S=3, S=3] - all ones
scores := mlx.Ones(1, 1, 3, 3)
mlx.Eval(scores)
out := ApplyCausalMask(scores)
mlx.Eval(out)
data := out.Data()
// Lower triangular should be 1, upper should be -1e9
// Row 0: [1, -inf, -inf]
// Row 1: [1, 1, -inf]
// Row 2: [1, 1, 1]
if data[0] != 1 || data[1] >= 0 || data[2] >= 0 {
t.Errorf("row 0 wrong: %v", data[0:3])
}
if data[3] != 1 || data[4] != 1 || data[5] >= 0 {
t.Errorf("row 1 wrong: %v", data[3:6])
}
if data[6] != 1 || data[7] != 1 || data[8] != 1 {
t.Errorf("row 2 wrong: %v", data[6:9])
}
}
// TestApplyCausalMaskWithOffset verifies causal masking with cache offset.
func TestApplyCausalMaskWithOffset(t *testing.T) {
// Simulating: cache has 2 tokens, adding 1 new query
// scores: [B=1, heads=1, queryLen=1, keyLen=3]
scores := mlx.Ones(1, 1, 1, 3)
mlx.Eval(scores)
out := ApplyCausalMaskWithOffset(scores, 2)
mlx.Eval(out)
data := out.Data()
// With offset=2, query at position 2 can attend to all 3 positions
if data[0] != 1 || data[1] != 1 || data[2] != 1 {
t.Errorf("expected [1, 1, 1], got %v", data)
}
}
// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal.
func TestApplyCausalMaskWithOffsetZero(t *testing.T) {
scores := mlx.Ones(1, 1, 2, 2)
mlx.Eval(scores)
out := ApplyCausalMaskWithOffset(scores, 0)
mlx.Eval(out)
data := out.Data()
// Standard causal: [1, -inf], [1, 1]
if data[0] != 1 || data[1] >= 0 {
t.Errorf("row 0 wrong: %v", data[0:2])
}
if data[2] != 1 || data[3] != 1 {
t.Errorf("row 1 wrong: %v", data[2:4])
}
}
// BenchmarkLinearSmall benchmarks small Linear forward pass.
func BenchmarkLinearSmall(b *testing.B) {
weight := mlx.RandomNormal([]int32{256, 256}, 42)
mlx.Eval(weight)
linear := NewLinear(weight, nil)
x := mlx.RandomNormal([]int32{1, 256}, 43)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := linear.Forward(x)
mlx.Eval(out)
}
}
// BenchmarkLinearLarge benchmarks larger Linear forward pass.
func BenchmarkLinearLarge(b *testing.B) {
weight := mlx.RandomNormal([]int32{4096, 4096}, 42)
mlx.Eval(weight)
linear := NewLinear(weight, nil)
x := mlx.RandomNormal([]int32{1, 4096}, 43)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := linear.Forward(x)
mlx.Eval(out)
}
}
// BenchmarkRMSNorm benchmarks RMSNorm forward pass.
func BenchmarkRMSNorm(b *testing.B) {
weight := mlx.Ones(4096)
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
x := mlx.RandomNormal([]int32{1, 4096}, 42)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := norm.Forward(x, 0)
mlx.Eval(out)
}
}
// BenchmarkEmbedding benchmarks embedding lookup.
func BenchmarkEmbedding(b *testing.B) {
// Typical vocab size
weight := mlx.RandomNormal([]int32{32000, 4096}, 42)
mlx.Eval(weight)
emb := NewEmbedding(weight)
// Single token lookup
indices := mlx.NewArrayInt32([]int32{1000}, []int32{1})
mlx.Eval(indices)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := emb.Forward(indices)
mlx.Eval(out)
}
}
// BenchmarkRepeatKV benchmarks K/V repetition.
func BenchmarkRepeatKV(b *testing.B) {
// Typical GQA setup: 8 kv heads -> 32 heads
x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := RepeatKV(x, 4)
mlx.Eval(out)
}
}