Add z-image image generation prototype (#13659)

This commit is contained in:
Jeffrey Morgan
2026-01-09 21:09:46 -08:00
committed by GitHub
parent c6d4c0c7f2
commit 2584940016
44 changed files with 6422 additions and 269 deletions

View File

@@ -161,10 +161,6 @@ ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
# TODO wire up the actual MLX engine here instead of building the main binary...
RUN mkdir -p dist/bin
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama

View File

@@ -46,6 +46,8 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -96,6 +98,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
return imagegenclient.CreateModel(args[0], ".", p)
}
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
@@ -457,6 +463,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -822,6 +837,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -1767,6 +1787,9 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",

33
progress/stepbar.go Normal file
View File

@@ -0,0 +1,33 @@
package progress
import (
"fmt"
"strings"
)
// StepBar displays step-based progress (e.g., for image generation steps).
type StepBar struct {
message string
current int
total int
}
func NewStepBar(message string, total int) *StepBar {
return &StepBar{message: message, total: total}
}
func (s *StepBar) Set(current int) {
s.current = current
}
func (s *StepBar) String() string {
percent := float64(s.current) / float64(s.total) * 100
barWidth := s.total
empty := barWidth - s.current
// "Generating 0% ▕ ▏ 0/9"
return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d",
s.message, percent,
strings.Repeat("█", s.current), strings.Repeat(" ", empty),
s.current, s.total)
}

View File

@@ -3,6 +3,7 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
)
func Execute(args []string) error {
@@ -11,12 +12,19 @@ func Execute(args []string) error {
}
var newRunner bool
if args[0] == "--ollama-engine" {
var imageRunner bool
if len(args) > 0 && args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if len(args) > 0 && args[0] == "--image-engine" {
args = args[1:]
imageRunner = true
}
if newRunner {
if imageRunner {
return imagerunner.Execute(args)
} else if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)

View File

@@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen/transfer"
)
var (
@@ -73,6 +74,11 @@ type Model struct {
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
}
// Check for completion capability
if m.ModelPath != "" {
f, err := gguf.Open(m.ModelPath)
@@ -555,6 +561,24 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := mp.GetManifestPath()
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
if err != nil {
return err
}
if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
@@ -620,6 +644,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
@@ -634,7 +667,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
@@ -643,13 +675,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
@@ -657,6 +687,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
}
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
@@ -690,6 +725,148 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
// hasTensorLayers checks if any layer has tensor media type.
func hasTensorLayers(layers []Layer) bool {
for _, layer := range layers {
if layer.MediaType == MediaTypeImageTensor {
return true
}
}
return false
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
}
}
destDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pulling model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
Repository: mp.GetNamespaceRepository(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
}); err != nil {
return err
}
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
Digest: layer.Digest,
Size: layer.Size,
From: layer.From,
}
}
srcDir, err := GetBlobsPath("")
if err != nil {
return err
}
base := mp.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
baseURL := base.String()
var totalSize int64
for _, blob := range blobs {
totalSize += blob.Size
}
progress := func(completed, total int64) {
fn(api.ProgressResponse{
Status: "pushing model",
Digest: "sha256:model",
Total: total,
Completed: completed,
})
}
getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) {
return getAuthorizationToken(ctx, registryChallenge{
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}
return transfer.Upload(ctx, transfer.UploadOptions{
Blobs: blobs,
BaseURL: baseURL,
SrcDir: srcDir,
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
ManifestRef: mp.Tag,
Repository: mp.GetNamespaceRepository(),
})
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)

View File

@@ -47,6 +47,15 @@ func TestModelCapabilities(t *testing.T) {
model Model
expectedCaps []model.Capability
}{
{
name: "model with image generation capability via config",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
},
{
name: "model with completion capability",
model: Model{

View File

@@ -13,9 +13,14 @@ type Layer struct {
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
status string
}
const (
MediaTypeImageTensor = "application/vnd.ollama.image.tensor"
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {

View File

@@ -50,6 +50,8 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -162,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil
}
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
@@ -189,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -1547,6 +1578,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil {
// wrap old with new
rs := &registry.Local{

View File

@@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
type LlmRequest struct {
@@ -194,6 +195,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation model before attempting GGML load
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadImageGen(pending) {
break
}
continue
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -543,6 +552,48 @@ iGPUScan:
return false
}
// loadImageGen loads an image generation model.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Use model name for imagegen (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName)
if err != nil {
req.errCh <- err
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
return true
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return

View File

@@ -6,6 +6,7 @@ import (
"errors"
"log/slog"
"os"
"slices"
"testing"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
func TestMain(m *testing.M) {
@@ -804,3 +806,61 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted by a language model request.
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Simulate an image gen runner already loaded
imageGenRunner := &runnerRef{
model: &Model{Name: "z-image", ModelPath: "/fake/image/model"},
modelPath: "/fake/image/model",
llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}},
sessionDuration: 5 * time.Millisecond,
refCount: 0, // idle
}
s.loadedMu.Lock()
s.loaded["/fake/image/model"] = imageGenRunner
s.loadedMu.Unlock()
// Verify the image gen runner is loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
// findRunnerToUnload should find the idle image gen runner
runner := s.findRunnerToUnload()
require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath)
}

View File

@@ -3,12 +3,13 @@ package model
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
)
func (c Capability) String() string {

View File

@@ -1,61 +1,236 @@
# imagegen
# Image Generation in Ollama (Experimental)
This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner.
in `CMakeLists.txt` and rebuild.
Generate images from text prompts using local AI models.
### 1. Download a Model
Download Llama 3.1 8B (or any compatible model) in safetensors format:
## Quick Start
```bash
mkdir -p ./weights
# Example using huggingface-cli
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
# Run with a prompt
ollama run z-image "a sunset over mountains"
Generating: step 30/30
Image saved to: /tmp/ollama-image-1704067200.png
```
### 2. Run Inference
On macOS, the generated image will automatically open in Preview.
## Supported Models
| Model | VRAM Required | Notes |
|-------|---------------|-------|
| z-image | ~12GB | Based on Flux architecture |
## CLI Usage
```bash
# Build
go build ./cmd/engine
# Generate an image
ollama run z-image "a cat playing piano"
# Text generation
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
# Check if model is running
ollama ps
# Qwen-Image 2512 (text-to-image)
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
-steps 8 -seed 42 -output edited.png
# Stop the model
ollama stop z-image
```
## Memory Management
## API
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
### OpenAI-Compatible Endpoint
Instead, arrays are automatically tracked and freed on `Eval()`:
```go
// All arrays are automatically tracked when created
x := mlx.Add(a, b)
y := mlx.Matmul(x, w)
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
mlx.Eval(y)
// After copying to CPU, free the array
data := y.Data()
y.Free()
```bash
POST /v1/images/generations
```
Key points:
**Request:**
```json
{
"model": "z-image",
"prompt": "a sunset over mountains",
"size": "1024x1024",
"response_format": "b64_json"
}
```
- All created arrays are automatically tracked
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
- Call `.Free()` when done with an array
**Response:**
```json
{
"created": 1704067200,
"data": [
{
"b64_json": "iVBORw0KGgo..."
}
]
}
```
### Example: cURL
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}'
```
### Example: Save to File
```bash
curl -s http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "z-image",
"prompt": "a white cat",
"size": "1024x1024"
}' | jq -r '.data[0].b64_json' | base64 -d > image.png
```
### Streaming Progress
Enable streaming to receive progress updates via SSE:
```bash
curl http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{"model": "z-image", "prompt": "a sunset", "stream": true}'
```
Events:
```
event: progress
data: {"step": 1, "total": 30}
event: progress
data: {"step": 2, "total": 30}
...
event: done
data: {"created": 1704067200, "data": [{"b64_json": "..."}]}
```
## Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| model | string | required | Model name |
| prompt | string | required | Text description of image |
| size | string | "1024x1024" | Image dimensions (WxH) |
| n | int | 1 | Number of images (currently only 1 supported) |
| response_format | string | "b64_json" | "b64_json" or "url" |
| stream | bool | false | Enable progress streaming |
## Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- CUDA: tested on CUDA 12 Blackwell, more testing coming soon
- Sufficient VRAM (see model table above)
- Ollama built with MLX support
## Limitations
- macOS only (uses MLX backend)
- Single image per request
- Fixed step count (30 steps)
- Modelfiles not yet supported (use `ollama create` from model directory)
---
# Tensor Model Storage Format
Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once).
## Manifest Structure
The manifest follows the standard ollama format with tensor-specific layer metadata:
```json
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": { "digest": "sha256:...", "size": 1234 },
"layers": [
{
"mediaType": "application/vnd.ollama.image.tensor",
"digest": "sha256:25b36eed...",
"size": 49807448,
"name": "text_encoder/model.layers.0.mlp.down_proj.weight",
"dtype": "BF16",
"shape": [2560, 9728]
},
{
"mediaType": "application/vnd.ollama.image.json",
"digest": "sha256:abc123...",
"size": 512,
"name": "text_encoder/config.json"
}
]
}
```
Each tensor layer includes:
- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`)
- `dtype`: Data type (BF16, F32, etc.)
- `shape`: Tensor dimensions
Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`).
## Blob Format
Each tensor blob is a minimal safetensors file:
```
[8 bytes: header size (uint64 LE)]
[~80 bytes: JSON header, padded to 8-byte alignment]
[N bytes: raw tensor data]
```
Header contains a single tensor named `"data"`:
```json
{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}}
```
## Why Include the Header?
The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which:
1. **Uses mmap** - Maps file directly into memory, no copies
2. **Zero-copy to GPU** - MLX reads directly from mapped pages
3. **No custom code** - Standard MLX API, battle-tested
Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers.
The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%).
## Why Per-Tensor Blobs?
**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file.
Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests.
```
~/.ollama/models/
blobs/
sha256-25b36eed... <- shared by both models
sha256-abc123...
manifests/
library/model-a/latest <- references sha256-25b36eed
library/model-b/latest <- references sha256-25b36eed
```
## Import Flow
```
cd ./weights/Z-Image-Turbo
ollama create z-image
1. Scan component directories (text_encoder/, transformer/, vae/)
2. For each .safetensors file:
- Extract individual tensors
- Wrap each in minimal safetensors format (88B header + data)
- Write to blob store (SHA256 content-addressed)
- Add layer entry to manifest with path-style name
3. Copy config files (*.json) as config layers
4. Write manifest
```

235
x/imagegen/api/handler.go Normal file
View File

@@ -0,0 +1,235 @@
package api
import (
"encoding/base64"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imagePath, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imagePath, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractPath(content string) string {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
return strings.TrimSpace(content[idx+16:])
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imagePath, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imagePath == "" {
return resp
}
if format == "url" {
resp.Data[0].URL = "file://" + imagePath
} else {
data, err := os.ReadFile(imagePath)
if err == nil {
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
}
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

31
x/imagegen/api/types.go Normal file
View File

@@ -0,0 +1,31 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

539
x/imagegen/cli.go Normal file
View File

@@ -0,0 +1,539 @@
// cli.go provides CLI commands for image generation models.
//
// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable.
// Currently these are separate to keep experimental code isolated.
package imagegen
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
)
// ImageGenOptions holds options for image generation.
// These can be set via environment variables or interactive commands.
type ImageGenOptions struct {
Width int
Height int
Steps int
Seed int
NegativePrompt string
}
// DefaultOptions returns the default image generation options.
func DefaultOptions() ImageGenOptions {
return ImageGenOptions{
Width: 1024,
Height: 1024,
Steps: 9,
Seed: 0, // 0 means random
}
}
// Show displays information about an image generation model.
func Show(modelName string, w io.Writer) error {
manifest, err := LoadManifest(modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Count total size
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalSize += layer.Size
}
}
// Read model_index.json for architecture
var architecture string
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
}
if json.Unmarshal(data, &index) == nil {
architecture = index.Architecture
}
}
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
paramCount := totalSize / 2
paramStr := formatParamCount(paramCount)
// Print Model info
fmt.Fprintln(w, " Model")
if architecture != "" {
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
}
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
fmt.Fprintln(w)
// Print Capabilities
fmt.Fprintln(w, " Capabilities")
fmt.Fprintf(w, " %s\n", "image")
fmt.Fprintln(w)
return nil
}
// formatParamCount formats parameter count as human-readable string.
func formatParamCount(count int64) string {
if count >= 1_000_000_000 {
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
}
if count >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
}
return fmt.Sprintf("%d", count)
}
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps")
cmd.Flags().MarkHidden("seed")
cmd.Flags().MarkHidden("negative")
}
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Verify it's a valid image gen model
if ResolveModelName(name) == "" {
return fmt.Errorf("unknown image generation model: %s", name)
}
// Get options from flags (with env var defaults)
opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil {
if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 {
opts.Width = v
}
if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 {
opts.Height = v
}
if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 {
opts.Steps = v
}
if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 {
opts.Seed = v
}
if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" {
opts.NegativePrompt = v
}
}
if interactive {
return runInteractive(cmd, name, keepAlive, opts)
}
// One-shot generation
return generateImageWithOptions(cmd, name, prompt, keepAlive, opts)
}
// generateImageWithOptions generates an image with the given options.
func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
})
p.Stop()
if err != nil {
return err
}
if imagePath != "" {
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
}
return nil
}
// runInteractive runs an interactive REPL for image generation.
func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)",
})
if err != nil {
return err
}
if envconfig.NoHistory() {
scanner.HistoryDisable()
}
for {
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
continue
case err != nil:
return err
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Handle commands
switch {
case strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
printInteractiveHelp(opts)
continue
case strings.HasPrefix(line, "/set "):
if err := handleSetCommand(line[5:], &opts); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
}
continue
case strings.HasPrefix(line, "/show"):
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
}
if keepAlive != nil {
req.KeepAlive = keepAlive
}
// Show loading spinner until generation starts
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
})
p.Stop()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Copy image to current directory with descriptive name
if imagePath != "" {
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
// Copy file to CWD
if err := copyFile(imagePath, newName); err != nil {
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
} else {
displayImageInTerminal(newName)
fmt.Printf("Image saved to: %s\n", newName)
}
}
fmt.Println()
}
}
// sanitizeFilename removes characters that aren't safe for filenames.
func sanitizeFilename(s string) string {
s = strings.ToLower(s)
s = strings.ReplaceAll(s, " ", "-")
// Remove any character that's not alphanumeric or hyphen
var result strings.Builder
for _, r := range s {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}
// copyFile copies a file from src to dst.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
fmt.Fprintln(os.Stderr, " /show Show current settings")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.")
fmt.Fprintln(os.Stderr)
}
// printCurrentSettings prints the current image generation settings.
func printCurrentSettings(opts ImageGenOptions) {
fmt.Fprintf(os.Stderr, "Current settings:\n")
fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width)
fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height)
fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps)
fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed)
if opts.NegativePrompt != "" {
fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt)
}
fmt.Fprintln(os.Stderr)
}
// handleSetCommand handles /set commands to change options.
func handleSetCommand(args string, opts *ImageGenOptions) error {
parts := strings.SplitN(args, " ", 2)
if len(parts) < 2 {
return fmt.Errorf("usage: /set <option> <value>")
}
key := strings.ToLower(parts[0])
value := strings.TrimSpace(parts[1])
switch key {
case "width", "w":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("width must be a positive integer")
}
opts.Width = v
fmt.Fprintf(os.Stderr, "Set width to %d\n", v)
case "height", "h":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("height must be a positive integer")
}
opts.Height = v
fmt.Fprintf(os.Stderr, "Set height to %d\n", v)
case "steps", "s":
v, err := strconv.Atoi(value)
if err != nil || v <= 0 {
return fmt.Errorf("steps must be a positive integer")
}
opts.Steps = v
fmt.Fprintf(os.Stderr, "Set steps to %d\n", v)
case "seed":
v, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("seed must be an integer")
}
opts.Seed = v
fmt.Fprintf(os.Stderr, "Set seed to %d\n", v)
case "negative", "neg", "n":
opts.NegativePrompt = value
if value == "" {
fmt.Fprintln(os.Stderr, "Cleared negative prompt")
} else {
fmt.Fprintf(os.Stderr, "Set negative prompt to: %s\n", value)
}
default:
return fmt.Errorf("unknown option: %s (try /help)", key)
}
return nil
}
// displayImageInTerminal attempts to render an image inline in the terminal.
// Supports iTerm2, Kitty, WezTerm, Ghostty, and other terminals with inline image support.
// Returns true if the image was displayed, false otherwise.
func displayImageInTerminal(imagePath string) bool {
// Check if terminal supports inline images
termProgram := os.Getenv("TERM_PROGRAM")
kittyWindowID := os.Getenv("KITTY_WINDOW_ID")
weztermPane := os.Getenv("WEZTERM_PANE")
ghostty := os.Getenv("GHOSTTY_RESOURCES_DIR")
// Read the image file
data, err := os.ReadFile(imagePath)
if err != nil {
return false
}
encoded := base64.StdEncoding.EncodeToString(data)
switch {
case termProgram == "iTerm.app" || termProgram == "WezTerm" || weztermPane != "":
// iTerm2/WezTerm inline image protocol
// ESC ] 1337 ; File = [arguments] : base64 BEL
fmt.Printf("\033]1337;File=inline=1;preserveAspectRatio=1:%s\a\n", encoded)
return true
case kittyWindowID != "" || ghostty != "" || termProgram == "ghostty":
// Kitty graphics protocol (also used by Ghostty)
// Send in chunks for large images
const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize {
end := i + chunkSize
if end > len(encoded) {
end = len(encoded)
}
chunk := encoded[i:end]
if i == 0 {
// First chunk: a=T (transmit), f=100 (PNG), m=1 (more chunks follow) or m=0 (last chunk)
more := 1
if end >= len(encoded) {
more = 0
}
fmt.Printf("\033_Ga=T,f=100,m=%d;%s\033\\", more, chunk)
} else if end >= len(encoded) {
// Last chunk
fmt.Printf("\033_Gm=0;%s\033\\", chunk)
} else {
// Middle chunk
fmt.Printf("\033_Gm=1;%s\033\\", chunk)
}
}
fmt.Println()
return true
default:
return false
}
}

130
x/imagegen/client/create.go Normal file
View File

@@ -0,0 +1,130 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -0,0 +1,35 @@
# MLX Engine
Experimental MLX backend for running models on Apple Silicon and CUDA.
## Build
```bash
go build -tags mlx -o engine ./x/imagegen/cmd/engine
```
## Text Generation
```bash
./engine -model /path/to/model -prompt "Hello" -max-tokens 100
```
Options:
- `-temperature` - sampling temperature (default 0.7)
- `-top-p` - nucleus sampling (default 0.9)
- `-top-k` - top-k sampling (default 40)
Supports: Llama, Gemma3, GPT-OSS
## Image Generation
```bash
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
```
Options:
- `-width`, `-height` - image dimensions (default 1024x1024)
- `-steps` - denoising steps (default 9)
- `-seed` - random seed (default 42)

View File

@@ -98,7 +98,7 @@ func main() {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),

183
x/imagegen/create.go Normal file
View File

@@ -0,0 +1,183 @@
package imagegen
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// IsTensorModelDir checks if the directory contains a tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
continue
}
// Find all safetensors files in this component
entries, err := os.ReadDir(componentDir)
if err != nil {
return fmt.Errorf("failed to read %s: %w", component, err)
}
for _, entry := range entries {
if !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(componentDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, layer)
}
extractor.Close()
}
}
// Import config files
configFiles := []string{
"model_index.json",
"text_encoder/config.json",
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
}
for _, cfgPath := range configFiles {
fullPath := filepath.Join(modelDir, cfgPath)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
continue
}
fn(fmt.Sprintf("importing config %s", cfgPath))
var r io.Reader
// For model_index.json, normalize to Ollama format
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
}
var cfg map[string]any
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
}
// Rename _class_name to architecture, remove diffusers-specific fields
if className, ok := cfg["_class_name"]; ok {
cfg["architecture"] = className
delete(cfg, "_class_name")
}
delete(cfg, "_diffusers_version")
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
}
r = bytes.NewReader(data)
} else {
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
defer f.Close()
r = f
}
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use model_index.json as the config layer
if cfgPath == "model_index.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("model_index.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

107
x/imagegen/image.go Normal file
View File

@@ -0,0 +1,107 @@
//go:build mlx
package imagegen
import (
"bytes"
"encoding/base64"
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SaveImage saves an MLX array as a PNG image file.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func SaveImage(arr *mlx.Array, path string) error {
img, err := ArrayToImage(arr)
if err != nil {
return err
}
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func EncodeImageBase64(arr *mlx.Array) (string, error) {
img, err := ArrayToImage(arr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// ArrayToImage converts an MLX array to a Go image.RGBA.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

177
x/imagegen/manifest.go Normal file
View File

@@ -0,0 +1,177 @@
package imagegen
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
)
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"` // Path-style name: "component/tensor" or "path/to/config.json"
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ModelManifest holds a parsed manifest with helper methods.
type ModelManifest struct {
Manifest *Manifest
BlobDir string
}
// DefaultBlobDir returns the default blob storage directory.
func DefaultBlobDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
switch runtime.GOOS {
case "darwin":
return filepath.Join(home, ".ollama", "models", "blobs")
case "linux":
return filepath.Join(home, ".ollama", "models", "blobs")
case "windows":
return filepath.Join(home, ".ollama", "models", "blobs")
default:
return filepath.Join(home, ".ollama", "models", "blobs")
}
}
// DefaultManifestDir returns the default manifest storage directory.
func DefaultManifestDir() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
return filepath.Join(home, ".ollama", "models", "manifests")
}
// LoadManifest loads a manifest for the given model name.
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
func LoadManifest(modelName string) (*ModelManifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("read manifest: %w", err)
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, fmt.Errorf("parse manifest: %w", err)
}
return &ModelManifest{
Manifest: &manifest,
BlobDir: DefaultBlobDir(),
}, nil
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
// Parse model name into components
// Default: registry.ollama.ai/library/<name>/<tag>
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
// Handle explicit tag
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
// Handle full path like "host/namespace/name"
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
}
// BlobPath returns the full path to a blob given its digest.
func (m *ModelManifest) BlobPath(digest string) string {
// Convert "sha256:abc123" to "sha256-abc123"
blobName := strings.Replace(digest, ":", "-", 1)
return filepath.Join(m.BlobDir, blobName)
}
// GetTensorLayers returns all tensor layers for a given component.
// Component should be "text_encoder", "transformer", or "vae".
// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
prefix := component + "/"
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
layers = append(layers, layer)
}
}
return layers
}
// GetConfigLayer returns the config layer for a given path.
func (m *ModelManifest) GetConfigLayer(configPath string) *ManifestLayer {
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
return &layer
}
}
return nil
}
// ReadConfig reads and returns the content of a config file.
func (m *ModelManifest) ReadConfig(configPath string) ([]byte, error) {
layer := m.GetConfigLayer(configPath)
if layer == nil {
return nil, fmt.Errorf("config %q not found in manifest", configPath)
}
blobPath := m.BlobPath(layer.Digest)
return os.ReadFile(blobPath)
}
// ReadConfigJSON reads and unmarshals a config file.
func (m *ModelManifest) ReadConfigJSON(configPath string, v any) error {
data, err := m.ReadConfig(configPath)
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
// OpenBlob opens a blob for reading.
func (m *ModelManifest) OpenBlob(digest string) (io.ReadCloser, error) {
return os.Open(m.BlobPath(digest))
}
// HasTensorLayers returns true if the manifest has any tensor layers.
func (m *ModelManifest) HasTensorLayers() bool {
for _, layer := range m.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
return true
}
}
return false
}

102
x/imagegen/memory.go Normal file
View File

@@ -0,0 +1,102 @@
// Package imagegen provides experimental image generation capabilities for Ollama.
//
// This package is in x/ because the tensor model storage format is under development.
// The goal is to integrate these capabilities into the main Ollama packages once
// the format is stable.
//
// TODO (jmorganca): Integrate into main packages when stable:
// - CLI commands → cmd/
// - API endpoints → api/
// - Model creation → server/
package imagegen
import (
"encoding/json"
"fmt"
"runtime"
)
// GB is a convenience constant for gigabytes.
const GB = 1024 * 1024 * 1024
// SupportedBackends lists the backends that support image generation.
var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
// Returns nil if supported, or an error describing why it's not supported.
func CheckPlatformSupport() error {
switch runtime.GOOS {
case "darwin":
// macOS: Metal is supported via MLX
if runtime.GOARCH != "arm64" {
return fmt.Errorf("image generation on macOS requires Apple Silicon (arm64), got %s", runtime.GOARCH)
}
return nil
case "linux", "windows":
// Linux/Windows: CUDA support (requires mlx or cuda build)
// The actual backend availability is checked at runtime
return nil
default:
return fmt.Errorf("image generation is not supported on %s", runtime.GOOS)
}
}
// CheckMemoryRequirements validates that there's enough memory for image generation.
// Returns nil if memory is sufficient, or an error if not.
func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
required := EstimateVRAM(modelName)
if availableMemory < required {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
required/GB, availableMemory/GB)
}
return nil
}
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
manifest, err := LoadManifest(modelName)
if err == nil && manifest.HasTensorLayers() {
return modelName
}
return ""
}
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 {
manifest, err := LoadManifest(modelName)
if err != nil {
return 21 * GB
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return 21 * GB
}
// Parse just the class name
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return 21 * GB
}
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
return estimate
}
return 21 * GB
}
// HasTensorLayers checks if the given model has tensor layers.
func HasTensorLayers(modelName string) bool {
return ResolveModelName(modelName) != ""
}

110
x/imagegen/memory_test.go Normal file
View File

@@ -0,0 +1,110 @@
package imagegen
import (
"runtime"
"testing"
)
func TestCheckPlatformSupport(t *testing.T) {
err := CheckPlatformSupport()
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
if err != nil {
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
}
} else {
if err == nil {
t.Error("Expected error on darwin/non-arm64")
}
}
case "linux", "windows":
if err != nil {
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
}
default:
if err == nil {
t.Errorf("Expected error on unsupported platform %s", runtime.GOOS)
}
}
}
func TestCheckMemoryRequirements(t *testing.T) {
tests := []struct {
name string
availableMemory uint64
wantErr bool
}{
{
name: "sufficient memory",
availableMemory: 32 * GB,
wantErr: false,
},
{
name: "exactly enough memory",
availableMemory: 21 * GB,
wantErr: false,
},
{
name: "insufficient memory",
availableMemory: 16 * GB,
wantErr: true,
},
{
name: "zero memory",
availableMemory: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use a non-existent model name which will default to 21GB estimate
err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
if (err != nil) != tt.wantErr {
t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries
expected := map[string]uint64{
"ZImagePipeline": 21 * GB,
"FluxPipeline": 21 * GB,
"QwenImagePipeline": 80 * GB,
}
for name, expectedVRAM := range expected {
if actual, ok := modelVRAMEstimates[name]; !ok {
t.Errorf("Missing VRAM estimate for %s", name)
} else if actual != expectedVRAM {
t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
}
}
}
func TestEstimateVRAMDefault(t *testing.T) {
// Non-existent model should return default 21GB
vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
if vram != 21*GB {
t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
}
}
func TestHasTensorLayers(t *testing.T) {
// Non-existent model should return false
if HasTensorLayers("nonexistent-model") {
t.Error("HasTensorLayers() should return false for non-existent model")
}
}
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")
if result != "" {
t.Errorf("ResolveModelName() = %q, want empty string", result)
}
}

View File

@@ -11,6 +11,10 @@ package mlx
#include "mlx/c/mlx.h"
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
// Forward declare cpu_stream
static mlx_stream cpu_stream();
// Cached default GPU stream for all ops
static mlx_stream _default_stream = {0};
@@ -1026,10 +1030,11 @@ func View(a *Array, dtype int) *Array {
return newArray(res)
}
// Contiguous returns a contiguous copy of the array
// Contiguous returns a contiguous copy of the array (row-major)
func Contiguous(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_contiguous(&res, a.c, true, C.default_stream())
// Use allow_col=false to force row-major contiguous layout
C.mlx_contiguous(&res, a.c, false, C.default_stream())
return newArray(res)
}
@@ -1762,11 +1767,16 @@ func RandomCategorical(logits *Array, axis int, numSamples int) *Array {
return RandomCategoricalWithKey(logits, key2, axis, numSamples)
}
// RandomNormal creates a random normal (Gaussian) tensor
// RandomNormal creates a random normal (Gaussian) tensor in float32
func RandomNormal(shape []int32, seed uint64) *Array {
return RandomNormalWithDtype(shape, seed, DtypeFloat32)
}
// RandomNormalWithDtype creates a random normal (Gaussian) tensor with specified dtype
func RandomNormalWithDtype(shape []int32, seed uint64, dtype Dtype) *Array {
key := RandomKey(seed)
res := C.mlx_array_new()
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, 0.0, 1.0, key.c, C.default_stream())
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), 0.0, 1.0, key.c, C.default_stream())
return newArray(res)
}

View File

@@ -128,14 +128,9 @@ func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timest
return mlx.Add(scaledClean, scaledNoise)
}
// InitNoise creates initial noise for sampling
// InitNoise creates initial noise for sampling (BFloat16 for GPU efficiency)
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return RandomNormal(shape, seed)
}
// RandomNormal creates a random normal tensor using MLX
func RandomNormal(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// GetLatentShape returns the latent shape for a given image size

View File

@@ -3,12 +3,10 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -28,19 +26,6 @@ type Qwen3Config struct {
HeadDim int32 `json:"head_dim"`
}
// loadQwen3Config loads text encoder config from a JSON file
func loadQwen3Config(path string) (*Qwen3Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg Qwen3Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj *nn.Linear `weight:"q_proj"`
@@ -194,33 +179,44 @@ type Qwen3TextEncoder struct {
*Qwen3Config
}
// Load loads the Qwen3 text encoder from a directory
func (m *Qwen3TextEncoder) Load(path string) error {
fmt.Println("Loading Qwen3 text encoder...")
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading text encoder... ")
// Load config
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
if err != nil {
// Load config from blob
var cfg Qwen3Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Qwen3Config = cfg
// Pre-allocate layers slice
m.Qwen3Config = &cfg
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
fmt.Print(" Loading weights via struct tags... ")
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// Initialize computed fields
// initComputedFields initializes computed fields after loading weights
func (m *Qwen3TextEncoder) initComputedFields() {
cfg := m.Qwen3Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
@@ -235,9 +231,6 @@ func (m *Qwen3TextEncoder) Load(path string) error {
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
weights.ReleaseAll()
return nil
}
// Forward encodes text tokens

View File

@@ -4,12 +4,10 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
@@ -335,41 +333,49 @@ type Transformer struct {
*TransformerConfig
}
// Load loads the Z-Image transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Z-Image transformer...")
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
if err != nil {
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = cfg
// Pre-allocate slices for loader
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
}
m.TransformerConfig = &cfg
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
// Load weights from tensor blobs with BF16 conversion
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
defer weights.ReleaseAll()
fmt.Print(" Loading weights via struct tags... ")
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// Initialize computed fields
// initComputedFields initializes computed fields after loading weights
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.CapEmbed.Norm.Eps = 1e-6
@@ -383,26 +389,6 @@ func (m *Transformer) Load(path string) error {
for _, block := range m.Layers {
initTransformerBlock(block, cfg)
}
weights.ReleaseAll()
return nil
}
// loadTransformerConfig loads transformer config from a JSON file
func loadTransformerConfig(path string) (*TransformerConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg TransformerConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Extract PatchSize from array
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
}
return &cfg, nil
}
// initTransformerBlock sets computed fields on a transformer block

View File

@@ -3,12 +3,10 @@
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -25,19 +23,6 @@ type VAEConfig struct {
ShiftFactor float32 `json:"shift_factor"`
}
// loadVAEConfig loads VAE config from a JSON file
func loadVAEConfig(path string) (*VAEConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg VAEConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// GroupNormLayer implements group normalization
type GroupNormLayer struct {
Weight *mlx.Array
@@ -57,49 +42,183 @@ func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
}
// Forward applies group normalization
// Input and output are in NHWC format [B, H, W, C]
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
// x: [B, C, H, W]
// x: [B, H, W, C] (NHWC format)
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
H := shape[1]
W := shape[2]
C := shape[3]
// Reshape to [B, groups, C/groups, H, W]
// For large spatial sizes, use tiled computation to avoid CUDA grid limits
// CUDA grid.y max is 65535, so H*W/16 must be <= 65535, meaning H*W <= ~1M
// To be safe, tile when H*W > 512*512 = 262144
if H*W > 512*512 {
return gn.forwardTiled(x, B, H, W, C)
}
return gn.forwardSmall(x, B, H, W, C)
}
// forwardSmall is the standard GroupNorm for tensors that fit within CUDA grid limits
func (gn *GroupNormLayer) forwardSmall(x *mlx.Array, B, H, W, C int32) *mlx.Array {
// Reshape to [B, H, W, groups, C/groups]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 2, true)
mean = mlx.Mean(mean, 3, true)
// Compute mean and variance per group (over H, W, and C/groups dimensions)
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
variance = mlx.Mean(variance, 3, true)
// Variance over same axes
sq := mlx.Square(xCentered)
variance := mlx.Mean(sq, 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, C, H, W]
xNorm = mlx.Reshape(xNorm, B, C, H, W)
// Reshape back to [B, H, W, C]
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift (weight and bias are [C])
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// forwardTiled handles large tensors by processing in H-tiles to avoid CUDA grid limits
func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Array {
groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later
mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions
// Build the entire compute graph first, then eval once
// Reshape to [B, H*W, groups, groupSize]
xFlat := mlx.Reshape(x, B, H*W, gn.NumGroups, groupSize)
// Mean over spatial (axis 1) and groupSize (axis 3) dimensions
// Result shape: [B, 1, groups, 1]
mean1 := mlx.Mean(xFlat, 1, true)
mean := mlx.Mean(mean1, 3, true)
// Variance using E[X^2] - E[X]^2
xSq := mlx.Square(xFlat)
meanSq1 := mlx.Mean(xSq, 1, true)
meanSq := mlx.Mean(meanSq1, 3, true)
meanSquared := mlx.Square(mean)
variance := mlx.Sub(meanSq, meanSquared)
// invStd = 1/sqrt(var + eps)
varPlusEps := mlx.AddScalar(variance, gn.Eps)
stdDev := mlx.Sqrt(varPlusEps)
one := mlx.Full(1.0, 1)
invStd := mlx.Div(one, stdDev)
// Eval mean and invStd together - these are what we need for the tile loop
mlx.Keep(mean, invStd)
mlx.Eval(mean, invStd)
// Tile along H dimension
tileH := int32(512 * 512 / W)
if tileH < 1 {
tileH = 1
}
if tileH > H {
tileH = H
}
// Prepare weight and bias reshaped for 4D broadcast [1, 1, groups, groupSize]
var weightGN, biasGN *mlx.Array
if gn.Weight != nil {
weightGN = mlx.Reshape(gn.Weight, 1, 1, gn.NumGroups, groupSize)
mlx.Keep(weightGN)
mlx.Eval(weightGN)
}
if gn.Bias != nil {
biasGN = mlx.Reshape(gn.Bias, 1, 1, gn.NumGroups, groupSize)
mlx.Keep(biasGN)
mlx.Eval(biasGN)
}
var tiles []*mlx.Array
for hStart := int32(0); hStart < H; hStart += tileH {
hEnd := hStart + tileH
if hEnd > H {
hEnd = H
}
tileHeight := hEnd - hStart
spatialSize := tileHeight * W
// Build the compute graph for this tile (no intermediate Evals)
// Extract tile and flatten spatial dims: [B, tileH*W, groups, groupSize]
tile := mlx.Slice(x, []int32{0, hStart, 0, 0}, []int32{B, hEnd, W, C})
tileFlat := mlx.Reshape(tile, B, spatialSize, gn.NumGroups, groupSize)
// Normalize: (x - mean) * invStd
tileCentered := mlx.Sub(tileFlat, mean)
tileNorm := mlx.Mul(tileCentered, invStd)
// Apply scale and shift in 4D space
if weightGN != nil {
tileNorm = mlx.Mul(tileNorm, weightGN)
}
if biasGN != nil {
tileNorm = mlx.Add(tileNorm, biasGN)
}
// Reshape back to [B, tileH, W, C]
tileOut := mlx.Reshape(tileNorm, B, tileHeight, W, C)
// Now eval and keep this tile
mlx.Keep(tileOut)
mlx.Eval(tileOut)
tiles = append(tiles, tileOut)
}
// Concatenate tiles along H axis
var result *mlx.Array
if len(tiles) == 1 {
result = tiles[0]
} else {
result = mlx.Concatenate(tiles, 1)
mlx.Eval(result)
// Free the individual tiles now that they're concatenated
for _, t := range tiles {
t.Free()
}
}
// Clean up kept arrays
mean.Free()
invStd.Free()
if weightGN != nil {
weightGN.Free()
}
if biasGN != nil {
biasGN.Free()
}
return result
}
// Conv2D represents a 2D convolution layer
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
// Works natively in NHWC format (MLX's native format)
type Conv2D struct {
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
Bias *mlx.Array // [out_channels]
@@ -123,21 +242,17 @@ func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
}
// Forward applies convolution
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
// Input and output are in NHWC format [N, H, W, C]
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
// x: [N, C, H, W] -> [N, H, W, C]
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
// Conv in NHWC format
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
// Conv in NHWC format (MLX native)
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
// Bias is [C], reshape to [1, 1, 1, C] for NHWC broadcast
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
out = mlx.Add(out, bias)
}
return out
}
@@ -151,7 +266,7 @@ type ResnetBlock2D struct {
}
// NewResnetBlock2D creates a ResNet block
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
func NewResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
if err != nil {
return nil, err
@@ -216,13 +331,13 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 1: norm1
{
h = rb.Norm1.Forward(x)
h = rb.Norm1.Forward(x)
mlx.Eval(h)
}
// Stage 2: silu + conv1
{
prev := h
prev := h
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
prev.Free()
@@ -231,7 +346,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 3: norm2
{
prev := h
prev := h
h = rb.Norm2.Forward(h)
prev.Free()
mlx.Eval(h)
@@ -239,7 +354,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 4: silu + conv2
{
prev := h
prev := h
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
prev.Free()
@@ -248,7 +363,7 @@ func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Residual connection
{
prev := h
prev := h
if rb.ConvShortcut != nil {
shortcut := rb.ConvShortcut.Forward(x)
h = mlx.Add(h, shortcut)
@@ -277,7 +392,7 @@ type VAEAttentionBlock struct {
}
// NewVAEAttentionBlock creates an attention block
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
func NewVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
if err != nil {
return nil, err
@@ -338,20 +453,20 @@ func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numG
}
// Forward applies attention with staged evaluation
// Input and output are in NHWC format [B, H, W, C]
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
H := shape[1]
W := shape[2]
C := shape[3]
var h *mlx.Array
// Stage 1: GroupNorm + reshape
// Stage 1: GroupNorm + reshape to [B, H*W, C]
{
h = ab.GroupNorm.Forward(x)
h = mlx.Transpose(h, 0, 2, 3, 1)
h = ab.GroupNorm.Forward(x)
h = mlx.Reshape(h, B, H*W, C)
mlx.Eval(h)
}
@@ -360,7 +475,7 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
// Stage 2: Q, K, V projections + attention
{
q := mlx.Linear(h, ab.ToQWeight)
q := mlx.Linear(h, ab.ToQWeight)
q = mlx.Add(q, ab.ToQBias)
k := mlx.Linear(h, ab.ToKWeight)
k = mlx.Add(k, ab.ToKBias)
@@ -380,11 +495,10 @@ func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
// Stage 3: Output projection + reshape + residual
{
prev := out
prev := out
out = mlx.Linear(out, ab.ToOutWeight)
out = mlx.Add(out, ab.ToOutBias)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Transpose(out, 0, 3, 1, 2)
out = mlx.Add(out, residual)
prev.Free()
mlx.Eval(out)
@@ -400,7 +514,7 @@ type UpDecoderBlock2D struct {
}
// NewUpDecoderBlock2D creates an up decoder block
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
func NewUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
@@ -467,7 +581,7 @@ type VAEMidBlock struct {
}
// NewVAEMidBlock creates the mid block
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
func NewVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
@@ -518,22 +632,31 @@ type VAEDecoder struct {
ConvOut *Conv2D
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading VAE decoder...")
// Load config
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
if err != nil {
// Load loads the VAE decoder from ollama blob storage.
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = cfg
m.Config = &cfg
// Load weights
weights, err := safetensors.LoadModelWeights(path)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights, &cfg)
}
// loadWeights loads VAE weights from any WeightSource
func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load conv_in
fmt.Print(" Loading conv_in... ")
@@ -596,20 +719,20 @@ func (m *VAEDecoder) Load(path string) error {
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Decode decodes latents to images.
// Uses staged pools to free intermediate arrays and reduce peak memory.
// Input latents are in NCHW format, output is in NCHW format.
// Internally uses NHWC format (MLX native) for all operations.
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
var h *mlx.Array
{
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
h = vae.ConvIn.Forward(z)
mlx.Eval(h)
}
// Scale latents
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
// Convert NCHW -> NHWC for internal processing
z = mlx.Transpose(z, 0, 2, 3, 1)
h := vae.ConvIn.Forward(z)
mlx.Eval(h)
h = vae.MidBlock.Forward(h)
@@ -617,36 +740,51 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
h = upBlock.Forward(h)
}
{
prev := h
h = vae.ConvNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
prev.Free()
mlx.Eval(h)
}
prev := h
h = vae.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
mlx.Eval(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
prev.Free()
mlx.Eval(h)
return h
}
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
// x: [B, C, H, W] -> [B, C, H*2, W*2]
// Upsample2x performs 2x nearest neighbor upsampling using Take.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
// Uses Take with repeated indices to produce contiguous output.
func Upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
H := shape[1]
W := shape[2]
// [B, C, H, W] -> [B, C, H, 1, W, 1]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Broadcast to [B, C, H, 2, W, 2]
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
// Reshape to [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H*2, W*2)
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
// For H dimension
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
hIdx = mlx.Reshape(hIdx, H, 1)
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
hIdx = mlx.Reshape(hIdx, H*2)
// For W dimension
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
wIdx = mlx.Reshape(wIdx, W, 1)
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
wIdx = mlx.Reshape(wIdx, W*2)
// Take along H axis (axis 1 in NHWC)
x = mlx.Take(x, hIdx, 1)
// Take along W axis (axis 2 in NHWC)
x = mlx.Take(x, wIdx, 2)
return x
}

View File

@@ -6,9 +6,9 @@ package zimage
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -37,16 +37,16 @@ type ProgressFunc func(step, totalSteps int)
// Model represents a Z-Image diffusion model.
type Model struct {
ModelPath string
ModelName string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Z-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Z-Image model...")
// Load loads the Z-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading Z-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
@@ -54,12 +54,34 @@ func (m *Model) Load(modelPath string) error {
mlx.EnableCompile()
}
m.ModelPath = modelPath
m.ModelName = modelName
// Load tokenizer
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Load tokenizer from manifest with config
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
tok, err := tokenizer.Load(tokenizerPath)
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
// Try to read tokenizer config files from manifest
tokConfig := &tokenizer.TokenizerConfig{}
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
tokConfig.SpecialTokensMapJSON = data
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
@@ -68,7 +90,7 @@ func (m *Model) Load(modelPath string) error {
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
if err := m.TextEncoder.Load(manifest); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
@@ -78,7 +100,7 @@ func (m *Model) Load(modelPath string) error {
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
@@ -88,7 +110,7 @@ func (m *Model) Load(modelPath string) error {
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
@@ -104,7 +126,7 @@ func (m *Model) Load(modelPath string) error {
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
@@ -115,7 +137,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
@@ -127,7 +149,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
@@ -140,9 +162,9 @@ func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
@@ -160,7 +182,7 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
@@ -247,11 +269,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
}
// Denoising loop
if cfg.Progress != nil {
cfg.Progress(0, cfg.Steps) // Start at 0%
}
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
stepStart := time.Now()
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
@@ -295,6 +325,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
@@ -313,6 +344,10 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps) // Report completed step
}
}
// Free denoising temporaries before VAE decode

217
x/imagegen/runner/runner.go Normal file
View File

@@ -0,0 +1,217 @@
//go:build mlx
// Package runner provides a subprocess server for image generation.
// It listens on a port and handles HTTP requests for image generation.
package runner
import (
"context"
"encoding/json"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content"`
Done bool `json:"done"`
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model *zimage.Model
modelName string
}
// Execute is the entry point for the image runner subprocess
func Execute(args []string) error {
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
modelName := fs.String("model", "", "path to image model")
port := fs.Int("port", 0, "port to listen on")
if err := fs.Parse(args); err != nil {
return err
}
if *modelName == "" {
return fmt.Errorf("--model is required")
}
if *port == 0 {
return fmt.Errorf("--port is required")
}
slog.Info("starting image runner", "model", *modelName, "port", *port)
// Check memory requirements before loading
requiredMemory := imagegen.EstimateVRAM(*modelName)
availableMemory := mlx.GetMemoryLimit()
if availableMemory > 0 && availableMemory < requiredMemory {
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
server := &Server{
model: model,
modelName: *modelName,
}
// Set up HTTP handlers
mux := http.NewServeMux()
mux.HandleFunc("/health", server.healthHandler)
mux.HandleFunc("/completion", server.completionHandler)
httpServer := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
Handler: mux,
}
// Handle shutdown
done := make(chan struct{})
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
slog.Info("shutting down image runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
close(done)
}()
slog.Info("image runner listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
return err
}
<-done
return nil
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Serialize generation requests - MLX model may not handle concurrent generation
s.mu.Lock()
defer s.mu.Unlock()
// Apply defaults
if req.Width <= 0 {
req.Width = 1024
}
if req.Height <= 0 {
req.Height = 1024
}
if req.Steps <= 0 {
req.Steps = 9
}
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
// Set up streaming response
w.Header().Set("Content-Type", "application/x-ndjson")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}
// Generate image
ctx := r.Context()
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
if err != nil {
// Don't send error for cancellation
if ctx.Err() != nil {
return
}
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Save image
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
if err := imagegen.SaveImage(img, outPath); err != nil {
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
return
}
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
// Send final response
resp := Response{
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}

View File

@@ -0,0 +1,10 @@
//go:build !mlx
package runner
import "errors"
// Execute returns an error when not built with MLX support.
func Execute(args []string) error {
return errors.New("image generation not available: build with mlx tag")
}

View File

@@ -0,0 +1,176 @@
package safetensors
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"sort"
)
// tensorInfo holds tensor metadata from safetensors headers.
// This avoids depending on safetensors.go which requires the mlx tag.
type tensorInfo struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// TensorExtractor extracts individual tensors from a safetensors file.
// It provides io.Reader interfaces for each tensor's raw data, enabling
// streaming writes to blobs without loading entire tensors into memory.
type TensorExtractor struct {
file *os.File
dataOffset int64 // Start of tensor data region
header map[string]tensorInfo
}
// TensorData holds tensor metadata and a reader for its raw bytes.
type TensorData struct {
Name string
Dtype string
Shape []int32
Size int64
reader *io.SectionReader
}
// Reader returns an io.Reader for the tensor's raw bytes.
func (td *TensorData) Reader() io.Reader {
return td.reader
}
// SafetensorsReader returns a reader that outputs the tensor wrapped in
// minimal safetensors format. This allows using mlx_load_safetensors on
// individual tensor blobs for native zero-copy loading.
func (td *TensorData) SafetensorsReader() io.Reader {
// Build minimal safetensors header with tensor named "data"
header := map[string]tensorInfo{
"data": {
Dtype: td.Dtype,
Shape: td.Shape,
DataOffsets: [2]int{0, int(td.Size)},
},
}
headerJSON, _ := json.Marshal(header)
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Build header with size prefix
headerBuf := new(bytes.Buffer)
binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON)))
headerBuf.Write(headerJSON)
// Return multi-reader: header + tensor data
td.reader.Seek(0, io.SeekStart)
return io.MultiReader(headerBuf, td.reader)
}
// SafetensorsSize returns the total size of the safetensors-wrapped tensor.
func (td *TensorData) SafetensorsSize() int64 {
header := map[string]tensorInfo{
"data": {
Dtype: td.Dtype,
Shape: td.Shape,
DataOffsets: [2]int{0, int(td.Size)},
},
}
headerJSON, _ := json.Marshal(header)
padding := (8 - len(headerJSON)%8) % 8
return 8 + int64(len(headerJSON)) + int64(padding) + td.Size
}
// OpenForExtraction opens a safetensors file for tensor extraction.
// The caller must call Close() when done.
func OpenForExtraction(path string) (*TensorExtractor, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
f.Close()
return nil, fmt.Errorf("failed to read header size: %w", err)
}
headerBytes := make([]byte, headerSize)
if _, err := f.Read(headerBytes); err != nil {
f.Close()
return nil, fmt.Errorf("failed to read header: %w", err)
}
var header map[string]tensorInfo
if err := json.Unmarshal(headerBytes, &header); err != nil {
f.Close()
return nil, fmt.Errorf("failed to parse header: %w", err)
}
delete(header, "__metadata__")
return &TensorExtractor{
file: f,
dataOffset: 8 + int64(headerSize), // 8 bytes for header size + header content
header: header,
}, nil
}
// GetTensor returns tensor metadata and a reader for extracting a single tensor.
func (te *TensorExtractor) GetTensor(name string) (*TensorData, error) {
info, ok := te.header[name]
if !ok {
return nil, fmt.Errorf("tensor %q not found", name)
}
start := te.dataOffset + int64(info.DataOffsets[0])
size := int64(info.DataOffsets[1] - info.DataOffsets[0])
return &TensorData{
Name: name,
Dtype: info.Dtype,
Shape: info.Shape,
Size: size,
reader: io.NewSectionReader(te.file, start, size),
}, nil
}
// ListTensors returns all tensor names in sorted order.
func (te *TensorExtractor) ListTensors() []string {
names := make([]string, 0, len(te.header))
for name := range te.header {
names = append(names, name)
}
sort.Strings(names)
return names
}
// TensorCount returns the number of tensors in the file.
func (te *TensorExtractor) TensorCount() int {
return len(te.header)
}
// Close closes the underlying file.
func (te *TensorExtractor) Close() error {
return te.file.Close()
}
// ExtractAll returns TensorData for all tensors in the file.
// Each TensorData has a reader that reads from the original file.
// The caller must call Close() on the TensorExtractor when done.
func (te *TensorExtractor) ExtractAll() ([]*TensorData, error) {
names := te.ListTensors()
tensors := make([]*TensorData, 0, len(names))
for _, name := range names {
td, err := te.GetTensor(name)
if err != nil {
return nil, err
}
tensors = append(tensors, td)
}
return tensors, nil
}

View File

@@ -10,6 +10,14 @@ import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// WeightSource is an interface for loading weights.
// Both ModelWeights (directory-based) and ManifestWeights (blob-based) implement this.
type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
}
// LoadModule loads weights into a struct using reflection and struct tags.
//
// Struct tags use the format: `weight:"path[,optional]"`
@@ -31,7 +39,7 @@ import (
// }
//
// err := LoadModule(&attn, weights, "model.layers.0")
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
func LoadModule(dst any, weights WeightSource, prefix string) error {
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.IsNil() {
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
@@ -51,7 +59,7 @@ func LoadModule(dst any, weights *ModelWeights, prefix string) error {
}
// loadStruct recursively loads weights into a struct value.
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]string, parentOptional bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
@@ -136,7 +144,7 @@ func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]s
}
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
func hasWeightsWithPrefix(weights WeightSource, prefix string) bool {
for _, name := range weights.ListTensors() {
if strings.HasPrefix(name, prefix+".") || name == prefix {
return true
@@ -146,7 +154,7 @@ func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
}
// loadSlice loads weights into each element of a slice of struct pointers.
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
func loadSlice(v reflect.Value, weights WeightSource, prefix string, errs *[]string) {
elemStructType := v.Type().Elem().Elem()
for i := 0; i < v.Len(); i++ {

View File

@@ -118,6 +118,34 @@ func LoadModelWeights(dir string) (*ModelWeights, error) {
return mw, nil
}
// LoadModelWeightsFromPaths loads weights from specific safetensor file paths.
// Used for loading from blob storage where files are not in a directory.
func LoadModelWeightsFromPaths(paths []string) (*ModelWeights, error) {
mw := &ModelWeights{
tensorFiles: make(map[string]string),
tensorInfo: make(map[string]TensorInfo),
nativeCache: make(map[string]*mlx.SafetensorsFile),
}
for _, path := range paths {
header, err := parseSafetensorHeader(path)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", path, err)
}
for name, info := range header {
mw.tensorFiles[name] = path
mw.tensorInfo[name] = info
}
}
if len(mw.tensorFiles) == 0 {
return nil, fmt.Errorf("no tensors found in provided paths")
}
return mw, nil
}
// Load loads all tensors into cache with the specified dtype.
// If dtype is 0, tensors are loaded in their original dtype.
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,

353
x/imagegen/server.go Normal file
View File

@@ -0,0 +1,353 @@
package imagegen
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math/rand"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"sync"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
// Server wraps an image generation subprocess to implement llm.LlamaServer.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
vramSize uint64
done chan error
client *http.Client
lastErr string // Last stderr line for error reporting
lastErrLock sync.Mutex
}
// completionRequest is sent to the subprocess
type completionRequest struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content"`
Done bool `json:"done"`
}
// NewServer spawns a new image generation subprocess and waits until it's ready.
func NewServer(modelName string) (*Server, error) {
// Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil {
return nil, err
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
// Get the ollama executable path
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
vramSize: EstimateVRAM(modelName),
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("image-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("image-runner", "msg", line)
// Capture last error line for better error reporting
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
slog.Info("starting image runner subprocess", "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
s.done <- err
}()
// Wait for subprocess to be ready
if err := s.waitUntilRunning(); err != nil {
s.Close()
return nil, err
}
return s, nil
}
// ModelPath returns the path to the model.
func (s *Server) ModelPath() string {
return s.modelName
}
// Load is called by the scheduler after the server is created.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
// Ping checks if the subprocess is healthy.
func (s *Server) Ping(ctx context.Context) error {
url := fmt.Sprintf("http://127.0.0.1:%d/health", s.port)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
resp, err := s.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
return nil
}
// waitUntilRunning waits for the subprocess to be ready.
func (s *Server) waitUntilRunning() error {
ctx := context.Background()
timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case err := <-s.done:
// Include last stderr line for better error context
s.lastErrLock.Lock()
lastErr := s.lastErr
s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
}
return fmt.Errorf("image runner exited unexpectedly: %w", err)
case <-timeout:
s.lastErrLock.Lock()
lastErr := s.lastErr
s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
}
return errors.New("timeout waiting for image runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
slog.Info("image runner is ready", "port", s.port)
return nil
}
}
}
}
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
}
// Completion generates an image from the prompt via the subprocess.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Build request
creq := completionRequest{
Prompt: req.Prompt,
Width: 1024,
Height: 1024,
Steps: 9,
Seed: time.Now().UnixNano(),
}
if req.Options != nil {
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
creq.Width = int32(req.Options.NumCtx)
}
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
creq.Height = int32(req.Options.NumGPU)
}
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
creq.Steps = req.Options.NumPredict
}
if req.Options.Seed > 0 {
creq.Seed = int64(req.Options.Seed)
}
}
// Encode request body
body, err := json.Marshal(creq)
if err != nil {
return err
}
// Send request to subprocess
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
}
// Stream responses
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var cresp completionResponse
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
continue
}
fn(llm.CompletionResponse{
Content: cresp.Content,
Done: cresp.Done,
})
if cresp.Done {
break
}
}
return scanner.Err()
}
// Close terminates the subprocess.
func (s *Server) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
select {
case <-s.done:
case <-time.After(5 * time.Second):
s.cmd.Process.Kill()
}
s.cmd = nil
}
return nil
}
// VRAMSize returns the estimated VRAM usage.
func (s *Server) VRAMSize() uint64 {
return s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
}
// VRAMByGPU returns VRAM usage for a specific GPU.
func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
// Embedding is not supported for image generation models.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("embedding not supported for image generation models")
}
// Tokenize is not supported for image generation models.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("tokenize not supported for image generation models")
}
// Detokenize is not supported for image generation models.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("detokenize not supported for image generation models")
}
// Pid returns the subprocess PID.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
return s.cmd.Process.Pid
}
return -1
}
// GetPort returns the subprocess port.
func (s *Server) GetPort() int {
return s.port
}
// GetDeviceInfos returns nil since we don't track GPU info.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns true if the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:
return true
default:
return false
}
}
// Ensure Server implements llm.LlamaServer
var _ llm.LlamaServer = (*Server)(nil)

82
x/imagegen/server_test.go Normal file
View File

@@ -0,0 +1,82 @@
package imagegen
import (
"runtime"
"testing"
)
// TestPlatformSupport verifies platform validation works correctly.
func TestPlatformSupport(t *testing.T) {
err := CheckPlatformSupport()
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
// Apple Silicon should be supported
if err != nil {
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
}
} else {
// Intel Mac should fail
if err == nil {
t.Error("Expected error on darwin/amd64 (Intel), got nil")
}
if err != nil && err.Error() == "" {
t.Error("Expected meaningful error message for unsupported platform")
}
}
case "linux", "windows":
// Linux/Windows are allowed (CUDA support checked at runtime)
if err != nil {
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
}
default:
// Other platforms should fail
if err == nil {
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
}
}
}
// TestMemoryRequirementsError verifies memory check returns clear error.
func TestMemoryRequirementsError(t *testing.T) {
// Test with insufficient memory
err := CheckMemoryRequirements("test-model", 8*GB)
if err == nil {
t.Error("Expected error for insufficient memory (8GB < 21GB default)")
}
// Test with sufficient memory
err = CheckMemoryRequirements("test-model", 32*GB)
if err != nil {
t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
}
}
// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
// Unknown model should return default (21GB)
vram := EstimateVRAM("unknown-model")
if vram < 10*GB || vram > 100*GB {
t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
}
// Verify known pipeline estimates exist and are reasonable
for name, estimate := range modelVRAMEstimates {
if estimate < 10*GB {
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
}
if estimate > 200*GB {
t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
}
}
}
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
// This is a compile-time check but we document it as a test.
func TestServerInterfaceCompliance(t *testing.T) {
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
// ensures compile-time interface compliance.
// This test documents that requirement.
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
}

View File

@@ -256,6 +256,164 @@ func rewritePatternForRE2(pattern string) string {
return pattern
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data, "")
}
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data, "")
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
// Helper to parse eos_token_id which can be int or []int
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.GenerationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.GenerationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.ConfigJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.ConfigJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.TokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.TokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.SpecialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.SpecialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// Load loads a tokenizer from a path which can be:
// - A tokenizer.json file
// - A directory containing tokenizer.json or vocab.json + merges.txt

View File

@@ -0,0 +1,320 @@
package transfer
import (
"cmp"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
var (
errStalled = errors.New("download stalled")
errSlow = errors.New("download too slow")
)
type downloader struct {
client *http.Client
baseURL string
destDir string
repository string // Repository path for blob URLs (e.g., "library/model")
token *string
getToken func(context.Context, AuthChallenge) (string, error)
userAgent string
stallTimeout time.Duration
progress *progressTracker
speeds *speedTracker
logger *slog.Logger
}
func download(ctx context.Context, opts DownloadOptions) error {
if len(opts.Blobs) == 0 {
return nil
}
// Filter existing
var blobs []Blob
var total int64
for _, b := range opts.Blobs {
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
if opts.Logger != nil {
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
}
continue
}
blobs = append(blobs, b)
total += b.Size
}
if len(blobs) == 0 {
return nil
}
token := opts.Token
d := &downloader{
client: cmp.Or(opts.Client, defaultClient),
baseURL: opts.BaseURL,
destDir: opts.DestDir,
repository: cmp.Or(opts.Repository, "library/_"),
token: &token,
getToken: opts.GetToken,
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
progress: newProgressTracker(total, opts.Progress),
speeds: &speedTracker{},
logger: opts.Logger,
}
concurrency := cmp.Or(opts.Concurrency, DefaultDownloadConcurrency)
sem := semaphore.NewWeighted(int64(concurrency))
g, ctx := errgroup.WithContext(ctx)
for _, blob := range blobs {
g.Go(func() error {
if err := sem.Acquire(ctx, 1); err != nil {
return err
}
defer sem.Release(1)
return d.download(ctx, blob)
})
}
return g.Wait()
}
func (d *downloader) download(ctx context.Context, blob Blob) error {
var lastErr error
var slowRetries int
attempt := 0
for attempt < maxRetries {
if attempt > 0 {
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
return err
}
}
start := time.Now()
n, err := d.downloadOnce(ctx, blob)
if err == nil {
if s := time.Since(start).Seconds(); s > 0 {
d.speeds.record(float64(blob.Size) / s)
}
return nil
}
d.progress.add(-n) // rollback
switch {
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return err
case errors.Is(err, errStalled):
// Don't count stall retries against limit
case errors.Is(err, errSlow):
if slowRetries++; slowRetries >= 3 {
attempt++ // Only count after 3 slow retries
}
default:
attempt++
}
lastErr = err
}
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
}
func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error) {
if d.logger != nil {
d.logger.Debug("downloading blob", "digest", blob.Digest, "size", blob.Size)
}
baseURL, _ := url.Parse(d.baseURL)
u, err := d.resolve(ctx, fmt.Sprintf("%s/v2/%s/blobs/%s", d.baseURL, d.repository, blob.Digest))
if err != nil {
return 0, err
}
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
req.Header.Set("User-Agent", d.userAgent)
// Add auth only for same-host (not CDN)
if u.Host == baseURL.Host && *d.token != "" {
req.Header.Set("Authorization", "Bearer "+*d.token)
}
resp, err := d.client.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("status %d", resp.StatusCode)
}
return d.save(ctx, blob, resp.Body)
}
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
tmp := dest + ".tmp"
os.MkdirAll(filepath.Dir(dest), 0o755)
f, err := os.Create(tmp)
if err != nil {
return 0, err
}
defer f.Close()
setSparse(f)
h := sha256.New()
n, err := d.copy(ctx, f, r, h)
if err != nil {
os.Remove(tmp)
return n, err
}
f.Close()
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
os.Remove(tmp)
return n, fmt.Errorf("digest mismatch")
}
if n != blob.Size {
os.Remove(tmp)
return n, fmt.Errorf("size mismatch")
}
return n, os.Rename(tmp, dest)
}
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
var n int64
var lastRead atomic.Int64
lastRead.Store(time.Now().UnixNano())
start := time.Now()
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
go func() {
tick := time.NewTicker(time.Second)
defer tick.Stop()
for {
select {
case <-ctx.Done():
return
case <-tick.C:
if time.Since(time.Unix(0, lastRead.Load())) > d.stallTimeout {
cancel(errStalled)
return
}
if e := time.Since(start); e > 5*time.Second {
if m := d.speeds.median(); m > 0 && float64(n)/e.Seconds() < m*0.1 {
cancel(errSlow)
return
}
}
}
}
}()
buf := make([]byte, 32*1024)
for {
if err := ctx.Err(); err != nil {
if c := context.Cause(ctx); c != nil {
return n, c
}
return n, err
}
nr, err := src.Read(buf)
if nr > 0 {
lastRead.Store(time.Now().UnixNano())
dst.Write(buf[:nr])
h.Write(buf[:nr])
d.progress.add(int64(nr))
n += int64(nr)
}
if err == io.EOF {
return n, nil
}
if err != nil {
return n, err
}
}
}
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
u, _ := url.Parse(rawURL)
for range 10 {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
req.Header.Set("User-Agent", d.userAgent)
if *d.token != "" {
req.Header.Set("Authorization", "Bearer "+*d.token)
}
resp, err := d.client.Do(req)
if err != nil {
return nil, err
}
resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
return u, nil
case http.StatusUnauthorized:
if d.getToken == nil {
return nil, fmt.Errorf("unauthorized")
}
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *d.token, err = d.getToken(ctx, ch); err != nil {
return nil, err
}
case http.StatusTemporaryRedirect, http.StatusFound, http.StatusMovedPermanently:
loc, _ := resp.Location()
if loc.Host != u.Host {
return loc, nil
}
u = loc
default:
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
}
return nil, fmt.Errorf("too many redirects")
}
type speedTracker struct {
mu sync.Mutex
speeds []float64
}
func (s *speedTracker) record(v float64) {
s.mu.Lock()
s.speeds = append(s.speeds, v)
if len(s.speeds) > 30 {
s.speeds = s.speeds[1:]
}
s.mu.Unlock()
}
func (s *speedTracker) median() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.speeds) < 5 {
return 0
}
sorted := make([]float64, len(s.speeds))
copy(sorted, s.speeds)
slices.Sort(sorted)
return sorted[len(sorted)/2]
}
const defaultStallTimeout = 10 * time.Second

View File

@@ -0,0 +1,12 @@
//go:build !windows
package transfer
import "os"
// setSparse is a no-op on non-Windows platforms.
// On Windows, this sets the FSCTL_SET_SPARSE attribute which allows the OS
// to not allocate disk blocks for zero-filled regions. This is useful for
// partial downloads where not all data has been written yet. On Unix-like
// systems, filesystems typically handle this automatically (sparse by default).
func setSparse(_ *os.File) {}

View File

@@ -0,0 +1,31 @@
//go:build windows
package transfer
import (
"os"
"golang.org/x/sys/windows"
)
// setSparse sets the FSCTL_SET_SPARSE attribute on Windows files.
// This allows the OS to not allocate disk blocks for zero-filled regions,
// which is useful for large files that may not be fully written (e.g., partial
// downloads). Without this, Windows may pre-allocate disk space for the full
// file size even if most of it is zeros.
//
// Note: Errors are intentionally ignored because:
// 1. The file will still work correctly without sparse support
// 2. Not all Windows filesystems support sparse files (e.g., FAT32)
// 3. This is an optimization, not a requirement
func setSparse(file *os.File) {
var bytesReturned uint32
_ = windows.DeviceIoControl(
windows.Handle(file.Fd()),
windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
&bytesReturned,
nil,
)
}

View File

@@ -0,0 +1,218 @@
// Package transfer provides minimal, fast blob transfer for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It provides optimized transfer for models with many small blobs (tensor models)
// rather than few large blobs (typical LLMs).
//
// TODO (jmorganca): Integrate into server/download.go and server/upload.go when stable.
//
// Design Philosophy:
// This package is intentionally simpler than the main server's download/upload code.
// Key simplifications for many-small-blob workloads:
//
// - Whole-blob transfers: No part-based chunking. Each blob downloads/uploads as one unit.
// - No resume: If a transfer fails, it restarts from scratch (fine for small blobs).
// - Inline hashing: SHA256 computed during streaming, not asynchronously after parts complete.
// - Stall and speed detection: Cancels on no data (stall) or speed below 10% of median.
//
// For large models (multi-GB), use the server's download/upload code which has:
// - Part-based transfers with 64MB chunks
// - Resumable downloads with JSON state files
// - Async streamHasher that hashes from OS page cache as parts complete
// - Speed tracking with rolling median to detect and restart slow parts
package transfer
import (
"context"
"errors"
"log/slog"
"math/rand/v2"
"net/http"
"strings"
"sync/atomic"
"time"
)
// Blob represents a content-addressed blob to transfer.
type Blob struct {
Digest string // sha256:...
Size int64
// From enables cross-repository blob mounting (upload only).
// When set, the upload will first attempt to mount the blob from this source
// repository instead of uploading the data. This is a Docker Registry v2 API
// feature that avoids re-uploading blobs that already exist elsewhere.
//
// Example: From="library/source-model" will add ?mount=<digest>&from=library/source-model
// to the POST /blobs/uploads/ request. If the registry returns 201 Created,
// the blob was mounted successfully and no upload is needed.
//
// See: https://distribution.github.io/distribution/spec/api/#cross-repository-blob-mount
From string
}
// DownloadOptions configures a parallel download operation.
type DownloadOptions struct {
Blobs []Blob // Blobs to download
BaseURL string // Registry base URL
DestDir string // Destination directory for blobs
Repository string // Repository path for blob URLs (e.g., "library/model")
Concurrency int // Max parallel downloads (default 64)
Progress func(completed, total int64) // Progress callback (optional)
Client *http.Client // HTTP client (optional, uses default)
Token string // Auth token (optional)
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
Logger *slog.Logger // Optional structured logger
UserAgent string // User-Agent header (optional, has default)
StallTimeout time.Duration // Timeout for stall detection (default 10s)
}
// UploadOptions configures a parallel upload operation.
type UploadOptions struct {
Blobs []Blob // Blobs to upload
BaseURL string // Registry base URL
SrcDir string // Source directory containing blobs
Concurrency int // Max parallel uploads (default 32)
Progress func(completed, total int64) // Progress callback (optional)
Client *http.Client // HTTP client (optional, uses default)
Token string // Auth token (optional)
GetToken func(ctx context.Context, challenge AuthChallenge) (string, error) // Token refresh callback
Logger *slog.Logger // Optional structured logger
UserAgent string // User-Agent header (optional, has default)
// Manifest fields (optional) - if set, manifest is pushed after all blobs complete
Manifest []byte // Raw manifest JSON to push
ManifestRef string // Tag or digest for the manifest (e.g., "latest", "sha256:...")
Repository string // Repository path for manifest URL (e.g., "library/model")
}
// AuthChallenge represents a parsed WWW-Authenticate challenge.
type AuthChallenge struct {
Realm string
Service string
Scope string
}
// Default concurrency limits and settings
const (
DefaultDownloadConcurrency = 64
DefaultUploadConcurrency = 32
maxRetries = 6
defaultUserAgent = "ollama-transfer/1.0"
)
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// defaultClient is a shared HTTP client with connection pooling.
var defaultClient = &http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
},
Timeout: 5 * time.Minute,
// Don't follow redirects automatically - we handle them manually
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
// progressTracker aggregates progress across concurrent operations.
type progressTracker struct {
completed atomic.Int64
total int64
callback func(completed, total int64)
}
func newProgressTracker(total int64, callback func(completed, total int64)) *progressTracker {
return &progressTracker{
total: total,
callback: callback,
}
}
func (p *progressTracker) add(n int64) {
if p == nil || p.callback == nil {
return
}
completed := p.completed.Add(n)
p.callback(completed, p.total)
}
// Download downloads blobs in parallel with streaming hash verification.
func Download(ctx context.Context, opts DownloadOptions) error {
return download(ctx, opts)
}
// Upload uploads blobs in parallel.
func Upload(ctx context.Context, opts UploadOptions) error {
return upload(ctx, opts)
}
// digestToPath converts sha256:abc123 to sha256-abc123
func digestToPath(digest string) string {
if len(digest) > 7 && digest[6] == ':' {
return digest[:6] + "-" + digest[7:]
}
return digest
}
// parseAuthChallenge parses a WWW-Authenticate header value.
// Example: Bearer realm="https://auth.example.com",service="registry",scope="repository:foo:pull"
func parseAuthChallenge(header string) AuthChallenge {
header = strings.TrimPrefix(header, "Bearer ")
getValue := func(key string) string {
startIdx := strings.Index(header, key+"=")
if startIdx == -1 {
return ""
}
startIdx += len(key) + 1
if startIdx >= len(header) {
return ""
}
// Handle quoted values
if header[startIdx] == '"' {
startIdx++
endIdx := strings.Index(header[startIdx:], "\"")
if endIdx == -1 {
return header[startIdx:]
}
return header[startIdx : startIdx+endIdx]
}
// Unquoted value - ends at comma or end of string
endIdx := strings.Index(header[startIdx:], ",")
if endIdx == -1 {
return header[startIdx:]
}
return header[startIdx : startIdx+endIdx]
}
return AuthChallenge{
Realm: getValue("realm"),
Service: getValue("service"),
Scope: getValue("scope"),
}
}
// backoff returns a function that sleeps with exponential backoff.
func backoff(ctx context.Context, attempt int, maxBackoff time.Duration) error {
if ctx.Err() != nil {
return ctx.Err()
}
// n^2 backoff with jitter
d := min(time.Duration(attempt*attempt)*10*time.Millisecond, maxBackoff)
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return nil
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,346 @@
package transfer
import (
"bytes"
"cmp"
"context"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"path/filepath"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
type uploader struct {
client *http.Client
baseURL string
srcDir string
repository string // Repository path for blob URLs (e.g., "library/model")
token *string
getToken func(context.Context, AuthChallenge) (string, error)
userAgent string
progress *progressTracker
logger *slog.Logger
}
func upload(ctx context.Context, opts UploadOptions) error {
if len(opts.Blobs) == 0 && len(opts.Manifest) == 0 {
return nil
}
token := opts.Token
u := &uploader{
client: cmp.Or(opts.Client, defaultClient),
baseURL: opts.BaseURL,
srcDir: opts.SrcDir,
repository: cmp.Or(opts.Repository, "library/_"),
token: &token,
getToken: opts.GetToken,
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
logger: opts.Logger,
}
if len(opts.Blobs) > 0 {
// Phase 1: Fast parallel HEAD checks to find which blobs need uploading
needsUpload := make([]bool, len(opts.Blobs))
{
sem := semaphore.NewWeighted(128) // High concurrency for HEAD checks
g, gctx := errgroup.WithContext(ctx)
for i, blob := range opts.Blobs {
g.Go(func() error {
if err := sem.Acquire(gctx, 1); err != nil {
return err
}
defer sem.Release(1)
exists, err := u.exists(gctx, blob)
if err != nil {
return err
}
if !exists {
needsUpload[i] = true
} else if u.logger != nil {
u.logger.Debug("blob exists", "digest", blob.Digest)
}
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
}
// Filter to only blobs that need uploading
var toUpload []Blob
var total int64
for i, blob := range opts.Blobs {
if needsUpload[i] {
toUpload = append(toUpload, blob)
total += blob.Size
}
}
if len(toUpload) == 0 {
if u.logger != nil {
u.logger.Debug("all blobs exist, nothing to upload")
}
} else {
// Phase 2: Upload blobs that don't exist
u.progress = newProgressTracker(total, opts.Progress)
concurrency := cmp.Or(opts.Concurrency, DefaultUploadConcurrency)
sem := semaphore.NewWeighted(int64(concurrency))
g, gctx := errgroup.WithContext(ctx)
for _, blob := range toUpload {
g.Go(func() error {
if err := sem.Acquire(gctx, 1); err != nil {
return err
}
defer sem.Release(1)
return u.upload(gctx, blob)
})
}
if err := g.Wait(); err != nil {
return err
}
}
}
if len(opts.Manifest) > 0 && opts.ManifestRef != "" && opts.Repository != "" {
return u.pushManifest(ctx, opts.Repository, opts.ManifestRef, opts.Manifest)
}
return nil
}
func (u *uploader) upload(ctx context.Context, blob Blob) error {
var lastErr error
var n int64
for attempt := range maxRetries {
if attempt > 0 {
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
return err
}
}
var err error
n, err = u.uploadOnce(ctx, blob)
if err == nil {
return nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return err
}
u.progress.add(-n)
lastErr = err
}
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
}
func (u *uploader) uploadOnce(ctx context.Context, blob Blob) (int64, error) {
if u.logger != nil {
u.logger.Debug("uploading blob", "digest", blob.Digest, "size", blob.Size)
}
// Init upload
uploadURL, err := u.initUpload(ctx, blob)
if err != nil {
return 0, err
}
// Open file
f, err := os.Open(filepath.Join(u.srcDir, digestToPath(blob.Digest)))
if err != nil {
return 0, err
}
defer f.Close()
// PUT blob
return u.put(ctx, uploadURL, f, blob.Size)
}
func (u *uploader) exists(ctx context.Context, blob Blob) (bool, error) {
req, _ := http.NewRequestWithContext(ctx, http.MethodHead, fmt.Sprintf("%s/v2/%s/blobs/%s", u.baseURL, u.repository, blob.Digest), nil)
req.Header.Set("User-Agent", u.userAgent)
if *u.token != "" {
req.Header.Set("Authorization", "Bearer "+*u.token)
}
resp, err := u.client.Do(req)
if err != nil {
return false, err
}
resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *u.token, err = u.getToken(ctx, ch); err != nil {
return false, err
}
return u.exists(ctx, blob)
}
return resp.StatusCode == http.StatusOK, nil
}
func (u *uploader) initUpload(ctx context.Context, blob Blob) (string, error) {
endpoint, _ := url.Parse(fmt.Sprintf("%s/v2/%s/blobs/uploads/", u.baseURL, u.repository))
q := endpoint.Query()
q.Set("digest", blob.Digest)
endpoint.RawQuery = q.Encode()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), nil)
req.Header.Set("User-Agent", u.userAgent)
if *u.token != "" {
req.Header.Set("Authorization", "Bearer "+*u.token)
}
resp, err := u.client.Do(req)
if err != nil {
return "", err
}
resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *u.token, err = u.getToken(ctx, ch); err != nil {
return "", err
}
return u.initUpload(ctx, blob)
}
if resp.StatusCode != http.StatusAccepted {
return "", fmt.Errorf("init: status %d", resp.StatusCode)
}
loc := resp.Header.Get("Docker-Upload-Location")
if loc == "" {
loc = resp.Header.Get("Location")
}
if loc == "" {
return "", fmt.Errorf("no upload location")
}
locURL, _ := url.Parse(loc)
if !locURL.IsAbs() {
base, _ := url.Parse(u.baseURL)
locURL = base.ResolveReference(locURL)
}
q = locURL.Query()
q.Set("digest", blob.Digest)
locURL.RawQuery = q.Encode()
return locURL.String(), nil
}
func (u *uploader) put(ctx context.Context, uploadURL string, f *os.File, size int64) (int64, error) {
pr := &progressReader{reader: f, tracker: u.progress}
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, uploadURL, pr)
req.ContentLength = size
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("User-Agent", u.userAgent)
if *u.token != "" {
req.Header.Set("Authorization", "Bearer "+*u.token)
}
resp, err := u.client.Do(req)
if err != nil {
return pr.n, err
}
defer resp.Body.Close()
// Handle auth retry
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *u.token, err = u.getToken(ctx, ch); err != nil {
return pr.n, err
}
f.Seek(0, 0)
u.progress.add(-pr.n)
return u.put(ctx, uploadURL, f, size)
}
// Handle redirect to CDN
if resp.StatusCode == http.StatusTemporaryRedirect {
loc, _ := resp.Location()
f.Seek(0, 0)
u.progress.add(-pr.n)
pr2 := &progressReader{reader: f, tracker: u.progress}
req2, _ := http.NewRequestWithContext(ctx, http.MethodPut, loc.String(), pr2)
req2.ContentLength = size
req2.Header.Set("Content-Type", "application/octet-stream")
req2.Header.Set("User-Agent", u.userAgent)
resp2, err := u.client.Do(req2)
if err != nil {
return pr2.n, err
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusCreated && resp2.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp2.Body)
return pr2.n, fmt.Errorf("status %d: %s", resp2.StatusCode, body)
}
return pr2.n, nil
}
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp.Body)
return pr.n, fmt.Errorf("status %d: %s", resp.StatusCode, body)
}
return pr.n, nil
}
func (u *uploader) pushManifest(ctx context.Context, repo, ref string, manifest []byte) error {
req, _ := http.NewRequestWithContext(ctx, http.MethodPut, fmt.Sprintf("%s/v2/%s/manifests/%s", u.baseURL, repo, ref), bytes.NewReader(manifest))
req.Header.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
req.Header.Set("User-Agent", u.userAgent)
if *u.token != "" {
req.Header.Set("Authorization", "Bearer "+*u.token)
}
resp, err := u.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized && u.getToken != nil {
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
if *u.token, err = u.getToken(ctx, ch); err != nil {
return err
}
return u.pushManifest(ctx, repo, ref, manifest)
}
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("status %d: %s", resp.StatusCode, body)
}
return nil
}
type progressReader struct {
reader io.Reader
tracker *progressTracker
n int64
}
func (r *progressReader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
if n > 0 {
r.n += int64(n)
r.tracker.add(int64(n))
}
return n, err
}

116
x/imagegen/weights.go Normal file
View File

@@ -0,0 +1,116 @@
//go:build mlx
package imagegen
import (
"fmt"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// ManifestWeights provides fast weight loading from tensor blobs.
// Uses native mmap loading with synthetic safetensors headers for zero-copy.
type ManifestWeights struct {
manifest *ModelManifest
component string
tensors map[string]ManifestLayer // name -> layer
cache map[string]*mlx.Array // name -> loaded array
nativeCache []*mlx.SafetensorsFile // keep native handles alive
}
// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
layers := manifest.GetTensorLayers(component)
if len(layers) == 0 {
return nil, fmt.Errorf("no tensor layers found for component %q", component)
}
// Strip component prefix from tensor names for model loading
// e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
prefix := component + "/"
tensors := make(map[string]ManifestLayer, len(layers))
for _, layer := range layers {
tensorName := strings.TrimPrefix(layer.Name, prefix)
tensors[tensorName] = layer
}
return &ManifestWeights{
manifest: manifest,
component: component,
tensors: tensors,
cache: make(map[string]*mlx.Array),
}, nil
}
// Load loads all tensor blobs using native mmap (zero-copy).
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
// If dtype is non-zero, tensors are converted to the specified dtype.
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
for name, layer := range mw.tensors {
path := mw.manifest.BlobPath(layer.Digest)
// Load blob as safetensors (native mmap, zero-copy)
sf, err := mlx.LoadSafetensorsNative(path)
if err != nil {
return fmt.Errorf("load %s: %w", name, err)
}
// Blob contains single tensor named "data"
arr := sf.Get("data")
if arr == nil {
sf.Free()
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
}
// Convert dtype if needed
if dtype != 0 && arr.Dtype() != dtype {
arr = mlx.AsType(arr, dtype)
}
// ALWAYS make a contiguous copy to ensure independence from mmap
arr = mlx.Contiguous(arr)
mlx.Eval(arr)
mw.cache[name] = arr
sf.Free() // Safe to free - arr is now an independent copy
}
return nil
}
// GetTensor returns a tensor from cache. Call Load() first.
func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
if mw.cache == nil {
return nil, fmt.Errorf("cache not initialized: call Load() first")
}
arr, ok := mw.cache[name]
if !ok {
return nil, fmt.Errorf("tensor %q not found", name)
}
return arr, nil
}
// ListTensors returns all tensor names in sorted order.
func (mw *ManifestWeights) ListTensors() []string {
names := make([]string, 0, len(mw.tensors))
for name := range mw.tensors {
names = append(names, name)
}
sort.Strings(names)
return names
}
// HasTensor checks if a tensor exists.
func (mw *ManifestWeights) HasTensor(name string) bool {
_, ok := mw.tensors[name]
return ok
}
// ReleaseAll frees all native handles and clears the tensor cache.
func (mw *ManifestWeights) ReleaseAll() {
for _, sf := range mw.nativeCache {
sf.Free()
}
mw.nativeCache = nil
mw.cache = nil
}