Files
ollama/x/imagegen/models/zimage/transformer.go
2026-01-09 21:09:46 -08:00

679 lines
22 KiB
Go

//go:build mlx
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Z-Image transformer configuration
type TransformerConfig struct {
Dim int32 `json:"dim"`
NHeads int32 `json:"n_heads"`
NKVHeads int32 `json:"n_kv_heads"`
NLayers int32 `json:"n_layers"`
NRefinerLayers int32 `json:"n_refiner_layers"`
InChannels int32 `json:"in_channels"`
PatchSize int32 `json:"-"` // Computed from AllPatchSize
CapFeatDim int32 `json:"cap_feat_dim"`
NormEps float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
TScale float32 `json:"t_scale"`
QKNorm bool `json:"qk_norm"`
AxesDims []int32 `json:"axes_dims"`
AxesLens []int32 `json:"axes_lens"`
AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
}
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// t: [B] timesteps
// Create sinusoidal embedding
half := te.FreqEmbedSize / 2
// freqs = exp(-log(10000) * arange(half) / half)
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// t[:, None] * freqs[None, :] -> [B, half]
tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
args := mlx.Mul(tExpanded, freqsArr)
// embedding = [cos(args), sin(args)] -> [B, 256]
cosArgs := mlx.Cos(args)
sinArgs := mlx.Sin(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)
// MLP: linear1 -> silu -> linear2
h := te.Linear1.Forward(embedding)
h = mlx.SiLU(h)
h = te.Linear2.Forward(h)
return h
}
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear *nn.Linear `weight:"2-1"`
}
// Forward embeds patchified image latents
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// x: [B, L, in_channels * 4] -> [B, L, dim]
return xe.Linear.Forward(x)
}
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear *nn.Linear `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// RMSNorm on last axis (uses 1e-6)
h := ce.Norm.Forward(capFeats, 1e-6)
// Linear projection
return ce.Linear.Forward(h)
}
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
h := mlx.Mul(gate, up)
out := ff.W2.Forward(h)
return mlx.Reshape(out, B, L, ff.OutDim)
}
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
}
// Forward computes attention
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)
// QK norm
q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
k = mlx.RMSNorm(k, attn.NormK, 1e-5)
// Apply RoPE if provided
if cos != nil && sin != nil {
q = applyRoPE3D(q, cos, sin)
k = applyRoPE3D(k, cos, sin)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
// Transpose back and reshape
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B*L, attn.Dim)
out = attn.ToOut.Forward(out)
return mlx.Reshape(out, B, L, attn.Dim)
}
// applyRoPE3D applies 3-axis rotary position embeddings
// x: [B, L, nheads, head_dim]
// cos, sin: [B, L, 1, head_dim/2]
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Create even/odd index arrays
evenIdx := make([]int32, half)
oddIdx := make([]int32, half)
for i := int32(0); i < half; i++ {
evenIdx[i] = i * 2
oddIdx[i] = i*2 + 1
}
evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})
// Extract x1 (even indices) and x2 (odd indices) along last axis
x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half]
// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))
// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1]
r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1]
stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
return mlx.Reshape(stacked, B, L, nheads, headDim)
}
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
if tb.AdaLN != nil && adaln != nil {
// Compute modulation: [B, 256] -> [B, 4*dim]
chunks := tb.AdaLN.Forward(adaln)
// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
chunkShape := chunks.Shape()
chunkDim := chunkShape[1] / 4
scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})
// Expand for broadcasting: [B, 1, dim]
scaleMSA = mlx.ExpandDims(scaleMSA, 1)
gateMSA = mlx.ExpandDims(gateMSA, 1)
scaleMLP = mlx.ExpandDims(scaleMLP, 1)
gateMLP = mlx.ExpandDims(gateMLP, 1)
// Attention with modulation
normX := tb.AttentionNorm1.Forward(x, eps)
normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
attnOut := tb.Attention.Forward(normX, cos, sin)
attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))
// FFN with modulation
normFFN := tb.FFNNorm1.Forward(x, eps)
normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
ffnOut := tb.FeedForward.Forward(normFFN)
ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
} else {
// No modulation (context refiner)
attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))
ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
}
return x
}
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
// Forward computes final output
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
// c: [B, 256] -> scale: [B, dim]
scale := mlx.SiLU(c)
scale = fl.AdaLN.Forward(scale)
scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]
// LayerNorm (affine=False) then scale
x = layerNormNoAffine(x, 1e-6)
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
// Output projection
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
x = mlx.Reshape(x, B*L, D)
x = fl.Output.Forward(x)
return mlx.Reshape(x, B, L, fl.OutDim)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Z-Image DiT model
type Transformer struct {
TEmbed *TimestepEmbedder `weight:"t_embedder"`
XEmbed *XEmbedder `weight:"all_x_embedder"`
CapEmbed *CapEmbedder `weight:"cap_embedder"`
NoiseRefiners []*TransformerBlock `weight:"noise_refiner"`
ContextRefiners []*TransformerBlock `weight:"context_refiner"`
Layers []*TransformerBlock `weight:"layers"`
FinalLayer *FinalLayer `weight:"all_final_layer.2-1"`
XPadToken *mlx.Array `weight:"x_pad_token"`
CapPadToken *mlx.Array `weight:"cap_pad_token"`
*TransformerConfig
}
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
}
m.TransformerConfig = &cfg
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights from tensor blobs with BF16 conversion
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.ContextRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.Layers {
initTransformerBlock(block, cfg)
}
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
block.HasModulation = block.AdaLN != nil
// Init attention computed fields
attn := block.Attention
attn.NHeads = cfg.NHeads
attn.HeadDim = cfg.Dim / cfg.NHeads
attn.Dim = cfg.Dim
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
block.AttentionNorm2.Eps = cfg.NormEps
block.FFNNorm1.Eps = cfg.NormEps
block.FFNNorm2.Eps = cfg.NormEps
}
// RoPECache holds precomputed RoPE values
type RoPECache struct {
ImgCos *mlx.Array
ImgSin *mlx.Array
CapCos *mlx.Array
CapSin *mlx.Array
UnifiedCos *mlx.Array
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
imgLen := hTok * wTok
// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
imgPos = mlx.ToBFloat16(imgPos)
// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
capPos = mlx.ToBFloat16(capPos)
// Compute RoPE from UNIFIED positions
unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)
// Slice RoPE for image and caption parts
imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
return &RoPECache{
ImgCos: imgCos,
ImgSin: imgSin,
CapCos: capCos,
CapSin: capSin,
UnifiedCos: unifiedCos,
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
}
}
// Forward runs the Z-Image transformer with precomputed RoPE
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
imgLen := rope.ImgLen
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Embed caption features -> [B, L_cap, dim]
capEmb := m.CapEmbed.Forward(capFeats)
eps := m.NormEps
// Noise refiner: refine image patches with modulation
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Context refiner: refine caption (no modulation)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Main transformer layers use full unified RoPE
for _, layer := range m.Layers {
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// ForwardWithCache runs the transformer with layer caching for faster inference.
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
func (m *Transformer) ForwardWithCache(
x *mlx.Array,
t *mlx.Array,
capFeats *mlx.Array,
rope *RoPECache,
stepCache *cache.StepCache,
step int,
cacheInterval int,
) *mlx.Array {
imgLen := rope.ImgLen
cacheLayers := stepCache.NumLayers()
eps := m.NormEps
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Context refiners: compute once on step 0, reuse forever
// (caption embedding doesn't depend on timestep or latents)
var capEmb *mlx.Array
if stepCache.GetConstant() != nil {
capEmb = stepCache.GetConstant()
} else {
capEmb = m.CapEmbed.Forward(capFeats)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
stepCache.SetConstant(capEmb)
}
// Noise refiners: always compute (depend on x which changes each step)
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Determine if this is a cache refresh step
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
// Main transformer layers with caching
for i, layer := range m.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached output for shallow layers
unified = stepCache.Get(i)
} else {
// Compute layer
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
// Cache shallow layer outputs on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, unified)
}
}
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
// Create meshgrid and stack
total := d0 * d1 * d2
coords := make([]float32, total*3)
idx := 0
for i := int32(0); i < d0; i++ {
for j := int32(0); j < d1; j++ {
for k := int32(0); k < d2; k++ {
coords[idx*3+0] = float32(s0 + i)
coords[idx*3+1] = float32(s1 + j)
coords[idx*3+2] = float32(s2 + k)
idx++
}
}
}
return mlx.NewArray(coords, []int32{1, total, 3})
}
// prepareRoPE3D computes cos/sin for 3-axis RoPE
// positions: [B, L, 3] with (h, w, t) coordinates
// axesDims: [32, 48, 48] - dimensions for each axis
// Returns: cos, sin each [B, L, 1, head_dim/2]
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
// Compute frequencies for each axis
// dims = [32, 48, 48], so halves = [16, 24, 24]
ropeTheta := float32(256.0)
freqs := make([]*mlx.Array, 3)
for axis := 0; axis < 3; axis++ {
half := axesDims[axis] / 2
f := make([]float32, half)
for i := int32(0); i < half; i++ {
f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
}
freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
}
// Extract position coordinates
shape := positions.Shape()
B := shape[0]
L := shape[1]
// positions[:, :, 0] -> h positions
posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})
// Compute args: pos * freqs for each axis
posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
posW = mlx.ExpandDims(posW, 3)
posT = mlx.ExpandDims(posT, 3)
argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]
// Concatenate: [B, L, 1, 16+24+24=64]
args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)
// Compute cos and sin
return mlx.Cos(args), mlx.Sin(args)
}
// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize // H_tok
pW := W / patchSize // W_tok
// Match Python exactly: reshape treating B=1 as part of contiguous data
// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)
// Python: transpose(1, 2, 3, 5, 4, 6, 0)
// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)
// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
pH := H / patchSize
pW := W / patchSize
// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)
// Python: transpose(6, 0, 1, 2, 4, 3, 5)
// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)
// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
return mlx.Reshape(x, 1, C, H, W)
}