mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
315 lines
7.9 KiB
Go
315 lines
7.9 KiB
Go
//go:build mlx
|
|
|
|
package mlx
|
|
|
|
import (
|
|
"log/slog"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/runner/common"
|
|
"github.com/ollama/ollama/sample"
|
|
"github.com/ollama/ollama/x/ml"
|
|
"github.com/ollama/ollama/x/model"
|
|
"github.com/ollama/ollama/x/model/input"
|
|
_ "github.com/ollama/ollama/x/model/models/gemma3"
|
|
)
|
|
|
|
func init() {
|
|
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
slog.SetDefault(logger)
|
|
}
|
|
|
|
func TestLoadModel(t *testing.T) {
|
|
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
|
b := &Backend{}
|
|
err := b.LoadSafeTensors(dir)
|
|
if err != nil {
|
|
t.Fatalf("load failed: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestFromInts(t *testing.T) {
|
|
b := &Backend{}
|
|
c := b.NewContext()
|
|
defer c.Close()
|
|
data := []int32{1, 2, 3, 4, 5, 6}
|
|
a := c.FromInts(data, 2, 3)
|
|
slog.Info("", "array", a)
|
|
t.Log(a.ToString())
|
|
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
|
t.Fatalf("incorrect shape: %v", a.Shape())
|
|
}
|
|
}
|
|
|
|
func TestFromFloats(t *testing.T) {
|
|
b := &Backend{}
|
|
c := b.NewContext()
|
|
defer c.Close()
|
|
data := []float32{1, 2, 3, 4, 5, 6}
|
|
a := c.FromFloats(data, 2, 3)
|
|
slog.Info("", "array", a)
|
|
t.Log(a.ToString())
|
|
if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
|
|
t.Fatalf("incorrect shape: %v", a.Shape())
|
|
}
|
|
res := a.Floats()
|
|
if !reflect.DeepEqual(res, data) {
|
|
t.Fatalf("incorrect results: %v", res)
|
|
}
|
|
}
|
|
|
|
func TestAdd(t *testing.T) {
|
|
b := &Backend{}
|
|
c := b.NewContext()
|
|
defer c.Close()
|
|
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
|
t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
|
|
exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
|
|
t3 := t1.Add(c, t2)
|
|
c.Compute(t3, exp)
|
|
t3f := t3.Floats()
|
|
if !reflect.DeepEqual(t3f, exp.Floats()) {
|
|
t.Fatalf("incorrect result: %v", t3f)
|
|
}
|
|
}
|
|
|
|
func TestReshapeTranspose(t *testing.T) {
|
|
b := &Backend{}
|
|
c := b.NewContext()
|
|
defer c.Close()
|
|
t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
|
|
c.Compute(t1)
|
|
t1f := t1.Floats()
|
|
exp := []float32{
|
|
0, 4, 8,
|
|
1, 5, 9,
|
|
2, 6, 10,
|
|
3, 7, 11,
|
|
12, 16, 20,
|
|
13, 17, 21,
|
|
14, 18, 22,
|
|
15, 19, 23,
|
|
}
|
|
if !reflect.DeepEqual(t1f, exp) {
|
|
t.Fatalf("incorrect results: %v", t1f)
|
|
}
|
|
}
|
|
|
|
func prod(vals ...int) int {
|
|
r := 1
|
|
for _, v := range vals {
|
|
r *= v
|
|
}
|
|
return r
|
|
}
|
|
func TestMatmul(t *testing.T) {
|
|
// TODO create scenarios...
|
|
b := &Backend{}
|
|
c := b.NewContext()
|
|
defer c.Close()
|
|
s1 := []int{1, 3, 2, 4}
|
|
t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
|
|
s2 := []int{4, 2}
|
|
t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
|
|
t3 := t1.Matmul(c, t2)
|
|
exp := []float32{
|
|
28, 34,
|
|
76, 98,
|
|
|
|
124, 162,
|
|
172, 226,
|
|
|
|
220, 290,
|
|
268, 354,
|
|
}
|
|
c.Compute(t3)
|
|
t3f := t3.Floats()
|
|
if !reflect.DeepEqual(t3f, exp) {
|
|
t.Fatalf("incorrect result: %v", t3f)
|
|
}
|
|
}
|
|
|
|
func TestRows(t *testing.T) {
|
|
b := &Backend{}
|
|
c := b.NewContext()
|
|
defer c.Close()
|
|
t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
|
|
outputs := c.Zeros(ml.DTypeInt32, 1)
|
|
t2 := t1.TakeAxes(c, outputs, 1)
|
|
c.Forward(t1, t2).Compute(t1, t2)
|
|
t.Log(t1.ToString())
|
|
t.Log(t2.ToString())
|
|
f := t2.Floats()
|
|
t.Logf("Result: %v", f)
|
|
}
|
|
|
|
func TestCaching(t *testing.T) {
|
|
// Validate the caching algorithm
|
|
b := &Backend{}
|
|
c := b.NewContext()
|
|
defer c.Close()
|
|
batchSize := 3
|
|
headDim := 4
|
|
numKVHeads := 2
|
|
// Make cache twice the size of one test batch
|
|
cells := batchSize * 2
|
|
cellSize := numKVHeads * headDim
|
|
shape := []int{1, numKVHeads, batchSize, headDim}
|
|
stop := float32(1)
|
|
for _, x := range shape {
|
|
stop *= float32(x)
|
|
}
|
|
// Create the cache
|
|
cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
|
|
t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
|
|
|
|
// Input tensor
|
|
t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
|
|
t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
|
|
|
|
// Reshape to copy into the cache
|
|
/*
|
|
From MLX python/src/indexing.cpp mlx_scatter_args_array
|
|
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
|
auto up_shape = indices.shape();
|
|
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
|
up = broadcast_to(up, up_shape);
|
|
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
|
up = reshape(up, up_shape);
|
|
*/
|
|
numRows := 3
|
|
up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
|
|
t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
|
|
|
|
// Simulate cells 1,3,5 are available
|
|
indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
|
|
t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
|
|
axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
|
|
cache.Scatter(c, indicies, up, axis)
|
|
|
|
c.Forward(cache)
|
|
// Cache should contain the data now
|
|
t.Log("Cache after put\n" + cache.ToString())
|
|
|
|
// Retrieve cache content and verify it matches
|
|
out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
|
|
t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
|
|
|
|
t1f := t1.Floats()
|
|
outf := out.Floats()
|
|
if !reflect.DeepEqual(t1f, outf) {
|
|
t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
|
|
}
|
|
}
|
|
|
|
func TestGemma3(t *testing.T) {
|
|
// Why is the sky blue
|
|
inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
|
|
limit := 50
|
|
|
|
// TODO generalize this
|
|
dir := "/Users/daniel/Models/gemma-3-4b-it/"
|
|
|
|
m, err := model.New(dir, ml.BackendParams{})
|
|
if err != nil {
|
|
t.Fatalf("unable to load model: %s", err)
|
|
}
|
|
b := m.Backend()
|
|
ctx := b.NewContext()
|
|
defer ctx.Close()
|
|
|
|
batch := input.Batch{
|
|
Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
|
|
Positions: make([]int32, len(inputs)),
|
|
Sequences: make([]int, len(inputs)),
|
|
Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
|
|
Offset: 0,
|
|
}
|
|
for i := range len(inputs) {
|
|
batch.Positions[i] = int32(i)
|
|
}
|
|
offset := len(inputs)
|
|
|
|
cache := m.Config().Cache
|
|
if cache != nil {
|
|
numSlots := 1
|
|
batchSize := 512
|
|
numCtx := 4096
|
|
|
|
// Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
|
|
// cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
|
|
cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
|
|
|
|
cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
|
|
err := cache.StartForward(ctx, batch, false)
|
|
if err != nil {
|
|
t.Fatalf("failed cache.StartForward: %s", err)
|
|
}
|
|
}
|
|
opts := api.DefaultOptions()
|
|
var grammar *sample.GrammarSampler
|
|
sampler := sample.NewSampler(
|
|
opts.Temperature,
|
|
opts.TopK,
|
|
opts.TopP,
|
|
opts.MinP,
|
|
opts.Seed,
|
|
grammar,
|
|
)
|
|
|
|
t.Log("Starting Forward pass loop")
|
|
pendingResponses := []string{}
|
|
for {
|
|
out, err := m.Forward(ctx, batch)
|
|
if err != nil {
|
|
t.Fatalf("failed forward pass: %s", err)
|
|
}
|
|
ctx.Forward(out)
|
|
outputs := out.Floats()
|
|
t.Logf("finished forward pass! length:%d", len(outputs))
|
|
// sample a token
|
|
logits := outputs
|
|
token, err := sampler.Sample(logits)
|
|
if err != nil {
|
|
t.Fatalf("unable to sample token: %s", err)
|
|
}
|
|
t.Logf("Sampled token: %v", token)
|
|
if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
|
t.Log("hit EOS")
|
|
break
|
|
}
|
|
piece, err := m.(model.TextProcessor).Decode([]int32{token})
|
|
if err != nil {
|
|
t.Fatalf("unable to decode token: %s", err)
|
|
}
|
|
|
|
pendingResponses = append(pendingResponses, piece)
|
|
sequence := strings.Join(pendingResponses, "")
|
|
if ok, stop := common.FindStop(sequence, opts.Stop); ok {
|
|
t.Logf("hit stop token: %v", stop)
|
|
break
|
|
}
|
|
t.Logf("RESULTS: %s", sequence)
|
|
batch = input.Batch{
|
|
Inputs: ctx.FromInts([]int32{token}, 1, 1),
|
|
Positions: make([]int32, 1),
|
|
Sequences: make([]int, 1),
|
|
Outputs: ctx.FromInts([]int32{0}, 1),
|
|
Offset: offset,
|
|
}
|
|
offset++
|
|
batch.Positions[0] = 0
|
|
err = cache.StartForward(ctx, batch, false)
|
|
if err != nil {
|
|
t.Fatalf("failed cache.StartForward: %s", err)
|
|
}
|
|
if offset > limit {
|
|
break
|
|
}
|
|
}
|
|
}
|