Files
ollama/x/ml/backend.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

434 lines
13 KiB
Go

package ml
import (
"fmt"
"log/slog"
"os"
"github.com/ollama/ollama/fs"
)
type Backend interface {
// Close frees all memory associated with this backend
// Close()
// Load(ctx context.Context, progress func(float32)) error
// BackendMemory returns the memory allocations that were made for this model
// BackendMemory() BackendMemory
Config() fs.Config
Get(name string) Tensor
NewContext() Context
// NewContextSize(size int) Context
// Enumerate the devices available for inference via this backend
// BackendDevices() []DeviceInfo
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// AllocMemory causes the backend to allocate memory for the model. If
// false, this is only being used for discovering the required amount of
// memory and cannot load the model for running.
AllocMemory bool
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
// GPULayers is the set of layers to offload to GPUs
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
backends[name] = f
}
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
be := os.Getenv("OLLAMA_BACKEND")
if be == "" {
be = "mlx"
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
}
slog.Info("Loading new engine", "backend", be)
if backend, ok := backends[be]; ok {
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
FromFloats(s []float32, shape ...int) Tensor
FromInts(s []int32, shape ...int) Tensor
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
// SetBatchSize provides a hint on the batch size to optimize processing
// Uses heuristics if not set
// SetBatchSize(int)
Compute(...Tensor)
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
// Reserve()
// MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
Input() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
// Load a tensor from "filename" safetensors file, and compare with the input tensor
// Returns error if the shape is inconsistent, or similarity measures are below 99%
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
}
type RoPEOptions struct {
Base *float32
Freqs Tensor
}
func WithRoPEBase(base float32) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Base = &base
}
}
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
return func(opts *RoPEOptions) {
opts.Freqs = freqs
}
}
type Tensor interface {
ToString() string
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
Dim(n int) int
Stride(n int) int
Shape() []int
DType() DType
// Cast(ctx Context, dtype DType) Tensor
// Bytes() []byte
Floats() []float32
Ints() []int32
// FromBytes([]byte)
// FromFloats([]float32)
// FromInts([]int32)
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
// Mul(ctx Context, t2 Tensor) Tensor
// Div(ctx Context, t2 Tensor) Tensor
Max(ctx Context, axes []int, keepDims bool) Tensor
Min(ctx Context, axes []int, keepDims bool) Tensor
Matmul(ctx Context, a2 Tensor) Tensor
// Mulmat(ctx Context, t2 Tensor) Tensor
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
// MulmatID(ctx Context, t2, ids Tensor) Tensor
// AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
// SumRows(ctx Context) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
// Sin(ctx Context) Tensor
// Cos(ctx Context) Tensor
// Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
// QuickGELU(ctx Context, up ...Tensor) Tensor
// SILU(ctx Context, up ...Tensor) Tensor
// RELU(ctx Context, up ...Tensor) Tensor
// Sigmoid(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
Transpose(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, allowColMajor bool) Tensor
// Pad(ctx Context, shape ...int) Tensor
// Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
// Repeat(ctx Context, dim, n int) Tensor
// Concat(ctx Context, t2 Tensor, dim int) Tensor
// Rows(ctx Context, t2 Tensor) Tensor
// TODO these probably aren't actually needed - false starts on trying to wire up cache
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
Copy(ctx Context, t2 Tensor) Tensor
// Duplicate(ctx Context) Tensor
// Slice(ctx Context, dim, low, high, step int) Tensor
// Chunk(ctx Context, dim int, size int) []Tensor
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
// TopK(ctx Context, k int) Tensor
// Argsort(ctx Context) Tensor
// Mean(ctx Context) Tensor
// Variance(ctx Context) Tensor
// Stddev(ctx Context) Tensor
// Sqr(ctx Context) Tensor
// Sqrt(ctx Context) Tensor
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// type ScaledDotProductAttention interface {
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
// }
// type number interface {
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
// ~float32 | ~float64 |
// ~complex64 | ~complex128
// }
// func mul[T number](s ...T) T {
// p := T(1)
// for _, v := range s {
// p *= v
// }
// return p
// }
// type DumpOptions func(*dumpOptions)
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
// func DumpWithPrecision(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Precision = n
// }
// }
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// // beginning and end of each dimension will be printed.
// func DumpWithThreshold(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.Threshold = n
// }
// }
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
// func DumpWithEdgeItems(n int) DumpOptions {
// return func(opts *dumpOptions) {
// opts.EdgeItems = n
// }
// }
// type dumpOptions struct {
// Precision, Threshold, EdgeItems int
// }
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
// for _, optsFunc := range optsFuncs {
// optsFunc(&opts)
// }
// if mul(t.Shape()...) <= opts.Threshold {
// opts.EdgeItems = math.MaxInt
// }
// switch t.DType() {
// case DTypeFloat32:
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeFloat16: // TODO other types...
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
// f32 = t.Copy(ctx, f32)
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// })
// case DTypeInt32:
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
// return strconv.FormatInt(int64(i), 10)
// })
// default:
// return "<unsupported>"
// }
// }
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
// if t.Bytes() == nil {
// ctx.Compute(t)
// }
// s := make(S, mul(t.Shape()...))
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
// panic(err)
// }
// shape := t.Shape()
// slices.Reverse(shape)
// var sb strings.Builder
// var f func([]int, int)
// f = func(dims []int, stride int) {
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
// sb.WriteString("[")
// defer func() { sb.WriteString("]") }()
// for i := 0; i < dims[0]; i++ {
// if i >= items && i < dims[0]-items {
// sb.WriteString("..., ")
// // skip to next printable element
// skip := dims[0] - 2*items
// if len(dims) > 1 {
// stride += mul(append(dims[1:], skip)...)
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
// }
// i += skip - 1
// } else if len(dims) > 1 {
// f(dims[1:], stride)
// stride += mul(dims[1:]...)
// if i < dims[0]-1 {
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
// }
// } else {
// text := fn(s[stride+i])
// if len(text) > 0 && text[0] != '-' {
// sb.WriteString(" ")
// }
// sb.WriteString(text)
// if i < dims[0]-1 {
// sb.WriteString(", ")
// }
// }
// }
// }
// f(shape, 0)
// return sb.String()
// }
type DType int
const (
DTypeBool DType = iota
DTypeUint8
DTypeUint16
DTypeUint32
DTypeUint64
DTypeInt8
DTypeInt16
DTypeInt32
DTypeInt64
DTypeFloat16
DTypeFloat32
DTypeFloat64
DTypeBfloat16
DTypeComplex64
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)