Files
ollama/x/imagegen/mlx/mlx.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

2239 lines
64 KiB
Go

//go:build mlx
package mlx
/*
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src
#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc
#include "mlx/c/mlx.h"
#include <stdlib.h>
#include <stdint.h>
// Cached default GPU stream for all ops
static mlx_stream _default_stream = {0};
static mlx_stream _cpu_stream = {0};
static inline mlx_stream default_stream() {
if (_default_stream.ctx == NULL) {
_default_stream = mlx_default_gpu_stream_new();
}
return _default_stream;
}
static inline void set_default_stream(mlx_stream s) {
_default_stream = s;
}
// CPU stream for file loading (Load primitive only runs on CPU)
static inline mlx_stream cpu_stream() {
if (_cpu_stream.ctx == NULL) {
_cpu_stream = mlx_default_cpu_stream_new();
}
return _cpu_stream;
}
// CGO noescape/nocallback hints to reduce CGO overhead
// noescape: pointers won't escape, no heap allocation needed
// nocallback: function won't call back into Go
#cgo noescape mlx_add
#cgo nocallback mlx_add
#cgo noescape mlx_subtract
#cgo nocallback mlx_subtract
#cgo noescape mlx_multiply
#cgo nocallback mlx_multiply
#cgo noescape mlx_divide
#cgo nocallback mlx_divide
#cgo noescape mlx_negative
#cgo nocallback mlx_negative
#cgo noescape mlx_abs
#cgo nocallback mlx_abs
#cgo noescape mlx_exp
#cgo nocallback mlx_exp
#cgo noescape mlx_log
#cgo nocallback mlx_log
#cgo noescape mlx_sqrt
#cgo nocallback mlx_sqrt
#cgo noescape mlx_rsqrt
#cgo nocallback mlx_rsqrt
#cgo noescape mlx_square
#cgo nocallback mlx_square
#cgo noescape mlx_power
#cgo nocallback mlx_power
#cgo noescape mlx_erf
#cgo nocallback mlx_erf
#cgo noescape mlx_sigmoid
#cgo nocallback mlx_sigmoid
#cgo noescape mlx_tanh
#cgo nocallback mlx_tanh
#cgo noescape mlx_sin
#cgo nocallback mlx_sin
#cgo noescape mlx_cos
#cgo nocallback mlx_cos
#cgo noescape mlx_maximum
#cgo nocallback mlx_maximum
#cgo noescape mlx_minimum
#cgo nocallback mlx_minimum
#cgo noescape mlx_clip
#cgo nocallback mlx_clip
#cgo noescape mlx_sum
#cgo nocallback mlx_sum
#cgo noescape mlx_sum_axis
#cgo nocallback mlx_sum_axis
#cgo noescape mlx_mean
#cgo nocallback mlx_mean
#cgo noescape mlx_mean_axis
#cgo nocallback mlx_mean_axis
#cgo noescape mlx_var_axis
#cgo nocallback mlx_var_axis
#cgo noescape mlx_argmax
#cgo nocallback mlx_argmax
#cgo noescape mlx_argmax_axis
#cgo nocallback mlx_argmax_axis
#cgo noescape mlx_softmax_axis
#cgo nocallback mlx_softmax_axis
#cgo noescape mlx_cumsum
#cgo nocallback mlx_cumsum
#cgo noescape mlx_matmul
#cgo nocallback mlx_matmul
#cgo noescape mlx_addmm
#cgo nocallback mlx_addmm
#cgo noescape mlx_gather_mm
#cgo nocallback mlx_gather_mm
#cgo noescape mlx_gather_qmm
#cgo nocallback mlx_gather_qmm
#cgo noescape mlx_reshape
#cgo nocallback mlx_reshape
#cgo noescape mlx_transpose_axes
#cgo nocallback mlx_transpose_axes
#cgo noescape mlx_expand_dims
#cgo nocallback mlx_expand_dims
#cgo noescape mlx_squeeze_axis
#cgo nocallback mlx_squeeze_axis
#cgo noescape mlx_flatten
#cgo nocallback mlx_flatten
#cgo noescape mlx_concatenate_axis
#cgo nocallback mlx_concatenate_axis
#cgo noescape mlx_slice
#cgo nocallback mlx_slice
#cgo noescape mlx_slice_update
#cgo nocallback mlx_slice_update
#cgo noescape mlx_as_strided
#cgo nocallback mlx_as_strided
#cgo noescape mlx_view
#cgo nocallback mlx_view
#cgo noescape mlx_contiguous
#cgo nocallback mlx_contiguous
#cgo noescape mlx_pad
#cgo nocallback mlx_pad
#cgo noescape mlx_tile
#cgo nocallback mlx_tile
#cgo noescape mlx_take_axis
#cgo nocallback mlx_take_axis
#cgo noescape mlx_take_along_axis
#cgo nocallback mlx_take_along_axis
#cgo noescape mlx_put_along_axis
#cgo nocallback mlx_put_along_axis
#cgo noescape mlx_where
#cgo nocallback mlx_where
#cgo noescape mlx_argsort_axis
#cgo nocallback mlx_argsort_axis
#cgo noescape mlx_argpartition_axis
#cgo nocallback mlx_argpartition_axis
#cgo noescape mlx_topk_axis
#cgo nocallback mlx_topk_axis
#cgo noescape mlx_less
#cgo nocallback mlx_less
#cgo noescape mlx_greater_equal
#cgo nocallback mlx_greater_equal
#cgo noescape mlx_logical_and
#cgo nocallback mlx_logical_and
#cgo noescape mlx_zeros
#cgo nocallback mlx_zeros
#cgo noescape mlx_zeros_like
#cgo nocallback mlx_zeros_like
#cgo noescape mlx_ones
#cgo nocallback mlx_ones
#cgo noescape mlx_full
#cgo nocallback mlx_full
#cgo noescape mlx_arange
#cgo nocallback mlx_arange
#cgo noescape mlx_linspace
#cgo nocallback mlx_linspace
#cgo noescape mlx_tri
#cgo nocallback mlx_tri
#cgo noescape mlx_astype
#cgo nocallback mlx_astype
#cgo noescape mlx_fast_rms_norm
#cgo nocallback mlx_fast_rms_norm
#cgo noescape mlx_fast_rope
#cgo nocallback mlx_fast_rope
#cgo noescape mlx_fast_scaled_dot_product_attention
#cgo nocallback mlx_fast_scaled_dot_product_attention
#cgo noescape mlx_conv2d
#cgo nocallback mlx_conv2d
#cgo noescape mlx_conv3d
#cgo nocallback mlx_conv3d
#cgo noescape mlx_random_key
#cgo nocallback mlx_random_key
#cgo noescape mlx_random_split
#cgo nocallback mlx_random_split
#cgo noescape mlx_random_categorical_num_samples
#cgo nocallback mlx_random_categorical_num_samples
#cgo noescape mlx_random_normal
#cgo nocallback mlx_random_normal
#cgo noescape mlx_random_uniform
#cgo nocallback mlx_random_uniform
#cgo noescape mlx_array_eval
#cgo nocallback mlx_array_eval
#cgo noescape mlx_eval
#cgo nocallback mlx_eval
#cgo noescape mlx_async_eval
#cgo nocallback mlx_async_eval
#cgo noescape mlx_synchronize
#cgo nocallback mlx_synchronize
#cgo noescape mlx_array_new
#cgo nocallback mlx_array_new
#cgo noescape mlx_array_new_data
#cgo nocallback mlx_array_new_data
#cgo noescape mlx_array_new_float
#cgo nocallback mlx_array_new_float
#cgo noescape mlx_array_free
#cgo nocallback mlx_array_free
#cgo noescape mlx_array_size
#cgo nocallback mlx_array_size
#cgo noescape mlx_array_ndim
#cgo nocallback mlx_array_ndim
#cgo noescape mlx_array_dim
#cgo nocallback mlx_array_dim
#cgo noescape mlx_array_dtype
#cgo nocallback mlx_array_dtype
#cgo noescape mlx_array_item_int32
#cgo nocallback mlx_array_item_int32
#cgo noescape mlx_vector_array_new_data
#cgo nocallback mlx_vector_array_new_data
#cgo noescape mlx_vector_array_free
#cgo nocallback mlx_vector_array_free
#cgo noescape mlx_array_new_int
#cgo nocallback mlx_array_new_int
#cgo noescape mlx_stream_new_device
#cgo nocallback mlx_stream_new_device
#cgo noescape mlx_get_default_stream
#cgo nocallback mlx_get_default_stream
#cgo noescape mlx_set_default_stream
#cgo nocallback mlx_set_default_stream
*/
import "C"
import (
"fmt"
"reflect"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
)
// Dtype represents MLX data types
type Dtype int
const (
DtypeBool Dtype = C.MLX_BOOL
DtypeUint8 Dtype = C.MLX_UINT8
DtypeUint16 Dtype = C.MLX_UINT16
DtypeUint32 Dtype = C.MLX_UINT32
DtypeUint64 Dtype = C.MLX_UINT64
DtypeInt8 Dtype = C.MLX_INT8
DtypeInt16 Dtype = C.MLX_INT16
DtypeInt32 Dtype = C.MLX_INT32
DtypeInt64 Dtype = C.MLX_INT64
DtypeFloat16 Dtype = C.MLX_FLOAT16
DtypeFloat32 Dtype = C.MLX_FLOAT32
DtypeFloat64 Dtype = C.MLX_FLOAT64
DtypeBFloat16 Dtype = C.MLX_BFLOAT16
DtypeComplex64 Dtype = C.MLX_COMPLEX64
)
// String implements fmt.Stringer for Dtype
func (d Dtype) String() string {
switch d {
case DtypeBool:
return "bool"
case DtypeUint8:
return "u8"
case DtypeUint16:
return "u16"
case DtypeUint32:
return "u32"
case DtypeUint64:
return "u64"
case DtypeInt8:
return "i8"
case DtypeInt16:
return "i16"
case DtypeInt32:
return "i32"
case DtypeInt64:
return "i64"
case DtypeFloat16:
return "f16"
case DtypeFloat32:
return "f32"
case DtypeFloat64:
return "f64"
case DtypeBFloat16:
return "bf16"
case DtypeComplex64:
return "c64"
default:
return "unknown"
}
}
// Memory Management:
//
// All arrays are automatically tracked for cleanup. On Eval(), non-kept arrays are freed.
//
// x := mlx.Matmul(input, weight) // x is tracked for cleanup
// mlx.Keep(x) // mark x as persistent
// mlx.Eval(x) // eval + free non-kept arrays
//
// Use Keep() for arrays that should persist (weights, caches).
// Use Free() to mark a kept array for cleanup on next Eval().
//
// Note: Not goroutine-safe. Use from a single goroutine.
// Array wraps an MLX array handle.
// Arrays are freed via Eval() cleanup (deterministic) or GC (fallback).
type Array struct {
c C.mlx_array
freed bool // Prevents double-free
kept bool // If true, survives Eval() cleanup
}
// arrays tracks all live arrays. On Eval(), non-kept arrays are freed.
// Not goroutine-safe.
var arrays = make([]*Array, 0, 4096)
// evalHandles is a pre-allocated slice for passing arrays to MLX eval.
var evalHandles = make([]C.mlx_array, 0, 64)
// arrayPool reduces allocations for intermediate arrays
var arrayPool = sync.Pool{
New: func() any { return &Array{} },
}
func newArray(array C.mlx_array) *Array {
// In compiled closures, MLX manages memory - skip Go tracking
if InClosureCallback() {
return &Array{c: array}
}
// Use pooled Array struct for efficiency
a := arrayPool.Get().(*Array)
a.c = array
a.freed = false
a.kept = false
// Track in global list
arrays = append(arrays, a)
return a
}
// Collect uses reflection to find all *Array fields in a struct (recursively).
// Use this to automatically gather model weights, cache state, etc.
func Collect(v any) []*Array {
var arrays []*Array
seen := make(map[uintptr]bool)
collect(reflect.ValueOf(v), &arrays, seen)
return arrays
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return
}
// Handle pointers
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
// Avoid infinite loops
ptr := v.Pointer()
if seen[ptr] {
return
}
seen[ptr] = true
// Check if it's *Array
if arr, ok := v.Interface().(*Array); ok {
if arr != nil && arr.c.ctx != nil {
*arrays = append(*arrays, arr)
}
return
}
collect(v.Elem(), arrays, seen)
return
}
// Handle structs
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanInterface() {
collect(field, arrays, seen)
}
}
return
}
// Handle slices
if v.Kind() == reflect.Slice {
for i := 0; i < v.Len(); i++ {
collect(v.Index(i), arrays, seen)
}
return
}
// Handle maps
if v.Kind() == reflect.Map {
for _, key := range v.MapKeys() {
collect(v.MapIndex(key), arrays, seen)
}
return
}
// Handle interfaces
if v.Kind() == reflect.Interface {
if !v.IsNil() {
collect(v.Elem(), arrays, seen)
}
return
}
}
// FreeStruct releases all *Array fields in a struct (recursively).
// Use this to free model weights when unloading a model.
func FreeStruct(v any) {
for _, arr := range Collect(v) {
arr.Free()
}
}
// Keep marks arrays to persist across Eval() cleanup.
// Kept arrays will NOT be freed when Eval() runs cleanup.
func Keep(arrays ...*Array) {
for _, a := range arrays {
if a != nil {
a.kept = true
}
}
}
// cleanup frees non-kept arrays and compacts the live array list.
// Returns number of arrays freed.
func cleanup() int {
freed := 0
n := 0
for _, a := range arrays {
if a.kept {
arrays[n] = a
n++
} else if a.c.ctx != nil && !a.freed {
C.mlx_array_free(a.c)
a.c.ctx = nil
arrayPool.Put(a)
freed++
}
}
arrays = arrays[:n]
return freed
}
// DebugArrays prints summary info about all tracked arrays.
func DebugArrays() {
var totalBytes int64
var keptCount, unkeptCount int
for _, a := range arrays {
if a.kept {
keptCount++
} else {
unkeptCount++
}
totalBytes += a.Nbytes()
}
fmt.Printf("[DEBUG] Arrays: %d kept, %d unkept, %.2f GB total\n",
keptCount, unkeptCount, float64(totalBytes)/(1024*1024*1024))
}
// DebugArraysVerbose prints detailed info about all tracked arrays, sorted by size.
func DebugArraysVerbose(topN int) {
type arrayInfo struct {
shape []int32
dtype Dtype
bytes int64
kept bool
}
var infos []arrayInfo
var totalBytes int64
for _, a := range arrays {
bytes := a.Nbytes()
infos = append(infos, arrayInfo{
shape: a.Shape(),
dtype: a.Dtype(),
bytes: bytes,
kept: a.kept,
})
totalBytes += bytes
}
// Sort by size descending
for i := 0; i < len(infos)-1; i++ {
for j := i + 1; j < len(infos); j++ {
if infos[j].bytes > infos[i].bytes {
infos[i], infos[j] = infos[j], infos[i]
}
}
}
fmt.Printf("[DEBUG] %d arrays, %.2f GB total:\n", len(infos), float64(totalBytes)/(1024*1024*1024))
for i, info := range infos {
if i >= topN {
break
}
keptStr := ""
if info.kept {
keptStr = " [kept]"
}
fmt.Printf(" %3d. %8.2f MB %v %v%s\n",
i+1, float64(info.bytes)/(1024*1024), info.shape, info.dtype, keptStr)
}
}
// Eval synchronously evaluates arrays and cleans up non-kept arrays.
// Outputs are automatically kept (survive cleanup). Returns them for chaining.
func Eval(outputs ...*Array) []*Array {
// Keep outputs so cleanup doesn't free them
for _, o := range outputs {
if o != nil {
o.kept = true
}
}
// Cleanup non-kept arrays
cleanup()
// Then evaluate
if len(outputs) > 0 {
evalHandles = evalHandles[:0]
for _, o := range outputs {
if o != nil {
evalHandles = append(evalHandles, o.c)
}
}
if len(evalHandles) > 0 {
vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles)))
C.mlx_eval(vec)
C.mlx_vector_array_free(vec)
}
}
return outputs
}
// AsyncEval dispatches async evaluation and cleans up non-kept arrays.
// Outputs are automatically kept (survive cleanup).
func AsyncEval(outputs ...*Array) {
// Keep outputs so cleanup doesn't free them
for _, o := range outputs {
if o != nil {
o.kept = true
}
}
// Cleanup non-kept arrays
cleanup()
// Then dispatch async eval
if len(outputs) > 0 {
evalHandles = evalHandles[:0]
for _, o := range outputs {
if o != nil {
evalHandles = append(evalHandles, o.c)
}
}
if len(evalHandles) > 0 {
vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles)))
C.mlx_async_eval(vec)
C.mlx_vector_array_free(vec)
}
}
}
// Sync waits for all async operations to complete (no cleanup).
func Sync() {
C.mlx_synchronize(C.default_stream())
}
// Free marks this array for cleanup on the next Eval().
// The array is not immediately freed - cleanup happens during Eval().
//
// Pattern for loops:
//
// oldLatents.Free() // mark for cleanup
// mlx.Eval(newLatents) // frees old, evals new
func (a *Array) Free() {
if a != nil {
a.kept = false
}
}
// Eval evaluates this single array and runs cleanup.
func (a *Array) Eval() *Array {
Eval(a)
return a
}
// Valid returns true if the array hasn't been freed.
func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil
}
func int32ToCInt(s []int32) *C.int {
if len(s) == 0 {
return nil
}
return (*C.int)(unsafe.Pointer(&s[0]))
}
// NewArray creates a new MLX array from float32 data
func NewArray(data []float32, shape []int32) *Array {
handle := C.mlx_array_new_data(
unsafe.Pointer(&data[0]),
int32ToCInt(shape),
C.int(len(shape)),
C.MLX_FLOAT32,
)
return newArray(handle)
}
// NewArrayInt32 creates a new MLX array from int32 data
func NewArrayInt32(data []int32, shape []int32) *Array {
handle := C.mlx_array_new_data(
unsafe.Pointer(&data[0]),
int32ToCInt(shape),
C.int(len(shape)),
C.MLX_INT32,
)
return newArray(handle)
}
// NewArrayFloat32 creates a new float32 array from data
func NewArrayFloat32(data []float32, shape []int32) *Array {
return NewArray(data, shape)
}
// Zeros creates an array of zeros with optional dtype (default float32)
func Zeros(shape []int32, dtype ...Dtype) *Array {
res := C.mlx_array_new()
dt := DtypeFloat32
if len(dtype) > 0 {
dt = dtype[0]
}
C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dt), C.default_stream())
return newArray(res)
}
// ZerosLike creates a zeros array with the same dtype as a.
// If shape is provided, uses that shape; otherwise uses a's shape.
func ZerosLike(a *Array, shape ...int32) *Array {
res := C.mlx_array_new()
if len(shape) == 0 {
C.mlx_zeros_like(&res, a.c, C.default_stream())
} else {
dtype := a.Dtype()
C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), C.default_stream())
}
return newArray(res)
}
// Ones creates an array of ones
func Ones(shape ...int32) *Array {
res := C.mlx_array_new()
C.mlx_ones(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}
// Full creates an array filled with a value
func Full(value float32, shape ...int32) *Array {
vals := C.mlx_array_new_float(C.float(value))
res := C.mlx_array_new()
C.mlx_full(&res, int32ToCInt(shape), C.size_t(len(shape)), vals, C.MLX_FLOAT32, C.default_stream())
C.mlx_array_free(vals)
return newArray(res)
}
// Arange creates a range of values
func Arange(start, stop, step float32) *Array {
res := C.mlx_array_new()
C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}
// Linspace creates evenly spaced values
func Linspace(start, stop float32, steps int32) *Array {
res := C.mlx_array_new()
C.mlx_linspace(&res, C.double(start), C.double(stop), C.int(steps), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}
// ============ Math Operations ============
// Add adds two arrays element-wise
func Add(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_add(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// AddRaw is like Add - kept for API compatibility (now identical to Add)
func AddRaw(a, b *Array) *Array {
return Add(a, b)
}
// Sub subtracts two arrays element-wise
func Sub(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_subtract(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Mul multiplies two arrays element-wise
func Mul(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Div divides two arrays element-wise
func Div(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_divide(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Matmul performs matrix multiplication
func Matmul(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_matmul(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// AddMM computes: result = beta*c + alpha*(a @ b)
// This fuses bias addition with matmul into a single op.
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
res := C.mlx_array_new()
C.mlx_addmm(&res, c.c, a.c, b.c, C.float(alpha), C.float(beta), C.default_stream())
return newArray(res)
}
// Linear performs matrix multiplication: a @ weight
func Linear(a, weight *Array) *Array {
return Matmul(a, weight)
}
// Sqrt computes element-wise square root
func Sqrt(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sqrt(&res, a.c, C.default_stream())
return newArray(res)
}
// RSqrt computes element-wise reciprocal square root
func RSqrt(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_rsqrt(&res, a.c, C.default_stream())
return newArray(res)
}
// Erf computes element-wise error function
func Erf(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_erf(&res, a.c, C.default_stream())
return newArray(res)
}
// Exp computes element-wise exponential
func Exp(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_exp(&res, a.c, C.default_stream())
return newArray(res)
}
// Log computes element-wise natural logarithm
func Log(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_log(&res, a.c, C.default_stream())
return newArray(res)
}
// Sin computes element-wise sine
func Sin(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sin(&res, a.c, C.default_stream())
return newArray(res)
}
// Cos computes element-wise cosine
func Cos(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_cos(&res, a.c, C.default_stream())
return newArray(res)
}
// Neg negates the array
func Neg(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_negative(&res, a.c, C.default_stream())
return newArray(res)
}
// Abs computes element-wise absolute value
func Abs(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_abs(&res, a.c, C.default_stream())
return newArray(res)
}
// Square computes element-wise square
func Square(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_square(&res, a.c, C.default_stream())
return newArray(res)
}
// Pow raises a to the power of b element-wise
func Pow(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_power(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Max computes element-wise maximum
func Max(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_maximum(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Min computes element-wise minimum
func Min(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_minimum(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// scalarWithDtype creates a scalar array matching the dtype of a (critical for graph fusion!)
func scalarWithDtype(s float32, a *Array) C.mlx_array {
// Create float32 scalar, then cast to match input dtype
f32 := C.mlx_array_new_float(C.float(s))
dtype := a.Dtype()
if dtype == DtypeFloat32 {
return f32 // No cast needed
}
// Cast to match input dtype
casted := C.mlx_array_new()
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), C.default_stream())
C.mlx_array_free(f32)
return casted
}
// AddScalar adds a scalar to an array (matches dtype for graph fusion)
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
res := C.mlx_array_new()
C.mlx_add(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// MulScalar multiplies an array by a scalar (matches dtype for graph fusion)
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// DivScalar divides an array by a scalar (matches dtype for graph fusion)
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
res := C.mlx_array_new()
C.mlx_divide(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// DivScalarInt divides an int array by an int scalar (regular division, may return float)
func DivScalarInt(a *Array, s int32) *Array {
scalar := C.mlx_array_new_int(C.int(s))
res := C.mlx_array_new()
C.mlx_divide(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// FloorDivideScalar performs integer floor division (a // s), preserving int dtype
func FloorDivideScalar(a *Array, s int32) *Array {
scalar := C.mlx_array_new_int(C.int(s))
res := C.mlx_array_new()
C.mlx_floor_divide(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// ============ Reduction Operations ============
// Sum reduces along an axis
func Sum(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_sum_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
return newArray(res)
}
// SumAll reduces the entire array to a scalar
func SumAll(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sum(&res, a.c, false, C.default_stream())
return newArray(res)
}
// Mean reduces along an axis
func Mean(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_mean_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
return newArray(res)
}
// MeanAll reduces the entire array to a scalar
func MeanAll(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_mean(&res, a.c, false, C.default_stream())
return newArray(res)
}
// Var computes variance along an axis
func Var(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_var_axis(&res, a.c, C.int(axis), C._Bool(keepdims), 0, C.default_stream())
return newArray(res)
}
// Argmax returns indices of maximum values along an axis
func Argmax(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_argmax_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
return newArray(res)
}
// ArgmaxAll returns the index of the maximum element (flattened).
// Triggers cleanup of non-kept arrays.
func ArgmaxAll(a *Array) int32 {
cleanup()
// Flatten, then argmax with keepdims=false
flat := C.mlx_array_new()
C.mlx_flatten(&flat, a.c, 0, -1, C.default_stream())
res := C.mlx_array_new()
C.mlx_argmax(&res, flat, false, C.default_stream())
C.mlx_array_eval(res)
var val C.int32_t
C.mlx_array_item_int32(&val, res)
C.mlx_array_free(flat)
C.mlx_array_free(res)
return int32(val)
}
// Reshape reshapes the array
func Reshape(a *Array, shape ...int32) *Array {
res := C.mlx_array_new()
C.mlx_reshape(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream())
return newArray(res)
}
// Transpose permutes the dimensions
func Transpose(a *Array, axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
res := C.mlx_array_new()
C.mlx_transpose_axes(&res, a.c, &cAxes[0], C.size_t(len(axes)), C.default_stream())
return newArray(res)
}
// AsStrided creates a view with custom strides. Useful for fusing reshape+transpose.
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
res := C.mlx_array_new()
C.mlx_as_strided(&res, a.c, &cShape[0], C.size_t(len(shape)), &cStrides[0], C.size_t(len(strides)), C.size_t(offset), C.default_stream())
return newArray(res)
}
// ExpandDims adds a dimension at the specified axis
func ExpandDims(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_expand_dims(&res, a.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Squeeze removes a dimension at the specified axis
func Squeeze(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_squeeze_axis(&res, a.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Flatten flattens the array to 1D
func Flatten(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_flatten(&res, a.c, 0, -1, C.default_stream())
return newArray(res)
}
// FlattenRange flattens consecutive axes from start_axis to end_axis (intermediates)
func FlattenRange(a *Array, startAxis, endAxis int) *Array {
res := C.mlx_array_new()
C.mlx_flatten(&res, a.c, C.int(startAxis), C.int(endAxis), C.default_stream())
return newArray(res)
}
// View reinterprets the array with a new dtype (no data copy)
func View(a *Array, dtype int) *Array {
res := C.mlx_array_new()
C.mlx_view(&res, a.c, C.mlx_dtype(dtype), C.default_stream())
return newArray(res)
}
// Contiguous returns a contiguous copy of the array
func Contiguous(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_contiguous(&res, a.c, true, C.default_stream())
return newArray(res)
}
// Clip clips values to [min, max]. Pass nil for no bound on that side.
func Clip(a *Array, aMin, aMax *Array) *Array {
res := C.mlx_array_new()
var minH, maxH C.mlx_array
if aMin != nil {
minH = aMin.c
}
if aMax != nil {
maxH = aMax.c
}
C.mlx_clip(&res, a.c, minH, maxH, C.default_stream())
return newArray(res)
}
// ClipScalar clips array values using scalar bounds (matches dtype for graph fusion)
// Pass math.NaN() or set hasMin/hasMax to false for unbounded
func ClipScalar(a *Array, minVal, maxVal float32, hasMin, hasMax bool) *Array {
var minArr, maxArr C.mlx_array
if hasMin {
minArr = scalarWithDtype(minVal, a)
}
if hasMax {
maxArr = scalarWithDtype(maxVal, a)
}
res := C.mlx_array_new()
C.mlx_clip(&res, a.c, minArr, maxArr, C.default_stream())
if hasMin {
C.mlx_array_free(minArr)
}
if hasMax {
C.mlx_array_free(maxArr)
}
return newArray(res)
}
// GreaterEqual returns element-wise a >= b
func GreaterEqual(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_greater_equal(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// LessArray returns element-wise a < b
func LessArray(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_less(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// LogicalAnd returns element-wise a && b
func LogicalAnd(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_logical_and(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// AllClose returns true if all elements of a and b are within tolerance.
// Uses rtol (relative tolerance) and atol (absolute tolerance):
// |a - b| <= atol + rtol * |b|
func AllClose(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream())
return newArray(res)
}
// AllCloseEqualNaN is like AllClose but treats NaN as equal to NaN.
func AllCloseEqualNaN(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream())
return newArray(res)
}
// ArrayEqual returns true if arrays have same shape and all elements are equal.
func ArrayEqual(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_array_equal(&res, a.c, b.c, C.bool(false), C.default_stream())
return newArray(res)
}
// ArrayEqualNaN is like ArrayEqual but treats NaN as equal to NaN.
func ArrayEqualNaN(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_array_equal(&res, a.c, b.c, C.bool(true), C.default_stream())
return newArray(res)
}
// IsClose returns element-wise bool array indicating if values are within tolerance.
// |a - b| <= atol + rtol * |b|
func IsClose(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream())
return newArray(res)
}
// IsCloseEqualNaN is like IsClose but treats NaN as equal to NaN.
func IsCloseEqualNaN(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream())
return newArray(res)
}
// ReduceMax reduces array to max value over all dimensions.
func ReduceMax(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_max(&res, a.c, C.bool(false), C.default_stream())
return newArray(res)
}
// ArangeInt creates an array with values from start to stop with step and specified dtype
func ArangeInt(start, stop, step int32, dtype Dtype) *Array {
res := C.mlx_array_new()
C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.mlx_dtype(dtype), C.default_stream())
return newArray(res)
}
// Concatenate concatenates arrays along an axis
func Concatenate(arrays []*Array, axis int) *Array {
handles := make([]C.mlx_array, len(arrays))
for i, arr := range arrays {
handles[i] = arr.c
}
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
res := C.mlx_array_new()
C.mlx_concatenate_axis(&res, vec, C.int(axis), C.default_stream())
C.mlx_vector_array_free(vec)
return newArray(res)
}
// Concat is a convenience function to concatenate two arrays
func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1 // Default stride of 1
}
res := C.mlx_array_new()
C.mlx_slice(&res, a.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream())
return newArray(res)
}
// SliceStride slices with start:stop:stride like Python a[start:stop:stride]
func SliceStride(a *Array, start, stop, strides []int32) *Array {
cStart := make([]C.int, len(start))
cStop := make([]C.int, len(stop))
cStrides := make([]C.int, len(strides))
for i := range start {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = C.int(strides[i])
}
res := C.mlx_array_new()
C.mlx_slice(&res, a.c, &cStart[0], C.size_t(len(start)), &cStop[0], C.size_t(len(stop)), &cStrides[0], C.size_t(len(strides)), C.default_stream())
return newArray(res)
}
// Tile repeats the array along each dimension
func Tile(a *Array, reps []int32) *Array {
res := C.mlx_array_new()
C.mlx_tile(&res, a.c, int32ToCInt(reps), C.size_t(len(reps)), C.default_stream())
return newArray(res)
}
// BroadcastTo broadcasts an array to a given shape
func BroadcastTo(a *Array, shape []int32) *Array {
res := C.mlx_array_new()
C.mlx_broadcast_to(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream())
return newArray(res)
}
// ============ Neural Network Operations ============
// Softmax computes softmax along an axis
func Softmax(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_softmax_axis(&res, a.c, C.int(axis), false, C.default_stream())
return newArray(res)
}
// Take gathers elements along an axis using indices
func Take(a *Array, indices *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_take_axis(&res, a.c, indices.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Argsort returns indices that would sort the array along an axis
func Argsort(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_argsort_axis(&res, a.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Sigmoid computes element-wise sigmoid
func Sigmoid(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sigmoid(&res, a.c, C.default_stream())
return newArray(res)
}
// ReLU computes element-wise ReLU: max(0, x)
func ReLU(a *Array) *Array {
// ReLU = maximum(x, 0) - mlx-c doesn't have mlx_relu, but we can use maximum
zero := C.mlx_array_new_float(0.0)
res := C.mlx_array_new()
C.mlx_maximum(&res, a.c, zero, C.default_stream())
C.mlx_array_free(zero)
return newArray(res)
}
// SiLU computes element-wise SiLU (Swish): x * sigmoid(x)
func SiLU(a *Array) *Array {
// SiLU = x * sigmoid(x)
sig := C.mlx_array_new()
C.mlx_sigmoid(&sig, a.c, C.default_stream())
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, sig, C.default_stream())
C.mlx_array_free(sig)
return newArray(res)
}
// GELU computes element-wise GELU (Gaussian Error Linear Unit)
// GELU(x) = x * 0.5 * (1 + erf(x / sqrt(2)))
func GELU(a *Array) *Array {
sqrt2 := C.mlx_array_new_float(1.4142135623730951)
scaled := C.mlx_array_new()
C.mlx_divide(&scaled, a.c, sqrt2, C.default_stream())
erfd := C.mlx_array_new()
C.mlx_erf(&erfd, scaled, C.default_stream())
one := C.mlx_array_new_float(1.0)
erfdPlusOne := C.mlx_array_new()
C.mlx_add(&erfdPlusOne, erfd, one, C.default_stream())
half := C.mlx_array_new_float(0.5)
halfErfdPlusOne := C.mlx_array_new()
C.mlx_multiply(&halfErfdPlusOne, half, erfdPlusOne, C.default_stream())
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, halfErfdPlusOne, C.default_stream())
C.mlx_array_free(sqrt2)
C.mlx_array_free(scaled)
C.mlx_array_free(erfd)
C.mlx_array_free(one)
C.mlx_array_free(erfdPlusOne)
C.mlx_array_free(half)
C.mlx_array_free(halfErfdPlusOne)
return newArray(res)
}
// Tanh computes element-wise tanh
func Tanh(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_tanh(&res, a.c, C.default_stream())
return newArray(res)
}
// RMSNorm computes RMS normalization using mlx.fast
func RMSNorm(x, weight *Array, eps float32) *Array {
res := C.mlx_array_new()
C.mlx_fast_rms_norm(&res, x.c, weight.c, C.float(eps), C.default_stream())
return newArray(res)
}
// RMSNormNoWeight applies RMS normalization without a weight
// x * rsqrt(mean(x^2) + eps)
// Uses mlx_fast_rms_norm with ones weight for f32 accumulation precision
func RMSNormNoWeight(x *Array, eps float32) *Array {
// Create weight of ones matching last dimension
lastDim := x.Shape()[len(x.Shape())-1]
ones := AsType(Full(1.0, lastDim), x.Dtype())
return RMSNorm(x, ones, eps)
}
// RoPE applies rotary position embeddings using mlx.fast
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
res := C.mlx_array_new()
optBase := C.mlx_optional_float{value: C.float(base), has_value: true}
C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), C.mlx_array{}, C.default_stream())
return newArray(res)
}
// RoPEWithFreqs applies rotary position embeddings with custom frequencies (for YaRN)
// freqs is required - use RoPE() if you don't have custom frequencies
func RoPEWithFreqs(x, freqs *Array, dims int, traditional bool, scale float32, offset int) *Array {
res := C.mlx_array_new()
optBase := C.mlx_optional_float{has_value: false} // No base when using freqs
C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), freqs.c, C.default_stream())
return newArray(res)
}
// ============ Indexing ============
// EmbeddingLookup performs embedding lookup (gathers from table)
// table: [vocab_size, hidden_size], indices: [batch, seq_len]
// returns: [batch, seq_len, hidden_size]
func EmbeddingLookup(table, indices *Array) *Array {
return Take(table, indices, 0)
}
// Gather gathers elements using indices - simplified to use take axis 0
func Gather(a, indices *Array) *Array {
return Take(a, indices, 0)
}
// ============ Array Properties ============
// Ndim returns the number of dimensions
func (a *Array) Ndim() int {
return int(C.mlx_array_ndim(a.c))
}
// Size returns the total number of elements
func (a *Array) Size() int {
return int(C.mlx_array_size(a.c))
}
// IsContiguous returns whether the array's data is contiguous in memory.
// Non-contiguous arrays (e.g., from SliceStride) must call Contiguous() before Data().
func (a *Array) IsContiguous() bool {
var res C.bool
C._mlx_array_is_contiguous(&res, a.c)
return bool(res)
}
// Dim returns the size of a dimension
func (a *Array) Dim(axis int) int32 {
return int32(C.mlx_array_dim(a.c, C.int(axis)))
}
// Shape returns the shape as a slice
func (a *Array) Shape() []int32 {
ndim := a.Ndim()
shape := make([]int32, ndim)
for i := 0; i < ndim; i++ {
shape[i] = a.Dim(i)
}
return shape
}
// IsValid returns true if the array hasn't been freed
func (a *Array) IsValid() bool {
return a != nil && a.c.ctx != nil
}
// Dtype returns the data type
func (a *Array) Dtype() Dtype {
return Dtype(C.mlx_array_dtype(a.c))
}
// Nbytes returns the total size in bytes
func (a *Array) Nbytes() int64 {
return int64(a.Size()) * a.Dtype().ItemSize()
}
// ItemSize returns the size in bytes of one element for this dtype
func (d Dtype) ItemSize() int64 {
switch d {
case DtypeBool, DtypeUint8, DtypeInt8:
return 1
case DtypeUint16, DtypeInt16, DtypeFloat16, DtypeBFloat16:
return 2
case DtypeUint32, DtypeInt32, DtypeFloat32:
return 4
case DtypeUint64, DtypeInt64, DtypeFloat64, DtypeComplex64:
return 8
default:
return 4
}
}
// ============ Data Access ============
// Data copies the float32 data out of the array.
// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first.
// Note: Arrays of other dtypes (bf16, f16, etc) are automatically converted to float32.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Data() []float32 {
cleanup()
size := a.Size()
if size == 0 {
return nil
}
arr := a
if a.Dtype() != DtypeFloat32 {
arr = AsType(a, DtypeFloat32)
arr.Eval()
// Cast array will be cleaned up on next Eval
}
ptr := C.mlx_array_data_float32(arr.c)
if ptr == nil {
return nil
}
data := make([]float32, size)
copy(data, unsafe.Slice((*float32)(unsafe.Pointer(ptr)), size))
return data
}
// Item returns the scalar value from a 0-dimensional array.
// Converts to float32 if necessary. Triggers cleanup.
func (a *Array) Item() float32 {
data := a.Data() // Data() calls cleanup()
if len(data) == 0 {
return 0
}
return data[0]
}
// DataInt32 copies the int32 data out of the array.
// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) DataInt32() []int32 {
cleanup()
size := a.Size()
if size == 0 {
return nil
}
ptr := C.mlx_array_data_int32(a.c)
if ptr == nil {
return nil
}
data := make([]int32, size)
copy(data, unsafe.Slice((*int32)(unsafe.Pointer(ptr)), size))
return data
}
// ItemInt32 gets a single scalar value efficiently (no array copy).
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) ItemInt32() int32 {
cleanup()
var val C.int32_t
C.mlx_array_item_int32(&val, a.c)
return int32(val)
}
// ============ Utility ============
// String returns a string representation
func (a *Array) String() string {
shape := a.Shape()
size := a.Size()
if size <= 20 {
data := a.Data()
return fmt.Sprintf("Array(shape=%v, data=%v)", shape, data)
}
return fmt.Sprintf("Array(shape=%v, size=%d)", shape, size)
}
// ============ Safetensors Support ============
// NewArrayFromBytes creates an array from raw bytes (for safetensors)
func NewArrayFromBytes(data []byte, shape []int32, dtype Dtype) *Array {
cData := unsafe.Pointer(&data[0])
intShape := make([]C.int, len(shape))
for i, s := range shape {
intShape[i] = C.int(s)
}
handle := C.mlx_array_new_data(cData, &intShape[0], C.int(len(shape)), C.mlx_dtype(dtype))
return newArray(handle)
}
// ============ Device Control ============
// SetDefaultDeviceGPU sets the default device to GPU (Metal)
func SetDefaultDeviceGPU() {
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
C.mlx_set_default_device(dev)
C.mlx_device_free(dev)
}
// SetDefaultDeviceCPU sets the default device to CPU
func SetDefaultDeviceCPU() {
dev := C.mlx_device_new_type(C.MLX_CPU, 0)
C.mlx_set_default_device(dev)
C.mlx_device_free(dev)
}
// MetalIsAvailable returns true if Metal GPU is available
func MetalIsAvailable() bool {
var available C._Bool
C.mlx_metal_is_available(&available)
return bool(available)
}
// MetalStartCapture starts a GPU trace capture to the given file path.
// The path must not already exist. Run with MTL_CAPTURE_ENABLED=1 env var.
// Open the resulting .gputrace file in Xcode for analysis.
func MetalStartCapture(path string) {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
C.mlx_metal_start_capture(cPath)
}
// MetalStopCapture stops the current GPU trace capture.
func MetalStopCapture() {
C.mlx_metal_stop_capture()
}
// GPUIsAvailable returns true if any GPU (Metal or CUDA) is available
func GPUIsAvailable() bool {
// On Linux with CUDA build, GPU is available
// On macOS, check Metal availability
if MetalIsAvailable() {
return true
}
// CUDA is available if we compiled with CUDA support (Linux)
return runtime.GOOS == "linux"
}
// GetDefaultDeviceType returns the current default device (0=CPU, 1=GPU)
func GetDefaultDeviceType() int {
var dev C.mlx_device
C.mlx_get_default_device(&dev)
var devType C.mlx_device_type
C.mlx_device_get_type(&devType, dev)
C.mlx_device_free(dev)
return int(devType)
}
// Synchronize waits for all GPU operations to complete
func Synchronize() {
C.mlx_synchronize(C.default_stream())
}
// ScaledDotProductAttention computes optimized attention using GPU kernel
// Q, K, V should be [batch, heads, seq, head_dim]
func ScaledDotProductAttention(q, k, v *Array, scale float32, causalMask bool) *Array {
res := C.mlx_array_new()
maskMode := "" // empty string for no mask
if causalMask {
maskMode = "causal"
}
cMaskMode := C.CString(maskMode)
defer C.free(unsafe.Pointer(cMaskMode))
C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, C.mlx_array{}, C.mlx_array{}, C.default_stream())
return newArray(res)
}
// ScaledDotProductAttentionWithSinks computes attention with sinks support
// maskMode: "causal", "sliding_window", or "" for none
// mask: optional attention mask array (nil for none)
// sinks: attention sinks array (nil for none)
func ScaledDotProductAttentionWithSinks(q, k, v *Array, scale float32, maskMode string, mask, sinks *Array) *Array {
res := C.mlx_array_new()
cMaskMode := C.CString(maskMode)
defer C.free(unsafe.Pointer(cMaskMode))
var maskH, sinksH C.mlx_array
if mask != nil {
maskH = mask.c
}
if sinks != nil {
sinksH = sinks.c
}
C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, maskH, sinksH, C.default_stream())
return newArray(res)
}
// ============ Native Safetensors Loading ============
// SafetensorsFile represents a loaded safetensors file
type SafetensorsFile struct {
arrays C.mlx_map_string_to_array
metadata C.mlx_map_string_to_string
}
// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader
// Note: Uses CPU stream because Load primitive only runs on CPU
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var arrays C.mlx_map_string_to_array
var metadata C.mlx_map_string_to_string
if C.mlx_load_safetensors(&arrays, &metadata, cPath, C.cpu_stream()) != 0 {
return nil, fmt.Errorf("failed to load safetensors: %s", path)
}
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
}
// Get retrieves a tensor by name
func (s *SafetensorsFile) Get(name string) *Array {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
var arr C.mlx_array
if C.mlx_map_string_to_array_get(&arr, s.arrays, cName) != 0 {
return nil
}
if arr.ctx == nil {
return nil
}
return newArray(arr)
}
// Set replaces a tensor in the map (like Python's weights[k] = v)
func (s *SafetensorsFile) Set(name string, arr *Array) {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_map_string_to_array_insert(s.arrays, cName, arr.c)
}
// Count returns the number of tensors (not directly available, would need iterator)
func (s *SafetensorsFile) Count() int {
// mlx-c doesn't have a direct count - would need to iterate
return 0
}
// Free releases the safetensors file
func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_array_free(s.arrays)
C.mlx_map_string_to_string_free(s.metadata)
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
// Note: Uses CPU stream because Load primitive only runs on CPU
func LoadNpy(path string) (*Array, error) {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var arr C.mlx_array
if C.mlx_load(&arr, cPath, C.cpu_stream()) != 0 {
return nil, fmt.Errorf("failed to load npy: %s", path)
}
if arr.ctx == nil {
return nil, fmt.Errorf("failed to load npy: %s", path)
}
return newArray(arr), nil
}
// ============ Slice Update ============
// SliceUpdate updates a slice of the array with new values
func SliceUpdate(a, update *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1 // Default stride of 1
}
res := C.mlx_array_new()
C.mlx_slice_update(&res, a.c, update.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream())
return newArray(res)
}
// SliceUpdateInplace updates a slice and returns a new array.
// Note: Despite the name, this is NOT in-place - MLX arrays are immutable.
// The caller must use the returned value.
func SliceUpdateInplace(a, update *Array, start, stop []int32) *Array {
return SliceUpdate(a, update, start, stop)
}
// ============ Optimized Operations ============
// SampleArgmax gets the last logit position and returns argmax (fused operation)
func SampleArgmax(logits *Array) int32 {
result := Argmax(logits, -1, false)
return result.ItemInt32()
}
// ArgmaxKeepArray returns argmax as an Array (for pipelining, no sync)
// This is like mlx-lm's sampler that returns y as an array, not .item()
func ArgmaxKeepArray(logits *Array) *Array {
// For greedy decoding: logits shape is [1, 1, vocab]
// We want argmax over vocab dimension, return shape []
return Argmax(logits, -1, false)
}
// RandomState is the global PRNG state, analogous to mx.random.state in Python.
// It's a slice containing a single key array. Random functions use and update this state.
//
// Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior.
// All random functions that use global state acquire this lock.
var RandomState = []*Array{nil}
var randomStateMu sync.Mutex
func init() {
// Lock main goroutine to OS thread for CUDA context stability.
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
runtime.LockOSThread()
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
Keep(RandomState[0]) // Global state should persist
}
// RandomKey creates a PRNG key from a seed
func RandomKey(seed uint64) *Array {
var res C.mlx_array
C.mlx_random_key(&res, C.uint64_t(seed))
return newArray(res)
}
// RandomSplit splits a PRNG key into two new keys
func RandomSplit(key *Array) (*Array, *Array) {
var key1, key2 C.mlx_array
C.mlx_random_split(&key1, &key2, key.c, C.default_stream())
return newArray(key1), newArray(key2)
}
// RandomCategoricalWithKey samples from categorical distribution using provided key.
func RandomCategoricalWithKey(logits, key *Array, axis int, numSamples int) *Array {
res := C.mlx_array_new()
C.mlx_random_categorical_num_samples(&res, logits.c, C.int(axis), C.int(numSamples), key.c, C.default_stream())
return newArray(res)
}
// RandomCategorical samples using global RandomState.
// For simple scripts - production code should use RandomCategoricalWithKey with explicit key management.
func RandomCategorical(logits *Array, axis int, numSamples int) *Array {
randomStateMu.Lock()
oldKey := RandomState[0]
key1, key2 := RandomSplit(oldKey)
Keep(key1) // key1 becomes the new global state
oldKey.Free()
RandomState[0] = key1
randomStateMu.Unlock()
return RandomCategoricalWithKey(logits, key2, axis, numSamples)
}
// RandomNormal creates a random normal (Gaussian) tensor
func RandomNormal(shape []int32, seed uint64) *Array {
key := RandomKey(seed)
res := C.mlx_array_new()
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, 0.0, 1.0, key.c, C.default_stream())
return newArray(res)
}
// RandomUniform generates uniform random values in [0, 1) with the given shape
func RandomUniform(shape []int32, seed uint64) *Array {
key := RandomKey(seed)
low := C.mlx_array_new_float(0.0)
high := C.mlx_array_new_float(1.0)
res := C.mlx_array_new()
C.mlx_random_uniform(&res, low, high, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, key.c, C.default_stream())
C.mlx_array_free(low)
C.mlx_array_free(high)
return newArray(res)
}
// Conv2d performs 2D convolution
// input: [N, H, W, C], weight: [O, kH, kW, C] (MLX uses NHWC layout)
// Returns: [N, H', W', O]
func Conv2d(input, weight *Array, stride, padding int32) *Array {
res := C.mlx_array_new()
C.mlx_conv2d(&res, input.c, weight.c, C.int(stride), C.int(stride), C.int(padding), C.int(padding), 1, 1, 1, C.default_stream())
return newArray(res)
}
// Conv3d performs 3D convolution
// input: [N, D, H, W, C], weight: [O, kD, kH, kW, C] (MLX uses NDHWC layout)
// Returns: [N, D', H', W', O]
func Conv3d(input, weight *Array, strideD, strideH, strideW, padD, padH, padW int32) *Array {
res := C.mlx_array_new()
C.mlx_conv3d(&res, input.c, weight.c, C.int(strideD), C.int(strideH), C.int(strideW), C.int(padD), C.int(padH), C.int(padW), 1, 1, 1, 1, C.default_stream())
return newArray(res)
}
// ============ Compilation Control ============
// EnableCompile enables global compilation/graph fusion
func EnableCompile() {
C.mlx_enable_compile()
}
// DisableCompile disables global compilation
func DisableCompile() {
C.mlx_disable_compile()
}
// SetCompileMode sets the compile mode
// 0=disabled, 1=no_simplify, 2=no_fuse, 3=enabled
func SetCompileMode(mode int) {
C.mlx_set_compile_mode(C.mlx_compile_mode(mode))
}
// ============ Stream Control ============
// Stream represents an MLX execution stream
type Stream struct {
c C.mlx_stream
}
// NewStream creates a new execution stream on the default device
func NewStream() *Stream {
var dev C.mlx_device
C.mlx_get_default_device(&dev)
stream := C.mlx_stream_new_device(dev)
C.mlx_device_free(dev)
return &Stream{c: stream}
}
// Free releases the stream
func (s *Stream) Free() {
if s.c.ctx != nil {
C.mlx_stream_free(s.c)
s.c.ctx = nil
}
}
// SetDefaultStream sets the default stream for operations
func SetDefaultStream(s *Stream) {
C.mlx_set_default_stream(s.c)
C.set_default_stream(s.c) // Also update our cached stream
}
// GetDefaultStream returns the current default stream
func GetDefaultStream() *Stream {
var stream C.mlx_stream
var dev C.mlx_device
C.mlx_get_default_device(&dev)
C.mlx_get_default_stream(&stream, dev)
C.mlx_device_free(dev)
return &Stream{c: stream}
}
// SynchronizeStream waits for all operations on the stream to complete
func SynchronizeStream(s *Stream) {
C.mlx_synchronize(s.c)
}
// ============ Metal Memory Control ============
// MetalGetCacheMemory returns the current cache memory usage in bytes
func MetalGetCacheMemory() uint64 {
var size C.size_t
C.mlx_get_cache_memory(&size)
return uint64(size)
}
// MetalGetPeakMemory returns the peak memory usage in bytes
func MetalGetPeakMemory() uint64 {
var size C.size_t
C.mlx_get_peak_memory(&size)
return uint64(size)
}
// MetalResetPeakMemory resets the peak memory counter
func MetalResetPeakMemory() {
C.mlx_reset_peak_memory()
}
// MetalSetWiredLimit sets the wired memory limit and returns the previous limit
// This keeps tensors pinned in GPU memory for faster access
func MetalSetWiredLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_wired_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// MetalGetActiveMemory returns the current active memory usage in bytes
func MetalGetActiveMemory() uint64 {
var size C.size_t
C.mlx_get_active_memory(&size)
return uint64(size)
}
// ClearCache clears the MLX memory cache
func ClearCache() {
C.mlx_clear_cache()
}
// SetCacheLimit sets the free cache limit in bytes
// Setting to 0 disables caching (useful for memory-constrained generation)
// Returns the previous cache limit
func SetCacheLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_cache_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// SetMemoryLimit sets the overall memory limit in bytes
// This is a guideline for maximum memory during graph evaluation.
// When Metal is available, defaults to 1.5x the max recommended working set.
// Returns the previous memory limit
func SetMemoryLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_memory_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// GetMemoryLimit returns the current memory limit in bytes
func GetMemoryLimit() uint64 {
var size C.size_t
C.mlx_get_memory_limit(&size)
return uint64(size)
}
// ============ MoE Operations ============
// GatherMM performs gather matrix multiplication for MoE
// a: input, b: weight matrices
// lhsIndices, rhsIndices: optional expert selection indices (nil for none)
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
var lhs, rhs C.mlx_array
if lhsIndices != nil {
lhs = lhsIndices.c
}
if rhsIndices != nil {
rhs = rhsIndices.c
}
res := C.mlx_array_new()
C.mlx_gather_mm(&res, a.c, b.c, lhs, rhs, C._Bool(sortedIndices), C.default_stream())
return newArray(res)
}
// GatherQMM performs quantized gather matrix multiplication for MoE
// Used for MXFP4 and other quantized MoE inference
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.c
}
if lhsIndices != nil {
lhs = lhsIndices.c
}
if rhsIndices != nil {
rhs = rhsIndices.c
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_array_new()
C.mlx_gather_qmm(&res, x.c, w.c, scales.c, b, lhs, rhs, C._Bool(transpose), optGroupSize, optBits, cMode, C._Bool(sortedIndices), C.default_stream())
return newArray(res)
}
// ============ Quantization ============
// Quantize quantizes weights to specified bits per element.
// Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default) or "mxfp4"
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of 3 arrays: [weights, scales, biases]
var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1)
C.mlx_vector_array_get(&w2, res, 2)
C.mlx_vector_array_free(res)
return newArray(w0), newArray(w1), newArray(w2)
}
// Dequantize reconstructs weights from quantized form.
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default) or "mxfp4"
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
var b C.mlx_array
if biases != nil {
b = biases.c
}
res := C.mlx_array_new()
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
return newArray(res)
}
// QuantizedMatmul performs matrix multiplication with quantized weights.
// x: input tensor [batch..., in_features]
// w: quantized weights
// scales, biases: from Quantize
// transpose: if true, compute x @ w.T (typical for Linear layers)
// groupSize, bits, mode: must match what was used in Quantize
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b C.mlx_array
if biases != nil {
b = biases.c
}
res := C.mlx_array_new()
C.mlx_quantized_matmul(&res, x.c, w.c, scales.c, b, C._Bool(transpose), optGroupSize, optBits, cMode, C.default_stream())
return newArray(res)
}
// ============ Sorting and Top-K ============
// TopK returns the k largest elements along an axis
func TopK(a *Array, k int, axis int) *Array {
res := C.mlx_array_new()
C.mlx_topk_axis(&res, a.c, C.int(k), C.int(axis), C.default_stream())
return newArray(res)
}
// Argpartition returns indices for partial sort (k-th smallest first)
func Argpartition(a *Array, kth int, axis int) *Array {
res := C.mlx_array_new()
C.mlx_argpartition_axis(&res, a.c, C.int(kth), C.int(axis), C.default_stream())
return newArray(res)
}
// TakeAlongAxis takes elements from array using indices along axis
func TakeAlongAxis(a, indices *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_take_along_axis(&res, a.c, indices.c, C.int(axis), C.default_stream())
return newArray(res)
}
// PutAlongAxis puts values into array at indices along axis
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_put_along_axis(&res, a.c, indices.c, values.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Cumsum computes cumulative sum along an axis
func Cumsum(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_cumsum(&res, a.c, C.int(axis), false, false, C.default_stream())
return newArray(res)
}
// Where selects elements: condition ? a : b
func Where(condition, a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_where(&res, condition.c, a.c, b.c, C.default_stream())
return newArray(res)
}
// LessScalar returns element-wise a < scalar
func LessScalar(a *Array, s float32) *Array {
scalar := C.mlx_array_new_float(C.float(s))
res := C.mlx_array_new()
C.mlx_less(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// FullDtype creates an array filled with a value with specific dtype
func FullDtype(value float32, dtype Dtype, shape ...int32) *Array {
intShape := make([]C.int, len(shape))
for i, s := range shape {
intShape[i] = C.int(s)
}
vals := C.mlx_array_new_float(C.float(value))
res := C.mlx_array_new()
C.mlx_full(&res, &intShape[0], C.size_t(len(shape)), vals, C.mlx_dtype(dtype), C.default_stream())
C.mlx_array_free(vals)
return newArray(res)
}
// AsType casts an array to a different dtype
func AsType(a *Array, dtype Dtype) *Array {
res := C.mlx_array_new()
C.mlx_astype(&res, a.c, C.mlx_dtype(dtype), C.default_stream())
return newArray(res)
}
// ToBFloat16 casts an array to bfloat16
func ToBFloat16(a *Array) *Array {
return AsType(a, DtypeBFloat16)
}
// ============ VibeVoice Helper Functions ============
// NewScalarArray creates a true 0-dimensional scalar array from a float32 value
func NewScalarArray(value float32) *Array {
return newArray(C.mlx_array_new_float(C.float(value)))
}
// Global random seed counter for RandN
var randnSeedCounter uint64 = uint64(time.Now().UnixNano())
// RandN creates an array of random samples from a standard normal distribution
func RandN(shape []int32) *Array {
// Use incrementing seed for unique random values each call
seed := atomic.AddUint64(&randnSeedCounter, 1)
return RandomNormal(shape, seed)
}
// Pad pads an array with zeros
// paddings: [before_0, after_0, before_1, after_1, ...] for each dimension
func Pad(a *Array, paddings []int32) *Array {
numAxes := len(paddings) / 2
// Convert to low/high pairs
lowPad := make([]C.int, numAxes)
highPad := make([]C.int, numAxes)
for i := 0; i < numAxes; i++ {
lowPad[i] = C.int(paddings[i*2])
highPad[i] = C.int(paddings[i*2+1])
}
zero := C.mlx_array_new_float(0.0)
res := C.mlx_array_new()
// mlx_pad takes axes, low, high arrays
axes := make([]C.int, numAxes)
for i := 0; i < numAxes; i++ {
axes[i] = C.int(i)
}
cMode := C.CString("constant")
defer C.free(unsafe.Pointer(cMode))
C.mlx_pad(&res, a.c, &axes[0], C.size_t(numAxes), &lowPad[0], C.size_t(numAxes), &highPad[0], C.size_t(numAxes), zero, cMode, C.default_stream())
C.mlx_array_free(zero)
return newArray(res)
}
// Conv1d performs 1D convolution
// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout)
// bias: optional (nil for no bias)
func Conv1d(x, weight *Array, bias *Array, stride int32) *Array {
res := C.mlx_array_new()
C.mlx_conv1d(&res, x.c, weight.c, C.int(stride), C.int(0), C.int(1), 1, C.default_stream())
// Apply bias if provided
if bias != nil {
biased := C.mlx_array_new()
C.mlx_add(&biased, res, bias.c, C.default_stream())
C.mlx_array_free(res)
return newArray(biased)
}
return newArray(res)
}
// ConvTranspose1d performs transposed 1D convolution
// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout)
// bias: optional (nil for no bias)
func ConvTranspose1d(x, weight *Array, bias *Array, stride int32) *Array {
res := C.mlx_array_new()
// stride, padding, dilation, output_padding, groups
C.mlx_conv_transpose1d(&res, x.c, weight.c, C.int(stride), 0, 1, 0, 1, C.default_stream())
// Apply bias if provided
if bias != nil {
biased := C.mlx_array_new()
C.mlx_add(&biased, res, bias.c, C.default_stream())
C.mlx_array_free(res)
return newArray(biased)
}
return newArray(res)
}
// DepthwiseConv1d performs depthwise 1D convolution (groups=Cin)
// x: [B, L, C], weight: [1, K, C] (groups = C)
// bias: optional (nil for no bias)
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
// Get number of input channels for groups
shape := x.Shape()
groups := int(shape[len(shape)-1])
res := C.mlx_array_new()
C.mlx_conv1d(&res, x.c, weight.c, 1, 0, 1, C.int(groups), C.default_stream())
// Apply bias if provided
if bias != nil {
biased := C.mlx_array_new()
C.mlx_add(&biased, res, bias.c, C.default_stream())
C.mlx_array_free(res)
return newArray(biased)
}
return newArray(res)
}
// SliceAxis extracts a slice along a specific axis
func SliceAxis(a *Array, axis int, start, stop int32) *Array {
shape := a.Shape()
// Build start and stop indices for all dimensions
starts := make([]int32, len(shape))
stops := make([]int32, len(shape))
for i := range shape {
if i == axis {
starts[i] = start
stops[i] = stop
} else {
starts[i] = 0
stops[i] = shape[i]
}
}
return Slice(a, starts, stops)
}
// Tri creates a lower triangular matrix
func Tri(n, m int32, k int) *Array {
res := C.mlx_array_new()
C.mlx_tri(&res, C.int(n), C.int(m), C.int(k), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}