mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
117 lines
3.3 KiB
Go
117 lines
3.3 KiB
Go
//go:build mlx
|
|
|
|
package imagegen
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// ManifestWeights provides fast weight loading from tensor blobs.
|
|
// Uses native mmap loading with synthetic safetensors headers for zero-copy.
|
|
type ManifestWeights struct {
|
|
manifest *ModelManifest
|
|
component string
|
|
tensors map[string]ManifestLayer // name -> layer
|
|
cache map[string]*mlx.Array // name -> loaded array
|
|
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
|
}
|
|
|
|
// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
|
|
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
|
|
layers := manifest.GetTensorLayers(component)
|
|
if len(layers) == 0 {
|
|
return nil, fmt.Errorf("no tensor layers found for component %q", component)
|
|
}
|
|
|
|
// Strip component prefix from tensor names for model loading
|
|
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
|
|
prefix := component + "/"
|
|
tensors := make(map[string]ManifestLayer, len(layers))
|
|
for _, layer := range layers {
|
|
tensorName := strings.TrimPrefix(layer.Name, prefix)
|
|
tensors[tensorName] = layer
|
|
}
|
|
|
|
return &ManifestWeights{
|
|
manifest: manifest,
|
|
component: component,
|
|
tensors: tensors,
|
|
cache: make(map[string]*mlx.Array),
|
|
}, nil
|
|
}
|
|
|
|
// Load loads all tensor blobs using native mmap (zero-copy).
|
|
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
|
|
// If dtype is non-zero, tensors are converted to the specified dtype.
|
|
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
|
for name, layer := range mw.tensors {
|
|
path := mw.manifest.BlobPath(layer.Digest)
|
|
|
|
// Load blob as safetensors (native mmap, zero-copy)
|
|
sf, err := mlx.LoadSafetensorsNative(path)
|
|
if err != nil {
|
|
return fmt.Errorf("load %s: %w", name, err)
|
|
}
|
|
|
|
// Blob contains single tensor named "data"
|
|
arr := sf.Get("data")
|
|
if arr == nil {
|
|
sf.Free()
|
|
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
|
|
}
|
|
|
|
// Convert dtype if needed
|
|
if dtype != 0 && arr.Dtype() != dtype {
|
|
arr = mlx.AsType(arr, dtype)
|
|
}
|
|
// ALWAYS make a contiguous copy to ensure independence from mmap
|
|
arr = mlx.Contiguous(arr)
|
|
mlx.Eval(arr)
|
|
mw.cache[name] = arr
|
|
sf.Free() // Safe to free - arr is now an independent copy
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTensor returns a tensor from cache. Call Load() first.
|
|
func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
|
|
if mw.cache == nil {
|
|
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
|
}
|
|
arr, ok := mw.cache[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("tensor %q not found", name)
|
|
}
|
|
return arr, nil
|
|
}
|
|
|
|
// ListTensors returns all tensor names in sorted order.
|
|
func (mw *ManifestWeights) ListTensors() []string {
|
|
names := make([]string, 0, len(mw.tensors))
|
|
for name := range mw.tensors {
|
|
names = append(names, name)
|
|
}
|
|
sort.Strings(names)
|
|
return names
|
|
}
|
|
|
|
// HasTensor checks if a tensor exists.
|
|
func (mw *ManifestWeights) HasTensor(name string) bool {
|
|
_, ok := mw.tensors[name]
|
|
return ok
|
|
}
|
|
|
|
// ReleaseAll frees all native handles and clears the tensor cache.
|
|
func (mw *ManifestWeights) ReleaseAll() {
|
|
for _, sf := range mw.nativeCache {
|
|
sf.Free()
|
|
}
|
|
mw.nativeCache = nil
|
|
mw.cache = nil
|
|
}
|