mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
357 lines
8.2 KiB
Go
357 lines
8.2 KiB
Go
//go:build mlx
|
|
|
|
package nn
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
|
|
func TestLinearNoBias(t *testing.T) {
|
|
// Weight: [out=2, in=3] -> transposed at forward time
|
|
weight := mlx.NewArrayFloat32([]float32{
|
|
1, 2, 3, // row 0
|
|
4, 5, 6, // row 1
|
|
}, []int32{2, 3})
|
|
mlx.Eval(weight)
|
|
|
|
linear := NewLinear(weight, nil)
|
|
|
|
// Input: [1, 3]
|
|
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
|
mlx.Eval(x)
|
|
|
|
out := linear.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
// Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15]
|
|
data := out.Data()
|
|
if len(data) != 2 || data[0] != 6 || data[1] != 15 {
|
|
t.Errorf("expected [6, 15], got %v", data)
|
|
}
|
|
}
|
|
|
|
// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly.
|
|
func TestLinearWithBias(t *testing.T) {
|
|
weight := mlx.NewArrayFloat32([]float32{
|
|
1, 2, 3,
|
|
4, 5, 6,
|
|
}, []int32{2, 3})
|
|
bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2})
|
|
mlx.Eval(weight, bias)
|
|
|
|
linear := NewLinear(weight, bias)
|
|
|
|
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
|
mlx.Eval(x)
|
|
|
|
out := linear.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
// Expected: [6, 15] + [10, 20] = [16, 35]
|
|
data := out.Data()
|
|
if len(data) != 2 || data[0] != 16 || data[1] != 35 {
|
|
t.Errorf("expected [16, 35], got %v", data)
|
|
}
|
|
}
|
|
|
|
// TestLinearBatched verifies Linear works with batched input.
|
|
func TestLinearBatched(t *testing.T) {
|
|
weight := mlx.NewArrayFloat32([]float32{
|
|
1, 0,
|
|
0, 1,
|
|
}, []int32{2, 2}) // Identity
|
|
mlx.Eval(weight)
|
|
|
|
linear := NewLinear(weight, nil)
|
|
|
|
// Batch of 3 inputs
|
|
x := mlx.NewArrayFloat32([]float32{
|
|
1, 2,
|
|
3, 4,
|
|
5, 6,
|
|
}, []int32{3, 2})
|
|
mlx.Eval(x)
|
|
|
|
out := linear.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
// Identity should return same values
|
|
data := out.Data()
|
|
expected := []float32{1, 2, 3, 4, 5, 6}
|
|
for i, v := range expected {
|
|
if data[i] != v {
|
|
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRMSNorm verifies RMSNorm computation.
|
|
func TestRMSNorm(t *testing.T) {
|
|
weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4})
|
|
mlx.Eval(weight)
|
|
|
|
norm := NewRMSNorm(weight, 1e-5)
|
|
|
|
// Input with known RMS
|
|
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
|
mlx.Eval(x)
|
|
|
|
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
|
mlx.Eval(out)
|
|
|
|
// RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1]
|
|
data := out.Data()
|
|
for i, v := range data {
|
|
if math.Abs(float64(v-1.0)) > 1e-4 {
|
|
t.Errorf("at %d: expected ~1.0, got %f", i, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRMSNormWithScale verifies RMSNorm applies weight scaling.
|
|
func TestRMSNormWithScale(t *testing.T) {
|
|
weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4})
|
|
mlx.Eval(weight)
|
|
|
|
norm := NewRMSNorm(weight, 1e-5)
|
|
|
|
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
|
mlx.Eval(x)
|
|
|
|
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
|
mlx.Eval(out)
|
|
|
|
// Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2]
|
|
data := out.Data()
|
|
for i, v := range data {
|
|
if math.Abs(float64(v-2.0)) > 1e-4 {
|
|
t.Errorf("at %d: expected ~2.0, got %f", i, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestEmbedding verifies embedding lookup.
|
|
func TestEmbedding(t *testing.T) {
|
|
// Embedding table: 4 tokens, dim 3
|
|
weight := mlx.NewArrayFloat32([]float32{
|
|
0, 0, 0, // token 0
|
|
1, 1, 1, // token 1
|
|
2, 2, 2, // token 2
|
|
3, 3, 3, // token 3
|
|
}, []int32{4, 3})
|
|
mlx.Eval(weight)
|
|
|
|
emb := NewEmbedding(weight)
|
|
|
|
// Look up tokens [1, 3, 0]
|
|
indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3})
|
|
mlx.Eval(indices)
|
|
|
|
out := emb.Forward(indices)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Data()
|
|
expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0}
|
|
for i, v := range expected {
|
|
if data[i] != v {
|
|
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRepeatKV verifies K/V repetition for GQA.
|
|
func TestRepeatKV(t *testing.T) {
|
|
// [B=1, num_kv_heads=2, S=2, head_dim=2]
|
|
x := mlx.NewArrayFloat32([]float32{
|
|
// head 0
|
|
1, 2, // pos 0
|
|
3, 4, // pos 1
|
|
// head 1
|
|
5, 6, // pos 0
|
|
7, 8, // pos 1
|
|
}, []int32{1, 2, 2, 2})
|
|
mlx.Eval(x)
|
|
|
|
// Repeat factor 2: 2 kv heads -> 4 heads
|
|
out := RepeatKV(x, 2)
|
|
mlx.Eval(out)
|
|
|
|
shape := out.Shape()
|
|
if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 {
|
|
t.Errorf("expected shape [1,4,2,2], got %v", shape)
|
|
}
|
|
|
|
data := out.Data()
|
|
// After repeat: head0, head0, head1, head1
|
|
expected := []float32{
|
|
1, 2, 3, 4, // head 0 (original)
|
|
1, 2, 3, 4, // head 0 (repeat)
|
|
5, 6, 7, 8, // head 1 (original)
|
|
5, 6, 7, 8, // head 1 (repeat)
|
|
}
|
|
for i, v := range expected {
|
|
if data[i] != v {
|
|
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged.
|
|
func TestRepeatKVNoOp(t *testing.T) {
|
|
x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2})
|
|
mlx.Eval(x)
|
|
|
|
out := RepeatKV(x, 1)
|
|
// Should return same pointer
|
|
if out != x {
|
|
t.Error("RepeatKV with factor 1 should return input unchanged")
|
|
}
|
|
}
|
|
|
|
// TestApplyCausalMask verifies causal masking.
|
|
func TestApplyCausalMask(t *testing.T) {
|
|
// [B=1, heads=1, S=3, S=3] - all ones
|
|
scores := mlx.Ones(1, 1, 3, 3)
|
|
mlx.Eval(scores)
|
|
|
|
out := ApplyCausalMask(scores)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Data()
|
|
// Lower triangular should be 1, upper should be -1e9
|
|
// Row 0: [1, -inf, -inf]
|
|
// Row 1: [1, 1, -inf]
|
|
// Row 2: [1, 1, 1]
|
|
if data[0] != 1 || data[1] >= 0 || data[2] >= 0 {
|
|
t.Errorf("row 0 wrong: %v", data[0:3])
|
|
}
|
|
if data[3] != 1 || data[4] != 1 || data[5] >= 0 {
|
|
t.Errorf("row 1 wrong: %v", data[3:6])
|
|
}
|
|
if data[6] != 1 || data[7] != 1 || data[8] != 1 {
|
|
t.Errorf("row 2 wrong: %v", data[6:9])
|
|
}
|
|
}
|
|
|
|
// TestApplyCausalMaskWithOffset verifies causal masking with cache offset.
|
|
func TestApplyCausalMaskWithOffset(t *testing.T) {
|
|
// Simulating: cache has 2 tokens, adding 1 new query
|
|
// scores: [B=1, heads=1, queryLen=1, keyLen=3]
|
|
scores := mlx.Ones(1, 1, 1, 3)
|
|
mlx.Eval(scores)
|
|
|
|
out := ApplyCausalMaskWithOffset(scores, 2)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Data()
|
|
// With offset=2, query at position 2 can attend to all 3 positions
|
|
if data[0] != 1 || data[1] != 1 || data[2] != 1 {
|
|
t.Errorf("expected [1, 1, 1], got %v", data)
|
|
}
|
|
}
|
|
|
|
// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal.
|
|
func TestApplyCausalMaskWithOffsetZero(t *testing.T) {
|
|
scores := mlx.Ones(1, 1, 2, 2)
|
|
mlx.Eval(scores)
|
|
|
|
out := ApplyCausalMaskWithOffset(scores, 0)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Data()
|
|
// Standard causal: [1, -inf], [1, 1]
|
|
if data[0] != 1 || data[1] >= 0 {
|
|
t.Errorf("row 0 wrong: %v", data[0:2])
|
|
}
|
|
if data[2] != 1 || data[3] != 1 {
|
|
t.Errorf("row 1 wrong: %v", data[2:4])
|
|
}
|
|
}
|
|
|
|
// BenchmarkLinearSmall benchmarks small Linear forward pass.
|
|
func BenchmarkLinearSmall(b *testing.B) {
|
|
weight := mlx.RandomNormal([]int32{256, 256}, 42)
|
|
mlx.Eval(weight)
|
|
|
|
linear := NewLinear(weight, nil)
|
|
|
|
x := mlx.RandomNormal([]int32{1, 256}, 43)
|
|
mlx.Eval(x)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
out := linear.Forward(x)
|
|
mlx.Eval(out)
|
|
}
|
|
}
|
|
|
|
// BenchmarkLinearLarge benchmarks larger Linear forward pass.
|
|
func BenchmarkLinearLarge(b *testing.B) {
|
|
weight := mlx.RandomNormal([]int32{4096, 4096}, 42)
|
|
mlx.Eval(weight)
|
|
|
|
linear := NewLinear(weight, nil)
|
|
|
|
x := mlx.RandomNormal([]int32{1, 4096}, 43)
|
|
mlx.Eval(x)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
out := linear.Forward(x)
|
|
mlx.Eval(out)
|
|
}
|
|
}
|
|
|
|
// BenchmarkRMSNorm benchmarks RMSNorm forward pass.
|
|
func BenchmarkRMSNorm(b *testing.B) {
|
|
weight := mlx.Ones(4096)
|
|
mlx.Eval(weight)
|
|
|
|
norm := NewRMSNorm(weight, 1e-5)
|
|
|
|
x := mlx.RandomNormal([]int32{1, 4096}, 42)
|
|
mlx.Eval(x)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
out := norm.Forward(x, 0)
|
|
mlx.Eval(out)
|
|
}
|
|
}
|
|
|
|
// BenchmarkEmbedding benchmarks embedding lookup.
|
|
func BenchmarkEmbedding(b *testing.B) {
|
|
// Typical vocab size
|
|
weight := mlx.RandomNormal([]int32{32000, 4096}, 42)
|
|
mlx.Eval(weight)
|
|
|
|
emb := NewEmbedding(weight)
|
|
|
|
// Single token lookup
|
|
indices := mlx.NewArrayInt32([]int32{1000}, []int32{1})
|
|
mlx.Eval(indices)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
out := emb.Forward(indices)
|
|
mlx.Eval(out)
|
|
}
|
|
}
|
|
|
|
// BenchmarkRepeatKV benchmarks K/V repetition.
|
|
func BenchmarkRepeatKV(b *testing.B) {
|
|
// Typical GQA setup: 8 kv heads -> 32 heads
|
|
x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42)
|
|
mlx.Eval(x)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
out := RepeatKV(x, 4)
|
|
mlx.Eval(out)
|
|
}
|
|
}
|