mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
51 lines
1.7 KiB
Go
51 lines
1.7 KiB
Go
//go:build mlx
|
|
|
|
package gemma3
|
|
|
|
import (
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/nn"
|
|
)
|
|
|
|
// MultiModalProjector projects vision features to text embedding space
|
|
type MultiModalProjector struct {
|
|
// mm_input_projection_weight: [vision_hidden, text_hidden]
|
|
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
|
|
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
|
|
|
|
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
|
SoftEmbNormScaled *mlx.Array `weight:"-"`
|
|
}
|
|
|
|
// Forward projects vision features to text space
|
|
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
|
|
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
|
|
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
|
|
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
|
|
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
|
|
B := visionFeatures.Shape()[0]
|
|
visionHidden := visionFeatures.Shape()[2]
|
|
|
|
// Reshape to [B, 64, 64, hidden]
|
|
gridSize := int32(64) // sqrt(4096)
|
|
pooledSize := int32(16) // 64/4
|
|
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
|
|
|
|
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
|
|
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
|
|
|
|
// Average over pooling dimensions (axes 2 and 4)
|
|
h = mlx.Mean(h, 4, false)
|
|
h = mlx.Mean(h, 2, false)
|
|
|
|
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
|
|
numTokens := pooledSize * pooledSize
|
|
h = mlx.Reshape(h, B, numTokens, visionHidden)
|
|
|
|
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
|
|
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
|
|
|
|
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
|
|
return mlx.Linear(h, p.InputProjection)
|
|
}
|