mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
174 lines
4.7 KiB
Go
174 lines
4.7 KiB
Go
//go:build mlx
|
|
|
|
package mlx
|
|
|
|
/*
|
|
#include "mlx/c/mlx.h"
|
|
#include <stdlib.h>
|
|
|
|
// Forward declaration for Go callback
|
|
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
|
|
|
// Destructor for payload (Go handle)
|
|
extern void goClosureDestructor(void* payload);
|
|
*/
|
|
import "C"
|
|
import (
|
|
"runtime/cgo"
|
|
"sync"
|
|
"unsafe"
|
|
)
|
|
|
|
// inClosureCallback is set to true during closure callback execution.
|
|
var inClosureCallback bool
|
|
var closureCallbackMu sync.Mutex
|
|
|
|
// InClosureCallback returns true if we're currently executing inside a closure callback.
|
|
func InClosureCallback() bool {
|
|
closureCallbackMu.Lock()
|
|
defer closureCallbackMu.Unlock()
|
|
return inClosureCallback
|
|
}
|
|
|
|
// CompiledFunc is a compiled MLX function that can be called efficiently.
|
|
// All intermediate arrays during execution stay inside MLX - only inputs
|
|
// and outputs cross the Go boundary.
|
|
type CompiledFunc struct {
|
|
closure C.mlx_closure
|
|
compiled C.mlx_closure
|
|
}
|
|
|
|
// ClosureFunc is the signature for functions that can be compiled.
|
|
// It takes a slice of input arrays and returns a slice of output arrays.
|
|
type ClosureFunc func(inputs []*Array) []*Array
|
|
|
|
// Compile compiles a Go function into an optimized MLX closure.
|
|
// The function is traced once during compilation, then subsequent calls
|
|
// run the optimized graph without creating Go intermediate arrays.
|
|
//
|
|
// Example:
|
|
//
|
|
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
|
|
// a, b := inputs[0], inputs[1]
|
|
// c := mlx.Add(a, b)
|
|
// d := mlx.Mul(c, c)
|
|
// return []*mlx.Array{d}
|
|
// })
|
|
// defer compiled.Free()
|
|
//
|
|
// result := compiled.Call(x, y)[0]
|
|
func Compile(fn ClosureFunc) *CompiledFunc {
|
|
return CompileShapeless(fn, false)
|
|
}
|
|
|
|
// CompileShapeless compiles with optional shapeless mode.
|
|
// If shapeless=true, the function works for any input shape after tracing.
|
|
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
|
|
// Create a cgo.Handle to prevent the Go function from being GC'd
|
|
handle := cgo.NewHandle(fn)
|
|
|
|
// Create the closure from the Go callback
|
|
closure := C.mlx_closure_new_func_payload(
|
|
(*[0]byte)(C.goClosureCallback),
|
|
unsafe.Pointer(handle),
|
|
(*[0]byte)(C.goClosureDestructor),
|
|
)
|
|
|
|
// Compile the closure
|
|
compiled := C.mlx_closure_new()
|
|
C.mlx_compile(&compiled, closure, C.bool(shapeless))
|
|
|
|
return &CompiledFunc{
|
|
closure: closure,
|
|
compiled: compiled,
|
|
}
|
|
}
|
|
|
|
// Call invokes the compiled function with the given inputs.
|
|
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
|
// Pack inputs into vector
|
|
inputVec := C.mlx_vector_array_new()
|
|
for _, arr := range inputs {
|
|
C.mlx_vector_array_append_value(inputVec, arr.c)
|
|
}
|
|
|
|
// Apply compiled closure
|
|
outputVec := C.mlx_vector_array_new()
|
|
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
|
|
C.mlx_vector_array_free(inputVec)
|
|
|
|
// Unpack outputs
|
|
numOutputs := int(C.mlx_vector_array_size(outputVec))
|
|
outputs := make([]*Array, numOutputs)
|
|
for i := 0; i < numOutputs; i++ {
|
|
var arr C.mlx_array
|
|
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
|
|
outputs[i] = newArray(arr)
|
|
}
|
|
C.mlx_vector_array_free(outputVec)
|
|
|
|
return outputs
|
|
}
|
|
|
|
// CallEval invokes the compiled function and evaluates the results.
|
|
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
|
|
outputs := cf.Call(inputs...)
|
|
Eval(outputs...)
|
|
return outputs
|
|
}
|
|
|
|
// Free releases the compiled function resources.
|
|
func (cf *CompiledFunc) Free() {
|
|
C.mlx_closure_free(cf.compiled)
|
|
C.mlx_closure_free(cf.closure)
|
|
}
|
|
|
|
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
|
|
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
|
|
func borrowArray(array C.mlx_array) *Array {
|
|
return &Array{c: array}
|
|
}
|
|
|
|
//export goClosureCallback
|
|
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
|
// Set flag to disable AddCleanup during callback
|
|
closureCallbackMu.Lock()
|
|
inClosureCallback = true
|
|
closureCallbackMu.Unlock()
|
|
defer func() {
|
|
closureCallbackMu.Lock()
|
|
inClosureCallback = false
|
|
closureCallbackMu.Unlock()
|
|
}()
|
|
|
|
// Recover the Go function from the handle
|
|
handle := cgo.Handle(payload)
|
|
fn := handle.Value().(ClosureFunc)
|
|
|
|
// Convert input vector to Go slice - use borrowArray since MLX owns these
|
|
numInputs := int(C.mlx_vector_array_size(input))
|
|
inputs := make([]*Array, numInputs)
|
|
for i := 0; i < numInputs; i++ {
|
|
var arr C.mlx_array
|
|
C.mlx_vector_array_get(&arr, input, C.size_t(i))
|
|
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
|
|
}
|
|
|
|
// Call the Go function
|
|
outputs := fn(inputs)
|
|
|
|
// Build output vector
|
|
*res = C.mlx_vector_array_new()
|
|
for _, arr := range outputs {
|
|
C.mlx_vector_array_append_value(*res, arr.c)
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
//export goClosureDestructor
|
|
func goClosureDestructor(payload unsafe.Pointer) {
|
|
handle := cgo.Handle(payload)
|
|
handle.Delete()
|
|
}
|