mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
139 lines
4.5 KiB
Go
139 lines
4.5 KiB
Go
//go:build mlx
|
|
|
|
package gemma3
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/nn"
|
|
)
|
|
|
|
// VisionConfig holds configuration for the SigLIP vision tower
|
|
type VisionConfig struct {
|
|
HiddenSize int32 `json:"hidden_size"`
|
|
ImageSize int32 `json:"image_size"`
|
|
IntermediateSize int32 `json:"intermediate_size"`
|
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
PatchSize int32 `json:"patch_size"`
|
|
}
|
|
|
|
// VisionTower is the SigLIP vision encoder
|
|
type VisionTower struct {
|
|
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
|
|
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
|
|
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
|
|
Config *VisionConfig
|
|
}
|
|
|
|
// VisionEmbeddings handles patch and position embeddings
|
|
type VisionEmbeddings struct {
|
|
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
|
|
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
|
|
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
|
|
PosEmbed *nn.Embedding `weight:"position_embedding"`
|
|
}
|
|
|
|
// VisionEncoderLayer is a single transformer encoder layer
|
|
type VisionEncoderLayer struct {
|
|
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
|
|
Attention *VisionAttention `weight:"self_attn"`
|
|
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
|
|
MLP *VisionMLP `weight:"mlp"`
|
|
}
|
|
|
|
// VisionAttention implements multi-head self-attention
|
|
type VisionAttention struct {
|
|
QProj *nn.Linear `weight:"q_proj"`
|
|
KProj *nn.Linear `weight:"k_proj"`
|
|
VProj *nn.Linear `weight:"v_proj"`
|
|
OutProj *nn.Linear `weight:"out_proj"`
|
|
}
|
|
|
|
// VisionMLP is the feed-forward network
|
|
type VisionMLP struct {
|
|
FC1 *nn.Linear `weight:"fc1"`
|
|
FC2 *nn.Linear `weight:"fc2"`
|
|
}
|
|
|
|
// Forward runs the vision tower on preprocessed images
|
|
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
|
|
// Output: [B, num_patches, hidden_size]
|
|
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
|
|
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
|
|
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
|
|
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
|
|
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
|
|
|
|
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
|
|
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
|
|
h = mlx.Add(h, bias)
|
|
|
|
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
|
|
B := h.Shape()[0]
|
|
gridH, gridW := h.Shape()[1], h.Shape()[2]
|
|
hidden := h.Shape()[3]
|
|
numPatches := gridH * gridW
|
|
h = mlx.Reshape(h, B, numPatches, hidden)
|
|
|
|
// Add position embeddings
|
|
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
|
|
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
|
|
h = mlx.Add(h, posEmbed)
|
|
|
|
// Encoder layers
|
|
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
|
|
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
|
for _, layer := range v.Encoder {
|
|
h = layer.Forward(h, v.Config, scale)
|
|
}
|
|
|
|
// Final layer norm
|
|
h = v.PostLayerNorm.Forward(h)
|
|
|
|
return h
|
|
}
|
|
|
|
// Forward runs a vision encoder layer
|
|
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
|
// Pre-norm attention
|
|
h := l.LayerNorm1.Forward(x)
|
|
h = l.Attention.Forward(h, cfg, scale)
|
|
x = mlx.Add(x, h)
|
|
|
|
// Pre-norm MLP
|
|
h = l.LayerNorm2.Forward(x)
|
|
h = l.MLP.Forward(h)
|
|
return mlx.Add(x, h)
|
|
}
|
|
|
|
// Forward runs multi-head self-attention
|
|
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
|
B, L := x.Shape()[0], x.Shape()[1]
|
|
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
|
|
|
|
q := a.QProj.Forward(x)
|
|
k := a.KProj.Forward(x)
|
|
v := a.VProj.Forward(x)
|
|
|
|
// Reshape to [B, num_heads, L, head_dim]
|
|
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
|
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
|
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
|
|
|
// Scaled dot-product attention (no causal mask for vision)
|
|
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
|
|
|
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
|
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
|
|
|
|
return a.OutProj.Forward(out)
|
|
}
|
|
|
|
// Forward runs the MLP with GELU activation
|
|
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
|
|
h := mlx.GELU(m.FC1.Forward(x))
|
|
return m.FC2.Forward(h)
|
|
}
|