Files
ollama/x/imagegen/safetensors/loader.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

171 lines
4.5 KiB
Go

//go:build mlx
package safetensors
import (
"fmt"
"reflect"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// LoadModule loads weights into a struct using reflection and struct tags.
//
// Struct tags use the format: `weight:"path[,optional]"`
// - path: the weight name suffix (appended to prefix)
// - optional: if present, missing weights don't cause errors
// - "-": skip this field entirely
// - no tag on struct pointer: recurse with current prefix
// - no tag on *mlx.Array: skip (computed fields don't need loading)
//
// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes.
// The slice must be pre-allocated to the correct length.
//
// Example:
//
// type Attention struct {
// QProj *nn.Linear `weight:"self_attn.q_proj"`
// KProj *nn.Linear `weight:"self_attn.k_proj"`
// Cache *mlx.Array // no tag = skipped (computed field)
// }
//
// err := LoadModule(&attn, weights, "model.layers.0")
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.IsNil() {
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind())
}
var errs []string
loadStruct(v, weights, prefix, &errs, false)
if len(errs) > 0 {
return fmt.Errorf("LoadModule: missing weights:\n %s", strings.Join(errs, "\n "))
}
return nil
}
// loadStruct recursively loads weights into a struct value.
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldVal := v.Field(i)
// Skip unexported fields
if !fieldVal.CanSet() {
continue
}
// Parse tag
tag, hasTag := field.Tag.Lookup("weight")
if tag == "-" {
continue
}
// Parse tag options
optional := parentOptional
weightPath := tag
if idx := strings.Index(tag, ","); idx != -1 {
weightPath = tag[:idx]
if strings.Contains(tag[idx+1:], "optional") {
optional = true
}
}
// Build full path
fullPath := joinPath(prefix, weightPath)
// For struct pointers without a tag, recurse with current prefix
if !hasTag && fieldVal.Kind() == reflect.Ptr {
elemType := fieldVal.Type().Elem()
if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) {
if fieldVal.IsNil() {
fieldVal.Set(reflect.New(elemType))
}
loadStruct(fieldVal.Elem(), weights, prefix, errs, optional)
continue
}
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
elemType := fieldVal.Type().Elem()
// *mlx.Array - load directly (but skip if no tag - computed fields)
if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) {
if !hasTag {
continue // no tag on *mlx.Array = computed field, skip
}
arr, err := weights.GetTensor(fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath)
}
continue
}
fieldVal.Set(reflect.ValueOf(arr))
continue
}
// Pointer to struct - allocate and recurse
if elemType.Kind() == reflect.Struct {
if optional && !hasWeightsWithPrefix(weights, fullPath) {
continue
}
if fieldVal.IsNil() {
fieldVal.Set(reflect.New(elemType))
}
loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional)
}
case reflect.Slice:
elemType := fieldVal.Type().Elem()
if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
loadSlice(fieldVal, weights, fullPath, errs)
}
}
}
}
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
for _, name := range weights.ListTensors() {
if strings.HasPrefix(name, prefix+".") || name == prefix {
return true
}
}
return false
}
// loadSlice loads weights into each element of a slice of struct pointers.
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
elemStructType := v.Type().Elem().Elem()
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if elem.IsNil() {
elem.Set(reflect.New(elemStructType))
}
loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false)
}
}
// joinPath joins path segments with dots, handling empty segments.
func joinPath(prefix, suffix string) string {
if prefix == "" {
return suffix
}
if suffix == "" {
return prefix
}
return prefix + "." + suffix
}