mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
653 lines
15 KiB
Go
653 lines
15 KiB
Go
//go:build mlx
|
|
|
|
package zimage
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
// VAEConfig holds VAE decoder configuration
|
|
type VAEConfig struct {
|
|
InChannels int32 `json:"in_channels"`
|
|
OutChannels int32 `json:"out_channels"`
|
|
LatentChannels int32 `json:"latent_channels"`
|
|
BlockOutChannels []int32 `json:"block_out_channels"`
|
|
LayersPerBlock int32 `json:"layers_per_block"`
|
|
NormNumGroups int32 `json:"norm_num_groups"`
|
|
ScalingFactor float32 `json:"scaling_factor"`
|
|
ShiftFactor float32 `json:"shift_factor"`
|
|
}
|
|
|
|
// loadVAEConfig loads VAE config from a JSON file
|
|
func loadVAEConfig(path string) (*VAEConfig, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read config: %w", err)
|
|
}
|
|
var cfg VAEConfig
|
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
return nil, fmt.Errorf("parse config: %w", err)
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
// GroupNormLayer implements group normalization
|
|
type GroupNormLayer struct {
|
|
Weight *mlx.Array
|
|
Bias *mlx.Array
|
|
NumGroups int32
|
|
Eps float32
|
|
}
|
|
|
|
// NewGroupNorm creates a group norm layer
|
|
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
|
|
return &GroupNormLayer{
|
|
Weight: weight,
|
|
Bias: bias,
|
|
NumGroups: numGroups,
|
|
Eps: 1e-5,
|
|
}
|
|
}
|
|
|
|
// Forward applies group normalization
|
|
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
|
// x: [B, C, H, W]
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
C := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
|
|
// Reshape to [B, groups, C/groups, H, W]
|
|
groupSize := C / gn.NumGroups
|
|
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
|
|
|
|
// Compute mean and variance per group
|
|
mean := mlx.Mean(x, 2, true)
|
|
mean = mlx.Mean(mean, 3, true)
|
|
mean = mlx.Mean(mean, 4, true)
|
|
|
|
xCentered := mlx.Sub(x, mean)
|
|
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
|
|
variance = mlx.Mean(variance, 3, true)
|
|
variance = mlx.Mean(variance, 4, true)
|
|
|
|
// Normalize
|
|
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
|
|
|
// Reshape back to [B, C, H, W]
|
|
xNorm = mlx.Reshape(xNorm, B, C, H, W)
|
|
|
|
// Scale and shift (weight and bias are [C])
|
|
if gn.Weight != nil {
|
|
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
|
|
xNorm = mlx.Mul(xNorm, weight)
|
|
}
|
|
if gn.Bias != nil {
|
|
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
|
|
xNorm = mlx.Add(xNorm, bias)
|
|
}
|
|
|
|
return xNorm
|
|
}
|
|
|
|
// Conv2D represents a 2D convolution layer
|
|
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
|
|
type Conv2D struct {
|
|
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
|
Bias *mlx.Array // [out_channels]
|
|
Stride int32
|
|
Padding int32
|
|
}
|
|
|
|
// NewConv2D creates a Conv2D layer
|
|
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
|
|
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
|
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
|
|
// Transpose weight from OIHW to OHWI
|
|
// [O, I, H, W] -> [O, H, W, I]
|
|
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
|
|
return &Conv2D{
|
|
Weight: weightOHWI,
|
|
Bias: bias,
|
|
Stride: stride,
|
|
Padding: padding,
|
|
}
|
|
}
|
|
|
|
// Forward applies convolution
|
|
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
|
|
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
|
// x: [N, C, H, W] -> [N, H, W, C]
|
|
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
|
|
|
|
// Conv in NHWC format
|
|
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
|
|
|
|
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
|
|
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
|
|
|
|
if conv.Bias != nil {
|
|
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
|
|
out = mlx.Add(out, bias)
|
|
}
|
|
return out
|
|
}
|
|
|
|
// ResnetBlock2D implements a ResNet block for VAE
|
|
type ResnetBlock2D struct {
|
|
Norm1 *GroupNormLayer
|
|
Conv1 *Conv2D
|
|
Norm2 *GroupNormLayer
|
|
Conv2 *Conv2D
|
|
ConvShortcut *Conv2D // nil if in_channels == out_channels
|
|
}
|
|
|
|
// NewResnetBlock2D creates a ResNet block
|
|
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
|
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
block := &ResnetBlock2D{
|
|
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
|
|
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
|
|
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
|
|
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
|
|
}
|
|
|
|
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
|
|
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
|
|
}
|
|
|
|
return block, nil
|
|
}
|
|
|
|
// Forward applies the ResNet block with staged evaluation
|
|
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
|
var h *mlx.Array
|
|
|
|
// Stage 1: norm1
|
|
{
|
|
h = rb.Norm1.Forward(x)
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 2: silu + conv1
|
|
{
|
|
prev := h
|
|
h = mlx.SiLU(h)
|
|
h = rb.Conv1.Forward(h)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 3: norm2
|
|
{
|
|
prev := h
|
|
h = rb.Norm2.Forward(h)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 4: silu + conv2
|
|
{
|
|
prev := h
|
|
h = mlx.SiLU(h)
|
|
h = rb.Conv2.Forward(h)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Residual connection
|
|
{
|
|
prev := h
|
|
if rb.ConvShortcut != nil {
|
|
shortcut := rb.ConvShortcut.Forward(x)
|
|
h = mlx.Add(h, shortcut)
|
|
} else {
|
|
h = mlx.Add(h, x)
|
|
}
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
return h
|
|
}
|
|
|
|
// VAEAttentionBlock implements self-attention for VAE
|
|
type VAEAttentionBlock struct {
|
|
GroupNorm *GroupNormLayer
|
|
ToQWeight *mlx.Array
|
|
ToQBias *mlx.Array
|
|
ToKWeight *mlx.Array
|
|
ToKBias *mlx.Array
|
|
ToVWeight *mlx.Array
|
|
ToVBias *mlx.Array
|
|
ToOutWeight *mlx.Array
|
|
ToOutBias *mlx.Array
|
|
NumHeads int32
|
|
}
|
|
|
|
// NewVAEAttentionBlock creates an attention block
|
|
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
|
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VAEAttentionBlock{
|
|
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
|
|
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
|
|
ToQBias: toQBias,
|
|
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
|
|
ToKBias: toKBias,
|
|
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
|
|
ToVBias: toVBias,
|
|
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
|
|
ToOutBias: toOutBias,
|
|
NumHeads: 1,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies attention with staged evaluation
|
|
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
residual := x
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
C := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
|
|
var h *mlx.Array
|
|
|
|
// Stage 1: GroupNorm + reshape
|
|
{
|
|
h = ab.GroupNorm.Forward(x)
|
|
h = mlx.Transpose(h, 0, 2, 3, 1)
|
|
h = mlx.Reshape(h, B, H*W, C)
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
var out *mlx.Array
|
|
|
|
// Stage 2: Q, K, V projections + attention
|
|
{
|
|
q := mlx.Linear(h, ab.ToQWeight)
|
|
q = mlx.Add(q, ab.ToQBias)
|
|
k := mlx.Linear(h, ab.ToKWeight)
|
|
k = mlx.Add(k, ab.ToKBias)
|
|
v := mlx.Linear(h, ab.ToVWeight)
|
|
v = mlx.Add(v, ab.ToVBias)
|
|
h.Free()
|
|
|
|
q = mlx.ExpandDims(q, 1)
|
|
k = mlx.ExpandDims(k, 1)
|
|
v = mlx.ExpandDims(v, 1)
|
|
|
|
scale := float32(1.0 / math.Sqrt(float64(C)))
|
|
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
out = mlx.Squeeze(out, 1)
|
|
mlx.Eval(out)
|
|
}
|
|
|
|
// Stage 3: Output projection + reshape + residual
|
|
{
|
|
prev := out
|
|
out = mlx.Linear(out, ab.ToOutWeight)
|
|
out = mlx.Add(out, ab.ToOutBias)
|
|
out = mlx.Reshape(out, B, H, W, C)
|
|
out = mlx.Transpose(out, 0, 3, 1, 2)
|
|
out = mlx.Add(out, residual)
|
|
prev.Free()
|
|
mlx.Eval(out)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
// UpDecoderBlock2D implements an upsampling decoder block
|
|
type UpDecoderBlock2D struct {
|
|
ResnetBlocks []*ResnetBlock2D
|
|
Upsample *Conv2D
|
|
}
|
|
|
|
// NewUpDecoderBlock2D creates an up decoder block
|
|
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
|
resnets := make([]*ResnetBlock2D, numLayers)
|
|
for i := int32(0); i < numLayers; i++ {
|
|
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resnets[i] = resnet
|
|
}
|
|
|
|
var upsample *Conv2D
|
|
if hasUpsample {
|
|
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upsample = NewConv2D(upWeight, upBias, 1, 1)
|
|
}
|
|
|
|
return &UpDecoderBlock2D{
|
|
ResnetBlocks: resnets,
|
|
Upsample: upsample,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the up decoder block with staged evaluation to reduce peak memory
|
|
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
|
for _, resnet := range ub.ResnetBlocks {
|
|
prev := x
|
|
x = resnet.Forward(x) // ResNet handles its own pools
|
|
prev.Free()
|
|
}
|
|
|
|
if ub.Upsample != nil {
|
|
// Stage 1: Upsample2x (nearest neighbor)
|
|
{
|
|
prev := x
|
|
x = Upsample2x(x)
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
// Stage 2: Upsample conv
|
|
{
|
|
prev := x
|
|
x = ub.Upsample.Forward(x)
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
}
|
|
|
|
return x
|
|
}
|
|
|
|
// VAEMidBlock is the middle block with attention
|
|
type VAEMidBlock struct {
|
|
Resnet1 *ResnetBlock2D
|
|
Attention *VAEAttentionBlock
|
|
Resnet2 *ResnetBlock2D
|
|
}
|
|
|
|
// NewVAEMidBlock creates the mid block
|
|
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
|
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VAEMidBlock{
|
|
Resnet1: resnet1,
|
|
Attention: attention,
|
|
Resnet2: resnet2,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the mid block with staged evaluation
|
|
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
prev := x
|
|
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
|
|
prev.Free()
|
|
|
|
// Attention handles its own pools
|
|
prev = x
|
|
x = mb.Attention.Forward(x)
|
|
prev.Free()
|
|
|
|
prev = x
|
|
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
|
|
prev.Free()
|
|
|
|
return x
|
|
}
|
|
|
|
// VAEDecoder is the full VAE decoder
|
|
type VAEDecoder struct {
|
|
Config *VAEConfig
|
|
ConvIn *Conv2D
|
|
MidBlock *VAEMidBlock
|
|
UpBlocks []*UpDecoderBlock2D
|
|
ConvNormOut *GroupNormLayer
|
|
ConvOut *Conv2D
|
|
}
|
|
|
|
// Load loads the VAE decoder from a directory
|
|
func (m *VAEDecoder) Load(path string) error {
|
|
fmt.Println("Loading VAE decoder...")
|
|
|
|
// Load config
|
|
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
|
|
if err != nil {
|
|
return fmt.Errorf("config: %w", err)
|
|
}
|
|
m.Config = cfg
|
|
|
|
// Load weights
|
|
weights, err := safetensors.LoadModelWeights(path)
|
|
if err != nil {
|
|
return fmt.Errorf("weights: %w", err)
|
|
}
|
|
|
|
// Load conv_in
|
|
fmt.Print(" Loading conv_in... ")
|
|
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
|
|
fmt.Println("✓")
|
|
|
|
// Load mid block
|
|
fmt.Print(" Loading mid block... ")
|
|
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fmt.Println("✓")
|
|
|
|
// Load up blocks
|
|
fmt.Print(" Loading up blocks... ")
|
|
numBlocks := len(cfg.BlockOutChannels)
|
|
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
|
|
for i := 0; i < numBlocks; i++ {
|
|
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
|
hasUpsample := i < numBlocks-1
|
|
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
fmt.Printf("✓ [%d blocks]\n", numBlocks)
|
|
|
|
// Load conv_norm_out
|
|
fmt.Print(" Loading conv_norm_out... ")
|
|
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
|
|
fmt.Println("✓")
|
|
|
|
// Load conv_out
|
|
fmt.Print(" Loading conv_out... ")
|
|
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
|
|
fmt.Println("✓")
|
|
|
|
weights.ReleaseAll()
|
|
return nil
|
|
}
|
|
|
|
// Decode decodes latents to images.
|
|
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
|
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
|
var h *mlx.Array
|
|
{
|
|
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
|
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
|
h = vae.ConvIn.Forward(z)
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
h = vae.MidBlock.Forward(h)
|
|
|
|
for _, upBlock := range vae.UpBlocks {
|
|
h = upBlock.Forward(h)
|
|
}
|
|
|
|
{
|
|
prev := h
|
|
h = vae.ConvNormOut.Forward(h)
|
|
h = mlx.SiLU(h)
|
|
h = vae.ConvOut.Forward(h)
|
|
// VAE outputs [-1, 1], convert to [0, 1]
|
|
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
|
|
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
return h
|
|
}
|
|
|
|
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
|
|
// x: [B, C, H, W] -> [B, C, H*2, W*2]
|
|
func Upsample2x(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
C := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
|
|
// [B, C, H, W] -> [B, C, H, 1, W, 1]
|
|
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
|
// Broadcast to [B, C, H, 2, W, 2]
|
|
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
|
|
// Reshape to [B, C, H*2, W*2]
|
|
x = mlx.Reshape(x, B, C, H*2, W*2)
|
|
|
|
return x
|
|
}
|