mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
791 lines
20 KiB
Go
791 lines
20 KiB
Go
//go:build mlx
|
|
|
|
package zimage
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/imagegen"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
// VAEConfig holds VAE decoder configuration
|
|
type VAEConfig struct {
|
|
InChannels int32 `json:"in_channels"`
|
|
OutChannels int32 `json:"out_channels"`
|
|
LatentChannels int32 `json:"latent_channels"`
|
|
BlockOutChannels []int32 `json:"block_out_channels"`
|
|
LayersPerBlock int32 `json:"layers_per_block"`
|
|
NormNumGroups int32 `json:"norm_num_groups"`
|
|
ScalingFactor float32 `json:"scaling_factor"`
|
|
ShiftFactor float32 `json:"shift_factor"`
|
|
}
|
|
|
|
// GroupNormLayer implements group normalization
|
|
type GroupNormLayer struct {
|
|
Weight *mlx.Array
|
|
Bias *mlx.Array
|
|
NumGroups int32
|
|
Eps float32
|
|
}
|
|
|
|
// NewGroupNorm creates a group norm layer
|
|
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
|
|
return &GroupNormLayer{
|
|
Weight: weight,
|
|
Bias: bias,
|
|
NumGroups: numGroups,
|
|
Eps: 1e-5,
|
|
}
|
|
}
|
|
|
|
// Forward applies group normalization
|
|
// Input and output are in NHWC format [B, H, W, C]
|
|
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
|
// x: [B, H, W, C] (NHWC format)
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
C := shape[3]
|
|
|
|
// For large spatial sizes, use tiled computation to avoid CUDA grid limits
|
|
// CUDA grid.y max is 65535, so H*W/16 must be <= 65535, meaning H*W <= ~1M
|
|
// To be safe, tile when H*W > 512*512 = 262144
|
|
if H*W > 512*512 {
|
|
return gn.forwardTiled(x, B, H, W, C)
|
|
}
|
|
|
|
return gn.forwardSmall(x, B, H, W, C)
|
|
}
|
|
|
|
// forwardSmall is the standard GroupNorm for tensors that fit within CUDA grid limits
|
|
func (gn *GroupNormLayer) forwardSmall(x *mlx.Array, B, H, W, C int32) *mlx.Array {
|
|
// Reshape to [B, H, W, groups, C/groups]
|
|
groupSize := C / gn.NumGroups
|
|
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
|
|
|
|
// Compute mean and variance per group (over H, W, and C/groups dimensions)
|
|
mean := mlx.Mean(x, 1, true)
|
|
mean = mlx.Mean(mean, 2, true)
|
|
mean = mlx.Mean(mean, 4, true)
|
|
|
|
xCentered := mlx.Sub(x, mean)
|
|
|
|
// Variance over same axes
|
|
sq := mlx.Square(xCentered)
|
|
variance := mlx.Mean(sq, 1, true)
|
|
variance = mlx.Mean(variance, 2, true)
|
|
variance = mlx.Mean(variance, 4, true)
|
|
|
|
// Normalize
|
|
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
|
|
|
// Reshape back to [B, H, W, C]
|
|
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
|
|
|
// Scale and shift (weight and bias are [C])
|
|
if gn.Weight != nil {
|
|
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
|
xNorm = mlx.Mul(xNorm, weight)
|
|
}
|
|
if gn.Bias != nil {
|
|
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
|
xNorm = mlx.Add(xNorm, bias)
|
|
}
|
|
|
|
return xNorm
|
|
}
|
|
|
|
// forwardTiled handles large tensors by processing in H-tiles to avoid CUDA grid limits
|
|
func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Array {
|
|
groupSize := C / gn.NumGroups
|
|
|
|
// Keep the input - we need it for slicing tiles later
|
|
mlx.Keep(x)
|
|
|
|
// Compute per-group mean and variance using flattened spatial dimensions
|
|
// Build the entire compute graph first, then eval once
|
|
// Reshape to [B, H*W, groups, groupSize]
|
|
xFlat := mlx.Reshape(x, B, H*W, gn.NumGroups, groupSize)
|
|
|
|
// Mean over spatial (axis 1) and groupSize (axis 3) dimensions
|
|
// Result shape: [B, 1, groups, 1]
|
|
mean1 := mlx.Mean(xFlat, 1, true)
|
|
mean := mlx.Mean(mean1, 3, true)
|
|
|
|
// Variance using E[X^2] - E[X]^2
|
|
xSq := mlx.Square(xFlat)
|
|
meanSq1 := mlx.Mean(xSq, 1, true)
|
|
meanSq := mlx.Mean(meanSq1, 3, true)
|
|
meanSquared := mlx.Square(mean)
|
|
variance := mlx.Sub(meanSq, meanSquared)
|
|
|
|
// invStd = 1/sqrt(var + eps)
|
|
varPlusEps := mlx.AddScalar(variance, gn.Eps)
|
|
stdDev := mlx.Sqrt(varPlusEps)
|
|
one := mlx.Full(1.0, 1)
|
|
invStd := mlx.Div(one, stdDev)
|
|
|
|
// Eval mean and invStd together - these are what we need for the tile loop
|
|
mlx.Keep(mean, invStd)
|
|
mlx.Eval(mean, invStd)
|
|
|
|
// Tile along H dimension
|
|
tileH := int32(512 * 512 / W)
|
|
if tileH < 1 {
|
|
tileH = 1
|
|
}
|
|
if tileH > H {
|
|
tileH = H
|
|
}
|
|
|
|
// Prepare weight and bias reshaped for 4D broadcast [1, 1, groups, groupSize]
|
|
var weightGN, biasGN *mlx.Array
|
|
if gn.Weight != nil {
|
|
weightGN = mlx.Reshape(gn.Weight, 1, 1, gn.NumGroups, groupSize)
|
|
mlx.Keep(weightGN)
|
|
mlx.Eval(weightGN)
|
|
}
|
|
if gn.Bias != nil {
|
|
biasGN = mlx.Reshape(gn.Bias, 1, 1, gn.NumGroups, groupSize)
|
|
mlx.Keep(biasGN)
|
|
mlx.Eval(biasGN)
|
|
}
|
|
|
|
var tiles []*mlx.Array
|
|
for hStart := int32(0); hStart < H; hStart += tileH {
|
|
hEnd := hStart + tileH
|
|
if hEnd > H {
|
|
hEnd = H
|
|
}
|
|
tileHeight := hEnd - hStart
|
|
spatialSize := tileHeight * W
|
|
|
|
// Build the compute graph for this tile (no intermediate Evals)
|
|
// Extract tile and flatten spatial dims: [B, tileH*W, groups, groupSize]
|
|
tile := mlx.Slice(x, []int32{0, hStart, 0, 0}, []int32{B, hEnd, W, C})
|
|
tileFlat := mlx.Reshape(tile, B, spatialSize, gn.NumGroups, groupSize)
|
|
|
|
// Normalize: (x - mean) * invStd
|
|
tileCentered := mlx.Sub(tileFlat, mean)
|
|
tileNorm := mlx.Mul(tileCentered, invStd)
|
|
|
|
// Apply scale and shift in 4D space
|
|
if weightGN != nil {
|
|
tileNorm = mlx.Mul(tileNorm, weightGN)
|
|
}
|
|
if biasGN != nil {
|
|
tileNorm = mlx.Add(tileNorm, biasGN)
|
|
}
|
|
|
|
// Reshape back to [B, tileH, W, C]
|
|
tileOut := mlx.Reshape(tileNorm, B, tileHeight, W, C)
|
|
|
|
// Now eval and keep this tile
|
|
mlx.Keep(tileOut)
|
|
mlx.Eval(tileOut)
|
|
|
|
tiles = append(tiles, tileOut)
|
|
}
|
|
|
|
// Concatenate tiles along H axis
|
|
var result *mlx.Array
|
|
if len(tiles) == 1 {
|
|
result = tiles[0]
|
|
} else {
|
|
result = mlx.Concatenate(tiles, 1)
|
|
mlx.Eval(result)
|
|
// Free the individual tiles now that they're concatenated
|
|
for _, t := range tiles {
|
|
t.Free()
|
|
}
|
|
}
|
|
|
|
// Clean up kept arrays
|
|
mean.Free()
|
|
invStd.Free()
|
|
if weightGN != nil {
|
|
weightGN.Free()
|
|
}
|
|
if biasGN != nil {
|
|
biasGN.Free()
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// Conv2D represents a 2D convolution layer
|
|
// Works natively in NHWC format (MLX's native format)
|
|
type Conv2D struct {
|
|
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
|
Bias *mlx.Array // [out_channels]
|
|
Stride int32
|
|
Padding int32
|
|
}
|
|
|
|
// NewConv2D creates a Conv2D layer
|
|
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
|
|
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
|
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
|
|
// Transpose weight from OIHW to OHWI
|
|
// [O, I, H, W] -> [O, H, W, I]
|
|
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
|
|
return &Conv2D{
|
|
Weight: weightOHWI,
|
|
Bias: bias,
|
|
Stride: stride,
|
|
Padding: padding,
|
|
}
|
|
}
|
|
|
|
// Forward applies convolution
|
|
// Input and output are in NHWC format [N, H, W, C]
|
|
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
|
// Conv in NHWC format (MLX native)
|
|
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
|
|
|
|
if conv.Bias != nil {
|
|
// Bias is [C], reshape to [1, 1, 1, C] for NHWC broadcast
|
|
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
|
|
out = mlx.Add(out, bias)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
// ResnetBlock2D implements a ResNet block for VAE
|
|
type ResnetBlock2D struct {
|
|
Norm1 *GroupNormLayer
|
|
Conv1 *Conv2D
|
|
Norm2 *GroupNormLayer
|
|
Conv2 *Conv2D
|
|
ConvShortcut *Conv2D // nil if in_channels == out_channels
|
|
}
|
|
|
|
// NewResnetBlock2D creates a ResNet block
|
|
func NewResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
|
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
block := &ResnetBlock2D{
|
|
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
|
|
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
|
|
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
|
|
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
|
|
}
|
|
|
|
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
|
|
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
|
|
}
|
|
|
|
return block, nil
|
|
}
|
|
|
|
// Forward applies the ResNet block with staged evaluation
|
|
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
|
var h *mlx.Array
|
|
|
|
// Stage 1: norm1
|
|
{
|
|
h = rb.Norm1.Forward(x)
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 2: silu + conv1
|
|
{
|
|
prev := h
|
|
h = mlx.SiLU(h)
|
|
h = rb.Conv1.Forward(h)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 3: norm2
|
|
{
|
|
prev := h
|
|
h = rb.Norm2.Forward(h)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 4: silu + conv2
|
|
{
|
|
prev := h
|
|
h = mlx.SiLU(h)
|
|
h = rb.Conv2.Forward(h)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Residual connection
|
|
{
|
|
prev := h
|
|
if rb.ConvShortcut != nil {
|
|
shortcut := rb.ConvShortcut.Forward(x)
|
|
h = mlx.Add(h, shortcut)
|
|
} else {
|
|
h = mlx.Add(h, x)
|
|
}
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
return h
|
|
}
|
|
|
|
// VAEAttentionBlock implements self-attention for VAE
|
|
type VAEAttentionBlock struct {
|
|
GroupNorm *GroupNormLayer
|
|
ToQWeight *mlx.Array
|
|
ToQBias *mlx.Array
|
|
ToKWeight *mlx.Array
|
|
ToKBias *mlx.Array
|
|
ToVWeight *mlx.Array
|
|
ToVBias *mlx.Array
|
|
ToOutWeight *mlx.Array
|
|
ToOutBias *mlx.Array
|
|
NumHeads int32
|
|
}
|
|
|
|
// NewVAEAttentionBlock creates an attention block
|
|
func NewVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
|
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VAEAttentionBlock{
|
|
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
|
|
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
|
|
ToQBias: toQBias,
|
|
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
|
|
ToKBias: toKBias,
|
|
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
|
|
ToVBias: toVBias,
|
|
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
|
|
ToOutBias: toOutBias,
|
|
NumHeads: 1,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies attention with staged evaluation
|
|
// Input and output are in NHWC format [B, H, W, C]
|
|
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
residual := x
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
C := shape[3]
|
|
|
|
var h *mlx.Array
|
|
|
|
// Stage 1: GroupNorm + reshape to [B, H*W, C]
|
|
{
|
|
h = ab.GroupNorm.Forward(x)
|
|
h = mlx.Reshape(h, B, H*W, C)
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
var out *mlx.Array
|
|
|
|
// Stage 2: Q, K, V projections + attention
|
|
{
|
|
q := mlx.Linear(h, ab.ToQWeight)
|
|
q = mlx.Add(q, ab.ToQBias)
|
|
k := mlx.Linear(h, ab.ToKWeight)
|
|
k = mlx.Add(k, ab.ToKBias)
|
|
v := mlx.Linear(h, ab.ToVWeight)
|
|
v = mlx.Add(v, ab.ToVBias)
|
|
h.Free()
|
|
|
|
q = mlx.ExpandDims(q, 1)
|
|
k = mlx.ExpandDims(k, 1)
|
|
v = mlx.ExpandDims(v, 1)
|
|
|
|
scale := float32(1.0 / math.Sqrt(float64(C)))
|
|
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
out = mlx.Squeeze(out, 1)
|
|
mlx.Eval(out)
|
|
}
|
|
|
|
// Stage 3: Output projection + reshape + residual
|
|
{
|
|
prev := out
|
|
out = mlx.Linear(out, ab.ToOutWeight)
|
|
out = mlx.Add(out, ab.ToOutBias)
|
|
out = mlx.Reshape(out, B, H, W, C)
|
|
out = mlx.Add(out, residual)
|
|
prev.Free()
|
|
mlx.Eval(out)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
// UpDecoderBlock2D implements an upsampling decoder block
|
|
type UpDecoderBlock2D struct {
|
|
ResnetBlocks []*ResnetBlock2D
|
|
Upsample *Conv2D
|
|
}
|
|
|
|
// NewUpDecoderBlock2D creates an up decoder block
|
|
func NewUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
|
resnets := make([]*ResnetBlock2D, numLayers)
|
|
for i := int32(0); i < numLayers; i++ {
|
|
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resnets[i] = resnet
|
|
}
|
|
|
|
var upsample *Conv2D
|
|
if hasUpsample {
|
|
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upsample = NewConv2D(upWeight, upBias, 1, 1)
|
|
}
|
|
|
|
return &UpDecoderBlock2D{
|
|
ResnetBlocks: resnets,
|
|
Upsample: upsample,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the up decoder block with staged evaluation to reduce peak memory
|
|
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
|
for _, resnet := range ub.ResnetBlocks {
|
|
prev := x
|
|
x = resnet.Forward(x) // ResNet handles its own pools
|
|
prev.Free()
|
|
}
|
|
|
|
if ub.Upsample != nil {
|
|
// Stage 1: Upsample2x (nearest neighbor)
|
|
{
|
|
prev := x
|
|
x = Upsample2x(x)
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
// Stage 2: Upsample conv
|
|
{
|
|
prev := x
|
|
x = ub.Upsample.Forward(x)
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
}
|
|
|
|
return x
|
|
}
|
|
|
|
// VAEMidBlock is the middle block with attention
|
|
type VAEMidBlock struct {
|
|
Resnet1 *ResnetBlock2D
|
|
Attention *VAEAttentionBlock
|
|
Resnet2 *ResnetBlock2D
|
|
}
|
|
|
|
// NewVAEMidBlock creates the mid block
|
|
func NewVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
|
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VAEMidBlock{
|
|
Resnet1: resnet1,
|
|
Attention: attention,
|
|
Resnet2: resnet2,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the mid block with staged evaluation
|
|
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
prev := x
|
|
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
|
|
prev.Free()
|
|
|
|
// Attention handles its own pools
|
|
prev = x
|
|
x = mb.Attention.Forward(x)
|
|
prev.Free()
|
|
|
|
prev = x
|
|
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
|
|
prev.Free()
|
|
|
|
return x
|
|
}
|
|
|
|
// VAEDecoder is the full VAE decoder
|
|
type VAEDecoder struct {
|
|
Config *VAEConfig
|
|
ConvIn *Conv2D
|
|
MidBlock *VAEMidBlock
|
|
UpBlocks []*UpDecoderBlock2D
|
|
ConvNormOut *GroupNormLayer
|
|
ConvOut *Conv2D
|
|
}
|
|
|
|
// Load loads the VAE decoder from ollama blob storage.
|
|
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
|
// Load config from blob
|
|
var cfg VAEConfig
|
|
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
|
return fmt.Errorf("config: %w", err)
|
|
}
|
|
m.Config = &cfg
|
|
|
|
// Load weights from tensor blobs
|
|
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
|
if err != nil {
|
|
return fmt.Errorf("weights: %w", err)
|
|
}
|
|
if err := weights.Load(0); err != nil {
|
|
return fmt.Errorf("load weights: %w", err)
|
|
}
|
|
defer weights.ReleaseAll()
|
|
|
|
return m.loadWeights(weights, &cfg)
|
|
}
|
|
|
|
// loadWeights loads VAE weights from any WeightSource
|
|
func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
|
var err error
|
|
|
|
// Load conv_in
|
|
fmt.Print(" Loading conv_in... ")
|
|
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
|
|
fmt.Println("✓")
|
|
|
|
// Load mid block
|
|
fmt.Print(" Loading mid block... ")
|
|
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fmt.Println("✓")
|
|
|
|
// Load up blocks
|
|
fmt.Print(" Loading up blocks... ")
|
|
numBlocks := len(cfg.BlockOutChannels)
|
|
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
|
|
for i := 0; i < numBlocks; i++ {
|
|
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
|
hasUpsample := i < numBlocks-1
|
|
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
fmt.Printf("✓ [%d blocks]\n", numBlocks)
|
|
|
|
// Load conv_norm_out
|
|
fmt.Print(" Loading conv_norm_out... ")
|
|
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
|
|
fmt.Println("✓")
|
|
|
|
// Load conv_out
|
|
fmt.Print(" Loading conv_out... ")
|
|
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
|
|
fmt.Println("✓")
|
|
|
|
return nil
|
|
}
|
|
|
|
// Decode decodes latents to images.
|
|
// Input latents are in NCHW format, output is in NCHW format.
|
|
// Internally uses NHWC format (MLX native) for all operations.
|
|
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
|
// Scale latents
|
|
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
|
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
|
// Convert NCHW -> NHWC for internal processing
|
|
z = mlx.Transpose(z, 0, 2, 3, 1)
|
|
h := vae.ConvIn.Forward(z)
|
|
mlx.Eval(h)
|
|
|
|
h = vae.MidBlock.Forward(h)
|
|
|
|
for _, upBlock := range vae.UpBlocks {
|
|
h = upBlock.Forward(h)
|
|
}
|
|
|
|
prev := h
|
|
h = vae.ConvNormOut.Forward(h)
|
|
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
|
h = mlx.SiLU(h)
|
|
h = vae.ConvOut.Forward(h)
|
|
mlx.Eval(h)
|
|
|
|
// VAE outputs [-1, 1], convert to [0, 1]
|
|
h = mlx.MulScalar(h, 0.5)
|
|
h = mlx.AddScalar(h, 0.5)
|
|
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
|
|
|
// Convert NHWC -> NCHW for output
|
|
h = mlx.Transpose(h, 0, 3, 1, 2)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
|
|
return h
|
|
}
|
|
|
|
// Upsample2x performs 2x nearest neighbor upsampling using Take.
|
|
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
|
// Uses Take with repeated indices to produce contiguous output.
|
|
func Upsample2x(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
H := shape[1]
|
|
W := shape[2]
|
|
|
|
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
|
// For H dimension
|
|
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
|
|
hIdx = mlx.Reshape(hIdx, H, 1)
|
|
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
|
|
hIdx = mlx.Reshape(hIdx, H*2)
|
|
|
|
// For W dimension
|
|
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
|
|
wIdx = mlx.Reshape(wIdx, W, 1)
|
|
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
|
|
wIdx = mlx.Reshape(wIdx, W*2)
|
|
|
|
// Take along H axis (axis 1 in NHWC)
|
|
x = mlx.Take(x, hIdx, 1)
|
|
// Take along W axis (axis 2 in NHWC)
|
|
x = mlx.Take(x, wIdx, 2)
|
|
|
|
return x
|
|
}
|