mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
287 lines
7.7 KiB
Go
287 lines
7.7 KiB
Go
//go:build mlx
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime/pprof"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
|
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
|
"github.com/ollama/ollama/x/imagegen/models/llama"
|
|
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
|
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
|
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
// stringSlice is a flag type that accumulates multiple values
|
|
type stringSlice []string
|
|
|
|
func (s *stringSlice) String() string {
|
|
return fmt.Sprintf("%v", *s)
|
|
}
|
|
|
|
func (s *stringSlice) Set(value string) error {
|
|
*s = append(*s, value)
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
modelPath := flag.String("model", "", "Model directory")
|
|
prompt := flag.String("prompt", "Hello", "Prompt")
|
|
|
|
// Text generation params
|
|
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
|
|
temperature := flag.Float64("temperature", 0.7, "Temperature")
|
|
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
|
topK := flag.Int("top-k", 40, "Top-k sampling")
|
|
imagePath := flag.String("image", "", "Image path for multimodal models")
|
|
|
|
// Image generation params
|
|
width := flag.Int("width", 1024, "Image width")
|
|
height := flag.Int("height", 1024, "Image height")
|
|
steps := flag.Int("steps", 9, "Denoising steps")
|
|
seed := flag.Int64("seed", 42, "Random seed")
|
|
out := flag.String("output", "output.png", "Output path")
|
|
|
|
// Utility flags
|
|
listTensors := flag.Bool("list", false, "List tensors only")
|
|
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
|
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
|
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
|
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
|
|
|
// Legacy mode flags
|
|
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
|
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
|
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
|
var inputImages stringSlice
|
|
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
|
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
|
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
|
|
|
flag.Parse()
|
|
|
|
if *modelPath == "" {
|
|
flag.Usage()
|
|
return
|
|
}
|
|
|
|
// CPU profiling
|
|
if *cpuProfile != "" {
|
|
f, err := os.Create(*cpuProfile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
if err := pprof.StartCPUProfile(f); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer pprof.StopCPUProfile()
|
|
}
|
|
|
|
var err error
|
|
|
|
// Handle legacy mode flags that aren't unified yet
|
|
switch {
|
|
case *zimageFlag:
|
|
m := &zimage.Model{}
|
|
if loadErr := m.Load(*modelPath); loadErr != nil {
|
|
log.Fatal(loadErr)
|
|
}
|
|
var img *mlx.Array
|
|
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
|
|
Prompt: *prompt,
|
|
Width: int32(*width),
|
|
Height: int32(*height),
|
|
Steps: *steps,
|
|
Seed: *seed,
|
|
CapturePath: *gpuCapture,
|
|
LayerCache: *layerCache,
|
|
})
|
|
if err == nil {
|
|
err = saveImageArray(img, *out)
|
|
}
|
|
case *qwenImage:
|
|
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
|
if loadErr != nil {
|
|
log.Fatal(loadErr)
|
|
}
|
|
var img *mlx.Array
|
|
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
|
Prompt: *prompt,
|
|
NegativePrompt: *negativePrompt,
|
|
CFGScale: float32(*cfgScale),
|
|
Width: int32(*width),
|
|
Height: int32(*height),
|
|
Steps: *steps,
|
|
Seed: *seed,
|
|
LayerCache: *layerCache,
|
|
})
|
|
if err == nil {
|
|
err = saveImageArray(img, *out)
|
|
}
|
|
case *qwenImageEdit:
|
|
if len(inputImages) == 0 {
|
|
log.Fatal("qwen-image-edit requires at least one -input-image")
|
|
}
|
|
|
|
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
|
if loadErr != nil {
|
|
log.Fatal(loadErr)
|
|
}
|
|
// For image editing, use 0 for dimensions to auto-detect from input image
|
|
// unless explicitly overridden from defaults
|
|
editWidth := int32(0)
|
|
editHeight := int32(0)
|
|
if *width != 1024 {
|
|
editWidth = int32(*width)
|
|
}
|
|
if *height != 1024 {
|
|
editHeight = int32(*height)
|
|
}
|
|
|
|
cfg := &qwen_image_edit.GenerateConfig{
|
|
Prompt: *prompt,
|
|
NegativePrompt: *negativePrompt,
|
|
CFGScale: float32(*cfgScale),
|
|
Width: editWidth,
|
|
Height: editHeight,
|
|
Steps: *steps,
|
|
Seed: *seed,
|
|
}
|
|
|
|
var img *mlx.Array
|
|
img, err = m.EditFromConfig(inputImages, cfg)
|
|
if err == nil {
|
|
err = saveImageArray(img, *out)
|
|
}
|
|
case *listTensors:
|
|
err = listModelTensors(*modelPath)
|
|
default:
|
|
// llm path
|
|
m, err := load(*modelPath)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Load image if provided and model supports it
|
|
var image *mlx.Array
|
|
if *imagePath != "" {
|
|
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
|
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
|
if err != nil {
|
|
log.Fatal("load image:", err)
|
|
}
|
|
} else {
|
|
log.Fatal("model does not support image input")
|
|
}
|
|
}
|
|
|
|
err = generate(context.Background(), m, input{
|
|
Prompt: *prompt,
|
|
Image: image,
|
|
MaxTokens: *maxTokens,
|
|
Temperature: float32(*temperature),
|
|
TopP: float32(*topP),
|
|
TopK: *topK,
|
|
WiredLimitGB: *wiredLimitGB,
|
|
}, func(out output) {
|
|
if out.Text != "" {
|
|
fmt.Print(out.Text)
|
|
}
|
|
if out.Done {
|
|
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
|
|
}
|
|
})
|
|
}
|
|
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func listModelTensors(modelPath string) error {
|
|
weights, err := safetensors.LoadModelWeights(modelPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, name := range weights.ListTensors() {
|
|
info, _ := weights.GetTensorInfo(name)
|
|
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// loadModel builds and evaluates a model using the common load pattern.
|
|
// Release safetensors BEFORE eval - lazy arrays have captured their data,
|
|
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
|
|
func loadModel[T Model](build func() T, cleanup func()) T {
|
|
m := build()
|
|
weights := mlx.Collect(m)
|
|
cleanup()
|
|
mlx.Eval(weights...)
|
|
return m
|
|
}
|
|
|
|
func load(modelPath string) (Model, error) {
|
|
kind, err := detectModelKind(modelPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("detect model kind: %w", err)
|
|
}
|
|
|
|
switch kind {
|
|
case "gpt_oss":
|
|
return gpt_oss.Load(modelPath)
|
|
case "gemma3":
|
|
return gemma3.Load(modelPath)
|
|
case "gemma3_text":
|
|
return gemma3.LoadText(modelPath)
|
|
default:
|
|
return llama.Load(modelPath)
|
|
}
|
|
}
|
|
|
|
func detectModelKind(modelPath string) (string, error) {
|
|
indexPath := filepath.Join(modelPath, "model_index.json")
|
|
if _, err := os.Stat(indexPath); err == nil {
|
|
data, err := os.ReadFile(indexPath)
|
|
if err != nil {
|
|
return "zimage", nil
|
|
}
|
|
var index struct {
|
|
ClassName string `json:"_class_name"`
|
|
}
|
|
if err := json.Unmarshal(data, &index); err == nil {
|
|
switch index.ClassName {
|
|
case "FluxPipeline", "ZImagePipeline":
|
|
return "zimage", nil
|
|
}
|
|
}
|
|
return "zimage", nil
|
|
}
|
|
|
|
configPath := filepath.Join(modelPath, "config.json")
|
|
data, err := os.ReadFile(configPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
|
|
}
|
|
|
|
var cfg struct {
|
|
ModelType string `json:"model_type"`
|
|
}
|
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
return "", fmt.Errorf("parse config.json: %w", err)
|
|
}
|
|
|
|
return cfg.ModelType, nil
|
|
}
|