mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
1803 lines
57 KiB
Go
1803 lines
57 KiB
Go
//go:build mlx
|
|
|
|
package qwen_image
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"path/filepath"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
)
|
|
|
|
// Qwen25VLConfig holds Qwen2.5-VL configuration
|
|
type Qwen25VLConfig struct {
|
|
// Text model config
|
|
HiddenSize int32 `json:"hidden_size"` // 3584
|
|
NumHiddenLayers int32 `json:"num_hidden_layers"` // 28
|
|
IntermediateSize int32 `json:"intermediate_size"` // 18944
|
|
NumAttentionHeads int32 `json:"num_attention_heads"` // 28
|
|
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 4
|
|
VocabSize int32 `json:"vocab_size"` // 152064
|
|
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-6
|
|
RopeTheta float32 `json:"rope_theta"` // 1000000
|
|
HeadDim int32 // Calculated: HiddenSize / NumAttentionHeads
|
|
MRoPESection []int32 // [16, 24, 24] for temporal, height, width
|
|
|
|
// Vision config
|
|
VisionHiddenSize int32 `json:"vision_hidden_size"` // 1280
|
|
VisionNumLayers int32 `json:"vision_num_layers"` // 32
|
|
VisionNumHeads int32 `json:"vision_num_heads"` // 16
|
|
VisionIntermSize int32 `json:"vision_intermediate"` // 3420
|
|
VisionPatchSize int32 `json:"vision_patch_size"` // 14
|
|
VisionOutHiddenSize int32 `json:"vision_out_hidden"` // 3584
|
|
VisionSpatialMerge int32 `json:"vision_spatial_merge"` // 2
|
|
VisionWindowSize int32 `json:"vision_window_size"` // 112
|
|
VisionFullAttIdx []int32 // [7, 15, 23, 31]
|
|
|
|
// Special tokens
|
|
ImageTokenID int32 // 151655
|
|
VisionStartTokenID int32 // 151652
|
|
VisionEndTokenID int32 // 151653
|
|
}
|
|
|
|
// defaultQwen25VLConfig returns default config
|
|
func defaultQwen25VLConfig() *Qwen25VLConfig {
|
|
cfg := &Qwen25VLConfig{
|
|
// Text
|
|
HiddenSize: 3584,
|
|
NumHiddenLayers: 28,
|
|
IntermediateSize: 18944,
|
|
NumAttentionHeads: 28,
|
|
NumKeyValueHeads: 4,
|
|
VocabSize: 152064,
|
|
RMSNormEps: 1e-6,
|
|
RopeTheta: 1000000,
|
|
MRoPESection: []int32{16, 24, 24},
|
|
|
|
// Vision
|
|
VisionHiddenSize: 1280,
|
|
VisionNumLayers: 32,
|
|
VisionNumHeads: 16,
|
|
VisionIntermSize: 3420,
|
|
VisionPatchSize: 14,
|
|
VisionOutHiddenSize: 3584,
|
|
VisionSpatialMerge: 2,
|
|
VisionWindowSize: 112,
|
|
VisionFullAttIdx: []int32{7, 15, 23, 31},
|
|
|
|
// Special tokens
|
|
ImageTokenID: 151655,
|
|
VisionStartTokenID: 151652,
|
|
VisionEndTokenID: 153653,
|
|
}
|
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
|
return cfg
|
|
}
|
|
|
|
// Qwen25VL is the Qwen2.5-VL vision-language encoder
|
|
type Qwen25VL struct {
|
|
Config *Qwen25VLConfig
|
|
|
|
// Text model
|
|
Embedding *mlx.Array
|
|
Blocks []*VLTextBlock
|
|
FinalNorm *mlx.Array
|
|
|
|
// Vision tower (optional - nil for text-only models)
|
|
VisionPatchEmbed *VisionPatchEmbed
|
|
VisionBlocks []*VisionBlock
|
|
VisionMerger *VisionMerger
|
|
HasVision bool // True if vision tower is loaded
|
|
}
|
|
|
|
// LoadTextOnly loads only the text encoder components (skips vision tower)
|
|
// Use this for text-to-image generation where vision components are not needed
|
|
func (m *Qwen25VL) LoadTextOnly(path string) error {
|
|
return m.load(path, false)
|
|
}
|
|
|
|
// Load loads the vision-language encoder from a directory
|
|
// Vision components are loaded if weights exist
|
|
func (m *Qwen25VL) Load(path string) error {
|
|
return m.load(path, true)
|
|
}
|
|
|
|
// load is the internal loading function
|
|
func (m *Qwen25VL) load(path string, loadVision bool) error {
|
|
fmt.Println("Loading Qwen2.5-VL encoder...")
|
|
|
|
cfg := defaultQwen25VLConfig()
|
|
m.Config = cfg
|
|
|
|
weights, err := safetensors.LoadModelWeights(path)
|
|
if err != nil {
|
|
return fmt.Errorf("weights: %w", err)
|
|
}
|
|
|
|
// Bulk load all weights as bf16
|
|
fmt.Print(" Loading weights as bf16... ")
|
|
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
|
return fmt.Errorf("failed to load weights: %w", err)
|
|
}
|
|
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
|
|
|
// Load text embedding
|
|
fmt.Print(" Loading text embeddings... ")
|
|
embedding, err := weights.Get("model.embed_tokens.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.Embedding = embedding
|
|
fmt.Printf("✓ [%v]\n", embedding.Shape())
|
|
|
|
// Load text blocks
|
|
m.Blocks = make([]*VLTextBlock, cfg.NumHiddenLayers)
|
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
fmt.Printf("\r Loading text blocks... %d/%d", i+1, cfg.NumHiddenLayers)
|
|
block, err := newVLTextBlock(weights, int(i), cfg)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load text block %d: %w", i, err)
|
|
}
|
|
m.Blocks[i] = block
|
|
}
|
|
fmt.Printf("\r Loading text blocks... ✓ [%d blocks] \n", cfg.NumHiddenLayers)
|
|
|
|
// Load final norm
|
|
fmt.Print(" Loading final norm... ")
|
|
finalNorm, err := weights.Get("model.norm.weight")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.FinalNorm = finalNorm
|
|
fmt.Println("✓")
|
|
|
|
// Try to load vision tower (optional)
|
|
m.HasVision = false
|
|
if loadVision {
|
|
if _, err := weights.Get("visual.patch_embed.proj.weight"); err == nil {
|
|
fmt.Print(" Loading vision patch embed... ")
|
|
m.VisionPatchEmbed, err = newVisionPatchEmbed(weights, cfg)
|
|
if err != nil {
|
|
return fmt.Errorf("vision patch embed: %w", err)
|
|
}
|
|
fmt.Println("✓")
|
|
|
|
m.VisionBlocks = make([]*VisionBlock, cfg.VisionNumLayers)
|
|
for i := int32(0); i < cfg.VisionNumLayers; i++ {
|
|
fmt.Printf("\r Loading vision blocks... %d/%d", i+1, cfg.VisionNumLayers)
|
|
block, err := newVisionBlock(weights, int(i), cfg)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load vision block %d: %w", i, err)
|
|
}
|
|
m.VisionBlocks[i] = block
|
|
}
|
|
fmt.Printf("\r Loading vision blocks... ✓ [%d blocks] \n", cfg.VisionNumLayers)
|
|
|
|
fmt.Print(" Loading vision merger... ")
|
|
m.VisionMerger, err = newVisionMerger(weights, cfg)
|
|
if err != nil {
|
|
return fmt.Errorf("vision merger: %w", err)
|
|
}
|
|
fmt.Println("✓")
|
|
|
|
m.HasVision = true
|
|
} else {
|
|
fmt.Println(" (No vision tower - text-only mode)")
|
|
}
|
|
} else {
|
|
fmt.Println(" (Skipping vision tower)")
|
|
}
|
|
|
|
weights.ReleaseAll()
|
|
return nil
|
|
}
|
|
|
|
// EncodePrompt encodes a text prompt for image generation (text-only mode)
|
|
// Uses the Qwen-Image template and drops the first 34 tokens (system prefix)
|
|
func (m *Qwen25VL) EncodePrompt(tok *tokenizer.Tokenizer, prompt string) *mlx.Array {
|
|
cfg := m.Config
|
|
|
|
// Template from Python: prompt_template_encode (for image generation)
|
|
template := "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n"
|
|
formattedPrompt := fmt.Sprintf(template, prompt)
|
|
|
|
// Tokenize
|
|
tokens := tok.Encode(formattedPrompt, false)
|
|
|
|
// Create token array
|
|
seqLen := int32(len(tokens))
|
|
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen})
|
|
|
|
// Get text embeddings
|
|
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr)
|
|
|
|
// Compute RoPE
|
|
cossin := m.computeTextRoPE(seqLen, 1)
|
|
|
|
// Forward through ALL text blocks
|
|
x := textEmbed
|
|
for _, block := range m.Blocks {
|
|
x = block.Forward(x, cossin)
|
|
}
|
|
|
|
// Apply final norm
|
|
x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
|
|
|
|
// Drop first 34 tokens (system prefix)
|
|
// prompt_template_encode_start_idx = 34
|
|
dropIdx := int32(34)
|
|
if x.Shape()[1] > dropIdx {
|
|
x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
|
|
}
|
|
|
|
return x
|
|
}
|
|
|
|
// EncodePromptWithImage encodes a text prompt with an image
|
|
// Returns: embeddings [B, L, hidden_size], mask [B, L], error
|
|
func (m *Qwen25VL) EncodePromptWithImage(tok *tokenizer.Tokenizer, prompt string, image *mlx.Array) (*mlx.Array, *mlx.Array, error) {
|
|
if !m.HasVision {
|
|
return nil, nil, errors.New("EncodePromptWithImage called on text-only model")
|
|
}
|
|
|
|
cfg := m.Config
|
|
|
|
// Template from Python diffusers pipeline: prompt_template_encode
|
|
// Python's _get_qwen_prompt_embeds adds "Picture 1: " before vision tokens
|
|
template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>%s<|im_end|>\n<|im_start|>assistant\n"
|
|
formattedPrompt := fmt.Sprintf(template, prompt)
|
|
|
|
// Tokenize
|
|
tokens := tok.Encode(formattedPrompt, false)
|
|
|
|
// Process vision if image provided
|
|
var visionEmbeddings *mlx.Array
|
|
var numImageTokens int32
|
|
var visionH, visionW int32 // Grid dims in patches (before spatial merge)
|
|
if image != nil {
|
|
visionEmbeddings = m.encodeVision(image)
|
|
numImageTokens = visionEmbeddings.Shape()[1]
|
|
// Get original grid dimensions from image shape
|
|
imgShape := image.Shape()
|
|
visionH = imgShape[2] / cfg.VisionPatchSize // Height in patches
|
|
visionW = imgShape[3] / cfg.VisionPatchSize // Width in patches
|
|
}
|
|
|
|
// Find image token position and expand
|
|
expandedTokens := make([]int32, 0, len(tokens)+int(numImageTokens))
|
|
imageTokenPos := int32(-1)
|
|
textAfterCount := int32(0)
|
|
for i, t := range tokens {
|
|
if t == cfg.ImageTokenID {
|
|
imageTokenPos = int32(len(expandedTokens))
|
|
// Insert placeholder tokens for image
|
|
for j := int32(0); j < numImageTokens; j++ {
|
|
expandedTokens = append(expandedTokens, cfg.ImageTokenID)
|
|
}
|
|
// Count remaining tokens after image
|
|
textAfterCount = int32(len(tokens) - i - 1)
|
|
} else {
|
|
expandedTokens = append(expandedTokens, t)
|
|
}
|
|
}
|
|
|
|
// Create token array
|
|
seqLen := int32(len(expandedTokens))
|
|
tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen})
|
|
|
|
// Get text embeddings
|
|
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
|
|
|
|
// Replace image token embeddings with vision embeddings
|
|
if visionEmbeddings != nil && imageTokenPos >= 0 {
|
|
// Split, replace, concat
|
|
before := mlx.Slice(textEmbed, []int32{0, 0, 0}, []int32{1, imageTokenPos, cfg.HiddenSize})
|
|
after := mlx.Slice(textEmbed, []int32{0, imageTokenPos + numImageTokens, 0}, []int32{1, seqLen, cfg.HiddenSize})
|
|
textEmbed = mlx.Concatenate([]*mlx.Array{before, visionEmbeddings, after}, 1)
|
|
}
|
|
|
|
// Compute RoPE - use multimodal RoPE when image is present
|
|
var cossin [2]*mlx.Array
|
|
if image != nil && imageTokenPos >= 0 {
|
|
cossin = m.ComputeMultimodalRoPE(imageTokenPos, visionH, visionW, textAfterCount, cfg.VisionSpatialMerge)
|
|
} else {
|
|
cossin = m.computeTextRoPE(seqLen, 1)
|
|
}
|
|
|
|
// Forward through ALL text blocks
|
|
// Python uses hidden_states[-1] (LAST layer output, not second-to-last!)
|
|
x := textEmbed
|
|
for _, block := range m.Blocks {
|
|
x = block.Forward(x, cossin)
|
|
}
|
|
|
|
// Apply final norm (Python DOES apply this for the output)
|
|
x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
|
|
|
|
// Drop first N tokens (system prefix)
|
|
// prompt_template_encode_start_idx = 64
|
|
dropIdx := int32(64)
|
|
if x.Shape()[1] > dropIdx {
|
|
x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
|
|
}
|
|
|
|
// Create attention mask (all ones for now)
|
|
mask := mlx.Ones(1, x.Shape()[1])
|
|
|
|
return x, mask, nil
|
|
}
|
|
|
|
// EncodeVision encodes an image through the vision tower (exported for testing)
|
|
// image: [B, C, H, W] normalized image tensor
|
|
// Returns: [B, num_tokens, hidden_size] vision embeddings
|
|
func (m *Qwen25VL) EncodeVision(image *mlx.Array) *mlx.Array {
|
|
return m.encodeVision(image)
|
|
}
|
|
|
|
// VisionRegion describes where vision embeddings are inserted in the sequence
|
|
type VisionRegion struct {
|
|
StartPos int32 // Position in sequence where vision tokens start
|
|
NumTokens int32 // Number of vision tokens
|
|
GridH int32 // Vision grid height (in patches, after spatial merge)
|
|
GridW int32 // Vision grid width (in patches, after spatial merge)
|
|
}
|
|
|
|
// EncodePromptWithImages encodes a text prompt with multiple images
|
|
// Returns: embeddings [B, L, hidden_size], mask [B, L], regions []VisionRegion, error
|
|
func (m *Qwen25VL) EncodePromptWithImages(tok *tokenizer.Tokenizer, prompt string, images []*mlx.Array) (*mlx.Array, *mlx.Array, []VisionRegion, error) {
|
|
if !m.HasVision {
|
|
return nil, nil, nil, errors.New("EncodePromptWithImages called on text-only model")
|
|
}
|
|
if len(images) == 0 {
|
|
return nil, nil, nil, errors.New("EncodePromptWithImages called with no images")
|
|
}
|
|
|
|
cfg := m.Config
|
|
|
|
// Build image prompt prefix: "Picture 1: <vision>...Picture N: <vision>..."
|
|
imgPromptTemplate := "Picture %d: <|vision_start|><|image_pad|><|vision_end|>"
|
|
imgPrompt := ""
|
|
for i := range images {
|
|
imgPrompt += fmt.Sprintf(imgPromptTemplate, i+1)
|
|
}
|
|
|
|
// Template from Python diffusers pipeline: prompt_template_encode
|
|
template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n%s%s<|im_end|>\n<|im_start|>assistant\n"
|
|
formattedPrompt := fmt.Sprintf(template, imgPrompt, prompt)
|
|
|
|
// Tokenize
|
|
tokens := tok.Encode(formattedPrompt, false)
|
|
|
|
// Process each image through vision tower
|
|
visionEmbeddings := make([]*mlx.Array, len(images))
|
|
numImageTokens := make([]int32, len(images))
|
|
visionGridH := make([]int32, len(images))
|
|
visionGridW := make([]int32, len(images))
|
|
|
|
for i, image := range images {
|
|
visionEmbeddings[i] = m.encodeVision(image)
|
|
numImageTokens[i] = visionEmbeddings[i].Shape()[1]
|
|
// Get original grid dimensions from image shape
|
|
imgShape := image.Shape()
|
|
visionH := imgShape[2] / cfg.VisionPatchSize // Height in patches
|
|
visionW := imgShape[3] / cfg.VisionPatchSize // Width in patches
|
|
// After spatial merge, grid is halved
|
|
visionGridH[i] = visionH / cfg.VisionSpatialMerge
|
|
visionGridW[i] = visionW / cfg.VisionSpatialMerge
|
|
}
|
|
|
|
// Find all image token positions and expand tokens
|
|
expandedTokens := make([]int32, 0, len(tokens)+int(sum(numImageTokens)))
|
|
imagePositions := make([]int32, 0, len(images)) // Start position for each image's tokens
|
|
imageIdx := 0
|
|
|
|
for _, t := range tokens {
|
|
if t == cfg.ImageTokenID {
|
|
if imageIdx < len(images) {
|
|
imagePositions = append(imagePositions, int32(len(expandedTokens)))
|
|
// Insert placeholder tokens for this image
|
|
for j := int32(0); j < numImageTokens[imageIdx]; j++ {
|
|
expandedTokens = append(expandedTokens, cfg.ImageTokenID)
|
|
}
|
|
imageIdx++
|
|
}
|
|
} else {
|
|
expandedTokens = append(expandedTokens, t)
|
|
}
|
|
}
|
|
|
|
// Create token array
|
|
seqLen := int32(len(expandedTokens))
|
|
tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen})
|
|
|
|
// Get text embeddings
|
|
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
|
|
|
|
// Replace image token embeddings with vision embeddings
|
|
// Build list of segments to concatenate
|
|
segments := make([]*mlx.Array, 0, len(images)*2+1)
|
|
regions := make([]VisionRegion, len(images))
|
|
lastEnd := int32(0)
|
|
|
|
for i, imgPos := range imagePositions {
|
|
// Text segment before this image
|
|
if imgPos > lastEnd {
|
|
segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, imgPos, cfg.HiddenSize}))
|
|
}
|
|
// Vision embeddings for this image
|
|
segments = append(segments, visionEmbeddings[i])
|
|
regions[i] = VisionRegion{
|
|
StartPos: imgPos,
|
|
NumTokens: numImageTokens[i],
|
|
GridH: visionGridH[i],
|
|
GridW: visionGridW[i],
|
|
}
|
|
lastEnd = imgPos + numImageTokens[i]
|
|
}
|
|
// Remaining text after last image
|
|
if lastEnd < seqLen {
|
|
segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, seqLen, cfg.HiddenSize}))
|
|
}
|
|
|
|
// Concatenate all segments
|
|
textEmbed = mlx.Concatenate(segments, 1)
|
|
|
|
// Compute RoPE - use multimodal RoPE for multiple images
|
|
cossin, err := m.ComputeMultiImageRoPE(imagePositions, visionGridH, visionGridW, numImageTokens, seqLen)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("computing RoPE: %w", err)
|
|
}
|
|
|
|
// Forward through ALL text blocks
|
|
x := textEmbed
|
|
for _, block := range m.Blocks {
|
|
x = block.Forward(x, cossin)
|
|
}
|
|
|
|
// Apply final norm
|
|
x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
|
|
|
|
// Drop first N tokens (system prefix)
|
|
// prompt_template_encode_start_idx = 64
|
|
dropIdx := int32(64)
|
|
if x.Shape()[1] > dropIdx {
|
|
x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
|
|
// Adjust region positions
|
|
for i := range regions {
|
|
regions[i].StartPos -= dropIdx
|
|
}
|
|
}
|
|
|
|
// Create attention mask (all ones)
|
|
mask := mlx.Ones(1, x.Shape()[1])
|
|
|
|
return x, mask, regions, nil
|
|
}
|
|
|
|
// sum returns the sum of int32 slice
|
|
func sum(arr []int32) int32 {
|
|
var s int32
|
|
for _, v := range arr {
|
|
s += v
|
|
}
|
|
return s
|
|
}
|
|
|
|
// EncodeTextOnly encodes text tokens through all text blocks (exported for testing)
|
|
// tokens: array of token IDs
|
|
// Returns: [B, L, hidden_size] text embeddings after all blocks
|
|
func (m *Qwen25VL) EncodeTextOnly(tokens []int32) *mlx.Array {
|
|
seqLen := int32(len(tokens))
|
|
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen})
|
|
|
|
// Get text embeddings
|
|
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
|
|
|
|
// Compute RoPE
|
|
cossin := m.computeTextRoPE(seqLen, 1)
|
|
|
|
// Forward through ALL text blocks (unlike Encode which stops at second-to-last)
|
|
x := textEmbed
|
|
for _, block := range m.Blocks {
|
|
x = block.Forward(x, cossin)
|
|
}
|
|
|
|
// Apply final norm
|
|
x = mlx.RMSNorm(x, m.FinalNorm, m.Config.RMSNormEps)
|
|
|
|
return x
|
|
}
|
|
|
|
// encodeVision encodes an image through the vision tower
|
|
// image: [B, C, H, W] normalized image tensor
|
|
// Returns: [B, num_tokens, hidden_size] vision embeddings
|
|
func (m *Qwen25VL) encodeVision(image *mlx.Array) *mlx.Array {
|
|
cfg := m.Config
|
|
|
|
// Calculate grid dimensions from image
|
|
imgShape := image.Shape()
|
|
imgH := imgShape[2]
|
|
imgW := imgShape[3]
|
|
pH := imgH / cfg.VisionPatchSize // grid height in patches
|
|
pW := imgW / cfg.VisionPatchSize // grid width in patches
|
|
|
|
// Patch embed
|
|
x := m.VisionPatchEmbed.Forward(image)
|
|
mlx.Eval(x)
|
|
|
|
// Get window reordering info
|
|
winInfo := m.getWindowInfo(pH, pW)
|
|
|
|
// Compute vision RoPE embeddings (already in 2x2-block order)
|
|
posEmb := m.computeVisionRoPE(pH, pW)
|
|
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1] // num patches = pH * pW
|
|
D := shape[2]
|
|
spatialMergeUnit := winInfo.SpatialMergeUnit
|
|
spatialMerge := cfg.VisionSpatialMerge
|
|
|
|
// Convert patch embed from row-major to 2x2-block order
|
|
// Row-major: (0,0), (0,1), (0,2), ..., (1,0), (1,1), ...
|
|
// 2x2-block: (0,0), (0,1), (1,0), (1,1), (0,2), (0,3), (1,2), (1,3), ...
|
|
llmGridH := pH / spatialMerge
|
|
llmGridW := pW / spatialMerge
|
|
blockReorderIdx := make([]int32, L)
|
|
idx := int32(0)
|
|
for hBlock := int32(0); hBlock < llmGridH; hBlock++ {
|
|
for wBlock := int32(0); wBlock < llmGridW; wBlock++ {
|
|
for dh := int32(0); dh < spatialMerge; dh++ {
|
|
for dw := int32(0); dw < spatialMerge; dw++ {
|
|
h := hBlock*spatialMerge + dh
|
|
w := wBlock*spatialMerge + dw
|
|
rowMajorIdx := h*pW + w
|
|
blockReorderIdx[idx] = rowMajorIdx
|
|
idx++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
blockIdxArr := mlx.NewArrayInt32(blockReorderIdx, []int32{L})
|
|
x = mlx.Take(x, blockIdxArr, 1) // Reorder patches to 2x2-block order
|
|
|
|
// Window reorder hidden states and RoPE before blocks
|
|
// Python: reshape to [L/4, 4, D], reorder dim 0, reshape back
|
|
// Reshape x: [B, L, D] -> [B, L/4, 4, D]
|
|
x = mlx.Reshape(x, B, L/spatialMergeUnit, spatialMergeUnit, D)
|
|
// Reorder using window index
|
|
winIdxArr := mlx.NewArrayInt32(winInfo.WindowIndex, []int32{int32(len(winInfo.WindowIndex))})
|
|
x = mlx.Take(x, winIdxArr, 1) // Take along axis 1
|
|
// Reshape back: [B, L/4, 4, D] -> [B, L, D]
|
|
x = mlx.Reshape(x, B, L, D)
|
|
|
|
// Similarly reorder RoPE: [L, headDim] -> [L/4, 4, headDim] -> reorder -> [L, headDim]
|
|
cosShape := posEmb[0].Shape()
|
|
ropeL := cosShape[0]
|
|
ropeD := cosShape[1]
|
|
cos := mlx.Reshape(posEmb[0], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD)
|
|
sin := mlx.Reshape(posEmb[1], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD)
|
|
cos = mlx.Take(cos, winIdxArr, 0)
|
|
sin = mlx.Take(sin, winIdxArr, 0)
|
|
cos = mlx.Reshape(cos, ropeL, ropeD)
|
|
sin = mlx.Reshape(sin, ropeL, ropeD)
|
|
posEmb = [2]*mlx.Array{cos, sin}
|
|
|
|
// Materialize to prevent freeing during block evaluations
|
|
mlx.Eval(x, posEmb[0], posEmb[1])
|
|
|
|
// Full sequence cu_seqlens for full attention blocks
|
|
cuSeqlensFull := []int32{0, L}
|
|
|
|
// Vision blocks - use window attention except at full attention indices
|
|
for i, block := range m.VisionBlocks {
|
|
useFullAttention := false
|
|
for _, idx := range cfg.VisionFullAttIdx {
|
|
if int32(i) == idx {
|
|
useFullAttention = true
|
|
break
|
|
}
|
|
}
|
|
|
|
var cuSeqlens []int32
|
|
if useFullAttention {
|
|
cuSeqlens = cuSeqlensFull
|
|
} else {
|
|
cuSeqlens = winInfo.CuWindowSeqlens
|
|
}
|
|
|
|
x = block.Forward(x, posEmb, cuSeqlens)
|
|
}
|
|
|
|
// Spatial merge (2x2 -> 1)
|
|
x = m.VisionMerger.ForwardWithDims(x, pH, pW)
|
|
|
|
// Reverse window reorder after merger
|
|
revIdxArr := mlx.NewArrayInt32(winInfo.ReverseIndex, []int32{int32(len(winInfo.ReverseIndex))})
|
|
x = mlx.Take(x, revIdxArr, 1)
|
|
|
|
return x
|
|
}
|
|
|
|
// WindowInfo holds window reordering and attention boundary info
|
|
type WindowInfo struct {
|
|
WindowIndex []int32 // Reordering indices for merged tokens
|
|
ReverseIndex []int32 // Reverse reordering indices
|
|
CuWindowSeqlens []int32 // Cumulative window boundaries in UNMERGED sequence
|
|
SpatialMergeUnit int32 // Number of patches per merged token (4 = 2x2)
|
|
}
|
|
|
|
// getWindowInfo computes window reordering indices and attention boundaries
|
|
// pH, pW: patch grid dimensions before 2x2 merge
|
|
func (m *Qwen25VL) getWindowInfo(pH, pW int32) *WindowInfo {
|
|
cfg := m.Config
|
|
spatialMergeUnit := cfg.VisionSpatialMerge * cfg.VisionSpatialMerge // 4
|
|
|
|
// After 2x2 merge
|
|
llmGridH := pH / cfg.VisionSpatialMerge
|
|
llmGridW := pW / cfg.VisionSpatialMerge
|
|
numTokens := llmGridH * llmGridW
|
|
|
|
// Window size in merged tokens
|
|
// window_size=112, spatial_merge_size=2, patch_size=14
|
|
// vit_merger_window_size = 112 / 2 / 14 = 4
|
|
vitMergerWindowSize := cfg.VisionWindowSize / cfg.VisionSpatialMerge / cfg.VisionPatchSize
|
|
|
|
// Calculate padding and number of windows
|
|
padH := vitMergerWindowSize - llmGridH%vitMergerWindowSize
|
|
if padH == vitMergerWindowSize {
|
|
padH = 0
|
|
}
|
|
padW := vitMergerWindowSize - llmGridW%vitMergerWindowSize
|
|
if padW == vitMergerWindowSize {
|
|
padW = 0
|
|
}
|
|
|
|
numWindowsH := (llmGridH + padH) / vitMergerWindowSize
|
|
numWindowsW := (llmGridW + padW) / vitMergerWindowSize
|
|
|
|
// Create padded grid with -1 for padding
|
|
paddedH := llmGridH + padH
|
|
paddedW := llmGridW + padW
|
|
grid := make([]int32, paddedH*paddedW)
|
|
for i := range grid {
|
|
grid[i] = -1
|
|
}
|
|
for h := int32(0); h < llmGridH; h++ {
|
|
for w := int32(0); w < llmGridW; w++ {
|
|
grid[h*paddedW+w] = h*llmGridW + w
|
|
}
|
|
}
|
|
|
|
// Reorder into windows and track window sizes
|
|
windowIndex := make([]int32, 0, numTokens)
|
|
windowSizes := make([]int32, 0, numWindowsH*numWindowsW)
|
|
ws := vitMergerWindowSize
|
|
|
|
for wh := int32(0); wh < numWindowsH; wh++ {
|
|
for ww := int32(0); ww < numWindowsW; ww++ {
|
|
windowStart := len(windowIndex)
|
|
// Extract window
|
|
for h := int32(0); h < ws; h++ {
|
|
for w := int32(0); w < ws; w++ {
|
|
idx := (wh*ws+h)*paddedW + (ww*ws + w)
|
|
if grid[idx] >= 0 {
|
|
windowIndex = append(windowIndex, grid[idx])
|
|
}
|
|
}
|
|
}
|
|
windowSize := int32(len(windowIndex) - windowStart)
|
|
windowSizes = append(windowSizes, windowSize)
|
|
}
|
|
}
|
|
|
|
// Create reverse index (argsort of windowIndex)
|
|
reverseIndex := make([]int32, numTokens)
|
|
for i, idx := range windowIndex {
|
|
reverseIndex[idx] = int32(i)
|
|
}
|
|
|
|
// Compute cumulative sequence lengths in UNMERGED sequence
|
|
// Each merged token corresponds to spatialMergeUnit patches
|
|
cuWindowSeqlens := make([]int32, len(windowSizes)+1)
|
|
cuWindowSeqlens[0] = 0
|
|
for i, size := range windowSizes {
|
|
cuWindowSeqlens[i+1] = cuWindowSeqlens[i] + size*spatialMergeUnit
|
|
}
|
|
|
|
return &WindowInfo{
|
|
WindowIndex: windowIndex,
|
|
ReverseIndex: reverseIndex,
|
|
CuWindowSeqlens: cuWindowSeqlens,
|
|
SpatialMergeUnit: spatialMergeUnit,
|
|
}
|
|
}
|
|
|
|
// ComputeMultiImageRoPE computes M-RoPE for combined text + multiple vision regions + text sequences
|
|
// This extends ComputeMultimodalRoPE to handle N images instead of just one.
|
|
//
|
|
// Parameters:
|
|
// - imagePositions: starting position of each image's tokens in the sequence
|
|
// - visionGridH, visionGridW: grid dimensions for each image (after spatial merge)
|
|
// - numImageTokens: number of tokens for each image
|
|
// - totalLen: total sequence length
|
|
func (m *Qwen25VL) ComputeMultiImageRoPE(imagePositions []int32, visionGridH, visionGridW, numImageTokens []int32, totalLen int32) ([2]*mlx.Array, error) {
|
|
numImages := len(imagePositions)
|
|
|
|
// Build 3D position IDs: [3, 1, totalLen]
|
|
// Dimension 0: temporal, Dimension 1: height, Dimension 2: width
|
|
posIDs := make([]float32, 3*totalLen)
|
|
|
|
// Process sequence in order
|
|
stIdx := int32(0) // Running text position counter
|
|
seqIdx := int32(0)
|
|
|
|
for i := 0; i < numImages; i++ {
|
|
imgPos := imagePositions[i]
|
|
gridH := visionGridH[i]
|
|
gridW := visionGridW[i]
|
|
numTokens := numImageTokens[i]
|
|
|
|
// Text segment before this image
|
|
for seqIdx < imgPos {
|
|
posIDs[0*totalLen+seqIdx] = float32(stIdx)
|
|
posIDs[1*totalLen+seqIdx] = float32(stIdx)
|
|
posIDs[2*totalLen+seqIdx] = float32(stIdx)
|
|
stIdx++
|
|
seqIdx++
|
|
}
|
|
|
|
// Vision tokens for this image
|
|
// Python uses stIdx as base offset for all position dimensions
|
|
for h := int32(0); h < gridH; h++ {
|
|
for w := int32(0); w < gridW; w++ {
|
|
posIDs[0*totalLen+seqIdx] = float32(stIdx) // temporal: constant = stIdx
|
|
posIDs[1*totalLen+seqIdx] = float32(stIdx + h) // height: stIdx + row_index
|
|
posIDs[2*totalLen+seqIdx] = float32(stIdx + w) // width: stIdx + col_index
|
|
seqIdx++
|
|
}
|
|
}
|
|
|
|
// Verify we processed the expected number of tokens
|
|
if seqIdx != imgPos+numTokens {
|
|
return [2]*mlx.Array{}, fmt.Errorf("mismatch: processed %d but expected %d tokens for image %d", seqIdx-imgPos, numTokens, i)
|
|
}
|
|
|
|
// Update stIdx for next text segment: max(temporal, height, width) + 1
|
|
maxVisionPos := stIdx // temporal max
|
|
if stIdx+gridH-1 > maxVisionPos {
|
|
maxVisionPos = stIdx + gridH - 1
|
|
}
|
|
if stIdx+gridW-1 > maxVisionPos {
|
|
maxVisionPos = stIdx + gridW - 1
|
|
}
|
|
stIdx = maxVisionPos + 1
|
|
}
|
|
|
|
// Text after last image
|
|
for seqIdx < totalLen {
|
|
posIDs[0*totalLen+seqIdx] = float32(stIdx)
|
|
posIDs[1*totalLen+seqIdx] = float32(stIdx)
|
|
posIDs[2*totalLen+seqIdx] = float32(stIdx)
|
|
stIdx++
|
|
seqIdx++
|
|
}
|
|
|
|
posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen})
|
|
return m.computeRoPEFromPositions(posIDsArr, totalLen, 1), nil
|
|
}
|
|
|
|
// computeTextRoPE computes M-RoPE for text-only sequences
|
|
func (m *Qwen25VL) computeTextRoPE(L, B int32) [2]*mlx.Array {
|
|
// For text-only, all 3 dims use same positions [0, 1, 2, ..., L-1]
|
|
posArr := make([]float32, L*3)
|
|
for d := 0; d < 3; d++ {
|
|
for i := int32(0); i < L; i++ {
|
|
posArr[int32(d)*L+i] = float32(i)
|
|
}
|
|
}
|
|
posIDs := mlx.NewArray(posArr, []int32{3, 1, L})
|
|
posIDs = mlx.Tile(posIDs, []int32{1, B, 1})
|
|
return m.computeRoPEFromPositions(posIDs, L, B)
|
|
}
|
|
|
|
// ComputeMultimodalRoPE computes M-RoPE for combined text + vision + text sequences
|
|
// This matches Python's get_rope_index behavior exactly.
|
|
// Exported for testing.
|
|
//
|
|
// Python pattern discovered from testing:
|
|
//
|
|
// Vision row 1: temporal=stIdx, height=stIdx, width=[stIdx, stIdx+1, ..., stIdx+gridW-1]
|
|
// Vision row 2: temporal=stIdx, height=stIdx+1, width=[stIdx, stIdx+1, ..., stIdx+gridW-1]
|
|
// Text after: temporal=stIdx+1+i, height=stIdx+gridH+i, width=stIdx+gridW+i
|
|
func (m *Qwen25VL) ComputeMultimodalRoPE(textBefore, visionH, visionW, textAfter int32, spatialMerge int32) [2]*mlx.Array {
|
|
// Vision grid after spatial merge
|
|
llmGridH := visionH / spatialMerge
|
|
llmGridW := visionW / spatialMerge
|
|
visionLen := llmGridH * llmGridW
|
|
totalLen := textBefore + visionLen + textAfter
|
|
|
|
// Build 3D position IDs: [3, 1, totalLen]
|
|
// Dimension 0: temporal, Dimension 1: height, Dimension 2: width
|
|
posIDs := make([]float32, 3*totalLen)
|
|
|
|
// Text before vision: all dims same [0, 1, 2, ..., textBefore-1]
|
|
for d := 0; d < 3; d++ {
|
|
for i := int32(0); i < textBefore; i++ {
|
|
posIDs[int32(d)*totalLen+i] = float32(i)
|
|
}
|
|
}
|
|
|
|
// Vision tokens: 3D grid positions
|
|
// Python uses stIdx (textBefore) as base offset for all position dimensions
|
|
stIdx := textBefore
|
|
for h := int32(0); h < llmGridH; h++ {
|
|
for w := int32(0); w < llmGridW; w++ {
|
|
idx := stIdx + h*llmGridW + w
|
|
posIDs[0*totalLen+idx] = float32(stIdx) // temporal: constant = stIdx
|
|
posIDs[1*totalLen+idx] = float32(stIdx + h) // height: stIdx + row_index
|
|
posIDs[2*totalLen+idx] = float32(stIdx + w) // width: stIdx + col_index
|
|
}
|
|
}
|
|
|
|
// Text after vision: ALL dimensions continue from max(temporal, height, width) + 1
|
|
// max is max(stIdx, stIdx+llmGridH-1, stIdx+llmGridW-1) = stIdx + max(0, llmGridH-1, llmGridW-1)
|
|
// Then st_idx = max + 1
|
|
maxVisionPos := stIdx // temporal max
|
|
if stIdx+llmGridH-1 > maxVisionPos {
|
|
maxVisionPos = stIdx + llmGridH - 1
|
|
}
|
|
if stIdx+llmGridW-1 > maxVisionPos {
|
|
maxVisionPos = stIdx + llmGridW - 1
|
|
}
|
|
textAfterStart := maxVisionPos + 1
|
|
for i := int32(0); i < textAfter; i++ {
|
|
seqIdx := textBefore + visionLen + i
|
|
posIDs[0*totalLen+seqIdx] = float32(textAfterStart + i) // temporal
|
|
posIDs[1*totalLen+seqIdx] = float32(textAfterStart + i) // height
|
|
posIDs[2*totalLen+seqIdx] = float32(textAfterStart + i) // width
|
|
}
|
|
|
|
posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen})
|
|
return m.computeRoPEFromPositions(posIDsArr, totalLen, 1)
|
|
}
|
|
|
|
// computeRoPEFromPositions computes cos/sin from 3D position IDs
|
|
// posIDs: [3, B, L] where dim 0 is temporal, 1 is height, 2 is width
|
|
func (m *Qwen25VL) computeRoPEFromPositions(posIDs *mlx.Array, L, B int32) [2]*mlx.Array {
|
|
cfg := m.Config
|
|
half := cfg.HeadDim / 2
|
|
|
|
// Compute inv_freq
|
|
invFreqArr := make([]float32, half)
|
|
for i := int32(0); i < half; i++ {
|
|
invFreqArr[i] = float32(1.0 / math.Pow(float64(cfg.RopeTheta), 2.0*float64(i)/float64(cfg.HeadDim)))
|
|
}
|
|
invFreq := mlx.NewArray(invFreqArr, []int32{half})
|
|
|
|
// Process each position dimension
|
|
var cosAll, sinAll []*mlx.Array
|
|
for d := int32(0); d < 3; d++ {
|
|
// Get positions for this dimension: [B, L]
|
|
pos := mlx.Slice(posIDs, []int32{d, 0, 0}, []int32{d + 1, B, L})
|
|
pos = mlx.Squeeze(pos, 0) // [B, L]
|
|
|
|
posExp := mlx.ExpandDims(pos, 2) // [B, L, 1]
|
|
invFreqExp := mlx.Reshape(invFreq, 1, 1, half) // [1, 1, half]
|
|
freqs := mlx.Mul(posExp, invFreqExp) // [B, L, half]
|
|
emb := mlx.Tile(freqs, []int32{1, 1, 2}) // [B, L, D]
|
|
|
|
cosAll = append(cosAll, mlx.ExpandDims(mlx.Cos(emb), 0))
|
|
sinAll = append(sinAll, mlx.ExpandDims(mlx.Sin(emb), 0))
|
|
}
|
|
|
|
cos := mlx.Concatenate(cosAll, 0) // [3, B, L, D]
|
|
sin := mlx.Concatenate(sinAll, 0)
|
|
|
|
return [2]*mlx.Array{cos, sin}
|
|
}
|
|
|
|
// computeVisionRoPE computes RoPE embeddings for vision patches
|
|
// pH, pW: grid dimensions in patches
|
|
// Returns: [2]*mlx.Array containing (cos, sin) each of shape [numPatches, headDim]
|
|
func (m *Qwen25VL) computeVisionRoPE(pH, pW int32) [2]*mlx.Array {
|
|
cfg := m.Config
|
|
headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads // 80 for 1280/16
|
|
halfDim := headDim / 2 // 40
|
|
quarterDim := halfDim / 2 // 20
|
|
spatialMerge := cfg.VisionSpatialMerge // 2
|
|
|
|
// Python Qwen2_5_VisionRotaryEmbedding uses dim=head_dim/2=40
|
|
// inv_freq = 1.0 / (theta ** (arange(0, dim, 2) / dim)) -> 20 elements
|
|
theta := float64(10000.0)
|
|
invFreqArr := make([]float32, quarterDim)
|
|
for i := int32(0); i < quarterDim; i++ {
|
|
invFreqArr[i] = float32(1.0 / math.Pow(theta, float64(2*i)/float64(halfDim)))
|
|
}
|
|
invFreq := mlx.NewArray(invFreqArr, []int32{quarterDim})
|
|
|
|
// Create position IDs matching Python's 2x2 block ordering:
|
|
// Python does: reshape(h//2, 2, w//2, 2), permute(0, 2, 1, 3), flatten
|
|
// This groups patches by 2x2 merged token blocks
|
|
numPatches := pH * pW
|
|
hPosArr := make([]float32, numPatches)
|
|
wPosArr := make([]float32, numPatches)
|
|
|
|
// Number of merged token blocks
|
|
llmGridH := pH / spatialMerge
|
|
llmGridW := pW / spatialMerge
|
|
|
|
idx := int32(0)
|
|
for hBlock := int32(0); hBlock < llmGridH; hBlock++ {
|
|
for wBlock := int32(0); wBlock < llmGridW; wBlock++ {
|
|
// Within each 2x2 block: (0,0), (0,1), (1,0), (1,1)
|
|
for dh := int32(0); dh < spatialMerge; dh++ {
|
|
for dw := int32(0); dw < spatialMerge; dw++ {
|
|
h := hBlock*spatialMerge + dh
|
|
w := wBlock*spatialMerge + dw
|
|
hPosArr[idx] = float32(h)
|
|
wPosArr[idx] = float32(w)
|
|
idx++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
hPos := mlx.NewArray(hPosArr, []int32{numPatches, 1})
|
|
wPos := mlx.NewArray(wPosArr, []int32{numPatches, 1})
|
|
invFreqExp := mlx.Reshape(invFreq, 1, quarterDim)
|
|
|
|
// Compute freqs: [numPatches, quarterDim] for each of h and w
|
|
hFreqs := mlx.Mul(hPos, invFreqExp) // [L, 20]
|
|
wFreqs := mlx.Mul(wPos, invFreqExp) // [L, 20]
|
|
|
|
// Concatenate h and w freqs: [numPatches, halfDim] = [L, 40]
|
|
freqs := mlx.Concatenate([]*mlx.Array{hFreqs, wFreqs}, 1)
|
|
|
|
// Double for cos/sin application: [L, 40] -> [L, 80] = [L, headDim]
|
|
emb := mlx.Concatenate([]*mlx.Array{freqs, freqs}, 1)
|
|
|
|
cos := mlx.Cos(emb)
|
|
sin := mlx.Sin(emb)
|
|
|
|
return [2]*mlx.Array{cos, sin}
|
|
}
|
|
|
|
// VLTextBlock is a single Qwen2.5 transformer block (for VL model)
|
|
type VLTextBlock struct {
|
|
Attention *VLTextAttention
|
|
MLP *VLTextMLP
|
|
InputLayerNorm *mlx.Array
|
|
PostAttnLayerNorm *mlx.Array
|
|
NormEps float32
|
|
}
|
|
|
|
// newVLTextBlock creates a text block
|
|
func newVLTextBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VLTextBlock, error) {
|
|
prefix := fmt.Sprintf("model.layers.%d", layerIdx)
|
|
|
|
inputNorm, err := weights.Get(prefix + ".input_layernorm.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
postAttnNorm, err := weights.Get(prefix + ".post_attention_layernorm.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
attention, err := newVLTextAttention(weights, prefix, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mlpLayer, err := newVLTextMLP(weights, prefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VLTextBlock{
|
|
Attention: attention,
|
|
MLP: mlpLayer,
|
|
InputLayerNorm: inputNorm,
|
|
PostAttnLayerNorm: postAttnNorm,
|
|
NormEps: cfg.RMSNormEps,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the block
|
|
func (tb *VLTextBlock) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array {
|
|
h := mlx.RMSNorm(x, tb.InputLayerNorm, tb.NormEps)
|
|
attnOut := tb.Attention.Forward(h, cossin)
|
|
x = mlx.Add(x, attnOut)
|
|
|
|
h = mlx.RMSNorm(x, tb.PostAttnLayerNorm, tb.NormEps)
|
|
mlpOut := tb.MLP.Forward(h)
|
|
x = mlx.Add(x, mlpOut)
|
|
|
|
return x
|
|
}
|
|
|
|
// VLTextAttention implements Qwen2.5 attention with M-RoPE
|
|
type VLTextAttention struct {
|
|
QProj *mlx.Array
|
|
KProj *mlx.Array
|
|
VProj *mlx.Array
|
|
OProj *mlx.Array
|
|
QBias *mlx.Array
|
|
KBias *mlx.Array
|
|
VBias *mlx.Array
|
|
NHeads int32
|
|
NKVHeads int32
|
|
HeadDim int32
|
|
Scale float32
|
|
MRoPESection []int32
|
|
}
|
|
|
|
// newVLTextAttention creates a text attention layer
|
|
func newVLTextAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VLTextAttention, error) {
|
|
qProj, err := weights.Get(prefix + ".self_attn.q_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
kProj, err := weights.Get(prefix + ".self_attn.k_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vProj, err := weights.Get(prefix + ".self_attn.v_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
oProj, err := weights.Get(prefix + ".self_attn.o_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
qBias, _ := weights.Get(prefix + ".self_attn.q_proj.bias")
|
|
kBias, _ := weights.Get(prefix + ".self_attn.k_proj.bias")
|
|
vBias, _ := weights.Get(prefix + ".self_attn.v_proj.bias")
|
|
|
|
return &VLTextAttention{
|
|
QProj: mlx.Transpose(qProj, 1, 0),
|
|
KProj: mlx.Transpose(kProj, 1, 0),
|
|
VProj: mlx.Transpose(vProj, 1, 0),
|
|
OProj: mlx.Transpose(oProj, 1, 0),
|
|
QBias: qBias,
|
|
KBias: kBias,
|
|
VBias: vBias,
|
|
NHeads: cfg.NumAttentionHeads,
|
|
NKVHeads: cfg.NumKeyValueHeads,
|
|
HeadDim: cfg.HeadDim,
|
|
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
|
MRoPESection: cfg.MRoPESection,
|
|
}, nil
|
|
}
|
|
|
|
// Forward computes attention
|
|
func (attn *VLTextAttention) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1]
|
|
|
|
q := mlx.Linear(x, attn.QProj)
|
|
if attn.QBias != nil {
|
|
q = mlx.Add(q, attn.QBias)
|
|
}
|
|
k := mlx.Linear(x, attn.KProj)
|
|
if attn.KBias != nil {
|
|
k = mlx.Add(k, attn.KBias)
|
|
}
|
|
v := mlx.Linear(x, attn.VProj)
|
|
if attn.VBias != nil {
|
|
v = mlx.Add(v, attn.VBias)
|
|
}
|
|
|
|
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
|
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
|
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
|
|
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
|
|
|
// Apply M-RoPE
|
|
if cossin[0] != nil && cossin[1] != nil {
|
|
q = applyMRoPE(q, cossin[0], cossin[1], attn.MRoPESection)
|
|
k = applyMRoPE(k, cossin[0], cossin[1], attn.MRoPESection)
|
|
}
|
|
|
|
// Repeat KV for GQA
|
|
if attn.NKVHeads < attn.NHeads {
|
|
repeats := attn.NHeads / attn.NKVHeads
|
|
k = repeatKV(k, repeats)
|
|
v = repeatKV(v, repeats)
|
|
}
|
|
|
|
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
|
|
|
|
out = mlx.Transpose(out, 0, 2, 1, 3)
|
|
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
|
|
|
return mlx.Linear(out, attn.OProj)
|
|
}
|
|
|
|
// applyMRoPE applies Multi-Resolution RoPE
|
|
func applyMRoPE(x *mlx.Array, cos, sin *mlx.Array, section []int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
L := shape[2]
|
|
D := shape[3]
|
|
half := D / 2
|
|
|
|
fullSection := make([]int32, len(section))
|
|
for i, s := range section {
|
|
fullSection[i] = s * 2
|
|
}
|
|
|
|
var cosParts, sinParts []*mlx.Array
|
|
offset := int32(0)
|
|
for i, size := range fullSection {
|
|
posDim := int32(i % 3)
|
|
cosSection := mlx.Slice(cos, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size})
|
|
sinSection := mlx.Slice(sin, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size})
|
|
cosSection = mlx.Squeeze(cosSection, 0)
|
|
sinSection = mlx.Squeeze(sinSection, 0)
|
|
cosParts = append(cosParts, cosSection)
|
|
sinParts = append(sinParts, sinSection)
|
|
offset += size
|
|
}
|
|
|
|
cosFlat := mlx.Concatenate(cosParts, 2)
|
|
sinFlat := mlx.Concatenate(sinParts, 2)
|
|
|
|
cosFlat = mlx.Reshape(cosFlat, B, 1, L, D)
|
|
sinFlat = mlx.Reshape(sinFlat, B, 1, L, D)
|
|
|
|
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, H, L, half})
|
|
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, H, L, D})
|
|
negX2 := mlx.MulScalar(x2, -1)
|
|
rotatedX := mlx.Concatenate([]*mlx.Array{negX2, x1}, 3)
|
|
|
|
return mlx.Add(mlx.Mul(x, cosFlat), mlx.Mul(rotatedX, sinFlat))
|
|
}
|
|
|
|
// repeatKV repeats key/value heads for GQA
|
|
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
|
if repeats == 1 {
|
|
return x
|
|
}
|
|
shape := x.Shape()
|
|
x = mlx.ExpandDims(x, 2)
|
|
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
|
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
|
}
|
|
|
|
// VLTextMLP implements Qwen2.5 SwiGLU MLP
|
|
type VLTextMLP struct {
|
|
GateProj *mlx.Array
|
|
UpProj *mlx.Array
|
|
DownProj *mlx.Array
|
|
}
|
|
|
|
// newVLTextMLP creates a text MLP layer
|
|
func newVLTextMLP(weights *safetensors.ModelWeights, prefix string) (*VLTextMLP, error) {
|
|
gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upProj, err := weights.Get(prefix + ".mlp.up_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
downProj, err := weights.Get(prefix + ".mlp.down_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VLTextMLP{
|
|
GateProj: mlx.Transpose(gateProj, 1, 0),
|
|
UpProj: mlx.Transpose(upProj, 1, 0),
|
|
DownProj: mlx.Transpose(downProj, 1, 0),
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the SwiGLU MLP
|
|
func (mlp *VLTextMLP) Forward(x *mlx.Array) *mlx.Array {
|
|
gate := mlx.Linear(x, mlp.GateProj)
|
|
gate = mlx.SiLU(gate)
|
|
up := mlx.Linear(x, mlp.UpProj)
|
|
h := mlx.Mul(gate, up)
|
|
return mlx.Linear(h, mlp.DownProj)
|
|
}
|
|
|
|
// VisionPatchEmbed embeds image patches
|
|
type VisionPatchEmbed struct {
|
|
ProjWeight *mlx.Array
|
|
ProjBias *mlx.Array
|
|
PatchSize int32
|
|
}
|
|
|
|
// newVisionPatchEmbed creates a vision patch embed layer
|
|
func newVisionPatchEmbed(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionPatchEmbed, error) {
|
|
projWeight, err := weights.Get("visual.patch_embed.proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
projBias, _ := weights.Get("visual.patch_embed.proj.bias")
|
|
|
|
return &VisionPatchEmbed{
|
|
ProjWeight: projWeight,
|
|
ProjBias: projBias,
|
|
PatchSize: cfg.VisionPatchSize,
|
|
}, nil
|
|
}
|
|
|
|
// Forward embeds patches from an image
|
|
// image: [B, C, H, W]
|
|
// Returns: [B, num_patches, hidden_size]
|
|
func (pe *VisionPatchEmbed) Forward(image *mlx.Array) *mlx.Array {
|
|
// Qwen2.5-VL uses 3D conv for patch embedding to support video
|
|
// Weight shape is [O, I, kT, kH, kW] e.g. [1280, 3, 2, 14, 14]
|
|
// For single image, we duplicate the frame to match temporal_patch_size
|
|
|
|
wShape := pe.ProjWeight.Shape()
|
|
if len(wShape) == 5 {
|
|
// 3D convolution case
|
|
temporalPatchSize := wShape[2] // kT from weight shape
|
|
|
|
// Add temporal dimension: [B, C, H, W] -> [B, C, 1, H, W]
|
|
image = mlx.ExpandDims(image, 2)
|
|
|
|
// Duplicate frame to match temporal_patch_size (Python does this for single images)
|
|
// [B, C, 1, H, W] -> [B, C, T, H, W] where T = temporal_patch_size
|
|
if temporalPatchSize > 1 {
|
|
image = mlx.Tile(image, []int32{1, 1, temporalPatchSize, 1, 1})
|
|
}
|
|
|
|
// Convert to channels-last: [B, C, T, H, W] -> [B, T, H, W, C]
|
|
image = mlx.Transpose(image, 0, 2, 3, 4, 1)
|
|
|
|
// Weight is [O, I, kT, kH, kW] - keep as-is since patches are now in [I, kT, kH, kW] order
|
|
// (extractPatches3DStrided transposes each patch to [C, T, H, W] to match Python)
|
|
|
|
// Apply 3D conv using manual patch extraction
|
|
// Strides: (temporal_patch_size, patch_size, patch_size)
|
|
x := conv3DStrided(image, pe.ProjWeight, temporalPatchSize, pe.PatchSize, pe.PatchSize)
|
|
|
|
if pe.ProjBias != nil {
|
|
outC := pe.ProjBias.Dim(0)
|
|
bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, 1, outC)
|
|
x = mlx.Add(x, bias)
|
|
}
|
|
|
|
// x is [B, T', H', W', C], squeeze T' and flatten spatial
|
|
shape := x.Shape()
|
|
// T' should be 1 for single image (since we used stride=temporal_patch_size)
|
|
x = mlx.Reshape(x, shape[0], shape[2]*shape[3], shape[4])
|
|
|
|
return x
|
|
}
|
|
|
|
// Original 2D case (fallback)
|
|
// Convert to channels-last for Conv2d
|
|
image = mlx.Transpose(image, 0, 2, 3, 1) // [B, H, W, C]
|
|
|
|
// Apply conv with stride=patch_size using manual strided convolution
|
|
weight := mlx.Transpose(pe.ProjWeight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
|
x := conv2DStrided(image, weight, pe.PatchSize)
|
|
if pe.ProjBias != nil {
|
|
bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, pe.ProjBias.Dim(0))
|
|
x = mlx.Add(x, bias)
|
|
}
|
|
|
|
// Flatten patches: [B, pH, pW, C] -> [B, pH*pW, C]
|
|
shape := x.Shape()
|
|
x = mlx.Reshape(x, shape[0], shape[1]*shape[2], shape[3])
|
|
|
|
return x
|
|
}
|
|
|
|
// VisionBlock is a single vision transformer block
|
|
type VisionBlock struct {
|
|
Norm1 *mlx.Array
|
|
Norm2 *mlx.Array
|
|
Attention *VisionAttention
|
|
MLP *VisionMLP
|
|
}
|
|
|
|
// newVisionBlock creates a vision block
|
|
func newVisionBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VisionBlock, error) {
|
|
prefix := fmt.Sprintf("visual.blocks.%d", layerIdx)
|
|
|
|
norm1, err := weights.Get(prefix + ".norm1.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
norm2, err := weights.Get(prefix + ".norm2.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
attention, err := newVisionAttention(weights, prefix, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mlpLayer, err := newVisionMLP(weights, prefix, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VisionBlock{
|
|
Norm1: norm1,
|
|
Norm2: norm2,
|
|
Attention: attention,
|
|
MLP: mlpLayer,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the vision block
|
|
// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil
|
|
// cuSeqlens: cumulative sequence lengths for window attention
|
|
func (vb *VisionBlock) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array {
|
|
// Python uses RMSNorm, not LayerNorm!
|
|
h := mlx.RMSNormNoWeight(x, 1e-6)
|
|
h = mlx.Mul(h, vb.Norm1)
|
|
attnOut := vb.Attention.Forward(h, posEmb, cuSeqlens)
|
|
x = mlx.Add(x, attnOut)
|
|
|
|
h = mlx.RMSNormNoWeight(x, 1e-6)
|
|
h = mlx.Mul(h, vb.Norm2)
|
|
mlpOut := vb.MLP.Forward(h)
|
|
x = mlx.Add(x, mlpOut)
|
|
|
|
return x
|
|
}
|
|
|
|
// VisionAttention implements vision attention
|
|
type VisionAttention struct {
|
|
QKVProj *mlx.Array
|
|
QKVBias *mlx.Array
|
|
OutProj *mlx.Array
|
|
OutBias *mlx.Array
|
|
NHeads int32
|
|
HeadDim int32
|
|
Scale float32
|
|
}
|
|
|
|
// newVisionAttention creates a vision attention layer
|
|
func newVisionAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionAttention, error) {
|
|
qkvProj, err := weights.Get(prefix + ".attn.qkv.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
qkvBias, _ := weights.Get(prefix + ".attn.qkv.bias")
|
|
outProj, err := weights.Get(prefix + ".attn.proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
outBias, _ := weights.Get(prefix + ".attn.proj.bias")
|
|
|
|
headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads
|
|
|
|
return &VisionAttention{
|
|
QKVProj: mlx.Transpose(qkvProj, 1, 0),
|
|
QKVBias: qkvBias,
|
|
OutProj: mlx.Transpose(outProj, 1, 0),
|
|
OutBias: outBias,
|
|
NHeads: cfg.VisionNumHeads,
|
|
HeadDim: headDim,
|
|
Scale: float32(1.0 / math.Sqrt(float64(headDim))),
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies vision attention with optional RoPE and window attention
|
|
// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil
|
|
// cuSeqlens: cumulative sequence lengths for window boundaries
|
|
func (attn *VisionAttention) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1]
|
|
D := shape[2]
|
|
|
|
qkv := mlx.Linear(x, attn.QKVProj)
|
|
if attn.QKVBias != nil {
|
|
qkv = mlx.Add(qkv, attn.QKVBias)
|
|
}
|
|
|
|
// Split into Q, K, V
|
|
qkv = mlx.Reshape(qkv, B, L, 3, attn.NHeads, attn.HeadDim)
|
|
q := mlx.Slice(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, attn.NHeads, attn.HeadDim})
|
|
k := mlx.Slice(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, attn.NHeads, attn.HeadDim})
|
|
v := mlx.Slice(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, attn.NHeads, attn.HeadDim})
|
|
|
|
q = mlx.Squeeze(q, 2) // [B, L, H, D]
|
|
k = mlx.Squeeze(k, 2)
|
|
v = mlx.Squeeze(v, 2)
|
|
|
|
// Apply RoPE if position embeddings provided
|
|
if posEmb[0] != nil && posEmb[1] != nil {
|
|
q, k = applyVisionRoPE(q, k, posEmb[0], posEmb[1])
|
|
}
|
|
|
|
q = mlx.Transpose(q, 0, 2, 1, 3) // [B, H, L, D]
|
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
|
|
|
var out *mlx.Array
|
|
|
|
// Check if we need window attention (more than 1 window)
|
|
numWindows := len(cuSeqlens) - 1
|
|
if numWindows <= 1 {
|
|
// Full attention - single window covering entire sequence
|
|
out = mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
|
|
} else {
|
|
// Window attention - process each window separately
|
|
attnOutputs := make([]*mlx.Array, numWindows)
|
|
|
|
for w := 0; w < numWindows; w++ {
|
|
start := cuSeqlens[w]
|
|
end := cuSeqlens[w+1]
|
|
|
|
// Slice Q, K, V for this window: [B, H, winLen, D]
|
|
qWin := mlx.Slice(q, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
|
|
kWin := mlx.Slice(k, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
|
|
vWin := mlx.Slice(v, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
|
|
|
|
// Compute attention for this window
|
|
attnWin := mlx.ScaledDotProductAttention(qWin, kWin, vWin, attn.Scale, false)
|
|
attnOutputs[w] = attnWin
|
|
}
|
|
|
|
// Concatenate all window outputs along sequence dimension
|
|
out = mlx.Concatenate(attnOutputs, 2)
|
|
}
|
|
|
|
out = mlx.Transpose(out, 0, 2, 1, 3) // [B, L, H, D]
|
|
out = mlx.Reshape(out, B, L, D)
|
|
|
|
out = mlx.Linear(out, attn.OutProj)
|
|
if attn.OutBias != nil {
|
|
out = mlx.Add(out, attn.OutBias)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
// applyVisionRoPE applies rotary position embedding to Q and K for vision
|
|
// q, k: [B, L, H, D], cos, sin: [L, D] (already doubled: D = head_dim)
|
|
// Returns: rotated q, k with same shape
|
|
// Note: Python does this computation in float32 for numerical stability
|
|
func applyVisionRoPE(q, k, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
// Convert to float32 for numerical stability (matches Python)
|
|
origDtype := q.Dtype()
|
|
q = mlx.AsType(q, mlx.DtypeFloat32)
|
|
k = mlx.AsType(k, mlx.DtypeFloat32)
|
|
cos = mlx.AsType(cos, mlx.DtypeFloat32)
|
|
sin = mlx.AsType(sin, mlx.DtypeFloat32)
|
|
|
|
// Expand cos/sin to match q/k shape: [L, D] -> [1, L, 1, D]
|
|
cos = mlx.ExpandDims(cos, 0)
|
|
cos = mlx.ExpandDims(cos, 2)
|
|
sin = mlx.ExpandDims(sin, 0)
|
|
sin = mlx.ExpandDims(sin, 2)
|
|
|
|
// rotate_half: split last dim in half and swap with negation
|
|
// q_rot = q * cos + rotate_half(q) * sin
|
|
qRotated := rotateHalf(q)
|
|
kRotated := rotateHalf(k)
|
|
|
|
qOut := mlx.Add(mlx.Mul(q, cos), mlx.Mul(qRotated, sin))
|
|
kOut := mlx.Add(mlx.Mul(k, cos), mlx.Mul(kRotated, sin))
|
|
|
|
// Convert back to original dtype
|
|
qOut = mlx.AsType(qOut, origDtype)
|
|
kOut = mlx.AsType(kOut, origDtype)
|
|
|
|
return qOut, kOut
|
|
}
|
|
|
|
// rotateHalf rotates the last dimension by splitting in half and swapping with negation
|
|
// x: [..., D] -> split to [..., D/2] and [..., D/2], then concat(-x2, x1)
|
|
func rotateHalf(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
lastDim := shape[len(shape)-1]
|
|
halfDim := lastDim / 2
|
|
|
|
// Split into two halves
|
|
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], halfDim})
|
|
x2 := mlx.Slice(x, []int32{0, 0, 0, halfDim}, []int32{shape[0], shape[1], shape[2], lastDim})
|
|
|
|
// Negate x2 and concatenate
|
|
x2Neg := mlx.MulScalar(x2, -1.0)
|
|
return mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3)
|
|
}
|
|
|
|
// VisionMLP implements vision SwiGLU MLP
|
|
type VisionMLP struct {
|
|
GateProj *mlx.Array
|
|
GateProjBias *mlx.Array
|
|
UpProj *mlx.Array
|
|
UpProjBias *mlx.Array
|
|
DownProj *mlx.Array
|
|
DownProjBias *mlx.Array
|
|
}
|
|
|
|
// newVisionMLP creates a vision MLP layer
|
|
func newVisionMLP(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionMLP, error) {
|
|
gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
gateProjBias, _ := weights.Get(prefix + ".mlp.gate_proj.bias")
|
|
upProj, err := weights.Get(prefix + ".mlp.up_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upProjBias, _ := weights.Get(prefix + ".mlp.up_proj.bias")
|
|
downProj, err := weights.Get(prefix + ".mlp.down_proj.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
downProjBias, _ := weights.Get(prefix + ".mlp.down_proj.bias")
|
|
|
|
return &VisionMLP{
|
|
GateProj: mlx.Transpose(gateProj, 1, 0),
|
|
GateProjBias: gateProjBias,
|
|
UpProj: mlx.Transpose(upProj, 1, 0),
|
|
UpProjBias: upProjBias,
|
|
DownProj: mlx.Transpose(downProj, 1, 0),
|
|
DownProjBias: downProjBias,
|
|
}, nil
|
|
}
|
|
|
|
// Forward applies the vision SwiGLU MLP
|
|
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
|
|
gate := mlx.Linear(x, m.GateProj)
|
|
if m.GateProjBias != nil {
|
|
gate = mlx.Add(gate, m.GateProjBias)
|
|
}
|
|
gate = mlx.SiLU(gate)
|
|
|
|
up := mlx.Linear(x, m.UpProj)
|
|
if m.UpProjBias != nil {
|
|
up = mlx.Add(up, m.UpProjBias)
|
|
}
|
|
|
|
h := mlx.Mul(gate, up)
|
|
h = mlx.Linear(h, m.DownProj)
|
|
if m.DownProjBias != nil {
|
|
h = mlx.Add(h, m.DownProjBias)
|
|
}
|
|
return h
|
|
}
|
|
|
|
// VisionMerger merges spatial patches (2x2 -> 1)
|
|
type VisionMerger struct {
|
|
MLP0Weight *mlx.Array
|
|
MLP0Bias *mlx.Array
|
|
MLP2Weight *mlx.Array
|
|
MLP2Bias *mlx.Array
|
|
LNWeight *mlx.Array
|
|
}
|
|
|
|
// newVisionMerger creates a vision merger
|
|
func newVisionMerger(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionMerger, error) {
|
|
mlp0Weight, err := weights.Get("visual.merger.mlp.0.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mlp0Bias, _ := weights.Get("visual.merger.mlp.0.bias")
|
|
mlp2Weight, err := weights.Get("visual.merger.mlp.2.weight")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mlp2Bias, _ := weights.Get("visual.merger.mlp.2.bias")
|
|
lnWeight, _ := weights.Get("visual.merger.ln_q.weight")
|
|
|
|
return &VisionMerger{
|
|
MLP0Weight: mlx.Transpose(mlp0Weight, 1, 0),
|
|
MLP0Bias: mlp0Bias,
|
|
MLP2Weight: mlx.Transpose(mlp2Weight, 1, 0),
|
|
MLP2Bias: mlp2Bias,
|
|
LNWeight: lnWeight,
|
|
}, nil
|
|
}
|
|
|
|
// Forward merges 2x2 patches into 1 (assumes square grid - use ForwardWithDims for non-square)
|
|
func (m *VisionMerger) Forward(x *mlx.Array) *mlx.Array {
|
|
shape := x.Shape()
|
|
L := shape[1]
|
|
side := int32(math.Sqrt(float64(L)))
|
|
return m.ForwardWithDims(x, side, side)
|
|
}
|
|
|
|
// ForwardWithDims merges 2x2 patches into 1 with explicit grid dimensions
|
|
// After window reordering, consecutive 4 patches form a 2x2 block, so we just
|
|
// reshape [B, L, D] -> [B, L/4, 4*D] without 2D spatial rearrangement.
|
|
func (m *VisionMerger) ForwardWithDims(x *mlx.Array, pH, pW int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
L := shape[1]
|
|
D := shape[2]
|
|
|
|
// RMSNorm BEFORE merge (applied to each token with D dimensions)
|
|
// Python: ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
|
if m.LNWeight != nil {
|
|
x = mlx.RMSNormNoWeight(x, 1e-6)
|
|
x = mlx.Mul(x, m.LNWeight)
|
|
}
|
|
|
|
// After window reordering, consecutive 4 patches belong to a 2x2 block
|
|
// Just reshape to [B, L/4, 4*D] - no spatial rearrangement needed
|
|
newL := L / 4
|
|
x = mlx.Reshape(x, B, newL, 4*D)
|
|
|
|
// MLP
|
|
h := mlx.Linear(x, m.MLP0Weight)
|
|
if m.MLP0Bias != nil {
|
|
h = mlx.Add(h, m.MLP0Bias)
|
|
}
|
|
h = mlx.GELU(h)
|
|
h = mlx.Linear(h, m.MLP2Weight)
|
|
if m.MLP2Bias != nil {
|
|
h = mlx.Add(h, m.MLP2Bias)
|
|
}
|
|
|
|
return h
|
|
}
|
|
|
|
// LoadQwen25VLFromPath loads the encoder from path
|
|
func LoadQwen25VLFromPath(path string) (*Qwen25VL, error) {
|
|
m := &Qwen25VL{}
|
|
if err := m.Load(filepath.Join(path, "text_encoder")); err != nil {
|
|
return nil, err
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
|
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
|
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
|
|
wShape := weight.Shape()
|
|
Cout := wShape[0]
|
|
kH := wShape[1]
|
|
kW := wShape[2]
|
|
|
|
outH := (H - kH) / stride + 1
|
|
outW := (W - kW) / stride + 1
|
|
|
|
patches := extractPatches2DStrided(x, kH, kW, stride)
|
|
wFlat := mlx.Reshape(weight, Cout, -1)
|
|
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
|
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
|
return mlx.Reshape(out, B, outH, outW, Cout)
|
|
}
|
|
|
|
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
|
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
|
// strideT, strideH, strideW are the strides for each dimension
|
|
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
|
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
T := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
C := shape[4]
|
|
|
|
wShape := weight.Shape()
|
|
Cout := wShape[0]
|
|
// I := wShape[1]
|
|
kT := wShape[2]
|
|
kH := wShape[3]
|
|
kW := wShape[4]
|
|
|
|
// For temporal: if T < kT, we need to repeat frames temporally
|
|
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
|
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
|
if T < kT {
|
|
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
|
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
|
T = kT
|
|
}
|
|
|
|
outT := (T - kT) / strideT + 1
|
|
outH := (H - kH) / strideH + 1
|
|
outW := (W - kW) / strideW + 1
|
|
|
|
// Extract 3D patches in [C, T, H, W] order to match Python
|
|
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
|
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
|
|
|
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
|
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
|
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
|
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
|
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
|
}
|
|
|
|
// extractPatches3DStrided extracts 3D patches with given strides
|
|
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
|
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
T := shape[1]
|
|
H := shape[2]
|
|
W := shape[3]
|
|
C := shape[4]
|
|
|
|
outT := (T - kT) / strideT + 1
|
|
outH := (H - kH) / strideH + 1
|
|
outW := (W - kW) / strideW + 1
|
|
|
|
numPatches := outT * outH * outW
|
|
patches := make([]*mlx.Array, numPatches)
|
|
idx := 0
|
|
for t := int32(0); t < outT; t++ {
|
|
for i := int32(0); i < outH; i++ {
|
|
for j := int32(0); j < outW; j++ {
|
|
startT := t * strideT
|
|
startH := i * strideH
|
|
startW := j * strideW
|
|
// Extract patch: [B, kT, kH, kW, C]
|
|
patch := mlx.Slice(x,
|
|
[]int32{0, startT, startH, startW, 0},
|
|
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
|
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
|
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
|
// Flatten to [B, C*T*H*W]
|
|
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
|
patches[idx] = patch
|
|
idx++
|
|
}
|
|
}
|
|
}
|
|
|
|
for i := range patches {
|
|
patches[i] = mlx.ExpandDims(patches[i], 1)
|
|
}
|
|
stacked := mlx.Concatenate(patches, 1)
|
|
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
|
}
|
|
|
|
// extractPatches2DStrided extracts patches with given stride
|
|
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
|
shape := x.Shape()
|
|
B := shape[0]
|
|
H := shape[1]
|
|
W := shape[2]
|
|
C := shape[3]
|
|
|
|
outH := (H - kH) / stride + 1
|
|
outW := (W - kW) / stride + 1
|
|
|
|
patches := make([]*mlx.Array, outH*outW)
|
|
idx := 0
|
|
for i := int32(0); i < outH; i++ {
|
|
for j := int32(0); j < outW; j++ {
|
|
startH := i * stride
|
|
startW := j * stride
|
|
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
|
patch = mlx.Reshape(patch, B, kH*kW*C)
|
|
patches[idx] = patch
|
|
idx++
|
|
}
|
|
}
|
|
|
|
for i := range patches {
|
|
patches[i] = mlx.ExpandDims(patches[i], 1)
|
|
}
|
|
stacked := mlx.Concatenate(patches, 1)
|
|
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
|
}
|