mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
643 lines
16 KiB
Go
643 lines
16 KiB
Go
//go:build mlx
|
|
|
|
package qwen_image_edit
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
)
|
|
|
|
// VAEConfig holds Qwen-Image VAE configuration
|
|
type VAEConfig struct {
|
|
ZDim int32 `json:"z_dim"` // 16
|
|
BaseDim int32 `json:"base_dim"` // 96
|
|
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
|
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
|
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
|
LatentsStd []float32 `json:"latents_std"` // 16 values
|
|
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
|
}
|
|
|
|
// defaultVAEConfig returns config for Qwen-Image VAE
|
|
func defaultVAEConfig() *VAEConfig {
|
|
return &VAEConfig{
|
|
ZDim: 16,
|
|
BaseDim: 96,
|
|
DimMult: []int32{1, 2, 4, 4},
|
|
NumResBlocks: 2,
|
|
LatentsMean: []float32{
|
|
-0.7571, -0.7089, -0.9113, 0.1075,
|
|
-0.1745, 0.9653, -0.1517, 1.5508,
|
|
0.4134, -0.0715, 0.5517, -0.3632,
|
|
-0.1922, -0.9497, 0.2503, -0.2921,
|
|
},
|
|
LatentsStd: []float32{
|
|
2.8184, 1.4541, 2.3275, 2.6558,
|
|
1.2196, 1.7708, 2.6052, 2.0743,
|
|
3.2687, 2.1526, 2.8652, 1.5579,
|
|
1.6382, 1.1253, 2.8251, 1.916,
|
|
},
|
|
TemperalDownsample: []bool{false, true, true},
|
|
}
|
|
}
|
|
|
|
// VAE is the full VAE with encoder and decoder
|
|
type VAE struct {
|
|
Config *VAEConfig
|
|
Encoder *VAEEncoder
|
|
Decoder *VAEDecoder
|
|
}
|
|
|
|
// Load loads the VAE from a directory
|
|
func (m *VAE) Load(path string) error {
|
|
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
|
|
|
|
cfg := defaultVAEConfig()
|
|
m.Config = cfg
|
|
|
|
weights, err := safetensors.LoadModelWeights(path)
|
|
if err != nil {
|
|
return fmt.Errorf("weights: %w", err)
|
|
}
|
|
|
|
// Load weights as f32 for quality (matches Python default behavior)
|
|
// VAE decoder precision is critical for final image quality
|
|
fmt.Print(" Loading weights as f32... ")
|
|
if err := weights.Load(mlx.DtypeFloat32); err != nil {
|
|
return fmt.Errorf("failed to load weights: %w", err)
|
|
}
|
|
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
|
|
|
// Load encoder
|
|
fmt.Print(" Loading encoder... ")
|
|
m.Encoder = &VAEEncoder{}
|
|
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
|
|
return fmt.Errorf("encoder: %w", err)
|
|
}
|
|
fmt.Println("✓")
|
|
|
|
// Load decoder
|
|
fmt.Print(" Loading decoder... ")
|
|
m.Decoder = &VAEDecoder{}
|
|
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
|
|
return fmt.Errorf("decoder: %w", err)
|
|
}
|
|
fmt.Println("✓")
|
|
|
|
weights.ReleaseAll()
|
|
return nil
|
|
}
|
|
|
|
// Encode encodes an image to latents
|
|
// x: [B, C, T, H, W] image tensor in [-1, 1] range
|
|
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
|
|
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
|
|
return m.Encoder.Encode(x)
|
|
}
|
|
|
|
// Decode decodes latents to image
|
|
// z: [B, C, T, H, W] latents (denormalized)
|
|
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
|
|
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
|
|
return m.Decoder.Decode(z)
|
|
}
|
|
|
|
// Normalize applies latent normalization
|
|
// Input z should be f32 (from VAE encoder), output is f32 for transformer
|
|
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
|
|
shape := z.Shape()
|
|
C := shape[1]
|
|
|
|
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
|
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
|
|
|
// Mean/std are f32, will match z dtype through broadcasting
|
|
return mlx.Div(mlx.Sub(z, mean), std)
|
|
}
|
|
|
|
// Denormalize reverses latent normalization
|
|
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
|
|
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
|
|
shape := z.Shape()
|
|
C := shape[1]
|
|
|
|
// Convert latents to f32 for VAE decoder quality
|
|
z = mlx.AsType(z, mlx.DtypeFloat32)
|
|
|
|
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
|
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
|
|
|
return mlx.Add(mlx.Mul(z, std), mean)
|
|
}
|
|
|
|
// VAEEncoder is the encoder part of the VAE
|
|
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
|
|
// - Blocks 0,1: ResBlocks (base_dim)
|
|
// - Block 2: Downsample
|
|
// - Blocks 3,4: ResBlocks (base_dim*2)
|
|
// - Block 5: Downsample + temporal
|
|
// - Blocks 6,7: ResBlocks (base_dim*4)
|
|
// - Block 8: Downsample + temporal
|
|
// - Blocks 9,10: ResBlocks (base_dim*4)
|
|
type VAEEncoder struct {
|
|
Config *VAEConfig
|
|
|
|
ConvIn *CausalConv3d
|
|
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
|
|
MidBlock *MidBlock
|
|
NormOut *RMSNorm3D
|
|
ConvOut *CausalConv3d
|
|
QuantConv *CausalConv3d
|
|
}
|
|
|
|
// EncoderBlock is either a ResBlock or a Downsample
|
|
type EncoderBlock interface {
|
|
Forward(x *mlx.Array) *mlx.Array
|
|
IsDownsample() bool
|
|
}
|
|
|
|
// EncoderResBlock wraps ResBlock
|
|
type EncoderResBlock struct {
|
|
*ResBlock
|
|
}
|
|
|
|
func (b *EncoderResBlock) IsDownsample() bool { return false }
|
|
|
|
// EncoderDownsample is a downsample layer
|
|
type EncoderDownsample struct {
|
|
Resample *CausalConv3d
|
|
TimeConv *CausalConv3d // Optional temporal downsample
|
|
}
|
|
|
|
func (d *EncoderDownsample) IsDownsample() bool { return true }
|
|
|
|
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
|
|
// Spatial downsample with stride 2
|
|
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
|
|
x = d.forwardSpatialDownsample(x)
|
|
|
|
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
|
|
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
|
|
// The Python forward checks: if feat_cache is not None ... then use time_conv
|
|
// Since we don't support streaming, we skip time_conv entirely.
|
|
return x
|
|
}
|
|
|
|
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
|
|
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
|
|
xShape := x.Shape()
|
|
B := xShape[0]
|
|
T := xShape[1]
|
|
H := xShape[2]
|
|
W := xShape[3]
|
|
C := xShape[4]
|
|
|
|
wShape := d.Resample.Weight.Shape()
|
|
outC := wShape[0]
|
|
|
|
// Reshape to [B*T, H, W, C] for 2D conv
|
|
x = mlx.Reshape(x, B*T, H, W, C)
|
|
|
|
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
|
|
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
|
|
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
|
|
|
|
// Apply 2D conv with stride 2
|
|
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
|
x = conv2DStrided(x, weight, 2)
|
|
|
|
if d.Resample.Bias != nil {
|
|
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
|
|
x = mlx.Add(x, bias)
|
|
}
|
|
|
|
// Output dims after stride 2: (H+1)/2, (W+1)/2
|
|
outH := (H + 1) / 2
|
|
outW := (W + 1) / 2
|
|
|
|
// Reshape back to [B, T, H', W', C]
|
|
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
|
mlx.Eval(x)
|
|
|
|
return x
|
|
}
|
|
|
|
// loadFromWeights loads the encoder from pre-loaded weights
|
|
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
|
e.Config = cfg
|
|
|
|
// Conv in
|
|
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.ConvIn = convIn
|
|
|
|
// Encoder uses flat block structure:
|
|
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
|
|
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
|
|
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
|
|
e.Blocks = make([]EncoderBlock, 0, 11)
|
|
|
|
// Track dimensions
|
|
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
|
|
blockIdx := 0
|
|
|
|
for stage := 0; stage < len(cfg.DimMult); stage++ {
|
|
inDim := cfg.BaseDim
|
|
if stage > 0 {
|
|
inDim = dims[stage-1]
|
|
}
|
|
outDim := dims[stage]
|
|
|
|
// ResBlocks for this stage (num_res_blocks per stage)
|
|
for r := int32(0); r < cfg.NumResBlocks; r++ {
|
|
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
|
currentInDim := inDim
|
|
if r > 0 {
|
|
currentInDim = outDim
|
|
}
|
|
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
|
|
if err != nil {
|
|
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
|
|
}
|
|
e.Blocks = append(e.Blocks, block)
|
|
blockIdx++
|
|
}
|
|
|
|
// Downsample after each stage except the last
|
|
if stage < len(cfg.DimMult)-1 {
|
|
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
|
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
|
|
if err != nil {
|
|
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
|
|
}
|
|
e.Blocks = append(e.Blocks, down)
|
|
blockIdx++
|
|
}
|
|
}
|
|
|
|
// Mid block
|
|
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
|
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.MidBlock = midBlock
|
|
|
|
// Norm out
|
|
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.NormOut = normOut
|
|
|
|
// Conv out
|
|
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.ConvOut = convOut
|
|
|
|
// Quant conv
|
|
quantConv, err := newCausalConv3d(weights, "quant_conv")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.QuantConv = quantConv
|
|
|
|
return nil
|
|
}
|
|
|
|
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
|
|
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
|
|
block, err := newResBlock(weights, prefix, inDim, outDim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &EncoderResBlock{block}, nil
|
|
}
|
|
|
|
// newEncoderDownsample creates a downsample layer for the encoder
|
|
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
|
|
resample, err := newCausalConv3d(weights, prefix+".resample.1")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var timeConv *CausalConv3d
|
|
if temporal {
|
|
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
|
|
}
|
|
|
|
return &EncoderDownsample{
|
|
Resample: resample,
|
|
TimeConv: timeConv,
|
|
}, nil
|
|
}
|
|
|
|
// Encode encodes an image to latents
|
|
// x: [B, C, T, H, W] image tensor (channels-first)
|
|
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
|
|
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
|
|
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
|
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
|
mlx.Eval(x)
|
|
|
|
// Conv in
|
|
x = e.ConvIn.Forward(x)
|
|
|
|
// Encoder blocks (mix of ResBlocks and Downsamplers)
|
|
for _, block := range e.Blocks {
|
|
prev := x
|
|
x = block.Forward(x)
|
|
prev.Free()
|
|
}
|
|
|
|
// Mid block
|
|
x = e.MidBlock.Forward(x)
|
|
|
|
// Norm + silu
|
|
{
|
|
prev := x
|
|
x = e.NormOut.Forward(x)
|
|
x = silu3D(x)
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
// Conv out
|
|
{
|
|
prev := x
|
|
x = e.ConvOut.Forward(x)
|
|
prev.Free()
|
|
}
|
|
|
|
// Quant conv
|
|
{
|
|
prev := x
|
|
x = e.QuantConv.Forward(x)
|
|
prev.Free()
|
|
}
|
|
|
|
// Get mode from distribution (first half of channels = mean)
|
|
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
|
|
shape := x.Shape()
|
|
latentC := shape[4] / 2
|
|
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
|
|
|
|
// Convert back to channels-first [N, C, T, H, W]
|
|
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
|
mlx.Eval(x)
|
|
|
|
return x
|
|
}
|
|
|
|
// VAEDecoder is the decoder part of the VAE
|
|
type VAEDecoder struct {
|
|
Config *VAEConfig
|
|
|
|
PostQuantConv *CausalConv3d
|
|
ConvIn *CausalConv3d
|
|
MidBlock *MidBlock
|
|
UpBlocks []*UpBlock
|
|
NormOut *RMSNorm3D
|
|
ConvOut *CausalConv3d
|
|
}
|
|
|
|
// loadFromWeights loads the decoder from pre-loaded weights
|
|
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
|
d.Config = cfg
|
|
|
|
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.PostQuantConv = postQuantConv
|
|
|
|
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.ConvIn = convIn
|
|
|
|
// Mid block
|
|
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
|
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.MidBlock = midBlock
|
|
|
|
// Up blocks (reversed dim_mult)
|
|
numUpBlocks := len(cfg.DimMult)
|
|
d.UpBlocks = make([]*UpBlock, numUpBlocks)
|
|
|
|
dimsMult := make([]int32, numUpBlocks+1)
|
|
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
|
for i := 0; i < numUpBlocks; i++ {
|
|
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
|
}
|
|
|
|
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
|
for i := range cfg.TemperalDownsample {
|
|
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
|
}
|
|
|
|
for i := 0; i < numUpBlocks; i++ {
|
|
inDim := cfg.BaseDim * dimsMult[i]
|
|
outDim := cfg.BaseDim * dimsMult[i+1]
|
|
|
|
if i > 0 {
|
|
inDim = inDim / 2
|
|
}
|
|
|
|
upsampleMode := ""
|
|
if i < numUpBlocks-1 {
|
|
if temporalUpsample[i] {
|
|
upsampleMode = "upsample3d"
|
|
} else {
|
|
upsampleMode = "upsample2d"
|
|
}
|
|
}
|
|
|
|
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
|
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.UpBlocks[i] = upBlock
|
|
}
|
|
|
|
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.NormOut = normOut
|
|
|
|
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.ConvOut = convOut
|
|
|
|
return nil
|
|
}
|
|
|
|
// Decode converts latents to image
|
|
// z: [B, C, T, H, W] denormalized latents
|
|
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
|
var x *mlx.Array
|
|
|
|
// Convert from channels-first to channels-last
|
|
{
|
|
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
|
mlx.Eval(z)
|
|
}
|
|
|
|
// PostQuantConv
|
|
x = d.PostQuantConv.Forward(z)
|
|
z.Free()
|
|
|
|
// ConvIn
|
|
{
|
|
prev := x
|
|
x = d.ConvIn.Forward(x)
|
|
prev.Free()
|
|
}
|
|
|
|
// Mid block
|
|
x = d.MidBlock.Forward(x)
|
|
|
|
// Up blocks
|
|
for _, upBlock := range d.UpBlocks {
|
|
x = upBlock.Forward(x)
|
|
}
|
|
|
|
// NormOut + silu
|
|
{
|
|
prev := x
|
|
x = d.NormOut.Forward(x)
|
|
x = silu3D(x)
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
// ConvOut
|
|
{
|
|
prev := x
|
|
x = d.ConvOut.Forward(x)
|
|
prev.Free()
|
|
}
|
|
|
|
// Post-processing: clamp and convert back to channels-first
|
|
{
|
|
prev := x
|
|
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
|
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
|
prev.Free()
|
|
mlx.Eval(x)
|
|
}
|
|
|
|
return x
|
|
}
|
|
|
|
// DownBlock handles downsampling in encoder
|
|
type DownBlock struct {
|
|
ResBlocks []*ResBlock
|
|
Downsampler *Downsample
|
|
}
|
|
|
|
// newDownBlock creates a down block
|
|
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
|
|
resBlocks := make([]*ResBlock, numBlocks+1)
|
|
|
|
currentDim := inDim
|
|
for i := int32(0); i <= numBlocks; i++ {
|
|
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
|
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resBlocks[i] = block
|
|
currentDim = outDim
|
|
}
|
|
|
|
var downsampler *Downsample
|
|
if downsampleMode != "" {
|
|
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
|
|
}
|
|
|
|
return &DownBlock{
|
|
ResBlocks: resBlocks,
|
|
Downsampler: downsampler,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies down block
|
|
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
|
|
for _, block := range d.ResBlocks {
|
|
prev := x
|
|
x = block.Forward(x)
|
|
prev.Free()
|
|
}
|
|
|
|
if d.Downsampler != nil {
|
|
prev := x
|
|
x = d.Downsampler.Forward(x)
|
|
prev.Free()
|
|
}
|
|
return x
|
|
}
|
|
|
|
// Downsample handles spatial downsampling
|
|
type Downsample struct {
|
|
Conv *mlx.Array
|
|
Bias *mlx.Array
|
|
Mode string
|
|
}
|
|
|
|
// newDownsample creates a downsampler
|
|
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
|
|
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
|
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
|
return &Downsample{
|
|
Conv: conv,
|
|
Bias: bias,
|
|
Mode: mode,
|
|
}
|
|
}
|
|
|
|
// Forward applies downsampling to channels-last input [B, T, H, W, C]
|
|
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
T := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
C := shape[4]
|
|
outC := d.Conv.Shape()[0]
|
|
|
|
// Reshape to [B*T, H, W, C] for 2D conv
|
|
x = mlx.Reshape(x, B*T, H, W, C)
|
|
|
|
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
|
|
// For 3x3 stride 2: pad 1 on all sides
|
|
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
|
|
|
// Conv with stride 2 using manual strided patching
|
|
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
|
|
x = conv2DStrided(x, weight, 2)
|
|
if d.Bias != nil {
|
|
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
|
|
x = mlx.Add(x, bias)
|
|
}
|
|
|
|
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
|
|
mlx.Eval(x)
|
|
|
|
return x
|
|
}
|