Files
ollama/x/imagegen/models/zimage/zimage.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

364 lines
11 KiB
Go

//go:build mlx
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Z-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Z-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Z-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 9 // Turbo default
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.TransformerConfig
latentH := cfg.Height / 8
latentW := cfg.Width / 8
hTok := latentH / tcfg.PatchSize
wTok := latentW / tcfg.PatchSize
// Text encoding with padding to multiple of 32
var posEmb, negEmb *mlx.Array
{
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
if useCFG {
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
}
// Pad both to same length (multiple of 32)
maxLen := posEmb.Shape()[1]
if useCFG && negEmb.Shape()[1] > maxLen {
maxLen = negEmb.Shape()[1]
}
if pad := (32 - (maxLen % 32)) % 32; pad > 0 {
maxLen += pad
}
posEmb = padToLength(posEmb, maxLen)
if useCFG {
negEmb = padToLength(negEmb, maxLen)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Scheduler
scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig())
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok))
// Init latents [B, C, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1])
mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin,
ropeCache.UnifiedCos, ropeCache.UnifiedSin)
mlx.Eval(ropeCache.UnifiedCos)
}
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
tCurr := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if useCFG {
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else {
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
}
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
return decoded, nil
}
// padToLength pads a sequence tensor to the target length by repeating the last token.
func padToLength(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]})
padding := mlx.Tile(lastToken, []int32{1, padLen, 1})
return mlx.Concatenate([]*mlx.Array{x, padding}, 1)
}
// CalculateShift computes the mu shift value for dynamic scheduling
func CalculateShift(imgSeqLen int32) float32 {
baseSeqLen := float32(256)
maxSeqLen := float32(4096)
baseShift := float32(0.5)
maxShift := float32(1.15)
m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen)
b := baseShift - m*baseSeqLen
return float32(imgSeqLen)*m + b
}