mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
1146 lines
27 KiB
Go
1146 lines
27 KiB
Go
//go:build mlx
|
|
|
|
package mlx
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
)
|
|
|
|
// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive.
|
|
func TestBasicCleanup(t *testing.T) {
|
|
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
intermediate := NewArrayFloat32([]float32{1, 1}, []int32{1, 2})
|
|
result := Matmul(intermediate, weight)
|
|
Keep(result)
|
|
|
|
// Before eval: intermediate should be valid
|
|
if !intermediate.Valid() {
|
|
t.Fatal("intermediate should be valid before Eval")
|
|
}
|
|
|
|
Eval(result)
|
|
|
|
// After eval: intermediate should be freed
|
|
if intermediate.Valid() {
|
|
t.Fatal("intermediate should be freed after Eval")
|
|
}
|
|
|
|
// Result should have correct values
|
|
data := result.Data()
|
|
if data[0] != 4 || data[1] != 6 {
|
|
t.Errorf("expected [4, 6], got %v", data)
|
|
}
|
|
|
|
// Weight should survive
|
|
if !weight.Valid() {
|
|
t.Error("weight was freed")
|
|
}
|
|
}
|
|
|
|
// TestKeptSurvives verifies kept arrays are not freed.
|
|
func TestKeptSurvives(t *testing.T) {
|
|
a := NewArrayFloat32([]float32{1, 2}, []int32{2})
|
|
b := NewArrayFloat32([]float32{3, 4}, []int32{2})
|
|
result := Add(a, b)
|
|
Keep(result)
|
|
|
|
Eval(result)
|
|
|
|
if !result.Valid() {
|
|
t.Error("kept result was freed")
|
|
}
|
|
|
|
data := result.Data()
|
|
if data[0] != 4 || data[1] != 6 {
|
|
t.Errorf("expected [4, 6], got %v", data)
|
|
}
|
|
}
|
|
|
|
// TestEvalAutoKeeps verifies Eval automatically keeps its outputs.
|
|
func TestEvalAutoKeeps(t *testing.T) {
|
|
a := NewArrayFloat32([]float32{1, 2}, []int32{2})
|
|
b := NewArrayFloat32([]float32{3, 4}, []int32{2})
|
|
result := Add(a, b)
|
|
|
|
// Don't call Keep(result) - Eval should auto-keep it
|
|
Eval(result)
|
|
|
|
// Result should survive (auto-kept by Eval)
|
|
if !result.Valid() {
|
|
t.Error("Eval output was freed - should be auto-kept")
|
|
}
|
|
|
|
// Inputs should be freed (not kept)
|
|
if a.Valid() {
|
|
t.Error("input 'a' should be freed")
|
|
}
|
|
if b.Valid() {
|
|
t.Error("input 'b' should be freed")
|
|
}
|
|
|
|
// Verify data is correct
|
|
data := result.Data()
|
|
if data[0] != 4 || data[1] != 6 {
|
|
t.Errorf("expected [4, 6], got %v", data)
|
|
}
|
|
}
|
|
|
|
// TestWeightsSurvive verifies kept arrays survive multiple Eval cycles.
|
|
func TestWeightsSurvive(t *testing.T) {
|
|
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
for i := 0; i < 5; i++ {
|
|
x := NewArrayFloat32([]float32{1, 1}, []int32{1, 2})
|
|
result := Matmul(x, weight)
|
|
Keep(result)
|
|
Eval(result)
|
|
}
|
|
|
|
if !weight.Valid() {
|
|
t.Error("weight was freed after multiple iterations")
|
|
}
|
|
}
|
|
|
|
// TestAsyncEvalCleanup verifies AsyncEval cleans up and dispatches.
|
|
func TestAsyncEvalCleanup(t *testing.T) {
|
|
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) // Identity matrix
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
// First async step
|
|
x1 := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
|
|
result1 := Matmul(x1, weight)
|
|
Keep(result1)
|
|
AsyncEval(result1)
|
|
|
|
// Second async step
|
|
x2 := NewArrayFloat32([]float32{3, 4}, []int32{1, 2})
|
|
result2 := Matmul(x2, weight)
|
|
Keep(result2)
|
|
AsyncEval(result2)
|
|
|
|
// Sync and verify results
|
|
result1.Eval()
|
|
d1 := result1.Data()
|
|
if d1[0] != 1 || d1[1] != 2 {
|
|
t.Errorf("result1: expected [1, 2], got %v", d1)
|
|
}
|
|
|
|
result2.Eval()
|
|
d2 := result2.Data()
|
|
if d2[0] != 3 || d2[1] != 4 {
|
|
t.Errorf("result2: expected [3, 4], got %v", d2)
|
|
}
|
|
|
|
if !weight.Valid() {
|
|
t.Error("weight was freed during async")
|
|
}
|
|
}
|
|
|
|
// TestMultiOutput verifies multiple kept arrays survive.
|
|
func TestMultiOutput(t *testing.T) {
|
|
a := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
|
sum := Add(a, a)
|
|
prod := Mul(a, a)
|
|
Keep(sum, prod)
|
|
|
|
Eval(sum, prod)
|
|
|
|
// Both kept arrays should be valid
|
|
if !sum.Valid() || !prod.Valid() {
|
|
t.Error("kept arrays should survive cleanup")
|
|
}
|
|
|
|
// Verify values
|
|
sumData := sum.Data()
|
|
prodData := prod.Data()
|
|
if sumData[0] != 2 || prodData[0] != 1 {
|
|
t.Errorf("unexpected results: sum=%v prod=%v", sumData, prodData)
|
|
}
|
|
}
|
|
|
|
// TestChaining verifies output from one step can be used in next.
|
|
func TestChaining(t *testing.T) {
|
|
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
// First step
|
|
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
|
|
out1 := Matmul(x, weight)
|
|
Keep(out1)
|
|
AsyncEval(out1)
|
|
|
|
// Second step uses output of first
|
|
out2 := Add(out1, out1)
|
|
Keep(out2)
|
|
Eval(out2)
|
|
|
|
// out1 should survive (was kept)
|
|
if !out1.Valid() {
|
|
t.Error("out1 was freed but used by second step")
|
|
}
|
|
|
|
// Final result should be correct
|
|
data := out2.Data()
|
|
if data[0] != 2 || data[1] != 4 {
|
|
t.Errorf("expected [2, 4], got %v", data)
|
|
}
|
|
}
|
|
|
|
// TestGenerationLoop simulates the LLM generation pattern with cache.
|
|
func TestGenerationLoop(t *testing.T) {
|
|
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
// Simulate cache - starts as zeros
|
|
cache := NewArrayFloat32([]float32{0, 0}, []int32{1, 2})
|
|
Keep(cache)
|
|
cache.Eval()
|
|
|
|
var lastToken *Array
|
|
|
|
// Simulate 5 generation steps
|
|
for step := 0; step < 5; step++ {
|
|
oldCache := cache
|
|
|
|
// Simulate forward pass
|
|
input := NewArrayFloat32([]float32{float32(step + 1), float32(step + 2)}, []int32{1, 2})
|
|
output := Matmul(input, weight)
|
|
|
|
// Simulate cache update
|
|
newCache := Add(output, cache)
|
|
|
|
// Mark what survives
|
|
Keep(output, newCache)
|
|
|
|
if step < 4 {
|
|
AsyncEval(output, newCache)
|
|
} else {
|
|
Eval(output, newCache)
|
|
}
|
|
|
|
// Free old cache, update references
|
|
oldCache.Free()
|
|
lastToken = output
|
|
cache = newCache
|
|
}
|
|
|
|
// Token output should be valid
|
|
if !lastToken.Valid() {
|
|
t.Error("token output was freed")
|
|
}
|
|
|
|
// Cache should be valid
|
|
if !cache.Valid() {
|
|
t.Error("cache was freed")
|
|
}
|
|
|
|
// Weight should survive all iterations
|
|
if !weight.Valid() {
|
|
t.Error("weight was freed")
|
|
}
|
|
}
|
|
|
|
// BenchmarkCleanupOnly isolates cleanup cost without MLX ops.
|
|
func BenchmarkCleanupOnly(b *testing.B) {
|
|
// Pre-create weight
|
|
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
// Create 100 arrays - minimal ops
|
|
arrays := make([]*Array, 100)
|
|
for j := range arrays {
|
|
arrays[j] = NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
|
|
}
|
|
Keep(arrays[0])
|
|
Eval() // Just cleanup
|
|
}
|
|
}
|
|
|
|
// BenchmarkNewArrayOnly measures array creation overhead.
|
|
func BenchmarkNewArrayOnly(b *testing.B) {
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_ = NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
|
}
|
|
}
|
|
|
|
// BenchmarkCGOCallOverhead measures raw CGO call cost.
|
|
func BenchmarkCGOCallOverhead(b *testing.B) {
|
|
arr := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
|
Keep(arr)
|
|
arr.Eval()
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_ = arr.Ndim() // Simple CGO call
|
|
}
|
|
}
|
|
|
|
// BenchmarkCleanup_50 measures cleanup with 50 arrays.
|
|
func BenchmarkCleanup_50(b *testing.B) {
|
|
benchCleanup(b, 50)
|
|
}
|
|
|
|
// BenchmarkCleanup_500 measures cleanup with 500 arrays (LLM scale).
|
|
func BenchmarkCleanup_500(b *testing.B) {
|
|
benchCleanup(b, 500)
|
|
}
|
|
|
|
// BenchmarkCleanup_1000 measures cleanup with 1000 arrays.
|
|
func BenchmarkCleanup_1000(b *testing.B) {
|
|
benchCleanup(b, 1000)
|
|
}
|
|
|
|
func benchCleanup(b *testing.B, numArrays int) {
|
|
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
|
|
for j := 0; j < numArrays; j++ {
|
|
x = Add(x, x)
|
|
}
|
|
result := Matmul(x, weight)
|
|
Keep(result)
|
|
Eval(result)
|
|
}
|
|
}
|
|
|
|
// BenchmarkGenerationLoop_10 simulates 10 token generation steps.
|
|
func BenchmarkGenerationLoop_10(b *testing.B) {
|
|
benchGenerationLoop(b, 10)
|
|
}
|
|
|
|
// BenchmarkGenerationLoop_100 simulates 100 token generation steps.
|
|
func BenchmarkGenerationLoop_100(b *testing.B) {
|
|
benchGenerationLoop(b, 100)
|
|
}
|
|
|
|
func benchGenerationLoop(b *testing.B, steps int) {
|
|
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
cache := NewArrayFloat32([]float32{0, 0}, []int32{1, 2})
|
|
Keep(cache)
|
|
cache.Eval()
|
|
|
|
for step := 0; step < steps; step++ {
|
|
oldCache := cache
|
|
input := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
|
|
output := Matmul(input, weight)
|
|
newCache := Add(output, cache)
|
|
Keep(output, newCache)
|
|
|
|
if step < steps-1 {
|
|
AsyncEval(output, newCache)
|
|
} else {
|
|
Eval(output, newCache)
|
|
}
|
|
oldCache.Free()
|
|
cache = newCache
|
|
}
|
|
}
|
|
}
|
|
|
|
// BenchmarkLLMForward simulates a realistic LLM forward pass with ~500 ops.
|
|
func BenchmarkLLMForward(b *testing.B) {
|
|
// Simulate weights for 32 layers
|
|
numLayers := 32
|
|
weights := make([]*Array, numLayers*4) // q, k, v, o per layer
|
|
for i := range weights {
|
|
weights[i] = NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
|
|
}
|
|
Keep(weights...)
|
|
Eval(weights...)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
|
|
|
|
// Simulate 32 transformer layers
|
|
for layer := 0; layer < numLayers; layer++ {
|
|
// Attention block (simplified)
|
|
q := Matmul(x, weights[layer*4])
|
|
k := Matmul(x, weights[layer*4+1])
|
|
v := Matmul(x, weights[layer*4+2])
|
|
attn := Matmul(Softmax(Matmul(q, Transpose(k, 1, 0)), -1), v)
|
|
attnOut := Matmul(attn, weights[layer*4+3])
|
|
|
|
// Residual + layernorm (simplified)
|
|
x = Add(x, attnOut)
|
|
x = RMSNormNoWeight(x, 1e-5)
|
|
|
|
// FFN (simplified as single matmul)
|
|
ffn := Matmul(x, weights[layer*4])
|
|
ffn = SiLU(ffn)
|
|
x = Add(x, ffn)
|
|
}
|
|
Keep(x)
|
|
Eval(x)
|
|
}
|
|
}
|
|
|
|
// ============ Compile Tests ============
|
|
|
|
// gelu implements GELU activation: x * 0.5 * (1 + erf(x / sqrt(2)))
|
|
func gelu(x *Array) *Array {
|
|
sqrt2 := NewScalarArray(1.4142135623730951)
|
|
half := NewScalarArray(0.5)
|
|
one := NewScalarArray(1.0)
|
|
scaled := Div(x, sqrt2)
|
|
erfd := Erf(scaled)
|
|
return Mul(Mul(x, half), Add(one, erfd))
|
|
}
|
|
|
|
// TestCompileBasic verifies compiled function produces correct output.
|
|
func TestCompileBasic(t *testing.T) {
|
|
x := NewArrayFloat32([]float32{-1, 0, 1, 2}, []int32{4})
|
|
Keep(x)
|
|
x.Eval()
|
|
|
|
// Uncompiled
|
|
expected := gelu(x)
|
|
Keep(expected)
|
|
Eval(expected)
|
|
|
|
// Compiled
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
return []*Array{gelu(inputs[0])}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
result := compiled.Call(x)[0]
|
|
Keep(result)
|
|
Eval(result)
|
|
|
|
// Compare with tolerance
|
|
expData := expected.Data()
|
|
resData := result.Data()
|
|
for i := range expData {
|
|
diff := expData[i] - resData[i]
|
|
if diff < 0 {
|
|
diff = -diff
|
|
}
|
|
if diff > 1e-5 {
|
|
t.Errorf("mismatch at %d: expected %f, got %f (diff=%e)", i, expData[i], resData[i], diff)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestCompileMultipleInputs verifies compiled function with multiple inputs.
|
|
func TestCompileMultipleInputs(t *testing.T) {
|
|
a := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{4})
|
|
b := NewArrayFloat32([]float32{5, 6, 7, 8}, []int32{4})
|
|
Keep(a, b)
|
|
Eval(a, b)
|
|
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
sum := Add(inputs[0], inputs[1])
|
|
prod := Mul(inputs[0], inputs[1])
|
|
return []*Array{sum, prod}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
outputs := compiled.Call(a, b)
|
|
Keep(outputs...)
|
|
Eval(outputs...)
|
|
|
|
sumData := outputs[0].Data()
|
|
prodData := outputs[1].Data()
|
|
if sumData[0] != 6 || prodData[0] != 5 {
|
|
t.Errorf("unexpected: sum[0]=%f, prod[0]=%f", sumData[0], prodData[0])
|
|
}
|
|
}
|
|
|
|
// TestCompileReuse verifies compiled function can be called multiple times.
|
|
func TestCompileReuse(t *testing.T) {
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
return []*Array{Add(inputs[0], inputs[0])}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
for i := 0; i < 5; i++ {
|
|
x := NewArrayFloat32([]float32{float32(i)}, []int32{1})
|
|
Keep(x)
|
|
x.Eval()
|
|
result := compiled.Call(x)[0]
|
|
Keep(result)
|
|
Eval(result)
|
|
data := result.Data()
|
|
expected := float32(i * 2)
|
|
if data[0] != expected {
|
|
t.Errorf("iteration %d: expected %f, got %f", i, expected, data[0])
|
|
}
|
|
}
|
|
}
|
|
|
|
// BenchmarkGELUUncompiled benchmarks uncompiled GELU.
|
|
func BenchmarkGELUUncompiled(b *testing.B) {
|
|
x := RandomNormal([]int32{1000, 1024}, 42)
|
|
Keep(x)
|
|
x.Eval()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
y := x
|
|
for j := 0; j < 10; j++ {
|
|
y = gelu(y)
|
|
}
|
|
Keep(y)
|
|
Eval(y)
|
|
}
|
|
}
|
|
|
|
// BenchmarkGELUCompiled benchmarks compiled GELU.
|
|
func BenchmarkGELUCompiled(b *testing.B) {
|
|
x := RandomNormal([]int32{1000, 1024}, 42)
|
|
Keep(x)
|
|
x.Eval()
|
|
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
y := inputs[0]
|
|
for j := 0; j < 10; j++ {
|
|
y = gelu(y)
|
|
}
|
|
return []*Array{y}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
result := compiled.Call(x)
|
|
Keep(result[0])
|
|
Eval(result[0])
|
|
}
|
|
}
|
|
|
|
// TestCompileNoMemoryLeak verifies compiled functions don't leak memory.
|
|
func TestCompileNoMemoryLeak(t *testing.T) {
|
|
x := RandomNormal([]int32{100, 100}, 42)
|
|
Keep(x)
|
|
x.Eval()
|
|
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
y := inputs[0]
|
|
for j := 0; j < 5; j++ {
|
|
y = gelu(y)
|
|
}
|
|
return []*Array{y}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
// Warmup to establish baseline
|
|
for i := 0; i < 10; i++ {
|
|
result := compiled.Call(x)
|
|
Keep(result[0])
|
|
Eval(result[0])
|
|
result[0].Free()
|
|
}
|
|
|
|
MetalResetPeakMemory()
|
|
initialMem := MetalGetActiveMemory()
|
|
|
|
for i := 0; i < 100; i++ {
|
|
result := compiled.Call(x)
|
|
Keep(result[0])
|
|
Eval(result[0])
|
|
result[0].Free()
|
|
}
|
|
|
|
Eval() // Final cleanup
|
|
|
|
finalMem := MetalGetActiveMemory()
|
|
peakMem := MetalGetPeakMemory()
|
|
|
|
// Memory should not grow significantly (allow 10MB slack for caching)
|
|
growth := int64(finalMem) - int64(initialMem)
|
|
if growth > 10*1024*1024 {
|
|
t.Errorf("memory grew by %d bytes over 100 iterations", growth)
|
|
}
|
|
t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB",
|
|
initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024)
|
|
}
|
|
|
|
// TestCompileWithRandomState verifies compiled function can capture and update random state.
|
|
func TestCompileWithRandomState(t *testing.T) {
|
|
// Simulate logits for sampling
|
|
logits := NewArrayFloat32([]float32{0.1, 0.2, 0.3, 0.4}, []int32{1, 4})
|
|
Keep(logits)
|
|
logits.Eval()
|
|
|
|
// Initial random key
|
|
key := RandomKey(42)
|
|
Keep(key)
|
|
|
|
// Compile a sampling function that splits the key
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
logits := inputs[0]
|
|
keyIn := inputs[1]
|
|
|
|
// Split key: one for sampling, one for next iteration
|
|
key1, key2 := RandomSplit(keyIn)
|
|
|
|
// Sample from logits
|
|
sample := RandomCategoricalWithKey(logits, key2, -1, 1)
|
|
|
|
return []*Array{sample, key1}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
// Run multiple sampling steps
|
|
samples := make([]int32, 10)
|
|
for i := 0; i < 10; i++ {
|
|
outputs := compiled.Call(logits, key)
|
|
Keep(outputs...)
|
|
Eval(outputs...)
|
|
samples[i] = outputs[0].ItemInt32()
|
|
key.Free()
|
|
key = outputs[1]
|
|
}
|
|
|
|
// Verify we got valid samples (0-3)
|
|
for i, s := range samples {
|
|
if s < 0 || s > 3 {
|
|
t.Errorf("sample %d out of range: %d", i, s)
|
|
}
|
|
}
|
|
t.Logf("samples: %v", samples)
|
|
|
|
// Verify samples aren't all the same (randomness works)
|
|
allSame := true
|
|
for i := 1; i < len(samples); i++ {
|
|
if samples[i] != samples[0] {
|
|
allSame = false
|
|
break
|
|
}
|
|
}
|
|
if allSame {
|
|
t.Error("all samples are the same - random state may not be updating")
|
|
}
|
|
}
|
|
|
|
// swiGLU implements the GPT-OSS custom SwiGLU activation.
|
|
func swiGLU(gate, up *Array, alpha, limit float32) *Array {
|
|
gateClipped := ClipScalar(gate, 0, limit, false, true)
|
|
upClipped := ClipScalar(up, -limit, limit, true, true)
|
|
gluScaled := MulScalar(gateClipped, alpha)
|
|
sig := Sigmoid(gluScaled)
|
|
outGlu := Mul(gateClipped, sig)
|
|
return Mul(outGlu, AddScalar(upClipped, 1.0))
|
|
}
|
|
|
|
// TestCompileSwiGLU verifies compiled SwiGLU produces correct output.
|
|
func TestCompileSwiGLU(t *testing.T) {
|
|
gate := NewArrayFloat32([]float32{-1, 0, 1, 2, 5, 10}, []int32{6})
|
|
up := NewArrayFloat32([]float32{-5, -1, 0, 1, 5, 10}, []int32{6})
|
|
Keep(gate, up)
|
|
Eval(gate, up)
|
|
|
|
const alpha float32 = 1.702
|
|
const limit float32 = 7.0
|
|
|
|
// Uncompiled
|
|
expected := swiGLU(gate, up, alpha, limit)
|
|
Keep(expected)
|
|
Eval(expected)
|
|
|
|
// Compiled
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
return []*Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
result := compiled.Call(gate, up)[0]
|
|
Keep(result)
|
|
Eval(result)
|
|
|
|
// Compare
|
|
expData := expected.Data()
|
|
resData := result.Data()
|
|
for i := range expData {
|
|
diff := expData[i] - resData[i]
|
|
if diff < 0 {
|
|
diff = -diff
|
|
}
|
|
if diff > 1e-5 {
|
|
t.Errorf("mismatch at %d: expected %f, got %f", i, expData[i], resData[i])
|
|
}
|
|
}
|
|
t.Logf("SwiGLU results: %v", resData)
|
|
}
|
|
|
|
// BenchmarkSwiGLUUncompiled benchmarks uncompiled SwiGLU.
|
|
func BenchmarkSwiGLUUncompiled(b *testing.B) {
|
|
gate := RandomNormal([]int32{1, 2880}, 42)
|
|
up := RandomNormal([]int32{1, 2880}, 43)
|
|
Keep(gate, up)
|
|
Eval(gate, up)
|
|
|
|
const alpha float32 = 1.702
|
|
const limit float32 = 7.0
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
result := swiGLU(gate, up, alpha, limit)
|
|
Keep(result)
|
|
Eval(result)
|
|
}
|
|
}
|
|
|
|
// BenchmarkSwiGLUCompiled benchmarks compiled SwiGLU.
|
|
func BenchmarkSwiGLUCompiled(b *testing.B) {
|
|
gate := RandomNormal([]int32{1, 2880}, 42)
|
|
up := RandomNormal([]int32{1, 2880}, 43)
|
|
Keep(gate, up)
|
|
Eval(gate, up)
|
|
|
|
const alpha float32 = 1.702
|
|
const limit float32 = 7.0
|
|
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
return []*Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
result := compiled.Call(gate, up)
|
|
Keep(result[0])
|
|
Eval(result[0])
|
|
}
|
|
}
|
|
|
|
// BenchmarkSwiGLU10xUncompiled benchmarks 10 chained SwiGLU ops uncompiled.
|
|
func BenchmarkSwiGLU10xUncompiled(b *testing.B) {
|
|
x := RandomNormal([]int32{1, 2880}, 42)
|
|
Keep(x)
|
|
x.Eval()
|
|
|
|
const alpha float32 = 1.702
|
|
const limit float32 = 7.0
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
y := x
|
|
for j := 0; j < 10; j++ {
|
|
y = swiGLU(y, y, alpha, limit)
|
|
}
|
|
Keep(y)
|
|
Eval(y)
|
|
}
|
|
}
|
|
|
|
// BenchmarkSwiGLU10xCompiled benchmarks 10 chained SwiGLU ops compiled.
|
|
func BenchmarkSwiGLU10xCompiled(b *testing.B) {
|
|
x := RandomNormal([]int32{1, 2880}, 42)
|
|
Keep(x)
|
|
x.Eval()
|
|
|
|
const alpha float32 = 1.702
|
|
const limit float32 = 7.0
|
|
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
y := inputs[0]
|
|
for j := 0; j < 10; j++ {
|
|
y = swiGLU(y, y, alpha, limit)
|
|
}
|
|
return []*Array{y}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
result := compiled.Call(x)
|
|
Keep(result[0])
|
|
Eval(result[0])
|
|
}
|
|
}
|
|
|
|
// ============ Sampler Benchmarks ============
|
|
|
|
// sampleTopK implements top-k sampling
|
|
func sampleTopK(logits, key *Array, k int) (*Array, *Array) {
|
|
neg := Neg(logits)
|
|
indices := Argpartition(neg, k-1, -1)
|
|
topK := Slice(indices, []int32{0}, []int32{int32(k)})
|
|
values := TakeAlongAxis(logits, topK, -1)
|
|
key1, key2 := RandomSplit(key)
|
|
sampled := RandomCategoricalWithKey(values, key2, -1, 1)
|
|
return Take(topK, sampled, -1), key1
|
|
}
|
|
|
|
// sampleTopP implements top-p (nucleus) sampling
|
|
func sampleTopP(logits, key *Array, p float32, vocabSize int32) (*Array, *Array) {
|
|
sorted := Argsort(Neg(logits), -1)
|
|
sortedLogits := TakeAlongAxis(logits, sorted, -1)
|
|
probs := Softmax(sortedLogits, -1)
|
|
cumProbs := Cumsum(probs, -1)
|
|
mask := LessScalar(cumProbs, p)
|
|
negInf := FullDtype(float32(-1e9), logits.Dtype(), vocabSize)
|
|
masked := Where(mask, sortedLogits, negInf)
|
|
key1, key2 := RandomSplit(key)
|
|
sampled := RandomCategoricalWithKey(masked, key2, -1, 1)
|
|
return Take(sorted, sampled, -1), key1
|
|
}
|
|
|
|
// BenchmarkSampleTopKUncompiled benchmarks uncompiled top-k sampling.
|
|
func BenchmarkSampleTopKUncompiled(b *testing.B) {
|
|
vocabSize := int32(32000)
|
|
logits := RandomNormal([]int32{vocabSize}, 42)
|
|
key := RandomKey(42)
|
|
Keep(logits, key)
|
|
Eval(logits, key)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
var token *Array
|
|
token, key = sampleTopK(logits, key, 40)
|
|
Keep(token, key)
|
|
Eval(token)
|
|
}
|
|
}
|
|
|
|
// BenchmarkSampleTopKCompiled benchmarks compiled top-k sampling.
|
|
func BenchmarkSampleTopKCompiled(b *testing.B) {
|
|
vocabSize := int32(32000)
|
|
logits := RandomNormal([]int32{vocabSize}, 42)
|
|
key := RandomKey(42)
|
|
Keep(logits, key)
|
|
Eval(logits, key)
|
|
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
token, newKey := sampleTopK(inputs[0], inputs[1], 40)
|
|
return []*Array{token, newKey}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
outputs := compiled.Call(logits, key)
|
|
Keep(outputs...)
|
|
Eval(outputs[0])
|
|
key = outputs[1]
|
|
}
|
|
}
|
|
|
|
// BenchmarkSampleTopPUncompiled benchmarks uncompiled top-p sampling.
|
|
func BenchmarkSampleTopPUncompiled(b *testing.B) {
|
|
vocabSize := int32(32000)
|
|
logits := RandomNormal([]int32{vocabSize}, 42)
|
|
key := RandomKey(42)
|
|
Keep(logits, key)
|
|
Eval(logits, key)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
var token *Array
|
|
token, key = sampleTopP(logits, key, 0.9, vocabSize)
|
|
Keep(token, key)
|
|
Eval(token)
|
|
}
|
|
}
|
|
|
|
// BenchmarkSampleTopPCompiled benchmarks compiled top-p sampling.
|
|
func BenchmarkSampleTopPCompiled(b *testing.B) {
|
|
vocabSize := int32(32000)
|
|
logits := RandomNormal([]int32{vocabSize}, 42)
|
|
key := RandomKey(42)
|
|
Keep(logits, key)
|
|
Eval(logits, key)
|
|
|
|
compiled := Compile(func(inputs []*Array) []*Array {
|
|
token, newKey := sampleTopP(inputs[0], inputs[1], 0.9, vocabSize)
|
|
return []*Array{token, newKey}
|
|
})
|
|
defer compiled.Free()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
outputs := compiled.Call(logits, key)
|
|
Keep(outputs...)
|
|
Eval(outputs[0])
|
|
key = outputs[1]
|
|
}
|
|
}
|
|
|
|
// TestCompiledSamplerMemoryStable verifies compiled samplers don't leak memory.
|
|
func TestCompiledSamplerMemoryStable(t *testing.T) {
|
|
vocabSize := int32(32000)
|
|
logits := RandomNormal([]int32{vocabSize}, 42)
|
|
key := RandomKey(42)
|
|
Keep(logits, key)
|
|
Eval(logits, key)
|
|
|
|
compiledTopK := Compile(func(inputs []*Array) []*Array {
|
|
token, newKey := sampleTopK(inputs[0], inputs[1], 40)
|
|
return []*Array{token, newKey}
|
|
})
|
|
defer compiledTopK.Free()
|
|
|
|
compiledTopP := Compile(func(inputs []*Array) []*Array {
|
|
token, newKey := sampleTopP(inputs[0], inputs[1], 0.9, vocabSize)
|
|
return []*Array{token, newKey}
|
|
})
|
|
defer compiledTopP.Free()
|
|
|
|
// Warmup
|
|
for i := 0; i < 10; i++ {
|
|
out := compiledTopK.Call(logits, key)
|
|
Keep(out...)
|
|
Eval(out[0])
|
|
out[0].Free()
|
|
key = out[1]
|
|
}
|
|
|
|
MetalResetPeakMemory()
|
|
initialMem := MetalGetActiveMemory()
|
|
|
|
// Run 500 iterations of each sampler
|
|
for i := 0; i < 500; i++ {
|
|
// TopK
|
|
out := compiledTopK.Call(logits, key)
|
|
Keep(out...)
|
|
Eval(out[0])
|
|
out[0].Free()
|
|
key = out[1]
|
|
|
|
// TopP
|
|
out = compiledTopP.Call(logits, key)
|
|
Keep(out...)
|
|
Eval(out[0])
|
|
out[0].Free()
|
|
key = out[1]
|
|
}
|
|
|
|
Eval() // Final cleanup
|
|
|
|
finalMem := MetalGetActiveMemory()
|
|
peakMem := MetalGetPeakMemory()
|
|
|
|
growth := int64(finalMem) - int64(initialMem)
|
|
t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB",
|
|
initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024)
|
|
|
|
// Memory should stay bounded (allow 20MB for caching overhead)
|
|
if growth > 20*1024*1024 {
|
|
t.Errorf("memory grew by %d bytes over 1000 sampler calls - possible leak!", growth)
|
|
}
|
|
}
|
|
|
|
// BenchmarkSimpleOps measures simple ops with cleanup
|
|
func BenchmarkSimpleOps(b *testing.B) {
|
|
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
|
|
Keep(weight)
|
|
weight.Eval()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
|
|
result := Matmul(x, weight)
|
|
Keep(result)
|
|
AsyncEval(result)
|
|
result.Eval()
|
|
}
|
|
}
|
|
|
|
// BenchmarkLayerLike measures layer-like ops (~15 ops)
|
|
func BenchmarkLayerLike(b *testing.B) {
|
|
hidden := int32(256)
|
|
w := Ones(hidden, hidden)
|
|
Keep(w)
|
|
w.Eval()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
x := Ones(1, hidden)
|
|
// Simulate attention-like ops with proper shapes
|
|
h := Matmul(x, w) // [1, 256] @ [256, 256] = [1, 256]
|
|
h = Add(h, Matmul(h, w)) // residual
|
|
h = Mul(h, Sigmoid(Matmul(h, w))) // gating
|
|
h = Matmul(h, w) // output projection
|
|
h = Add(x, RMSNormNoWeight(h, 1e-5)) // residual + norm
|
|
Keep(h)
|
|
AsyncEval(h)
|
|
Eval(h)
|
|
}
|
|
}
|
|
|
|
// BenchmarkManyOps measures with increasing op counts
|
|
func BenchmarkManyOps(b *testing.B) {
|
|
w := Ones(64, 64)
|
|
Keep(w)
|
|
w.Eval()
|
|
|
|
for _, numOps := range []int{10, 50, 100, 500, 1000} {
|
|
b.Run(fmt.Sprintf("ops_%d", numOps), func(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
x := Ones(1, 64)
|
|
for j := 0; j < numOps; j++ {
|
|
x = Add(x, Matmul(x, w))
|
|
}
|
|
Keep(x)
|
|
AsyncEval(x)
|
|
Eval(x)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// BenchmarkLLMScale measures at LLM-realistic scale (~1348 arrays)
|
|
func BenchmarkLLMScale(b *testing.B) {
|
|
// Simulate Qwen-like model: 24 layers, each with ~56 ops = 1344 arrays
|
|
numLayers := 24
|
|
opsPerLayer := 56
|
|
|
|
// Create weights
|
|
hidden := int32(64)
|
|
weights := make([]*Array, numLayers*4)
|
|
for i := range weights {
|
|
weights[i] = Ones(hidden, hidden)
|
|
}
|
|
Keep(weights...)
|
|
Eval(weights...)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
x := Ones(1, hidden)
|
|
|
|
for layer := 0; layer < numLayers; layer++ {
|
|
for op := 0; op < opsPerLayer/4; op++ {
|
|
x = Add(x, Matmul(x, weights[layer*4]))
|
|
x = Mul(x, Sigmoid(x))
|
|
}
|
|
}
|
|
Keep(x)
|
|
AsyncEval(x)
|
|
Eval(x)
|
|
}
|
|
}
|
|
|
|
// BenchmarkArrayFreeLoop measures the cost of freeing N arrays
|
|
func BenchmarkArrayFreeLoop(b *testing.B) {
|
|
for _, count := range []int{100, 500, 1000, 1500} {
|
|
b.Run(fmt.Sprintf("arrays_%d", count), func(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
b.StopTimer()
|
|
arrays := make([]*Array, count)
|
|
for j := 0; j < count; j++ {
|
|
arrays[j] = NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
|
}
|
|
b.StartTimer()
|
|
|
|
// Cleanup all arrays
|
|
Eval()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// BenchmarkCleanupIsolated measures just cleanup time
|
|
func BenchmarkCleanupIsolated(b *testing.B) {
|
|
w := NewArrayFloat32([]float32{1}, []int32{1, 1})
|
|
Keep(w)
|
|
w.Eval()
|
|
|
|
for _, count := range []int{100, 500, 1000, 1500} {
|
|
b.Run(fmt.Sprintf("arrays_%d", count), func(b *testing.B) {
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
b.StopTimer()
|
|
x := NewArrayFloat32([]float32{1}, []int32{1})
|
|
for j := 0; j < count; j++ {
|
|
x = Add(x, x)
|
|
}
|
|
Keep(x)
|
|
b.StartTimer()
|
|
Eval() // Just cleanup
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMemoryStable verifies that cleanup doesn't cause unbounded memory growth.
|
|
func TestMemoryStable(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping memory test in short mode")
|
|
}
|
|
|
|
// Create realistic-sized arrays (like KV cache)
|
|
batchSize := int32(1)
|
|
numHeads := int32(8)
|
|
seqLen := int32(256)
|
|
headDim := int32(64)
|
|
cacheShape := []int32{batchSize, numHeads, seqLen, headDim}
|
|
cacheSize := batchSize * numHeads * seqLen * headDim * 4 // float32 = 4 bytes
|
|
|
|
// Initial cache
|
|
keys := Zeros(cacheShape, DtypeFloat32)
|
|
values := Zeros(cacheShape, DtypeFloat32)
|
|
Keep(keys, values)
|
|
Eval(keys, values)
|
|
|
|
// Warmup
|
|
for i := 0; i < 5; i++ {
|
|
oldKeys, oldValues := keys, values
|
|
|
|
newKeys := Add(keys, keys)
|
|
newValues := Add(values, values)
|
|
Keep(newKeys, newValues)
|
|
Eval(newKeys, newValues)
|
|
|
|
oldKeys.Free()
|
|
oldValues.Free()
|
|
keys, values = newKeys, newValues
|
|
}
|
|
|
|
MetalResetPeakMemory()
|
|
initialMem := MetalGetActiveMemory()
|
|
|
|
// Run 100 steps
|
|
for step := 0; step < 100; step++ {
|
|
oldKeys, oldValues := keys, values
|
|
|
|
newKeys := Add(keys, keys)
|
|
newValues := Add(values, values)
|
|
Keep(newKeys, newValues)
|
|
Eval(newKeys, newValues)
|
|
|
|
oldKeys.Free()
|
|
oldValues.Free()
|
|
keys, values = newKeys, newValues
|
|
}
|
|
|
|
Eval() // Final cleanup
|
|
|
|
finalMem := MetalGetActiveMemory()
|
|
peakMem := MetalGetPeakMemory()
|
|
|
|
growth := int64(finalMem) - int64(initialMem)
|
|
expectedMaxGrowth := int64(cacheSize * 4 * 10)
|
|
|
|
t.Logf("cache size: %d bytes", cacheSize*2)
|
|
t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB",
|
|
initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024)
|
|
|
|
if growth > expectedMaxGrowth {
|
|
t.Errorf("memory grew by %d bytes over 100 steps (expected max %d) - possible leak",
|
|
growth, expectedMaxGrowth)
|
|
}
|
|
}
|