Files
ollama/x/kvcache/causal_test.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

974 lines
26 KiB
Go

package kvcache
// import (
// "fmt"
// "math"
// "slices"
// "testing"
// "github.com/ollama/ollama/ml"
// "github.com/ollama/ollama/model/input"
// )
// type testCase struct {
// name string
// in []float32
// inShape []int
// seqs []int
// pos []int32
// expected []float32
// expectedShape []int
// expectedMask []float32
// }
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
// t.Helper()
// for _, permuted := range []bool{false, true} {
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
// fn(t, &testBackend{permutedV: permuted})
// })
// }
// }
// func TestStore(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// inShape: []int{2, 3, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
// expectedShape: []int{2, 3, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{115, 215, 125, 225, 135, 235},
// inShape: []int{2, 3, 1},
// seqs: []int{0},
// pos: []int32{4},
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
// expectedShape: []int{2, 3, 5},
// expectedMask: []float32{0, 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWA(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 6, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWASeparateBatches(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWACache(1, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "First seq 0",
// in: []float32{1, 2},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{0, 1},
// expected: []float32{1, 2},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 0",
// in: []float32{3, 4},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{2, 3},
// expected: []float32{2, 3, 4},
// expectedShape: []int{1, 1, 3},
// expectedMask: []float32{
// 0, 0, x,
// x, 0, 0,
// },
// },
// {
// name: "First seq 1",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{0, 1},
// expected: []float32{5, 6},
// expectedShape: []int{1, 1, 2},
// expectedMask: []float32{
// 0, x,
// 0, 0,
// },
// },
// {
// name: "Second seq 1",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{2, 3},
// expected: []float32{6, 3, 4, 7, 8},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// x, x, x, 0, 0,
// },
// },
// {
// name: "Third seq 0",
// in: []float32{9, 10},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{9, 10, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, 0,
// 0, 0, x, x,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewSWAMemCache(1, 3, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, 0, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{4, 5},
// expected: []float32{5, 2, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, x, x, 0, x,
// 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestChunkedAttention(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewChunkedAttentionCache(2, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// testCache(
// t, backend, cache,
// []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6, 7},
// inShape: []int{1, 1, 3},
// seqs: []int{0, 0, 0},
// pos: []int32{4, 5, 6},
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
// expectedShape: []int{1, 1, 7},
// expectedMask: []float32{
// x, x, x, x, 0, x, x,
// x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, 0,
// },
// },
// {
// name: "ThirdBatch",
// in: []float32{8, 9},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{7, 8},
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
// expectedShape: []int{1, 1, 9},
// expectedMask: []float32{
// x, x, x, x, x, x, 0, 0, x,
// x, x, x, x, x, x, x, x, 0,
// },
// },
// },
// )
// })
// }
// func TestSequences(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// {
// name: "SecondBatch",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{2, 2},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestRemove(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// return key.Add(ctx, shift), nil
// })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// x := float32(math.Inf(-1))
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 1, 1},
// pos: []int32{0, 1, 0, 1},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{
// 0, x, x, x,
// 0, 0, x, x,
// x, x, 0, x,
// x, x, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err := cache.Remove(0, 1, math.MaxInt32)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveEnd",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 1},
// pos: []int32{1, 2},
// expected: []float32{1, 5, 3, 4, 6},
// expectedShape: []int{1, 1, 5},
// expectedMask: []float32{
// 0, 0, x, x, x,
// x, x, 0, 0, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// err = cache.Remove(0, 0, 1)
// if err != nil {
// panic(err)
// }
// tests = []testCase{
// {
// name: "RemoveMiddle",
// in: []float32{7, 8},
// inShape: []int{1, 1, 2},
// seqs: []int{0, 0},
// pos: []int32{1, 2},
// expected: []float32{7, 4, 3, 4, 6, 8},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{
// 0, 0, x, x, x, x,
// 0, 0, x, x, x, 0,
// },
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func TestCopy(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// tests := []testCase{
// {
// name: "FirstBatch",
// in: []float32{1, 2, 3, 4},
// inShape: []int{1, 1, 4},
// seqs: []int{0, 0, 0, 0},
// pos: []int32{0, 1, 2, 3},
// expected: []float32{1, 2, 3, 4},
// expectedShape: []int{1, 1, 4},
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// cache.CopyPrefix(0, 1, 2)
// tests = []testCase{
// {
// name: "Copy",
// in: []float32{5, 6},
// inShape: []int{1, 1, 2},
// seqs: []int{1, 1},
// pos: []int32{3, 4},
// expected: []float32{1, 2, 3, 4, 5, 6},
// expectedShape: []int{1, 1, 6},
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
// },
// }
// testCache(t, backend, cache, tests)
// })
// }
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
// if err != nil {
// panic(err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats(test.in, test.inShape...)
// cache.Put(context, tensor, tensor)
// out, _, mask := cache.Get(context)
// context.Forward(out, mask).Compute(out, mask)
// if !slices.Equal(out.Floats(), test.expected) {
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
// }
// if !slices.Equal(out.Shape(), test.expectedShape) {
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
// }
// if !slices.Equal(mask.Floats(), test.expectedMask) {
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
// }
// })
// }
// }
// func TestCanResume(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// cache := NewSWACache(windowSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4},
// Sequences: []int{0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
// cache.Put(context, tensor, tensor)
// // with window size 4, nothing has slid out of the window yet
// if !cache.CanResume(0, 0) {
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
// }
// if !cache.CanResume(0, 1) {
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
// }
// if !cache.CanResume(0, 2) {
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
// }
// if !cache.CanResume(0, 3) {
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
// }
// if !cache.CanResume(0, 4) {
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
// }
// // shift window by adding position 5
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{5},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
// }
// })
// }
// func TestCanResumeSWAMem(t *testing.T) {
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
// windowSize := int32(4)
// memSize := int32(5)
// cache := NewSWAMemCache(windowSize, memSize, nil)
// defer cache.Close()
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
// context := backend.NewContext()
// defer context.Close()
// err := cache.StartForward(context, input.Batch{
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
// cache.Put(context, tensor, tensor)
// // shift window by adding position 7
// err = cache.StartForward(context, input.Batch{
// Positions: []int32{7},
// Sequences: []int{0},
// }, false)
// if err != nil {
// t.Fatalf("StartForward failed: %v", err)
// }
// cache.SetLayer(0)
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
// cache.Put(context, tensor, tensor)
// // only the latest position has overlapping windows
// if cache.CanResume(0, 0) {
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
// }
// if cache.CanResume(0, 1) {
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
// }
// if cache.CanResume(0, 2) {
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
// }
// if cache.CanResume(0, 3) {
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
// }
// if cache.CanResume(0, 4) {
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
// }
// if cache.CanResume(0, 5) {
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
// }
// if !cache.CanResume(0, 6) {
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
// }
// if !cache.CanResume(0, 7) {
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
// }
// })
// }
// type testBackend struct {
// ml.Backend
// permutedV bool
// }
// func (b *testBackend) NewContext() ml.Context {
// return &testContext{}
// }
// func (b *testBackend) NewContextSize(int) ml.Context {
// return &testContext{}
// }
// func (b *testBackend) CacheConfig() ml.CacheConfig {
// return ml.CacheConfig{PermutedV: b.permutedV}
// }
// type testContext struct {
// ml.Context
// }
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
// total := 0
// if len(shape) > 0 {
// total = 1
// for _, s := range shape {
// total *= s
// }
// }
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
// }
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
// return c.Empty(dtype, shape...)
// }
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
// copy(t.data, s)
// return t
// }
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
// f := make([]float32, len(s))
// for i := range f {
// f[i] = float32(s[i])
// }
// out := c.FromFloats(f, shape...)
// out.(*testTensor).dtype = ml.DTypeI32
// return out
// }
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
// s := make([]float32, 0, int((stop-start)/step))
// for i := start; i < stop; i += step {
// s = append(s, i)
// }
// out := c.FromFloats(s, len(s))
// out.(*testTensor).dtype = dtype
// return out
// }
// func (c *testContext) Input() ml.Context { return c }
// func (c *testContext) Layer(int) ml.Context { return c }
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
// func (c *testContext) Compute(...ml.Tensor) {}
// func (c *testContext) Reserve() {}
// func (c *testContext) MaxGraphNodes() int {
// return 10
// }
// func (c *testContext) Close() {}
// type testTensor struct {
// ml.Tensor
// dtype ml.DType
// elementSize int
// data []float32
// shape []int
// }
// func (t *testTensor) Dim(n int) int {
// return t.shape[n]
// }
// func (t *testTensor) Stride(n int) int {
// stride := t.elementSize
// for i := range n {
// stride *= t.shape[i]
// }
// return stride
// }
// func (t *testTensor) Shape() []int {
// return t.shape
// }
// func (t *testTensor) DType() ml.DType {
// return t.dtype
// }
// func (t *testTensor) Floats() []float32 {
// out := make([]float32, len(t.data))
// copy(out, t.data)
// return out
// }
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = -t.data[i]
// }
// return out
// }
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
// for i := range out.data {
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
// }
// return out
// }
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: t.data,
// shape: shape,
// }
// }
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
// offset /= t.elementSize
// var s []int
// switch len(shape) {
// case 1:
// s = []int{shape[0]}
// case 3:
// s = []int{shape[0], shape[2]}
// case 5:
// s = []int{shape[0], shape[2], shape[4]}
// default:
// panic("unsupported number of dimensions")
// }
// context := &testContext{}
// view := context.Empty(t.dtype, s...).(*testTensor)
// view.data = t.data[offset : offset+len(view.data)]
// return view
// }
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
// if len(t.shape) > 4 || len(order) > 4 {
// panic("permute only supports up to 4 dimensions")
// }
// if len(order) != len(t.shape) && len(order) != 4 {
// panic("invalid number of dimensions for permute")
// }
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
// orderFull := append(make([]int, 0, 4), order...)
// for len(orderFull) < 4 {
// orderFull = append(orderFull, len(orderFull))
// }
// seen := [4]bool{}
// shape4 := [4]int{1, 1, 1, 1}
// for i := 0; i < len(t.shape) && i < 4; i++ {
// shape4[i] = t.shape[i]
// }
// newShape4 := [4]int{1, 1, 1, 1}
// for axis := range 4 {
// dst := orderFull[axis]
// if dst < 0 || dst >= 4 {
// panic("invalid axis for permute")
// }
// if seen[dst] {
// panic("duplicate axis for permute")
// }
// seen[dst] = true
// newShape4[dst] = shape4[axis]
// }
// total := len(t.data)
// newData := make([]float32, total)
// if total > 0 {
// oldDims := shape4
// newDims := newShape4
// oldStride := [4]int{1, 1, 1, 1}
// newStride := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
// newStride[i] = newStride[i-1] * newDims[i-1]
// }
// var coords [4]int
// var newCoords [4]int
// for idx := range total {
// remainder := idx
// for axis := range 4 {
// dim := oldDims[axis]
// if dim == 0 {
// coords[axis] = 0
// continue
// }
// coords[axis] = remainder % dim
// remainder /= dim
// }
// for axis := range 4 {
// newCoords[orderFull[axis]] = coords[axis]
// }
// newIndex := 0
// for axis := range 4 {
// if newDims[axis] == 0 {
// continue
// }
// newIndex += newCoords[axis] * newStride[axis]
// }
// newData[newIndex] = t.data[idx]
// }
// }
// numDims := 4
// for numDims > 1 && newShape4[numDims-1] <= 1 {
// numDims--
// }
// newShape := make([]int, numDims)
// copy(newShape, newShape4[:numDims])
// return &testTensor{
// dtype: t.dtype,
// elementSize: t.elementSize,
// data: newData,
// shape: newShape,
// }
// }
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
// dst := t
// srcTensor := src.(*testTensor)
// idxTensor := idxs.(*testTensor)
// shapeTo4D := func(shape []int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 0; i < len(shape) && i < 4; i++ {
// out[i] = shape[i]
// }
// return out
// }
// computeStrides := func(shape [4]int) [4]int {
// out := [4]int{1, 1, 1, 1}
// for i := 1; i < 4; i++ {
// out[i] = out[i-1] * shape[i-1]
// }
// return out
// }
// dstShape4D := shapeTo4D(dst.shape)
// srcShape4D := shapeTo4D(srcTensor.shape)
// idxShape4D := shapeTo4D(idxTensor.shape)
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
// panic("SetRows requires matching tensor shapes")
// }
// if srcShape4D[1] != idxShape4D[0] {
// panic("SetRows rows/index mismatch")
// }
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
// panic("SetRows cannot broadcast indices")
// }
// if idxShape4D[3] != 1 {
// panic("SetRows expects 1D or 2D index tensors")
// }
// dstStride := computeStrides(dstShape4D)
// srcStride := computeStrides(srcShape4D)
// idxStride := computeStrides(idxShape4D)
// numColumns := srcShape4D[0]
// numRows := srcShape4D[1]
// for dim3Index := range dstShape4D[3] {
// for dim2Index := range dstShape4D[2] {
// idxDim2 := 0
// idxDim3 := 0
// if idxShape4D[1] > 0 {
// idxDim2 = dim2Index % idxShape4D[1]
// }
// if idxShape4D[2] > 0 {
// idxDim3 = dim3Index % idxShape4D[2]
// }
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
// for row := range numRows {
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
// if idx < 0 || idx >= dstShape4D[1] {
// panic("SetRows index out of range")
// }
// srcOffset := srcBase + row*srcStride[1]
// dstOffset := dstBase + idx*dstStride[1]
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
// }
// }
// }
// return dst
// }
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
// copy(t2.(*testTensor).data, t.data)
// return nil
// }