Files
ollama/x/imagegen/models/qwen_image/vae.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

855 lines
21 KiB
Go

//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
// Input is channels-last: [B, T, H, W, C]
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// RMSNorm3D applies RMS normalization over channels
// Works with channels-last [B, T, H, W, C] format
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
// Use h as working variable, keep x intact for residual (caller will free x)
// Conv handles its own pools, so we just need pools for non-conv operations
var h *mlx.Array
// Keep x so it survives Eval() cleanup - needed for residual connection
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1 (handles its own pools)
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2 (handles its own pools)
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection (shortcut handles its own pools if present)
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V: [B*T, H*W, 3*C]
// Weight is [outC, inC] or [outC, inC, 1, 1]
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0) // [inC, outC]
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// out: [B*T, 1, H*W, C]
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0) // [inC, outC]
out = mlx.Linear(out, p) // [B*T, H*W, C]
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block with staged memory management
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
// ResBlocks handle their own pools
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Upsampler handles its own pools
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
// Uses staged pools to reduce peak memory during 2x upsampling
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block of decoder
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
// Each component handles its own pools; we just free inputs
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading Qwen-Image VAE decoder...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading post_quant_conv... ")
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
m.PostQuantConv = postQuantConv
fmt.Println("✓")
fmt.Print(" Loading conv_in... ")
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
m.ConvIn = convIn
fmt.Println("✓")
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
fmt.Print(" Loading mid_block... ")
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
m.MidBlock = midBlock
fmt.Println("✓")
// Up blocks (reversed dim_mult)
fmt.Print(" Loading up_blocks... ")
numUpBlocks := len(cfg.DimMult)
m.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
m.UpBlocks[i] = upBlock
}
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
fmt.Print(" Loading output layers... ")
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
m.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
m.ConvOut = convOut
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
m := &VAEDecoder{}
if err := m.Load(filepath.Join(path, "vae")); err != nil {
return nil, err
}
return m, nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] normalized latents
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Stage 1a: Denormalize and transpose
{
z = vae.Denormalize(z)
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// Stage 1b: PostQuantConv (handles its own pools)
x = vae.PostQuantConv.Forward(z)
z.Free()
// Stage 1c: ConvIn (handles its own pools)
{
prev := x
x = vae.ConvIn.Forward(x)
prev.Free()
}
// Stage 2: Mid block (handles its own pools)
x = vae.MidBlock.Forward(x)
// Stage 3: Up blocks (each handles its own pools)
for _, upBlock := range vae.UpBlocks {
x = upBlock.Forward(x)
}
// Stage 4a: NormOut + silu
{
prev := x
x = vae.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Stage 4b: ConvOut (handles its own pools)
{
prev := x
x = vae.ConvOut.Forward(x)
prev.Free()
}
// Stage 4c: Post-processing
{
prev := x
// Clamp to [-1, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
// Convert back from channels-last to channels-first
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// Denormalize reverses the normalization applied during encoding
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
mean = mlx.ToBFloat16(mean)
std = mlx.ToBFloat16(std)
return mlx.Add(mlx.Mul(z, std), mean)
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
}
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[2]
W := shape[3]
x = mlx.Transpose(x, 0, 2, 3, 1)
x = mlx.Reshape(x, B*H*W, shape[1])
wShape := weight.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(weight, wShape[0], wShape[1])
} else {
w = weight
}
w = mlx.Transpose(w, 1, 0)
out := mlx.Linear(x, w)
outC := w.Dim(1)
out = mlx.Reshape(out, B, H, W, outC)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
x = pad2D(x, 1, 1, 1, 1)
return conv2D(x, weight, 1, 1)
}
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
x = mlx.Transpose(x, 0, 2, 3, 1)
w = mlx.Transpose(w, 0, 2, 3, 1)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := w.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := extractPatches2D(x, kH, kW, strideH, strideW)
wFlat := mlx.Reshape(w, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
out = mlx.Reshape(out, B, outH, outW, Cout)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * strideH
startW := j * strideW
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[2]
W := shape[3]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 2)
x = mlx.Take(x, colIdx, 3)
return x
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
// Create repeat indices for rows
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
// Create repeat indices for columns
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
// Take along H (axis 1) then W (axis 2)
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
// weight: [outC, kH, kW, inC] (MLX channels-last format)
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
// stride=1, padding=0 (we already padded manually)
return mlx.Conv2d(x, weight, 1, 0)
}