mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
281 lines
7.1 KiB
Go
281 lines
7.1 KiB
Go
//go:build mlx
|
|
|
|
package safetensors
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// SafetensorHeader represents the JSON header of a safetensors file
|
|
type SafetensorHeader map[string]TensorInfo
|
|
|
|
// TensorInfo contains metadata about a tensor
|
|
type TensorInfo struct {
|
|
Dtype string `json:"dtype"`
|
|
Shape []int32 `json:"shape"`
|
|
DataOffsets [2]int `json:"data_offsets"`
|
|
}
|
|
|
|
// parseSafetensorHeader reads only the JSON header from a safetensors file.
|
|
func parseSafetensorHeader(path string) (SafetensorHeader, error) {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open file: %w", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
var headerSize uint64
|
|
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
|
return nil, fmt.Errorf("failed to read header size: %w", err)
|
|
}
|
|
|
|
headerBytes := make([]byte, headerSize)
|
|
if _, err := f.Read(headerBytes); err != nil {
|
|
return nil, fmt.Errorf("failed to read header: %w", err)
|
|
}
|
|
|
|
var header SafetensorHeader
|
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
|
return nil, fmt.Errorf("failed to parse header: %w", err)
|
|
}
|
|
|
|
delete(header, "__metadata__")
|
|
return header, nil
|
|
}
|
|
|
|
// dtypeFromString converts safetensors dtype string to mlx.Dtype
|
|
func dtypeFromString(s string) mlx.Dtype {
|
|
switch strings.ToUpper(s) {
|
|
case "F32", "FLOAT32":
|
|
return mlx.DtypeFloat32
|
|
case "F16", "FLOAT16":
|
|
return mlx.DtypeFloat16
|
|
case "BF16", "BFLOAT16":
|
|
return mlx.DtypeBFloat16
|
|
case "I32", "INT32":
|
|
return mlx.DtypeInt32
|
|
case "I64", "INT64":
|
|
return mlx.DtypeInt64
|
|
case "U8", "UINT8":
|
|
return mlx.DtypeUint8
|
|
default:
|
|
return mlx.DtypeFloat32
|
|
}
|
|
}
|
|
|
|
// ModelWeights manages weights from multiple safetensor files.
|
|
type ModelWeights struct {
|
|
dir string // Model directory
|
|
tensorFiles map[string]string // tensor name -> file path
|
|
tensorInfo map[string]TensorInfo // tensor name -> metadata
|
|
nativeCache map[string]*mlx.SafetensorsFile // file path -> loaded native handle
|
|
cache map[string]*mlx.Array // tensor name -> array (after Load)
|
|
}
|
|
|
|
// LoadModelWeights scans safetensor files and builds a tensor index.
|
|
// This only reads JSON headers, not tensor data.
|
|
func LoadModelWeights(dir string) (*ModelWeights, error) {
|
|
mw := &ModelWeights{
|
|
dir: dir,
|
|
tensorFiles: make(map[string]string),
|
|
tensorInfo: make(map[string]TensorInfo),
|
|
nativeCache: make(map[string]*mlx.SafetensorsFile),
|
|
}
|
|
|
|
entries, err := os.ReadDir(dir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read directory: %w", err)
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
|
path := filepath.Join(dir, entry.Name())
|
|
|
|
header, err := parseSafetensorHeader(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse %s: %w", entry.Name(), err)
|
|
}
|
|
|
|
for name, info := range header {
|
|
mw.tensorFiles[name] = path
|
|
mw.tensorInfo[name] = info
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(mw.tensorFiles) == 0 {
|
|
return nil, fmt.Errorf("no safetensor files found in %s", dir)
|
|
}
|
|
|
|
return mw, nil
|
|
}
|
|
|
|
// Load loads all tensors into cache with the specified dtype.
|
|
// If dtype is 0, tensors are loaded in their original dtype.
|
|
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
|
|
// or native loading when tensors are already in the target dtype.
|
|
func (mw *ModelWeights) Load(dtype mlx.Dtype) error {
|
|
if dtype == 0 {
|
|
return mw.loadNative()
|
|
}
|
|
|
|
// Check if any tensor needs conversion
|
|
needsConversion := false
|
|
for name := range mw.tensorFiles {
|
|
info := mw.tensorInfo[name]
|
|
if dtypeFromString(info.Dtype) != dtype {
|
|
needsConversion = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if needsConversion {
|
|
return mw.loadStreaming(dtype)
|
|
}
|
|
return mw.loadNative()
|
|
}
|
|
|
|
// loadNative loads all tensors using the native memory-mapped loader.
|
|
func (mw *ModelWeights) loadNative() error {
|
|
mw.cache = make(map[string]*mlx.Array)
|
|
|
|
fileToTensors := make(map[string][]string)
|
|
for name, path := range mw.tensorFiles {
|
|
fileToTensors[path] = append(fileToTensors[path], name)
|
|
}
|
|
|
|
for path, names := range fileToTensors {
|
|
native, err := mlx.LoadSafetensorsNative(path)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load %s: %w", path, err)
|
|
}
|
|
|
|
for _, name := range names {
|
|
arr := native.Get(name)
|
|
if arr == nil {
|
|
native.Free()
|
|
return fmt.Errorf("tensor %q not found in %s", name, path)
|
|
}
|
|
mw.cache[name] = arr
|
|
}
|
|
|
|
mw.nativeCache[path] = native
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadStreaming loads tensors with dtype conversion.
|
|
// Uses the same pattern as Python: replace each entry in the map after conversion,
|
|
// so the original tensor loses its reference and can be freed.
|
|
func (mw *ModelWeights) loadStreaming(dtype mlx.Dtype) error {
|
|
mw.cache = make(map[string]*mlx.Array)
|
|
|
|
fileToTensors := make(map[string][]string)
|
|
for name, path := range mw.tensorFiles {
|
|
fileToTensors[path] = append(fileToTensors[path], name)
|
|
}
|
|
|
|
for path, names := range fileToTensors {
|
|
native, err := mlx.LoadSafetensorsNative(path)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load %s: %w", path, err)
|
|
}
|
|
|
|
for _, name := range names {
|
|
src := native.Get(name)
|
|
if src == nil {
|
|
native.Free()
|
|
return fmt.Errorf("tensor %q not found in %s", name, path)
|
|
}
|
|
|
|
dst := mlx.AsType(src, dtype)
|
|
mlx.Eval(dst)
|
|
native.Set(name, dst)
|
|
mw.cache[name] = dst
|
|
}
|
|
|
|
native.Free()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Get returns a tensor from cache. Call Load() first.
|
|
func (mw *ModelWeights) Get(name string) (*mlx.Array, error) {
|
|
if mw.cache == nil {
|
|
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
|
}
|
|
arr, ok := mw.cache[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("tensor %q not found in cache", name)
|
|
}
|
|
return arr, nil
|
|
}
|
|
|
|
// GetTensor loads a tensor using the native loader without caching.
|
|
// For bulk loading, use Load() + Get() instead.
|
|
func (mw *ModelWeights) GetTensor(name string) (*mlx.Array, error) {
|
|
if mw.cache != nil {
|
|
if arr, ok := mw.cache[name]; ok {
|
|
return arr, nil
|
|
}
|
|
}
|
|
|
|
path, ok := mw.tensorFiles[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("tensor %q not found", name)
|
|
}
|
|
|
|
native, ok := mw.nativeCache[path]
|
|
if !ok {
|
|
var err error
|
|
native, err = mlx.LoadSafetensorsNative(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load %s: %w", path, err)
|
|
}
|
|
mw.nativeCache[path] = native
|
|
}
|
|
|
|
return native.Get(name), nil
|
|
}
|
|
|
|
// GetTensorInfo returns metadata about a tensor without loading it.
|
|
func (mw *ModelWeights) GetTensorInfo(name string) (TensorInfo, bool) {
|
|
info, ok := mw.tensorInfo[name]
|
|
return info, ok
|
|
}
|
|
|
|
// ListTensors returns all tensor names.
|
|
func (mw *ModelWeights) ListTensors() []string {
|
|
names := make([]string, 0, len(mw.tensorFiles))
|
|
for name := range mw.tensorFiles {
|
|
names = append(names, name)
|
|
}
|
|
sort.Strings(names)
|
|
return names
|
|
}
|
|
|
|
// HasTensor checks if a tensor exists.
|
|
func (mw *ModelWeights) HasTensor(name string) bool {
|
|
_, ok := mw.tensorFiles[name]
|
|
return ok
|
|
}
|
|
|
|
// ReleaseAll releases all cached native file handles.
|
|
func (mw *ModelWeights) ReleaseAll() {
|
|
for path, native := range mw.nativeCache {
|
|
native.Free()
|
|
delete(mw.nativeCache, path)
|
|
}
|
|
}
|
|
|