//go:build mlx // Package zimage implements the Z-Image diffusion transformer model. package zimage import ( "fmt" "math" "github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen/cache" "github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/nn" "github.com/ollama/ollama/x/imagegen/safetensors" ) // TransformerConfig holds Z-Image transformer configuration type TransformerConfig struct { Dim int32 `json:"dim"` NHeads int32 `json:"n_heads"` NKVHeads int32 `json:"n_kv_heads"` NLayers int32 `json:"n_layers"` NRefinerLayers int32 `json:"n_refiner_layers"` InChannels int32 `json:"in_channels"` PatchSize int32 `json:"-"` // Computed from AllPatchSize CapFeatDim int32 `json:"cap_feat_dim"` NormEps float32 `json:"norm_eps"` RopeTheta float32 `json:"rope_theta"` TScale float32 `json:"t_scale"` QKNorm bool `json:"qk_norm"` AxesDims []int32 `json:"axes_dims"` AxesLens []int32 `json:"axes_lens"` AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element } // TimestepEmbedder creates sinusoidal timestep embeddings // Output dimension is 256 (fixed), used for AdaLN modulation type TimestepEmbedder struct { Linear1 *nn.Linear `weight:"mlp.0"` Linear2 *nn.Linear `weight:"mlp.2"` FreqEmbedSize int32 // 256 (computed) } // Forward computes timestep embeddings -> [B, 256] func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array { // t: [B] timesteps // Create sinusoidal embedding half := te.FreqEmbedSize / 2 // freqs = exp(-log(10000) * arange(half) / half) freqs := make([]float32, half) for i := int32(0); i < half; i++ { freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half))) } freqsArr := mlx.NewArray(freqs, []int32{1, half}) // t[:, None] * freqs[None, :] -> [B, half] tExpanded := mlx.ExpandDims(t, 1) // [B, 1] args := mlx.Mul(tExpanded, freqsArr) // embedding = [cos(args), sin(args)] -> [B, 256] cosArgs := mlx.Cos(args) sinArgs := mlx.Sin(args) embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // MLP: linear1 -> silu -> linear2 h := te.Linear1.Forward(embedding) h = mlx.SiLU(h) h = te.Linear2.Forward(h) return h } // XEmbedder embeds image patches to model dimension type XEmbedder struct { Linear *nn.Linear `weight:"2-1"` } // Forward embeds patchified image latents func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array { // x: [B, L, in_channels * 4] -> [B, L, dim] return xe.Linear.Forward(x) } // CapEmbedder projects caption features to model dimension type CapEmbedder struct { Norm *nn.RMSNorm `weight:"0"` Linear *nn.Linear `weight:"1"` PadToken *mlx.Array // loaded separately at root level } // Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim] func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array { // RMSNorm on last axis (uses 1e-6) h := ce.Norm.Forward(capFeats, 1e-6) // Linear projection return ce.Linear.Forward(h) } // FeedForward implements SwiGLU FFN type FeedForward struct { W1 *nn.Linear `weight:"w1"` // gate projection W2 *nn.Linear `weight:"w2"` // down projection W3 *nn.Linear `weight:"w3"` // up projection OutDim int32 // computed from W2 } // Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2 func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array { shape := x.Shape() B := shape[0] L := shape[1] D := shape[2] // Reshape for matmul x = mlx.Reshape(x, B*L, D) gate := ff.W1.Forward(x) gate = mlx.SiLU(gate) up := ff.W3.Forward(x) h := mlx.Mul(gate, up) out := ff.W2.Forward(h) return mlx.Reshape(out, B, L, ff.OutDim) } // Attention implements multi-head attention with QK norm type Attention struct { ToQ *nn.Linear `weight:"to_q"` ToK *nn.Linear `weight:"to_k"` ToV *nn.Linear `weight:"to_v"` ToOut *nn.Linear `weight:"to_out.0"` NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm NormK *mlx.Array `weight:"norm_k.weight"` // Computed fields NHeads int32 HeadDim int32 Dim int32 Scale float32 } // Forward computes attention func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array { shape := x.Shape() B := shape[0] L := shape[1] D := shape[2] // Project Q, K, V xFlat := mlx.Reshape(x, B*L, D) q := attn.ToQ.Forward(xFlat) k := attn.ToK.Forward(xFlat) v := attn.ToV.Forward(xFlat) // Reshape to [B, L, nheads, head_dim] q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim) v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim) // QK norm q = mlx.RMSNorm(q, attn.NormQ, 1e-5) k = mlx.RMSNorm(k, attn.NormK, 1e-5) // Apply RoPE if provided if cos != nil && sin != nil { q = applyRoPE3D(q, cos, sin) k = applyRoPE3D(k, cos, sin) } // Transpose to [B, nheads, L, head_dim] q = mlx.Transpose(q, 0, 2, 1, 3) k = mlx.Transpose(k, 0, 2, 1, 3) v = mlx.Transpose(v, 0, 2, 1, 3) // SDPA out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false) // Transpose back and reshape out = mlx.Transpose(out, 0, 2, 1, 3) out = mlx.Reshape(out, B*L, attn.Dim) out = attn.ToOut.Forward(out) return mlx.Reshape(out, B, L, attn.Dim) } // applyRoPE3D applies 3-axis rotary position embeddings // x: [B, L, nheads, head_dim] // cos, sin: [B, L, 1, head_dim/2] func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array { shape := x.Shape() B := shape[0] L := shape[1] nheads := shape[2] headDim := shape[3] half := headDim / 2 // Create even/odd index arrays evenIdx := make([]int32, half) oddIdx := make([]int32, half) for i := int32(0); i < half; i++ { evenIdx[i] = i * 2 oddIdx[i] = i*2 + 1 } evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half}) oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half}) // Extract x1 (even indices) and x2 (odd indices) along last axis x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half] x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half] // Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos] r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin)) r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos)) // Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...] r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1] r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1] stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2] return mlx.Reshape(stacked, B, L, nheads, headDim) } // TransformerBlock is a single transformer block with optional AdaLN modulation type TransformerBlock struct { Attention *Attention `weight:"attention"` FeedForward *FeedForward `weight:"feed_forward"` AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"` AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"` FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"` FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"` AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation // Computed fields HasModulation bool Dim int32 } // Forward applies the transformer block func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array { if tb.AdaLN != nil && adaln != nil { // Compute modulation: [B, 256] -> [B, 4*dim] chunks := tb.AdaLN.Forward(adaln) // Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp chunkShape := chunks.Shape() chunkDim := chunkShape[1] / 4 scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim}) gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2}) scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3}) gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4}) // Expand for broadcasting: [B, 1, dim] scaleMSA = mlx.ExpandDims(scaleMSA, 1) gateMSA = mlx.ExpandDims(gateMSA, 1) scaleMLP = mlx.ExpandDims(scaleMLP, 1) gateMLP = mlx.ExpandDims(gateMLP, 1) // Attention with modulation normX := tb.AttentionNorm1.Forward(x, eps) normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0)) attnOut := tb.Attention.Forward(normX, cos, sin) attnOut = tb.AttentionNorm2.Forward(attnOut, eps) x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut)) // FFN with modulation normFFN := tb.FFNNorm1.Forward(x, eps) normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0)) ffnOut := tb.FeedForward.Forward(normFFN) ffnOut = tb.FFNNorm2.Forward(ffnOut, eps) x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut)) } else { // No modulation (context refiner) attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin) x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps)) ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps)) x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps)) } return x } // FinalLayer outputs the denoised patches type FinalLayer struct { AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim] Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels] OutDim int32 // computed from Output } // Forward computes final output func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array { // c: [B, 256] -> scale: [B, dim] scale := mlx.SiLU(c) scale = fl.AdaLN.Forward(scale) scale = mlx.ExpandDims(scale, 1) // [B, 1, dim] // LayerNorm (affine=False) then scale x = layerNormNoAffine(x, 1e-6) x = mlx.Mul(x, mlx.AddScalar(scale, 1.0)) // Output projection shape := x.Shape() B := shape[0] L := shape[1] D := shape[2] x = mlx.Reshape(x, B*L, D) x = fl.Output.Forward(x) return mlx.Reshape(x, B, L, fl.OutDim) } // layerNormNoAffine applies layer norm without learnable parameters func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array { ndim := x.Ndim() lastAxis := ndim - 1 mean := mlx.Mean(x, lastAxis, true) xCentered := mlx.Sub(x, mean) variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true) return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps))) } // Transformer is the full Z-Image DiT model type Transformer struct { TEmbed *TimestepEmbedder `weight:"t_embedder"` XEmbed *XEmbedder `weight:"all_x_embedder"` CapEmbed *CapEmbedder `weight:"cap_embedder"` NoiseRefiners []*TransformerBlock `weight:"noise_refiner"` ContextRefiners []*TransformerBlock `weight:"context_refiner"` Layers []*TransformerBlock `weight:"layers"` FinalLayer *FinalLayer `weight:"all_final_layer.2-1"` XPadToken *mlx.Array `weight:"x_pad_token"` CapPadToken *mlx.Array `weight:"cap_pad_token"` *TransformerConfig } // Load loads the Z-Image transformer from ollama blob storage. func (m *Transformer) Load(manifest *imagegen.ModelManifest) error { fmt.Print(" Loading transformer... ") // Load config from blob var cfg TransformerConfig if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil { return fmt.Errorf("config: %w", err) } if len(cfg.AllPatchSize) > 0 { cfg.PatchSize = cfg.AllPatchSize[0] } m.TransformerConfig = &cfg m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers) m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers) m.Layers = make([]*TransformerBlock, cfg.NLayers) // Load weights from tensor blobs with BF16 conversion weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer") if err != nil { return fmt.Errorf("weights: %w", err) } if err := weights.Load(mlx.DtypeBFloat16); err != nil { return fmt.Errorf("load weights: %w", err) } defer weights.ReleaseAll() return m.loadWeights(weights) } // loadWeights loads weights from any WeightSource into the model func (m *Transformer) loadWeights(weights safetensors.WeightSource) error { if err := safetensors.LoadModule(m, weights, ""); err != nil { return fmt.Errorf("load module: %w", err) } m.initComputedFields() fmt.Println("✓") return nil } // initComputedFields initializes computed fields after loading weights func (m *Transformer) initComputedFields() { cfg := m.TransformerConfig m.TEmbed.FreqEmbedSize = 256 m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0] m.CapEmbed.Norm.Eps = 1e-6 for _, block := range m.NoiseRefiners { initTransformerBlock(block, cfg) } for _, block := range m.ContextRefiners { initTransformerBlock(block, cfg) } for _, block := range m.Layers { initTransformerBlock(block, cfg) } } // initTransformerBlock sets computed fields on a transformer block func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) { block.Dim = cfg.Dim block.HasModulation = block.AdaLN != nil // Init attention computed fields attn := block.Attention attn.NHeads = cfg.NHeads attn.HeadDim = cfg.Dim / cfg.NHeads attn.Dim = cfg.Dim attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim))) // Init feedforward OutDim block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0] // Set eps on all RMSNorm layers block.AttentionNorm1.Eps = cfg.NormEps block.AttentionNorm2.Eps = cfg.NormEps block.FFNNorm1.Eps = cfg.NormEps block.FFNNorm2.Eps = cfg.NormEps } // RoPECache holds precomputed RoPE values type RoPECache struct { ImgCos *mlx.Array ImgSin *mlx.Array CapCos *mlx.Array CapSin *mlx.Array UnifiedCos *mlx.Array UnifiedSin *mlx.Array ImgLen int32 CapLen int32 } // PrepareRoPECache precomputes RoPE values for the given image and caption lengths. // hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize). func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache { imgLen := hTok * wTok // Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0) imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0) imgPos = mlx.ToBFloat16(imgPos) // Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0) capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0) capPos = mlx.ToBFloat16(capPos) // Compute RoPE from UNIFIED positions unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1) unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims) // Slice RoPE for image and caption parts imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64}) imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64}) capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64}) capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64}) return &RoPECache{ ImgCos: imgCos, ImgSin: imgSin, CapCos: capCos, CapSin: capSin, UnifiedCos: unifiedCos, UnifiedSin: unifiedSin, ImgLen: imgLen, CapLen: capLen, } } // Forward runs the Z-Image transformer with precomputed RoPE func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array { imgLen := rope.ImgLen // Timestep embedding -> [B, 256] temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale)) // Embed image patches -> [B, L_img, dim] x = m.XEmbed.Forward(x) // Embed caption features -> [B, L_cap, dim] capEmb := m.CapEmbed.Forward(capFeats) eps := m.NormEps // Noise refiner: refine image patches with modulation for _, refiner := range m.NoiseRefiners { x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps) } // Context refiner: refine caption (no modulation) for _, refiner := range m.ContextRefiners { capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps) } // Concatenate image and caption for joint attention unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1) // Main transformer layers use full unified RoPE for _, layer := range m.Layers { unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps) } // Extract image tokens only unifiedShape := unified.Shape() B := unifiedShape[0] imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]}) // Final layer return m.FinalLayer.Forward(imgOut, temb) } // ForwardWithCache runs the transformer with layer caching for faster inference. // On refresh steps (step % cacheInterval == 0), all layers are computed and cached. // On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs. func (m *Transformer) ForwardWithCache( x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache, stepCache *cache.StepCache, step int, cacheInterval int, ) *mlx.Array { imgLen := rope.ImgLen cacheLayers := stepCache.NumLayers() eps := m.NormEps // Timestep embedding -> [B, 256] temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale)) // Embed image patches -> [B, L_img, dim] x = m.XEmbed.Forward(x) // Context refiners: compute once on step 0, reuse forever // (caption embedding doesn't depend on timestep or latents) var capEmb *mlx.Array if stepCache.GetConstant() != nil { capEmb = stepCache.GetConstant() } else { capEmb = m.CapEmbed.Forward(capFeats) for _, refiner := range m.ContextRefiners { capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps) } stepCache.SetConstant(capEmb) } // Noise refiners: always compute (depend on x which changes each step) for _, refiner := range m.NoiseRefiners { x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps) } // Concatenate image and caption for joint attention unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1) // Determine if this is a cache refresh step refreshCache := stepCache.ShouldRefresh(step, cacheInterval) // Main transformer layers with caching for i, layer := range m.Layers { if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil { // Use cached output for shallow layers unified = stepCache.Get(i) } else { // Compute layer unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps) // Cache shallow layer outputs on refresh steps if i < cacheLayers && refreshCache { stepCache.Set(i, unified) } } } // Extract image tokens only unifiedShape := unified.Shape() B := unifiedShape[0] imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]}) // Final layer return m.FinalLayer.Forward(imgOut, temb) } // createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3] func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array { // Create meshgrid and stack total := d0 * d1 * d2 coords := make([]float32, total*3) idx := 0 for i := int32(0); i < d0; i++ { for j := int32(0); j < d1; j++ { for k := int32(0); k < d2; k++ { coords[idx*3+0] = float32(s0 + i) coords[idx*3+1] = float32(s1 + j) coords[idx*3+2] = float32(s2 + k) idx++ } } } return mlx.NewArray(coords, []int32{1, total, 3}) } // prepareRoPE3D computes cos/sin for 3-axis RoPE // positions: [B, L, 3] with (h, w, t) coordinates // axesDims: [32, 48, 48] - dimensions for each axis // Returns: cos, sin each [B, L, 1, head_dim/2] func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) { // Compute frequencies for each axis // dims = [32, 48, 48], so halves = [16, 24, 24] ropeTheta := float32(256.0) freqs := make([]*mlx.Array, 3) for axis := 0; axis < 3; axis++ { half := axesDims[axis] / 2 f := make([]float32, half) for i := int32(0); i < half; i++ { f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half))) } freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half}) } // Extract position coordinates shape := positions.Shape() B := shape[0] L := shape[1] // positions[:, :, 0] -> h positions posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1}) posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2}) posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3}) // Compute args: pos * freqs for each axis posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1] posW = mlx.ExpandDims(posW, 3) posT = mlx.ExpandDims(posT, 3) argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16] argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24] argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24] // Concatenate: [B, L, 1, 16+24+24=64] args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3) // Compute cos and sin return mlx.Cos(args), mlx.Sin(args) } // PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2] // Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4) func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array { shape := latents.Shape() C := shape[1] H := shape[2] W := shape[3] pH := H / patchSize // H_tok pW := W / patchSize // W_tok // Match Python exactly: reshape treating B=1 as part of contiguous data // [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2] x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize) // Python: transpose(1, 2, 3, 5, 4, 6, 0) // [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C] x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0) // [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4] return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize) } // UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W] // Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W) func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array { pH := H / patchSize pW := W / patchSize // [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C] x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C) // Python: transpose(6, 0, 1, 2, 4, 3, 5) // [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2] x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5) // [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W] return mlx.Reshape(x, 1, C, H, W) }