Files
ollama/x/imagegen/models/gemma3/vision.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

139 lines
4.5 KiB
Go

//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// VisionConfig holds configuration for the SigLIP vision tower
type VisionConfig struct {
HiddenSize int32 `json:"hidden_size"`
ImageSize int32 `json:"image_size"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
PatchSize int32 `json:"patch_size"`
}
// VisionTower is the SigLIP vision encoder
type VisionTower struct {
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
Config *VisionConfig
}
// VisionEmbeddings handles patch and position embeddings
type VisionEmbeddings struct {
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
PosEmbed *nn.Embedding `weight:"position_embedding"`
}
// VisionEncoderLayer is a single transformer encoder layer
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
Attention *VisionAttention `weight:"self_attn"`
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
MLP *VisionMLP `weight:"mlp"`
}
// VisionAttention implements multi-head self-attention
type VisionAttention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OutProj *nn.Linear `weight:"out_proj"`
}
// VisionMLP is the feed-forward network
type VisionMLP struct {
FC1 *nn.Linear `weight:"fc1"`
FC2 *nn.Linear `weight:"fc2"`
}
// Forward runs the vision tower on preprocessed images
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
// Output: [B, num_patches, hidden_size]
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
h = mlx.Add(h, bias)
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
B := h.Shape()[0]
gridH, gridW := h.Shape()[1], h.Shape()[2]
hidden := h.Shape()[3]
numPatches := gridH * gridW
h = mlx.Reshape(h, B, numPatches, hidden)
// Add position embeddings
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
h = mlx.Add(h, posEmbed)
// Encoder layers
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
scale := float32(1.0 / math.Sqrt(float64(headDim)))
for _, layer := range v.Encoder {
h = layer.Forward(h, v.Config, scale)
}
// Final layer norm
h = v.PostLayerNorm.Forward(h)
return h
}
// Forward runs a vision encoder layer
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
// Pre-norm attention
h := l.LayerNorm1.Forward(x)
h = l.Attention.Forward(h, cfg, scale)
x = mlx.Add(x, h)
// Pre-norm MLP
h = l.LayerNorm2.Forward(x)
h = l.MLP.Forward(h)
return mlx.Add(x, h)
}
// Forward runs multi-head self-attention
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
B, L := x.Shape()[0], x.Shape()[1]
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
// Scaled dot-product attention (no causal mask for vision)
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
return a.OutProj.Forward(out)
}
// Forward runs the MLP with GELU activation
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
h := mlx.GELU(m.FC1.Forward(x))
return m.FC2.Forward(h)
}