mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
488 lines
17 KiB
Go
488 lines
17 KiB
Go
//go:build mlx
|
|
|
|
package gpt_oss
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/cache"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/nn"
|
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
)
|
|
|
|
// RopeScaling holds YaRN or other RoPE scaling configuration
|
|
type RopeScaling struct {
|
|
RopeType string `json:"rope_type"`
|
|
Factor float32 `json:"factor"`
|
|
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
|
|
BetaFast float32 `json:"beta_fast"`
|
|
BetaSlow float32 `json:"beta_slow"`
|
|
}
|
|
|
|
type Config struct {
|
|
HiddenSize int32 `json:"hidden_size"`
|
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
IntermediateSize int32 `json:"intermediate_size"`
|
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
VocabSize int32 `json:"vocab_size"`
|
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
RopeTheta float32 `json:"rope_theta"`
|
|
HeadDim int32 `json:"head_dim"`
|
|
SlidingWindow int32 `json:"sliding_window"`
|
|
NumLocalExperts int32 `json:"num_local_experts"`
|
|
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
|
LayerTypes []string `json:"layer_types"`
|
|
SwiGLULimit float32 `json:"swiglu_limit"`
|
|
RopeScaling *RopeScaling `json:"rope_scaling"`
|
|
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
|
|
}
|
|
|
|
type Attention struct {
|
|
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
|
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
|
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
|
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
|
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
|
|
YarnFreqs *mlx.Array // computed
|
|
YarnMscale float32
|
|
}
|
|
|
|
// swiGLU applies the GPT-OSS custom SwiGLU activation.
|
|
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
|
|
// with clipping: gate to [None, limit], up to [-limit, limit]
|
|
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
|
|
// Clip gate to [None, limit]
|
|
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
|
|
|
|
// Clip up to [-limit, limit]
|
|
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
|
|
|
|
// glu_scaled = alpha * gate_clipped
|
|
gluScaled := mlx.MulScalar(gateClipped, alpha)
|
|
|
|
// sig = sigmoid(glu_scaled)
|
|
sig := mlx.Sigmoid(gluScaled)
|
|
|
|
// out_glu = gate_clipped * sig
|
|
outGlu := mlx.Mul(gateClipped, sig)
|
|
|
|
// result = out_glu * (up_clipped + 1)
|
|
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
|
|
}
|
|
|
|
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
|
|
var compiledSwiGLU *mlx.CompiledFunc
|
|
|
|
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
|
|
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
|
if compiledSwiGLU == nil {
|
|
const alpha float32 = 1.702
|
|
const limit float32 = 7.0
|
|
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
|
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
|
}, true) // shapeless=true so it works for any input size
|
|
}
|
|
return compiledSwiGLU
|
|
}
|
|
|
|
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
|
|
// Based on mlx-lm's YarnRoPE implementation
|
|
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
|
|
// yarn_find_correction_dim
|
|
yarnFindCorrectionDim := func(numRotations float64) float64 {
|
|
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
|
|
}
|
|
|
|
// yarn_find_correction_range
|
|
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
|
|
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
|
|
if low < 0 {
|
|
low = 0
|
|
}
|
|
if high > int(dims)-1 {
|
|
high = int(dims) - 1
|
|
}
|
|
|
|
// yarn_get_mscale
|
|
yarnGetMscale := func(scale, mscale float64) float64 {
|
|
if scale <= 1 {
|
|
return 1.0
|
|
}
|
|
return 0.1*mscale*math.Log(scale) + 1.0
|
|
}
|
|
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
|
|
|
|
// Compute frequencies
|
|
// freq_extra = base ** (arange(0, dims, 2) / dims)
|
|
// freq_inter = scaling_factor * freq_extra
|
|
halfDims := dims / 2
|
|
freqData := make([]float32, halfDims)
|
|
for i := int32(0); i < halfDims; i++ {
|
|
exp := float64(2*i) / float64(dims)
|
|
freqExtra := math.Pow(float64(base), exp)
|
|
freqInter := float64(scalingFactor) * freqExtra
|
|
|
|
// linear ramp mask
|
|
var freqMask float64
|
|
if low == high {
|
|
freqMask = 0.0
|
|
} else {
|
|
t := (float64(i) - float64(low)) / float64(high-low)
|
|
if t < 0 {
|
|
t = 0
|
|
}
|
|
if t > 1 {
|
|
t = 1
|
|
}
|
|
freqMask = 1.0 - t
|
|
}
|
|
|
|
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
|
|
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
|
|
}
|
|
|
|
return mlx.NewArray(freqData, []int32{halfDims}), mscale
|
|
}
|
|
|
|
// initYarn initializes YaRN RoPE if configured
|
|
func (a *Attention) initYarn(cfg *Config) {
|
|
a.YarnMscale = 1.0
|
|
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
|
|
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
|
|
cfg.HeadDim,
|
|
cfg.RopeTheta,
|
|
cfg.RopeScaling.Factor,
|
|
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
|
|
cfg.RopeScaling.BetaFast,
|
|
cfg.RopeScaling.BetaSlow,
|
|
)
|
|
}
|
|
}
|
|
|
|
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
|
q := a.QProj.Forward(x)
|
|
k := a.KProj.Forward(x)
|
|
v := a.VProj.Forward(x)
|
|
|
|
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
|
|
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
|
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
|
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
|
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
|
|
|
offset := 0
|
|
if c != nil {
|
|
offset = c.Offset()
|
|
}
|
|
if a.YarnFreqs != nil {
|
|
if a.YarnMscale != 1.0 {
|
|
q = mlx.MulScalar(q, a.YarnMscale)
|
|
}
|
|
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
|
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
|
} else {
|
|
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
|
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
|
}
|
|
|
|
if c != nil {
|
|
k, v = c.Update(k, v, int(L))
|
|
}
|
|
|
|
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
|
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
|
return a.OProj.Forward(out)
|
|
}
|
|
|
|
// CreateSlidingWindowMask creates a causal mask with sliding window
|
|
// Mirrors mlx-lm's create_causal_mask with window_size
|
|
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
|
|
// Build mask aligned to actual cache length (may be rotated)
|
|
// rinds covers existing keys: [keyStart, keyStart+keyLen)
|
|
// linds covers new queries: [queryStart, queryStart+seqLen)
|
|
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
|
|
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
|
|
|
|
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
|
|
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
|
|
|
|
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
|
|
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
|
|
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
|
|
|
|
return mlx.LogicalAnd(causalMask, windowMask)
|
|
}
|
|
|
|
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
|
|
type MoE struct {
|
|
Router *nn.Linear `weight:"mlp.router"`
|
|
TopK int32
|
|
HiddenSize int32
|
|
GroupSize int
|
|
Bits int
|
|
// Expert weights (loaded manually via sanitizeExpertWeights)
|
|
GateBlocks, GateScales, GateBias *mlx.Array
|
|
UpBlocks, UpScales, UpBias *mlx.Array
|
|
DownBlocks, DownScales, DownBias *mlx.Array
|
|
}
|
|
|
|
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
|
|
logits := moe.Router.Forward(x)
|
|
neg := mlx.Neg(logits)
|
|
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
|
|
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
|
|
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
|
|
weights := mlx.Softmax(topKVal, -1)
|
|
|
|
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
|
|
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
|
|
|
|
doSort := B*L >= 64
|
|
var invOrder *mlx.Array
|
|
sorted := false
|
|
n := B * L * moe.TopK
|
|
|
|
if doSort {
|
|
idxAll := mlx.Flatten(idxFlat)
|
|
order := mlx.Argsort(idxAll, 0)
|
|
invOrder = mlx.Argsort(order, 0)
|
|
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
|
|
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
|
sorted = true
|
|
}
|
|
|
|
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
|
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
|
|
|
if moe.GateBias != nil {
|
|
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
|
|
}
|
|
if moe.UpBias != nil {
|
|
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
|
|
}
|
|
|
|
hidden := getCompiledSwiGLU().Call(gate, up)[0]
|
|
|
|
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
|
if moe.DownBias != nil {
|
|
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
|
|
}
|
|
|
|
if doSort {
|
|
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
|
|
} else {
|
|
down = mlx.Squeeze(down, 2)
|
|
}
|
|
|
|
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
|
|
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
|
|
}
|
|
|
|
type Block struct {
|
|
Attention *Attention
|
|
MLP *MoE
|
|
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
|
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
|
LayerType string // "sliding_attention" or "full_attention"
|
|
}
|
|
|
|
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
|
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
|
|
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
|
|
}
|
|
|
|
type Model struct {
|
|
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
|
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
|
|
Norm *nn.RMSNorm `weight:"model.norm"`
|
|
LMHead *nn.Linear `weight:"lm_head"`
|
|
|
|
tok *tokenizer.Tokenizer
|
|
*Config
|
|
}
|
|
|
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
|
|
|
func (m *Model) NewCache(int32) []cache.Cache {
|
|
caches := make([]cache.Cache, len(m.Layers))
|
|
for i, layer := range m.Layers {
|
|
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
|
|
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
|
} else {
|
|
caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
return caches
|
|
}
|
|
|
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
|
x := m.EmbedTokens.Forward(tokens)
|
|
|
|
// Find representative cache indices for sliding window attention
|
|
var swaIdx int = -1
|
|
for i, layer := range m.Layers {
|
|
if layer.LayerType == "sliding_attention" {
|
|
swaIdx = i
|
|
break
|
|
}
|
|
}
|
|
|
|
// Create masks once at model level
|
|
var fullMask, swaMask *mlx.Array
|
|
var fullMaskMode, swaMaskMode string
|
|
|
|
if L > 1 {
|
|
fullMaskMode = "causal"
|
|
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
|
|
c := caches[swaIdx]
|
|
offset := c.Offset()
|
|
windowSize := int(m.SlidingWindow)
|
|
cacheLen := min(int(L), windowSize)
|
|
if offset > 0 {
|
|
cacheLen = min(c.Len()+int(L), windowSize)
|
|
}
|
|
if int(L) > windowSize {
|
|
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
|
|
} else {
|
|
swaMaskMode = "causal"
|
|
}
|
|
} else {
|
|
swaMaskMode = "causal"
|
|
}
|
|
}
|
|
|
|
for i, layer := range m.Layers {
|
|
var c cache.Cache
|
|
if caches != nil {
|
|
c = caches[i]
|
|
}
|
|
mask, maskMode := fullMask, fullMaskMode
|
|
if layer.LayerType == "sliding_attention" {
|
|
mask, maskMode = swaMask, swaMaskMode
|
|
}
|
|
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
|
|
}
|
|
|
|
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
|
|
}
|
|
|
|
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
|
|
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
|
|
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
|
|
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
|
|
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
|
|
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
|
|
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
|
|
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
|
|
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
|
|
|
|
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
|
|
|
|
if gateUpBlocks != nil {
|
|
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
|
|
s := gub.Shape()
|
|
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
}
|
|
if gateUpScales != nil {
|
|
s := gateUpScales.Shape()
|
|
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
|
}
|
|
if gateUpBias != nil {
|
|
s := gateUpBias.Shape()
|
|
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
|
|
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
|
|
}
|
|
if downBlocks != nil {
|
|
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
|
|
}
|
|
return moe
|
|
}
|
|
|
|
func Load(modelPath string) (*Model, error) {
|
|
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load config: %w", err)
|
|
}
|
|
var cfg Config
|
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
return nil, fmt.Errorf("parse config: %w", err)
|
|
}
|
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
|
|
|
weights, err := safetensors.LoadModelWeights(modelPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load weights: %w", err)
|
|
}
|
|
|
|
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load tokenizer: %w", err)
|
|
}
|
|
|
|
m := &Model{
|
|
Layers: make([]*Block, cfg.NumHiddenLayers),
|
|
Config: &cfg,
|
|
tok: tok,
|
|
}
|
|
|
|
// Load simple weights via struct tags
|
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Load layers with custom MoE handling
|
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
prefix := fmt.Sprintf("model.layers.%d", i)
|
|
layer := &Block{}
|
|
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
|
|
return nil, fmt.Errorf("layer %d: %w", i, err)
|
|
}
|
|
|
|
// Initialize attention YaRN
|
|
layer.Attention.initYarn(&cfg)
|
|
|
|
// Load MoE with weight sanitization
|
|
moe := sanitizeExpertWeights(weights, prefix)
|
|
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
|
|
moe.TopK = cfg.NumExpertsPerTok
|
|
moe.HiddenSize = cfg.HiddenSize
|
|
layer.MLP = moe
|
|
|
|
// Set layer type
|
|
layer.LayerType = "full_attention"
|
|
if int(i) < len(cfg.LayerTypes) {
|
|
layer.LayerType = cfg.LayerTypes[i]
|
|
}
|
|
|
|
m.Layers[i] = layer
|
|
}
|
|
|
|
// Release safetensors BEFORE eval - lazy arrays have captured data,
|
|
// this reduces peak memory by freeing mmap during materialization
|
|
weights.ReleaseAll()
|
|
mlx.Eval(mlx.Collect(m)...)
|
|
|
|
return m, nil
|
|
}
|
|
|
|
func (m *Model) MaxContextLength() int32 {
|
|
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
|
|
return m.RopeScaling.OriginalMaxPositionEmbeddings
|
|
}
|
|
return 131072
|
|
}
|