Files
ollama/x/imagegen/tokenizer/tokenizer.go
2026-01-09 21:09:46 -08:00

1172 lines
31 KiB
Go

//go:build mlx
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
//
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
// - GPT-2 byte-level encoding (OpenAI tiktoken)
// - HuggingFace tokenizer.json pretokenizer patterns
// - SentencePiece ▁-style space handling
package tokenizer
import (
"encoding/json"
"fmt"
"os"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
// TokenizerType identifies the tokenization algorithm
type TokenizerType int
const (
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
TokenizerSentencePiece // SentencePiece with ▁ for spaces
TokenizerWordPiece // BERT style with ## continuations
)
// Vocabulary holds the tokenizer vocabulary and merges
type Vocabulary struct {
Values []string
Reverse map[string]int32
Merges map[string]int
BOS int32
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
PAD int32 // Padding token (often <|endoftext|> or <pad>)
AddBOS bool
AddEOS bool
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
byteTokens [256]int32
}
// Tokenizer handles BPE, SentencePiece, and WordPiece tokenization
type Tokenizer struct {
vocab *Vocabulary
pretokenizer *regexp.Regexp
specialTokens map[string]int32 // Special tokens for direct lookup
typ TokenizerType // Algorithm type
unkToken int32 // [UNK] token ID for WordPiece fallback
}
// Precomputed GPT-2 byte-level encoding table
// Maps byte values to their encoded rune equivalents
var byteToRune [256]rune
func init() {
for b := 0; b < 256; b++ {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
byteToRune[b] = r
}
}
// loadSpecialTokenConfig loads special token configuration from HuggingFace companion files.
//
// Loading priority for EOS tokens (can be single int or []int):
// 1. generation_config.json - eos_token_id (preferred, matches HuggingFace generation)
// 2. config.json - eos_token_id (model config fallback)
// 3. tokenizer_config.json - eos_token string + add_bos/add_eos flags
// 4. special_tokens_map.json - final fallback
func loadSpecialTokenConfig(dir string, t *Tokenizer) {
// Helper to parse eos_token_id which can be int or []int
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json (eos_token_id can be int or []int)
if data, err := os.ReadFile(dir + "generation_config.json"); err == nil {
var config struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(data, &config); err == nil {
if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json (model config, same format)
if len(t.vocab.EOS) == 0 || t.vocab.BOS < 0 {
if data, err := os.ReadFile(dir + "config.json"); err == nil {
var config struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(data, &config); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
}
// Priority 3: tokenizer_config.json (token strings + add_bos/add_eos flags)
if data, err := os.ReadFile(dir + "tokenizer_config.json"); err == nil {
var config struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(data, &config); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(config.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(config.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(config.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if config.AddBOSToken != nil {
t.vocab.AddBOS = *config.AddBOSToken
}
if config.AddEOSToken != nil {
t.vocab.AddEOS = *config.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json (final fallback)
if data, err := os.ReadFile(dir + "special_tokens_map.json"); err == nil {
var tokensMap map[string]interface{}
if err := json.Unmarshal(data, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
// Tokens can be represented as:
// - string: "token"
// - object: {"content": "token", ...}
func extractTokenString(v interface{}) string {
if v == nil {
return ""
}
// Direct string
if s, ok := v.(string); ok {
return s
}
// Object with content field
if m, ok := v.(map[string]interface{}); ok {
if content, ok := m["content"].(string); ok {
return content
}
}
return ""
}
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
// - (?!\S) negative lookahead - RE2 doesn't support this
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
//
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
func rewritePatternForRE2(pattern string) string {
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
// Expand case-insensitive contraction pattern to explicit alternations
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
pattern = strings.ReplaceAll(pattern,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
return pattern
}
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
// This is useful when loading from blob storage where the file content is already in memory.
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
func LoadFromBytes(data []byte) (*Tokenizer, error) {
return loadFromTokenizerJSON(data, "")
}
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
type TokenizerConfig struct {
TokenizerConfigJSON []byte // tokenizer_config.json content
GenerationConfigJSON []byte // generation_config.json content
SpecialTokensMapJSON []byte // special_tokens_map.json content
ConfigJSON []byte // config.json content
}
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
// This is useful when loading from blob storage where companion config files are also blobs.
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
t, err := loadFromTokenizerJSON(data, "")
if err != nil {
return nil, err
}
if config == nil {
return t, nil
}
// Apply special token configs from provided data
loadSpecialTokenConfigFromBytes(t, config)
return t, nil
}
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
// Helper to parse eos_token_id which can be int or []int
parseTokenIDs := func(v interface{}) []int32 {
switch val := v.(type) {
case float64:
return []int32{int32(val)}
case []interface{}:
ids := make([]int32, 0, len(val))
for _, id := range val {
if f, ok := id.(float64); ok {
ids = append(ids, int32(f))
}
}
return ids
}
return nil
}
// Priority 1: generation_config.json
if len(config.GenerationConfigJSON) > 0 {
var genConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.GenerationConfigJSON, &genConfig); err == nil {
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
// Priority 2: config.json
if len(config.ConfigJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
var modelConfig struct {
EOSTokenID interface{} `json:"eos_token_id"`
BOSTokenID interface{} `json:"bos_token_id"`
}
if err := json.Unmarshal(config.ConfigJSON, &modelConfig); err == nil {
if len(t.vocab.EOS) == 0 {
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
t.vocab.EOS = ids
}
}
if t.vocab.BOS < 0 {
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
t.vocab.BOS = ids[0]
}
}
}
}
// Priority 3: tokenizer_config.json
if len(config.TokenizerConfigJSON) > 0 {
var tokConfig struct {
BOSToken interface{} `json:"bos_token"`
EOSToken interface{} `json:"eos_token"`
PADToken interface{} `json:"pad_token"`
AddBOSToken *bool `json:"add_bos_token"`
AddEOSToken *bool `json:"add_eos_token"`
}
if err := json.Unmarshal(config.TokenizerConfigJSON, &tokConfig); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
if tokConfig.AddBOSToken != nil {
t.vocab.AddBOS = *tokConfig.AddBOSToken
}
if tokConfig.AddEOSToken != nil {
t.vocab.AddEOS = *tokConfig.AddEOSToken
}
}
}
// Priority 4: special_tokens_map.json
if len(config.SpecialTokensMapJSON) > 0 {
var tokensMap map[string]interface{}
if err := json.Unmarshal(config.SpecialTokensMapJSON, &tokensMap); err == nil {
if t.vocab.BOS < 0 {
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
if id, ok := t.specialTokens[bosStr]; ok {
t.vocab.BOS = id
}
}
}
if len(t.vocab.EOS) == 0 {
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
if id, ok := t.specialTokens[eosStr]; ok {
t.vocab.EOS = []int32{id}
}
}
}
if t.vocab.PAD < 0 {
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
if id, ok := t.specialTokens[padStr]; ok {
t.vocab.PAD = id
}
}
}
}
}
}
// Load loads a tokenizer from a path which can be:
// - A tokenizer.json file
// - A directory containing tokenizer.json or vocab.json + merges.txt
func Load(path string) (*Tokenizer, error) {
// Check if path is a directory
if info, err := os.Stat(path); err == nil && info.IsDir() {
dir := strings.TrimSuffix(path, "/") + "/"
// Try tokenizer.json first
if data, err := os.ReadFile(dir + "tokenizer.json"); err == nil {
return loadFromTokenizerJSON(data, dir)
}
// Fall back to vocab.json + merges.txt
return LoadVocabMerges(path)
}
// It's a file - read it directly
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer: %w", err)
}
// Get directory for loading companion files
dir := ""
if idx := strings.LastIndex(path, "/"); idx >= 0 {
dir = path[:idx+1]
}
return loadFromTokenizerJSON(data, dir)
}
// loadFromTokenizerJSON parses a tokenizer.json file
func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
var raw struct {
Model struct {
Type string `json:"type"` // "BPE" or "WordPiece"
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
}
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS)
// WordPiece models don't have merges
var mergesStrings []string
if raw.Model.Type != "WordPiece" && raw.Model.Merges != nil {
var mergesArrays [][]string
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
// Try array of arrays format
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
return nil, fmt.Errorf("failed to parse merges: %w", err)
}
// Convert [][]string to []string
mergesStrings = make([]string, len(mergesArrays))
for i, pair := range mergesArrays {
mergesStrings[i] = pair[0] + " " + pair[1]
}
}
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(raw.Model.Vocab)),
Reverse: raw.Model.Vocab,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Build values array
for token, id := range raw.Model.Vocab {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Add special tokens to vocabulary
for _, tok := range raw.AddedTokens {
if int(tok.ID) >= len(t.vocab.Values) {
newValues := make([]string, tok.ID+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[tok.ID] = tok.Content
if tok.Special {
t.specialTokens[tok.Content] = tok.ID
}
}
// Load special token configuration from companion files
loadSpecialTokenConfig(dir, t)
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// Determine tokenizer type
switch {
case raw.Model.Type == "WordPiece":
t.typ = TokenizerWordPiece
case detectSentencePiece(raw.Decoder):
t.typ = TokenizerSentencePiece
default:
t.typ = TokenizerBPE
}
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
if t.typ == TokenizerBPE {
pattern := extractPretokenizer(raw.PreTokenizer)
if pattern == "" {
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
}
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
}
t.pretokenizer = re
}
return t, nil
}
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
// vs GPT-2 byte-level encoding
func detectSentencePiece(data json.RawMessage) bool {
if data == nil {
return false
}
// Check for Sequence decoder with Replace step (SentencePiece style)
var seq struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(data, &seq); err == nil {
if seq.Type == "Sequence" {
for _, dec := range seq.Decoders {
// Look for Replace decoder that converts ▁ to space
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
return true
}
}
}
}
// Check for direct ByteLevel decoder (GPT-2 style)
var simple struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &simple); err == nil {
if simple.Type == "ByteLevel" {
return false
}
}
return false
}
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
func initByteTokens(t *Tokenizer) {
for i := range t.vocab.byteTokens {
t.vocab.byteTokens[i] = -1
}
for b := 0; b < 256; b++ {
token := fmt.Sprintf("<0x%02X>", b)
if id, ok := t.vocab.Reverse[token]; ok {
t.vocab.byteTokens[b] = id
}
}
}
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
func extractPretokenizer(data json.RawMessage) string {
if data == nil {
return ""
}
// Try to parse as a single Split pretokenizer
var single struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
}
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
return single.Pattern.Regex
}
// Try to parse as Sequence of pretokenizers - use first Split pattern
var seq struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
}
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
for _, pt := range seq.Pretokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
return pt.Pattern.Regex
}
}
}
return ""
}
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
func isNonNewlineWhitespace(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r == '\n' || r == '\r' {
return false
}
if !unicode.IsSpace(r) {
return false
}
}
return true
}
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
if len(t.specialTokens) == 0 {
return []string{s}
}
// Sort special tokens by length (longest first) to match greedily
tokens := make([]string, 0, len(t.specialTokens))
for tok := range t.specialTokens {
tokens = append(tokens, tok)
}
sort.Slice(tokens, func(i, j int) bool {
return len(tokens[i]) > len(tokens[j])
})
var result []string
remaining := s
for len(remaining) > 0 {
found := false
for _, tok := range tokens {
if strings.HasPrefix(remaining, tok) {
result = append(result, tok)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Find next special token position
nextPos := len(remaining)
for _, tok := range tokens {
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
nextPos = idx
}
}
if nextPos > 0 {
result = append(result, remaining[:nextPos])
}
remaining = remaining[nextPos:]
}
}
return result
}
// Encode tokenizes text to token IDs. Parallelizes for large inputs (>10KB).
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
// First: split by special tokens
parts := t.splitBySpecialTokens(s)
// Second: collect all pretokenizer chunks
type chunk struct {
text string
isSpecial bool
}
var allChunks []chunk
if t.pretokenizer != nil {
re := t.pretokenizer
for _, part := range parts {
if _, ok := t.specialTokens[part]; ok {
allChunks = append(allChunks, chunk{part, true})
continue
}
// Split by pretokenizer regex
type match struct{ start, end int }
var matches []match
offset := 0
for offset < len(part) {
loc := re.FindStringIndex(part[offset:])
if loc == nil {
break
}
matches = append(matches, match{offset + loc[0], offset + loc[1]})
offset += loc[1]
}
// Apply whitespace boundary fix for Python regex compatibility
for i := 0; i < len(matches)-1; i++ {
m := part[matches[i].start:matches[i].end]
next := part[matches[i+1].start:matches[i+1].end]
if isNonNewlineWhitespace(m) && len(next) > 0 {
firstRune, _ := utf8.DecodeRuneInString(next)
if unicode.IsLetter(firstRune) {
lastSpaceStart := matches[i].end
for j := matches[i].end; j > matches[i].start; {
r, size := utf8.DecodeLastRuneInString(part[matches[i].start:j])
if unicode.IsSpace(r) {
lastSpaceStart = j - size
break
}
j -= size
}
if lastSpaceStart > matches[i].start {
matches[i].end = lastSpaceStart
matches[i+1].start = lastSpaceStart
} else {
matches[i+1].start = matches[i].start
matches[i].end = matches[i].start
}
}
}
}
for _, m := range matches {
if m.end > m.start {
allChunks = append(allChunks, chunk{part[m.start:m.end], false})
}
}
}
} else {
// No pretokenizer - treat each part as a single chunk
for _, part := range parts {
if _, ok := t.specialTokens[part]; ok {
allChunks = append(allChunks, chunk{part, true})
} else {
allChunks = append(allChunks, chunk{part, false})
}
}
}
// Encode chunks - parallel for large inputs (>4KB), sequential otherwise
var ids []int32
if len(s) < 4096 {
for _, c := range allChunks {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
ids = append(ids, id)
}
} else {
ids = t.encodeChunkInto(c.text, ids)
}
}
} else {
numWorkers := runtime.GOMAXPROCS(0)
if numWorkers > len(allChunks) {
numWorkers = len(allChunks)
}
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
results := make([][]int32, numWorkers)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
start := i * chunksPer
end := start + chunksPer
if end > len(allChunks) {
end = len(allChunks)
}
if start >= end {
continue
}
wg.Add(1)
go func(i int, chunks []chunk) {
defer wg.Done()
var r []int32
for _, c := range chunks {
if c.isSpecial {
if id, ok := t.specialTokens[c.text]; ok {
r = append(r, id)
}
} else {
r = t.encodeChunkInto(c.text, r)
}
}
results[i] = r
}(i, allChunks[start:end])
}
wg.Wait()
for _, r := range results {
ids = append(ids, r...)
}
}
if addBOS && t.vocab.BOS >= 0 {
ids = append([]int32{t.vocab.BOS}, ids...)
}
return ids
}
// encodeChunkInto appends encoded tokens to ids and returns the extended slice
// Uses BPE merge algorithm when merges are available, otherwise longest-match
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
if t.typ == TokenizerWordPiece {
return t.encodeWordPieceInto(s, ids)
}
if s == "" {
return ids
}
// Apply encoding transformation
// SentencePiece: replace space with ▁
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
var encoded string
if t.typ == TokenizerSentencePiece {
encoded = strings.ReplaceAll(s, " ", "▁")
} else {
var sb strings.Builder
sb.Grow(len(s) * 2)
for i := 0; i < len(s); i++ {
sb.WriteRune(byteToRune[s[i]])
}
encoded = sb.String()
}
// Fast path: check if entire chunk is a single token
if id, ok := t.vocab.Reverse[encoded]; ok {
return append(ids, id)
}
return t.encodeBPEMerge(encoded, ids)
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Repeatedly merges the pair with lowest rank until no more merges possible.
// Works correctly with empty merges (falls back to individual rune/byte encoding).
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
// Start with individual runes as parts
runes := []rune(encoded)
parts := make([]string, len(runes))
for i, r := range runes {
parts[i] = string(r)
}
// Repeatedly merge lowest-rank pair
for len(parts) > 1 {
minRank := int(0x7FFFFFFF)
minIdx := -1
for i := 0; i < len(parts)-1; i++ {
// Merge key format: "token1 token2" (space-separated)
mergeKey := parts[i] + " " + parts[i+1]
if rank, ok := t.vocab.Merges[mergeKey]; ok {
if rank < minRank {
minRank = rank
minIdx = i
}
}
}
if minIdx < 0 {
break // No more merges possible
}
// Merge the pair
parts[minIdx] = parts[minIdx] + parts[minIdx+1]
parts = append(parts[:minIdx+1], parts[minIdx+2:]...)
}
// Convert parts to token IDs
for _, part := range parts {
if id, ok := t.vocab.Reverse[part]; ok {
ids = append(ids, id)
} else {
// Byte fallback for unknown parts
for _, b := range []byte(part) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
}
return ids
}
// encodeWordPieceInto appends WordPiece tokens to ids and returns extended slice
// Uses greedy longest-match with ## prefix for continuation tokens
func (t *Tokenizer) encodeWordPieceInto(s string, ids []int32) []int32 {
if s == "" {
return ids
}
// Check if entire string is in vocabulary (common case)
if id, ok := t.vocab.Reverse[s]; ok {
return append(ids, id)
}
runes := []rune(s)
start := 0
for start < len(runes) {
end := len(runes)
found := false
// Greedy longest-match
for end > start {
substr := string(runes[start:end])
if start > 0 {
// Continuation token: prefix with ##
substr = "##" + substr
}
if id, ok := t.vocab.Reverse[substr]; ok {
ids = append(ids, id)
found = true
start = end
break
}
end--
}
if !found {
// No match found - use [UNK] token or skip
if t.unkToken >= 0 {
ids = append(ids, t.unkToken)
}
start++
}
}
return ids
}
// Decode converts token IDs back to text
func (t *Tokenizer) Decode(ids []int32) string {
var sb strings.Builder
for _, id := range ids {
if int(id) >= len(t.vocab.Values) {
continue
}
token := t.vocab.Values[id]
switch t.typ {
case TokenizerWordPiece:
// WordPiece style: strip ## prefix from continuation tokens
if strings.HasPrefix(token, "##") {
sb.WriteString(token[2:])
} else {
sb.WriteString(token)
}
case TokenizerSentencePiece:
// SentencePiece style: replace ▁ with space, decode byte tokens
token = strings.ReplaceAll(token, "▁", " ")
// Handle byte fallback tokens like <0x0D>
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
sb.WriteByte(byte(v))
continue
}
}
sb.WriteString(token)
default:
// GPT-2 BPE style: decode byte-level encoding
for _, r := range token {
switch {
case r == 0x0100:
// NULL byte (0x00 encoded as 0x0100)
sb.WriteByte(0)
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// Write as byte, not UTF-8 encoded rune
sb.WriteByte(byte(r))
}
}
}
return sb.String()
}
// VocabSize returns the vocabulary size
func (t *Tokenizer) VocabSize() int {
return len(t.vocab.Values)
}
// BOS returns the beginning of sequence token ID
func (t *Tokenizer) BOS() int32 {
return t.vocab.BOS
}
// EOS returns the first end of sequence token ID (for backwards compatibility)
func (t *Tokenizer) EOS() int32 {
if len(t.vocab.EOS) > 0 {
return t.vocab.EOS[0]
}
return -1
}
// EOSTokens returns all end of sequence token IDs
func (t *Tokenizer) EOSTokens() []int32 {
return t.vocab.EOS
}
// PAD returns the padding token ID, or -1 if not set
func (t *Tokenizer) PAD() int32 {
return t.vocab.PAD
}
// IsEOS returns true if the token ID is an end of sequence token
func (t *Tokenizer) IsEOS(id int32) bool {
for _, eos := range t.vocab.EOS {
if id == eos {
return true
}
}
return false
}
// GetSpecialToken returns the token ID for a special token string
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
id, ok := t.specialTokens[name]
return id, ok
}
// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style)
func LoadVocabMerges(dir string) (*Tokenizer, error) {
vocabPath := dir + "/vocab.json"
mergesPath := dir + "/merges.txt"
addedTokensPath := dir + "/added_tokens.json"
// Load vocab
vocabData, err := os.ReadFile(vocabPath)
if err != nil {
return nil, fmt.Errorf("failed to read vocab.json: %w", err)
}
vocabMap := make(map[string]int32)
if err := json.Unmarshal(vocabData, &vocabMap); err != nil {
return nil, fmt.Errorf("failed to parse vocab.json: %w", err)
}
// Load merges
mergesData, err := os.ReadFile(mergesPath)
if err != nil {
return nil, fmt.Errorf("failed to read merges.txt: %w", err)
}
mergesLines := strings.Split(string(mergesData), "\n")
var mergesStrings []string
for _, line := range mergesLines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
mergesStrings = append(mergesStrings, line)
}
// Build tokenizer
t := &Tokenizer{
vocab: &Vocabulary{
Values: make([]string, len(vocabMap)),
Reverse: vocabMap,
Merges: make(map[string]int, len(mergesStrings)),
BOS: -1,
PAD: -1,
},
specialTokens: make(map[string]int32),
}
// Load added tokens if exists
if addedData, err := os.ReadFile(addedTokensPath); err == nil {
addedMap := make(map[string]int32)
if err := json.Unmarshal(addedData, &addedMap); err == nil {
for token, id := range addedMap {
vocabMap[token] = id
t.specialTokens[token] = id
}
}
}
// Build values array
for token, id := range vocabMap {
if int(id) >= len(t.vocab.Values) {
newValues := make([]string, id+1)
copy(newValues, t.vocab.Values)
t.vocab.Values = newValues
}
t.vocab.Values[id] = token
}
// Build merges map
for i, merge := range mergesStrings {
t.vocab.Merges[merge] = i
}
// Load special token configuration from companion files
loadSpecialTokenConfig(dir+"/", t)
// Precompute byte token IDs for <0xNN> fallback
initByteTokens(t)
// GPT-2/tiktoken pretokenizer pattern
pattern := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
re, err := regexp.Compile(rewritePatternForRE2(pattern))
if err != nil {
return nil, fmt.Errorf("failed to compile pretokenizer regex: %w", err)
}
t.pretokenizer = re
return t, nil
}