mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
360 lines
9.2 KiB
Go
360 lines
9.2 KiB
Go
//go:build mlx
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/cache"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
)
|
|
|
|
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
|
var generationStream *mlx.Stream
|
|
|
|
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
|
|
// This handles cases where tokenizers output partial multi-byte sequences.
|
|
type utf8Streamer struct {
|
|
buffer []byte
|
|
}
|
|
|
|
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
|
|
func (s *utf8Streamer) Write(text string) string {
|
|
s.buffer = append(s.buffer, text...)
|
|
|
|
// Find the last position that ends with a complete UTF-8 character
|
|
validLen := 0
|
|
for i := 0; i < len(s.buffer); {
|
|
r, size := utf8.DecodeRune(s.buffer[i:])
|
|
if r == utf8.RuneError && size == 1 {
|
|
// Invalid or incomplete UTF-8 sequence at this position
|
|
// Check if it could be a valid start of a multi-byte sequence
|
|
if len(s.buffer)-i < 4 {
|
|
// Might be incomplete, keep it in buffer
|
|
break
|
|
}
|
|
// Definitely invalid, skip this byte
|
|
i++
|
|
validLen = i
|
|
} else {
|
|
i += size
|
|
validLen = i
|
|
}
|
|
}
|
|
|
|
if validLen == 0 {
|
|
return ""
|
|
}
|
|
|
|
result := string(s.buffer[:validLen])
|
|
s.buffer = s.buffer[validLen:]
|
|
return result
|
|
}
|
|
|
|
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
|
|
func (s *utf8Streamer) Flush() string {
|
|
if len(s.buffer) == 0 {
|
|
return ""
|
|
}
|
|
result := string(s.buffer)
|
|
s.buffer = nil
|
|
return result
|
|
}
|
|
|
|
func init() {
|
|
generationStream = mlx.NewStream()
|
|
}
|
|
|
|
// withStream runs fn with the generation stream as default
|
|
func withStream(fn func()) {
|
|
orig := mlx.GetDefaultStream()
|
|
mlx.SetDefaultStream(generationStream)
|
|
fn()
|
|
mlx.SetDefaultStream(orig)
|
|
}
|
|
|
|
type Model interface {
|
|
Tokenizer() *tokenizer.Tokenizer
|
|
VocabSize() int32
|
|
NewCache(maxSeqLen int32) []cache.Cache
|
|
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
|
|
}
|
|
|
|
// ChatModel is an optional interface for models that support chat formatting
|
|
type ChatModel interface {
|
|
FormatPrompt(prompt string) string
|
|
}
|
|
|
|
// MultimodalModel is for models that support image input
|
|
type MultimodalModel interface {
|
|
Model
|
|
FormatPromptWithImage(prompt string) string
|
|
ExpandImageTokens(tokens []int32) []int32
|
|
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
|
|
ImageSize() int32 // Returns expected image size for preprocessing
|
|
}
|
|
|
|
// ImageLoader loads and preprocesses an image for multimodal models
|
|
// Returns nil if path is empty
|
|
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
|
|
|
|
type input struct {
|
|
Prompt string
|
|
Image *mlx.Array // Optional preprocessed image for multimodal models
|
|
MaxTokens int
|
|
Temperature float32
|
|
TopP float32
|
|
TopK int
|
|
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
|
}
|
|
|
|
type output struct {
|
|
Text string
|
|
Done bool
|
|
PrefillTokSec float64
|
|
GenTokSec float64
|
|
}
|
|
|
|
// Decoder wraps model + cache for autoregressive generation.
|
|
type Decoder struct {
|
|
model Model
|
|
caches []cache.Cache
|
|
vocabSize int32
|
|
temp float32
|
|
topK int
|
|
topP float32
|
|
token *mlx.Array // Current token (kept across pools)
|
|
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
|
image *mlx.Array // Optional image for multimodal prefill
|
|
}
|
|
|
|
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
|
caches := m.NewCache(0)
|
|
return &Decoder{
|
|
model: m,
|
|
caches: caches,
|
|
vocabSize: m.VocabSize(),
|
|
temp: temp,
|
|
topK: topK,
|
|
topP: topP,
|
|
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
|
}
|
|
}
|
|
|
|
// SetImage sets the image for multimodal prefill (call before prefill)
|
|
func (d *Decoder) SetImage(img *mlx.Array) {
|
|
d.image = img
|
|
}
|
|
|
|
func (d *Decoder) prefill(inputIDs []int32) int {
|
|
processed := 0
|
|
|
|
// Track old cache state to free after each chunk
|
|
var oldCacheState []*mlx.Array
|
|
|
|
// For multimodal models with an image, we need to process all tokens together
|
|
// in the first forward pass so the image embeddings can be inserted properly.
|
|
// Skip chunking for multimodal prefill.
|
|
isMultimodal := d.image != nil
|
|
|
|
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
|
// Skip chunking for multimodal - process everything in the final step
|
|
if !isMultimodal {
|
|
for len(inputIDs) > 1 {
|
|
chunkSize := min(2048, len(inputIDs)-1)
|
|
if chunkSize <= 0 {
|
|
break
|
|
}
|
|
chunk := inputIDs[:chunkSize]
|
|
|
|
// Save old cache state before forward
|
|
oldCacheState = oldCacheState[:0]
|
|
for _, c := range d.caches {
|
|
oldCacheState = append(oldCacheState, c.State()...)
|
|
}
|
|
|
|
var cacheState []*mlx.Array
|
|
withStream(func() {
|
|
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
|
d.model.Forward(x, d.caches)
|
|
for _, c := range d.caches {
|
|
cacheState = append(cacheState, c.State()...)
|
|
}
|
|
})
|
|
mlx.Eval(cacheState...)
|
|
|
|
// Free old cache state
|
|
for _, arr := range oldCacheState {
|
|
if arr != nil {
|
|
arr.Free()
|
|
}
|
|
}
|
|
|
|
inputIDs = inputIDs[chunkSize:]
|
|
processed += chunkSize
|
|
}
|
|
}
|
|
|
|
// Save old cache state before final step
|
|
oldCacheState = oldCacheState[:0]
|
|
for _, c := range d.caches {
|
|
oldCacheState = append(oldCacheState, c.State()...)
|
|
}
|
|
|
|
// Final token + sampling (or all tokens for multimodal)
|
|
withStream(func() {
|
|
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
|
mlx.Eval(x) // Materialize before any other evals
|
|
|
|
var logits *mlx.Array
|
|
// Use ForwardWithImage if we have an image and model supports it
|
|
if d.image != nil {
|
|
if mm, ok := d.model.(MultimodalModel); ok {
|
|
logits = mm.ForwardWithImage(x, d.image, d.caches)
|
|
d.image = nil // Only use image for first forward
|
|
} else {
|
|
logits = d.model.Forward(x, d.caches)
|
|
}
|
|
} else {
|
|
logits = d.model.Forward(x, d.caches)
|
|
}
|
|
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
|
})
|
|
// Keep cache state (token auto-kept by AsyncEval)
|
|
for _, c := range d.caches {
|
|
mlx.Keep(c.State()...)
|
|
}
|
|
mlx.AsyncEval(d.token)
|
|
|
|
// Free old cache state from before final step
|
|
for _, arr := range oldCacheState {
|
|
if arr != nil {
|
|
arr.Free()
|
|
}
|
|
}
|
|
|
|
mlx.ClearCache()
|
|
|
|
return processed + len(inputIDs)
|
|
}
|
|
|
|
func (d *Decoder) step() int32 {
|
|
prevToken := d.token
|
|
|
|
// Save old cache state (reuse preallocated slice)
|
|
d.oldCacheState = d.oldCacheState[:0]
|
|
for _, c := range d.caches {
|
|
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
|
}
|
|
|
|
withStream(func() {
|
|
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
|
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
|
})
|
|
// Keep token and new cache state so they survive cleanup
|
|
mlx.Keep(d.token)
|
|
for _, c := range d.caches {
|
|
mlx.Keep(c.State()...)
|
|
}
|
|
mlx.AsyncEval(d.token)
|
|
|
|
// Sync on previous token (GPU already working on next step)
|
|
val := prevToken.ItemInt32()
|
|
|
|
// Free old token and old cache state
|
|
prevToken.Free()
|
|
for _, arr := range d.oldCacheState {
|
|
arr.Free()
|
|
}
|
|
return val
|
|
}
|
|
|
|
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
|
mlx.EnableCompile()
|
|
wiredLimit := in.WiredLimitGB
|
|
if wiredLimit <= 0 {
|
|
wiredLimit = 32 // default 32GB
|
|
}
|
|
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
|
|
|
|
temp := in.Temperature
|
|
if temp < 0 {
|
|
temp = 0.7
|
|
}
|
|
|
|
tok := m.Tokenizer()
|
|
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
|
|
|
// Apply chat template - use image template if we have an image
|
|
prompt := in.Prompt
|
|
var tokens []int32
|
|
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
|
|
prompt = mm.FormatPromptWithImage(prompt)
|
|
tokens = tok.Encode(prompt, true)
|
|
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
|
|
dec.SetImage(in.Image)
|
|
} else if cm, ok := m.(ChatModel); ok {
|
|
prompt = cm.FormatPrompt(prompt)
|
|
tokens = tok.Encode(prompt, true)
|
|
} else {
|
|
tokens = tok.Encode(prompt, true)
|
|
}
|
|
|
|
prefillStart := time.Now()
|
|
prefillTokens := dec.prefill(tokens)
|
|
// Prefill measurement should include time to first token (like mlx-lm)
|
|
// Step() waits for prefill to complete and returns first token
|
|
firstToken := dec.step()
|
|
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
|
|
|
|
genStart := time.Now()
|
|
maxTokens := max(in.MaxTokens, 100)
|
|
var genTokens int
|
|
|
|
// UTF-8 streamer to handle partial multi-byte characters
|
|
streamer := &utf8Streamer{}
|
|
|
|
// Handle first token
|
|
genTokens++
|
|
if tok.IsEOS(firstToken) {
|
|
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
|
|
return nil
|
|
}
|
|
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
|
cb(output{Text: text})
|
|
}
|
|
|
|
for n := 1; n < maxTokens; n++ {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
token := dec.step()
|
|
genTokens++
|
|
|
|
if tok.IsEOS(token) {
|
|
break
|
|
}
|
|
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
|
cb(output{Text: text})
|
|
}
|
|
|
|
if n%256 == 0 {
|
|
mlx.ClearCache()
|
|
}
|
|
}
|
|
|
|
// Flush any remaining buffered bytes
|
|
if text := streamer.Flush(); text != "" {
|
|
cb(output{Text: text})
|
|
}
|
|
|
|
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
|
|
cb(output{Done: true, PrefillTokSec: prefillTokSec,
|
|
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
|
return nil
|
|
}
|