Files
ollama/x/ml/backend/mlx/mlx.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

1279 lines
30 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build mlx
package mlx
/*
#cgo CPPFLAGS: -I${SRCDIR}/../../../../build/_deps/mlx-c-src
#cgo LDFLAGS: -L${SRCDIR}/../../../../build/lib/ollama/ -lmlxc -lmlx
#cgo LDFLAGS: -framework Accelerate
#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../../build/lib/ollama/
#include <stdlib.h>
#include "mlx/c/mlx.h"
static inline size_t stride(const mlx_array a, int i) {return mlx_array_strides(a)[i];}
extern void goStackTrace();
static void error_handler(const char *msg, void* data) {
fprintf(stderr, "MLX error: %s\n", msg);
goStackTrace();
exit(-1); // TODO adjust so this can become a return code on the current thread instead of exit
}
static void set_error_handler() {mlx_set_error_handler(&error_handler, NULL, NULL);}
static void* mlx_array_data_float16_asvoid(const mlx_array a) {return (void*)mlx_array_data_float16(a);}
typedef const char cchar_t;
*/
import "C"
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"reflect"
"runtime"
"runtime/debug"
"sync"
"unsafe"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/x/ml"
"github.com/x448/float16"
)
func init() {
ml.RegisterBackend("mlx", New)
C.set_error_handler()
}
//export goStackTrace
func goStackTrace() {
debug.PrintStack()
}
type SafetensorsIndexMetadata struct {
TotalSize uint64 `json:"total_size"`
}
type SafetensorsIndex struct {
Metadata SafetensorsIndexMetadata `json:"metadata"`
WeightMap map[string]string `json:"weight_map"`
}
type Backend struct {
meta fs.Config
tensors map[string]*Array
}
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
// TODO assumes modelPath is actually a directory for now...
kv, tokenizer, err := convert.LoadModelMetadata(os.DirFS(modelPath))
if err != nil {
return nil, fmt.Errorf("unable to load model: %w", err)
}
b := &Backend{
meta: kv.KV(tokenizer),
}
err = b.LoadSafeTensors(modelPath)
if err != nil {
return nil, fmt.Errorf("safetensors load failed: %w", err)
}
return b, nil
}
func (b *Backend) LoadSafeTensors(dir string) error {
if _, err := os.Stat(dir); err != nil {
return fmt.Errorf("failed to stat dir: %w", err)
}
// other variations to try?
stFilename := filepath.Join(dir, "model.safetensors.index.json")
if _, err := os.Stat(stFilename); err != nil {
return fmt.Errorf("failed to stat %s: %w", stFilename, err)
}
fp, err := os.Open(stFilename)
if err != nil {
return fmt.Errorf("failed to open safetensor index: %s: %w", stFilename, err)
}
decoder := json.NewDecoder(fp)
var index SafetensorsIndex
if err := decoder.Decode(&index); err != nil {
return fmt.Errorf("decode error: %s: %w", stFilename, err)
}
slog.Info("XXX parsed metadata", "size", index.Metadata.TotalSize, "weights", len(index.WeightMap))
filenames := map[string]struct{}{}
for _, filename := range index.WeightMap {
filenames[filename] = struct{}{}
}
stream := C.mlx_default_cpu_stream_new()
b.tensors = map[string]*Array{}
for filename := range filenames {
filepath := filepath.Join(dir, filename)
if _, err := os.Stat(filepath); err != nil {
return fmt.Errorf("failed to stat %s: %w", filepath, err)
}
slog.Info("Loading tensors from", "filename", filename)
cFilename := C.CString(filepath)
defer C.free(unsafe.Pointer(cFilename))
data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it?
metadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_array_free(data)
defer C.mlx_map_string_to_string_free(metadata)
if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 {
// TODO with the current error handling, this will never happen
return fmt.Errorf("load failed")
}
it := C.mlx_map_string_to_array_iterator_new(data)
// defer C.mlx_array_free(shaped)
// TODO confusing how memory management works with this...
for {
var key *C.cchar_t
var value C.mlx_array
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
k := C.GoString((*C.char)(key))
b.tensors[k] = &Array{
name: k,
a: value,
}
// slog.Info("XXX read", "tensor", b.tensors[k], "type", b.tensors[k].TypeString())
}
}
return nil
}
func (b *Backend) Get(name string) ml.Tensor {
var t ml.Tensor
var ok bool
if t, ok = b.tensors[name]; !ok {
// slog.Warn("unable to locate", "tensor", name)
return nil
}
// slog.Info("Fetching", "tensor", name, "type", b.tensors[name].TypeString())
return t
}
func (b *Backend) NewContext() ml.Context {
// slog.Info("MLX.NewContext")
return &Context{
stream: C.mlx_default_gpu_stream_new(),
}
}
func (b *Backend) Config() fs.Config {
return b.meta
}
type Context struct {
stream C.mlx_stream
mu sync.Mutex
arrays []C.mlx_array // TODO should we do some bookkeeping to ensure none of these Arrays are still lingering?
}
func (c *Context) Close() {
// C.mlx_synchronize(c.stream) // ???
C.mlx_stream_free(c.stream)
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.arrays {
slog.Info("XXX freeing", "array", a)
C.mlx_array_free(a)
}
}
func (c *Context) Compute(tensors ...ml.Tensor) {
// TODO - for the zero tensor case this feels like it might not be correct...
needSync := true
sync := func() {
if needSync {
C.mlx_synchronize(c.stream)
needSync = false
}
}
vec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vec)
for _, t := range tensors {
C.mlx_vector_array_append_value(vec, t.(*Array).a)
t.(*Array).sync = sync
}
C.mlx_async_eval(vec)
}
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
vec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vec)
needSync := true
sync := func() {
if needSync {
C.mlx_synchronize(c.stream)
needSync = false
}
}
for _, t := range tensors {
t.(*Array).sync = sync
C.mlx_vector_array_append_value(vec, t.(*Array).a)
}
C.mlx_async_eval(vec)
return c
}
func (c *Context) Input() ml.Context {
return c
}
// func (c *Context) Output() ml.Context {
// return c
// }
func (c *Context) Layer(_ int) ml.Context {
return c
}
func (c *Context) RandomNormal(shape []int, dtype ml.DType, loc, scale float32, key ml.Tensor) ml.Tensor {
var r C.mlx_array
var k C.mlx_array
if key != nil {
k = key.(*Array).a
}
sh := make([]C.int, len(shape))
for i := range shape {
sh[i] = C.int(shape[i])
}
C.mlx_random_normal(
&r,
&sh[0],
C.size_t(len(shape)),
C.mlx_dtype(dtype),
C.float(loc),
C.float(scale),
k,
c.stream,
)
return newArray(c, r)
}
func (c *Context) CompareWith(filepath string, tensors map[string]ml.Tensor, abortOnError bool) (err error) {
minCosine := float32(0.96) // TODO too low...
fileTensors := map[string]*Array{}
defer func() {
if err != nil {
for k, v := range tensors {
fmt.Fprintln(os.Stderr, "input tensor "+k+"\n"+v.ToString())
if fv, ok := fileTensors[k]; ok {
fmt.Fprintln(os.Stderr, " file tensor "+k+"\n"+fv.ToString())
} else {
fmt.Fprintln(os.Stderr, " file tensor "+k+" missing!\n")
}
}
}
if abortOnError {
if err != nil {
panic(fmt.Sprintf("%s", err))
}
}
}()
if _, err = os.Stat(filepath); err != nil {
filepath += ".safetensors"
if _, err = os.Stat(filepath); err != nil {
err = fmt.Errorf("failed to stat %s: %w", filepath, err)
return
}
err = nil
}
// slog.Info("Loading tensors from", "filename", filepath)
cFilename := C.CString(filepath)
defer C.free(unsafe.Pointer(cFilename))
data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it?
metadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_array_free(data)
defer C.mlx_map_string_to_string_free(metadata)
stream := C.mlx_default_cpu_stream_new()
if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 {
// TODO with the current error handling, this will never happen
err = fmt.Errorf("load failed")
return
}
it := C.mlx_map_string_to_array_iterator_new(data)
allTensors := []ml.Tensor{}
for _, t := range tensors {
allTensors = append(allTensors, t)
}
for {
var key *C.cchar_t
var value C.mlx_array
defer C.mlx_array_free(value)
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
k := C.GoString((*C.char)(key))
var r C.mlx_array
defer C.mlx_array_free(r)
C.mlx_astype(
&r,
value,
C.MLX_FLOAT32,
stream,
)
fileTensors[k] = &Array{
name: k,
a: r,
}
// slog.Info("XXX read", "tensor", t, "type", t.TypeString())
allTensors = append(allTensors, fileTensors[k])
}
c.Forward(allTensors...)
for k, t := range tensors {
a, ok := fileTensors[k]
if !ok {
err = fmt.Errorf("tensor named %s not found in file", k)
return
}
if !reflect.DeepEqual(a.Shape(), t.Shape()) {
err = fmt.Errorf("mismatched shapes: file: %v vs. input %v", a.Shape(), t.Shape())
return
}
// slog.Info("XXX shapes match", "shape", t.Shape())
// TODO handle int types...
tDType := t.DType()
if tDType != ml.DTypeFloat16 && tDType != ml.DTypeFloat32 {
var r C.mlx_array
defer C.mlx_array_free(r)
C.mlx_astype(
&r,
t.(*Array).a,
C.MLX_FLOAT32,
stream,
)
t = &Array{
a: r,
}
c.Forward(t)
}
af := a.Floats()
tf := t.Floats()
cos := cosineSimilarity(af, tf)
diff := a.Sub(c, t)
min := diff.Min(c, nil, true)
max := diff.Max(c, nil, true)
c.Forward(min, max)
minf := min.Floats()
maxf := max.Floats()
if cos < minCosine {
err = fmt.Errorf("%s shapes match, but not similar enough: %v min_difference=%v max_difference=%v", k, cos, minf, maxf)
return
}
slog.Info("XXX tensors are similar", k, cos, "shape", t.Shape(), "min_difference", minf, "max_difference", maxf)
}
err = nil
return
}
func dotProduct[V float32 | float64](v1, v2 []V) V {
var result V = 0
if len(v1) != len(v2) {
return result
}
for i := 0; i < len(v1); i++ {
result += v1[i] * v2[i]
}
return result
}
func magnitude[V float32 | float64](v []V) V {
var result V = 0
for _, val := range v {
result += val * val
}
return V(math.Sqrt(float64(result)))
}
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
mag1 := magnitude(v1)
mag2 := magnitude(v2)
if mag1 == 0 || mag2 == 0 {
return 0
}
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
}
func euclideanDistance[V float32 | float64](v1, v2 []V) V {
if len(v1) != len(v2) {
return V(math.Inf(1))
}
var sum V = 0
for i := 0; i < len(v1); i++ {
diff := v1[i] - v2[i]
sum += diff * diff
}
return V(math.Sqrt(float64(sum)))
}
func manhattanDistance[V float32 | float64](v1, v2 []V) V {
if len(v1) != len(v2) {
return V(math.Inf(1))
}
var sum V = 0
for i := 0; i < len(v1); i++ {
sum += V(math.Abs(float64(v1[i] - v2[i])))
}
return sum
}
type Array struct {
name string
a C.mlx_array
c *Context
sync func()
}
func newArray(ctx *Context, a C.mlx_array) *Array {
// TODO measure impact and if this slows things down, make it conditional on some debugging flag at load time
var name string
_, f, l, ok := runtime.Caller(2)
if ok {
name = fmt.Sprintf("%s:%d", f, l)
}
t := &Array{
name: name,
a: a,
c: ctx,
}
// DEBUG memory allocation problems...
// slog.Info("XXX Allocated", "array", t, "a", a)
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.arrays = append(ctx.arrays, a)
return t
}
// FromFloats implements ml.Context.
func (c *Context) FromFloats(s []float32, shape ...int) ml.Tensor {
u16s := make([]float16.Float16, len(s))
for i := range u16s {
u16s[i] = float16.Fromfloat32(s[i])
}
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
return newArray(c,
C.mlx_array_new_data(
unsafe.Pointer(&u16s[0]),
&cshape[0],
C.int(len(cshape)),
C.MLX_FLOAT16,
),
)
}
func (a *Array) Floats() []float32 {
if a.sync != nil {
a.sync()
}
l := (int)(C.mlx_array_size(a.a))
switch C.mlx_array_dtype(a.a) {
case C.MLX_BFLOAT16:
panic("bfloat16 not yet implemented")
case C.MLX_FLOAT16:
data := C.mlx_array_data_float16_asvoid(a.a)
if data == nil {
panic("nil data, wasn't eval'd")
}
u16s := unsafe.Slice((*uint16)(data), l)
f32s := make([]float32, len(u16s))
for i := range u16s {
f32s[i] = float16.Frombits(u16s[i]).Float32()
}
return f32s
case C.MLX_FLOAT32:
data := C.mlx_array_data_float32(a.a)
if data == nil {
panic("nil data, wasn't eval'd")
}
f32s := unsafe.Slice((*float32)(data), l)
return f32s
default:
panic(fmt.Sprintf("unsupported dtype for Floats: %d", C.mlx_array_dtype(a.a)))
}
}
// FromInts implements ml.Context.
func (c *Context) FromInts(s []int32, shape ...int) ml.Tensor {
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
return newArray(c,
C.mlx_array_new_data(
unsafe.Pointer(&s[0]),
&cshape[0],
C.int(len(cshape)),
C.MLX_INT32,
),
)
}
func (a *Array) Ints() []int32 {
if a.sync != nil {
a.sync()
}
l := (int)(C.mlx_array_size(a.a))
switch C.mlx_array_dtype(a.a) {
case C.MLX_INT32:
data := C.mlx_array_data_int32(a.a)
if data == nil {
panic("nil data, wasn't eval'd")
}
i32s := unsafe.Slice((*int32)(data), l)
return i32s
// TODO other types via conversion?
default:
panic(fmt.Sprintf("unsupported dtype for Ints: %d", C.mlx_array_dtype(a.a)))
}
}
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
sh := make([]C.int, len(shape))
for i, s := range shape {
sh[i] = (C.int)(s)
}
var r C.mlx_array
C.mlx_zeros(
&r,
&sh[0],
(C.size_t)(len(sh)),
C.mlx_dtype(dtype),
c.stream,
)
return newArray(c, r)
}
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
// TODO more efficient impl?
return c.Zeros(dtype, shape...)
}
func (a *Array) DType() ml.DType {
return (ml.DType)(C.mlx_array_dtype(a.a))
}
func (a *Array) Dim(n int) int {
return int(C.mlx_array_dim(a.a, C.int(n)))
}
func (a *Array) Stride(n int) int {
return (int)(C.stride(a.a, (C.int)(n)))
}
func (c *Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
var r C.mlx_array
C.mlx_arange(
&r,
C.double(start),
C.double(stop),
C.double(step),
(C.mlx_dtype)(dtype),
c.stream,
)
return newArray(c, r)
}
// Scale implements ml.Tensor.
func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor {
scale := C.mlx_array_new_float(C.float(s))
var r C.mlx_array
C.mlx_multiply(
&r,
a.a,
scale,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Softmax(ctx ml.Context) ml.Tensor {
var r C.mlx_array
C.mlx_softmax(
&r,
a.a,
false, // TODO - precise?
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) SliceUpdate(ctx ml.Context, update ml.Tensor, start, stop, strides []int) ml.Tensor {
cStart := make([]C.int, len(start))
for i := range start {
cStart[i] = C.int(start[i])
}
cStop := make([]C.int, len(stop))
for i := range stop {
cStop[i] = C.int(stop[i])
}
cStrides := make([]C.int, len(strides))
for i := range strides {
cStrides[i] = C.int(strides[i])
}
var r C.mlx_array
C.mlx_slice_update(
&r,
a.a,
update.(*Array).a,
(*C.int)(unsafe.Pointer(&cStart[0])),
C.size_t(len(cStart)),
(*C.int)(unsafe.Pointer(&cStop[0])),
C.size_t(len(cStop)),
(*C.int)(unsafe.Pointer(&cStrides[0])),
C.size_t(len(cStrides)),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) SliceUpdateDynamic(ctx ml.Context, update, start ml.Tensor, axes []int) ml.Tensor {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var r C.mlx_array
C.mlx_slice_update_dynamic(
&r,
a.a,
update.(*Array).a,
start.(*Array).a,
(*C.int)(unsafe.Pointer(&cAxes[0])),
C.size_t(len(cAxes)),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) PutAlongAxis(ctx ml.Context, indicies, values ml.Tensor, axis int) ml.Tensor {
var r C.mlx_array
C.mlx_put_along_axis(
&r,
a.a,
indicies.(*Array).a,
values.(*Array).a,
C.int(axis),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) Scatter(ctx ml.Context, indicies []ml.Tensor, updates ml.Tensor, axes []int) ml.Tensor {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var cAxes0 *C.int
if len(cAxes) > 0 {
cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
}
indiciesVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(indiciesVec)
for _, ind := range indicies {
C.mlx_vector_array_append_value(indiciesVec, ind.(*Array).a)
}
var r C.mlx_array
C.mlx_scatter(
&r,
a.a,
indiciesVec,
updates.(*Array).a,
cAxes0,
C.size_t(len(cAxes)),
ctx.(*Context).stream,
)
// Release the old array and replace with the new one to ensure the same underlying buffer is used
a.c.mu.Lock()
defer a.c.mu.Unlock()
for i := range a.c.arrays {
if a.c.arrays[i] == a.a {
C.mlx_array_free(a.a)
a.a = r
a.c.arrays[i] = r
return a
}
}
panic("unable to locate array in context")
}
func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
C.mlx_copy(
&a2.(*Array).a,
a.a,
ctx.(*Context).stream,
)
// TODO - view?
return newArray(ctx.(*Context), a2.(*Array).a)
}
func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
var r C.mlx_array
C.mlx_add(
&r,
a.a,
a2.(*Array).a,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Sub(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
var r C.mlx_array
C.mlx_subtract(
&r,
a.a,
a2.(*Array).a,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Max(ctx ml.Context, axes []int, keepDims bool) ml.Tensor {
var r C.mlx_array
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var cAxes0 *C.int
if len(cAxes) > 0 {
cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
C.mlx_max_axes(
&r,
a.a,
cAxes0,
C.size_t(len(cAxes)),
C._Bool(keepDims),
ctx.(*Context).stream,
)
} else {
C.mlx_max(
&r,
a.a,
C._Bool(keepDims),
ctx.(*Context).stream,
)
}
return newArray(ctx.(*Context), r)
}
func (a *Array) Min(ctx ml.Context, axes []int, keepDims bool) ml.Tensor {
var r C.mlx_array
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
var cAxes0 *C.int
if len(cAxes) > 0 {
cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
C.mlx_min_axes(
&r,
a.a,
cAxes0,
C.size_t(len(cAxes)),
C._Bool(keepDims),
ctx.(*Context).stream,
)
} else {
C.mlx_min(
&r,
a.a,
C._Bool(keepDims),
ctx.(*Context).stream,
)
}
return newArray(ctx.(*Context), r)
}
func (a *Array) Matmul(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
var r C.mlx_array
C.mlx_matmul(
&r,
a.a,
a2.(*Array).a,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
// slog.Info("MLX.RMSNorm", "a", a, "w", w)
var r C.mlx_array
C.mlx_fast_rms_norm(
&r,
a.a,
w.(*Array).a,
C.float(eps),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
var r C.mlx_array
C.mlx_fast_layer_norm(
&r,
a.a,
w.(*Array).a,
b.(*Array).a,
C.float(eps),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
// TODO implement
panic("NOT YET IMPLEMENTED")
}
func (t Array) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("NOT YET IMPLEMENTED")
}
// RoPE implements Rotary Positional Encoding
//
// dims (int) The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.
// traditional (bool) If set to True choose the traditional implementation which rotates consecutive dimensions.
// scale (float) The scale used to scale the positions.
// offset (int) The position offset to start at. TODO MLX-C does not yet expose Offset as an Array
// WithBase (float, optional) The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of base and freqs must be None.
// WithFreqs (array, optional) Optional frequencies to use with RoPE. If set, the base parameter must be None. Default: None.
func (a *Array) RoPE(ctx ml.Context, dims int, traditional bool, scale float32, offset int, options ...func(*ml.RoPEOptions)) ml.Tensor {
opts := ml.RoPEOptions{}
// Apply any provided options
for _, option := range options {
option(&opts)
}
var r C.mlx_array
var base C.mlx_optional_float
var freqs C.mlx_array
if opts.Base != nil {
base.value = C.float(*opts.Base)
base.has_value = true
}
if opts.Freqs != nil {
freqs = opts.Freqs.(*Array).a
}
C.mlx_fast_rope(
&r,
a.a,
C.int(dims),
C._Bool(traditional),
base,
C.float(scale),
C.int(offset),
freqs,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
// A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.
//
// Supports:
// - Multi-Head Attention
// - Grouped Query Attention
// - Multi-Query Attention
//
// Note:
// - The softmax operation is performed in float32 regardless of the input precision.
// - For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q.
//
// In the following the dimensions are given by:
// - B: The batch size.
// - N_q: The number of query heads.
// - N_kv: The number of key and value heads.
// - T_q: The number of queries per example.
// - T_kv: The number of keys and values per example.
// - D: The per-head dimension.
//
// Parameters:
// - [subject array] queries (array) Queries with shape [B, N_q, T_q, D].
// - keys (array) with shape [B, N_kv, T_kv, D].
// - values (array) with shape [B, N_kv, T_kv, D].
// - scale (float) Scale for queries (typically 1.0 / sqrt(q.shape(-1)).
// - mask (str or array, optional) The mask to apply to the query-key scores.
// The mask can be an array or a string indicating the mask type. The only supported string type is "causal".
// If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and
// must be broadcast-compatible with the shape [B, N, T_q, T_kv]. If an additive mask is given its type must
// promote to the promoted type of q, k, and v.
// - sinks (array, optional) An optional array of attention sinks. Default: None.
func (queries *Array) ScaledDotProductAttention(ctx ml.Context, keys, values ml.Tensor, scale float64, maskMode string, mask ml.Tensor, sinks ml.Tensor) ml.Tensor {
var r C.mlx_array
var s C.mlx_array
if sinks != nil {
s = sinks.(*Array).a
}
maskModeC := C.CString(maskMode)
defer C.free(unsafe.Pointer(maskModeC))
var maskArr C.mlx_array
if mask != nil {
maskArr = mask.(*Array).a
}
C.mlx_fast_scaled_dot_product_attention(
&r,
queries.a,
keys.(*Array).a,
values.(*Array).a,
C.float(scale),
maskModeC,
maskArr,
s,
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) TakeAxes(ctx ml.Context, indicies ml.Tensor, axes int) ml.Tensor {
var r C.mlx_array
C.mlx_take_axis(&r, a.a, indicies.(*Array).a, C.int(axes), ctx.(*Context).stream)
return newArray(ctx.(*Context), r)
}
// TODO not sure if we'll want this variation taking raw ints instead of a tensor...
// func (a *Array) TakeAxes(ctx ml.Context, axes int, indicies ...int) ml.Tensor {
// var i C.mlx_array
// var r C.mlx_array
// if indicies != nil {
// shape := []C.int{C.int(len(indicies))}
// cindicies := make([]int32, len(indicies))
// for i, v := range indicies {
// cindicies[i] = int32(v)
// }
// i = C.mlx_array_new_data(
// unsafe.Pointer(&cindicies[0]),
// &shape[0],
// C.int(len(shape)),
// C.MLX_INT32,
// )
// }
// C.mlx_take_axis(&r, a.a, i, C.int(axes), ctx.(*Context).stream)
// return newArray(ctx.(*Context), r)
// }
func (a *Array) GELU(ctx ml.Context, up ...ml.Tensor) ml.Tensor {
// TODO precise vs fast, and compile
// x * mx.sigmoid(1.702 * x)
u16s := []float16.Float16{float16.Fromfloat32(1.702)}
cshape := []C.int{1}
f := C.mlx_array_new_data(unsafe.Pointer(&u16s[0]), &cshape[0], 1, C.MLX_FLOAT16)
defer C.mlx_array_free(f)
var r1, r2, r3 C.mlx_array
C.mlx_multiply(&r1, a.a, f, ctx.(*Context).stream)
defer C.mlx_array_free(r1)
C.mlx_sigmoid(&r2, r1, ctx.(*Context).stream)
defer C.mlx_array_free(r2)
C.mlx_multiply(&r3, a.a, r2, ctx.(*Context).stream)
if len(up) > 0 {
var r4 C.mlx_array
defer C.mlx_array_free(r3)
C.mlx_multiply(&r4, r3, up[0].(*Array).a, ctx.(*Context).stream)
return newArray(ctx.(*Context), r4)
}
return newArray(ctx.(*Context), r3)
}
// Create a view into the array with the given shape and strides.
//
// The resulting array will always be as if the provided array was row
// contiguous regardless of the provided arrays storage order and current
// strides.
//
// Note that this function should be used with caution as it changes the shape
// and strides of the array directly. This can lead to the resulting array
// pointing to invalid memory locations which can result into crashes.
//
// Parameters:
// - shape (list(int), optional) The shape of the resulting array. If None it defaults to a.shape().
// - strides (list(int), optional) The strides of the resulting array. If None it defaults to the
// reverse exclusive cumulative product of a.shape().
// - offset (int) Skip that many elements from the beginning of the input array.
func (a *Array) AsStrided(ctx ml.Context, shape, strides []int, offset int) ml.Tensor {
var r C.mlx_array
sh := make([]C.int, len(shape))
st := make([]C.int64_t, len(strides))
var sh0 *C.int
var st0 *C.int64_t
for i, s := range shape {
sh[i] = C.int(s)
}
for i, s := range strides {
st[i] = C.int64_t(s)
}
if len(sh) > 0 {
sh0 = (*C.int)(unsafe.Pointer(&sh[0]))
}
if len(st) > 0 {
st0 = (*C.int64_t)(unsafe.Pointer(&st[0]))
}
C.mlx_as_strided(
&r,
a.a,
sh0,
C.size_t(len(sh)),
st0,
C.size_t(len(st)),
C.size_t(offset),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
cshape := make([]C.int, len(shape))
for i, dim := range shape {
cshape[i] = C.int(dim)
}
var r C.mlx_array
C.mlx_reshape(&r, a.a, &cshape[0], C.size_t(len(cshape)), ctx.(*Context).stream)
return newArray(ctx.(*Context), r)
}
func (a *Array) Transpose(ctx ml.Context, shape ...int) ml.Tensor {
ndim := min(C.mlx_array_ndim(a.a), C.size_t(len(shape)))
var r C.mlx_array
sh := make([]C.int, ndim)
for i := range ndim {
sh[i] = (C.int)(shape[i])
if int(sh[i]) >= int(ndim) {
slog.Error("Permute error", "tensor", a, "shape", shape)
panic("invalid pemute call")
}
}
if len(sh) > 0 {
C.mlx_transpose_axes(
&r,
a.a,
&sh[0],
ndim,
ctx.(*Context).stream,
)
} else {
C.mlx_transpose(
&r,
a.a,
ctx.(*Context).stream,
)
}
return newArray(ctx.(*Context), r)
}
func (a *Array) Contiguous(ctx ml.Context, allowColMajor bool) ml.Tensor {
var r C.mlx_array
C.mlx_contiguous(
&r,
a.a,
(C._Bool)(allowColMajor),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
// Conv2D implements ml.Tensor.
// GGML API
// input: [N, IC, IH, IW]
// weight: [OCIC, KH, KW]
// result: [N, OC, OH, OW]
//
// MLX:
// input: (N, KH, KW, C_in)
// weight: (C_out, IH, IW, C_in)
// result: XXX
func (input *Array) Conv2D(ctx ml.Context, weight ml.Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) ml.Tensor {
var r C.mlx_array
C.mlx_conv2d(
&r,
input.a,
weight.(*Array).a,
C.int(stride0),
C.int(stride1),
C.int(padding0),
C.int(padding1),
C.int(dilation0),
C.int(dilation1),
C.int(groups),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (input *Array) Conv3D(ctx ml.Context, weight ml.Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) ml.Tensor {
var r C.mlx_array
C.mlx_conv3d(
&r,
input.a,
weight.(*Array).a,
C.int(stride0),
C.int(stride1),
C.int(stride2),
C.int(padding0),
C.int(padding1),
C.int(padding2),
C.int(dilation0),
C.int(dilation1),
C.int(dilation2),
C.int(groups),
ctx.(*Context).stream,
)
return newArray(ctx.(*Context), r)
}
func (a *Array) ToString() string {
str := C.mlx_string_new()
C.mlx_array_tostring(&str, a.a)
s := C.mlx_string_data(str)
defer C.mlx_string_free(str)
return C.GoString(s)
}
func (a *Array) LogValue() slog.Value {
dims := int(C.mlx_array_ndim(a.a))
strides := make([]int, dims)
for i := range strides {
strides[i] = int(C.stride(a.a, (C.int)(i)))
}
return slog.GroupValue(
slog.String("name", a.name),
slog.String("type", a.TypeString()),
slog.Any("shape", a.Shape()),
slog.Any("strides", strides),
// slog.String("values", C.GoString(s)),
)
}
func (a *Array) Shape() []int {
shape := make([]int, C.mlx_array_ndim(a.a))
for i := range shape {
shape[i] = int(C.mlx_array_dim(a.a, C.int(i)))
}
return shape
}
func (a *Array) TypeString() string {
switch C.mlx_array_dtype(a.a) {
case C.MLX_BOOL:
return "bool"
case C.MLX_UINT8:
return "uint8"
case C.MLX_UINT16:
return "uint16"
case C.MLX_UINT32:
return "uint32"
case C.MLX_UINT64:
return "uint64"
case C.MLX_INT8:
return "int8"
case C.MLX_INT16:
return "int16"
case C.MLX_INT32:
return "int32"
case C.MLX_INT64:
return "int64"
case C.MLX_FLOAT16:
return "float16"
case C.MLX_FLOAT32:
return "float32"
case C.MLX_BFLOAT16:
return "bfloat16"
case C.MLX_COMPLEX64:
return "complex64"
default:
return "unknown"
}
}