Files
ollama/x/imagegen/models/zimage/vae.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

653 lines
15 KiB
Go

//go:build mlx
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"`
OutChannels int32 `json:"out_channels"`
LatentChannels int32 `json:"latent_channels"`
BlockOutChannels []int32 `json:"block_out_channels"`
LayersPerBlock int32 `json:"layers_per_block"`
NormNumGroups int32 `json:"norm_num_groups"`
ScalingFactor float32 `json:"scaling_factor"`
ShiftFactor float32 `json:"shift_factor"`
}
// loadVAEConfig loads VAE config from a JSON file
func loadVAEConfig(path string) (*VAEConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg VAEConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// GroupNormLayer implements group normalization
type GroupNormLayer struct {
Weight *mlx.Array
Bias *mlx.Array
NumGroups int32
Eps float32
}
// NewGroupNorm creates a group norm layer
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
return &GroupNormLayer{
Weight: weight,
Bias: bias,
NumGroups: numGroups,
Eps: 1e-5,
}
}
// Forward applies group normalization
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
// x: [B, C, H, W]
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Reshape to [B, groups, C/groups, H, W]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
// Compute mean and variance per group
mean := mlx.Mean(x, 2, true)
mean = mlx.Mean(mean, 3, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
variance = mlx.Mean(variance, 3, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, C, H, W]
xNorm = mlx.Reshape(xNorm, B, C, H, W)
// Scale and shift (weight and bias are [C])
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Conv2D represents a 2D convolution layer
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
type Conv2D struct {
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
Bias *mlx.Array // [out_channels]
Stride int32
Padding int32
}
// NewConv2D creates a Conv2D layer
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
// Transpose weight from OIHW to OHWI
// [O, I, H, W] -> [O, H, W, I]
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
return &Conv2D{
Weight: weightOHWI,
Bias: bias,
Stride: stride,
Padding: padding,
}
}
// Forward applies convolution
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
// x: [N, C, H, W] -> [N, H, W, C]
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
// Conv in NHWC format
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer
Conv1 *Conv2D
Norm2 *GroupNormLayer
Conv2 *Conv2D
ConvShortcut *Conv2D // nil if in_channels == out_channels
}
// NewResnetBlock2D creates a ResNet block
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
if err != nil {
return nil, err
}
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
if err != nil {
return nil, err
}
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
if err != nil {
return nil, err
}
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
if err != nil {
return nil, err
}
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
if err != nil {
return nil, err
}
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
if err != nil {
return nil, err
}
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
if err != nil {
return nil, err
}
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
if err != nil {
return nil, err
}
block := &ResnetBlock2D{
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
}
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
if err != nil {
return nil, err
}
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
if err != nil {
return nil, err
}
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
}
return block, nil
}
// Forward applies the ResNet block with staged evaluation
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
// Stage 1: norm1
{
h = rb.Norm1.Forward(x)
mlx.Eval(h)
}
// Stage 2: silu + conv1
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 3: norm2
{
prev := h
h = rb.Norm2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: silu + conv2
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Residual connection
{
prev := h
if rb.ConvShortcut != nil {
shortcut := rb.ConvShortcut.Forward(x)
h = mlx.Add(h, shortcut)
} else {
h = mlx.Add(h, x)
}
prev.Free()
mlx.Eval(h)
}
return h
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer
ToQWeight *mlx.Array
ToQBias *mlx.Array
ToKWeight *mlx.Array
ToKBias *mlx.Array
ToVWeight *mlx.Array
ToVBias *mlx.Array
ToOutWeight *mlx.Array
ToOutBias *mlx.Array
NumHeads int32
}
// NewVAEAttentionBlock creates an attention block
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
if err != nil {
return nil, err
}
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
if err != nil {
return nil, err
}
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
if err != nil {
return nil, err
}
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
if err != nil {
return nil, err
}
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
if err != nil {
return nil, err
}
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
if err != nil {
return nil, err
}
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
if err != nil {
return nil, err
}
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
if err != nil {
return nil, err
}
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
if err != nil {
return nil, err
}
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
if err != nil {
return nil, err
}
return &VAEAttentionBlock{
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
ToQBias: toQBias,
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
ToKBias: toKBias,
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
ToVBias: toVBias,
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
ToOutBias: toOutBias,
NumHeads: 1,
}, nil
}
// Forward applies attention with staged evaluation
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
var h *mlx.Array
// Stage 1: GroupNorm + reshape
{
h = ab.GroupNorm.Forward(x)
h = mlx.Transpose(h, 0, 2, 3, 1)
h = mlx.Reshape(h, B, H*W, C)
mlx.Eval(h)
}
var out *mlx.Array
// Stage 2: Q, K, V projections + attention
{
q := mlx.Linear(h, ab.ToQWeight)
q = mlx.Add(q, ab.ToQBias)
k := mlx.Linear(h, ab.ToKWeight)
k = mlx.Add(k, ab.ToKBias)
v := mlx.Linear(h, ab.ToVWeight)
v = mlx.Add(v, ab.ToVBias)
h.Free()
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
mlx.Eval(out)
}
// Stage 3: Output projection + reshape + residual
{
prev := out
out = mlx.Linear(out, ab.ToOutWeight)
out = mlx.Add(out, ab.ToOutBias)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Transpose(out, 0, 3, 1, 2)
out = mlx.Add(out, residual)
prev.Free()
mlx.Eval(out)
}
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// NewUpDecoderBlock2D creates an up decoder block
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
if err != nil {
return nil, err
}
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
if err != nil {
return nil, err
}
upsample = NewConv2D(upWeight, upBias, 1, 1)
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Forward applies the up decoder block with staged evaluation to reduce peak memory
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
prev := x
x = resnet.Forward(x) // ResNet handles its own pools
prev.Free()
}
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
}
// Stage 2: Upsample conv
{
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
}
}
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// NewVAEMidBlock creates the mid block
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// Forward applies the mid block with staged evaluation
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
prev.Free()
// Attention handles its own pools
prev = x
x = mb.Attention.Forward(x)
prev.Free()
prev = x
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
ConvIn *Conv2D
MidBlock *VAEMidBlock
UpBlocks []*UpDecoderBlock2D
ConvNormOut *GroupNormLayer
ConvOut *Conv2D
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading VAE decoder...")
// Load config
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
if err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = cfg
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Load conv_in
fmt.Print(" Loading conv_in... ")
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
if err != nil {
return err
}
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
if err != nil {
return err
}
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
fmt.Println("✓")
// Load mid block
fmt.Print(" Loading mid block... ")
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return err
}
fmt.Println("✓")
// Load up blocks
fmt.Print(" Loading up blocks... ")
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return err
}
}
fmt.Printf("✓ [%d blocks]\n", numBlocks)
// Load conv_norm_out
fmt.Print(" Loading conv_norm_out... ")
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
if err != nil {
return err
}
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
if err != nil {
return err
}
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
fmt.Println("✓")
// Load conv_out
fmt.Print(" Loading conv_out... ")
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
if err != nil {
return err
}
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
if err != nil {
return err
}
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Decode decodes latents to images.
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
var h *mlx.Array
{
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
h = vae.ConvIn.Forward(z)
mlx.Eval(h)
}
h = vae.MidBlock.Forward(h)
for _, upBlock := range vae.UpBlocks {
h = upBlock.Forward(h)
}
{
prev := h
h = vae.ConvNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
prev.Free()
mlx.Eval(h)
}
return h
}
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
// x: [B, C, H, W] -> [B, C, H*2, W*2]
func Upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// [B, C, H, W] -> [B, C, H, 1, W, 1]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Broadcast to [B, C, H, 2, W, 2]
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
// Reshape to [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H*2, W*2)
return x
}