Files
ollama/x/imagegen/models/gpt_oss/gpt_oss.go
Daniel Hiltgen 33ee7168ba Add experimental MLX backend and engine with imagegen support (#13648)
* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-01-08 16:18:59 -08:00

488 lines
17 KiB
Go

//go:build mlx
package gpt_oss
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// RopeScaling holds YaRN or other RoPE scaling configuration
type RopeScaling struct {
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
}
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
SlidingWindow int32 `json:"sliding_window"`
NumLocalExperts int32 `json:"num_local_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
LayerTypes []string `json:"layer_types"`
SwiGLULimit float32 `json:"swiglu_limit"`
RopeScaling *RopeScaling `json:"rope_scaling"`
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
YarnFreqs *mlx.Array // computed
YarnMscale float32
}
// swiGLU applies the GPT-OSS custom SwiGLU activation.
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
// with clipping: gate to [None, limit], up to [-limit, limit]
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
// Clip gate to [None, limit]
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
// Clip up to [-limit, limit]
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
// glu_scaled = alpha * gate_clipped
gluScaled := mlx.MulScalar(gateClipped, alpha)
// sig = sigmoid(glu_scaled)
sig := mlx.Sigmoid(gluScaled)
// out_glu = gate_clipped * sig
outGlu := mlx.Mul(gateClipped, sig)
// result = out_glu * (up_clipped + 1)
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
}
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
var compiledSwiGLU *mlx.CompiledFunc
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
const alpha float32 = 1.702
const limit float32 = 7.0
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
}, true) // shapeless=true so it works for any input size
}
return compiledSwiGLU
}
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
// Based on mlx-lm's YarnRoPE implementation
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
// yarn_find_correction_dim
yarnFindCorrectionDim := func(numRotations float64) float64 {
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
}
// yarn_find_correction_range
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
if low < 0 {
low = 0
}
if high > int(dims)-1 {
high = int(dims) - 1
}
// yarn_get_mscale
yarnGetMscale := func(scale, mscale float64) float64 {
if scale <= 1 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
// Compute frequencies
// freq_extra = base ** (arange(0, dims, 2) / dims)
// freq_inter = scaling_factor * freq_extra
halfDims := dims / 2
freqData := make([]float32, halfDims)
for i := int32(0); i < halfDims; i++ {
exp := float64(2*i) / float64(dims)
freqExtra := math.Pow(float64(base), exp)
freqInter := float64(scalingFactor) * freqExtra
// linear ramp mask
var freqMask float64
if low == high {
freqMask = 0.0
} else {
t := (float64(i) - float64(low)) / float64(high-low)
if t < 0 {
t = 0
}
if t > 1 {
t = 1
}
freqMask = 1.0 - t
}
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
}
return mlx.NewArray(freqData, []int32{halfDims}), mscale
}
// initYarn initializes YaRN RoPE if configured
func (a *Attention) initYarn(cfg *Config) {
a.YarnMscale = 1.0
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
cfg.HeadDim,
cfg.RopeTheta,
cfg.RopeScaling.Factor,
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
cfg.RopeScaling.BetaFast,
cfg.RopeScaling.BetaSlow,
)
}
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
offset := 0
if c != nil {
offset = c.Offset()
}
if a.YarnFreqs != nil {
if a.YarnMscale != 1.0 {
q = mlx.MulScalar(q, a.YarnMscale)
}
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
} else {
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
}
if c != nil {
k, v = c.Update(k, v, int(L))
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// CreateSlidingWindowMask creates a causal mask with sliding window
// Mirrors mlx-lm's create_causal_mask with window_size
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
// Build mask aligned to actual cache length (may be rotated)
// rinds covers existing keys: [keyStart, keyStart+keyLen)
// linds covers new queries: [queryStart, queryStart+seqLen)
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
return mlx.LogicalAnd(causalMask, windowMask)
}
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
type MoE struct {
Router *nn.Linear `weight:"mlp.router"`
TopK int32
HiddenSize int32
GroupSize int
Bits int
// Expert weights (loaded manually via sanitizeExpertWeights)
GateBlocks, GateScales, GateBias *mlx.Array
UpBlocks, UpScales, UpBias *mlx.Array
DownBlocks, DownScales, DownBias *mlx.Array
}
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
logits := moe.Router.Forward(x)
neg := mlx.Neg(logits)
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
weights := mlx.Softmax(topKVal, -1)
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
doSort := B*L >= 64
var invOrder *mlx.Array
sorted := false
n := B * L * moe.TopK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
sorted = true
}
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.GateBias != nil {
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
}
if moe.UpBias != nil {
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
}
hidden := getCompiledSwiGLU().Call(gate, up)[0]
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.DownBias != nil {
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
}
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
}
type Block struct {
Attention *Attention
MLP *MoE
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
LayerType string // "sliding_attention" or "full_attention"
}
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead *nn.Linear `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NewCache(int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
x := m.EmbedTokens.Forward(tokens)
// Find representative cache indices for sliding window attention
var swaIdx int = -1
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" {
swaIdx = i
break
}
}
// Create masks once at model level
var fullMask, swaMask *mlx.Array
var fullMaskMode, swaMaskMode string
if L > 1 {
fullMaskMode = "causal"
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
c := caches[swaIdx]
offset := c.Offset()
windowSize := int(m.SlidingWindow)
cacheLen := min(int(L), windowSize)
if offset > 0 {
cacheLen = min(c.Len()+int(L), windowSize)
}
if int(L) > windowSize {
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
} else {
swaMaskMode = "causal"
}
} else {
swaMaskMode = "causal"
}
}
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
mask, maskMode := fullMask, fullMaskMode
if layer.LayerType == "sliding_attention" {
mask, maskMode = swaMask, swaMaskMode
}
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
}
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
}
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
if gateUpBlocks != nil {
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
s := gub.Shape()
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpScales != nil {
s := gateUpScales.Shape()
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpBias != nil {
s := gateUpBias.Shape()
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
}
if downBlocks != nil {
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
}
return moe
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load simple weights via struct tags
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers with custom MoE handling
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
layer := &Block{}
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d: %w", i, err)
}
// Initialize attention YaRN
layer.Attention.initYarn(&cfg)
// Load MoE with weight sanitization
moe := sanitizeExpertWeights(weights, prefix)
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
moe.TopK = cfg.NumExpertsPerTok
moe.HiddenSize = cfg.HiddenSize
layer.MLP = moe
// Set layer type
layer.LayerType = "full_attention"
if int(i) < len(cfg.LayerTypes) {
layer.LayerType = cfg.LayerTypes[i]
}
m.Layers[i] = layer
}
// Release safetensors BEFORE eval - lazy arrays have captured data,
// this reduces peak memory by freeing mmap during materialization
weights.ReleaseAll()
mlx.Eval(mlx.Collect(m)...)
return m, nil
}
func (m *Model) MaxContextLength() int32 {
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
return m.RopeScaling.OriginalMaxPositionEmbeddings
}
return 131072
}