mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
434 lines
13 KiB
Go
434 lines
13 KiB
Go
package ml
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
)
|
|
|
|
type Backend interface {
|
|
// Close frees all memory associated with this backend
|
|
// Close()
|
|
|
|
// Load(ctx context.Context, progress func(float32)) error
|
|
|
|
// BackendMemory returns the memory allocations that were made for this model
|
|
// BackendMemory() BackendMemory
|
|
|
|
Config() fs.Config
|
|
Get(name string) Tensor
|
|
NewContext() Context
|
|
// NewContextSize(size int) Context
|
|
|
|
// Enumerate the devices available for inference via this backend
|
|
// BackendDevices() []DeviceInfo
|
|
}
|
|
|
|
// BackendCacheConfig should be implemented by backends that need special output
|
|
// from the cache to meet specific requirements. It is frequently implemented in
|
|
// conjunction with ScaledDotProductAttention.
|
|
type BackendCacheConfig interface {
|
|
CacheConfig() CacheConfig
|
|
}
|
|
|
|
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
|
// the output the cache to work better with specific kernels.
|
|
type CacheConfig struct {
|
|
// CachePadding specifies the multiple for the number of tokens of cache history
|
|
// that will be returned from cache Get for k, v and mask. The capacity of the
|
|
// cache itself will also be increased to a multiple of this size if needed.
|
|
CachePadding int
|
|
|
|
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
|
// and return the permuted version via Get. This uses the cache copy operation
|
|
// to avoid a Contiguous call on the permuted tensor.
|
|
PermutedV bool
|
|
|
|
// MaskDType specifies the data type for generating the mask. If unset it will
|
|
// default to DTypeF32.
|
|
MaskDType DType
|
|
|
|
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
|
// Any position that does not correspond to an actual token will be filled with -Inf.
|
|
MaskBatchPadding int
|
|
}
|
|
|
|
// BackendParams controls how the backend loads and executes models
|
|
type BackendParams struct {
|
|
// AllocMemory causes the backend to allocate memory for the model. If
|
|
// false, this is only being used for discovering the required amount of
|
|
// memory and cannot load the model for running.
|
|
AllocMemory bool
|
|
|
|
// NumThreads sets the number of threads to use if running on the CPU
|
|
NumThreads int
|
|
|
|
// GPULayers is the set of layers to offload to GPUs
|
|
GPULayers GPULayersList
|
|
|
|
// FlashAttention indicates that we should use a fused flash attention kernel
|
|
FlashAttention bool
|
|
}
|
|
|
|
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
|
|
|
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
|
if _, ok := backends[name]; ok {
|
|
panic("backend: backend already registered")
|
|
}
|
|
|
|
backends[name] = f
|
|
}
|
|
|
|
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
|
be := os.Getenv("OLLAMA_BACKEND")
|
|
if be == "" {
|
|
be = "mlx"
|
|
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
|
}
|
|
slog.Info("Loading new engine", "backend", be)
|
|
if backend, ok := backends[be]; ok {
|
|
return backend(modelPath, params)
|
|
}
|
|
|
|
return nil, fmt.Errorf("unsupported backend")
|
|
}
|
|
|
|
type Context interface {
|
|
Empty(dtype DType, shape ...int) Tensor
|
|
Zeros(dtype DType, shape ...int) Tensor
|
|
// FromBytes(dtype DType, s []byte, shape ...int) Tensor
|
|
FromFloats(s []float32, shape ...int) Tensor
|
|
FromInts(s []int32, shape ...int) Tensor
|
|
RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
|
|
|
|
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
|
Arange(start, stop, step float32, dtype DType) Tensor
|
|
|
|
Forward(...Tensor) Context
|
|
|
|
// SetBatchSize provides a hint on the batch size to optimize processing
|
|
// Uses heuristics if not set
|
|
// SetBatchSize(int)
|
|
|
|
Compute(...Tensor)
|
|
// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
|
|
|
// Reserve is analogous to Compute but rather than executing a
|
|
// graph, simply preallocates memory. Typically called with a
|
|
// worst case graph to ensure all resources are available for
|
|
// for future inference.
|
|
// Reserve()
|
|
|
|
// MaxGraphNodes() int
|
|
Close()
|
|
|
|
// Input returns a context appropriate for creating tensors that are
|
|
// inputs to the model (which includes things like output locations)
|
|
Input() Context
|
|
|
|
// Layer returns a context appropriate for creating intermediate tensors
|
|
Layer(int) Context
|
|
|
|
// Load a tensor from "filename" safetensors file, and compare with the input tensor
|
|
// Returns error if the shape is inconsistent, or similarity measures are below 99%
|
|
CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
|
|
}
|
|
|
|
type RoPEOptions struct {
|
|
Base *float32
|
|
Freqs Tensor
|
|
}
|
|
|
|
func WithRoPEBase(base float32) func(*RoPEOptions) {
|
|
return func(opts *RoPEOptions) {
|
|
opts.Base = &base
|
|
}
|
|
}
|
|
|
|
func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
|
|
return func(opts *RoPEOptions) {
|
|
opts.Freqs = freqs
|
|
}
|
|
}
|
|
|
|
type Tensor interface {
|
|
ToString() string
|
|
RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
|
|
ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
|
|
TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
|
|
// TakeAxes(ctx Context, axes int, indicies ...int) Tensor
|
|
|
|
Dim(n int) int
|
|
Stride(n int) int
|
|
|
|
Shape() []int
|
|
DType() DType
|
|
// Cast(ctx Context, dtype DType) Tensor
|
|
|
|
// Bytes() []byte
|
|
Floats() []float32
|
|
Ints() []int32
|
|
|
|
// FromBytes([]byte)
|
|
// FromFloats([]float32)
|
|
// FromInts([]int32)
|
|
|
|
Add(ctx Context, t2 Tensor) Tensor
|
|
Sub(ctx Context, t2 Tensor) Tensor
|
|
// Mul(ctx Context, t2 Tensor) Tensor
|
|
// Div(ctx Context, t2 Tensor) Tensor
|
|
|
|
Max(ctx Context, axes []int, keepDims bool) Tensor
|
|
Min(ctx Context, axes []int, keepDims bool) Tensor
|
|
|
|
Matmul(ctx Context, a2 Tensor) Tensor
|
|
// Mulmat(ctx Context, t2 Tensor) Tensor
|
|
// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
|
// MulmatID(ctx Context, t2, ids Tensor) Tensor
|
|
// AddID(ctx Context, t2, ids Tensor) Tensor
|
|
|
|
Softmax(ctx Context) Tensor
|
|
L2Norm(ctx Context, eps float32) Tensor
|
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
|
Scale(ctx Context, s float64) Tensor
|
|
// SumRows(ctx Context) Tensor
|
|
|
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
|
Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
|
|
Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
|
|
|
|
// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
|
|
|
// Sin(ctx Context) Tensor
|
|
// Cos(ctx Context) Tensor
|
|
// Tanh(ctx Context) Tensor
|
|
GELU(ctx Context, up ...Tensor) Tensor
|
|
// QuickGELU(ctx Context, up ...Tensor) Tensor
|
|
// SILU(ctx Context, up ...Tensor) Tensor
|
|
// RELU(ctx Context, up ...Tensor) Tensor
|
|
// Sigmoid(ctx Context) Tensor
|
|
|
|
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
|
// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
|
|
|
Reshape(ctx Context, shape ...int) Tensor
|
|
AsStrided(ctx Context, shape, strides []int, offset int) Tensor
|
|
Transpose(ctx Context, shape ...int) Tensor
|
|
Contiguous(ctx Context, allowColMajor bool) Tensor
|
|
|
|
// Pad(ctx Context, shape ...int) Tensor
|
|
|
|
// Stack(ctx Context, dim int, s ...Tensor) Tensor
|
|
|
|
// Repeat repeats the tensor n times along dimension dim
|
|
// Repeat(ctx Context, dim, n int) Tensor
|
|
// Concat(ctx Context, t2 Tensor, dim int) Tensor
|
|
// Rows(ctx Context, t2 Tensor) Tensor
|
|
|
|
// TODO these probably aren't actually needed - false starts on trying to wire up cache
|
|
// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
|
|
// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
|
|
// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
|
|
|
|
Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
|
|
|
|
Copy(ctx Context, t2 Tensor) Tensor
|
|
// Duplicate(ctx Context) Tensor
|
|
|
|
// Slice(ctx Context, dim, low, high, step int) Tensor
|
|
// Chunk(ctx Context, dim int, size int) []Tensor
|
|
// ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
|
|
|
// TopK(ctx Context, k int) Tensor
|
|
// Argsort(ctx Context) Tensor
|
|
// Mean(ctx Context) Tensor
|
|
// Variance(ctx Context) Tensor
|
|
// Stddev(ctx Context) Tensor
|
|
// Sqr(ctx Context) Tensor
|
|
// Sqrt(ctx Context) Tensor
|
|
|
|
// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
|
}
|
|
|
|
// ScaledDotProductAttention implements a fused attention
|
|
// operation equivalent to following code on a tensor named
|
|
// query:
|
|
//
|
|
// query = query.Permute(ctx, 0, 2, 1, 3)
|
|
// key = key.Permute(ctx, 0, 2, 1, 3)
|
|
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
//
|
|
// kq := key.MulmatFullPrec(ctx, query)
|
|
//
|
|
// kq = kq.Scale(ctx, scale)
|
|
//
|
|
// if mask != nil {
|
|
// kq = kq.Add(ctx, mask)
|
|
// }
|
|
//
|
|
// kq = kq.Softmax(ctx)
|
|
//
|
|
// kqv := value.Mulmat(ctx, kq)
|
|
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
// type ScaledDotProductAttention interface {
|
|
// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
|
|
// }
|
|
|
|
// type number interface {
|
|
// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
|
// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
|
// ~float32 | ~float64 |
|
|
// ~complex64 | ~complex128
|
|
// }
|
|
|
|
// func mul[T number](s ...T) T {
|
|
// p := T(1)
|
|
// for _, v := range s {
|
|
// p *= v
|
|
// }
|
|
|
|
// return p
|
|
// }
|
|
|
|
// type DumpOptions func(*dumpOptions)
|
|
|
|
// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
|
// func DumpWithPrecision(n int) DumpOptions {
|
|
// return func(opts *dumpOptions) {
|
|
// opts.Precision = n
|
|
// }
|
|
// }
|
|
|
|
// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
|
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
|
// // beginning and end of each dimension will be printed.
|
|
// func DumpWithThreshold(n int) DumpOptions {
|
|
// return func(opts *dumpOptions) {
|
|
// opts.Threshold = n
|
|
// }
|
|
// }
|
|
|
|
// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
|
// func DumpWithEdgeItems(n int) DumpOptions {
|
|
// return func(opts *dumpOptions) {
|
|
// opts.EdgeItems = n
|
|
// }
|
|
// }
|
|
|
|
// type dumpOptions struct {
|
|
// Precision, Threshold, EdgeItems int
|
|
// }
|
|
|
|
// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
|
// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
|
// for _, optsFunc := range optsFuncs {
|
|
// optsFunc(&opts)
|
|
// }
|
|
|
|
// if mul(t.Shape()...) <= opts.Threshold {
|
|
// opts.EdgeItems = math.MaxInt
|
|
// }
|
|
|
|
// switch t.DType() {
|
|
// case DTypeFloat32:
|
|
// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
|
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
|
// })
|
|
// case DTypeFloat16: // TODO other types...
|
|
// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
|
|
// f32 = t.Copy(ctx, f32)
|
|
// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
|
// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
|
// })
|
|
// case DTypeInt32:
|
|
// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
|
// return strconv.FormatInt(int64(i), 10)
|
|
// })
|
|
// default:
|
|
// return "<unsupported>"
|
|
// }
|
|
// }
|
|
|
|
// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
|
// if t.Bytes() == nil {
|
|
// ctx.Compute(t)
|
|
// }
|
|
|
|
// s := make(S, mul(t.Shape()...))
|
|
// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
|
// panic(err)
|
|
// }
|
|
|
|
// shape := t.Shape()
|
|
// slices.Reverse(shape)
|
|
|
|
// var sb strings.Builder
|
|
// var f func([]int, int)
|
|
// f = func(dims []int, stride int) {
|
|
// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
|
// sb.WriteString("[")
|
|
// defer func() { sb.WriteString("]") }()
|
|
// for i := 0; i < dims[0]; i++ {
|
|
// if i >= items && i < dims[0]-items {
|
|
// sb.WriteString("..., ")
|
|
// // skip to next printable element
|
|
// skip := dims[0] - 2*items
|
|
// if len(dims) > 1 {
|
|
// stride += mul(append(dims[1:], skip)...)
|
|
// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
|
// }
|
|
// i += skip - 1
|
|
// } else if len(dims) > 1 {
|
|
// f(dims[1:], stride)
|
|
// stride += mul(dims[1:]...)
|
|
// if i < dims[0]-1 {
|
|
// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
|
// }
|
|
// } else {
|
|
// text := fn(s[stride+i])
|
|
// if len(text) > 0 && text[0] != '-' {
|
|
// sb.WriteString(" ")
|
|
// }
|
|
|
|
// sb.WriteString(text)
|
|
// if i < dims[0]-1 {
|
|
// sb.WriteString(", ")
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
// f(shape, 0)
|
|
|
|
// return sb.String()
|
|
// }
|
|
|
|
type DType int
|
|
|
|
const (
|
|
DTypeBool DType = iota
|
|
DTypeUint8
|
|
DTypeUint16
|
|
DTypeUint32
|
|
DTypeUint64
|
|
DTypeInt8
|
|
DTypeInt16
|
|
DTypeInt32
|
|
DTypeInt64
|
|
DTypeFloat16
|
|
DTypeFloat32
|
|
DTypeFloat64
|
|
DTypeBfloat16
|
|
DTypeComplex64
|
|
)
|
|
|
|
type SamplingMode int
|
|
|
|
const (
|
|
SamplingModeNearest SamplingMode = iota
|
|
SamplingModeBilinear
|
|
)
|