mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
683 lines
16 KiB
Go
683 lines
16 KiB
Go
//go:build mlx
|
|
|
|
package qwen_image_edit
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
|
type CausalConv3d struct {
|
|
Weight *mlx.Array
|
|
Bias *mlx.Array
|
|
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
|
KernelT int32
|
|
}
|
|
|
|
// newCausalConv3d creates a 3D causal conv
|
|
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
|
weight, err := weights.Get(prefix + ".weight")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("weight not found: %s", prefix)
|
|
}
|
|
bias, _ := weights.Get(prefix + ".bias")
|
|
|
|
kernelT := weight.Shape()[2]
|
|
outC := weight.Shape()[0]
|
|
|
|
var biasReshaped *mlx.Array
|
|
if bias != nil {
|
|
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
|
}
|
|
|
|
return &CausalConv3d{
|
|
Weight: weight,
|
|
Bias: bias,
|
|
BiasReshaped: biasReshaped,
|
|
KernelT: kernelT,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies causal 3D convolution (or 2D if weight is 4D)
|
|
// x: [B, T, H, W, C] (channels-last, MLX format)
|
|
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
|
shape := c.Weight.Shape()
|
|
|
|
// Handle both 5D (3D conv) and 4D (2D conv) weights
|
|
if len(shape) == 4 {
|
|
// 2D conv: [O, I, kH, kW] - need to apply per-frame
|
|
return c.forward2D(x)
|
|
}
|
|
|
|
// 3D conv: [O, I, kT, kH, kW]
|
|
kernelT := shape[2]
|
|
kernelH := shape[3]
|
|
kernelW := shape[4]
|
|
|
|
// Causal temporal padding, same spatial padding
|
|
padT := kernelT - 1
|
|
padH := kernelH / 2
|
|
padW := kernelW / 2
|
|
|
|
// Stage 1: Pad
|
|
{
|
|
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
// Stage 2: Conv + bias
|
|
var out *mlx.Array
|
|
{
|
|
prev := x
|
|
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
|
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
|
if c.Bias != nil {
|
|
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
|
out = mlx.Add(out, bias)
|
|
}
|
|
prev.Free()
|
|
mlx.Eval(out)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
|
|
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
|
|
xShape := x.Shape()
|
|
B := xShape[0]
|
|
T := xShape[1]
|
|
H := xShape[2]
|
|
W := xShape[3]
|
|
C := xShape[4]
|
|
|
|
wShape := c.Weight.Shape() // [O, I, kH, kW]
|
|
kernelH := wShape[2]
|
|
kernelW := wShape[3]
|
|
outC := wShape[0]
|
|
|
|
padH := kernelH / 2
|
|
padW := kernelW / 2
|
|
|
|
// Reshape to [B*T, H, W, C] for 2D conv
|
|
x = mlx.Reshape(x, B*T, H, W, C)
|
|
|
|
// Pad spatially
|
|
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
|
|
|
|
// Apply 2D conv
|
|
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
|
x = mlx.Conv2d(x, weight, 1, 0)
|
|
|
|
if c.Bias != nil {
|
|
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
|
|
x = mlx.Add(x, bias)
|
|
}
|
|
|
|
// Get output spatial dims
|
|
outH := H
|
|
outW := W
|
|
|
|
// Reshape back to [B, T, H, W, C]
|
|
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
|
mlx.Eval(x)
|
|
|
|
return x
|
|
}
|
|
|
|
// RMSNorm3D applies RMS normalization over channels
|
|
type RMSNorm3D struct {
|
|
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
|
}
|
|
|
|
// newRMSNorm3D creates an RMS norm
|
|
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
|
gamma, err := weights.Get(prefix + ".gamma")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
|
return &RMSNorm3D{Gamma: gamma}, nil
|
|
}
|
|
|
|
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
|
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
|
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
|
return mlx.Mul(normalized, n.Gamma)
|
|
}
|
|
|
|
// ResBlock is a residual block with RMS norm and causal convs
|
|
type ResBlock struct {
|
|
Norm1 *RMSNorm3D
|
|
Conv1 *CausalConv3d
|
|
Norm2 *RMSNorm3D
|
|
Conv2 *CausalConv3d
|
|
Shortcut *CausalConv3d
|
|
}
|
|
|
|
// newResBlock creates a residual block
|
|
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
|
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var shortcut *CausalConv3d
|
|
if inDim != outDim {
|
|
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &ResBlock{
|
|
Norm1: norm1,
|
|
Conv1: conv1,
|
|
Norm2: norm2,
|
|
Conv2: conv2,
|
|
Shortcut: shortcut,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the residual block
|
|
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
var h *mlx.Array
|
|
|
|
mlx.Keep(x)
|
|
|
|
// Stage 1: norm1 + silu
|
|
{
|
|
h = r.Norm1.Forward(x)
|
|
h = silu3D(h)
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 2: conv1
|
|
{
|
|
prev := h
|
|
h = r.Conv1.Forward(h)
|
|
prev.Free()
|
|
}
|
|
|
|
// Stage 3: norm2 + silu
|
|
{
|
|
prev := h
|
|
h = r.Norm2.Forward(h)
|
|
h = silu3D(h)
|
|
prev.Free()
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
// Stage 4: conv2
|
|
{
|
|
prev := h
|
|
h = r.Conv2.Forward(h)
|
|
prev.Free()
|
|
}
|
|
|
|
// Residual connection
|
|
if r.Shortcut != nil {
|
|
shortcut := r.Shortcut.Forward(x)
|
|
h = mlx.Add(h, shortcut)
|
|
mlx.Eval(h)
|
|
} else {
|
|
h = mlx.Add(h, x)
|
|
mlx.Eval(h)
|
|
}
|
|
|
|
return h
|
|
}
|
|
|
|
// AttentionBlock is a 2D attention block
|
|
type AttentionBlock struct {
|
|
Norm *RMSNorm3D
|
|
ToQKV *mlx.Array
|
|
ToQKVBias *mlx.Array
|
|
Proj *mlx.Array
|
|
ProjBias *mlx.Array
|
|
Dim int32
|
|
}
|
|
|
|
// newAttentionBlock creates an attention block
|
|
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
|
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
|
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
|
proj, _ := weights.Get(prefix + ".proj.weight")
|
|
projBias, _ := weights.Get(prefix + ".proj.bias")
|
|
|
|
return &AttentionBlock{
|
|
Norm: norm,
|
|
ToQKV: toQKV,
|
|
ToQKVBias: toQKVBias,
|
|
Proj: proj,
|
|
ProjBias: projBias,
|
|
Dim: dim,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies 2D attention
|
|
// Input: [B, T, H, W, C] (channels-last)
|
|
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
T := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
C := shape[4]
|
|
|
|
identity := x
|
|
|
|
// Flatten to [B*T, 1, H, W, C] for norm
|
|
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
|
x = a.Norm.Forward(x)
|
|
x = mlx.Reshape(x, B*T, H, W, C)
|
|
|
|
// Flatten spatial to [B*T, H*W, C]
|
|
x = mlx.Reshape(x, B*T, H*W, C)
|
|
|
|
// Linear to get Q, K, V
|
|
wShape := a.ToQKV.Shape()
|
|
var w *mlx.Array
|
|
if len(wShape) == 4 {
|
|
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
|
} else {
|
|
w = a.ToQKV
|
|
}
|
|
w = mlx.Transpose(w, 1, 0)
|
|
|
|
qkv := mlx.Linear(x, w)
|
|
if a.ToQKVBias != nil {
|
|
qkv = mlx.Add(qkv, a.ToQKVBias)
|
|
}
|
|
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
|
|
|
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
|
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
|
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
|
|
|
scale := float32(1.0 / math.Sqrt(float64(C)))
|
|
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
|
|
out = mlx.Reshape(out, B*T, H*W, C)
|
|
|
|
// Project back
|
|
pShape := a.Proj.Shape()
|
|
var p *mlx.Array
|
|
if len(pShape) == 4 {
|
|
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
|
} else {
|
|
p = a.Proj
|
|
}
|
|
p = mlx.Transpose(p, 1, 0)
|
|
out = mlx.Linear(out, p)
|
|
if a.ProjBias != nil {
|
|
out = mlx.Add(out, a.ProjBias)
|
|
}
|
|
|
|
out = mlx.Reshape(out, B, T, H, W, C)
|
|
return mlx.Add(out, identity)
|
|
}
|
|
|
|
// UpBlock handles upsampling in decoder
|
|
type UpBlock struct {
|
|
ResBlocks []*ResBlock
|
|
Upsampler *Upsample
|
|
}
|
|
|
|
// newUpBlock creates an up block
|
|
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
|
resBlocks := make([]*ResBlock, numBlocks+1)
|
|
|
|
currentDim := inDim
|
|
for i := int32(0); i <= numBlocks; i++ {
|
|
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resBlocks[i] = block
|
|
currentDim = outDim
|
|
}
|
|
|
|
var upsampler *Upsample
|
|
if upsampleMode != "" {
|
|
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
|
}
|
|
|
|
return &UpBlock{
|
|
ResBlocks: resBlocks,
|
|
Upsampler: upsampler,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies up block
|
|
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
for _, block := range u.ResBlocks {
|
|
prev := x
|
|
x = block.Forward(x)
|
|
prev.Free()
|
|
}
|
|
|
|
if u.Upsampler != nil {
|
|
prev := x
|
|
x = u.Upsampler.Forward(x)
|
|
prev.Free()
|
|
}
|
|
return x
|
|
}
|
|
|
|
// Upsample handles spatial upsampling
|
|
type Upsample struct {
|
|
Conv *mlx.Array
|
|
Bias *mlx.Array
|
|
Mode string
|
|
}
|
|
|
|
// newUpsample creates an upsampler
|
|
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
|
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
|
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
|
return &Upsample{
|
|
Conv: conv,
|
|
Bias: bias,
|
|
Mode: mode,
|
|
}
|
|
}
|
|
|
|
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
|
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
T := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
C := shape[4]
|
|
outC := u.Conv.Shape()[0]
|
|
|
|
// Stage 1: 2x nearest neighbor upsample
|
|
{
|
|
x = mlx.Reshape(x, B*T, H, W, C)
|
|
x = upsample2xChannelsLast(x)
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
// Stage 2: Conv + bias
|
|
{
|
|
prev := x
|
|
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
|
x = conv2D3x3PaddedChannelsLast(x, weight)
|
|
if u.Bias != nil {
|
|
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
|
x = mlx.Add(x, bias)
|
|
}
|
|
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
return x
|
|
}
|
|
|
|
// MidBlock is the middle block
|
|
type MidBlock struct {
|
|
ResBlock1 *ResBlock
|
|
Attention *AttentionBlock
|
|
ResBlock2 *ResBlock
|
|
}
|
|
|
|
// newMidBlock creates a mid block
|
|
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
|
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &MidBlock{
|
|
ResBlock1: res1,
|
|
Attention: attn,
|
|
ResBlock2: res2,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies mid block
|
|
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
prev := x
|
|
x = m.ResBlock1.Forward(x)
|
|
prev.Free()
|
|
|
|
prev = x
|
|
x = m.Attention.Forward(x)
|
|
prev.Free()
|
|
|
|
prev = x
|
|
x = m.ResBlock2.Forward(x)
|
|
prev.Free()
|
|
|
|
return x
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func silu3D(x *mlx.Array) *mlx.Array {
|
|
return mlx.Mul(x, mlx.Sigmoid(x))
|
|
}
|
|
|
|
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
|
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
|
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
|
return x
|
|
}
|
|
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
|
}
|
|
|
|
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
|
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
H := shape[1]
|
|
W := shape[2]
|
|
|
|
rowIdxData := make([]int32, H*2)
|
|
for i := int32(0); i < H; i++ {
|
|
rowIdxData[i*2] = i
|
|
rowIdxData[i*2+1] = i
|
|
}
|
|
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
|
|
|
colIdxData := make([]int32, W*2)
|
|
for i := int32(0); i < W; i++ {
|
|
colIdxData[i*2] = i
|
|
colIdxData[i*2+1] = i
|
|
}
|
|
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
|
|
|
x = mlx.Take(x, rowIdx, 1)
|
|
x = mlx.Take(x, colIdx, 2)
|
|
|
|
return x
|
|
}
|
|
|
|
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
|
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
|
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
|
return mlx.Conv2d(x, weight, 1, 0)
|
|
}
|
|
|
|
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
|
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
|
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
|
|
wShape := weight.Shape()
|
|
Cout := wShape[0]
|
|
kH := wShape[1]
|
|
kW := wShape[2]
|
|
|
|
outH := (H - kH) / stride + 1
|
|
outW := (W - kW) / stride + 1
|
|
|
|
patches := extractPatches2DStrided(x, kH, kW, stride)
|
|
wFlat := mlx.Reshape(weight, Cout, -1)
|
|
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
|
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
|
return mlx.Reshape(out, B, outH, outW, Cout)
|
|
}
|
|
|
|
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
|
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
|
// strideT, strideH, strideW are the strides for each dimension
|
|
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
|
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
T := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
C := shape[4]
|
|
|
|
wShape := weight.Shape()
|
|
Cout := wShape[0]
|
|
// I := wShape[1]
|
|
kT := wShape[2]
|
|
kH := wShape[3]
|
|
kW := wShape[4]
|
|
|
|
// For temporal: if T < kT, we need to repeat frames temporally
|
|
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
|
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
|
if T < kT {
|
|
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
|
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
|
T = kT
|
|
}
|
|
|
|
outT := (T - kT) / strideT + 1
|
|
outH := (H - kH) / strideH + 1
|
|
outW := (W - kW) / strideW + 1
|
|
|
|
// Extract 3D patches in [C, T, H, W] order to match Python
|
|
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
|
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
|
|
|
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
|
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
|
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
|
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
|
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
|
}
|
|
|
|
// extractPatches3DStrided extracts 3D patches with given strides
|
|
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
|
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
T := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
C := shape[4]
|
|
|
|
outT := (T - kT) / strideT + 1
|
|
outH := (H - kH) / strideH + 1
|
|
outW := (W - kW) / strideW + 1
|
|
|
|
numPatches := outT * outH * outW
|
|
patches := make([]*mlx.Array, numPatches)
|
|
idx := 0
|
|
for t := int32(0); t < outT; t++ {
|
|
for i := int32(0); i < outH; i++ {
|
|
for j := int32(0); j < outW; j++ {
|
|
startT := t * strideT
|
|
startH := i * strideH
|
|
startW := j * strideW
|
|
// Extract patch: [B, kT, kH, kW, C]
|
|
patch := mlx.Slice(x,
|
|
[]int32{0, startT, startH, startW, 0},
|
|
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
|
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
|
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
|
// Flatten to [B, C*T*H*W]
|
|
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
|
patches[idx] = patch
|
|
idx++
|
|
}
|
|
}
|
|
}
|
|
|
|
for i := range patches {
|
|
patches[i] = mlx.ExpandDims(patches[i], 1)
|
|
}
|
|
stacked := mlx.Concatenate(patches, 1)
|
|
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
|
}
|
|
|
|
// extractPatches2DStrided extracts patches with given stride
|
|
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
C := shape[3]
|
|
|
|
outH := (H - kH) / stride + 1
|
|
outW := (W - kW) / stride + 1
|
|
|
|
patches := make([]*mlx.Array, outH*outW)
|
|
idx := 0
|
|
for i := int32(0); i < outH; i++ {
|
|
for j := int32(0); j < outW; j++ {
|
|
startH := i * stride
|
|
startW := j * stride
|
|
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
|
patch = mlx.Reshape(patch, B, kH*kW*C)
|
|
patches[idx] = patch
|
|
idx++
|
|
}
|
|
}
|
|
|
|
for i := range patches {
|
|
patches[i] = mlx.ExpandDims(patches[i], 1)
|
|
}
|
|
stacked := mlx.Concatenate(patches, 1)
|
|
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
|
}
|
|
|
|
// layerNormNoAffine applies layer norm without learnable parameters
|
|
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
|
ndim := x.Ndim()
|
|
lastAxis := ndim - 1
|
|
mean := mlx.Mean(x, lastAxis, true)
|
|
xCentered := mlx.Sub(x, mean)
|
|
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
|
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
|
}
|