diff --git a/CMakeLists.txt b/CMakeLists.txt index 2820dee09..6ea95f7d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21) project(Ollama C CXX) +# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a +# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR +# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the +# host architecture, but downstream projects (like MLX) use it to detect the +# target architecture. +if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";") + # Single architecture specified + if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64") + set(CMAKE_SYSTEM_PROCESSOR "x86_64") + elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64") + set(CMAKE_SYSTEM_PROCESSOR "arm64") + endif() +endif() + include(CheckLanguage) include(GNUInstallDirs) @@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly set(GGML_BUILD ON) set(GGML_SHARED ON) @@ -147,14 +163,48 @@ if(CMAKE_HIP_COMPILER) endif() endif() -find_package(Vulkan) -if(Vulkan_FOUND) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) - install(TARGETS ggml-vulkan - RUNTIME_DEPENDENCIES - PRE_INCLUDE_REGEXES vulkan - PRE_EXCLUDE_REGEXES ".*" - RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan - LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan - ) +if(NOT APPLE) + find_package(Vulkan) + if(Vulkan_FOUND) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) + install(TARGETS ggml-vulkan + RUNTIME_DEPENDENCIES + PRE_INCLUDE_REGEXES vulkan + PRE_EXCLUDE_REGEXES ".*" + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan + ) + endif() endif() + +option(MLX_ENGINE "Enable MLX backend" OFF) + +if(MLX_ENGINE) + message(STATUS "Setting up MLX (this takes a while...)") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx) + + # Find CUDA toolkit if MLX is built with CUDA support + find_package(CUDAToolkit) + + install(TARGETS mlx mlxc + RUNTIME_DEPENDENCIES + DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} + PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl + PRE_EXCLUDE_REGEXES ".*" + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX + FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX + ) + + # Manually install cudart and cublas since they might not be picked up as direct dependencies + if(CUDAToolkit_FOUND) + file(GLOB CUDART_LIBS + "${CUDAToolkit_LIBRARY_DIR}/libcudart.so*" + "${CUDAToolkit_LIBRARY_DIR}/libcublas.so*") + if(CUDART_LIBS) + install(FILES ${CUDART_LIBS} + DESTINATION ${OLLAMA_INSTALL_DIR} + COMPONENT MLX) + endif() + endif() +endif() \ No newline at end of file diff --git a/CMakePresets.json b/CMakePresets.json index 6fcdf4d25..64b7fd58a 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -41,7 +41,7 @@ "inherits": [ "CUDA" ], "cacheVariables": { "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", - "CMAKE_CUDA_FLAGS": "-t 2", + "CMAKE_CUDA_FLAGS": "-t 4", "OLLAMA_RUNNER_DIR": "cuda_v13" } }, @@ -83,6 +83,28 @@ "cacheVariables": { "OLLAMA_RUNNER_DIR": "vulkan" } + }, + { + "name": "MLX", + "inherits": [ "Default" ], + "cacheVariables": { + "MLX_ENGINE": "ON", + "OLLAMA_RUNNER_DIR": "mlx" + } + }, + { + "name": "MLX CUDA 12", + "inherits": [ "MLX", "CUDA 12" ], + "cacheVariables": { + "OLLAMA_RUNNER_DIR": "mlx_cuda_v12" + } + }, + { + "name": "MLX CUDA 13", + "inherits": [ "MLX", "CUDA 13" ], + "cacheVariables": { + "OLLAMA_RUNNER_DIR": "mlx_cuda_v13" + } } ], "buildPresets": [ @@ -140,6 +162,21 @@ "name": "Vulkan", "targets": [ "ggml-vulkan" ], "configurePreset": "Vulkan" + }, + { + "name": "MLX", + "targets": [ "mlx", "mlxc" ], + "configurePreset": "MLX" + }, + { + "name": "MLX CUDA 12", + "targets": [ "mlx", "mlxc" ], + "configurePreset": "MLX CUDA 12" + }, + { + "name": "MLX CUDA 13", + "targets": [ "mlx", "mlxc" ], + "configurePreset": "MLX CUDA 13" } ] } diff --git a/Dockerfile b/Dockerfile index c46cfe08e..6d893455e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -131,7 +131,39 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'Vulkan' \ && cmake --build --parallel --preset 'Vulkan' \ - && cmake --install build --component Vulkan --strip --parallel 8 + && cmake --install build --component Vulkan --strip --parallel 8 + +FROM base AS mlx +ARG CUDA13VERSION=13.0 +RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \ + && dnf install -y openblas-devel lapack-devel \ + && dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \ + && dnf install -y libnccl libnccl-devel +ENV PATH=/usr/local/cuda-13/bin:$PATH +ENV BLAS_INCLUDE_DIRS=/usr/include/openblas +ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas +ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs" +ARG PARALLEL +WORKDIR /go/src/github.com/ollama/ollama +COPY CMakeLists.txt CMakePresets.json . +COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +COPY x/ml/backend/mlx x/ml/backend/mlx +COPY go.mod go.sum . +RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local +ENV PATH=/usr/local/go/bin:$PATH +RUN go mod download +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ + && cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \ + && cmake --install build --component MLX --strip --parallel ${PARALLEL} +COPY . . +ARG GOFLAGS="'-ldflags=-w -s'" +ENV CGO_ENABLED=1 +ARG CGO_CFLAGS +ARG CGO_CXXFLAGS +# TODO wire up the actual MLX engine here instead of building the main binary... +RUN mkdir -p dist/bin +RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine FROM base AS build @@ -153,6 +185,8 @@ FROM --platform=linux/amd64 scratch AS amd64 COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ COPY --from=vulkan dist/lib/ollama /lib/ollama/ +COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/ +COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/ FROM --platform=linux/arm64 scratch AS arm64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ diff --git a/convert/convert.go b/convert/convert.go index a6d286683..bd3c84344 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -6,11 +6,14 @@ import ( "errors" "fmt" "io/fs" + "iter" "log/slog" + "maps" "os" "slices" "strings" + ofs "github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs/ggml" ) @@ -18,8 +21,13 @@ type ModelParameters struct { Architectures []string `json:"architectures"` VocabSize uint32 `json:"vocab_size"` + // TODO is this needed? + ModelType string `json:"model_type"` + TextModel struct { - VocabSize uint32 `json:"vocab_size"` + VocabSize uint32 `json:"vocab_size"` + HiddenSize uint32 `json:"hidden_size"` + ModelType string `json:"model_type"` } `json:"text_config"` } @@ -33,8 +41,94 @@ type AdapterParameters struct { } `json:"lora_parameters"` } -func (ModelParameters) KV(t *Tokenizer) ggml.KV { - kv := ggml.KV{ +type KV map[string]any + +func (kv KV) Architecture() string { + return kv.String("general.architecture", "unknown") +} + +type valueTypes interface { + uint8 | int8 | uint16 | int16 | + uint32 | int32 | uint64 | int64 | + string | float32 | float64 | bool +} + +type arrayValueTypes interface { + []uint8 | []int8 | []uint16 | []int16 | + []uint32 | []int32 | []uint64 | []int64 | + []string | []float32 | []float64 | []bool +} + +func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) { + if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { + key = kv.Architecture() + "." + key + } + + if val, ok := kv[key].(T); ok { + return val, true + } + return defaultValue[0], false +} + +func (kv KV) String(key string, defaultValue ...string) string { + val, _ := keyValue(kv, key, append(defaultValue, "")...) + return val +} + +func (kv KV) Uint(key string, defaultValue ...uint32) uint32 { + val, _ := keyValue(kv, key, append(defaultValue, 0)...) + return val +} + +func (kv KV) Float(key string, defaultValue ...float32) float32 { + val, _ := keyValue(kv, key, append(defaultValue, 0)...) + return val +} + +func (kv KV) Bool(key string, defaultValue ...bool) bool { + val, _ := keyValue(kv, key, append(defaultValue, false)...) + return val +} + +func (kv KV) Strings(key string, defaultValue ...[]string) []string { + val, _ := keyValue(kv, key, append(defaultValue, []string{""})...) + return val +} + +func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 { + val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...) + return val +} + +func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { + val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...) + return val +} + +func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { + val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...) + return val +} + +func (kv KV) Bools(key string, defaultValue ...[]bool) []bool { + val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...) + return val +} + +func (kv KV) Len() int { + return len(kv) +} + +func (kv KV) Keys() iter.Seq[string] { + return maps.Keys(kv) +} + +func (kv KV) Value(key string) any { + return kv[key] +} + +func (ModelParameters) KV(t *Tokenizer) KV { + kv := KV{ "general.file_type": uint32(1), "general.quantization_version": uint32(2), "tokenizer.ggml.pre": t.Pre, @@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV { return kv } -func (p AdapterParameters) KV() ggml.KV { +func (p AdapterParameters) KV() KV { var alpha float32 if p.LoraParameters.Alpha == 0 { alpha = float32(p.Alpha) @@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV { alpha = p.LoraParameters.Alpha } - kv := ggml.KV{ + kv := KV{ "adapter.lora.alpha": alpha, "adapter.type": "lora", "general.file_type": uint32(1), @@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string { } } -type ModelConverter interface { +type ModelKV interface { // KV maps parameters to LLM key-values - KV(*Tokenizer) ggml.KV + KV(*Tokenizer) KV +} + +type ModelConverter interface { + ModelKV + // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here. Tensors([]Tensor) []*ggml.Tensor // Replacements returns a list of string pairs to replace in tensor names. @@ -107,7 +206,7 @@ type moreParser interface { type AdapterConverter interface { // KV maps parameters to LLM key-values - KV(ggml.KV) ggml.KV + KV(ofs.Config) KV // Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here. Tensors([]Tensor) []*ggml.Tensor // Replacements returns a list of string pairs to replace in tensor names. @@ -115,7 +214,7 @@ type AdapterConverter interface { Replacements() []string } -func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error { +func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error { bts, err := fs.ReadFile(fsys, "adapter_config.json") if err != nil { return err @@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error { return err } - arch, ok := baseKV["general.architecture"] - if !ok { + arch := baseKV.Architecture() + if arch == "" { return errors.New("architecture not set for the base model") } @@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error { return writeFile(f, conv.KV(baseKV), conv.Tensors(ts)) } -// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations -// and files it finds in the input path. -// Supported input model formats include safetensors. -// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model. -func ConvertModel(fsys fs.FS, f *os.File) error { +func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) { bts, err := fs.ReadFile(fsys, "config.json") if err != nil { - return err + return nil, nil, err } var p ModelParameters if err := json.Unmarshal(bts, &p); err != nil { - return err + return nil, nil, err } if len(p.Architectures) < 1 { - return errors.New("unknown architecture") + return nil, nil, errors.New("unknown architecture") } var conv ModelConverter @@ -217,22 +312,22 @@ func ConvertModel(fsys fs.FS, f *os.File) error { case "DeepseekV3ForCausalLM": conv = &deepseek2Model{} default: - return fmt.Errorf("unsupported architecture %q", p.Architectures[0]) + return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } if err := json.Unmarshal(bts, conv); err != nil { - return err + return nil, nil, err } if t, ok := conv.(moreParser); ok { if err := t.parseMore(fsys); err != nil { - return err + return nil, nil, err } } t, err := parseTokenizer(fsys, conv.specialTokenTypes()) if err != nil { - return err + return nil, nil, err } vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize)) @@ -254,6 +349,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error { default: slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) } + return conv, t, nil +} + +// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations +// and files it finds in the input path. +// Supported input model formats include safetensors. +// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model. +func ConvertModel(fsys fs.FS, f *os.File) error { + kv, t, err := LoadModelMetadata(fsys) + if err != nil { + return err + } + conv := kv.(ModelConverter) ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...)) if err != nil { @@ -263,7 +371,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error { return writeFile(f, conv.KV(t), conv.Tensors(ts)) } -func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error { +func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error { for i := range ts { ts[i].Shape = slices.Clone(ts[i].Shape) slices.Reverse(ts[i].Shape) diff --git a/convert/convert_bert.go b/convert/convert_bert.go index 6b0d0030a..7acdf53ef 100644 --- a/convert/convert_bert.go +++ b/convert/convert_bert.go @@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error { return nil } -func (p *bertModel) KV(t *Tokenizer) ggml.KV { +func (p *bertModel) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "bert" kv["bert.attention.causal"] = false diff --git a/convert/convert_commandr.go b/convert/convert_commandr.go index a909515bd..48b2bb3f5 100644 --- a/convert/convert_commandr.go +++ b/convert/convert_commandr.go @@ -24,7 +24,7 @@ type commandrModel struct { var _ ModelConverter = (*commandrModel)(nil) -func (p *commandrModel) KV(t *Tokenizer) ggml.KV { +func (p *commandrModel) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "command-r" kv["general.name"] = "command-r" diff --git a/convert/convert_deepseek2.go b/convert/convert_deepseek2.go index aa6203277..dce81f3eb 100644 --- a/convert/convert_deepseek2.go +++ b/convert/convert_deepseek2.go @@ -47,7 +47,7 @@ type deepseek2Model struct { Architecture string } -func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV { +func (p *deepseek2Model) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "deepseek2" kv["general.type"] = "model" diff --git a/convert/convert_deepseekocr.go b/convert/convert_deepseekocr.go index cf1dfa0c4..f2de490b0 100644 --- a/convert/convert_deepseekocr.go +++ b/convert/convert_deepseekocr.go @@ -41,7 +41,7 @@ type deepseekocr struct { } `json:"vision_config"` } -func (m *deepseekocr) KV(t *Tokenizer) ggml.KV { +func (m *deepseekocr) KV(t *Tokenizer) KV { kv := m.ModelParameters.KV(t) kv["general.architecture"] = "deepseekocr" kv["block_count"] = m.LanguageConfig.HiddenLayers diff --git a/convert/convert_gemma.go b/convert/convert_gemma.go index 26698d6a6..ab53c062e 100644 --- a/convert/convert_gemma.go +++ b/convert/convert_gemma.go @@ -23,7 +23,7 @@ type gemmaModel struct { var _ ModelConverter = (*gemmaModel)(nil) -func (p *gemmaModel) KV(t *Tokenizer) ggml.KV { +func (p *gemmaModel) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "gemma" kv["gemma.context_length"] = p.MaxPositionEmbeddings diff --git a/convert/convert_gemma2.go b/convert/convert_gemma2.go index 4917e42cd..aecc67ff7 100644 --- a/convert/convert_gemma2.go +++ b/convert/convert_gemma2.go @@ -1,7 +1,5 @@ package convert -import "github.com/ollama/ollama/fs/ggml" - type gemma2Model struct { gemmaModel SlidingWindow uint32 `json:"sliding_window"` @@ -9,7 +7,7 @@ type gemma2Model struct { FinalLogitSoftcap float32 `json:"final_logit_softcapping"` } -func (p *gemma2Model) KV(t *Tokenizer) ggml.KV { +func (p *gemma2Model) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "gemma2" kv["gemma2.context_length"] = p.MaxPositionEmbeddings diff --git a/convert/convert_gemma2_adapter.go b/convert/convert_gemma2_adapter.go index 6299cd9e0..fa070e073 100644 --- a/convert/convert_gemma2_adapter.go +++ b/convert/convert_gemma2_adapter.go @@ -6,6 +6,7 @@ import ( "github.com/pdevine/tensor" "github.com/pdevine/tensor/native" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs/ggml" ) @@ -15,7 +16,7 @@ type gemma2Adapter struct { var _ AdapterConverter = (*gemma2Adapter)(nil) -func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV { +func (p *gemma2Adapter) KV(baseKV fs.Config) KV { kv := p.AdapterParameters.KV() kv["general.architecture"] = "gemma2" return kv diff --git a/convert/convert_gemma3.go b/convert/convert_gemma3.go index 5e6e6904c..bd5cc211a 100644 --- a/convert/convert_gemma3.go +++ b/convert/convert_gemma3.go @@ -3,8 +3,6 @@ package convert import ( "cmp" "slices" - - "github.com/ollama/ollama/fs/ggml" ) type gemma3Model struct { @@ -55,7 +53,7 @@ const ( gemma27BLayerCount = 62 ) -func (p *gemma3Model) KV(t *Tokenizer) ggml.KV { +func (p *gemma3Model) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "gemma3" diff --git a/convert/convert_gemma3n.go b/convert/convert_gemma3n.go index 135ebaa55..03f878632 100644 --- a/convert/convert_gemma3n.go +++ b/convert/convert_gemma3n.go @@ -38,7 +38,7 @@ type gemma3nModel struct { VisionModel struct{} `json:"vision_config"` } -func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV { +func (m *gemma3nModel) KV(t *Tokenizer) KV { kv := m.ModelParameters.KV(t) kv["general.architecture"] = "gemma3n" kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) { diff --git a/convert/convert_gptoss.go b/convert/convert_gptoss.go index d7bfb361d..462e92179 100644 --- a/convert/convert_gptoss.go +++ b/convert/convert_gptoss.go @@ -37,7 +37,7 @@ type gptossModel struct { var _ ModelConverter = (*gptossModel)(nil) -func (m *gptossModel) KV(t *Tokenizer) ggml.KV { +func (m *gptossModel) KV(t *Tokenizer) KV { kv := m.ModelParameters.KV(t) kv["general.architecture"] = "gptoss" kv["general.file_type"] = uint32(4) diff --git a/convert/convert_llama.go b/convert/convert_llama.go index 43969749c..b72b7220e 100644 --- a/convert/convert_llama.go +++ b/convert/convert_llama.go @@ -48,7 +48,7 @@ type llamaModel struct { var _ ModelConverter = (*llamaModel)(nil) -func (p *llamaModel) KV(t *Tokenizer) ggml.KV { +func (p *llamaModel) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "llama" kv["llama.vocab_size"] = p.VocabSize diff --git a/convert/convert_llama4.go b/convert/convert_llama4.go index 3e3792339..be3a6b2a8 100644 --- a/convert/convert_llama4.go +++ b/convert/convert_llama4.go @@ -35,7 +35,7 @@ type llama4Model struct { } // KV implements ModelConverter. -func (p *llama4Model) KV(t *Tokenizer) ggml.KV { +func (p *llama4Model) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "llama4" diff --git a/convert/convert_llama_adapter.go b/convert/convert_llama_adapter.go index 4cc451153..6ea6806c8 100644 --- a/convert/convert_llama_adapter.go +++ b/convert/convert_llama_adapter.go @@ -7,6 +7,7 @@ import ( "github.com/pdevine/tensor" "github.com/pdevine/tensor/native" + "github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs/ggml" ) @@ -18,13 +19,13 @@ type llamaAdapter struct { var _ AdapterConverter = (*llamaAdapter)(nil) -func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV { +func (p *llamaAdapter) KV(baseKV fs.Config) KV { kv := p.AdapterParameters.KV() kv["general.architecture"] = "llama" - kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"] - kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"] + kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count") + kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv") - p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32) + p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32) return kv } diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go index f11bd9644..1b9f9a3b6 100644 --- a/convert/convert_mistral.go +++ b/convert/convert_mistral.go @@ -60,7 +60,7 @@ type mistral3Model struct { ProjectorHiddenAct string `json:"projector_hidden_act"` } -func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { +func (p *mistral3Model) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "mistral3" kv["mistral3.vocab_size"] = p.TextModel.VocabSize diff --git a/convert/convert_mistral_causal.go b/convert/convert_mistral_causal.go index 99a483736..3aeaa41b4 100644 --- a/convert/convert_mistral_causal.go +++ b/convert/convert_mistral_causal.go @@ -39,7 +39,7 @@ type mistral3CausalModel struct { } `json:"rope_parameters"` } -func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV { +func (p *mistral3CausalModel) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "mistral3" kv["mistral3.vocab_size"] = p.VocabSize diff --git a/convert/convert_mixtral.go b/convert/convert_mixtral.go index 7d60146bd..7104b67d1 100644 --- a/convert/convert_mixtral.go +++ b/convert/convert_mixtral.go @@ -12,7 +12,7 @@ type mixtralModel struct { NumExpertsPerToken uint32 `json:"num_experts_per_tok"` } -func (p *mixtralModel) KV(t *Tokenizer) ggml.KV { +func (p *mixtralModel) KV(t *Tokenizer) KV { kv := p.llamaModel.KV(t) if p.NumLocalExperts > 0 { diff --git a/convert/convert_mllama.go b/convert/convert_mllama.go index 69d7f5882..5c9e7ac69 100644 --- a/convert/convert_mllama.go +++ b/convert/convert_mllama.go @@ -34,7 +34,7 @@ type mllamaModel struct { } `json:"vision_config"` } -func (m *mllamaModel) KV(t *Tokenizer) ggml.KV { +func (m *mllamaModel) KV(t *Tokenizer) KV { kv := m.ModelParameters.KV(t) kv["general.architecture"] = "mllama" diff --git a/convert/convert_nomicbert.go b/convert/convert_nomicbert.go index 6aed5ee75..fa4ed6ac8 100644 --- a/convert/convert_nomicbert.go +++ b/convert/convert_nomicbert.go @@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error { return nil } -func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV { +func (p *nomicbertModel) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) // Determine architecture based on MoE parameters (following qwen3 pattern) diff --git a/convert/convert_olmo.go b/convert/convert_olmo.go index f75c68477..bd5788d82 100644 --- a/convert/convert_olmo.go +++ b/convert/convert_olmo.go @@ -34,7 +34,7 @@ type olmoModel struct { var _ ModelConverter = (*olmoModel)(nil) -func (p *olmoModel) KV(t *Tokenizer) ggml.KV { +func (p *olmoModel) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "olmo3" kv["olmo3.block_count"] = p.NumHiddenLayers diff --git a/convert/convert_phi3.go b/convert/convert_phi3.go index 5a6756053..5fb72cf27 100644 --- a/convert/convert_phi3.go +++ b/convert/convert_phi3.go @@ -37,7 +37,7 @@ type phi3Model struct { var _ ModelConverter = (*phi3Model)(nil) -func (p *phi3Model) KV(t *Tokenizer) ggml.KV { +func (p *phi3Model) KV(t *Tokenizer) KV { kv := p.ModelParameters.KV(t) kv["general.architecture"] = "phi3" kv["phi3.context_length"] = p.MaxPositionEmbeddings diff --git a/convert/convert_qwen2.go b/convert/convert_qwen2.go index 3647c4e54..0c9c84eee 100644 --- a/convert/convert_qwen2.go +++ b/convert/convert_qwen2.go @@ -22,7 +22,7 @@ type qwen2Model struct { var _ ModelConverter = (*qwen2Model)(nil) -func (q *qwen2Model) KV(t *Tokenizer) ggml.KV { +func (q *qwen2Model) KV(t *Tokenizer) KV { kv := q.ModelParameters.KV(t) kv["general.architecture"] = "qwen2" kv["qwen2.block_count"] = q.HiddenLayers diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go index 6e4c96408..60d6f4943 100644 --- a/convert/convert_qwen25vl.go +++ b/convert/convert_qwen25vl.go @@ -29,7 +29,7 @@ type qwen25VLModel struct { var _ ModelConverter = (*qwen25VLModel)(nil) -func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV { +func (q *qwen25VLModel) KV(t *Tokenizer) KV { kv := q.ModelParameters.KV(t) kv["general.architecture"] = "qwen25vl" diff --git a/convert/convert_qwen3.go b/convert/convert_qwen3.go index f54418a9c..6ceebedee 100644 --- a/convert/convert_qwen3.go +++ b/convert/convert_qwen3.go @@ -32,7 +32,7 @@ type qwen3Model struct { } // KV implements ModelConverter. -func (q *qwen3Model) KV(t *Tokenizer) ggml.KV { +func (q *qwen3Model) KV(t *Tokenizer) KV { arch := "qwen3" if q.NumExperts > 0 { arch += "moe" diff --git a/convert/convert_qwen3vl.go b/convert/convert_qwen3vl.go index e0ccb805f..041e15017 100644 --- a/convert/convert_qwen3vl.go +++ b/convert/convert_qwen3vl.go @@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error { return json.Unmarshal(bts, &m.VisionModel) } -func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV { +func (m *qwen3VLModel) KV(t *Tokenizer) KV { kv := m.qwen3Model.KV(t) arch := "qwen3vl" diff --git a/convert/convert_test.go b/convert/convert_test.go index a63e478f3..fa5d7488a 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -19,6 +19,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + fsc "github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs/ggml" ) @@ -28,7 +29,7 @@ type tensorData struct { Shape []int `json:"shape"` } -func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) { +func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) { t.Helper() f, err := os.CreateTemp(t.TempDir(), "f16") @@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) { return r, m.KV(), m.Tensors() } -func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string { +func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string { actual := make(map[string]string) - for k, v := range kv { + for k := range kv.Keys() { + v := kv.Value(k) if s, ok := v.(json.Marshaler); !ok { actual[k] = fmt.Sprintf("%v", v) } else { @@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str func TestConvertAdapter(t *testing.T) { type AdapterCase struct { Name string - BaseKV map[string]any + BaseKV KV Expected map[string]string } diff --git a/fs/config.go b/fs/config.go index 3d6ae90ec..db305bc56 100644 --- a/fs/config.go +++ b/fs/config.go @@ -1,5 +1,7 @@ package fs +import "iter" + type Config interface { Architecture() string String(string, ...string) string @@ -11,4 +13,8 @@ type Config interface { Ints(string, ...[]int32) []int32 Floats(string, ...[]float32) []float32 Bools(string, ...[]bool) []bool + + Len() int + Keys() iter.Seq[string] + Value(key string) any } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 44a48511c..4d0dcb07c 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "io" + "iter" "log/slog" + "maps" "math" "slices" "strings" @@ -239,6 +241,18 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool { return val.values } +func (kv KV) Len() int { + return len(kv) +} + +func (kv KV) Keys() iter.Seq[string] { + return maps.Keys(kv) +} + +func (kv KV) Value(key string) any { + return kv[key] +} + func (kv KV) OllamaEngineRequired() bool { return slices.Contains([]string{ "bert", diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index e093efea1..3cae4979d 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -8,12 +8,12 @@ import ( "fmt" "io" "log/slog" - "maps" "os" "runtime" "slices" "strings" + "github.com/ollama/ollama/fs" "golang.org/x/sync/errgroup" ) @@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error { return binary.Write(w, binary.LittleEndian, s) } -func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error { +func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error { arch := kv.String("general.architecture") if arch == "" { return fmt.Errorf("architecture not set") @@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error { return err } - if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil { + if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil { return err } - for _, key := range slices.Sorted(maps.Keys(kv)) { - if err := ggufWriteKV(f, arch, key, kv[key]); err != nil { + for _, key := range slices.Sorted(kv.Keys()) { + if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil { return err } } diff --git a/parser/parser_test.go b/parser/parser_test.go index 4b97e8c20..4dcfed0cb 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -21,6 +21,7 @@ import ( "golang.org/x/text/encoding/unicode" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/convert" "github.com/ollama/ollama/fs/ggml" ) @@ -801,7 +802,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, } defer f.Close() - base := map[string]any{"general.architecture": "test"} + var base convert.KV = map[string]any{"general.architecture": "test"} maps.Copy(base, kv) if err := ggml.WriteGGUF(f, base, ti); err != nil { diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index 7ee9e2817..c5294e04a 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -42,18 +42,39 @@ shift $(( $OPTIND - 1 )) _build_darwin() { for ARCH in $ARCHS; do status "Building darwin $ARCH" - INSTALL_PREFIX=dist/darwin-$ARCH/ - GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX . + INSTALL_PREFIX=dist/darwin-$ARCH/ if [ "$ARCH" = "amd64" ]; then status "Building darwin $ARCH dynamic backends" - cmake -B build/darwin-$ARCH \ + BUILD_DIR=build/darwin-$ARCH + cmake -B $BUILD_DIR \ -DCMAKE_OSX_ARCHITECTURES=x86_64 \ - -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DMLX_ENGINE=ON \ + -DMLX_ENABLE_X64_MAC=ON \ + -DOLLAMA_RUNNER_DIR=./ + cmake --build $BUILD_DIR --target ggml-cpu -j + cmake --build $BUILD_DIR --target mlx mlxc -j + cmake --install $BUILD_DIR --component CPU + cmake --install $BUILD_DIR --component MLX + # Override CGO flags to point to the amd64 build directory + MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" + MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0" + else + BUILD_DIR=build + cmake --preset MLX \ + -DOLLAMA_RUNNER_DIR=./ \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX - cmake --build build/darwin-$ARCH --target ggml-cpu -j - cmake --install build/darwin-$ARCH --component CPU + cmake --build --preset MLX --parallel + cmake --install $BUILD_DIR --component MLX + # Use default CGO flags from mlx.go for arm64 + MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" + MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0" fi + GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine + GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX . done } @@ -61,10 +82,12 @@ _sign_darwin() { status "Creating universal binary..." mkdir -p dist/darwin lipo -create -output dist/darwin/ollama dist/darwin-*/ollama + lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen chmod +x dist/darwin/ollama + chmod +x dist/darwin/imagegen if [ -n "$APPLE_IDENTITY" ]; then - for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do + for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F done @@ -131,17 +154,23 @@ _build_macapp() { mkdir -p dist/Ollama.app/Contents/Resources if [ -d dist/darwin-amd64 ]; then lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama - cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/ + lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen + for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do + lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F) + done + cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/ + cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/ else cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/ fi + cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen chmod a+x dist/Ollama.app/Contents/Resources/ollama # Sign if [ -n "$APPLE_IDENTITY" ]; then codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama - for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do + for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib} done codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app @@ -149,7 +178,7 @@ _build_macapp() { rm -f dist/Ollama-darwin.zip ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip - (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz + (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz # Notarize and Staple if [ -n "$APPLE_IDENTITY" ]; then diff --git a/scripts/build_linux.sh b/scripts/build_linux.sh index 7a5bfc9b8..11b278cbd 100755 --- a/scripts/build_linux.sh +++ b/scripts/build_linux.sh @@ -48,6 +48,55 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then . fi +# Deduplicate CUDA libraries across mlx_* and cuda_* directories +deduplicate_cuda_libs() { + local base_dir="$1" + echo "Deduplicating CUDA libraries in ${base_dir}..." + + # Find all mlx_cuda_* directories + for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do + [ -d "${mlx_dir}" ] || continue + + # Extract CUDA version (e.g., v12, v13) + cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//') + cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}" + + # Skip if corresponding cuda_* directory doesn't exist + [ -d "${cuda_dir}" ] || continue + + echo " Checking ${mlx_dir} against ${cuda_dir}..." + + # Find all .so* files in mlx directory + find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do + filename=$(basename "${mlx_file}") + cuda_file="${cuda_dir}/${filename}" + + # Skip if file doesn't exist in cuda directory + [ -f "${cuda_file}" ] || continue + + # Compare checksums + mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}') + cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}') + + if [ "${mlx_sum}" = "${cuda_sum}" ]; then + echo " Deduplicating ${filename}" + # Calculate relative path from mlx_dir to cuda_dir + rel_path="../cuda_${cuda_version}/${filename}" + rm -f "${mlx_file}" + ln -s "${rel_path}" "${mlx_file}" + fi + done + done +} + +# Run deduplication for each platform output directory +if echo $PLATFORM | grep "," > /dev/null ; then + deduplicate_cuda_libs "./dist/linux_amd64" + deduplicate_cuda_libs "./dist/linux_arm64" +elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then + deduplicate_cuda_libs "./dist" +fi + # buildx behavior changes for single vs. multiplatform echo "Compressing linux tar bundles..." if echo $PLATFORM | grep "," > /dev/null ; then diff --git a/server/create.go b/server/create.go index 15e364e1e..0944f7685 100644 --- a/server/create.go +++ b/server/create.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/convert" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" + ofs "github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/errtypes" @@ -454,7 +455,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is return layers, nil } -func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) { +func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) { for _, l := range baseLayers { if l.GGML != nil { return l.KV(), nil diff --git a/server/routes_create_test.go b/server/routes_create_test.go index b1b1a2882..3d2ac3b5d 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -22,6 +22,7 @@ import ( gocmpopts "github.com/google/go-cmp/cmp/cmpopts" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/convert" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/types/model" @@ -41,7 +42,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, } defer f.Close() - base := map[string]any{"general.architecture": "test"} + var base convert.KV = map[string]any{"general.architecture": "test"} maps.Copy(base, kv) if err := ggml.WriteGGUF(f, base, ti); err != nil { diff --git a/x/README.md b/x/README.md new file mode 100644 index 000000000..56791bbed --- /dev/null +++ b/x/README.md @@ -0,0 +1,24 @@ +# Experimental Features + +## MLX Backend + +We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx) + +Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build: + +``` +cmake --preset MLX +cmake --build --preset MLX --parallel +cmake --install --component MLX +go build -tags mlx . +``` + +On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled. + +## Image Generation + +Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above: + +``` +go build -o imagegen ./x/imagegen/cmd/engine +``` diff --git a/x/imagegen/.gitignore b/x/imagegen/.gitignore new file mode 100644 index 000000000..2b00c701a --- /dev/null +++ b/x/imagegen/.gitignore @@ -0,0 +1,38 @@ +# Build directories +build/ +dist/ + +# CMake +CMakeCache.txt +CMakeFiles/ +cmake_install.cmake +Makefile +*.cmake + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# macOS +.DS_Store +*.dSYM/ + +# Go +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Python +*.npy + +/engine +weights +outputs + +prompt.txt +negative.txt diff --git a/x/imagegen/README.md b/x/imagegen/README.md new file mode 100644 index 000000000..e68f295b8 --- /dev/null +++ b/x/imagegen/README.md @@ -0,0 +1,61 @@ +# imagegen + +This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner. +in `CMakeLists.txt` and rebuild. + +### 1. Download a Model + +Download Llama 3.1 8B (or any compatible model) in safetensors format: + +```bash +mkdir -p ./weights + +# Example using huggingface-cli +hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B +hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b +``` + +### 2. Run Inference + +```bash +# Build +go build ./cmd/engine + +# Text generation +./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250 + +# Qwen-Image 2512 (text-to-image) +./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \ + -width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png + +# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50 +./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \ + -input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \ + -steps 8 -seed 42 -output edited.png +``` + +## Memory Management + +MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly. + +Instead, arrays are automatically tracked and freed on `Eval()`: + +```go +// All arrays are automatically tracked when created +x := mlx.Add(a, b) +y := mlx.Matmul(x, w) + +// Eval frees non-kept arrays, evaluates outputs (auto-kept) +mlx.Eval(y) + +// After copying to CPU, free the array +data := y.Data() +y.Free() +``` + +Key points: + +- All created arrays are automatically tracked +- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept) +- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches) +- Call `.Free()` when done with an array diff --git a/x/imagegen/cache/cache.go b/x/imagegen/cache/cache.go new file mode 100644 index 000000000..4faa2412e --- /dev/null +++ b/x/imagegen/cache/cache.go @@ -0,0 +1,156 @@ +//go:build mlx + +package cache + +import "github.com/ollama/ollama/x/imagegen/mlx" + +type Cache interface { + Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) + Offset() int + Len() int + State() []*mlx.Array +} + +type KVCache struct { + keys, values *mlx.Array + offset int + step int +} + +func NewKVCache() *KVCache { + return &KVCache{step: 256} +} + +func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { + prev := c.offset + shape := k.Shape() + B, H, Dk := shape[0], shape[1], shape[3] + Dv := v.Shape()[3] + + // Grow buffer if needed + if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) { + nSteps := (c.step + seqLen - 1) / c.step + newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype()) + newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype()) + + if c.keys != nil { + if prev%c.step != 0 { + c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk}) + c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv}) + } + c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) + c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) + } else { + c.keys, c.values = newK, newV + } + } + + c.offset += seqLen + c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk}) + c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv}) + + return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}), + mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv}) +} + +func (c *KVCache) State() []*mlx.Array { + if c.keys == nil { + return nil + } + return []*mlx.Array{c.keys, c.values} +} + +func (c *KVCache) Offset() int { return c.offset } +func (c *KVCache) Len() int { return c.offset } + +// RotatingKVCache implements sliding window attention with bounded memory +type RotatingKVCache struct { + keys, values *mlx.Array + offset int + maxSize int + step int + idx int +} + +func NewRotatingKVCache(maxSize int) *RotatingKVCache { + return &RotatingKVCache{maxSize: maxSize, step: 256} +} + +func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { + if seqLen > 1 { + return c.updateConcat(k, v, seqLen) + } + return c.updateInPlace(k, v) +} + +func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) { + shape := k.Shape() + B, H, Dk := shape[0], shape[1], shape[3] + Dv := v.Shape()[3] + + // Grow buffer if not yet at max + if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) { + var cap int + if c.keys != nil { + cap = int(c.keys.Shape()[2]) + } + newSize := min(c.step, c.maxSize-cap) + newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype()) + newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype()) + if c.keys != nil { + c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) + c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) + } else { + c.keys, c.values = newK, newV + } + } + + // Rotate when hitting max + if c.idx >= c.maxSize { + c.idx = 0 + } + + c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk}) + c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv}) + + c.offset++ + c.idx++ + + validLen := int32(min(c.offset, c.maxSize)) + return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}), + mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv}) +} + +func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { + shape := k.Shape() + B, H, Dk := shape[0], shape[1], shape[3] + Dv := v.Shape()[3] + + if c.keys == nil { + c.keys, c.values = k, v + } else { + c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2) + c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2) + } + c.offset += seqLen + + // Trim to max_size to maintain sliding window + cap := int(c.keys.Shape()[2]) + if trim := cap - c.maxSize; trim > 0 { + c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk}) + c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv}) + } + + c.idx = int(c.keys.Shape()[2]) + return c.keys, c.values +} + +func (c *RotatingKVCache) State() []*mlx.Array { + if c.keys == nil { + return nil + } + return []*mlx.Array{c.keys, c.values} +} + +func (c *RotatingKVCache) Offset() int { return c.offset } +func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) } diff --git a/x/imagegen/cache/step.go b/x/imagegen/cache/step.go new file mode 100644 index 000000000..830df447f --- /dev/null +++ b/x/imagegen/cache/step.go @@ -0,0 +1,164 @@ +//go:build mlx + +package cache + +import "github.com/ollama/ollama/x/imagegen/mlx" + +// StepCache caches layer outputs across diffusion denoising steps. +// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024): +// shallow layers change little between consecutive steps, so we can +// cache their outputs and skip recomputation on non-refresh steps. +// +// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures: +// - Single-stream: use Get/Set for the single output per layer +// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH) +// +// Usage (single-stream): +// +// cache := NewStepCache(15) // cache first 15 layers +// for step := 0; step < numSteps; step++ { +// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps +// for i, layer := range layers { +// if i < 15 && !refresh && cache.Get(i) != nil { +// output = cache.Get(i) // reuse cached +// } else { +// output = layer.Forward(input) +// if i < 15 && refresh { +// cache.Set(i, output) +// } +// } +// } +// } +// cache.Free() // cleanup when done +// +// Usage (dual-stream): +// +// cache := NewStepCache(15) +// for step := 0; step < numSteps; step++ { +// refresh := cache.ShouldRefresh(step, 3) +// for i, layer := range layers { +// if i < 15 && !refresh && cache.Get(i) != nil { +// imgH, txtH = cache.Get(i), cache.Get2(i) +// } else { +// imgH, txtH = layer.Forward(imgH, txtH, ...) +// if i < 15 && refresh { +// cache.Set(i, imgH) +// cache.Set2(i, txtH) +// } +// } +// } +// } +type StepCache struct { + layers []*mlx.Array // cached layer outputs (stream 1) + layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models) + constant *mlx.Array // optional constant (e.g., text embeddings) +} + +// NewStepCache creates a cache for the given number of layers. +func NewStepCache(numLayers int) *StepCache { + return &StepCache{ + layers: make([]*mlx.Array, numLayers), + layers2: make([]*mlx.Array, numLayers), + } +} + +// ShouldRefresh returns true if the cache should be refreshed at this step. +// Refresh happens on step 0, interval, 2*interval, etc. +func (c *StepCache) ShouldRefresh(step, interval int) bool { + return step%interval == 0 +} + +// Get returns the cached output for a layer, or nil if not cached. +func (c *StepCache) Get(layer int) *mlx.Array { + if layer < len(c.layers) { + return c.layers[layer] + } + return nil +} + +// Set stores a layer output (stream 1), freeing any previous value. +func (c *StepCache) Set(layer int, arr *mlx.Array) { + if layer < len(c.layers) { + if c.layers[layer] != nil { + c.layers[layer].Free() + } + c.layers[layer] = arr + } +} + +// Get2 returns the cached output for a layer (stream 2), or nil if not cached. +// Used for dual-stream architectures like Qwen-Image. +func (c *StepCache) Get2(layer int) *mlx.Array { + if layer < len(c.layers2) { + return c.layers2[layer] + } + return nil +} + +// Set2 stores a layer output (stream 2), freeing any previous value. +// Used for dual-stream architectures like Qwen-Image. +func (c *StepCache) Set2(layer int, arr *mlx.Array) { + if layer < len(c.layers2) { + if c.layers2[layer] != nil { + c.layers2[layer].Free() + } + c.layers2[layer] = arr + } +} + +// GetConstant returns the cached constant value. +func (c *StepCache) GetConstant() *mlx.Array { + return c.constant +} + +// SetConstant stores a constant value, freeing any previous value. +func (c *StepCache) SetConstant(arr *mlx.Array) { + if c.constant != nil { + c.constant.Free() + } + c.constant = arr +} + +// Arrays returns all non-nil cached arrays (for pool.Keep). +func (c *StepCache) Arrays() []*mlx.Array { + var result []*mlx.Array + if c.constant != nil { + result = append(result, c.constant) + } + for _, arr := range c.layers { + if arr != nil { + result = append(result, arr) + } + } + for _, arr := range c.layers2 { + if arr != nil { + result = append(result, arr) + } + } + return result +} + +// Free releases all cached arrays. Call when generation completes. +func (c *StepCache) Free() { + if c.constant != nil { + c.constant.Free() + c.constant = nil + } + for i, arr := range c.layers { + if arr != nil { + arr.Free() + c.layers[i] = nil + } + } + for i, arr := range c.layers2 { + if arr != nil { + arr.Free() + c.layers2[i] = nil + } + } +} + +// NumLayers returns the number of layers this cache can store. +func (c *StepCache) NumLayers() int { + return len(c.layers) +} diff --git a/x/imagegen/cmd/engine/generate.go b/x/imagegen/cmd/engine/generate.go new file mode 100644 index 000000000..506a48c54 --- /dev/null +++ b/x/imagegen/cmd/engine/generate.go @@ -0,0 +1,359 @@ +//go:build mlx + +package main + +import ( + "context" + "fmt" + "time" + "unicode/utf8" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// Dedicated stream for generation (like mlx-lm's generation_stream) +var generationStream *mlx.Stream + +// utf8Streamer buffers decoded text and emits only complete UTF-8 characters. +// This handles cases where tokenizers output partial multi-byte sequences. +type utf8Streamer struct { + buffer []byte +} + +// Write adds decoded text to the buffer and returns complete UTF-8 characters. +func (s *utf8Streamer) Write(text string) string { + s.buffer = append(s.buffer, text...) + + // Find the last position that ends with a complete UTF-8 character + validLen := 0 + for i := 0; i < len(s.buffer); { + r, size := utf8.DecodeRune(s.buffer[i:]) + if r == utf8.RuneError && size == 1 { + // Invalid or incomplete UTF-8 sequence at this position + // Check if it could be a valid start of a multi-byte sequence + if len(s.buffer)-i < 4 { + // Might be incomplete, keep it in buffer + break + } + // Definitely invalid, skip this byte + i++ + validLen = i + } else { + i += size + validLen = i + } + } + + if validLen == 0 { + return "" + } + + result := string(s.buffer[:validLen]) + s.buffer = s.buffer[validLen:] + return result +} + +// Flush returns any remaining buffered bytes (may be incomplete UTF-8). +func (s *utf8Streamer) Flush() string { + if len(s.buffer) == 0 { + return "" + } + result := string(s.buffer) + s.buffer = nil + return result +} + +func init() { + generationStream = mlx.NewStream() +} + +// withStream runs fn with the generation stream as default +func withStream(fn func()) { + orig := mlx.GetDefaultStream() + mlx.SetDefaultStream(generationStream) + fn() + mlx.SetDefaultStream(orig) +} + +type Model interface { + Tokenizer() *tokenizer.Tokenizer + VocabSize() int32 + NewCache(maxSeqLen int32) []cache.Cache + Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array +} + +// ChatModel is an optional interface for models that support chat formatting +type ChatModel interface { + FormatPrompt(prompt string) string +} + +// MultimodalModel is for models that support image input +type MultimodalModel interface { + Model + FormatPromptWithImage(prompt string) string + ExpandImageTokens(tokens []int32) []int32 + ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array + ImageSize() int32 // Returns expected image size for preprocessing +} + +// ImageLoader loads and preprocesses an image for multimodal models +// Returns nil if path is empty +type ImageLoader func(path string, imageSize int32) (*mlx.Array, error) + +type input struct { + Prompt string + Image *mlx.Array // Optional preprocessed image for multimodal models + MaxTokens int + Temperature float32 + TopP float32 + TopK int + WiredLimitGB int // Metal wired memory limit in GB (default 32) +} + +type output struct { + Text string + Done bool + PrefillTokSec float64 + GenTokSec float64 +} + +// Decoder wraps model + cache for autoregressive generation. +type Decoder struct { + model Model + caches []cache.Cache + vocabSize int32 + temp float32 + topK int + topP float32 + token *mlx.Array // Current token (kept across pools) + oldCacheState []*mlx.Array // Preallocated slice for old cache state + image *mlx.Array // Optional image for multimodal prefill +} + +func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder { + caches := m.NewCache(0) + return &Decoder{ + model: m, + caches: caches, + vocabSize: m.VocabSize(), + temp: temp, + topK: topK, + topP: topP, + oldCacheState: make([]*mlx.Array, 0, len(caches)*2), + } +} + +// SetImage sets the image for multimodal prefill (call before prefill) +func (d *Decoder) SetImage(img *mlx.Array) { + d.image = img +} + +func (d *Decoder) prefill(inputIDs []int32) int { + processed := 0 + + // Track old cache state to free after each chunk + var oldCacheState []*mlx.Array + + // For multimodal models with an image, we need to process all tokens together + // in the first forward pass so the image embeddings can be inserted properly. + // Skip chunking for multimodal prefill. + isMultimodal := d.image != nil + + // Process all-but-1 tokens in chunks, eval cache state for memory management + // Skip chunking for multimodal - process everything in the final step + if !isMultimodal { + for len(inputIDs) > 1 { + chunkSize := min(2048, len(inputIDs)-1) + if chunkSize <= 0 { + break + } + chunk := inputIDs[:chunkSize] + + // Save old cache state before forward + oldCacheState = oldCacheState[:0] + for _, c := range d.caches { + oldCacheState = append(oldCacheState, c.State()...) + } + + var cacheState []*mlx.Array + withStream(func() { + x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))}) + d.model.Forward(x, d.caches) + for _, c := range d.caches { + cacheState = append(cacheState, c.State()...) + } + }) + mlx.Eval(cacheState...) + + // Free old cache state + for _, arr := range oldCacheState { + if arr != nil { + arr.Free() + } + } + + inputIDs = inputIDs[chunkSize:] + processed += chunkSize + } + } + + // Save old cache state before final step + oldCacheState = oldCacheState[:0] + for _, c := range d.caches { + oldCacheState = append(oldCacheState, c.State()...) + } + + // Final token + sampling (or all tokens for multimodal) + withStream(func() { + x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))}) + mlx.Eval(x) // Materialize before any other evals + + var logits *mlx.Array + // Use ForwardWithImage if we have an image and model supports it + if d.image != nil { + if mm, ok := d.model.(MultimodalModel); ok { + logits = mm.ForwardWithImage(x, d.image, d.caches) + d.image = nil // Only use image for first forward + } else { + logits = d.model.Forward(x, d.caches) + } + } else { + logits = d.model.Forward(x, d.caches) + } + d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize) + }) + // Keep cache state (token auto-kept by AsyncEval) + for _, c := range d.caches { + mlx.Keep(c.State()...) + } + mlx.AsyncEval(d.token) + + // Free old cache state from before final step + for _, arr := range oldCacheState { + if arr != nil { + arr.Free() + } + } + + mlx.ClearCache() + + return processed + len(inputIDs) +} + +func (d *Decoder) step() int32 { + prevToken := d.token + + // Save old cache state (reuse preallocated slice) + d.oldCacheState = d.oldCacheState[:0] + for _, c := range d.caches { + d.oldCacheState = append(d.oldCacheState, c.State()...) + } + + withStream(func() { + logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches) + d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize) + }) + // Keep token and new cache state so they survive cleanup + mlx.Keep(d.token) + for _, c := range d.caches { + mlx.Keep(c.State()...) + } + mlx.AsyncEval(d.token) + + // Sync on previous token (GPU already working on next step) + val := prevToken.ItemInt32() + + // Free old token and old cache state + prevToken.Free() + for _, arr := range d.oldCacheState { + arr.Free() + } + return val +} + +func generate(ctx context.Context, m Model, in input, cb func(output)) error { + mlx.EnableCompile() + wiredLimit := in.WiredLimitGB + if wiredLimit <= 0 { + wiredLimit = 32 // default 32GB + } + mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30) + + temp := in.Temperature + if temp < 0 { + temp = 0.7 + } + + tok := m.Tokenizer() + dec := NewDecoder(m, temp, in.TopK, in.TopP) + + // Apply chat template - use image template if we have an image + prompt := in.Prompt + var tokens []int32 + if mm, ok := m.(MultimodalModel); ok && in.Image != nil { + prompt = mm.FormatPromptWithImage(prompt) + tokens = tok.Encode(prompt, true) + tokens = mm.ExpandImageTokens(tokens) // Expand to 256 image tokens + dec.SetImage(in.Image) + } else if cm, ok := m.(ChatModel); ok { + prompt = cm.FormatPrompt(prompt) + tokens = tok.Encode(prompt, true) + } else { + tokens = tok.Encode(prompt, true) + } + + prefillStart := time.Now() + prefillTokens := dec.prefill(tokens) + // Prefill measurement should include time to first token (like mlx-lm) + // Step() waits for prefill to complete and returns first token + firstToken := dec.step() + prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds() + + genStart := time.Now() + maxTokens := max(in.MaxTokens, 100) + var genTokens int + + // UTF-8 streamer to handle partial multi-byte characters + streamer := &utf8Streamer{} + + // Handle first token + genTokens++ + if tok.IsEOS(firstToken) { + cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0}) + return nil + } + if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" { + cb(output{Text: text}) + } + + for n := 1; n < maxTokens; n++ { + if ctx.Err() != nil { + return ctx.Err() + } + token := dec.step() + genTokens++ + + if tok.IsEOS(token) { + break + } + if text := streamer.Write(tok.Decode([]int32{token})); text != "" { + cb(output{Text: text}) + } + + if n%256 == 0 { + mlx.ClearCache() + } + } + + // Flush any remaining buffered bytes + if text := streamer.Flush(); text != "" { + cb(output{Text: text}) + } + + fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30)) + cb(output{Done: true, PrefillTokSec: prefillTokSec, + GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()}) + return nil +} diff --git a/x/imagegen/cmd/engine/image.go b/x/imagegen/cmd/engine/image.go new file mode 100644 index 000000000..e8af2222a --- /dev/null +++ b/x/imagegen/cmd/engine/image.go @@ -0,0 +1,89 @@ +//go:build mlx + +package main + +import ( + "fmt" + "image" + "image/png" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// saveImageArray saves an MLX array as a PNG image. +// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB). +func saveImageArray(arr *mlx.Array, path string) error { + img, err := arrayToImage(arr) + if err != nil { + return err + } + return savePNG(img, path) +} + +func savePNG(img *image.RGBA, path string) error { + if filepath.Ext(path) != ".png" { + path = path + ".png" + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return png.Encode(f, img) +} + +func arrayToImage(arr *mlx.Array) (*image.RGBA, error) { + shape := arr.Shape() + if len(shape) != 4 { + return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape) + } + + // Transform to [H, W, C] for image conversion + img := mlx.Squeeze(arr, 0) + arr.Free() + img = mlx.Transpose(img, 1, 2, 0) + img = mlx.Contiguous(img) + mlx.Eval(img) + + imgShape := img.Shape() + H := int(imgShape[0]) + W := int(imgShape[1]) + C := int(imgShape[2]) + + if C != 3 { + img.Free() + return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C) + } + + // Copy to CPU and free GPU memory + data := img.Data() + img.Free() + + // Write directly to Pix slice (faster than SetRGBA) + goImg := image.NewRGBA(image.Rect(0, 0, W, H)) + pix := goImg.Pix + for y := 0; y < H; y++ { + for x := 0; x < W; x++ { + srcIdx := (y*W + x) * C + dstIdx := (y*W + x) * 4 + pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255)) + pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255)) + pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255)) + pix[dstIdx+3] = 255 + } + } + + return goImg, nil +} + +func clampF(v, min, max float32) float32 { + if v < min { + return min + } + if v > max { + return max + } + return v +} diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go new file mode 100644 index 000000000..f2fca5450 --- /dev/null +++ b/x/imagegen/cmd/engine/main.go @@ -0,0 +1,286 @@ +//go:build mlx + +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "runtime/pprof" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/models/gemma3" + "github.com/ollama/ollama/x/imagegen/models/gpt_oss" + "github.com/ollama/ollama/x/imagegen/models/llama" + "github.com/ollama/ollama/x/imagegen/models/qwen_image" + "github.com/ollama/ollama/x/imagegen/models/qwen_image_edit" + "github.com/ollama/ollama/x/imagegen/models/zimage" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// stringSlice is a flag type that accumulates multiple values +type stringSlice []string + +func (s *stringSlice) String() string { + return fmt.Sprintf("%v", *s) +} + +func (s *stringSlice) Set(value string) error { + *s = append(*s, value) + return nil +} + +func main() { + modelPath := flag.String("model", "", "Model directory") + prompt := flag.String("prompt", "Hello", "Prompt") + + // Text generation params + maxTokens := flag.Int("max-tokens", 100, "Max tokens") + temperature := flag.Float64("temperature", 0.7, "Temperature") + topP := flag.Float64("top-p", 0.9, "Top-p sampling") + topK := flag.Int("top-k", 40, "Top-k sampling") + imagePath := flag.String("image", "", "Image path for multimodal models") + + // Image generation params + width := flag.Int("width", 1024, "Image width") + height := flag.Int("height", 1024, "Image height") + steps := flag.Int("steps", 9, "Denoising steps") + seed := flag.Int64("seed", 42, "Random seed") + out := flag.String("output", "output.png", "Output path") + + // Utility flags + listTensors := flag.Bool("list", false, "List tensors only") + cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file") + gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)") + layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.") + wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB") + + // Legacy mode flags + zimageFlag := flag.Bool("zimage", false, "Z-Image generation") + qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation") + qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing") + var inputImages stringSlice + flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)") + negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)") + cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing") + + flag.Parse() + + if *modelPath == "" { + flag.Usage() + return + } + + // CPU profiling + if *cpuProfile != "" { + f, err := os.Create(*cpuProfile) + if err != nil { + log.Fatal(err) + } + defer f.Close() + if err := pprof.StartCPUProfile(f); err != nil { + log.Fatal(err) + } + defer pprof.StopCPUProfile() + } + + var err error + + // Handle legacy mode flags that aren't unified yet + switch { + case *zimageFlag: + m := &zimage.Model{} + if loadErr := m.Load(*modelPath); loadErr != nil { + log.Fatal(loadErr) + } + var img *mlx.Array + img, err = m.GenerateFromConfig(&zimage.GenerateConfig{ + Prompt: *prompt, + Width: int32(*width), + Height: int32(*height), + Steps: *steps, + Seed: *seed, + CapturePath: *gpuCapture, + LayerCache: *layerCache, + }) + if err == nil { + err = saveImageArray(img, *out) + } + case *qwenImage: + m, loadErr := qwen_image.LoadPersistent(*modelPath) + if loadErr != nil { + log.Fatal(loadErr) + } + var img *mlx.Array + img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{ + Prompt: *prompt, + NegativePrompt: *negativePrompt, + CFGScale: float32(*cfgScale), + Width: int32(*width), + Height: int32(*height), + Steps: *steps, + Seed: *seed, + LayerCache: *layerCache, + }) + if err == nil { + err = saveImageArray(img, *out) + } + case *qwenImageEdit: + if len(inputImages) == 0 { + log.Fatal("qwen-image-edit requires at least one -input-image") + } + + m, loadErr := qwen_image_edit.LoadPersistent(*modelPath) + if loadErr != nil { + log.Fatal(loadErr) + } + // For image editing, use 0 for dimensions to auto-detect from input image + // unless explicitly overridden from defaults + editWidth := int32(0) + editHeight := int32(0) + if *width != 1024 { + editWidth = int32(*width) + } + if *height != 1024 { + editHeight = int32(*height) + } + + cfg := &qwen_image_edit.GenerateConfig{ + Prompt: *prompt, + NegativePrompt: *negativePrompt, + CFGScale: float32(*cfgScale), + Width: editWidth, + Height: editHeight, + Steps: *steps, + Seed: *seed, + } + + var img *mlx.Array + img, err = m.EditFromConfig(inputImages, cfg) + if err == nil { + err = saveImageArray(img, *out) + } + case *listTensors: + err = listModelTensors(*modelPath) + default: + // llm path + m, err := load(*modelPath) + if err != nil { + log.Fatal(err) + } + + // Load image if provided and model supports it + var image *mlx.Array + if *imagePath != "" { + if mm, ok := m.(interface{ ImageSize() int32 }); ok { + image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize()) + if err != nil { + log.Fatal("load image:", err) + } + } else { + log.Fatal("model does not support image input") + } + } + + err = generate(context.Background(), m, input{ + Prompt: *prompt, + Image: image, + MaxTokens: *maxTokens, + Temperature: float32(*temperature), + TopP: float32(*topP), + TopK: *topK, + WiredLimitGB: *wiredLimitGB, + }, func(out output) { + if out.Text != "" { + fmt.Print(out.Text) + } + if out.Done { + fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec) + } + }) + } + + if err != nil { + log.Fatal(err) + } +} + +func listModelTensors(modelPath string) error { + weights, err := safetensors.LoadModelWeights(modelPath) + if err != nil { + return err + } + for _, name := range weights.ListTensors() { + info, _ := weights.GetTensorInfo(name) + fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype) + } + return nil +} + +// loadModel builds and evaluates a model using the common load pattern. +// Release safetensors BEFORE eval - lazy arrays have captured their data, +// and this reduces peak memory by ~6GB (matches mlx-lm behavior). +func loadModel[T Model](build func() T, cleanup func()) T { + m := build() + weights := mlx.Collect(m) + cleanup() + mlx.Eval(weights...) + return m +} + +func load(modelPath string) (Model, error) { + kind, err := detectModelKind(modelPath) + if err != nil { + return nil, fmt.Errorf("detect model kind: %w", err) + } + + switch kind { + case "gpt_oss": + return gpt_oss.Load(modelPath) + case "gemma3": + return gemma3.Load(modelPath) + case "gemma3_text": + return gemma3.LoadText(modelPath) + default: + return llama.Load(modelPath) + } +} + +func detectModelKind(modelPath string) (string, error) { + indexPath := filepath.Join(modelPath, "model_index.json") + if _, err := os.Stat(indexPath); err == nil { + data, err := os.ReadFile(indexPath) + if err != nil { + return "zimage", nil + } + var index struct { + ClassName string `json:"_class_name"` + } + if err := json.Unmarshal(data, &index); err == nil { + switch index.ClassName { + case "FluxPipeline", "ZImagePipeline": + return "zimage", nil + } + } + return "zimage", nil + } + + configPath := filepath.Join(modelPath, "config.json") + data, err := os.ReadFile(configPath) + if err != nil { + return "", fmt.Errorf("no config.json or model_index.json found: %w", err) + } + + var cfg struct { + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return "", fmt.Errorf("parse config.json: %w", err) + } + + return cfg.ModelType, nil +} diff --git a/x/imagegen/cmd/engine/sample.go b/x/imagegen/cmd/engine/sample.go new file mode 100644 index 000000000..5d723e6dc --- /dev/null +++ b/x/imagegen/cmd/engine/sample.go @@ -0,0 +1,49 @@ +//go:build mlx + +package main + +import "github.com/ollama/ollama/x/imagegen/mlx" + +// sampleTopK samples from top-k logits using global random state +func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array { + neg := mlx.Neg(scaledLogits) + indices := mlx.Argpartition(neg, k-1, -1) + topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)}) + values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1) + sampled := mlx.RandomCategorical(values, -1, 1) + return mlx.Take(topKIdx, sampled, -1) +} + +// sampleTopP samples using nucleus sampling with global random state +func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array { + sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1) + sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1) + probs := mlx.Softmax(sortedLogits, -1) + cumProbs := mlx.Cumsum(probs, -1) + mask := mlx.LessScalar(cumProbs, p) + negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize) + masked := mlx.Where(mask, sortedLogits, negInf) + sampled := mlx.RandomCategorical(masked, -1, 1) + return mlx.Take(sorted, sampled, -1) +} + +// sample samples from logits at the last position +func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array { + // Get last position logits: [1, L, vocab] -> [vocab] + shape := logits.Shape() + seqLen := shape[1] + lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab}) + lastLogits = mlx.Reshape(lastLogits, vocab) + + if temp == 0 { + return mlx.Argmax(lastLogits, -1, false) + } + scaled := mlx.DivScalar(lastLogits, temp) + if topK > 0 && topK < int(vocab) { + return sampleTopK(scaled, topK) + } + if topP > 0 && topP < 1.0 { + return sampleTopP(scaled, topP, vocab) + } + return mlx.RandomCategorical(scaled, -1, 1) +} diff --git a/x/imagegen/mlx/README.md b/x/imagegen/mlx/README.md new file mode 100644 index 000000000..3a2f7b3d8 --- /dev/null +++ b/x/imagegen/mlx/README.md @@ -0,0 +1,46 @@ +# MLX Memory Management + +| This package will get consolidated with `x/ml/backend/mlx` in the future. + +## Automatic Tracking + +All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed. + +### API + +```go +result := mlx.Matmul(x, w) // arrays automatically tracked +mlx.Eval(result) // free non-kept, eval result (auto-kept) +``` + +### Key Functions + +- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept) +- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept) +- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches) +- `array.Free()` - mark array for cleanup on next Eval + +### Loop Pattern + +```go +for step := 0; step < maxTokens; step++ { + logits := model.Forward(token, caches) + oldToken := token + token = sample(logits) + + // Keep cache state across iterations + for _, c := range caches { + mlx.Keep(c.State()...) + } + + oldToken.Free() // mark for cleanup + mlx.AsyncEval(token) // frees old, evals new +} +``` + +### Notes + +- `Eval()` and `AsyncEval()` auto-keep their outputs +- `Free()` marks for cleanup - actual free happens during next Eval +- Use `Keep()` for weights and cache state that must survive multiple Eval cycles +- Arrays created inside compiled closures are managed by MLX, not tracked diff --git a/x/imagegen/mlx/compile.go b/x/imagegen/mlx/compile.go new file mode 100644 index 000000000..36de65c5f --- /dev/null +++ b/x/imagegen/mlx/compile.go @@ -0,0 +1,173 @@ +//go:build mlx + +package mlx + +/* +#include "mlx/c/mlx.h" +#include + +// Forward declaration for Go callback +extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload); + +// Destructor for payload (Go handle) +extern void goClosureDestructor(void* payload); +*/ +import "C" +import ( + "runtime/cgo" + "sync" + "unsafe" +) + +// inClosureCallback is set to true during closure callback execution. +var inClosureCallback bool +var closureCallbackMu sync.Mutex + +// InClosureCallback returns true if we're currently executing inside a closure callback. +func InClosureCallback() bool { + closureCallbackMu.Lock() + defer closureCallbackMu.Unlock() + return inClosureCallback +} + +// CompiledFunc is a compiled MLX function that can be called efficiently. +// All intermediate arrays during execution stay inside MLX - only inputs +// and outputs cross the Go boundary. +type CompiledFunc struct { + closure C.mlx_closure + compiled C.mlx_closure +} + +// ClosureFunc is the signature for functions that can be compiled. +// It takes a slice of input arrays and returns a slice of output arrays. +type ClosureFunc func(inputs []*Array) []*Array + +// Compile compiles a Go function into an optimized MLX closure. +// The function is traced once during compilation, then subsequent calls +// run the optimized graph without creating Go intermediate arrays. +// +// Example: +// +// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array { +// a, b := inputs[0], inputs[1] +// c := mlx.Add(a, b) +// d := mlx.Mul(c, c) +// return []*mlx.Array{d} +// }) +// defer compiled.Free() +// +// result := compiled.Call(x, y)[0] +func Compile(fn ClosureFunc) *CompiledFunc { + return CompileShapeless(fn, false) +} + +// CompileShapeless compiles with optional shapeless mode. +// If shapeless=true, the function works for any input shape after tracing. +func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc { + // Create a cgo.Handle to prevent the Go function from being GC'd + handle := cgo.NewHandle(fn) + + // Create the closure from the Go callback + closure := C.mlx_closure_new_func_payload( + (*[0]byte)(C.goClosureCallback), + unsafe.Pointer(handle), + (*[0]byte)(C.goClosureDestructor), + ) + + // Compile the closure + compiled := C.mlx_closure_new() + C.mlx_compile(&compiled, closure, C.bool(shapeless)) + + return &CompiledFunc{ + closure: closure, + compiled: compiled, + } +} + +// Call invokes the compiled function with the given inputs. +func (cf *CompiledFunc) Call(inputs ...*Array) []*Array { + // Pack inputs into vector + inputVec := C.mlx_vector_array_new() + for _, arr := range inputs { + C.mlx_vector_array_append_value(inputVec, arr.c) + } + + // Apply compiled closure + outputVec := C.mlx_vector_array_new() + C.mlx_closure_apply(&outputVec, cf.compiled, inputVec) + C.mlx_vector_array_free(inputVec) + + // Unpack outputs + numOutputs := int(C.mlx_vector_array_size(outputVec)) + outputs := make([]*Array, numOutputs) + for i := 0; i < numOutputs; i++ { + var arr C.mlx_array + C.mlx_vector_array_get(&arr, outputVec, C.size_t(i)) + outputs[i] = newArray(arr) + } + C.mlx_vector_array_free(outputVec) + + return outputs +} + +// CallEval invokes the compiled function and evaluates the results. +func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array { + outputs := cf.Call(inputs...) + Eval(outputs...) + return outputs +} + +// Free releases the compiled function resources. +func (cf *CompiledFunc) Free() { + C.mlx_closure_free(cf.compiled) + C.mlx_closure_free(cf.closure) +} + +// borrowArray wraps a C array WITHOUT setting up GC cleanup. +// Use this for arrays we don't own (e.g., borrowed references in callbacks). +func borrowArray(array C.mlx_array) *Array { + return &Array{c: array} +} + +//export goClosureCallback +func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int { + // Set flag to disable AddCleanup during callback + closureCallbackMu.Lock() + inClosureCallback = true + closureCallbackMu.Unlock() + defer func() { + closureCallbackMu.Lock() + inClosureCallback = false + closureCallbackMu.Unlock() + }() + + // Recover the Go function from the handle + handle := cgo.Handle(payload) + fn := handle.Value().(ClosureFunc) + + // Convert input vector to Go slice - use borrowArray since MLX owns these + numInputs := int(C.mlx_vector_array_size(input)) + inputs := make([]*Array, numInputs) + for i := 0; i < numInputs; i++ { + var arr C.mlx_array + C.mlx_vector_array_get(&arr, input, C.size_t(i)) + inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these + } + + // Call the Go function + outputs := fn(inputs) + + // Build output vector + *res = C.mlx_vector_array_new() + for _, arr := range outputs { + C.mlx_vector_array_append_value(*res, arr.c) + } + + return 0 +} + +//export goClosureDestructor +func goClosureDestructor(payload unsafe.Pointer) { + handle := cgo.Handle(payload) + handle.Delete() +} diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go new file mode 100644 index 000000000..2c172196c --- /dev/null +++ b/x/imagegen/mlx/mlx.go @@ -0,0 +1,2238 @@ +//go:build mlx + +package mlx + +/* +#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src +#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/ +#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate +#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc + +#include "mlx/c/mlx.h" +#include +#include + +// Cached default GPU stream for all ops +static mlx_stream _default_stream = {0}; +static mlx_stream _cpu_stream = {0}; + +static inline mlx_stream default_stream() { + if (_default_stream.ctx == NULL) { + _default_stream = mlx_default_gpu_stream_new(); + } + return _default_stream; +} + +static inline void set_default_stream(mlx_stream s) { + _default_stream = s; +} + +// CPU stream for file loading (Load primitive only runs on CPU) +static inline mlx_stream cpu_stream() { + if (_cpu_stream.ctx == NULL) { + _cpu_stream = mlx_default_cpu_stream_new(); + } + return _cpu_stream; +} + +// CGO noescape/nocallback hints to reduce CGO overhead +// noescape: pointers won't escape, no heap allocation needed +// nocallback: function won't call back into Go +#cgo noescape mlx_add +#cgo nocallback mlx_add +#cgo noescape mlx_subtract +#cgo nocallback mlx_subtract +#cgo noescape mlx_multiply +#cgo nocallback mlx_multiply +#cgo noescape mlx_divide +#cgo nocallback mlx_divide +#cgo noescape mlx_negative +#cgo nocallback mlx_negative +#cgo noescape mlx_abs +#cgo nocallback mlx_abs +#cgo noescape mlx_exp +#cgo nocallback mlx_exp +#cgo noescape mlx_log +#cgo nocallback mlx_log +#cgo noescape mlx_sqrt +#cgo nocallback mlx_sqrt +#cgo noescape mlx_rsqrt +#cgo nocallback mlx_rsqrt +#cgo noescape mlx_square +#cgo nocallback mlx_square +#cgo noescape mlx_power +#cgo nocallback mlx_power +#cgo noescape mlx_erf +#cgo nocallback mlx_erf +#cgo noescape mlx_sigmoid +#cgo nocallback mlx_sigmoid +#cgo noescape mlx_tanh +#cgo nocallback mlx_tanh +#cgo noescape mlx_sin +#cgo nocallback mlx_sin +#cgo noescape mlx_cos +#cgo nocallback mlx_cos +#cgo noescape mlx_maximum +#cgo nocallback mlx_maximum +#cgo noescape mlx_minimum +#cgo nocallback mlx_minimum +#cgo noescape mlx_clip +#cgo nocallback mlx_clip +#cgo noescape mlx_sum +#cgo nocallback mlx_sum +#cgo noescape mlx_sum_axis +#cgo nocallback mlx_sum_axis +#cgo noescape mlx_mean +#cgo nocallback mlx_mean +#cgo noescape mlx_mean_axis +#cgo nocallback mlx_mean_axis +#cgo noescape mlx_var_axis +#cgo nocallback mlx_var_axis +#cgo noescape mlx_argmax +#cgo nocallback mlx_argmax +#cgo noescape mlx_argmax_axis +#cgo nocallback mlx_argmax_axis +#cgo noescape mlx_softmax_axis +#cgo nocallback mlx_softmax_axis +#cgo noescape mlx_cumsum +#cgo nocallback mlx_cumsum +#cgo noescape mlx_matmul +#cgo nocallback mlx_matmul +#cgo noescape mlx_addmm +#cgo nocallback mlx_addmm +#cgo noescape mlx_gather_mm +#cgo nocallback mlx_gather_mm +#cgo noescape mlx_gather_qmm +#cgo nocallback mlx_gather_qmm +#cgo noescape mlx_reshape +#cgo nocallback mlx_reshape +#cgo noescape mlx_transpose_axes +#cgo nocallback mlx_transpose_axes +#cgo noescape mlx_expand_dims +#cgo nocallback mlx_expand_dims +#cgo noescape mlx_squeeze_axis +#cgo nocallback mlx_squeeze_axis +#cgo noescape mlx_flatten +#cgo nocallback mlx_flatten +#cgo noescape mlx_concatenate_axis +#cgo nocallback mlx_concatenate_axis +#cgo noescape mlx_slice +#cgo nocallback mlx_slice +#cgo noescape mlx_slice_update +#cgo nocallback mlx_slice_update +#cgo noescape mlx_as_strided +#cgo nocallback mlx_as_strided +#cgo noescape mlx_view +#cgo nocallback mlx_view +#cgo noescape mlx_contiguous +#cgo nocallback mlx_contiguous +#cgo noescape mlx_pad +#cgo nocallback mlx_pad +#cgo noescape mlx_tile +#cgo nocallback mlx_tile +#cgo noescape mlx_take_axis +#cgo nocallback mlx_take_axis +#cgo noescape mlx_take_along_axis +#cgo nocallback mlx_take_along_axis +#cgo noescape mlx_put_along_axis +#cgo nocallback mlx_put_along_axis +#cgo noescape mlx_where +#cgo nocallback mlx_where +#cgo noescape mlx_argsort_axis +#cgo nocallback mlx_argsort_axis +#cgo noescape mlx_argpartition_axis +#cgo nocallback mlx_argpartition_axis +#cgo noescape mlx_topk_axis +#cgo nocallback mlx_topk_axis +#cgo noescape mlx_less +#cgo nocallback mlx_less +#cgo noescape mlx_greater_equal +#cgo nocallback mlx_greater_equal +#cgo noescape mlx_logical_and +#cgo nocallback mlx_logical_and +#cgo noescape mlx_zeros +#cgo nocallback mlx_zeros +#cgo noescape mlx_zeros_like +#cgo nocallback mlx_zeros_like +#cgo noescape mlx_ones +#cgo nocallback mlx_ones +#cgo noescape mlx_full +#cgo nocallback mlx_full +#cgo noescape mlx_arange +#cgo nocallback mlx_arange +#cgo noescape mlx_linspace +#cgo nocallback mlx_linspace +#cgo noescape mlx_tri +#cgo nocallback mlx_tri +#cgo noescape mlx_astype +#cgo nocallback mlx_astype +#cgo noescape mlx_fast_rms_norm +#cgo nocallback mlx_fast_rms_norm +#cgo noescape mlx_fast_rope +#cgo nocallback mlx_fast_rope +#cgo noescape mlx_fast_scaled_dot_product_attention +#cgo nocallback mlx_fast_scaled_dot_product_attention +#cgo noescape mlx_conv2d +#cgo nocallback mlx_conv2d +#cgo noescape mlx_conv3d +#cgo nocallback mlx_conv3d +#cgo noescape mlx_random_key +#cgo nocallback mlx_random_key +#cgo noescape mlx_random_split +#cgo nocallback mlx_random_split +#cgo noescape mlx_random_categorical_num_samples +#cgo nocallback mlx_random_categorical_num_samples +#cgo noescape mlx_random_normal +#cgo nocallback mlx_random_normal +#cgo noescape mlx_random_uniform +#cgo nocallback mlx_random_uniform +#cgo noescape mlx_array_eval +#cgo nocallback mlx_array_eval +#cgo noescape mlx_eval +#cgo nocallback mlx_eval +#cgo noescape mlx_async_eval +#cgo nocallback mlx_async_eval +#cgo noescape mlx_synchronize +#cgo nocallback mlx_synchronize +#cgo noescape mlx_array_new +#cgo nocallback mlx_array_new +#cgo noescape mlx_array_new_data +#cgo nocallback mlx_array_new_data +#cgo noescape mlx_array_new_float +#cgo nocallback mlx_array_new_float +#cgo noescape mlx_array_free +#cgo nocallback mlx_array_free +#cgo noescape mlx_array_size +#cgo nocallback mlx_array_size +#cgo noescape mlx_array_ndim +#cgo nocallback mlx_array_ndim +#cgo noescape mlx_array_dim +#cgo nocallback mlx_array_dim +#cgo noescape mlx_array_dtype +#cgo nocallback mlx_array_dtype +#cgo noescape mlx_array_item_int32 +#cgo nocallback mlx_array_item_int32 +#cgo noescape mlx_vector_array_new_data +#cgo nocallback mlx_vector_array_new_data +#cgo noescape mlx_vector_array_free +#cgo nocallback mlx_vector_array_free +#cgo noescape mlx_array_new_int +#cgo nocallback mlx_array_new_int +#cgo noescape mlx_stream_new_device +#cgo nocallback mlx_stream_new_device +#cgo noescape mlx_get_default_stream +#cgo nocallback mlx_get_default_stream +#cgo noescape mlx_set_default_stream +#cgo nocallback mlx_set_default_stream +*/ +import "C" +import ( + "fmt" + "reflect" + "runtime" + "sync" + "sync/atomic" + "time" + "unsafe" +) + +// Dtype represents MLX data types +type Dtype int + +const ( + DtypeBool Dtype = C.MLX_BOOL + DtypeUint8 Dtype = C.MLX_UINT8 + DtypeUint16 Dtype = C.MLX_UINT16 + DtypeUint32 Dtype = C.MLX_UINT32 + DtypeUint64 Dtype = C.MLX_UINT64 + DtypeInt8 Dtype = C.MLX_INT8 + DtypeInt16 Dtype = C.MLX_INT16 + DtypeInt32 Dtype = C.MLX_INT32 + DtypeInt64 Dtype = C.MLX_INT64 + DtypeFloat16 Dtype = C.MLX_FLOAT16 + DtypeFloat32 Dtype = C.MLX_FLOAT32 + DtypeFloat64 Dtype = C.MLX_FLOAT64 + DtypeBFloat16 Dtype = C.MLX_BFLOAT16 + DtypeComplex64 Dtype = C.MLX_COMPLEX64 +) + +// String implements fmt.Stringer for Dtype +func (d Dtype) String() string { + switch d { + case DtypeBool: + return "bool" + case DtypeUint8: + return "u8" + case DtypeUint16: + return "u16" + case DtypeUint32: + return "u32" + case DtypeUint64: + return "u64" + case DtypeInt8: + return "i8" + case DtypeInt16: + return "i16" + case DtypeInt32: + return "i32" + case DtypeInt64: + return "i64" + case DtypeFloat16: + return "f16" + case DtypeFloat32: + return "f32" + case DtypeFloat64: + return "f64" + case DtypeBFloat16: + return "bf16" + case DtypeComplex64: + return "c64" + default: + return "unknown" + } +} + +// Memory Management: +// +// All arrays are automatically tracked for cleanup. On Eval(), non-kept arrays are freed. +// +// x := mlx.Matmul(input, weight) // x is tracked for cleanup +// mlx.Keep(x) // mark x as persistent +// mlx.Eval(x) // eval + free non-kept arrays +// +// Use Keep() for arrays that should persist (weights, caches). +// Use Free() to mark a kept array for cleanup on next Eval(). +// +// Note: Not goroutine-safe. Use from a single goroutine. + +// Array wraps an MLX array handle. +// Arrays are freed via Eval() cleanup (deterministic) or GC (fallback). +type Array struct { + c C.mlx_array + freed bool // Prevents double-free + kept bool // If true, survives Eval() cleanup +} + +// arrays tracks all live arrays. On Eval(), non-kept arrays are freed. +// Not goroutine-safe. +var arrays = make([]*Array, 0, 4096) + +// evalHandles is a pre-allocated slice for passing arrays to MLX eval. +var evalHandles = make([]C.mlx_array, 0, 64) + +// arrayPool reduces allocations for intermediate arrays +var arrayPool = sync.Pool{ + New: func() any { return &Array{} }, +} + +func newArray(array C.mlx_array) *Array { + // In compiled closures, MLX manages memory - skip Go tracking + if InClosureCallback() { + return &Array{c: array} + } + + // Use pooled Array struct for efficiency + a := arrayPool.Get().(*Array) + a.c = array + a.freed = false + a.kept = false + + // Track in global list + arrays = append(arrays, a) + + return a +} + +// Collect uses reflection to find all *Array fields in a struct (recursively). +// Use this to automatically gather model weights, cache state, etc. +func Collect(v any) []*Array { + var arrays []*Array + seen := make(map[uintptr]bool) + collect(reflect.ValueOf(v), &arrays, seen) + return arrays +} + +func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) { + if !v.IsValid() { + return + } + + // Handle pointers + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return + } + // Avoid infinite loops + ptr := v.Pointer() + if seen[ptr] { + return + } + seen[ptr] = true + + // Check if it's *Array + if arr, ok := v.Interface().(*Array); ok { + if arr != nil && arr.c.ctx != nil { + *arrays = append(*arrays, arr) + } + return + } + collect(v.Elem(), arrays, seen) + return + } + + // Handle structs + if v.Kind() == reflect.Struct { + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.CanInterface() { + collect(field, arrays, seen) + } + } + return + } + + // Handle slices + if v.Kind() == reflect.Slice { + for i := 0; i < v.Len(); i++ { + collect(v.Index(i), arrays, seen) + } + return + } + + // Handle maps + if v.Kind() == reflect.Map { + for _, key := range v.MapKeys() { + collect(v.MapIndex(key), arrays, seen) + } + return + } + + // Handle interfaces + if v.Kind() == reflect.Interface { + if !v.IsNil() { + collect(v.Elem(), arrays, seen) + } + return + } +} + +// FreeStruct releases all *Array fields in a struct (recursively). +// Use this to free model weights when unloading a model. +func FreeStruct(v any) { + for _, arr := range Collect(v) { + arr.Free() + } +} + +// Keep marks arrays to persist across Eval() cleanup. +// Kept arrays will NOT be freed when Eval() runs cleanup. +func Keep(arrays ...*Array) { + for _, a := range arrays { + if a != nil { + a.kept = true + } + } +} + +// cleanup frees non-kept arrays and compacts the live array list. +// Returns number of arrays freed. +func cleanup() int { + freed := 0 + n := 0 + for _, a := range arrays { + if a.kept { + arrays[n] = a + n++ + } else if a.c.ctx != nil && !a.freed { + C.mlx_array_free(a.c) + a.c.ctx = nil + arrayPool.Put(a) + freed++ + } + } + arrays = arrays[:n] + return freed +} + +// DebugArrays prints summary info about all tracked arrays. +func DebugArrays() { + var totalBytes int64 + var keptCount, unkeptCount int + for _, a := range arrays { + if a.kept { + keptCount++ + } else { + unkeptCount++ + } + totalBytes += a.Nbytes() + } + fmt.Printf("[DEBUG] Arrays: %d kept, %d unkept, %.2f GB total\n", + keptCount, unkeptCount, float64(totalBytes)/(1024*1024*1024)) +} + +// DebugArraysVerbose prints detailed info about all tracked arrays, sorted by size. +func DebugArraysVerbose(topN int) { + type arrayInfo struct { + shape []int32 + dtype Dtype + bytes int64 + kept bool + } + + var infos []arrayInfo + var totalBytes int64 + for _, a := range arrays { + bytes := a.Nbytes() + infos = append(infos, arrayInfo{ + shape: a.Shape(), + dtype: a.Dtype(), + bytes: bytes, + kept: a.kept, + }) + totalBytes += bytes + } + + // Sort by size descending + for i := 0; i < len(infos)-1; i++ { + for j := i + 1; j < len(infos); j++ { + if infos[j].bytes > infos[i].bytes { + infos[i], infos[j] = infos[j], infos[i] + } + } + } + + fmt.Printf("[DEBUG] %d arrays, %.2f GB total:\n", len(infos), float64(totalBytes)/(1024*1024*1024)) + for i, info := range infos { + if i >= topN { + break + } + keptStr := "" + if info.kept { + keptStr = " [kept]" + } + fmt.Printf(" %3d. %8.2f MB %v %v%s\n", + i+1, float64(info.bytes)/(1024*1024), info.shape, info.dtype, keptStr) + } +} + +// Eval synchronously evaluates arrays and cleans up non-kept arrays. +// Outputs are automatically kept (survive cleanup). Returns them for chaining. +func Eval(outputs ...*Array) []*Array { + // Keep outputs so cleanup doesn't free them + for _, o := range outputs { + if o != nil { + o.kept = true + } + } + + // Cleanup non-kept arrays + cleanup() + + // Then evaluate + if len(outputs) > 0 { + evalHandles = evalHandles[:0] + for _, o := range outputs { + if o != nil { + evalHandles = append(evalHandles, o.c) + } + } + if len(evalHandles) > 0 { + vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles))) + C.mlx_eval(vec) + C.mlx_vector_array_free(vec) + } + } + return outputs +} + +// AsyncEval dispatches async evaluation and cleans up non-kept arrays. +// Outputs are automatically kept (survive cleanup). +func AsyncEval(outputs ...*Array) { + // Keep outputs so cleanup doesn't free them + for _, o := range outputs { + if o != nil { + o.kept = true + } + } + + // Cleanup non-kept arrays + cleanup() + + // Then dispatch async eval + if len(outputs) > 0 { + evalHandles = evalHandles[:0] + for _, o := range outputs { + if o != nil { + evalHandles = append(evalHandles, o.c) + } + } + if len(evalHandles) > 0 { + vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles))) + C.mlx_async_eval(vec) + C.mlx_vector_array_free(vec) + } + } +} + +// Sync waits for all async operations to complete (no cleanup). +func Sync() { + C.mlx_synchronize(C.default_stream()) +} + +// Free marks this array for cleanup on the next Eval(). +// The array is not immediately freed - cleanup happens during Eval(). +// +// Pattern for loops: +// +// oldLatents.Free() // mark for cleanup +// mlx.Eval(newLatents) // frees old, evals new +func (a *Array) Free() { + if a != nil { + a.kept = false + } +} + +// Eval evaluates this single array and runs cleanup. +func (a *Array) Eval() *Array { + Eval(a) + return a +} + +// Valid returns true if the array hasn't been freed. +func (a *Array) Valid() bool { + return a != nil && a.c.ctx != nil +} + +func int32ToCInt(s []int32) *C.int { + if len(s) == 0 { + return nil + } + return (*C.int)(unsafe.Pointer(&s[0])) +} + +// NewArray creates a new MLX array from float32 data +func NewArray(data []float32, shape []int32) *Array { + handle := C.mlx_array_new_data( + unsafe.Pointer(&data[0]), + int32ToCInt(shape), + C.int(len(shape)), + C.MLX_FLOAT32, + ) + return newArray(handle) +} + +// NewArrayInt32 creates a new MLX array from int32 data +func NewArrayInt32(data []int32, shape []int32) *Array { + handle := C.mlx_array_new_data( + unsafe.Pointer(&data[0]), + int32ToCInt(shape), + C.int(len(shape)), + C.MLX_INT32, + ) + return newArray(handle) +} + +// NewArrayFloat32 creates a new float32 array from data +func NewArrayFloat32(data []float32, shape []int32) *Array { + return NewArray(data, shape) +} + +// Zeros creates an array of zeros with optional dtype (default float32) +func Zeros(shape []int32, dtype ...Dtype) *Array { + res := C.mlx_array_new() + dt := DtypeFloat32 + if len(dtype) > 0 { + dt = dtype[0] + } + C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dt), C.default_stream()) + return newArray(res) +} + +// ZerosLike creates a zeros array with the same dtype as a. +// If shape is provided, uses that shape; otherwise uses a's shape. +func ZerosLike(a *Array, shape ...int32) *Array { + res := C.mlx_array_new() + if len(shape) == 0 { + C.mlx_zeros_like(&res, a.c, C.default_stream()) + } else { + dtype := a.Dtype() + C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), C.default_stream()) + } + return newArray(res) +} + +// Ones creates an array of ones +func Ones(shape ...int32) *Array { + res := C.mlx_array_new() + C.mlx_ones(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, C.default_stream()) + return newArray(res) +} + +// Full creates an array filled with a value +func Full(value float32, shape ...int32) *Array { + vals := C.mlx_array_new_float(C.float(value)) + res := C.mlx_array_new() + C.mlx_full(&res, int32ToCInt(shape), C.size_t(len(shape)), vals, C.MLX_FLOAT32, C.default_stream()) + C.mlx_array_free(vals) + return newArray(res) +} + +// Arange creates a range of values +func Arange(start, stop, step float32) *Array { + res := C.mlx_array_new() + C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.MLX_FLOAT32, C.default_stream()) + return newArray(res) +} + +// Linspace creates evenly spaced values +func Linspace(start, stop float32, steps int32) *Array { + res := C.mlx_array_new() + C.mlx_linspace(&res, C.double(start), C.double(stop), C.int(steps), C.MLX_FLOAT32, C.default_stream()) + return newArray(res) +} + +// ============ Math Operations ============ + +// Add adds two arrays element-wise +func Add(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_add(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// AddRaw is like Add - kept for API compatibility (now identical to Add) +func AddRaw(a, b *Array) *Array { + return Add(a, b) +} + +// Sub subtracts two arrays element-wise +func Sub(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_subtract(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// Mul multiplies two arrays element-wise +func Mul(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_multiply(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// Div divides two arrays element-wise +func Div(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_divide(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// Matmul performs matrix multiplication +func Matmul(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_matmul(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// AddMM computes: result = beta*c + alpha*(a @ b) +// This fuses bias addition with matmul into a single op. +func AddMM(c, a, b *Array, alpha, beta float32) *Array { + res := C.mlx_array_new() + C.mlx_addmm(&res, c.c, a.c, b.c, C.float(alpha), C.float(beta), C.default_stream()) + return newArray(res) +} + +// Linear performs matrix multiplication: a @ weight +func Linear(a, weight *Array) *Array { + return Matmul(a, weight) +} + +// Sqrt computes element-wise square root +func Sqrt(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_sqrt(&res, a.c, C.default_stream()) + return newArray(res) +} + +// RSqrt computes element-wise reciprocal square root +func RSqrt(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_rsqrt(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Erf computes element-wise error function +func Erf(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_erf(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Exp computes element-wise exponential +func Exp(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_exp(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Log computes element-wise natural logarithm +func Log(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_log(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Sin computes element-wise sine +func Sin(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_sin(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Cos computes element-wise cosine +func Cos(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_cos(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Neg negates the array +func Neg(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_negative(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Abs computes element-wise absolute value +func Abs(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_abs(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Square computes element-wise square +func Square(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_square(&res, a.c, C.default_stream()) + return newArray(res) +} + +// Pow raises a to the power of b element-wise +func Pow(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_power(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// Max computes element-wise maximum +func Max(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_maximum(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// Min computes element-wise minimum +func Min(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_minimum(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// scalarWithDtype creates a scalar array matching the dtype of a (critical for graph fusion!) +func scalarWithDtype(s float32, a *Array) C.mlx_array { + // Create float32 scalar, then cast to match input dtype + f32 := C.mlx_array_new_float(C.float(s)) + dtype := a.Dtype() + if dtype == DtypeFloat32 { + return f32 // No cast needed + } + // Cast to match input dtype + casted := C.mlx_array_new() + C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), C.default_stream()) + C.mlx_array_free(f32) + return casted +} + +// AddScalar adds a scalar to an array (matches dtype for graph fusion) +func AddScalar(a *Array, s float32) *Array { + scalar := scalarWithDtype(s, a) + res := C.mlx_array_new() + C.mlx_add(&res, a.c, scalar, C.default_stream()) + C.mlx_array_free(scalar) + return newArray(res) +} + +// MulScalar multiplies an array by a scalar (matches dtype for graph fusion) +func MulScalar(a *Array, s float32) *Array { + scalar := scalarWithDtype(s, a) + res := C.mlx_array_new() + C.mlx_multiply(&res, a.c, scalar, C.default_stream()) + C.mlx_array_free(scalar) + return newArray(res) +} + +// DivScalar divides an array by a scalar (matches dtype for graph fusion) +func DivScalar(a *Array, s float32) *Array { + scalar := scalarWithDtype(s, a) + res := C.mlx_array_new() + C.mlx_divide(&res, a.c, scalar, C.default_stream()) + C.mlx_array_free(scalar) + return newArray(res) +} + +// DivScalarInt divides an int array by an int scalar (regular division, may return float) +func DivScalarInt(a *Array, s int32) *Array { + scalar := C.mlx_array_new_int(C.int(s)) + res := C.mlx_array_new() + C.mlx_divide(&res, a.c, scalar, C.default_stream()) + C.mlx_array_free(scalar) + return newArray(res) +} + +// FloorDivideScalar performs integer floor division (a // s), preserving int dtype +func FloorDivideScalar(a *Array, s int32) *Array { + scalar := C.mlx_array_new_int(C.int(s)) + res := C.mlx_array_new() + C.mlx_floor_divide(&res, a.c, scalar, C.default_stream()) + C.mlx_array_free(scalar) + return newArray(res) +} + +// ============ Reduction Operations ============ + +// Sum reduces along an axis +func Sum(a *Array, axis int, keepdims bool) *Array { + res := C.mlx_array_new() + C.mlx_sum_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream()) + return newArray(res) +} + +// SumAll reduces the entire array to a scalar +func SumAll(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_sum(&res, a.c, false, C.default_stream()) + return newArray(res) +} + +// Mean reduces along an axis +func Mean(a *Array, axis int, keepdims bool) *Array { + res := C.mlx_array_new() + C.mlx_mean_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream()) + return newArray(res) +} + +// MeanAll reduces the entire array to a scalar +func MeanAll(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_mean(&res, a.c, false, C.default_stream()) + return newArray(res) +} + +// Var computes variance along an axis +func Var(a *Array, axis int, keepdims bool) *Array { + res := C.mlx_array_new() + C.mlx_var_axis(&res, a.c, C.int(axis), C._Bool(keepdims), 0, C.default_stream()) + return newArray(res) +} + +// Argmax returns indices of maximum values along an axis +func Argmax(a *Array, axis int, keepdims bool) *Array { + res := C.mlx_array_new() + C.mlx_argmax_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream()) + return newArray(res) +} + +// ArgmaxAll returns the index of the maximum element (flattened). +// Triggers cleanup of non-kept arrays. +func ArgmaxAll(a *Array) int32 { + cleanup() + // Flatten, then argmax with keepdims=false + flat := C.mlx_array_new() + C.mlx_flatten(&flat, a.c, 0, -1, C.default_stream()) + res := C.mlx_array_new() + C.mlx_argmax(&res, flat, false, C.default_stream()) + C.mlx_array_eval(res) + var val C.int32_t + C.mlx_array_item_int32(&val, res) + C.mlx_array_free(flat) + C.mlx_array_free(res) + return int32(val) +} + +// Reshape reshapes the array +func Reshape(a *Array, shape ...int32) *Array { + res := C.mlx_array_new() + C.mlx_reshape(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream()) + return newArray(res) +} + +// Transpose permutes the dimensions +func Transpose(a *Array, axes ...int) *Array { + cAxes := make([]C.int, len(axes)) + for i, ax := range axes { + cAxes[i] = C.int(ax) + } + res := C.mlx_array_new() + C.mlx_transpose_axes(&res, a.c, &cAxes[0], C.size_t(len(axes)), C.default_stream()) + return newArray(res) +} + +// AsStrided creates a view with custom strides. Useful for fusing reshape+transpose. +func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array { + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + cStrides := make([]C.int64_t, len(strides)) + for i, s := range strides { + cStrides[i] = C.int64_t(s) + } + res := C.mlx_array_new() + C.mlx_as_strided(&res, a.c, &cShape[0], C.size_t(len(shape)), &cStrides[0], C.size_t(len(strides)), C.size_t(offset), C.default_stream()) + return newArray(res) +} + +// ExpandDims adds a dimension at the specified axis +func ExpandDims(a *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_expand_dims(&res, a.c, C.int(axis), C.default_stream()) + return newArray(res) +} + +// Squeeze removes a dimension at the specified axis +func Squeeze(a *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_squeeze_axis(&res, a.c, C.int(axis), C.default_stream()) + return newArray(res) +} + +// Flatten flattens the array to 1D +func Flatten(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_flatten(&res, a.c, 0, -1, C.default_stream()) + return newArray(res) +} + +// FlattenRange flattens consecutive axes from start_axis to end_axis (intermediates) +func FlattenRange(a *Array, startAxis, endAxis int) *Array { + res := C.mlx_array_new() + C.mlx_flatten(&res, a.c, C.int(startAxis), C.int(endAxis), C.default_stream()) + return newArray(res) +} + +// View reinterprets the array with a new dtype (no data copy) +func View(a *Array, dtype int) *Array { + res := C.mlx_array_new() + C.mlx_view(&res, a.c, C.mlx_dtype(dtype), C.default_stream()) + return newArray(res) +} + +// Contiguous returns a contiguous copy of the array +func Contiguous(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_contiguous(&res, a.c, true, C.default_stream()) + return newArray(res) +} + +// Clip clips values to [min, max]. Pass nil for no bound on that side. +func Clip(a *Array, aMin, aMax *Array) *Array { + res := C.mlx_array_new() + var minH, maxH C.mlx_array + if aMin != nil { + minH = aMin.c + } + if aMax != nil { + maxH = aMax.c + } + C.mlx_clip(&res, a.c, minH, maxH, C.default_stream()) + return newArray(res) +} + +// ClipScalar clips array values using scalar bounds (matches dtype for graph fusion) +// Pass math.NaN() or set hasMin/hasMax to false for unbounded +func ClipScalar(a *Array, minVal, maxVal float32, hasMin, hasMax bool) *Array { + var minArr, maxArr C.mlx_array + if hasMin { + minArr = scalarWithDtype(minVal, a) + } + if hasMax { + maxArr = scalarWithDtype(maxVal, a) + } + res := C.mlx_array_new() + C.mlx_clip(&res, a.c, minArr, maxArr, C.default_stream()) + if hasMin { + C.mlx_array_free(minArr) + } + if hasMax { + C.mlx_array_free(maxArr) + } + return newArray(res) +} + +// GreaterEqual returns element-wise a >= b +func GreaterEqual(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_greater_equal(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// LessArray returns element-wise a < b +func LessArray(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_less(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// LogicalAnd returns element-wise a && b +func LogicalAnd(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_logical_and(&res, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// AllClose returns true if all elements of a and b are within tolerance. +// Uses rtol (relative tolerance) and atol (absolute tolerance): +// |a - b| <= atol + rtol * |b| +func AllClose(a, b *Array, rtol, atol float64) *Array { + res := C.mlx_array_new() + C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream()) + return newArray(res) +} + +// AllCloseEqualNaN is like AllClose but treats NaN as equal to NaN. +func AllCloseEqualNaN(a, b *Array, rtol, atol float64) *Array { + res := C.mlx_array_new() + C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream()) + return newArray(res) +} + +// ArrayEqual returns true if arrays have same shape and all elements are equal. +func ArrayEqual(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_array_equal(&res, a.c, b.c, C.bool(false), C.default_stream()) + return newArray(res) +} + +// ArrayEqualNaN is like ArrayEqual but treats NaN as equal to NaN. +func ArrayEqualNaN(a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_array_equal(&res, a.c, b.c, C.bool(true), C.default_stream()) + return newArray(res) +} + +// IsClose returns element-wise bool array indicating if values are within tolerance. +// |a - b| <= atol + rtol * |b| +func IsClose(a, b *Array, rtol, atol float64) *Array { + res := C.mlx_array_new() + C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream()) + return newArray(res) +} + +// IsCloseEqualNaN is like IsClose but treats NaN as equal to NaN. +func IsCloseEqualNaN(a, b *Array, rtol, atol float64) *Array { + res := C.mlx_array_new() + C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream()) + return newArray(res) +} + +// ReduceMax reduces array to max value over all dimensions. +func ReduceMax(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_max(&res, a.c, C.bool(false), C.default_stream()) + return newArray(res) +} + +// ArangeInt creates an array with values from start to stop with step and specified dtype +func ArangeInt(start, stop, step int32, dtype Dtype) *Array { + res := C.mlx_array_new() + C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.mlx_dtype(dtype), C.default_stream()) + return newArray(res) +} + +// Concatenate concatenates arrays along an axis +func Concatenate(arrays []*Array, axis int) *Array { + handles := make([]C.mlx_array, len(arrays)) + for i, arr := range arrays { + handles[i] = arr.c + } + vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles))) + res := C.mlx_array_new() + C.mlx_concatenate_axis(&res, vec, C.int(axis), C.default_stream()) + C.mlx_vector_array_free(vec) + return newArray(res) +} + +// Concat is a convenience function to concatenate two arrays +func Concat(a, b *Array, axis int) *Array { + return Concatenate([]*Array{a, b}, axis) +} + +// Slice slices the array +func Slice(a *Array, start, stop []int32) *Array { + n := len(start) + cStart := make([]C.int, n) + cStop := make([]C.int, n) + cStrides := make([]C.int, n) + for i := 0; i < n; i++ { + cStart[i] = C.int(start[i]) + cStop[i] = C.int(stop[i]) + cStrides[i] = 1 // Default stride of 1 + } + res := C.mlx_array_new() + C.mlx_slice(&res, a.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream()) + return newArray(res) +} + +// SliceStride slices with start:stop:stride like Python a[start:stop:stride] +func SliceStride(a *Array, start, stop, strides []int32) *Array { + cStart := make([]C.int, len(start)) + cStop := make([]C.int, len(stop)) + cStrides := make([]C.int, len(strides)) + for i := range start { + cStart[i] = C.int(start[i]) + cStop[i] = C.int(stop[i]) + cStrides[i] = C.int(strides[i]) + } + res := C.mlx_array_new() + C.mlx_slice(&res, a.c, &cStart[0], C.size_t(len(start)), &cStop[0], C.size_t(len(stop)), &cStrides[0], C.size_t(len(strides)), C.default_stream()) + return newArray(res) +} + +// Tile repeats the array along each dimension +func Tile(a *Array, reps []int32) *Array { + res := C.mlx_array_new() + C.mlx_tile(&res, a.c, int32ToCInt(reps), C.size_t(len(reps)), C.default_stream()) + return newArray(res) +} + +// BroadcastTo broadcasts an array to a given shape +func BroadcastTo(a *Array, shape []int32) *Array { + res := C.mlx_array_new() + C.mlx_broadcast_to(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream()) + return newArray(res) +} + +// ============ Neural Network Operations ============ + +// Softmax computes softmax along an axis +func Softmax(a *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_softmax_axis(&res, a.c, C.int(axis), false, C.default_stream()) + return newArray(res) +} + +// Take gathers elements along an axis using indices +func Take(a *Array, indices *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_take_axis(&res, a.c, indices.c, C.int(axis), C.default_stream()) + return newArray(res) +} + +// Argsort returns indices that would sort the array along an axis +func Argsort(a *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_argsort_axis(&res, a.c, C.int(axis), C.default_stream()) + return newArray(res) +} + +// Sigmoid computes element-wise sigmoid +func Sigmoid(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_sigmoid(&res, a.c, C.default_stream()) + return newArray(res) +} + +// ReLU computes element-wise ReLU: max(0, x) +func ReLU(a *Array) *Array { + // ReLU = maximum(x, 0) - mlx-c doesn't have mlx_relu, but we can use maximum + zero := C.mlx_array_new_float(0.0) + res := C.mlx_array_new() + C.mlx_maximum(&res, a.c, zero, C.default_stream()) + C.mlx_array_free(zero) + return newArray(res) +} + +// SiLU computes element-wise SiLU (Swish): x * sigmoid(x) +func SiLU(a *Array) *Array { + // SiLU = x * sigmoid(x) + sig := C.mlx_array_new() + C.mlx_sigmoid(&sig, a.c, C.default_stream()) + res := C.mlx_array_new() + C.mlx_multiply(&res, a.c, sig, C.default_stream()) + C.mlx_array_free(sig) + return newArray(res) +} + +// GELU computes element-wise GELU (Gaussian Error Linear Unit) +// GELU(x) = x * 0.5 * (1 + erf(x / sqrt(2))) +func GELU(a *Array) *Array { + sqrt2 := C.mlx_array_new_float(1.4142135623730951) + scaled := C.mlx_array_new() + C.mlx_divide(&scaled, a.c, sqrt2, C.default_stream()) + erfd := C.mlx_array_new() + C.mlx_erf(&erfd, scaled, C.default_stream()) + one := C.mlx_array_new_float(1.0) + erfdPlusOne := C.mlx_array_new() + C.mlx_add(&erfdPlusOne, erfd, one, C.default_stream()) + half := C.mlx_array_new_float(0.5) + halfErfdPlusOne := C.mlx_array_new() + C.mlx_multiply(&halfErfdPlusOne, half, erfdPlusOne, C.default_stream()) + res := C.mlx_array_new() + C.mlx_multiply(&res, a.c, halfErfdPlusOne, C.default_stream()) + C.mlx_array_free(sqrt2) + C.mlx_array_free(scaled) + C.mlx_array_free(erfd) + C.mlx_array_free(one) + C.mlx_array_free(erfdPlusOne) + C.mlx_array_free(half) + C.mlx_array_free(halfErfdPlusOne) + return newArray(res) +} + +// Tanh computes element-wise tanh +func Tanh(a *Array) *Array { + res := C.mlx_array_new() + C.mlx_tanh(&res, a.c, C.default_stream()) + return newArray(res) +} + +// RMSNorm computes RMS normalization using mlx.fast +func RMSNorm(x, weight *Array, eps float32) *Array { + res := C.mlx_array_new() + C.mlx_fast_rms_norm(&res, x.c, weight.c, C.float(eps), C.default_stream()) + return newArray(res) +} + +// RMSNormNoWeight applies RMS normalization without a weight +// x * rsqrt(mean(x^2) + eps) +// Uses mlx_fast_rms_norm with ones weight for f32 accumulation precision +func RMSNormNoWeight(x *Array, eps float32) *Array { + // Create weight of ones matching last dimension + lastDim := x.Shape()[len(x.Shape())-1] + ones := AsType(Full(1.0, lastDim), x.Dtype()) + return RMSNorm(x, ones, eps) +} + +// RoPE applies rotary position embeddings using mlx.fast +func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { + res := C.mlx_array_new() + optBase := C.mlx_optional_float{value: C.float(base), has_value: true} + C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), C.mlx_array{}, C.default_stream()) + return newArray(res) +} + +// RoPEWithFreqs applies rotary position embeddings with custom frequencies (for YaRN) +// freqs is required - use RoPE() if you don't have custom frequencies +func RoPEWithFreqs(x, freqs *Array, dims int, traditional bool, scale float32, offset int) *Array { + res := C.mlx_array_new() + optBase := C.mlx_optional_float{has_value: false} // No base when using freqs + C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), freqs.c, C.default_stream()) + return newArray(res) +} + +// ============ Indexing ============ + +// EmbeddingLookup performs embedding lookup (gathers from table) +// table: [vocab_size, hidden_size], indices: [batch, seq_len] +// returns: [batch, seq_len, hidden_size] +func EmbeddingLookup(table, indices *Array) *Array { + return Take(table, indices, 0) +} + +// Gather gathers elements using indices - simplified to use take axis 0 +func Gather(a, indices *Array) *Array { + return Take(a, indices, 0) +} + +// ============ Array Properties ============ + +// Ndim returns the number of dimensions +func (a *Array) Ndim() int { + return int(C.mlx_array_ndim(a.c)) +} + +// Size returns the total number of elements +func (a *Array) Size() int { + return int(C.mlx_array_size(a.c)) +} + +// IsContiguous returns whether the array's data is contiguous in memory. +// Non-contiguous arrays (e.g., from SliceStride) must call Contiguous() before Data(). +func (a *Array) IsContiguous() bool { + var res C.bool + C._mlx_array_is_contiguous(&res, a.c) + return bool(res) +} + +// Dim returns the size of a dimension +func (a *Array) Dim(axis int) int32 { + return int32(C.mlx_array_dim(a.c, C.int(axis))) +} + +// Shape returns the shape as a slice +func (a *Array) Shape() []int32 { + ndim := a.Ndim() + shape := make([]int32, ndim) + for i := 0; i < ndim; i++ { + shape[i] = a.Dim(i) + } + return shape +} + +// IsValid returns true if the array hasn't been freed +func (a *Array) IsValid() bool { + return a != nil && a.c.ctx != nil +} + +// Dtype returns the data type +func (a *Array) Dtype() Dtype { + return Dtype(C.mlx_array_dtype(a.c)) +} + +// Nbytes returns the total size in bytes +func (a *Array) Nbytes() int64 { + return int64(a.Size()) * a.Dtype().ItemSize() +} + +// ItemSize returns the size in bytes of one element for this dtype +func (d Dtype) ItemSize() int64 { + switch d { + case DtypeBool, DtypeUint8, DtypeInt8: + return 1 + case DtypeUint16, DtypeInt16, DtypeFloat16, DtypeBFloat16: + return 2 + case DtypeUint32, DtypeInt32, DtypeFloat32: + return 4 + case DtypeUint64, DtypeInt64, DtypeFloat64, DtypeComplex64: + return 8 + default: + return 4 + } +} + +// ============ Data Access ============ + +// Data copies the float32 data out of the array. +// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first. +// Note: Arrays of other dtypes (bf16, f16, etc) are automatically converted to float32. +// Note: Triggers cleanup of non-kept arrays. +func (a *Array) Data() []float32 { + cleanup() + size := a.Size() + if size == 0 { + return nil + } + + arr := a + if a.Dtype() != DtypeFloat32 { + arr = AsType(a, DtypeFloat32) + arr.Eval() + // Cast array will be cleaned up on next Eval + } + + ptr := C.mlx_array_data_float32(arr.c) + if ptr == nil { + return nil + } + data := make([]float32, size) + copy(data, unsafe.Slice((*float32)(unsafe.Pointer(ptr)), size)) + return data +} + +// Item returns the scalar value from a 0-dimensional array. +// Converts to float32 if necessary. Triggers cleanup. +func (a *Array) Item() float32 { + data := a.Data() // Data() calls cleanup() + if len(data) == 0 { + return 0 + } + return data[0] +} + +// DataInt32 copies the int32 data out of the array. +// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first. +// Note: Triggers cleanup of non-kept arrays. +func (a *Array) DataInt32() []int32 { + cleanup() + size := a.Size() + if size == 0 { + return nil + } + ptr := C.mlx_array_data_int32(a.c) + if ptr == nil { + return nil + } + data := make([]int32, size) + copy(data, unsafe.Slice((*int32)(unsafe.Pointer(ptr)), size)) + return data +} + +// ItemInt32 gets a single scalar value efficiently (no array copy). +// Note: Triggers cleanup of non-kept arrays. +func (a *Array) ItemInt32() int32 { + cleanup() + var val C.int32_t + C.mlx_array_item_int32(&val, a.c) + return int32(val) +} + +// ============ Utility ============ + +// String returns a string representation +func (a *Array) String() string { + shape := a.Shape() + size := a.Size() + if size <= 20 { + data := a.Data() + return fmt.Sprintf("Array(shape=%v, data=%v)", shape, data) + } + return fmt.Sprintf("Array(shape=%v, size=%d)", shape, size) +} + +// ============ Safetensors Support ============ + +// NewArrayFromBytes creates an array from raw bytes (for safetensors) +func NewArrayFromBytes(data []byte, shape []int32, dtype Dtype) *Array { + cData := unsafe.Pointer(&data[0]) + intShape := make([]C.int, len(shape)) + for i, s := range shape { + intShape[i] = C.int(s) + } + handle := C.mlx_array_new_data(cData, &intShape[0], C.int(len(shape)), C.mlx_dtype(dtype)) + return newArray(handle) +} + +// ============ Device Control ============ + +// SetDefaultDeviceGPU sets the default device to GPU (Metal) +func SetDefaultDeviceGPU() { + dev := C.mlx_device_new_type(C.MLX_GPU, 0) + C.mlx_set_default_device(dev) + C.mlx_device_free(dev) +} + +// SetDefaultDeviceCPU sets the default device to CPU +func SetDefaultDeviceCPU() { + dev := C.mlx_device_new_type(C.MLX_CPU, 0) + C.mlx_set_default_device(dev) + C.mlx_device_free(dev) +} + +// MetalIsAvailable returns true if Metal GPU is available +func MetalIsAvailable() bool { + var available C._Bool + C.mlx_metal_is_available(&available) + return bool(available) +} + +// MetalStartCapture starts a GPU trace capture to the given file path. +// The path must not already exist. Run with MTL_CAPTURE_ENABLED=1 env var. +// Open the resulting .gputrace file in Xcode for analysis. +func MetalStartCapture(path string) { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + C.mlx_metal_start_capture(cPath) +} + +// MetalStopCapture stops the current GPU trace capture. +func MetalStopCapture() { + C.mlx_metal_stop_capture() +} + +// GPUIsAvailable returns true if any GPU (Metal or CUDA) is available +func GPUIsAvailable() bool { + // On Linux with CUDA build, GPU is available + // On macOS, check Metal availability + if MetalIsAvailable() { + return true + } + // CUDA is available if we compiled with CUDA support (Linux) + return runtime.GOOS == "linux" +} + +// GetDefaultDeviceType returns the current default device (0=CPU, 1=GPU) +func GetDefaultDeviceType() int { + var dev C.mlx_device + C.mlx_get_default_device(&dev) + var devType C.mlx_device_type + C.mlx_device_get_type(&devType, dev) + C.mlx_device_free(dev) + return int(devType) +} + +// Synchronize waits for all GPU operations to complete +func Synchronize() { + C.mlx_synchronize(C.default_stream()) +} + +// ScaledDotProductAttention computes optimized attention using GPU kernel +// Q, K, V should be [batch, heads, seq, head_dim] +func ScaledDotProductAttention(q, k, v *Array, scale float32, causalMask bool) *Array { + res := C.mlx_array_new() + maskMode := "" // empty string for no mask + if causalMask { + maskMode = "causal" + } + cMaskMode := C.CString(maskMode) + defer C.free(unsafe.Pointer(cMaskMode)) + C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, C.mlx_array{}, C.mlx_array{}, C.default_stream()) + return newArray(res) +} + +// ScaledDotProductAttentionWithSinks computes attention with sinks support +// maskMode: "causal", "sliding_window", or "" for none +// mask: optional attention mask array (nil for none) +// sinks: attention sinks array (nil for none) +func ScaledDotProductAttentionWithSinks(q, k, v *Array, scale float32, maskMode string, mask, sinks *Array) *Array { + res := C.mlx_array_new() + cMaskMode := C.CString(maskMode) + defer C.free(unsafe.Pointer(cMaskMode)) + var maskH, sinksH C.mlx_array + if mask != nil { + maskH = mask.c + } + if sinks != nil { + sinksH = sinks.c + } + C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, maskH, sinksH, C.default_stream()) + return newArray(res) +} + +// ============ Native Safetensors Loading ============ + +// SafetensorsFile represents a loaded safetensors file +type SafetensorsFile struct { + arrays C.mlx_map_string_to_array + metadata C.mlx_map_string_to_string +} + +// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader +// Note: Uses CPU stream because Load primitive only runs on CPU +func LoadSafetensorsNative(path string) (*SafetensorsFile, error) { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + var arrays C.mlx_map_string_to_array + var metadata C.mlx_map_string_to_string + if C.mlx_load_safetensors(&arrays, &metadata, cPath, C.cpu_stream()) != 0 { + return nil, fmt.Errorf("failed to load safetensors: %s", path) + } + return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil +} + +// Get retrieves a tensor by name +func (s *SafetensorsFile) Get(name string) *Array { + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + + var arr C.mlx_array + if C.mlx_map_string_to_array_get(&arr, s.arrays, cName) != 0 { + return nil + } + if arr.ctx == nil { + return nil + } + return newArray(arr) +} + +// Set replaces a tensor in the map (like Python's weights[k] = v) +func (s *SafetensorsFile) Set(name string, arr *Array) { + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + C.mlx_map_string_to_array_insert(s.arrays, cName, arr.c) +} + +// Count returns the number of tensors (not directly available, would need iterator) +func (s *SafetensorsFile) Count() int { + // mlx-c doesn't have a direct count - would need to iterate + return 0 +} + +// Free releases the safetensors file +func (s *SafetensorsFile) Free() { + C.mlx_map_string_to_array_free(s.arrays) + C.mlx_map_string_to_string_free(s.metadata) +} + +// ============ NPY Loading ============ + +// LoadNpy loads a numpy array from an npy file +// Note: Uses CPU stream because Load primitive only runs on CPU +func LoadNpy(path string) (*Array, error) { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + var arr C.mlx_array + if C.mlx_load(&arr, cPath, C.cpu_stream()) != 0 { + return nil, fmt.Errorf("failed to load npy: %s", path) + } + if arr.ctx == nil { + return nil, fmt.Errorf("failed to load npy: %s", path) + } + return newArray(arr), nil +} + +// ============ Slice Update ============ + +// SliceUpdate updates a slice of the array with new values +func SliceUpdate(a, update *Array, start, stop []int32) *Array { + n := len(start) + cStart := make([]C.int, n) + cStop := make([]C.int, n) + cStrides := make([]C.int, n) + for i := 0; i < n; i++ { + cStart[i] = C.int(start[i]) + cStop[i] = C.int(stop[i]) + cStrides[i] = 1 // Default stride of 1 + } + res := C.mlx_array_new() + C.mlx_slice_update(&res, a.c, update.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream()) + return newArray(res) +} + +// SliceUpdateInplace updates a slice and returns a new array. +// Note: Despite the name, this is NOT in-place - MLX arrays are immutable. +// The caller must use the returned value. +func SliceUpdateInplace(a, update *Array, start, stop []int32) *Array { + return SliceUpdate(a, update, start, stop) +} + +// ============ Optimized Operations ============ + +// SampleArgmax gets the last logit position and returns argmax (fused operation) +func SampleArgmax(logits *Array) int32 { + result := Argmax(logits, -1, false) + return result.ItemInt32() +} + +// ArgmaxKeepArray returns argmax as an Array (for pipelining, no sync) +// This is like mlx-lm's sampler that returns y as an array, not .item() +func ArgmaxKeepArray(logits *Array) *Array { + // For greedy decoding: logits shape is [1, 1, vocab] + // We want argmax over vocab dimension, return shape [] + return Argmax(logits, -1, false) +} + +// RandomState is the global PRNG state, analogous to mx.random.state in Python. +// It's a slice containing a single key array. Random functions use and update this state. +// +// Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior. +// All random functions that use global state acquire this lock. +var RandomState = []*Array{nil} +var randomStateMu sync.Mutex + +func init() { + // Lock main goroutine to OS thread for CUDA context stability. + // CUDA contexts are bound to threads; Go can migrate goroutines between threads. + runtime.LockOSThread() + RandomState[0] = RandomKey(uint64(time.Now().UnixMilli())) + Keep(RandomState[0]) // Global state should persist +} + +// RandomKey creates a PRNG key from a seed +func RandomKey(seed uint64) *Array { + var res C.mlx_array + C.mlx_random_key(&res, C.uint64_t(seed)) + return newArray(res) +} + +// RandomSplit splits a PRNG key into two new keys +func RandomSplit(key *Array) (*Array, *Array) { + var key1, key2 C.mlx_array + C.mlx_random_split(&key1, &key2, key.c, C.default_stream()) + return newArray(key1), newArray(key2) +} + +// RandomCategoricalWithKey samples from categorical distribution using provided key. +func RandomCategoricalWithKey(logits, key *Array, axis int, numSamples int) *Array { + res := C.mlx_array_new() + C.mlx_random_categorical_num_samples(&res, logits.c, C.int(axis), C.int(numSamples), key.c, C.default_stream()) + return newArray(res) +} + +// RandomCategorical samples using global RandomState. +// For simple scripts - production code should use RandomCategoricalWithKey with explicit key management. +func RandomCategorical(logits *Array, axis int, numSamples int) *Array { + randomStateMu.Lock() + oldKey := RandomState[0] + key1, key2 := RandomSplit(oldKey) + Keep(key1) // key1 becomes the new global state + oldKey.Free() + RandomState[0] = key1 + randomStateMu.Unlock() + return RandomCategoricalWithKey(logits, key2, axis, numSamples) +} + +// RandomNormal creates a random normal (Gaussian) tensor +func RandomNormal(shape []int32, seed uint64) *Array { + key := RandomKey(seed) + res := C.mlx_array_new() + C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, 0.0, 1.0, key.c, C.default_stream()) + return newArray(res) +} + +// RandomUniform generates uniform random values in [0, 1) with the given shape +func RandomUniform(shape []int32, seed uint64) *Array { + key := RandomKey(seed) + low := C.mlx_array_new_float(0.0) + high := C.mlx_array_new_float(1.0) + res := C.mlx_array_new() + C.mlx_random_uniform(&res, low, high, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, key.c, C.default_stream()) + C.mlx_array_free(low) + C.mlx_array_free(high) + return newArray(res) +} + +// Conv2d performs 2D convolution +// input: [N, H, W, C], weight: [O, kH, kW, C] (MLX uses NHWC layout) +// Returns: [N, H', W', O] +func Conv2d(input, weight *Array, stride, padding int32) *Array { + res := C.mlx_array_new() + C.mlx_conv2d(&res, input.c, weight.c, C.int(stride), C.int(stride), C.int(padding), C.int(padding), 1, 1, 1, C.default_stream()) + return newArray(res) +} + +// Conv3d performs 3D convolution +// input: [N, D, H, W, C], weight: [O, kD, kH, kW, C] (MLX uses NDHWC layout) +// Returns: [N, D', H', W', O] +func Conv3d(input, weight *Array, strideD, strideH, strideW, padD, padH, padW int32) *Array { + res := C.mlx_array_new() + C.mlx_conv3d(&res, input.c, weight.c, C.int(strideD), C.int(strideH), C.int(strideW), C.int(padD), C.int(padH), C.int(padW), 1, 1, 1, 1, C.default_stream()) + return newArray(res) +} + +// ============ Compilation Control ============ + +// EnableCompile enables global compilation/graph fusion +func EnableCompile() { + C.mlx_enable_compile() +} + +// DisableCompile disables global compilation +func DisableCompile() { + C.mlx_disable_compile() +} + +// SetCompileMode sets the compile mode +// 0=disabled, 1=no_simplify, 2=no_fuse, 3=enabled +func SetCompileMode(mode int) { + C.mlx_set_compile_mode(C.mlx_compile_mode(mode)) +} + +// ============ Stream Control ============ + +// Stream represents an MLX execution stream +type Stream struct { + c C.mlx_stream +} + +// NewStream creates a new execution stream on the default device +func NewStream() *Stream { + var dev C.mlx_device + C.mlx_get_default_device(&dev) + stream := C.mlx_stream_new_device(dev) + C.mlx_device_free(dev) + return &Stream{c: stream} +} + +// Free releases the stream +func (s *Stream) Free() { + if s.c.ctx != nil { + C.mlx_stream_free(s.c) + s.c.ctx = nil + } +} + +// SetDefaultStream sets the default stream for operations +func SetDefaultStream(s *Stream) { + C.mlx_set_default_stream(s.c) + C.set_default_stream(s.c) // Also update our cached stream +} + +// GetDefaultStream returns the current default stream +func GetDefaultStream() *Stream { + var stream C.mlx_stream + var dev C.mlx_device + C.mlx_get_default_device(&dev) + C.mlx_get_default_stream(&stream, dev) + C.mlx_device_free(dev) + return &Stream{c: stream} +} + +// SynchronizeStream waits for all operations on the stream to complete +func SynchronizeStream(s *Stream) { + C.mlx_synchronize(s.c) +} + +// ============ Metal Memory Control ============ + +// MetalGetCacheMemory returns the current cache memory usage in bytes +func MetalGetCacheMemory() uint64 { + var size C.size_t + C.mlx_get_cache_memory(&size) + return uint64(size) +} + +// MetalGetPeakMemory returns the peak memory usage in bytes +func MetalGetPeakMemory() uint64 { + var size C.size_t + C.mlx_get_peak_memory(&size) + return uint64(size) +} + +// MetalResetPeakMemory resets the peak memory counter +func MetalResetPeakMemory() { + C.mlx_reset_peak_memory() +} + +// MetalSetWiredLimit sets the wired memory limit and returns the previous limit +// This keeps tensors pinned in GPU memory for faster access +func MetalSetWiredLimit(limit uint64) uint64 { + var prev C.size_t + C.mlx_set_wired_limit(&prev, C.size_t(limit)) + return uint64(prev) +} + +// MetalGetActiveMemory returns the current active memory usage in bytes +func MetalGetActiveMemory() uint64 { + var size C.size_t + C.mlx_get_active_memory(&size) + return uint64(size) +} + +// ClearCache clears the MLX memory cache +func ClearCache() { + C.mlx_clear_cache() +} + +// SetCacheLimit sets the free cache limit in bytes +// Setting to 0 disables caching (useful for memory-constrained generation) +// Returns the previous cache limit +func SetCacheLimit(limit uint64) uint64 { + var prev C.size_t + C.mlx_set_cache_limit(&prev, C.size_t(limit)) + return uint64(prev) +} + +// SetMemoryLimit sets the overall memory limit in bytes +// This is a guideline for maximum memory during graph evaluation. +// When Metal is available, defaults to 1.5x the max recommended working set. +// Returns the previous memory limit +func SetMemoryLimit(limit uint64) uint64 { + var prev C.size_t + C.mlx_set_memory_limit(&prev, C.size_t(limit)) + return uint64(prev) +} + +// GetMemoryLimit returns the current memory limit in bytes +func GetMemoryLimit() uint64 { + var size C.size_t + C.mlx_get_memory_limit(&size) + return uint64(size) +} + +// ============ MoE Operations ============ + +// GatherMM performs gather matrix multiplication for MoE +// a: input, b: weight matrices +// lhsIndices, rhsIndices: optional expert selection indices (nil for none) +func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array { + var lhs, rhs C.mlx_array + if lhsIndices != nil { + lhs = lhsIndices.c + } + if rhsIndices != nil { + rhs = rhsIndices.c + } + res := C.mlx_array_new() + C.mlx_gather_mm(&res, a.c, b.c, lhs, rhs, C._Bool(sortedIndices), C.default_stream()) + return newArray(res) +} + +// GatherQMM performs quantized gather matrix multiplication for MoE +// Used for MXFP4 and other quantized MoE inference +func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array { + var b, lhs, rhs C.mlx_array + if biases != nil { + b = biases.c + } + if lhsIndices != nil { + lhs = lhsIndices.c + } + if rhsIndices != nil { + rhs = rhsIndices.c + } + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + res := C.mlx_array_new() + C.mlx_gather_qmm(&res, x.c, w.c, scales.c, b, lhs, rhs, C._Bool(transpose), optGroupSize, optBits, cMode, C._Bool(sortedIndices), C.default_stream()) + return newArray(res) +} + +// ============ Quantization ============ + +// Quantize quantizes weights to specified bits per element. +// Returns (quantized_weights, scales, biases). +// groupSize: number of elements quantized together (default 64) +// bits: bits per element, 2, 4, or 8 (default 4) +// mode: "affine" (default) or "mxfp4" +func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) { + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + res := C.mlx_vector_array_new() + C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream()) + + // Result is a vector of 3 arrays: [weights, scales, biases] + var w0, w1, w2 C.mlx_array + C.mlx_vector_array_get(&w0, res, 0) + C.mlx_vector_array_get(&w1, res, 1) + C.mlx_vector_array_get(&w2, res, 2) + C.mlx_vector_array_free(res) + + return newArray(w0), newArray(w1), newArray(w2) +} + +// Dequantize reconstructs weights from quantized form. +// groupSize: number of elements quantized together (default 64) +// bits: bits per element, 2, 4, or 8 (default 4) +// mode: "affine" (default) or "mxfp4" +func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array { + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + optDtype := C.mlx_optional_dtype{has_value: false} + + var b C.mlx_array + if biases != nil { + b = biases.c + } + + res := C.mlx_array_new() + C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream()) + return newArray(res) +} + +// QuantizedMatmul performs matrix multiplication with quantized weights. +// x: input tensor [batch..., in_features] +// w: quantized weights +// scales, biases: from Quantize +// transpose: if true, compute x @ w.T (typical for Linear layers) +// groupSize, bits, mode: must match what was used in Quantize +func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array { + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + + var b C.mlx_array + if biases != nil { + b = biases.c + } + + res := C.mlx_array_new() + C.mlx_quantized_matmul(&res, x.c, w.c, scales.c, b, C._Bool(transpose), optGroupSize, optBits, cMode, C.default_stream()) + return newArray(res) +} + +// ============ Sorting and Top-K ============ + +// TopK returns the k largest elements along an axis +func TopK(a *Array, k int, axis int) *Array { + res := C.mlx_array_new() + C.mlx_topk_axis(&res, a.c, C.int(k), C.int(axis), C.default_stream()) + return newArray(res) +} + +// Argpartition returns indices for partial sort (k-th smallest first) +func Argpartition(a *Array, kth int, axis int) *Array { + res := C.mlx_array_new() + C.mlx_argpartition_axis(&res, a.c, C.int(kth), C.int(axis), C.default_stream()) + return newArray(res) +} + +// TakeAlongAxis takes elements from array using indices along axis +func TakeAlongAxis(a, indices *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_take_along_axis(&res, a.c, indices.c, C.int(axis), C.default_stream()) + return newArray(res) +} + +// PutAlongAxis puts values into array at indices along axis +func PutAlongAxis(a, indices, values *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_put_along_axis(&res, a.c, indices.c, values.c, C.int(axis), C.default_stream()) + return newArray(res) +} + +// Cumsum computes cumulative sum along an axis +func Cumsum(a *Array, axis int) *Array { + res := C.mlx_array_new() + C.mlx_cumsum(&res, a.c, C.int(axis), false, false, C.default_stream()) + return newArray(res) +} + +// Where selects elements: condition ? a : b +func Where(condition, a, b *Array) *Array { + res := C.mlx_array_new() + C.mlx_where(&res, condition.c, a.c, b.c, C.default_stream()) + return newArray(res) +} + +// LessScalar returns element-wise a < scalar +func LessScalar(a *Array, s float32) *Array { + scalar := C.mlx_array_new_float(C.float(s)) + res := C.mlx_array_new() + C.mlx_less(&res, a.c, scalar, C.default_stream()) + C.mlx_array_free(scalar) + return newArray(res) +} + +// FullDtype creates an array filled with a value with specific dtype +func FullDtype(value float32, dtype Dtype, shape ...int32) *Array { + intShape := make([]C.int, len(shape)) + for i, s := range shape { + intShape[i] = C.int(s) + } + vals := C.mlx_array_new_float(C.float(value)) + res := C.mlx_array_new() + C.mlx_full(&res, &intShape[0], C.size_t(len(shape)), vals, C.mlx_dtype(dtype), C.default_stream()) + C.mlx_array_free(vals) + return newArray(res) +} + +// AsType casts an array to a different dtype +func AsType(a *Array, dtype Dtype) *Array { + res := C.mlx_array_new() + C.mlx_astype(&res, a.c, C.mlx_dtype(dtype), C.default_stream()) + return newArray(res) +} + +// ToBFloat16 casts an array to bfloat16 +func ToBFloat16(a *Array) *Array { + return AsType(a, DtypeBFloat16) +} + +// ============ VibeVoice Helper Functions ============ + +// NewScalarArray creates a true 0-dimensional scalar array from a float32 value +func NewScalarArray(value float32) *Array { + return newArray(C.mlx_array_new_float(C.float(value))) +} + +// Global random seed counter for RandN +var randnSeedCounter uint64 = uint64(time.Now().UnixNano()) + +// RandN creates an array of random samples from a standard normal distribution +func RandN(shape []int32) *Array { + // Use incrementing seed for unique random values each call + seed := atomic.AddUint64(&randnSeedCounter, 1) + return RandomNormal(shape, seed) +} + +// Pad pads an array with zeros +// paddings: [before_0, after_0, before_1, after_1, ...] for each dimension +func Pad(a *Array, paddings []int32) *Array { + numAxes := len(paddings) / 2 + // Convert to low/high pairs + lowPad := make([]C.int, numAxes) + highPad := make([]C.int, numAxes) + for i := 0; i < numAxes; i++ { + lowPad[i] = C.int(paddings[i*2]) + highPad[i] = C.int(paddings[i*2+1]) + } + zero := C.mlx_array_new_float(0.0) + res := C.mlx_array_new() + // mlx_pad takes axes, low, high arrays + axes := make([]C.int, numAxes) + for i := 0; i < numAxes; i++ { + axes[i] = C.int(i) + } + cMode := C.CString("constant") + defer C.free(unsafe.Pointer(cMode)) + C.mlx_pad(&res, a.c, &axes[0], C.size_t(numAxes), &lowPad[0], C.size_t(numAxes), &highPad[0], C.size_t(numAxes), zero, cMode, C.default_stream()) + C.mlx_array_free(zero) + return newArray(res) +} + +// Conv1d performs 1D convolution +// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout) +// bias: optional (nil for no bias) +func Conv1d(x, weight *Array, bias *Array, stride int32) *Array { + res := C.mlx_array_new() + C.mlx_conv1d(&res, x.c, weight.c, C.int(stride), C.int(0), C.int(1), 1, C.default_stream()) + // Apply bias if provided + if bias != nil { + biased := C.mlx_array_new() + C.mlx_add(&biased, res, bias.c, C.default_stream()) + C.mlx_array_free(res) + return newArray(biased) + } + return newArray(res) +} + +// ConvTranspose1d performs transposed 1D convolution +// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout) +// bias: optional (nil for no bias) +func ConvTranspose1d(x, weight *Array, bias *Array, stride int32) *Array { + res := C.mlx_array_new() + // stride, padding, dilation, output_padding, groups + C.mlx_conv_transpose1d(&res, x.c, weight.c, C.int(stride), 0, 1, 0, 1, C.default_stream()) + // Apply bias if provided + if bias != nil { + biased := C.mlx_array_new() + C.mlx_add(&biased, res, bias.c, C.default_stream()) + C.mlx_array_free(res) + return newArray(biased) + } + return newArray(res) +} + +// DepthwiseConv1d performs depthwise 1D convolution (groups=Cin) +// x: [B, L, C], weight: [1, K, C] (groups = C) +// bias: optional (nil for no bias) +func DepthwiseConv1d(x, weight *Array, bias *Array) *Array { + // Get number of input channels for groups + shape := x.Shape() + groups := int(shape[len(shape)-1]) + res := C.mlx_array_new() + C.mlx_conv1d(&res, x.c, weight.c, 1, 0, 1, C.int(groups), C.default_stream()) + // Apply bias if provided + if bias != nil { + biased := C.mlx_array_new() + C.mlx_add(&biased, res, bias.c, C.default_stream()) + C.mlx_array_free(res) + return newArray(biased) + } + return newArray(res) +} + +// SliceAxis extracts a slice along a specific axis +func SliceAxis(a *Array, axis int, start, stop int32) *Array { + shape := a.Shape() + + // Build start and stop indices for all dimensions + starts := make([]int32, len(shape)) + stops := make([]int32, len(shape)) + for i := range shape { + if i == axis { + starts[i] = start + stops[i] = stop + } else { + starts[i] = 0 + stops[i] = shape[i] + } + } + + return Slice(a, starts, stops) +} + +// Tri creates a lower triangular matrix +func Tri(n, m int32, k int) *Array { + res := C.mlx_array_new() + C.mlx_tri(&res, C.int(n), C.int(m), C.int(k), C.MLX_FLOAT32, C.default_stream()) + return newArray(res) +} diff --git a/x/imagegen/mlx/mlx_test.go b/x/imagegen/mlx/mlx_test.go new file mode 100644 index 000000000..db8fe394f --- /dev/null +++ b/x/imagegen/mlx/mlx_test.go @@ -0,0 +1,1145 @@ +//go:build mlx + +package mlx + +import ( + "fmt" + "testing" +) + +// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive. +func TestBasicCleanup(t *testing.T) { + weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + intermediate := NewArrayFloat32([]float32{1, 1}, []int32{1, 2}) + result := Matmul(intermediate, weight) + Keep(result) + + // Before eval: intermediate should be valid + if !intermediate.Valid() { + t.Fatal("intermediate should be valid before Eval") + } + + Eval(result) + + // After eval: intermediate should be freed + if intermediate.Valid() { + t.Fatal("intermediate should be freed after Eval") + } + + // Result should have correct values + data := result.Data() + if data[0] != 4 || data[1] != 6 { + t.Errorf("expected [4, 6], got %v", data) + } + + // Weight should survive + if !weight.Valid() { + t.Error("weight was freed") + } +} + +// TestKeptSurvives verifies kept arrays are not freed. +func TestKeptSurvives(t *testing.T) { + a := NewArrayFloat32([]float32{1, 2}, []int32{2}) + b := NewArrayFloat32([]float32{3, 4}, []int32{2}) + result := Add(a, b) + Keep(result) + + Eval(result) + + if !result.Valid() { + t.Error("kept result was freed") + } + + data := result.Data() + if data[0] != 4 || data[1] != 6 { + t.Errorf("expected [4, 6], got %v", data) + } +} + +// TestEvalAutoKeeps verifies Eval automatically keeps its outputs. +func TestEvalAutoKeeps(t *testing.T) { + a := NewArrayFloat32([]float32{1, 2}, []int32{2}) + b := NewArrayFloat32([]float32{3, 4}, []int32{2}) + result := Add(a, b) + + // Don't call Keep(result) - Eval should auto-keep it + Eval(result) + + // Result should survive (auto-kept by Eval) + if !result.Valid() { + t.Error("Eval output was freed - should be auto-kept") + } + + // Inputs should be freed (not kept) + if a.Valid() { + t.Error("input 'a' should be freed") + } + if b.Valid() { + t.Error("input 'b' should be freed") + } + + // Verify data is correct + data := result.Data() + if data[0] != 4 || data[1] != 6 { + t.Errorf("expected [4, 6], got %v", data) + } +} + +// TestWeightsSurvive verifies kept arrays survive multiple Eval cycles. +func TestWeightsSurvive(t *testing.T) { + weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + for i := 0; i < 5; i++ { + x := NewArrayFloat32([]float32{1, 1}, []int32{1, 2}) + result := Matmul(x, weight) + Keep(result) + Eval(result) + } + + if !weight.Valid() { + t.Error("weight was freed after multiple iterations") + } +} + +// TestAsyncEvalCleanup verifies AsyncEval cleans up and dispatches. +func TestAsyncEvalCleanup(t *testing.T) { + weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) // Identity matrix + Keep(weight) + weight.Eval() + + // First async step + x1 := NewArrayFloat32([]float32{1, 2}, []int32{1, 2}) + result1 := Matmul(x1, weight) + Keep(result1) + AsyncEval(result1) + + // Second async step + x2 := NewArrayFloat32([]float32{3, 4}, []int32{1, 2}) + result2 := Matmul(x2, weight) + Keep(result2) + AsyncEval(result2) + + // Sync and verify results + result1.Eval() + d1 := result1.Data() + if d1[0] != 1 || d1[1] != 2 { + t.Errorf("result1: expected [1, 2], got %v", d1) + } + + result2.Eval() + d2 := result2.Data() + if d2[0] != 3 || d2[1] != 4 { + t.Errorf("result2: expected [3, 4], got %v", d2) + } + + if !weight.Valid() { + t.Error("weight was freed during async") + } +} + +// TestMultiOutput verifies multiple kept arrays survive. +func TestMultiOutput(t *testing.T) { + a := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) + sum := Add(a, a) + prod := Mul(a, a) + Keep(sum, prod) + + Eval(sum, prod) + + // Both kept arrays should be valid + if !sum.Valid() || !prod.Valid() { + t.Error("kept arrays should survive cleanup") + } + + // Verify values + sumData := sum.Data() + prodData := prod.Data() + if sumData[0] != 2 || prodData[0] != 1 { + t.Errorf("unexpected results: sum=%v prod=%v", sumData, prodData) + } +} + +// TestChaining verifies output from one step can be used in next. +func TestChaining(t *testing.T) { + weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + // First step + x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2}) + out1 := Matmul(x, weight) + Keep(out1) + AsyncEval(out1) + + // Second step uses output of first + out2 := Add(out1, out1) + Keep(out2) + Eval(out2) + + // out1 should survive (was kept) + if !out1.Valid() { + t.Error("out1 was freed but used by second step") + } + + // Final result should be correct + data := out2.Data() + if data[0] != 2 || data[1] != 4 { + t.Errorf("expected [2, 4], got %v", data) + } +} + +// TestGenerationLoop simulates the LLM generation pattern with cache. +func TestGenerationLoop(t *testing.T) { + weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + // Simulate cache - starts as zeros + cache := NewArrayFloat32([]float32{0, 0}, []int32{1, 2}) + Keep(cache) + cache.Eval() + + var lastToken *Array + + // Simulate 5 generation steps + for step := 0; step < 5; step++ { + oldCache := cache + + // Simulate forward pass + input := NewArrayFloat32([]float32{float32(step + 1), float32(step + 2)}, []int32{1, 2}) + output := Matmul(input, weight) + + // Simulate cache update + newCache := Add(output, cache) + + // Mark what survives + Keep(output, newCache) + + if step < 4 { + AsyncEval(output, newCache) + } else { + Eval(output, newCache) + } + + // Free old cache, update references + oldCache.Free() + lastToken = output + cache = newCache + } + + // Token output should be valid + if !lastToken.Valid() { + t.Error("token output was freed") + } + + // Cache should be valid + if !cache.Valid() { + t.Error("cache was freed") + } + + // Weight should survive all iterations + if !weight.Valid() { + t.Error("weight was freed") + } +} + +// BenchmarkCleanupOnly isolates cleanup cost without MLX ops. +func BenchmarkCleanupOnly(b *testing.B) { + // Pre-create weight + weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Create 100 arrays - minimal ops + arrays := make([]*Array, 100) + for j := range arrays { + arrays[j] = NewArrayFloat32([]float32{1, 2}, []int32{1, 2}) + } + Keep(arrays[0]) + Eval() // Just cleanup + } +} + +// BenchmarkNewArrayOnly measures array creation overhead. +func BenchmarkNewArrayOnly(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) + } +} + +// BenchmarkCGOCallOverhead measures raw CGO call cost. +func BenchmarkCGOCallOverhead(b *testing.B) { + arr := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) + Keep(arr) + arr.Eval() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = arr.Ndim() // Simple CGO call + } +} + +// BenchmarkCleanup_50 measures cleanup with 50 arrays. +func BenchmarkCleanup_50(b *testing.B) { + benchCleanup(b, 50) +} + +// BenchmarkCleanup_500 measures cleanup with 500 arrays (LLM scale). +func BenchmarkCleanup_500(b *testing.B) { + benchCleanup(b, 500) +} + +// BenchmarkCleanup_1000 measures cleanup with 1000 arrays. +func BenchmarkCleanup_1000(b *testing.B) { + benchCleanup(b, 1000) +} + +func benchCleanup(b *testing.B, numArrays int) { + weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2}) + for j := 0; j < numArrays; j++ { + x = Add(x, x) + } + result := Matmul(x, weight) + Keep(result) + Eval(result) + } +} + +// BenchmarkGenerationLoop_10 simulates 10 token generation steps. +func BenchmarkGenerationLoop_10(b *testing.B) { + benchGenerationLoop(b, 10) +} + +// BenchmarkGenerationLoop_100 simulates 100 token generation steps. +func BenchmarkGenerationLoop_100(b *testing.B) { + benchGenerationLoop(b, 100) +} + +func benchGenerationLoop(b *testing.B, steps int) { + weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache := NewArrayFloat32([]float32{0, 0}, []int32{1, 2}) + Keep(cache) + cache.Eval() + + for step := 0; step < steps; step++ { + oldCache := cache + input := NewArrayFloat32([]float32{1, 2}, []int32{1, 2}) + output := Matmul(input, weight) + newCache := Add(output, cache) + Keep(output, newCache) + + if step < steps-1 { + AsyncEval(output, newCache) + } else { + Eval(output, newCache) + } + oldCache.Free() + cache = newCache + } + } +} + +// BenchmarkLLMForward simulates a realistic LLM forward pass with ~500 ops. +func BenchmarkLLMForward(b *testing.B) { + // Simulate weights for 32 layers + numLayers := 32 + weights := make([]*Array, numLayers*4) // q, k, v, o per layer + for i := range weights { + weights[i] = NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) + } + Keep(weights...) + Eval(weights...) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2}) + + // Simulate 32 transformer layers + for layer := 0; layer < numLayers; layer++ { + // Attention block (simplified) + q := Matmul(x, weights[layer*4]) + k := Matmul(x, weights[layer*4+1]) + v := Matmul(x, weights[layer*4+2]) + attn := Matmul(Softmax(Matmul(q, Transpose(k, 1, 0)), -1), v) + attnOut := Matmul(attn, weights[layer*4+3]) + + // Residual + layernorm (simplified) + x = Add(x, attnOut) + x = RMSNormNoWeight(x, 1e-5) + + // FFN (simplified as single matmul) + ffn := Matmul(x, weights[layer*4]) + ffn = SiLU(ffn) + x = Add(x, ffn) + } + Keep(x) + Eval(x) + } +} + +// ============ Compile Tests ============ + +// gelu implements GELU activation: x * 0.5 * (1 + erf(x / sqrt(2))) +func gelu(x *Array) *Array { + sqrt2 := NewScalarArray(1.4142135623730951) + half := NewScalarArray(0.5) + one := NewScalarArray(1.0) + scaled := Div(x, sqrt2) + erfd := Erf(scaled) + return Mul(Mul(x, half), Add(one, erfd)) +} + +// TestCompileBasic verifies compiled function produces correct output. +func TestCompileBasic(t *testing.T) { + x := NewArrayFloat32([]float32{-1, 0, 1, 2}, []int32{4}) + Keep(x) + x.Eval() + + // Uncompiled + expected := gelu(x) + Keep(expected) + Eval(expected) + + // Compiled + compiled := Compile(func(inputs []*Array) []*Array { + return []*Array{gelu(inputs[0])} + }) + defer compiled.Free() + + result := compiled.Call(x)[0] + Keep(result) + Eval(result) + + // Compare with tolerance + expData := expected.Data() + resData := result.Data() + for i := range expData { + diff := expData[i] - resData[i] + if diff < 0 { + diff = -diff + } + if diff > 1e-5 { + t.Errorf("mismatch at %d: expected %f, got %f (diff=%e)", i, expData[i], resData[i], diff) + } + } +} + +// TestCompileMultipleInputs verifies compiled function with multiple inputs. +func TestCompileMultipleInputs(t *testing.T) { + a := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{4}) + b := NewArrayFloat32([]float32{5, 6, 7, 8}, []int32{4}) + Keep(a, b) + Eval(a, b) + + compiled := Compile(func(inputs []*Array) []*Array { + sum := Add(inputs[0], inputs[1]) + prod := Mul(inputs[0], inputs[1]) + return []*Array{sum, prod} + }) + defer compiled.Free() + + outputs := compiled.Call(a, b) + Keep(outputs...) + Eval(outputs...) + + sumData := outputs[0].Data() + prodData := outputs[1].Data() + if sumData[0] != 6 || prodData[0] != 5 { + t.Errorf("unexpected: sum[0]=%f, prod[0]=%f", sumData[0], prodData[0]) + } +} + +// TestCompileReuse verifies compiled function can be called multiple times. +func TestCompileReuse(t *testing.T) { + compiled := Compile(func(inputs []*Array) []*Array { + return []*Array{Add(inputs[0], inputs[0])} + }) + defer compiled.Free() + + for i := 0; i < 5; i++ { + x := NewArrayFloat32([]float32{float32(i)}, []int32{1}) + Keep(x) + x.Eval() + result := compiled.Call(x)[0] + Keep(result) + Eval(result) + data := result.Data() + expected := float32(i * 2) + if data[0] != expected { + t.Errorf("iteration %d: expected %f, got %f", i, expected, data[0]) + } + } +} + +// BenchmarkGELUUncompiled benchmarks uncompiled GELU. +func BenchmarkGELUUncompiled(b *testing.B) { + x := RandomNormal([]int32{1000, 1024}, 42) + Keep(x) + x.Eval() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + y := x + for j := 0; j < 10; j++ { + y = gelu(y) + } + Keep(y) + Eval(y) + } +} + +// BenchmarkGELUCompiled benchmarks compiled GELU. +func BenchmarkGELUCompiled(b *testing.B) { + x := RandomNormal([]int32{1000, 1024}, 42) + Keep(x) + x.Eval() + + compiled := Compile(func(inputs []*Array) []*Array { + y := inputs[0] + for j := 0; j < 10; j++ { + y = gelu(y) + } + return []*Array{y} + }) + defer compiled.Free() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := compiled.Call(x) + Keep(result[0]) + Eval(result[0]) + } +} + +// TestCompileNoMemoryLeak verifies compiled functions don't leak memory. +func TestCompileNoMemoryLeak(t *testing.T) { + x := RandomNormal([]int32{100, 100}, 42) + Keep(x) + x.Eval() + + compiled := Compile(func(inputs []*Array) []*Array { + y := inputs[0] + for j := 0; j < 5; j++ { + y = gelu(y) + } + return []*Array{y} + }) + defer compiled.Free() + + // Warmup to establish baseline + for i := 0; i < 10; i++ { + result := compiled.Call(x) + Keep(result[0]) + Eval(result[0]) + result[0].Free() + } + + MetalResetPeakMemory() + initialMem := MetalGetActiveMemory() + + for i := 0; i < 100; i++ { + result := compiled.Call(x) + Keep(result[0]) + Eval(result[0]) + result[0].Free() + } + + Eval() // Final cleanup + + finalMem := MetalGetActiveMemory() + peakMem := MetalGetPeakMemory() + + // Memory should not grow significantly (allow 10MB slack for caching) + growth := int64(finalMem) - int64(initialMem) + if growth > 10*1024*1024 { + t.Errorf("memory grew by %d bytes over 100 iterations", growth) + } + t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB", + initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024) +} + +// TestCompileWithRandomState verifies compiled function can capture and update random state. +func TestCompileWithRandomState(t *testing.T) { + // Simulate logits for sampling + logits := NewArrayFloat32([]float32{0.1, 0.2, 0.3, 0.4}, []int32{1, 4}) + Keep(logits) + logits.Eval() + + // Initial random key + key := RandomKey(42) + Keep(key) + + // Compile a sampling function that splits the key + compiled := Compile(func(inputs []*Array) []*Array { + logits := inputs[0] + keyIn := inputs[1] + + // Split key: one for sampling, one for next iteration + key1, key2 := RandomSplit(keyIn) + + // Sample from logits + sample := RandomCategoricalWithKey(logits, key2, -1, 1) + + return []*Array{sample, key1} + }) + defer compiled.Free() + + // Run multiple sampling steps + samples := make([]int32, 10) + for i := 0; i < 10; i++ { + outputs := compiled.Call(logits, key) + Keep(outputs...) + Eval(outputs...) + samples[i] = outputs[0].ItemInt32() + key.Free() + key = outputs[1] + } + + // Verify we got valid samples (0-3) + for i, s := range samples { + if s < 0 || s > 3 { + t.Errorf("sample %d out of range: %d", i, s) + } + } + t.Logf("samples: %v", samples) + + // Verify samples aren't all the same (randomness works) + allSame := true + for i := 1; i < len(samples); i++ { + if samples[i] != samples[0] { + allSame = false + break + } + } + if allSame { + t.Error("all samples are the same - random state may not be updating") + } +} + +// swiGLU implements the GPT-OSS custom SwiGLU activation. +func swiGLU(gate, up *Array, alpha, limit float32) *Array { + gateClipped := ClipScalar(gate, 0, limit, false, true) + upClipped := ClipScalar(up, -limit, limit, true, true) + gluScaled := MulScalar(gateClipped, alpha) + sig := Sigmoid(gluScaled) + outGlu := Mul(gateClipped, sig) + return Mul(outGlu, AddScalar(upClipped, 1.0)) +} + +// TestCompileSwiGLU verifies compiled SwiGLU produces correct output. +func TestCompileSwiGLU(t *testing.T) { + gate := NewArrayFloat32([]float32{-1, 0, 1, 2, 5, 10}, []int32{6}) + up := NewArrayFloat32([]float32{-5, -1, 0, 1, 5, 10}, []int32{6}) + Keep(gate, up) + Eval(gate, up) + + const alpha float32 = 1.702 + const limit float32 = 7.0 + + // Uncompiled + expected := swiGLU(gate, up, alpha, limit) + Keep(expected) + Eval(expected) + + // Compiled + compiled := Compile(func(inputs []*Array) []*Array { + return []*Array{swiGLU(inputs[0], inputs[1], alpha, limit)} + }) + defer compiled.Free() + + result := compiled.Call(gate, up)[0] + Keep(result) + Eval(result) + + // Compare + expData := expected.Data() + resData := result.Data() + for i := range expData { + diff := expData[i] - resData[i] + if diff < 0 { + diff = -diff + } + if diff > 1e-5 { + t.Errorf("mismatch at %d: expected %f, got %f", i, expData[i], resData[i]) + } + } + t.Logf("SwiGLU results: %v", resData) +} + +// BenchmarkSwiGLUUncompiled benchmarks uncompiled SwiGLU. +func BenchmarkSwiGLUUncompiled(b *testing.B) { + gate := RandomNormal([]int32{1, 2880}, 42) + up := RandomNormal([]int32{1, 2880}, 43) + Keep(gate, up) + Eval(gate, up) + + const alpha float32 = 1.702 + const limit float32 = 7.0 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := swiGLU(gate, up, alpha, limit) + Keep(result) + Eval(result) + } +} + +// BenchmarkSwiGLUCompiled benchmarks compiled SwiGLU. +func BenchmarkSwiGLUCompiled(b *testing.B) { + gate := RandomNormal([]int32{1, 2880}, 42) + up := RandomNormal([]int32{1, 2880}, 43) + Keep(gate, up) + Eval(gate, up) + + const alpha float32 = 1.702 + const limit float32 = 7.0 + + compiled := Compile(func(inputs []*Array) []*Array { + return []*Array{swiGLU(inputs[0], inputs[1], alpha, limit)} + }) + defer compiled.Free() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := compiled.Call(gate, up) + Keep(result[0]) + Eval(result[0]) + } +} + +// BenchmarkSwiGLU10xUncompiled benchmarks 10 chained SwiGLU ops uncompiled. +func BenchmarkSwiGLU10xUncompiled(b *testing.B) { + x := RandomNormal([]int32{1, 2880}, 42) + Keep(x) + x.Eval() + + const alpha float32 = 1.702 + const limit float32 = 7.0 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + y := x + for j := 0; j < 10; j++ { + y = swiGLU(y, y, alpha, limit) + } + Keep(y) + Eval(y) + } +} + +// BenchmarkSwiGLU10xCompiled benchmarks 10 chained SwiGLU ops compiled. +func BenchmarkSwiGLU10xCompiled(b *testing.B) { + x := RandomNormal([]int32{1, 2880}, 42) + Keep(x) + x.Eval() + + const alpha float32 = 1.702 + const limit float32 = 7.0 + + compiled := Compile(func(inputs []*Array) []*Array { + y := inputs[0] + for j := 0; j < 10; j++ { + y = swiGLU(y, y, alpha, limit) + } + return []*Array{y} + }) + defer compiled.Free() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + result := compiled.Call(x) + Keep(result[0]) + Eval(result[0]) + } +} + +// ============ Sampler Benchmarks ============ + +// sampleTopK implements top-k sampling +func sampleTopK(logits, key *Array, k int) (*Array, *Array) { + neg := Neg(logits) + indices := Argpartition(neg, k-1, -1) + topK := Slice(indices, []int32{0}, []int32{int32(k)}) + values := TakeAlongAxis(logits, topK, -1) + key1, key2 := RandomSplit(key) + sampled := RandomCategoricalWithKey(values, key2, -1, 1) + return Take(topK, sampled, -1), key1 +} + +// sampleTopP implements top-p (nucleus) sampling +func sampleTopP(logits, key *Array, p float32, vocabSize int32) (*Array, *Array) { + sorted := Argsort(Neg(logits), -1) + sortedLogits := TakeAlongAxis(logits, sorted, -1) + probs := Softmax(sortedLogits, -1) + cumProbs := Cumsum(probs, -1) + mask := LessScalar(cumProbs, p) + negInf := FullDtype(float32(-1e9), logits.Dtype(), vocabSize) + masked := Where(mask, sortedLogits, negInf) + key1, key2 := RandomSplit(key) + sampled := RandomCategoricalWithKey(masked, key2, -1, 1) + return Take(sorted, sampled, -1), key1 +} + +// BenchmarkSampleTopKUncompiled benchmarks uncompiled top-k sampling. +func BenchmarkSampleTopKUncompiled(b *testing.B) { + vocabSize := int32(32000) + logits := RandomNormal([]int32{vocabSize}, 42) + key := RandomKey(42) + Keep(logits, key) + Eval(logits, key) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var token *Array + token, key = sampleTopK(logits, key, 40) + Keep(token, key) + Eval(token) + } +} + +// BenchmarkSampleTopKCompiled benchmarks compiled top-k sampling. +func BenchmarkSampleTopKCompiled(b *testing.B) { + vocabSize := int32(32000) + logits := RandomNormal([]int32{vocabSize}, 42) + key := RandomKey(42) + Keep(logits, key) + Eval(logits, key) + + compiled := Compile(func(inputs []*Array) []*Array { + token, newKey := sampleTopK(inputs[0], inputs[1], 40) + return []*Array{token, newKey} + }) + defer compiled.Free() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + outputs := compiled.Call(logits, key) + Keep(outputs...) + Eval(outputs[0]) + key = outputs[1] + } +} + +// BenchmarkSampleTopPUncompiled benchmarks uncompiled top-p sampling. +func BenchmarkSampleTopPUncompiled(b *testing.B) { + vocabSize := int32(32000) + logits := RandomNormal([]int32{vocabSize}, 42) + key := RandomKey(42) + Keep(logits, key) + Eval(logits, key) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var token *Array + token, key = sampleTopP(logits, key, 0.9, vocabSize) + Keep(token, key) + Eval(token) + } +} + +// BenchmarkSampleTopPCompiled benchmarks compiled top-p sampling. +func BenchmarkSampleTopPCompiled(b *testing.B) { + vocabSize := int32(32000) + logits := RandomNormal([]int32{vocabSize}, 42) + key := RandomKey(42) + Keep(logits, key) + Eval(logits, key) + + compiled := Compile(func(inputs []*Array) []*Array { + token, newKey := sampleTopP(inputs[0], inputs[1], 0.9, vocabSize) + return []*Array{token, newKey} + }) + defer compiled.Free() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + outputs := compiled.Call(logits, key) + Keep(outputs...) + Eval(outputs[0]) + key = outputs[1] + } +} + +// TestCompiledSamplerMemoryStable verifies compiled samplers don't leak memory. +func TestCompiledSamplerMemoryStable(t *testing.T) { + vocabSize := int32(32000) + logits := RandomNormal([]int32{vocabSize}, 42) + key := RandomKey(42) + Keep(logits, key) + Eval(logits, key) + + compiledTopK := Compile(func(inputs []*Array) []*Array { + token, newKey := sampleTopK(inputs[0], inputs[1], 40) + return []*Array{token, newKey} + }) + defer compiledTopK.Free() + + compiledTopP := Compile(func(inputs []*Array) []*Array { + token, newKey := sampleTopP(inputs[0], inputs[1], 0.9, vocabSize) + return []*Array{token, newKey} + }) + defer compiledTopP.Free() + + // Warmup + for i := 0; i < 10; i++ { + out := compiledTopK.Call(logits, key) + Keep(out...) + Eval(out[0]) + out[0].Free() + key = out[1] + } + + MetalResetPeakMemory() + initialMem := MetalGetActiveMemory() + + // Run 500 iterations of each sampler + for i := 0; i < 500; i++ { + // TopK + out := compiledTopK.Call(logits, key) + Keep(out...) + Eval(out[0]) + out[0].Free() + key = out[1] + + // TopP + out = compiledTopP.Call(logits, key) + Keep(out...) + Eval(out[0]) + out[0].Free() + key = out[1] + } + + Eval() // Final cleanup + + finalMem := MetalGetActiveMemory() + peakMem := MetalGetPeakMemory() + + growth := int64(finalMem) - int64(initialMem) + t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB", + initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024) + + // Memory should stay bounded (allow 20MB for caching overhead) + if growth > 20*1024*1024 { + t.Errorf("memory grew by %d bytes over 1000 sampler calls - possible leak!", growth) + } +} + +// BenchmarkSimpleOps measures simple ops with cleanup +func BenchmarkSimpleOps(b *testing.B) { + weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) + Keep(weight) + weight.Eval() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2}) + result := Matmul(x, weight) + Keep(result) + AsyncEval(result) + result.Eval() + } +} + +// BenchmarkLayerLike measures layer-like ops (~15 ops) +func BenchmarkLayerLike(b *testing.B) { + hidden := int32(256) + w := Ones(hidden, hidden) + Keep(w) + w.Eval() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x := Ones(1, hidden) + // Simulate attention-like ops with proper shapes + h := Matmul(x, w) // [1, 256] @ [256, 256] = [1, 256] + h = Add(h, Matmul(h, w)) // residual + h = Mul(h, Sigmoid(Matmul(h, w))) // gating + h = Matmul(h, w) // output projection + h = Add(x, RMSNormNoWeight(h, 1e-5)) // residual + norm + Keep(h) + AsyncEval(h) + Eval(h) + } +} + +// BenchmarkManyOps measures with increasing op counts +func BenchmarkManyOps(b *testing.B) { + w := Ones(64, 64) + Keep(w) + w.Eval() + + for _, numOps := range []int{10, 50, 100, 500, 1000} { + b.Run(fmt.Sprintf("ops_%d", numOps), func(b *testing.B) { + for i := 0; i < b.N; i++ { + x := Ones(1, 64) + for j := 0; j < numOps; j++ { + x = Add(x, Matmul(x, w)) + } + Keep(x) + AsyncEval(x) + Eval(x) + } + }) + } +} + +// BenchmarkLLMScale measures at LLM-realistic scale (~1348 arrays) +func BenchmarkLLMScale(b *testing.B) { + // Simulate Qwen-like model: 24 layers, each with ~56 ops = 1344 arrays + numLayers := 24 + opsPerLayer := 56 + + // Create weights + hidden := int32(64) + weights := make([]*Array, numLayers*4) + for i := range weights { + weights[i] = Ones(hidden, hidden) + } + Keep(weights...) + Eval(weights...) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x := Ones(1, hidden) + + for layer := 0; layer < numLayers; layer++ { + for op := 0; op < opsPerLayer/4; op++ { + x = Add(x, Matmul(x, weights[layer*4])) + x = Mul(x, Sigmoid(x)) + } + } + Keep(x) + AsyncEval(x) + Eval(x) + } +} + +// BenchmarkArrayFreeLoop measures the cost of freeing N arrays +func BenchmarkArrayFreeLoop(b *testing.B) { + for _, count := range []int{100, 500, 1000, 1500} { + b.Run(fmt.Sprintf("arrays_%d", count), func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + arrays := make([]*Array, count) + for j := 0; j < count; j++ { + arrays[j] = NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) + } + b.StartTimer() + + // Cleanup all arrays + Eval() + } + }) + } +} + +// BenchmarkCleanupIsolated measures just cleanup time +func BenchmarkCleanupIsolated(b *testing.B) { + w := NewArrayFloat32([]float32{1}, []int32{1, 1}) + Keep(w) + w.Eval() + + for _, count := range []int{100, 500, 1000, 1500} { + b.Run(fmt.Sprintf("arrays_%d", count), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + x := NewArrayFloat32([]float32{1}, []int32{1}) + for j := 0; j < count; j++ { + x = Add(x, x) + } + Keep(x) + b.StartTimer() + Eval() // Just cleanup + } + }) + } +} + +// TestMemoryStable verifies that cleanup doesn't cause unbounded memory growth. +func TestMemoryStable(t *testing.T) { + if testing.Short() { + t.Skip("skipping memory test in short mode") + } + + // Create realistic-sized arrays (like KV cache) + batchSize := int32(1) + numHeads := int32(8) + seqLen := int32(256) + headDim := int32(64) + cacheShape := []int32{batchSize, numHeads, seqLen, headDim} + cacheSize := batchSize * numHeads * seqLen * headDim * 4 // float32 = 4 bytes + + // Initial cache + keys := Zeros(cacheShape, DtypeFloat32) + values := Zeros(cacheShape, DtypeFloat32) + Keep(keys, values) + Eval(keys, values) + + // Warmup + for i := 0; i < 5; i++ { + oldKeys, oldValues := keys, values + + newKeys := Add(keys, keys) + newValues := Add(values, values) + Keep(newKeys, newValues) + Eval(newKeys, newValues) + + oldKeys.Free() + oldValues.Free() + keys, values = newKeys, newValues + } + + MetalResetPeakMemory() + initialMem := MetalGetActiveMemory() + + // Run 100 steps + for step := 0; step < 100; step++ { + oldKeys, oldValues := keys, values + + newKeys := Add(keys, keys) + newValues := Add(values, values) + Keep(newKeys, newValues) + Eval(newKeys, newValues) + + oldKeys.Free() + oldValues.Free() + keys, values = newKeys, newValues + } + + Eval() // Final cleanup + + finalMem := MetalGetActiveMemory() + peakMem := MetalGetPeakMemory() + + growth := int64(finalMem) - int64(initialMem) + expectedMaxGrowth := int64(cacheSize * 4 * 10) + + t.Logf("cache size: %d bytes", cacheSize*2) + t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB", + initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024) + + if growth > expectedMaxGrowth { + t.Errorf("memory grew by %d bytes over 100 steps (expected max %d) - possible leak", + growth, expectedMaxGrowth) + } +} diff --git a/x/imagegen/models/gemma3/gemma3.go b/x/imagegen/models/gemma3/gemma3.go new file mode 100644 index 000000000..b56adc797 --- /dev/null +++ b/x/imagegen/models/gemma3/gemma3.go @@ -0,0 +1,614 @@ +//go:build mlx + +package gemma3 + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// TextConfig holds configuration for the text model +type TextConfig struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + RopeLocalBaseFreq float32 `json:"rope_local_base_freq"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + SlidingWindow int32 `json:"sliding_window"` + SlidingWindowPattern int32 `json:"sliding_window_pattern"` + + // Computed fields + Scale float32 `json:"-"` +} + +// TextModel is the Gemma 3 text-only model +type TextModel struct { + EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` + Layers []*DecoderLayer `weight:"model.layers"` + Norm *nn.RMSNorm `weight:"model.norm"` + Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually + + // Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward + NormScaled *mlx.Array `weight:"-"` + + tok *tokenizer.Tokenizer + *TextConfig +} + +// DecoderLayer is a single transformer block +type DecoderLayer struct { + InputNorm *nn.RMSNorm `weight:"input_layernorm"` + Attention *Attention + PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"` + PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"` + MLP *MLP + PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"` + + // Precomputed (1 + weight) for Gemma-style RMSNorm + InputNormScaled *mlx.Array `weight:"-"` + PostAttnNormScaled *mlx.Array `weight:"-"` + PreFFNormScaled *mlx.Array `weight:"-"` + PostFFNormScaled *mlx.Array `weight:"-"` + + // Whether this layer uses sliding window attention + IsSliding bool + LayerIdx int32 +} + +// Attention implements Gemma 3 attention with Q/K normalization +type Attention struct { + QProj *nn.Linear `weight:"self_attn.q_proj"` + KProj *nn.Linear `weight:"self_attn.k_proj"` + VProj *nn.Linear `weight:"self_attn.v_proj"` + OProj *nn.Linear `weight:"self_attn.o_proj"` + QNorm *nn.RMSNorm `weight:"self_attn.q_norm"` + KNorm *nn.RMSNorm `weight:"self_attn.k_norm"` + + // Precomputed (1 + weight) for Gemma-style RMSNorm + QNormScaled *mlx.Array `weight:"-"` + KNormScaled *mlx.Array `weight:"-"` +} + +// MLP is the feed-forward network with GELU activation +type MLP struct { + GateProj *nn.Linear `weight:"mlp.gate_proj"` + UpProj *nn.Linear `weight:"mlp.up_proj"` + DownProj *nn.Linear `weight:"mlp.down_proj"` +} + +// LoadText loads the text-only Gemma 3 model +func LoadText(modelPath string) (*TextModel, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + var cfg TextConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + // Compute scale + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + + // Set defaults if not specified + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 1000000 + } + if cfg.RopeLocalBaseFreq == 0 { + cfg.RopeLocalBaseFreq = 10000 + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + + weights, err := safetensors.LoadModelWeights(modelPath) + if err != nil { + return nil, fmt.Errorf("load weights: %w", err) + } + + tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + if err != nil { + return nil, fmt.Errorf("load tokenizer: %w", err) + } + + m := &TextModel{ + Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), + TextConfig: &cfg, + tok: tok, + } + + // Initialize layer metadata + for i := range m.Layers { + m.Layers[i] = &DecoderLayer{ + LayerIdx: int32(i), + IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern), + } + } + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return nil, err + } + + // Tied embeddings for output + m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil) + + mlx.Eval(mlx.Collect(m)...) + weights.ReleaseAll() + + // Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation + precomputeGemmaScaledWeights(m) + + return m, nil +} + +// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers +// This avoids creating temporary arrays on every forward pass +func precomputeGemmaScaledWeights(m *TextModel) { + m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0) + + for _, layer := range m.Layers { + layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0) + layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0) + layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0) + layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0) + + layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0) + layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0) + } + + // Eval all the precomputed weights + var scaled []*mlx.Array + scaled = append(scaled, m.NormScaled) + for _, layer := range m.Layers { + scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled, + layer.PreFFNormScaled, layer.PostFFNormScaled, + layer.Attention.QNormScaled, layer.Attention.KNormScaled) + } + mlx.Eval(scaled...) +} + +// isLayerSliding determines if a layer uses sliding window attention +// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc. +func isLayerSliding(layerIdx, pattern int32) bool { + if pattern <= 0 { + return false // No sliding window + } + // Layer is full attention if (layerIdx + 1) % pattern == 0 + return (layerIdx+1)%pattern != 0 +} + +// Forward runs the text model forward pass +func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + B, L := tokens.Shape()[0], tokens.Shape()[1] + + // Get embeddings and scale by sqrt(hidden_size) + h := m.EmbedTokens.Forward(tokens) + h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize)))) + + for i, layer := range m.Layers { + h = layer.Forward(h, caches[i], B, L, m.TextConfig) + } + + // Final norm and output projection + return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps)) +} + +// Forward runs a decoder layer +func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { + // Pre-attention norm (use precomputed scaled weight) + normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) + + // Attention + attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg) + + // Post-attention norm and residual + attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) + h := mlx.Add(x, attnOut) + + // Pre-FFN norm + normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) + + // MLP + mlpOut := l.MLP.Forward(normed) + + // Post-FFN norm and residual + mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) + return mlx.Add(h, mlpOut) +} + +// Forward runs attention with Q/K normalization +func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + // Reshape to [B, num_heads, L, head_dim] + q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) + k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + + // Q/K normalization after reshaping (use precomputed scaled weight) + q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps) + k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps) + + // Apply RoPE with appropriate theta + ropeTheta := cfg.RopeTheta + if isSliding { + ropeTheta = cfg.RopeLocalBaseFreq + } + q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + + // Update cache + k, v = c.Update(k, v, int(L)) + + // Repeat K/V for GQA if needed + repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads + if repeatFactor > 1 { + k = nn.RepeatKV(k, repeatFactor) + v = nn.RepeatKV(v, repeatFactor) + } + + // Attention + out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +// compiledGeluApprox is a singleton compiled GELU function shared across all layers +var compiledGeluApprox *mlx.CompiledFunc + +// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed +func getCompiledGeluApprox() *mlx.CompiledFunc { + if compiledGeluApprox == nil { + compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array { + return []*mlx.Array{geluApproxImpl(inputs[0])} + }, true) + } + return compiledGeluApprox +} + +// Forward runs the MLP with GELU approximation (tanh variant) +func (m *MLP) Forward(x *mlx.Array) *mlx.Array { + gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0] + return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) +} + +// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh): +// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +func geluApproxImpl(x *mlx.Array) *mlx.Array { + // Constants + const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi) + const coeff = 0.044715 + + // x^3 + x3 := mlx.Mul(mlx.Mul(x, x), x) + // x + 0.044715 * x^3 + inner := mlx.Add(x, mlx.MulScalar(x3, coeff)) + // sqrt(2/pi) * (x + 0.044715 * x^3) + scaled := mlx.MulScalar(inner, sqrt2OverPi) + // tanh(...) + tanh := mlx.Tanh(scaled) + // 1 + tanh(...) + onePlusTanh := mlx.AddScalar(tanh, 1.0) + // 0.5 * x * (1 + tanh(...)) + return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh) +} + +// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight) +// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight) +func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array { + // Gemma uses (1 + weight) instead of weight + scaledWeight := mlx.AddScalar(weight, 1.0) + return mlx.RMSNorm(x, scaledWeight, eps) +} + +// Interface methods +func (m *TextModel) NumLayers() int { return len(m.Layers) } +func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings } +func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize } + +// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template +func (m *TextModel) Tokenizer() *tokenizer.Tokenizer { + return m.tok +} + +// FormatPrompt applies the Gemma 3 chat template to a prompt +func (m *TextModel) FormatPrompt(prompt string) string { + // Gemma 3 chat format: user\n{prompt}\nmodel\n + return fmt.Sprintf("user\n%s\nmodel\n", prompt) +} + +func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + if m.Layers[i].IsSliding { + // Use rotating cache for sliding window layers + caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) + } else { + // Use regular cache for global attention layers + caches[i] = cache.NewKVCache() + } + } + return caches +} + +// Config holds config for the full multimodal model +type Config struct { + TextConfig TextConfig `json:"text_config"` + VisionConfig VisionConfig `json:"vision_config"` + + // Image token config (from config.json) + BOITokenIndex int32 `json:"boi_token_index"` // = 255999 + EOITokenIndex int32 `json:"eoi_token_index"` // = 256000 + ImageTokenIndex int32 `json:"image_token_index"` // = 262144 + MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256 +} + +// Model is the full Gemma 3 multimodal model +type Model struct { + VisionTower *VisionTower `weight:"vision_tower"` + Projector *MultiModalProjector `weight:"multi_modal_projector"` + TextModel *TextModel `weight:"language_model"` + Config *Config + tok *tokenizer.Tokenizer +} + +// Load loads the full multimodal Gemma 3 model +func Load(modelPath string) (*Model, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + var cfg Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + // Set defaults for text config (multimodal config often has incomplete text_config) + // These defaults match transformers.Gemma3TextConfig defaults + tc := &cfg.TextConfig + if tc.HeadDim == 0 { + tc.HeadDim = 256 // Gemma 3 uses head_dim=256 + } + if tc.NumAttentionHeads == 0 { + // Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim) + tc.NumAttentionHeads = 8 + } + if tc.NumKeyValueHeads == 0 { + // Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio) + tc.NumKeyValueHeads = 4 + } + if tc.VocabSize == 0 { + tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!) + } + if tc.RopeTheta == 0 { + tc.RopeTheta = 1000000 + } + if tc.RopeLocalBaseFreq == 0 { + tc.RopeLocalBaseFreq = 10000 + } + if tc.RMSNormEps == 0 { + tc.RMSNormEps = 1e-6 + } + if tc.SlidingWindowPattern == 0 { + tc.SlidingWindowPattern = 6 + } + if tc.MaxPositionEmbeddings == 0 { + tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default + } + + // Compute text model scale + tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim))) + + // Set defaults for image token config + if cfg.BOITokenIndex == 0 { + cfg.BOITokenIndex = 255999 // + } + if cfg.EOITokenIndex == 0 { + cfg.EOITokenIndex = 256000 // + } + if cfg.ImageTokenIndex == 0 { + cfg.ImageTokenIndex = 262144 // + } + if cfg.MMTokensPerImage == 0 { + cfg.MMTokensPerImage = 256 + } + + weights, err := safetensors.LoadModelWeights(modelPath) + if err != nil { + return nil, fmt.Errorf("load weights: %w", err) + } + + tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + if err != nil { + return nil, fmt.Errorf("load tokenizer: %w", err) + } + + m := &Model{ + VisionTower: &VisionTower{ + Embeddings: &VisionEmbeddings{}, + Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers), + Config: &cfg.VisionConfig, + }, + Projector: &MultiModalProjector{}, + TextModel: &TextModel{ + Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers), + TextConfig: &cfg.TextConfig, + }, + Config: &cfg, + tok: tok, + } + + // Initialize text layer metadata + for i := range m.TextModel.Layers { + m.TextModel.Layers[i] = &DecoderLayer{ + LayerIdx: int32(i), + IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern), + } + } + + // Initialize vision encoder layers + for i := range m.VisionTower.Encoder { + m.VisionTower.Encoder[i] = &VisionEncoderLayer{} + } + + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return nil, err + } + + // Tied embeddings for text output + m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil) + m.TextModel.tok = tok + + mlx.Eval(mlx.Collect(m)...) + weights.ReleaseAll() + + // Precompute (1 + weight) for Gemma-style RMSNorm + precomputeGemmaScaledWeights(m.TextModel) + + // Precompute projector's scaled weight + m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0) + mlx.Eval(m.Projector.SoftEmbNormScaled) + + return m, nil +} + +// Forward runs the text-only forward pass +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + return m.TextModel.Forward(tokens, caches) +} + +// ForwardWithImage runs the multimodal forward pass +// tokens: [B, L] input token IDs (with image placeholder tokens) +// image: [B, H, W, C] preprocessed image tensor +func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array { + B, L := tokens.Shape()[0], tokens.Shape()[1] + cfg := m.Config.TextConfig + + // Find image token position FIRST before any eval that might free tokens + imageStartPos := int32(-1) + if image != nil && B == 1 { + tokenData := tokens.DataInt32() // This evals tokens + for i, t := range tokenData { + if t == m.Config.ImageTokenIndex { + imageStartPos = int32(i) + break + } + } + } + + // Get text embeddings and scale + h := m.TextModel.EmbedTokens.Forward(tokens) + h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize)))) + + // Process image if provided + if image != nil && imageStartPos >= 0 { + // Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden] + visionFeatures := m.VisionTower.Forward(image) + + // Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden] + imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps) + + // Eval h and imageEmbeds together so neither gets freed + mlx.Eval(h, imageEmbeds) + + // Cast imageEmbeds to match text embeddings dtype (bf16) + if imageEmbeds.Dtype() != h.Dtype() { + imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype()) + mlx.Eval(imageEmbeds) + } + + // Insert image embeddings at the known position + h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos) + } + + // Run through text model layers + for i, layer := range m.TextModel.Layers { + h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig) + } + + // Final norm and output projection + return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps)) +} + +// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings +// at a known position (to avoid re-scanning tokens after eval) +// textEmbeds: [B, L, hidden_size] text embeddings +// imageEmbeds: [B, 256, hidden_size] image embeddings from projector +// startPos: starting position of image tokens in the sequence +func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array { + numImageTokens := imageEmbeds.Shape()[1] + L := textEmbeds.Shape()[1] + + // Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L] + afterStart := startPos + numImageTokens + + // Slice before image tokens: textEmbeds[:, 0:startPos, :] + before := mlx.SliceAxis(textEmbeds, 1, 0, startPos) + + // Slice after image tokens: textEmbeds[:, startPos+256:L, :] + after := mlx.SliceAxis(textEmbeds, 1, afterStart, L) + + // Concatenate: before + imageEmbeds + after along axis 1 + return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1) +} + +// Interface methods for Model +func (m *Model) NumLayers() int { return len(m.TextModel.Layers) } +func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings } +func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize } +func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } +func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) } +func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize } + +// FormatPrompt applies the Gemma 3 multimodal chat template +func (m *Model) FormatPrompt(prompt string) string { + return fmt.Sprintf("user\n%s\nmodel\n", prompt) +} + +// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image +func (m *Model) FormatPromptWithImage(prompt string) string { + return fmt.Sprintf("user\n%s\nmodel\n", prompt) +} + +// ExpandImageTokens expands into 256 image placeholder tokens +// Input tokens containing boi_token (255999) are expanded to: +// boi_token + 256 * image_token + eoi_token +func (m *Model) ExpandImageTokens(tokens []int32) []int32 { + result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1) + + for _, t := range tokens { + if t == m.Config.BOITokenIndex { + // Expand: boi + 256 * image_token + eoi + result = append(result, m.Config.BOITokenIndex) + for i := int32(0); i < m.Config.MMTokensPerImage; i++ { + result = append(result, m.Config.ImageTokenIndex) + } + result = append(result, m.Config.EOITokenIndex) + } else { + result = append(result, t) + } + } + + return result +} diff --git a/x/imagegen/models/gemma3/image.go b/x/imagegen/models/gemma3/image.go new file mode 100644 index 000000000..9532d852d --- /dev/null +++ b/x/imagegen/models/gemma3/image.go @@ -0,0 +1,58 @@ +//go:build mlx + +package gemma3 + +import ( + "fmt" + "image" + _ "image/jpeg" + _ "image/png" + "os" + + "github.com/ollama/ollama/x/imagegen/mlx" + "golang.org/x/image/draw" +) + +// ProcessImage loads and preprocesses an image for the vision tower +// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP +func ProcessImage(path string, imageSize int32) (*mlx.Array, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open image: %w", err) + } + defer f.Close() + + img, _, err := image.Decode(f) + if err != nil { + return nil, fmt.Errorf("decode image: %w", err) + } + + return ProcessImageData(img, imageSize) +} + +// ProcessImageData preprocesses an image.Image for the vision tower +func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) { + // Resize to target size using bilinear interpolation + resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize))) + draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil) + + // Convert to float32 array [H, W, C] and normalize + // SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0 + data := make([]float32, imageSize*imageSize*3) + idx := 0 + for y := int32(0); y < imageSize; y++ { + for x := int32(0); x < imageSize; x++ { + r, g, b, _ := resized.At(int(x), int(y)).RGBA() + // RGBA returns 16-bit values, convert to 8-bit + data[idx] = float32(r>>8)/127.5 - 1.0 + data[idx+1] = float32(g>>8)/127.5 - 1.0 + data[idx+2] = float32(b>>8)/127.5 - 1.0 + idx += 3 + } + } + + // Create MLX array [1, H, W, C] for NHWC layout + arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3}) + mlx.Eval(arr) // Materialize to prevent use-after-free + return arr, nil +} diff --git a/x/imagegen/models/gemma3/projector.go b/x/imagegen/models/gemma3/projector.go new file mode 100644 index 000000000..ecdbe6941 --- /dev/null +++ b/x/imagegen/models/gemma3/projector.go @@ -0,0 +1,50 @@ +//go:build mlx + +package gemma3 + +import ( + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" +) + +// MultiModalProjector projects vision features to text embedding space +type MultiModalProjector struct { + // mm_input_projection_weight: [vision_hidden, text_hidden] + InputProjection *mlx.Array `weight:"mm_input_projection_weight"` + SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"` + + // Precomputed (1 + weight) for Gemma-style RMSNorm + SoftEmbNormScaled *mlx.Array `weight:"-"` +} + +// Forward projects vision features to text space +// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152]) +// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560]) +func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array { + // Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152] + // 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens + B := visionFeatures.Shape()[0] + visionHidden := visionFeatures.Shape()[2] + + // Reshape to [B, 64, 64, hidden] + gridSize := int32(64) // sqrt(4096) + pooledSize := int32(16) // 64/4 + h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden) + + // Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling + h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden) + + // Average over pooling dimensions (axes 2 and 4) + h = mlx.Mean(h, 4, false) + h = mlx.Mean(h, 2, false) + + // h is now [B, 16, 16, hidden], reshape to [B, 256, hidden] + numTokens := pooledSize * pooledSize + h = mlx.Reshape(h, B, numTokens, visionHidden) + + // Apply Gemma-style RMS norm (use precomputed 1 + weight) + h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps) + + // Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden] + return mlx.Linear(h, p.InputProjection) +} diff --git a/x/imagegen/models/gemma3/vision.go b/x/imagegen/models/gemma3/vision.go new file mode 100644 index 000000000..1c4d8e54f --- /dev/null +++ b/x/imagegen/models/gemma3/vision.go @@ -0,0 +1,138 @@ +//go:build mlx + +package gemma3 + +import ( + "math" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" +) + +// VisionConfig holds configuration for the SigLIP vision tower +type VisionConfig struct { + HiddenSize int32 `json:"hidden_size"` + ImageSize int32 `json:"image_size"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + PatchSize int32 `json:"patch_size"` +} + +// VisionTower is the SigLIP vision encoder +type VisionTower struct { + Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"` + Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"` + PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"` + Config *VisionConfig +} + +// VisionEmbeddings handles patch and position embeddings +type VisionEmbeddings struct { + // PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX + PatchWeight *mlx.Array `weight:"patch_embedding.weight"` + PatchBias *mlx.Array `weight:"patch_embedding.bias"` + PosEmbed *nn.Embedding `weight:"position_embedding"` +} + +// VisionEncoderLayer is a single transformer encoder layer +type VisionEncoderLayer struct { + LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"` + Attention *VisionAttention `weight:"self_attn"` + LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"` + MLP *VisionMLP `weight:"mlp"` +} + +// VisionAttention implements multi-head self-attention +type VisionAttention struct { + QProj *nn.Linear `weight:"q_proj"` + KProj *nn.Linear `weight:"k_proj"` + VProj *nn.Linear `weight:"v_proj"` + OutProj *nn.Linear `weight:"out_proj"` +} + +// VisionMLP is the feed-forward network +type VisionMLP struct { + FC1 *nn.Linear `weight:"fc1"` + FC2 *nn.Linear `weight:"fc2"` +} + +// Forward runs the vision tower on preprocessed images +// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX) +// Output: [B, num_patches, hidden_size] +func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array { + // Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O] + // Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C] + weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1) + h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding + + // Add bias: [O] -> [1, 1, 1, O] for broadcasting + bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0]) + h = mlx.Add(h, bias) + + // h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden] + B := h.Shape()[0] + gridH, gridW := h.Shape()[1], h.Shape()[2] + hidden := h.Shape()[3] + numPatches := gridH * gridW + h = mlx.Reshape(h, B, numPatches, hidden) + + // Add position embeddings + posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32) + posEmbed := v.Embeddings.PosEmbed.Forward(posIds) + h = mlx.Add(h, posEmbed) + + // Encoder layers + headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads) + scale := float32(1.0 / math.Sqrt(float64(headDim))) + for _, layer := range v.Encoder { + h = layer.Forward(h, v.Config, scale) + } + + // Final layer norm + h = v.PostLayerNorm.Forward(h) + + return h +} + +// Forward runs a vision encoder layer +func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array { + // Pre-norm attention + h := l.LayerNorm1.Forward(x) + h = l.Attention.Forward(h, cfg, scale) + x = mlx.Add(x, h) + + // Pre-norm MLP + h = l.LayerNorm2.Forward(x) + h = l.MLP.Forward(h) + return mlx.Add(x, h) +} + +// Forward runs multi-head self-attention +func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array { + B, L := x.Shape()[0], x.Shape()[1] + headDim := cfg.HiddenSize / cfg.NumAttentionHeads + + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + // Reshape to [B, num_heads, L, head_dim] + q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3) + k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3) + v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3) + + // Scaled dot-product attention (no causal mask for vision) + out := mlx.ScaledDotProductAttention(q, k, v, scale, false) + + // Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden] + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize) + + return a.OutProj.Forward(out) +} + +// Forward runs the MLP with GELU activation +func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array { + h := mlx.GELU(m.FC1.Forward(x)) + return m.FC2.Forward(h) +} diff --git a/x/imagegen/models/gpt_oss/gpt_oss.go b/x/imagegen/models/gpt_oss/gpt_oss.go new file mode 100644 index 000000000..bbf01370f --- /dev/null +++ b/x/imagegen/models/gpt_oss/gpt_oss.go @@ -0,0 +1,487 @@ +//go:build mlx + +package gpt_oss + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// RopeScaling holds YaRN or other RoPE scaling configuration +type RopeScaling struct { + RopeType string `json:"rope_type"` + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` +} + +type Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + HeadDim int32 `json:"head_dim"` + SlidingWindow int32 `json:"sliding_window"` + NumLocalExperts int32 `json:"num_local_experts"` + NumExpertsPerTok int32 `json:"num_experts_per_tok"` + LayerTypes []string `json:"layer_types"` + SwiGLULimit float32 `json:"swiglu_limit"` + RopeScaling *RopeScaling `json:"rope_scaling"` + Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim) +} + +type Attention struct { + QProj *nn.Linear `weight:"self_attn.q_proj"` + KProj *nn.Linear `weight:"self_attn.k_proj"` + VProj *nn.Linear `weight:"self_attn.v_proj"` + OProj *nn.Linear `weight:"self_attn.o_proj"` + Sinks *mlx.Array `weight:"self_attn.sinks,optional"` + YarnFreqs *mlx.Array // computed + YarnMscale float32 +} + +// swiGLU applies the GPT-OSS custom SwiGLU activation. +// Formula: (gate * sigmoid(alpha * gate)) * (up + 1) +// with clipping: gate to [None, limit], up to [-limit, limit] +func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array { + // Clip gate to [None, limit] + gateClipped := mlx.ClipScalar(gate, 0, limit, false, true) + + // Clip up to [-limit, limit] + upClipped := mlx.ClipScalar(up, -limit, limit, true, true) + + // glu_scaled = alpha * gate_clipped + gluScaled := mlx.MulScalar(gateClipped, alpha) + + // sig = sigmoid(glu_scaled) + sig := mlx.Sigmoid(gluScaled) + + // out_glu = gate_clipped * sig + outGlu := mlx.Mul(gateClipped, sig) + + // result = out_glu * (up_clipped + 1) + return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0)) +} + +// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers +var compiledSwiGLU *mlx.CompiledFunc + +// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed +func getCompiledSwiGLU() *mlx.CompiledFunc { + if compiledSwiGLU == nil { + const alpha float32 = 1.702 + const limit float32 = 7.0 + compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array { + return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)} + }, true) // shapeless=true so it works for any input size + } + return compiledSwiGLU +} + +// ComputeYarnFreqs computes YaRN-modified RoPE frequencies +// Based on mlx-lm's YarnRoPE implementation +func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) { + // yarn_find_correction_dim + yarnFindCorrectionDim := func(numRotations float64) float64 { + return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base))) + } + + // yarn_find_correction_range + low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast)))) + high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow)))) + if low < 0 { + low = 0 + } + if high > int(dims)-1 { + high = int(dims) - 1 + } + + // yarn_get_mscale + yarnGetMscale := func(scale, mscale float64) float64 { + if scale <= 1 { + return 1.0 + } + return 0.1*mscale*math.Log(scale) + 1.0 + } + mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0)) + + // Compute frequencies + // freq_extra = base ** (arange(0, dims, 2) / dims) + // freq_inter = scaling_factor * freq_extra + halfDims := dims / 2 + freqData := make([]float32, halfDims) + for i := int32(0); i < halfDims; i++ { + exp := float64(2*i) / float64(dims) + freqExtra := math.Pow(float64(base), exp) + freqInter := float64(scalingFactor) * freqExtra + + // linear ramp mask + var freqMask float64 + if low == high { + freqMask = 0.0 + } else { + t := (float64(i) - float64(low)) / float64(high-low) + if t < 0 { + t = 0 + } + if t > 1 { + t = 1 + } + freqMask = 1.0 - t + } + + // Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask)) + freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask))) + } + + return mlx.NewArray(freqData, []int32{halfDims}), mscale +} + +// initYarn initializes YaRN RoPE if configured +func (a *Attention) initYarn(cfg *Config) { + a.YarnMscale = 1.0 + if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" { + a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs( + cfg.HeadDim, + cfg.RopeTheta, + cfg.RopeScaling.Factor, + cfg.RopeScaling.OriginalMaxPositionEmbeddings, + cfg.RopeScaling.BetaFast, + cfg.RopeScaling.BetaSlow, + ) + } +} + +func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + // Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim] + q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) + k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + + offset := 0 + if c != nil { + offset = c.Offset() + } + if a.YarnFreqs != nil { + if a.YarnMscale != 1.0 { + q = mlx.MulScalar(q, a.YarnMscale) + } + q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset) + k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset) + } else { + q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) + k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) + } + + if c != nil { + k, v = c.Update(k, v, int(L)) + } + + out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +// CreateSlidingWindowMask creates a causal mask with sliding window +// Mirrors mlx-lm's create_causal_mask with window_size +func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array { + // Build mask aligned to actual cache length (may be rotated) + // rinds covers existing keys: [keyStart, keyStart+keyLen) + // linds covers new queries: [queryStart, queryStart+seqLen) + rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen] + linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen] + + linds = mlx.ExpandDims(linds, 1) // [seqLen, 1] + rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen] + + causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen] + windowLimit := mlx.AddScalar(rinds, float32(windowSize)) + windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen] + + return mlx.LogicalAnd(causalMask, windowMask) +} + +// MoE represents the Mixture of Experts SwiGLU layer with quantized experts. +type MoE struct { + Router *nn.Linear `weight:"mlp.router"` + TopK int32 + HiddenSize int32 + GroupSize int + Bits int + // Expert weights (loaded manually via sanitizeExpertWeights) + GateBlocks, GateScales, GateBias *mlx.Array + UpBlocks, UpScales, UpBias *mlx.Array + DownBlocks, DownScales, DownBias *mlx.Array +} + +func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array { + logits := moe.Router.Forward(x) + neg := mlx.Neg(logits) + part := mlx.Argpartition(neg, int(moe.TopK)-1, -1) + topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK}) + topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1) + weights := mlx.Softmax(topKVal, -1) + + xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize) + idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK) + + doSort := B*L >= 64 + var invOrder *mlx.Array + sorted := false + n := B * L * moe.TopK + + if doSort { + idxAll := mlx.Flatten(idxFlat) + order := mlx.Argsort(idxAll, 0) + invOrder = mlx.Argsort(order, 0) + xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1) + idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) + sorted = true + } + + gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted) + up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted) + + if moe.GateBias != nil { + gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2)) + } + if moe.UpBias != nil { + up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2)) + } + + hidden := getCompiledSwiGLU().Call(gate, up)[0] + + down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted) + if moe.DownBias != nil { + down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2)) + } + + if doSort { + down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize) + } else { + down = mlx.Squeeze(down, 2) + } + + ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1) + return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize) +} + +type Block struct { + Attention *Attention + MLP *MoE + InputNorm *nn.RMSNorm `weight:"input_layernorm"` + PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"` + LayerType string // "sliding_attention" or "full_attention" +} + +func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array { + h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg)) + return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L)) +} + +type Model struct { + EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` + Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization + Norm *nn.RMSNorm `weight:"model.norm"` + LMHead *nn.Linear `weight:"lm_head"` + + tok *tokenizer.Tokenizer + *Config +} + +func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } +func (m *Model) NumLayers() int { return len(m.Layers) } +func (m *Model) VocabSize() int32 { return m.Config.VocabSize } + +func (m *Model) NewCache(int32) []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i, layer := range m.Layers { + if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 { + caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) + } else { + caches[i] = cache.NewKVCache() + } + } + return caches +} + +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + B, L := tokens.Shape()[0], tokens.Shape()[1] + x := m.EmbedTokens.Forward(tokens) + + // Find representative cache indices for sliding window attention + var swaIdx int = -1 + for i, layer := range m.Layers { + if layer.LayerType == "sliding_attention" { + swaIdx = i + break + } + } + + // Create masks once at model level + var fullMask, swaMask *mlx.Array + var fullMaskMode, swaMaskMode string + + if L > 1 { + fullMaskMode = "causal" + if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil { + c := caches[swaIdx] + offset := c.Offset() + windowSize := int(m.SlidingWindow) + cacheLen := min(int(L), windowSize) + if offset > 0 { + cacheLen = min(c.Len()+int(L), windowSize) + } + if int(L) > windowSize { + swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize) + } else { + swaMaskMode = "causal" + } + } else { + swaMaskMode = "causal" + } + } + + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil { + c = caches[i] + } + mask, maskMode := fullMask, fullMaskMode + if layer.LayerType == "sliding_attention" { + mask, maskMode = swaMask, swaMaskMode + } + x = layer.Forward(x, c, B, L, mask, maskMode, m.Config) + } + + return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps)) +} + +// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays. +// MXFP4 quantized weights require contiguous memory - strided views give wrong results. +func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) { + gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks") + gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales") + gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias") + downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks") + downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales") + downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias") + + moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias} + + if gateUpBlocks != nil { + gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1) + s := gub.Shape() + moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) + moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) + } + if gateUpScales != nil { + s := gateUpScales.Shape() + moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) + moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1})) + } + if gateUpBias != nil { + s := gateUpBias.Shape() + moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2})) + moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2})) + } + if downBlocks != nil { + moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1) + } + return moe +} + +func Load(modelPath string) (*Model, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + var cfg Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + + weights, err := safetensors.LoadModelWeights(modelPath) + if err != nil { + return nil, fmt.Errorf("load weights: %w", err) + } + + tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + if err != nil { + return nil, fmt.Errorf("load tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]*Block, cfg.NumHiddenLayers), + Config: &cfg, + tok: tok, + } + + // Load simple weights via struct tags + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return nil, err + } + + // Load layers with custom MoE handling + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + prefix := fmt.Sprintf("model.layers.%d", i) + layer := &Block{} + if err := safetensors.LoadModule(layer, weights, prefix); err != nil { + return nil, fmt.Errorf("layer %d: %w", i, err) + } + + // Initialize attention YaRN + layer.Attention.initYarn(&cfg) + + // Load MoE with weight sanitization + moe := sanitizeExpertWeights(weights, prefix) + moe.Router = layer.MLP.Router // Router was loaded by LoadModule + moe.TopK = cfg.NumExpertsPerTok + moe.HiddenSize = cfg.HiddenSize + layer.MLP = moe + + // Set layer type + layer.LayerType = "full_attention" + if int(i) < len(cfg.LayerTypes) { + layer.LayerType = cfg.LayerTypes[i] + } + + m.Layers[i] = layer + } + + // Release safetensors BEFORE eval - lazy arrays have captured data, + // this reduces peak memory by freeing mmap during materialization + weights.ReleaseAll() + mlx.Eval(mlx.Collect(m)...) + + return m, nil +} + +func (m *Model) MaxContextLength() int32 { + if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 { + return m.RopeScaling.OriginalMaxPositionEmbeddings + } + return 131072 +} diff --git a/x/imagegen/models/llama/llama.go b/x/imagegen/models/llama/llama.go new file mode 100644 index 000000000..2b695f78e --- /dev/null +++ b/x/imagegen/models/llama/llama.go @@ -0,0 +1,152 @@ +//go:build mlx + +package llama + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +type Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + HeadDim int32 `json:"-"` + Scale float32 `json:"-"` +} + +type Model struct { + EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` + Layers []*Layer `weight:"model.layers"` + Norm *nn.RMSNorm `weight:"model.norm"` + Output *nn.Linear `weight:"lm_head,optional"` + + tok *tokenizer.Tokenizer + *Config +} + +type Layer struct { + Attention *Attention + MLP *MLP + AttentionNorm *nn.RMSNorm `weight:"input_layernorm"` + MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"` +} + +type Attention struct { + QProj *nn.Linear `weight:"self_attn.q_proj"` + KProj *nn.Linear `weight:"self_attn.k_proj"` + VProj *nn.Linear `weight:"self_attn.v_proj"` + OProj *nn.Linear `weight:"self_attn.o_proj"` +} + +type MLP struct { + GateProj *nn.Linear `weight:"mlp.gate_proj"` + UpProj *nn.Linear `weight:"mlp.up_proj"` + DownProj *nn.Linear `weight:"mlp.down_proj"` +} + +func Load(modelPath string) (*Model, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + var cfg Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + + weights, err := safetensors.LoadModelWeights(modelPath) + if err != nil { + return nil, fmt.Errorf("load weights: %w", err) + } + + tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + if err != nil { + return nil, fmt.Errorf("load tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]*Layer, cfg.NumHiddenLayers), + Config: &cfg, + tok: tok, + } + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return nil, err + } + m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil) + + mlx.Eval(mlx.Collect(m)...) + weights.ReleaseAll() + + return m, nil +} + +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + B, L := tokens.Shape()[0], tokens.Shape()[1] + h := m.EmbedTokens.Forward(tokens) + for i, layer := range m.Layers { + h = layer.Forward(h, caches[i], B, L, m.Config) + } + return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps)) +} + +func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)) + return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps))) +} + +func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) + k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + + q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + + k, v = c.Update(k, v, int(L)) + out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +func (m *MLP) Forward(x *mlx.Array) *mlx.Array { + return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) +} + +// Interface methods +func (m *Model) NumLayers() int { return len(m.Layers) } +func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } +func (m *Model) VocabSize() int32 { return m.Config.VocabSize } +func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } + +func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + caches[i] = cache.NewKVCache() + } + return caches +} diff --git a/x/imagegen/models/qwen_image/pipeline_test.go b/x/imagegen/models/qwen_image/pipeline_test.go new file mode 100644 index 000000000..5625427f5 --- /dev/null +++ b/x/imagegen/models/qwen_image/pipeline_test.go @@ -0,0 +1,66 @@ +//go:build mlx + +package qwen_image + +import ( + "os" + "testing" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// TestPipelineOutput runs the full pipeline (integration test). +// Skips if model weights not found. Requires ~50GB VRAM. +func TestPipelineOutput(t *testing.T) { + modelPath := "../../../weights/Qwen-Image-2512" + if _, err := os.Stat(modelPath); os.IsNotExist(err) { + t.Skip("Skipping: model weights not found at " + modelPath) + } + + // Load model + pm, err := LoadPersistent(modelPath) + if err != nil { + t.Skipf("Skipping: failed to load model: %v", err) + } + + // Run 2-step pipeline (minimum for stable scheduler) + cfg := &GenerateConfig{ + Prompt: "a cat", + Width: 256, + Height: 256, + Steps: 2, + Seed: 42, + } + + output, err := pm.GenerateFromConfig(cfg) + if err != nil { + t.Fatalf("Pipeline failed: %v", err) + } + mlx.Eval(output) + + // Verify output shape [1, C, H, W] + shape := output.Shape() + if len(shape) != 4 { + t.Errorf("Expected 4D output, got %v", shape) + } + if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width { + t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width) + } + + // Verify values in expected range [0, 1] + data := output.Data() + minVal, maxVal := float32(1.0), float32(0.0) + for _, v := range data { + if v < minVal { + minVal = v + } + if v > maxVal { + maxVal = v + } + } + t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal) + + if minVal < -0.1 || maxVal > 1.1 { + t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal) + } +} diff --git a/x/imagegen/models/qwen_image/qwen25vl.go b/x/imagegen/models/qwen_image/qwen25vl.go new file mode 100644 index 000000000..af519ee7d --- /dev/null +++ b/x/imagegen/models/qwen_image/qwen25vl.go @@ -0,0 +1,1802 @@ +//go:build mlx + +package qwen_image + +import ( + "errors" + "fmt" + "math" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// Qwen25VLConfig holds Qwen2.5-VL configuration +type Qwen25VLConfig struct { + // Text model config + HiddenSize int32 `json:"hidden_size"` // 3584 + NumHiddenLayers int32 `json:"num_hidden_layers"` // 28 + IntermediateSize int32 `json:"intermediate_size"` // 18944 + NumAttentionHeads int32 `json:"num_attention_heads"` // 28 + NumKeyValueHeads int32 `json:"num_key_value_heads"` // 4 + VocabSize int32 `json:"vocab_size"` // 152064 + RMSNormEps float32 `json:"rms_norm_eps"` // 1e-6 + RopeTheta float32 `json:"rope_theta"` // 1000000 + HeadDim int32 // Calculated: HiddenSize / NumAttentionHeads + MRoPESection []int32 // [16, 24, 24] for temporal, height, width + + // Vision config + VisionHiddenSize int32 `json:"vision_hidden_size"` // 1280 + VisionNumLayers int32 `json:"vision_num_layers"` // 32 + VisionNumHeads int32 `json:"vision_num_heads"` // 16 + VisionIntermSize int32 `json:"vision_intermediate"` // 3420 + VisionPatchSize int32 `json:"vision_patch_size"` // 14 + VisionOutHiddenSize int32 `json:"vision_out_hidden"` // 3584 + VisionSpatialMerge int32 `json:"vision_spatial_merge"` // 2 + VisionWindowSize int32 `json:"vision_window_size"` // 112 + VisionFullAttIdx []int32 // [7, 15, 23, 31] + + // Special tokens + ImageTokenID int32 // 151655 + VisionStartTokenID int32 // 151652 + VisionEndTokenID int32 // 151653 +} + +// defaultQwen25VLConfig returns default config +func defaultQwen25VLConfig() *Qwen25VLConfig { + cfg := &Qwen25VLConfig{ + // Text + HiddenSize: 3584, + NumHiddenLayers: 28, + IntermediateSize: 18944, + NumAttentionHeads: 28, + NumKeyValueHeads: 4, + VocabSize: 152064, + RMSNormEps: 1e-6, + RopeTheta: 1000000, + MRoPESection: []int32{16, 24, 24}, + + // Vision + VisionHiddenSize: 1280, + VisionNumLayers: 32, + VisionNumHeads: 16, + VisionIntermSize: 3420, + VisionPatchSize: 14, + VisionOutHiddenSize: 3584, + VisionSpatialMerge: 2, + VisionWindowSize: 112, + VisionFullAttIdx: []int32{7, 15, 23, 31}, + + // Special tokens + ImageTokenID: 151655, + VisionStartTokenID: 151652, + VisionEndTokenID: 153653, + } + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + return cfg +} + +// Qwen25VL is the Qwen2.5-VL vision-language encoder +type Qwen25VL struct { + Config *Qwen25VLConfig + + // Text model + Embedding *mlx.Array + Blocks []*VLTextBlock + FinalNorm *mlx.Array + + // Vision tower (optional - nil for text-only models) + VisionPatchEmbed *VisionPatchEmbed + VisionBlocks []*VisionBlock + VisionMerger *VisionMerger + HasVision bool // True if vision tower is loaded +} + +// LoadTextOnly loads only the text encoder components (skips vision tower) +// Use this for text-to-image generation where vision components are not needed +func (m *Qwen25VL) LoadTextOnly(path string) error { + return m.load(path, false) +} + +// Load loads the vision-language encoder from a directory +// Vision components are loaded if weights exist +func (m *Qwen25VL) Load(path string) error { + return m.load(path, true) +} + +// load is the internal loading function +func (m *Qwen25VL) load(path string, loadVision bool) error { + fmt.Println("Loading Qwen2.5-VL encoder...") + + cfg := defaultQwen25VLConfig() + m.Config = cfg + + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + + // Bulk load all weights as bf16 + fmt.Print(" Loading weights as bf16... ") + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("failed to load weights: %w", err) + } + fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) + + // Load text embedding + fmt.Print(" Loading text embeddings... ") + embedding, err := weights.Get("model.embed_tokens.weight") + if err != nil { + return err + } + m.Embedding = embedding + fmt.Printf("✓ [%v]\n", embedding.Shape()) + + // Load text blocks + m.Blocks = make([]*VLTextBlock, cfg.NumHiddenLayers) + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + fmt.Printf("\r Loading text blocks... %d/%d", i+1, cfg.NumHiddenLayers) + block, err := newVLTextBlock(weights, int(i), cfg) + if err != nil { + return fmt.Errorf("failed to load text block %d: %w", i, err) + } + m.Blocks[i] = block + } + fmt.Printf("\r Loading text blocks... ✓ [%d blocks] \n", cfg.NumHiddenLayers) + + // Load final norm + fmt.Print(" Loading final norm... ") + finalNorm, err := weights.Get("model.norm.weight") + if err != nil { + return err + } + m.FinalNorm = finalNorm + fmt.Println("✓") + + // Try to load vision tower (optional) + m.HasVision = false + if loadVision { + if _, err := weights.Get("visual.patch_embed.proj.weight"); err == nil { + fmt.Print(" Loading vision patch embed... ") + m.VisionPatchEmbed, err = newVisionPatchEmbed(weights, cfg) + if err != nil { + return fmt.Errorf("vision patch embed: %w", err) + } + fmt.Println("✓") + + m.VisionBlocks = make([]*VisionBlock, cfg.VisionNumLayers) + for i := int32(0); i < cfg.VisionNumLayers; i++ { + fmt.Printf("\r Loading vision blocks... %d/%d", i+1, cfg.VisionNumLayers) + block, err := newVisionBlock(weights, int(i), cfg) + if err != nil { + return fmt.Errorf("failed to load vision block %d: %w", i, err) + } + m.VisionBlocks[i] = block + } + fmt.Printf("\r Loading vision blocks... ✓ [%d blocks] \n", cfg.VisionNumLayers) + + fmt.Print(" Loading vision merger... ") + m.VisionMerger, err = newVisionMerger(weights, cfg) + if err != nil { + return fmt.Errorf("vision merger: %w", err) + } + fmt.Println("✓") + + m.HasVision = true + } else { + fmt.Println(" (No vision tower - text-only mode)") + } + } else { + fmt.Println(" (Skipping vision tower)") + } + + weights.ReleaseAll() + return nil +} + +// EncodePrompt encodes a text prompt for image generation (text-only mode) +// Uses the Qwen-Image template and drops the first 34 tokens (system prefix) +func (m *Qwen25VL) EncodePrompt(tok *tokenizer.Tokenizer, prompt string) *mlx.Array { + cfg := m.Config + + // Template from Python: prompt_template_encode (for image generation) + template := "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n" + formattedPrompt := fmt.Sprintf(template, prompt) + + // Tokenize + tokens := tok.Encode(formattedPrompt, false) + + // Create token array + seqLen := int32(len(tokens)) + tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen}) + + // Get text embeddings + textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) + + // Compute RoPE + cossin := m.computeTextRoPE(seqLen, 1) + + // Forward through ALL text blocks + x := textEmbed + for _, block := range m.Blocks { + x = block.Forward(x, cossin) + } + + // Apply final norm + x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps) + + // Drop first 34 tokens (system prefix) + // prompt_template_encode_start_idx = 34 + dropIdx := int32(34) + if x.Shape()[1] > dropIdx { + x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize}) + } + + return x +} + +// EncodePromptWithImage encodes a text prompt with an image +// Returns: embeddings [B, L, hidden_size], mask [B, L], error +func (m *Qwen25VL) EncodePromptWithImage(tok *tokenizer.Tokenizer, prompt string, image *mlx.Array) (*mlx.Array, *mlx.Array, error) { + if !m.HasVision { + return nil, nil, errors.New("EncodePromptWithImage called on text-only model") + } + + cfg := m.Config + + // Template from Python diffusers pipeline: prompt_template_encode + // Python's _get_qwen_prompt_embeds adds "Picture 1: " before vision tokens + template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>%s<|im_end|>\n<|im_start|>assistant\n" + formattedPrompt := fmt.Sprintf(template, prompt) + + // Tokenize + tokens := tok.Encode(formattedPrompt, false) + + // Process vision if image provided + var visionEmbeddings *mlx.Array + var numImageTokens int32 + var visionH, visionW int32 // Grid dims in patches (before spatial merge) + if image != nil { + visionEmbeddings = m.encodeVision(image) + numImageTokens = visionEmbeddings.Shape()[1] + // Get original grid dimensions from image shape + imgShape := image.Shape() + visionH = imgShape[2] / cfg.VisionPatchSize // Height in patches + visionW = imgShape[3] / cfg.VisionPatchSize // Width in patches + } + + // Find image token position and expand + expandedTokens := make([]int32, 0, len(tokens)+int(numImageTokens)) + imageTokenPos := int32(-1) + textAfterCount := int32(0) + for i, t := range tokens { + if t == cfg.ImageTokenID { + imageTokenPos = int32(len(expandedTokens)) + // Insert placeholder tokens for image + for j := int32(0); j < numImageTokens; j++ { + expandedTokens = append(expandedTokens, cfg.ImageTokenID) + } + // Count remaining tokens after image + textAfterCount = int32(len(tokens) - i - 1) + } else { + expandedTokens = append(expandedTokens, t) + } + } + + // Create token array + seqLen := int32(len(expandedTokens)) + tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen}) + + // Get text embeddings + textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden] + + // Replace image token embeddings with vision embeddings + if visionEmbeddings != nil && imageTokenPos >= 0 { + // Split, replace, concat + before := mlx.Slice(textEmbed, []int32{0, 0, 0}, []int32{1, imageTokenPos, cfg.HiddenSize}) + after := mlx.Slice(textEmbed, []int32{0, imageTokenPos + numImageTokens, 0}, []int32{1, seqLen, cfg.HiddenSize}) + textEmbed = mlx.Concatenate([]*mlx.Array{before, visionEmbeddings, after}, 1) + } + + // Compute RoPE - use multimodal RoPE when image is present + var cossin [2]*mlx.Array + if image != nil && imageTokenPos >= 0 { + cossin = m.ComputeMultimodalRoPE(imageTokenPos, visionH, visionW, textAfterCount, cfg.VisionSpatialMerge) + } else { + cossin = m.computeTextRoPE(seqLen, 1) + } + + // Forward through ALL text blocks + // Python uses hidden_states[-1] (LAST layer output, not second-to-last!) + x := textEmbed + for _, block := range m.Blocks { + x = block.Forward(x, cossin) + } + + // Apply final norm (Python DOES apply this for the output) + x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps) + + // Drop first N tokens (system prefix) + // prompt_template_encode_start_idx = 64 + dropIdx := int32(64) + if x.Shape()[1] > dropIdx { + x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize}) + } + + // Create attention mask (all ones for now) + mask := mlx.Ones(1, x.Shape()[1]) + + return x, mask, nil +} + +// EncodeVision encodes an image through the vision tower (exported for testing) +// image: [B, C, H, W] normalized image tensor +// Returns: [B, num_tokens, hidden_size] vision embeddings +func (m *Qwen25VL) EncodeVision(image *mlx.Array) *mlx.Array { + return m.encodeVision(image) +} + +// VisionRegion describes where vision embeddings are inserted in the sequence +type VisionRegion struct { + StartPos int32 // Position in sequence where vision tokens start + NumTokens int32 // Number of vision tokens + GridH int32 // Vision grid height (in patches, after spatial merge) + GridW int32 // Vision grid width (in patches, after spatial merge) +} + +// EncodePromptWithImages encodes a text prompt with multiple images +// Returns: embeddings [B, L, hidden_size], mask [B, L], regions []VisionRegion, error +func (m *Qwen25VL) EncodePromptWithImages(tok *tokenizer.Tokenizer, prompt string, images []*mlx.Array) (*mlx.Array, *mlx.Array, []VisionRegion, error) { + if !m.HasVision { + return nil, nil, nil, errors.New("EncodePromptWithImages called on text-only model") + } + if len(images) == 0 { + return nil, nil, nil, errors.New("EncodePromptWithImages called with no images") + } + + cfg := m.Config + + // Build image prompt prefix: "Picture 1: ...Picture N: ..." + imgPromptTemplate := "Picture %d: <|vision_start|><|image_pad|><|vision_end|>" + imgPrompt := "" + for i := range images { + imgPrompt += fmt.Sprintf(imgPromptTemplate, i+1) + } + + // Template from Python diffusers pipeline: prompt_template_encode + template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n%s%s<|im_end|>\n<|im_start|>assistant\n" + formattedPrompt := fmt.Sprintf(template, imgPrompt, prompt) + + // Tokenize + tokens := tok.Encode(formattedPrompt, false) + + // Process each image through vision tower + visionEmbeddings := make([]*mlx.Array, len(images)) + numImageTokens := make([]int32, len(images)) + visionGridH := make([]int32, len(images)) + visionGridW := make([]int32, len(images)) + + for i, image := range images { + visionEmbeddings[i] = m.encodeVision(image) + numImageTokens[i] = visionEmbeddings[i].Shape()[1] + // Get original grid dimensions from image shape + imgShape := image.Shape() + visionH := imgShape[2] / cfg.VisionPatchSize // Height in patches + visionW := imgShape[3] / cfg.VisionPatchSize // Width in patches + // After spatial merge, grid is halved + visionGridH[i] = visionH / cfg.VisionSpatialMerge + visionGridW[i] = visionW / cfg.VisionSpatialMerge + } + + // Find all image token positions and expand tokens + expandedTokens := make([]int32, 0, len(tokens)+int(sum(numImageTokens))) + imagePositions := make([]int32, 0, len(images)) // Start position for each image's tokens + imageIdx := 0 + + for _, t := range tokens { + if t == cfg.ImageTokenID { + if imageIdx < len(images) { + imagePositions = append(imagePositions, int32(len(expandedTokens))) + // Insert placeholder tokens for this image + for j := int32(0); j < numImageTokens[imageIdx]; j++ { + expandedTokens = append(expandedTokens, cfg.ImageTokenID) + } + imageIdx++ + } + } else { + expandedTokens = append(expandedTokens, t) + } + } + + // Create token array + seqLen := int32(len(expandedTokens)) + tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen}) + + // Get text embeddings + textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden] + + // Replace image token embeddings with vision embeddings + // Build list of segments to concatenate + segments := make([]*mlx.Array, 0, len(images)*2+1) + regions := make([]VisionRegion, len(images)) + lastEnd := int32(0) + + for i, imgPos := range imagePositions { + // Text segment before this image + if imgPos > lastEnd { + segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, imgPos, cfg.HiddenSize})) + } + // Vision embeddings for this image + segments = append(segments, visionEmbeddings[i]) + regions[i] = VisionRegion{ + StartPos: imgPos, + NumTokens: numImageTokens[i], + GridH: visionGridH[i], + GridW: visionGridW[i], + } + lastEnd = imgPos + numImageTokens[i] + } + // Remaining text after last image + if lastEnd < seqLen { + segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, seqLen, cfg.HiddenSize})) + } + + // Concatenate all segments + textEmbed = mlx.Concatenate(segments, 1) + + // Compute RoPE - use multimodal RoPE for multiple images + cossin, err := m.ComputeMultiImageRoPE(imagePositions, visionGridH, visionGridW, numImageTokens, seqLen) + if err != nil { + return nil, nil, nil, fmt.Errorf("computing RoPE: %w", err) + } + + // Forward through ALL text blocks + x := textEmbed + for _, block := range m.Blocks { + x = block.Forward(x, cossin) + } + + // Apply final norm + x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps) + + // Drop first N tokens (system prefix) + // prompt_template_encode_start_idx = 64 + dropIdx := int32(64) + if x.Shape()[1] > dropIdx { + x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize}) + // Adjust region positions + for i := range regions { + regions[i].StartPos -= dropIdx + } + } + + // Create attention mask (all ones) + mask := mlx.Ones(1, x.Shape()[1]) + + return x, mask, regions, nil +} + +// sum returns the sum of int32 slice +func sum(arr []int32) int32 { + var s int32 + for _, v := range arr { + s += v + } + return s +} + +// EncodeTextOnly encodes text tokens through all text blocks (exported for testing) +// tokens: array of token IDs +// Returns: [B, L, hidden_size] text embeddings after all blocks +func (m *Qwen25VL) EncodeTextOnly(tokens []int32) *mlx.Array { + seqLen := int32(len(tokens)) + tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen}) + + // Get text embeddings + textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden] + + // Compute RoPE + cossin := m.computeTextRoPE(seqLen, 1) + + // Forward through ALL text blocks (unlike Encode which stops at second-to-last) + x := textEmbed + for _, block := range m.Blocks { + x = block.Forward(x, cossin) + } + + // Apply final norm + x = mlx.RMSNorm(x, m.FinalNorm, m.Config.RMSNormEps) + + return x +} + +// encodeVision encodes an image through the vision tower +// image: [B, C, H, W] normalized image tensor +// Returns: [B, num_tokens, hidden_size] vision embeddings +func (m *Qwen25VL) encodeVision(image *mlx.Array) *mlx.Array { + cfg := m.Config + + // Calculate grid dimensions from image + imgShape := image.Shape() + imgH := imgShape[2] + imgW := imgShape[3] + pH := imgH / cfg.VisionPatchSize // grid height in patches + pW := imgW / cfg.VisionPatchSize // grid width in patches + + // Patch embed + x := m.VisionPatchEmbed.Forward(image) + mlx.Eval(x) + + // Get window reordering info + winInfo := m.getWindowInfo(pH, pW) + + // Compute vision RoPE embeddings (already in 2x2-block order) + posEmb := m.computeVisionRoPE(pH, pW) + + shape := x.Shape() + B := shape[0] + L := shape[1] // num patches = pH * pW + D := shape[2] + spatialMergeUnit := winInfo.SpatialMergeUnit + spatialMerge := cfg.VisionSpatialMerge + + // Convert patch embed from row-major to 2x2-block order + // Row-major: (0,0), (0,1), (0,2), ..., (1,0), (1,1), ... + // 2x2-block: (0,0), (0,1), (1,0), (1,1), (0,2), (0,3), (1,2), (1,3), ... + llmGridH := pH / spatialMerge + llmGridW := pW / spatialMerge + blockReorderIdx := make([]int32, L) + idx := int32(0) + for hBlock := int32(0); hBlock < llmGridH; hBlock++ { + for wBlock := int32(0); wBlock < llmGridW; wBlock++ { + for dh := int32(0); dh < spatialMerge; dh++ { + for dw := int32(0); dw < spatialMerge; dw++ { + h := hBlock*spatialMerge + dh + w := wBlock*spatialMerge + dw + rowMajorIdx := h*pW + w + blockReorderIdx[idx] = rowMajorIdx + idx++ + } + } + } + } + blockIdxArr := mlx.NewArrayInt32(blockReorderIdx, []int32{L}) + x = mlx.Take(x, blockIdxArr, 1) // Reorder patches to 2x2-block order + + // Window reorder hidden states and RoPE before blocks + // Python: reshape to [L/4, 4, D], reorder dim 0, reshape back + // Reshape x: [B, L, D] -> [B, L/4, 4, D] + x = mlx.Reshape(x, B, L/spatialMergeUnit, spatialMergeUnit, D) + // Reorder using window index + winIdxArr := mlx.NewArrayInt32(winInfo.WindowIndex, []int32{int32(len(winInfo.WindowIndex))}) + x = mlx.Take(x, winIdxArr, 1) // Take along axis 1 + // Reshape back: [B, L/4, 4, D] -> [B, L, D] + x = mlx.Reshape(x, B, L, D) + + // Similarly reorder RoPE: [L, headDim] -> [L/4, 4, headDim] -> reorder -> [L, headDim] + cosShape := posEmb[0].Shape() + ropeL := cosShape[0] + ropeD := cosShape[1] + cos := mlx.Reshape(posEmb[0], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD) + sin := mlx.Reshape(posEmb[1], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD) + cos = mlx.Take(cos, winIdxArr, 0) + sin = mlx.Take(sin, winIdxArr, 0) + cos = mlx.Reshape(cos, ropeL, ropeD) + sin = mlx.Reshape(sin, ropeL, ropeD) + posEmb = [2]*mlx.Array{cos, sin} + + // Materialize to prevent freeing during block evaluations + mlx.Eval(x, posEmb[0], posEmb[1]) + + // Full sequence cu_seqlens for full attention blocks + cuSeqlensFull := []int32{0, L} + + // Vision blocks - use window attention except at full attention indices + for i, block := range m.VisionBlocks { + useFullAttention := false + for _, idx := range cfg.VisionFullAttIdx { + if int32(i) == idx { + useFullAttention = true + break + } + } + + var cuSeqlens []int32 + if useFullAttention { + cuSeqlens = cuSeqlensFull + } else { + cuSeqlens = winInfo.CuWindowSeqlens + } + + x = block.Forward(x, posEmb, cuSeqlens) + } + + // Spatial merge (2x2 -> 1) + x = m.VisionMerger.ForwardWithDims(x, pH, pW) + + // Reverse window reorder after merger + revIdxArr := mlx.NewArrayInt32(winInfo.ReverseIndex, []int32{int32(len(winInfo.ReverseIndex))}) + x = mlx.Take(x, revIdxArr, 1) + + return x +} + +// WindowInfo holds window reordering and attention boundary info +type WindowInfo struct { + WindowIndex []int32 // Reordering indices for merged tokens + ReverseIndex []int32 // Reverse reordering indices + CuWindowSeqlens []int32 // Cumulative window boundaries in UNMERGED sequence + SpatialMergeUnit int32 // Number of patches per merged token (4 = 2x2) +} + +// getWindowInfo computes window reordering indices and attention boundaries +// pH, pW: patch grid dimensions before 2x2 merge +func (m *Qwen25VL) getWindowInfo(pH, pW int32) *WindowInfo { + cfg := m.Config + spatialMergeUnit := cfg.VisionSpatialMerge * cfg.VisionSpatialMerge // 4 + + // After 2x2 merge + llmGridH := pH / cfg.VisionSpatialMerge + llmGridW := pW / cfg.VisionSpatialMerge + numTokens := llmGridH * llmGridW + + // Window size in merged tokens + // window_size=112, spatial_merge_size=2, patch_size=14 + // vit_merger_window_size = 112 / 2 / 14 = 4 + vitMergerWindowSize := cfg.VisionWindowSize / cfg.VisionSpatialMerge / cfg.VisionPatchSize + + // Calculate padding and number of windows + padH := vitMergerWindowSize - llmGridH%vitMergerWindowSize + if padH == vitMergerWindowSize { + padH = 0 + } + padW := vitMergerWindowSize - llmGridW%vitMergerWindowSize + if padW == vitMergerWindowSize { + padW = 0 + } + + numWindowsH := (llmGridH + padH) / vitMergerWindowSize + numWindowsW := (llmGridW + padW) / vitMergerWindowSize + + // Create padded grid with -1 for padding + paddedH := llmGridH + padH + paddedW := llmGridW + padW + grid := make([]int32, paddedH*paddedW) + for i := range grid { + grid[i] = -1 + } + for h := int32(0); h < llmGridH; h++ { + for w := int32(0); w < llmGridW; w++ { + grid[h*paddedW+w] = h*llmGridW + w + } + } + + // Reorder into windows and track window sizes + windowIndex := make([]int32, 0, numTokens) + windowSizes := make([]int32, 0, numWindowsH*numWindowsW) + ws := vitMergerWindowSize + + for wh := int32(0); wh < numWindowsH; wh++ { + for ww := int32(0); ww < numWindowsW; ww++ { + windowStart := len(windowIndex) + // Extract window + for h := int32(0); h < ws; h++ { + for w := int32(0); w < ws; w++ { + idx := (wh*ws+h)*paddedW + (ww*ws + w) + if grid[idx] >= 0 { + windowIndex = append(windowIndex, grid[idx]) + } + } + } + windowSize := int32(len(windowIndex) - windowStart) + windowSizes = append(windowSizes, windowSize) + } + } + + // Create reverse index (argsort of windowIndex) + reverseIndex := make([]int32, numTokens) + for i, idx := range windowIndex { + reverseIndex[idx] = int32(i) + } + + // Compute cumulative sequence lengths in UNMERGED sequence + // Each merged token corresponds to spatialMergeUnit patches + cuWindowSeqlens := make([]int32, len(windowSizes)+1) + cuWindowSeqlens[0] = 0 + for i, size := range windowSizes { + cuWindowSeqlens[i+1] = cuWindowSeqlens[i] + size*spatialMergeUnit + } + + return &WindowInfo{ + WindowIndex: windowIndex, + ReverseIndex: reverseIndex, + CuWindowSeqlens: cuWindowSeqlens, + SpatialMergeUnit: spatialMergeUnit, + } +} + +// ComputeMultiImageRoPE computes M-RoPE for combined text + multiple vision regions + text sequences +// This extends ComputeMultimodalRoPE to handle N images instead of just one. +// +// Parameters: +// - imagePositions: starting position of each image's tokens in the sequence +// - visionGridH, visionGridW: grid dimensions for each image (after spatial merge) +// - numImageTokens: number of tokens for each image +// - totalLen: total sequence length +func (m *Qwen25VL) ComputeMultiImageRoPE(imagePositions []int32, visionGridH, visionGridW, numImageTokens []int32, totalLen int32) ([2]*mlx.Array, error) { + numImages := len(imagePositions) + + // Build 3D position IDs: [3, 1, totalLen] + // Dimension 0: temporal, Dimension 1: height, Dimension 2: width + posIDs := make([]float32, 3*totalLen) + + // Process sequence in order + stIdx := int32(0) // Running text position counter + seqIdx := int32(0) + + for i := 0; i < numImages; i++ { + imgPos := imagePositions[i] + gridH := visionGridH[i] + gridW := visionGridW[i] + numTokens := numImageTokens[i] + + // Text segment before this image + for seqIdx < imgPos { + posIDs[0*totalLen+seqIdx] = float32(stIdx) + posIDs[1*totalLen+seqIdx] = float32(stIdx) + posIDs[2*totalLen+seqIdx] = float32(stIdx) + stIdx++ + seqIdx++ + } + + // Vision tokens for this image + // Python uses stIdx as base offset for all position dimensions + for h := int32(0); h < gridH; h++ { + for w := int32(0); w < gridW; w++ { + posIDs[0*totalLen+seqIdx] = float32(stIdx) // temporal: constant = stIdx + posIDs[1*totalLen+seqIdx] = float32(stIdx + h) // height: stIdx + row_index + posIDs[2*totalLen+seqIdx] = float32(stIdx + w) // width: stIdx + col_index + seqIdx++ + } + } + + // Verify we processed the expected number of tokens + if seqIdx != imgPos+numTokens { + return [2]*mlx.Array{}, fmt.Errorf("mismatch: processed %d but expected %d tokens for image %d", seqIdx-imgPos, numTokens, i) + } + + // Update stIdx for next text segment: max(temporal, height, width) + 1 + maxVisionPos := stIdx // temporal max + if stIdx+gridH-1 > maxVisionPos { + maxVisionPos = stIdx + gridH - 1 + } + if stIdx+gridW-1 > maxVisionPos { + maxVisionPos = stIdx + gridW - 1 + } + stIdx = maxVisionPos + 1 + } + + // Text after last image + for seqIdx < totalLen { + posIDs[0*totalLen+seqIdx] = float32(stIdx) + posIDs[1*totalLen+seqIdx] = float32(stIdx) + posIDs[2*totalLen+seqIdx] = float32(stIdx) + stIdx++ + seqIdx++ + } + + posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen}) + return m.computeRoPEFromPositions(posIDsArr, totalLen, 1), nil +} + +// computeTextRoPE computes M-RoPE for text-only sequences +func (m *Qwen25VL) computeTextRoPE(L, B int32) [2]*mlx.Array { + // For text-only, all 3 dims use same positions [0, 1, 2, ..., L-1] + posArr := make([]float32, L*3) + for d := 0; d < 3; d++ { + for i := int32(0); i < L; i++ { + posArr[int32(d)*L+i] = float32(i) + } + } + posIDs := mlx.NewArray(posArr, []int32{3, 1, L}) + posIDs = mlx.Tile(posIDs, []int32{1, B, 1}) + return m.computeRoPEFromPositions(posIDs, L, B) +} + +// ComputeMultimodalRoPE computes M-RoPE for combined text + vision + text sequences +// This matches Python's get_rope_index behavior exactly. +// Exported for testing. +// +// Python pattern discovered from testing: +// +// Vision row 1: temporal=stIdx, height=stIdx, width=[stIdx, stIdx+1, ..., stIdx+gridW-1] +// Vision row 2: temporal=stIdx, height=stIdx+1, width=[stIdx, stIdx+1, ..., stIdx+gridW-1] +// Text after: temporal=stIdx+1+i, height=stIdx+gridH+i, width=stIdx+gridW+i +func (m *Qwen25VL) ComputeMultimodalRoPE(textBefore, visionH, visionW, textAfter int32, spatialMerge int32) [2]*mlx.Array { + // Vision grid after spatial merge + llmGridH := visionH / spatialMerge + llmGridW := visionW / spatialMerge + visionLen := llmGridH * llmGridW + totalLen := textBefore + visionLen + textAfter + + // Build 3D position IDs: [3, 1, totalLen] + // Dimension 0: temporal, Dimension 1: height, Dimension 2: width + posIDs := make([]float32, 3*totalLen) + + // Text before vision: all dims same [0, 1, 2, ..., textBefore-1] + for d := 0; d < 3; d++ { + for i := int32(0); i < textBefore; i++ { + posIDs[int32(d)*totalLen+i] = float32(i) + } + } + + // Vision tokens: 3D grid positions + // Python uses stIdx (textBefore) as base offset for all position dimensions + stIdx := textBefore + for h := int32(0); h < llmGridH; h++ { + for w := int32(0); w < llmGridW; w++ { + idx := stIdx + h*llmGridW + w + posIDs[0*totalLen+idx] = float32(stIdx) // temporal: constant = stIdx + posIDs[1*totalLen+idx] = float32(stIdx + h) // height: stIdx + row_index + posIDs[2*totalLen+idx] = float32(stIdx + w) // width: stIdx + col_index + } + } + + // Text after vision: ALL dimensions continue from max(temporal, height, width) + 1 + // max is max(stIdx, stIdx+llmGridH-1, stIdx+llmGridW-1) = stIdx + max(0, llmGridH-1, llmGridW-1) + // Then st_idx = max + 1 + maxVisionPos := stIdx // temporal max + if stIdx+llmGridH-1 > maxVisionPos { + maxVisionPos = stIdx + llmGridH - 1 + } + if stIdx+llmGridW-1 > maxVisionPos { + maxVisionPos = stIdx + llmGridW - 1 + } + textAfterStart := maxVisionPos + 1 + for i := int32(0); i < textAfter; i++ { + seqIdx := textBefore + visionLen + i + posIDs[0*totalLen+seqIdx] = float32(textAfterStart + i) // temporal + posIDs[1*totalLen+seqIdx] = float32(textAfterStart + i) // height + posIDs[2*totalLen+seqIdx] = float32(textAfterStart + i) // width + } + + posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen}) + return m.computeRoPEFromPositions(posIDsArr, totalLen, 1) +} + +// computeRoPEFromPositions computes cos/sin from 3D position IDs +// posIDs: [3, B, L] where dim 0 is temporal, 1 is height, 2 is width +func (m *Qwen25VL) computeRoPEFromPositions(posIDs *mlx.Array, L, B int32) [2]*mlx.Array { + cfg := m.Config + half := cfg.HeadDim / 2 + + // Compute inv_freq + invFreqArr := make([]float32, half) + for i := int32(0); i < half; i++ { + invFreqArr[i] = float32(1.0 / math.Pow(float64(cfg.RopeTheta), 2.0*float64(i)/float64(cfg.HeadDim))) + } + invFreq := mlx.NewArray(invFreqArr, []int32{half}) + + // Process each position dimension + var cosAll, sinAll []*mlx.Array + for d := int32(0); d < 3; d++ { + // Get positions for this dimension: [B, L] + pos := mlx.Slice(posIDs, []int32{d, 0, 0}, []int32{d + 1, B, L}) + pos = mlx.Squeeze(pos, 0) // [B, L] + + posExp := mlx.ExpandDims(pos, 2) // [B, L, 1] + invFreqExp := mlx.Reshape(invFreq, 1, 1, half) // [1, 1, half] + freqs := mlx.Mul(posExp, invFreqExp) // [B, L, half] + emb := mlx.Tile(freqs, []int32{1, 1, 2}) // [B, L, D] + + cosAll = append(cosAll, mlx.ExpandDims(mlx.Cos(emb), 0)) + sinAll = append(sinAll, mlx.ExpandDims(mlx.Sin(emb), 0)) + } + + cos := mlx.Concatenate(cosAll, 0) // [3, B, L, D] + sin := mlx.Concatenate(sinAll, 0) + + return [2]*mlx.Array{cos, sin} +} + +// computeVisionRoPE computes RoPE embeddings for vision patches +// pH, pW: grid dimensions in patches +// Returns: [2]*mlx.Array containing (cos, sin) each of shape [numPatches, headDim] +func (m *Qwen25VL) computeVisionRoPE(pH, pW int32) [2]*mlx.Array { + cfg := m.Config + headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads // 80 for 1280/16 + halfDim := headDim / 2 // 40 + quarterDim := halfDim / 2 // 20 + spatialMerge := cfg.VisionSpatialMerge // 2 + + // Python Qwen2_5_VisionRotaryEmbedding uses dim=head_dim/2=40 + // inv_freq = 1.0 / (theta ** (arange(0, dim, 2) / dim)) -> 20 elements + theta := float64(10000.0) + invFreqArr := make([]float32, quarterDim) + for i := int32(0); i < quarterDim; i++ { + invFreqArr[i] = float32(1.0 / math.Pow(theta, float64(2*i)/float64(halfDim))) + } + invFreq := mlx.NewArray(invFreqArr, []int32{quarterDim}) + + // Create position IDs matching Python's 2x2 block ordering: + // Python does: reshape(h//2, 2, w//2, 2), permute(0, 2, 1, 3), flatten + // This groups patches by 2x2 merged token blocks + numPatches := pH * pW + hPosArr := make([]float32, numPatches) + wPosArr := make([]float32, numPatches) + + // Number of merged token blocks + llmGridH := pH / spatialMerge + llmGridW := pW / spatialMerge + + idx := int32(0) + for hBlock := int32(0); hBlock < llmGridH; hBlock++ { + for wBlock := int32(0); wBlock < llmGridW; wBlock++ { + // Within each 2x2 block: (0,0), (0,1), (1,0), (1,1) + for dh := int32(0); dh < spatialMerge; dh++ { + for dw := int32(0); dw < spatialMerge; dw++ { + h := hBlock*spatialMerge + dh + w := wBlock*spatialMerge + dw + hPosArr[idx] = float32(h) + wPosArr[idx] = float32(w) + idx++ + } + } + } + } + + hPos := mlx.NewArray(hPosArr, []int32{numPatches, 1}) + wPos := mlx.NewArray(wPosArr, []int32{numPatches, 1}) + invFreqExp := mlx.Reshape(invFreq, 1, quarterDim) + + // Compute freqs: [numPatches, quarterDim] for each of h and w + hFreqs := mlx.Mul(hPos, invFreqExp) // [L, 20] + wFreqs := mlx.Mul(wPos, invFreqExp) // [L, 20] + + // Concatenate h and w freqs: [numPatches, halfDim] = [L, 40] + freqs := mlx.Concatenate([]*mlx.Array{hFreqs, wFreqs}, 1) + + // Double for cos/sin application: [L, 40] -> [L, 80] = [L, headDim] + emb := mlx.Concatenate([]*mlx.Array{freqs, freqs}, 1) + + cos := mlx.Cos(emb) + sin := mlx.Sin(emb) + + return [2]*mlx.Array{cos, sin} +} + +// VLTextBlock is a single Qwen2.5 transformer block (for VL model) +type VLTextBlock struct { + Attention *VLTextAttention + MLP *VLTextMLP + InputLayerNorm *mlx.Array + PostAttnLayerNorm *mlx.Array + NormEps float32 +} + +// newVLTextBlock creates a text block +func newVLTextBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VLTextBlock, error) { + prefix := fmt.Sprintf("model.layers.%d", layerIdx) + + inputNorm, err := weights.Get(prefix + ".input_layernorm.weight") + if err != nil { + return nil, err + } + postAttnNorm, err := weights.Get(prefix + ".post_attention_layernorm.weight") + if err != nil { + return nil, err + } + + attention, err := newVLTextAttention(weights, prefix, cfg) + if err != nil { + return nil, err + } + + mlpLayer, err := newVLTextMLP(weights, prefix) + if err != nil { + return nil, err + } + + return &VLTextBlock{ + Attention: attention, + MLP: mlpLayer, + InputLayerNorm: inputNorm, + PostAttnLayerNorm: postAttnNorm, + NormEps: cfg.RMSNormEps, + }, nil +} + +// Forward applies the block +func (tb *VLTextBlock) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array { + h := mlx.RMSNorm(x, tb.InputLayerNorm, tb.NormEps) + attnOut := tb.Attention.Forward(h, cossin) + x = mlx.Add(x, attnOut) + + h = mlx.RMSNorm(x, tb.PostAttnLayerNorm, tb.NormEps) + mlpOut := tb.MLP.Forward(h) + x = mlx.Add(x, mlpOut) + + return x +} + +// VLTextAttention implements Qwen2.5 attention with M-RoPE +type VLTextAttention struct { + QProj *mlx.Array + KProj *mlx.Array + VProj *mlx.Array + OProj *mlx.Array + QBias *mlx.Array + KBias *mlx.Array + VBias *mlx.Array + NHeads int32 + NKVHeads int32 + HeadDim int32 + Scale float32 + MRoPESection []int32 +} + +// newVLTextAttention creates a text attention layer +func newVLTextAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VLTextAttention, error) { + qProj, err := weights.Get(prefix + ".self_attn.q_proj.weight") + if err != nil { + return nil, err + } + kProj, err := weights.Get(prefix + ".self_attn.k_proj.weight") + if err != nil { + return nil, err + } + vProj, err := weights.Get(prefix + ".self_attn.v_proj.weight") + if err != nil { + return nil, err + } + oProj, err := weights.Get(prefix + ".self_attn.o_proj.weight") + if err != nil { + return nil, err + } + + qBias, _ := weights.Get(prefix + ".self_attn.q_proj.bias") + kBias, _ := weights.Get(prefix + ".self_attn.k_proj.bias") + vBias, _ := weights.Get(prefix + ".self_attn.v_proj.bias") + + return &VLTextAttention{ + QProj: mlx.Transpose(qProj, 1, 0), + KProj: mlx.Transpose(kProj, 1, 0), + VProj: mlx.Transpose(vProj, 1, 0), + OProj: mlx.Transpose(oProj, 1, 0), + QBias: qBias, + KBias: kBias, + VBias: vBias, + NHeads: cfg.NumAttentionHeads, + NKVHeads: cfg.NumKeyValueHeads, + HeadDim: cfg.HeadDim, + Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))), + MRoPESection: cfg.MRoPESection, + }, nil +} + +// Forward computes attention +func (attn *VLTextAttention) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + + q := mlx.Linear(x, attn.QProj) + if attn.QBias != nil { + q = mlx.Add(q, attn.QBias) + } + k := mlx.Linear(x, attn.KProj) + if attn.KBias != nil { + k = mlx.Add(k, attn.KBias) + } + v := mlx.Linear(x, attn.VProj) + if attn.VBias != nil { + v = mlx.Add(v, attn.VBias) + } + + q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) + k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim) + v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim) + + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + // Apply M-RoPE + if cossin[0] != nil && cossin[1] != nil { + q = applyMRoPE(q, cossin[0], cossin[1], attn.MRoPESection) + k = applyMRoPE(k, cossin[0], cossin[1], attn.MRoPESection) + } + + // Repeat KV for GQA + if attn.NKVHeads < attn.NHeads { + repeats := attn.NHeads / attn.NKVHeads + k = repeatKV(k, repeats) + v = repeatKV(v, repeats) + } + + out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true) + + out = mlx.Transpose(out, 0, 2, 1, 3) + out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim) + + return mlx.Linear(out, attn.OProj) +} + +// applyMRoPE applies Multi-Resolution RoPE +func applyMRoPE(x *mlx.Array, cos, sin *mlx.Array, section []int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[1] + L := shape[2] + D := shape[3] + half := D / 2 + + fullSection := make([]int32, len(section)) + for i, s := range section { + fullSection[i] = s * 2 + } + + var cosParts, sinParts []*mlx.Array + offset := int32(0) + for i, size := range fullSection { + posDim := int32(i % 3) + cosSection := mlx.Slice(cos, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size}) + sinSection := mlx.Slice(sin, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size}) + cosSection = mlx.Squeeze(cosSection, 0) + sinSection = mlx.Squeeze(sinSection, 0) + cosParts = append(cosParts, cosSection) + sinParts = append(sinParts, sinSection) + offset += size + } + + cosFlat := mlx.Concatenate(cosParts, 2) + sinFlat := mlx.Concatenate(sinParts, 2) + + cosFlat = mlx.Reshape(cosFlat, B, 1, L, D) + sinFlat = mlx.Reshape(sinFlat, B, 1, L, D) + + x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, H, L, half}) + x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, H, L, D}) + negX2 := mlx.MulScalar(x2, -1) + rotatedX := mlx.Concatenate([]*mlx.Array{negX2, x1}, 3) + + return mlx.Add(mlx.Mul(x, cosFlat), mlx.Mul(rotatedX, sinFlat)) +} + +// repeatKV repeats key/value heads for GQA +func repeatKV(x *mlx.Array, repeats int32) *mlx.Array { + if repeats == 1 { + return x + } + shape := x.Shape() + x = mlx.ExpandDims(x, 2) + x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1}) + return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3]) +} + +// VLTextMLP implements Qwen2.5 SwiGLU MLP +type VLTextMLP struct { + GateProj *mlx.Array + UpProj *mlx.Array + DownProj *mlx.Array +} + +// newVLTextMLP creates a text MLP layer +func newVLTextMLP(weights *safetensors.ModelWeights, prefix string) (*VLTextMLP, error) { + gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight") + if err != nil { + return nil, err + } + upProj, err := weights.Get(prefix + ".mlp.up_proj.weight") + if err != nil { + return nil, err + } + downProj, err := weights.Get(prefix + ".mlp.down_proj.weight") + if err != nil { + return nil, err + } + + return &VLTextMLP{ + GateProj: mlx.Transpose(gateProj, 1, 0), + UpProj: mlx.Transpose(upProj, 1, 0), + DownProj: mlx.Transpose(downProj, 1, 0), + }, nil +} + +// Forward applies the SwiGLU MLP +func (mlp *VLTextMLP) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.Linear(x, mlp.GateProj) + gate = mlx.SiLU(gate) + up := mlx.Linear(x, mlp.UpProj) + h := mlx.Mul(gate, up) + return mlx.Linear(h, mlp.DownProj) +} + +// VisionPatchEmbed embeds image patches +type VisionPatchEmbed struct { + ProjWeight *mlx.Array + ProjBias *mlx.Array + PatchSize int32 +} + +// newVisionPatchEmbed creates a vision patch embed layer +func newVisionPatchEmbed(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionPatchEmbed, error) { + projWeight, err := weights.Get("visual.patch_embed.proj.weight") + if err != nil { + return nil, err + } + projBias, _ := weights.Get("visual.patch_embed.proj.bias") + + return &VisionPatchEmbed{ + ProjWeight: projWeight, + ProjBias: projBias, + PatchSize: cfg.VisionPatchSize, + }, nil +} + +// Forward embeds patches from an image +// image: [B, C, H, W] +// Returns: [B, num_patches, hidden_size] +func (pe *VisionPatchEmbed) Forward(image *mlx.Array) *mlx.Array { + // Qwen2.5-VL uses 3D conv for patch embedding to support video + // Weight shape is [O, I, kT, kH, kW] e.g. [1280, 3, 2, 14, 14] + // For single image, we duplicate the frame to match temporal_patch_size + + wShape := pe.ProjWeight.Shape() + if len(wShape) == 5 { + // 3D convolution case + temporalPatchSize := wShape[2] // kT from weight shape + + // Add temporal dimension: [B, C, H, W] -> [B, C, 1, H, W] + image = mlx.ExpandDims(image, 2) + + // Duplicate frame to match temporal_patch_size (Python does this for single images) + // [B, C, 1, H, W] -> [B, C, T, H, W] where T = temporal_patch_size + if temporalPatchSize > 1 { + image = mlx.Tile(image, []int32{1, 1, temporalPatchSize, 1, 1}) + } + + // Convert to channels-last: [B, C, T, H, W] -> [B, T, H, W, C] + image = mlx.Transpose(image, 0, 2, 3, 4, 1) + + // Weight is [O, I, kT, kH, kW] - keep as-is since patches are now in [I, kT, kH, kW] order + // (extractPatches3DStrided transposes each patch to [C, T, H, W] to match Python) + + // Apply 3D conv using manual patch extraction + // Strides: (temporal_patch_size, patch_size, patch_size) + x := conv3DStrided(image, pe.ProjWeight, temporalPatchSize, pe.PatchSize, pe.PatchSize) + + if pe.ProjBias != nil { + outC := pe.ProjBias.Dim(0) + bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, 1, outC) + x = mlx.Add(x, bias) + } + + // x is [B, T', H', W', C], squeeze T' and flatten spatial + shape := x.Shape() + // T' should be 1 for single image (since we used stride=temporal_patch_size) + x = mlx.Reshape(x, shape[0], shape[2]*shape[3], shape[4]) + + return x + } + + // Original 2D case (fallback) + // Convert to channels-last for Conv2d + image = mlx.Transpose(image, 0, 2, 3, 1) // [B, H, W, C] + + // Apply conv with stride=patch_size using manual strided convolution + weight := mlx.Transpose(pe.ProjWeight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I] + x := conv2DStrided(image, weight, pe.PatchSize) + if pe.ProjBias != nil { + bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, pe.ProjBias.Dim(0)) + x = mlx.Add(x, bias) + } + + // Flatten patches: [B, pH, pW, C] -> [B, pH*pW, C] + shape := x.Shape() + x = mlx.Reshape(x, shape[0], shape[1]*shape[2], shape[3]) + + return x +} + +// VisionBlock is a single vision transformer block +type VisionBlock struct { + Norm1 *mlx.Array + Norm2 *mlx.Array + Attention *VisionAttention + MLP *VisionMLP +} + +// newVisionBlock creates a vision block +func newVisionBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VisionBlock, error) { + prefix := fmt.Sprintf("visual.blocks.%d", layerIdx) + + norm1, err := weights.Get(prefix + ".norm1.weight") + if err != nil { + return nil, err + } + norm2, err := weights.Get(prefix + ".norm2.weight") + if err != nil { + return nil, err + } + + attention, err := newVisionAttention(weights, prefix, cfg) + if err != nil { + return nil, err + } + + mlpLayer, err := newVisionMLP(weights, prefix, cfg) + if err != nil { + return nil, err + } + + return &VisionBlock{ + Norm1: norm1, + Norm2: norm2, + Attention: attention, + MLP: mlpLayer, + }, nil +} + +// Forward applies the vision block +// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil +// cuSeqlens: cumulative sequence lengths for window attention +func (vb *VisionBlock) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array { + // Python uses RMSNorm, not LayerNorm! + h := mlx.RMSNormNoWeight(x, 1e-6) + h = mlx.Mul(h, vb.Norm1) + attnOut := vb.Attention.Forward(h, posEmb, cuSeqlens) + x = mlx.Add(x, attnOut) + + h = mlx.RMSNormNoWeight(x, 1e-6) + h = mlx.Mul(h, vb.Norm2) + mlpOut := vb.MLP.Forward(h) + x = mlx.Add(x, mlpOut) + + return x +} + +// VisionAttention implements vision attention +type VisionAttention struct { + QKVProj *mlx.Array + QKVBias *mlx.Array + OutProj *mlx.Array + OutBias *mlx.Array + NHeads int32 + HeadDim int32 + Scale float32 +} + +// newVisionAttention creates a vision attention layer +func newVisionAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionAttention, error) { + qkvProj, err := weights.Get(prefix + ".attn.qkv.weight") + if err != nil { + return nil, err + } + qkvBias, _ := weights.Get(prefix + ".attn.qkv.bias") + outProj, err := weights.Get(prefix + ".attn.proj.weight") + if err != nil { + return nil, err + } + outBias, _ := weights.Get(prefix + ".attn.proj.bias") + + headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads + + return &VisionAttention{ + QKVProj: mlx.Transpose(qkvProj, 1, 0), + QKVBias: qkvBias, + OutProj: mlx.Transpose(outProj, 1, 0), + OutBias: outBias, + NHeads: cfg.VisionNumHeads, + HeadDim: headDim, + Scale: float32(1.0 / math.Sqrt(float64(headDim))), + }, nil +} + +// Forward applies vision attention with optional RoPE and window attention +// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil +// cuSeqlens: cumulative sequence lengths for window boundaries +func (attn *VisionAttention) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + qkv := mlx.Linear(x, attn.QKVProj) + if attn.QKVBias != nil { + qkv = mlx.Add(qkv, attn.QKVBias) + } + + // Split into Q, K, V + qkv = mlx.Reshape(qkv, B, L, 3, attn.NHeads, attn.HeadDim) + q := mlx.Slice(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, attn.NHeads, attn.HeadDim}) + k := mlx.Slice(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, attn.NHeads, attn.HeadDim}) + v := mlx.Slice(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, attn.NHeads, attn.HeadDim}) + + q = mlx.Squeeze(q, 2) // [B, L, H, D] + k = mlx.Squeeze(k, 2) + v = mlx.Squeeze(v, 2) + + // Apply RoPE if position embeddings provided + if posEmb[0] != nil && posEmb[1] != nil { + q, k = applyVisionRoPE(q, k, posEmb[0], posEmb[1]) + } + + q = mlx.Transpose(q, 0, 2, 1, 3) // [B, H, L, D] + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + var out *mlx.Array + + // Check if we need window attention (more than 1 window) + numWindows := len(cuSeqlens) - 1 + if numWindows <= 1 { + // Full attention - single window covering entire sequence + out = mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false) + } else { + // Window attention - process each window separately + attnOutputs := make([]*mlx.Array, numWindows) + + for w := 0; w < numWindows; w++ { + start := cuSeqlens[w] + end := cuSeqlens[w+1] + + // Slice Q, K, V for this window: [B, H, winLen, D] + qWin := mlx.Slice(q, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim}) + kWin := mlx.Slice(k, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim}) + vWin := mlx.Slice(v, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim}) + + // Compute attention for this window + attnWin := mlx.ScaledDotProductAttention(qWin, kWin, vWin, attn.Scale, false) + attnOutputs[w] = attnWin + } + + // Concatenate all window outputs along sequence dimension + out = mlx.Concatenate(attnOutputs, 2) + } + + out = mlx.Transpose(out, 0, 2, 1, 3) // [B, L, H, D] + out = mlx.Reshape(out, B, L, D) + + out = mlx.Linear(out, attn.OutProj) + if attn.OutBias != nil { + out = mlx.Add(out, attn.OutBias) + } + + return out +} + +// applyVisionRoPE applies rotary position embedding to Q and K for vision +// q, k: [B, L, H, D], cos, sin: [L, D] (already doubled: D = head_dim) +// Returns: rotated q, k with same shape +// Note: Python does this computation in float32 for numerical stability +func applyVisionRoPE(q, k, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) { + // Convert to float32 for numerical stability (matches Python) + origDtype := q.Dtype() + q = mlx.AsType(q, mlx.DtypeFloat32) + k = mlx.AsType(k, mlx.DtypeFloat32) + cos = mlx.AsType(cos, mlx.DtypeFloat32) + sin = mlx.AsType(sin, mlx.DtypeFloat32) + + // Expand cos/sin to match q/k shape: [L, D] -> [1, L, 1, D] + cos = mlx.ExpandDims(cos, 0) + cos = mlx.ExpandDims(cos, 2) + sin = mlx.ExpandDims(sin, 0) + sin = mlx.ExpandDims(sin, 2) + + // rotate_half: split last dim in half and swap with negation + // q_rot = q * cos + rotate_half(q) * sin + qRotated := rotateHalf(q) + kRotated := rotateHalf(k) + + qOut := mlx.Add(mlx.Mul(q, cos), mlx.Mul(qRotated, sin)) + kOut := mlx.Add(mlx.Mul(k, cos), mlx.Mul(kRotated, sin)) + + // Convert back to original dtype + qOut = mlx.AsType(qOut, origDtype) + kOut = mlx.AsType(kOut, origDtype) + + return qOut, kOut +} + +// rotateHalf rotates the last dimension by splitting in half and swapping with negation +// x: [..., D] -> split to [..., D/2] and [..., D/2], then concat(-x2, x1) +func rotateHalf(x *mlx.Array) *mlx.Array { + shape := x.Shape() + lastDim := shape[len(shape)-1] + halfDim := lastDim / 2 + + // Split into two halves + x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], halfDim}) + x2 := mlx.Slice(x, []int32{0, 0, 0, halfDim}, []int32{shape[0], shape[1], shape[2], lastDim}) + + // Negate x2 and concatenate + x2Neg := mlx.MulScalar(x2, -1.0) + return mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) +} + +// VisionMLP implements vision SwiGLU MLP +type VisionMLP struct { + GateProj *mlx.Array + GateProjBias *mlx.Array + UpProj *mlx.Array + UpProjBias *mlx.Array + DownProj *mlx.Array + DownProjBias *mlx.Array +} + +// newVisionMLP creates a vision MLP layer +func newVisionMLP(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionMLP, error) { + gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight") + if err != nil { + return nil, err + } + gateProjBias, _ := weights.Get(prefix + ".mlp.gate_proj.bias") + upProj, err := weights.Get(prefix + ".mlp.up_proj.weight") + if err != nil { + return nil, err + } + upProjBias, _ := weights.Get(prefix + ".mlp.up_proj.bias") + downProj, err := weights.Get(prefix + ".mlp.down_proj.weight") + if err != nil { + return nil, err + } + downProjBias, _ := weights.Get(prefix + ".mlp.down_proj.bias") + + return &VisionMLP{ + GateProj: mlx.Transpose(gateProj, 1, 0), + GateProjBias: gateProjBias, + UpProj: mlx.Transpose(upProj, 1, 0), + UpProjBias: upProjBias, + DownProj: mlx.Transpose(downProj, 1, 0), + DownProjBias: downProjBias, + }, nil +} + +// Forward applies the vision SwiGLU MLP +func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.Linear(x, m.GateProj) + if m.GateProjBias != nil { + gate = mlx.Add(gate, m.GateProjBias) + } + gate = mlx.SiLU(gate) + + up := mlx.Linear(x, m.UpProj) + if m.UpProjBias != nil { + up = mlx.Add(up, m.UpProjBias) + } + + h := mlx.Mul(gate, up) + h = mlx.Linear(h, m.DownProj) + if m.DownProjBias != nil { + h = mlx.Add(h, m.DownProjBias) + } + return h +} + +// VisionMerger merges spatial patches (2x2 -> 1) +type VisionMerger struct { + MLP0Weight *mlx.Array + MLP0Bias *mlx.Array + MLP2Weight *mlx.Array + MLP2Bias *mlx.Array + LNWeight *mlx.Array +} + +// newVisionMerger creates a vision merger +func newVisionMerger(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionMerger, error) { + mlp0Weight, err := weights.Get("visual.merger.mlp.0.weight") + if err != nil { + return nil, err + } + mlp0Bias, _ := weights.Get("visual.merger.mlp.0.bias") + mlp2Weight, err := weights.Get("visual.merger.mlp.2.weight") + if err != nil { + return nil, err + } + mlp2Bias, _ := weights.Get("visual.merger.mlp.2.bias") + lnWeight, _ := weights.Get("visual.merger.ln_q.weight") + + return &VisionMerger{ + MLP0Weight: mlx.Transpose(mlp0Weight, 1, 0), + MLP0Bias: mlp0Bias, + MLP2Weight: mlx.Transpose(mlp2Weight, 1, 0), + MLP2Bias: mlp2Bias, + LNWeight: lnWeight, + }, nil +} + +// Forward merges 2x2 patches into 1 (assumes square grid - use ForwardWithDims for non-square) +func (m *VisionMerger) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + L := shape[1] + side := int32(math.Sqrt(float64(L))) + return m.ForwardWithDims(x, side, side) +} + +// ForwardWithDims merges 2x2 patches into 1 with explicit grid dimensions +// After window reordering, consecutive 4 patches form a 2x2 block, so we just +// reshape [B, L, D] -> [B, L/4, 4*D] without 2D spatial rearrangement. +func (m *VisionMerger) ForwardWithDims(x *mlx.Array, pH, pW int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + // RMSNorm BEFORE merge (applied to each token with D dimensions) + // Python: ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + if m.LNWeight != nil { + x = mlx.RMSNormNoWeight(x, 1e-6) + x = mlx.Mul(x, m.LNWeight) + } + + // After window reordering, consecutive 4 patches belong to a 2x2 block + // Just reshape to [B, L/4, 4*D] - no spatial rearrangement needed + newL := L / 4 + x = mlx.Reshape(x, B, newL, 4*D) + + // MLP + h := mlx.Linear(x, m.MLP0Weight) + if m.MLP0Bias != nil { + h = mlx.Add(h, m.MLP0Bias) + } + h = mlx.GELU(h) + h = mlx.Linear(h, m.MLP2Weight) + if m.MLP2Bias != nil { + h = mlx.Add(h, m.MLP2Bias) + } + + return h +} + +// LoadQwen25VLFromPath loads the encoder from path +func LoadQwen25VLFromPath(path string) (*Qwen25VL, error) { + m := &Qwen25VL{} + if err := m.Load(filepath.Join(path, "text_encoder")); err != nil { + return nil, err + } + return m, nil +} + +// conv2DStrided applies conv with stride > 1 using manual patch extraction +// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I] +func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + + wShape := weight.Shape() + Cout := wShape[0] + kH := wShape[1] + kW := wShape[2] + + outH := (H - kH) / stride + 1 + outW := (W - kW) / stride + 1 + + patches := extractPatches2DStrided(x, kH, kW, stride) + wFlat := mlx.Reshape(weight, Cout, -1) + patches = mlx.Reshape(patches, B*outH*outW, -1) + out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) + return mlx.Reshape(out, B, outH, outW, Cout) +} + +// conv3DStrided applies 3D conv with strides using manual patch extraction +// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format) +// strideT, strideH, strideW are the strides for each dimension +// Patches are extracted in [C, T, H, W] order to match Python's preprocessing +func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + + wShape := weight.Shape() + Cout := wShape[0] + // I := wShape[1] + kT := wShape[2] + kH := wShape[3] + kW := wShape[4] + + // For temporal: if T < kT, we need to repeat frames temporally + // For single image with T=1 and kT=2, we duplicate the frame to T=kT + // Python Qwen2.5-VL duplicates the frame, not zero-pads + if T < kT { + // Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C] + x = mlx.Tile(x, []int32{1, kT, 1, 1, 1}) + T = kT + } + + outT := (T - kT) / strideT + 1 + outH := (H - kH) / strideH + 1 + outW := (W - kW) / strideW + 1 + + // Extract 3D patches in [C, T, H, W] order to match Python + patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW) + // patches shape: [B, outT, outH, outW, C*kT*kH*kW] + + // Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W] + wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW] + patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW) + out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) + return mlx.Reshape(out, B, outT, outH, outW, Cout) +} + +// extractPatches3DStrided extracts 3D patches with given strides +// Returns patches with values in [C, T, H, W] order to match Python's preprocessing +func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + + outT := (T - kT) / strideT + 1 + outH := (H - kH) / strideH + 1 + outW := (W - kW) / strideW + 1 + + numPatches := outT * outH * outW + patches := make([]*mlx.Array, numPatches) + idx := 0 + for t := int32(0); t < outT; t++ { + for i := int32(0); i < outH; i++ { + for j := int32(0); j < outW; j++ { + startT := t * strideT + startH := i * strideH + startW := j * strideW + // Extract patch: [B, kT, kH, kW, C] + patch := mlx.Slice(x, + []int32{0, startT, startH, startW, 0}, + []int32{B, startT + kT, startH + kH, startW + kW, C}) + // Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order + patch = mlx.Transpose(patch, 0, 4, 1, 2, 3) + // Flatten to [B, C*T*H*W] + patch = mlx.Reshape(patch, B, C*kT*kH*kW) + patches[idx] = patch + idx++ + } + } + } + + for i := range patches { + patches[i] = mlx.ExpandDims(patches[i], 1) + } + stacked := mlx.Concatenate(patches, 1) + return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW) +} + +// extractPatches2DStrided extracts patches with given stride +func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + C := shape[3] + + outH := (H - kH) / stride + 1 + outW := (W - kW) / stride + 1 + + patches := make([]*mlx.Array, outH*outW) + idx := 0 + for i := int32(0); i < outH; i++ { + for j := int32(0); j < outW; j++ { + startH := i * stride + startW := j * stride + patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C}) + patch = mlx.Reshape(patch, B, kH*kW*C) + patches[idx] = patch + idx++ + } + } + + for i := range patches { + patches[i] = mlx.ExpandDims(patches[i], 1) + } + stacked := mlx.Concatenate(patches, 1) + return mlx.Reshape(stacked, B, outH, outW, kH*kW*C) +} diff --git a/x/imagegen/models/qwen_image/qwen_image.go b/x/imagegen/models/qwen_image/qwen_image.go new file mode 100644 index 000000000..97dbb089e --- /dev/null +++ b/x/imagegen/models/qwen_image/qwen_image.go @@ -0,0 +1,350 @@ +//go:build mlx + +// Package qwen_image implements the Qwen-Image diffusion transformer model. +package qwen_image + +import ( + "context" + "fmt" + "path/filepath" + "time" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// GenerateConfig holds all options for image generation. +type GenerateConfig struct { + Prompt string + NegativePrompt string // Empty = no CFG + CFGScale float32 // Only used if NegativePrompt is set (default: 4.0) + Width int32 // Image width (default: 1024) + Height int32 // Image height (default: 1024) + Steps int // Denoising steps (default: 30) + Seed int64 // Random seed + Progress ProgressFunc // Optional progress callback + + // Layer caching (DeepCache/Learning-to-Cache speedup) + LayerCache bool // Enable layer caching (default: false) + CacheInterval int // Refresh cache every N steps (default: 3) + CacheLayers int // Number of shallow layers to cache (default: 25) +} + +// ProgressFunc is called during generation with step progress. +type ProgressFunc func(step, totalSteps int) + +// Model represents a Qwen-Image diffusion model. +type Model struct { + ModelPath string + Tokenizer *tokenizer.Tokenizer + TextEncoder *Qwen25VL + Transformer *Transformer + VAEDecoder *VAEDecoder +} + +// Load loads the Qwen-Image model from a directory. +func (m *Model) Load(modelPath string) error { + fmt.Println("Loading Qwen-Image model...") + start := time.Now() + + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + mlx.EnableCompile() + } + + m.ModelPath = modelPath + + // Load tokenizer + fmt.Print(" Loading tokenizer... ") + tokenizerPath := filepath.Join(modelPath, "tokenizer") + tok, err := tokenizer.Load(tokenizerPath) + if err != nil { + return fmt.Errorf("tokenizer: %w", err) + } + m.Tokenizer = tok + fmt.Println("✓") + + // Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency) + m.TextEncoder = &Qwen25VL{} + if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil { + return fmt.Errorf("text encoder: %w", err) + } + mlx.Eval(mlx.Collect(m.TextEncoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load transformer + m.Transformer = &Transformer{} + if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil { + return fmt.Errorf("transformer: %w", err) + } + mlx.Eval(mlx.Collect(m.Transformer)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load VAE decoder + m.VAEDecoder = &VAEDecoder{} + if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil { + return fmt.Errorf("VAE decoder: %w", err) + } + mlx.Eval(mlx.Collect(m.VAEDecoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + mem := mlx.MetalGetActiveMemory() + peak := mlx.MetalGetPeakMemory() + fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n", + time.Since(start).Seconds(), + float64(mem)/(1024*1024*1024), + float64(peak)/(1024*1024*1024)) + + return nil +} + +// Generate creates an image from a prompt. +func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.GenerateFromConfig(&GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + }) +} + +// GenerateWithProgress creates an image with progress callback. +func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) { + return m.GenerateFromConfig(&GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + Progress: progress, + }) +} + +// GenerateWithCFG creates an image with classifier-free guidance. +func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) { + return m.GenerateFromConfig(&GenerateConfig{ + Prompt: prompt, + NegativePrompt: negativePrompt, + CFGScale: cfgScale, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + Progress: progress, + }) +} + +// GenerateFromConfig generates an image using the unified config struct. +func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) { + start := time.Now() + result, err := m.generate(cfg) + if err != nil { + return nil, err + } + if cfg.NegativePrompt != "" { + fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps) + } else { + fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps) + } + return result, nil +} + +// GenerateImage implements model.ImageModel interface. +func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.Generate(prompt, width, height, steps, seed) +} + +// generate is the internal denoising pipeline. +func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) { + // Apply defaults + if cfg.Width <= 0 { + cfg.Width = 1024 + } + if cfg.Height <= 0 { + cfg.Height = 1024 + } + if cfg.Steps <= 0 { + cfg.Steps = 30 + } + if cfg.CFGScale <= 0 { + cfg.CFGScale = 4.0 + } + if cfg.CacheInterval <= 0 { + cfg.CacheInterval = 3 + } + if cfg.CacheLayers <= 0 { + cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38) + } + + useCFG := cfg.NegativePrompt != "" + tcfg := m.Transformer.Config + latentH := cfg.Height / 8 + latentW := cfg.Width / 8 + pH := latentH / tcfg.PatchSize + pW := latentW / tcfg.PatchSize + imgSeqLen := pH * pW + + // Text encoding + var posEmb, negEmb *mlx.Array + { + posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt) + if useCFG { + negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt) + mlx.Keep(posEmb, negEmb) + mlx.Eval(posEmb, negEmb) + } else { + mlx.Keep(posEmb) + mlx.Eval(posEmb) + } + } + + // Pad sequences to same length for CFG + txtLen := posEmb.Shape()[1] + if useCFG { + negLen := negEmb.Shape()[1] + if negLen > txtLen { + txtLen = negLen + } + if posEmb.Shape()[1] < txtLen { + posEmb = padSequence(posEmb, txtLen) + } + if negEmb.Shape()[1] < txtLen { + negEmb = padSequence(negEmb, txtLen) + } + mlx.Keep(posEmb, negEmb) + } + + // Scheduler + scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig()) + scheduler.SetTimesteps(cfg.Steps, imgSeqLen) + + // Init latents [B, C, T, H, W] + var latents *mlx.Array + { + latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed) + mlx.Eval(latents) + } + + // RoPE cache + var ropeCache *RoPECache + { + ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope) + mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs) + mlx.Eval(ropeCache.ImgFreqs) + } + + // Layer cache for DeepCache/Learning-to-Cache speedup + var stepCache *cache.StepCache + if cfg.LayerCache { + stepCache = cache.NewStepCache(cfg.CacheLayers) + fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval) + } + + // Denoising loop + for i := 0; i < cfg.Steps; i++ { + stepStart := time.Now() + if cfg.Progress != nil { + cfg.Progress(i+1, cfg.Steps) + } + + t := scheduler.Timesteps[i] + timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1})) + + // Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W] + latents2D := mlx.Squeeze(latents, 2) + patches := PackLatents(latents2D, tcfg.PatchSize) + + var output *mlx.Array + if useCFG { + // True CFG: run twice and combine with norm rescaling + // Note: layer caching with CFG is not supported yet (would need 2 caches) + posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) + negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) + + diff := mlx.Sub(posOutput, negOutput) + scaledDiff := mlx.MulScalar(diff, cfg.CFGScale) + combPred := mlx.Add(negOutput, scaledDiff) + + // Norm rescaling: rescale combined prediction to match conditional prediction's norm + condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true)) + combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true)) + output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm)) + } else if stepCache != nil { + output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs, + stepCache, i, cfg.CacheInterval, cfg.CacheLayers) + } else { + output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) + } + + noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize) + oldLatents := latents + latents = scheduler.Step(noisePred, latents, i) + + // Keep cached arrays alive across cleanup + if stepCache != nil { + mlx.Keep(stepCache.Arrays()...) + } + mlx.Eval(latents) + oldLatents.Free() + + activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024) + peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024) + fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem) + } + + // Free denoising temporaries before VAE decode + posEmb.Free() + if negEmb != nil { + negEmb.Free() + } + ropeCache.ImgFreqs.Free() + ropeCache.TxtFreqs.Free() + if stepCache != nil { + stepCache.Free() + } + + // VAE decode (Decode manages its own pools for staged memory) + decoded := m.VAEDecoder.Decode(latents) + latents.Free() + // Post-process: squeeze temporal dim and rescale to [0, 1] + { + decoded = mlx.Squeeze(decoded, 2) + decoded = mlx.AddScalar(decoded, 1.0) + decoded = mlx.DivScalar(decoded, 2.0) + mlx.Eval(decoded) + } + + fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + return decoded, nil +} + +// padSequence pads a sequence tensor to the target length with zeros +func padSequence(x *mlx.Array, targetLen int32) *mlx.Array { + shape := x.Shape() + currentLen := shape[1] + if currentLen >= targetLen { + return x + } + padLen := targetLen - currentLen + // Pad on sequence dimension (axis 1) + return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0}) +} + +// LoadPersistent is an alias for backward compatibility. +// Use m := &Model{}; m.Load(path) instead. +func LoadPersistent(modelPath string) (*Model, error) { + m := &Model{} + if err := m.Load(modelPath); err != nil { + return nil, err + } + return m, nil +} diff --git a/x/imagegen/models/qwen_image/scheduler.go b/x/imagegen/models/qwen_image/scheduler.go new file mode 100644 index 000000000..d1f0da049 --- /dev/null +++ b/x/imagegen/models/qwen_image/scheduler.go @@ -0,0 +1,218 @@ +//go:build mlx + +package qwen_image + +import ( + "math" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration +type SchedulerConfig struct { + NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000 + BaseShift float32 `json:"base_shift"` // 0.5 + MaxShift float32 `json:"max_shift"` // 0.9 + BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256 + MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192 + ShiftTerminal float32 `json:"shift_terminal"` // 0.02 + UseDynamicShift bool `json:"use_dynamic_shifting"` // true +} + +// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler +func DefaultSchedulerConfig() *SchedulerConfig { + return &SchedulerConfig{ + NumTrainTimesteps: 1000, + BaseShift: 0.5, + MaxShift: 0.9, // Matches scheduler_config.json + BaseImageSeqLen: 256, + MaxImageSeqLen: 8192, + ShiftTerminal: 0.02, + UseDynamicShift: true, + } +} + +// FlowMatchScheduler implements the Flow Match Euler discrete scheduler +type FlowMatchScheduler struct { + Config *SchedulerConfig + Timesteps []float32 + Sigmas []float32 + NumSteps int +} + +// NewFlowMatchScheduler creates a new scheduler +func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler { + return &FlowMatchScheduler{ + Config: cfg, + } +} + +// CalculateShift computes the dynamic shift based on image sequence length +// This matches Python's calculate_shift function +func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 { + m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen) + b := baseShift - m*float32(baseSeqLen) + mu := float32(imageSeqLen)*m + b + return mu +} + +// SetTimesteps sets up the scheduler for the given number of inference steps +// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior: +// 1. Create sigmas from sigma_max to sigma_min (linspace) +// 2. Apply time_shift with mu (if dynamic shifting) +// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal +func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) { + s.NumSteps = numSteps + + // Calculate mu for dynamic shifting + var mu float32 + if s.Config.UseDynamicShift { + mu = CalculateShift( + imageSeqLen, + s.Config.BaseImageSeqLen, + s.Config.MaxImageSeqLen, + s.Config.BaseShift, + s.Config.MaxShift, + ) + } + + // Step 1: Create sigmas from 1.0 to 1/num_steps + // Python (pipeline_qwenimage.py:639): + // sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + // This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps + sigmas := make([]float32, numSteps) + sigmaMax := float32(1.0) + sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps + if numSteps == 1 { + sigmas[0] = sigmaMax + } else { + for i := 0; i < numSteps; i++ { + sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1) + } + } + + // Step 2: Apply time shift if using dynamic shifting + if s.Config.UseDynamicShift && mu != 0 { + for i := range sigmas { + sigmas[i] = s.timeShift(mu, sigmas[i]) + } + } + + // Step 3: Apply stretch_shift_to_terminal + if s.Config.ShiftTerminal > 0 { + sigmas = s.stretchShiftToTerminal(sigmas) + } + + // Step 4: Append terminal sigma (0) and store + // Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000 + // before passing to transformer. We skip both steps and just use sigmas directly. + s.Sigmas = make([]float32, numSteps+1) + s.Timesteps = make([]float32, numSteps+1) + for i := 0; i < numSteps; i++ { + s.Sigmas[i] = sigmas[i] + s.Timesteps[i] = sigmas[i] + } + s.Sigmas[numSteps] = 0.0 + s.Timesteps[numSteps] = 0.0 +} + +// stretchShiftToTerminal stretches and shifts the timestep schedule +// so the final value equals shift_terminal (matches Python behavior) +func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 { + if len(sigmas) == 0 { + return sigmas + } + + // one_minus_z = 1 - t + // scale_factor = one_minus_z[-1] / (1 - shift_terminal) + // stretched_t = 1 - (one_minus_z / scale_factor) + lastSigma := sigmas[len(sigmas)-1] + scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal) + + // Handle edge case: if scaleFactor is 0 or near 0, skip stretch + // This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift) + if scaleFactor < 1e-6 { + return sigmas + } + + result := make([]float32, len(sigmas)) + for i, t := range sigmas { + oneMinusZ := 1.0 - t + result[i] = 1.0 - (oneMinusZ / scaleFactor) + } + return result +} + +// timeShift applies the dynamic time shift (exponential) +// exp(mu) / (exp(mu) + (1/t - 1)) +func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 { + if t <= 0 { + return 0 + } + expMu := float32(math.Exp(float64(mu))) + return expMu / (expMu + (1.0/t - 1.0)) +} + +// Step performs one denoising step +// modelOutput: predicted velocity from the transformer +// sample: current noisy sample +// timestepIdx: current timestep index +func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array { + // Get current and next sigma + sigma := s.Sigmas[timestepIdx] + sigmaNext := s.Sigmas[timestepIdx+1] + + // Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t + dt := sigmaNext - sigma + + // Upcast to float32 to avoid precision issues (matches Python diffusers) + sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32) + modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32) + + scaledOutput := mlx.MulScalar(modelOutputF32, dt) + result := mlx.Add(sampleF32, scaledOutput) + + // Cast back to original dtype + return mlx.ToBFloat16(result) +} + +// GetTimestep returns the timestep value at the given index +func (s *FlowMatchScheduler) GetTimestep(idx int) float32 { + if idx < len(s.Timesteps) { + return s.Timesteps[idx] + } + return 0.0 +} + +// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W] +func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array { + return mlx.RandomNormal(shape, uint64(seed)) +} + +// InitNoisePacked creates initial noise directly in packed format [B, L, C*4] +// This matches how Python diffusers generates noise - directly in packed space. +// Generating in unpacked format and then packing produces different spatial +// correlation structure, which affects model output quality. +func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array { + shape := []int32{batchSize, seqLen, channels} + return mlx.RandomNormal(shape, uint64(seed)) +} + +// GetLatentShape returns the latent shape for a given image size +// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels +func GetLatentShape(batchSize, height, width int32) []int32 { + latentH := height / 8 + latentW := width / 8 + return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W] +} + +// GetPatchedLatentShape returns the patchified latent shape +// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2 +func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 { + latentH := height / 8 + latentW := width / 8 + pH := latentH / patchSize + pW := latentW / patchSize + inChannels := int32(64) // 16 * patch_size^2 + return []int32{batchSize, pH * pW, inChannels} +} diff --git a/x/imagegen/models/qwen_image/scheduler_test.go b/x/imagegen/models/qwen_image/scheduler_test.go new file mode 100644 index 000000000..46adeb99a --- /dev/null +++ b/x/imagegen/models/qwen_image/scheduler_test.go @@ -0,0 +1,135 @@ +//go:build mlx + +package qwen_image + +import ( + "math" + "testing" +) + +// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference. +// Golden values generated via: +// +// python3 -c " +// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +// import numpy as np +// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9, +// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True) +// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256 +// sigmas = np.linspace(1.0, 1.0/30, 30) +// s.set_timesteps(sigmas=sigmas, mu=mu) +// print(s.sigmas.numpy())" +func TestSchedulerSetTimesteps(t *testing.T) { + cfg := DefaultSchedulerConfig() + scheduler := NewFlowMatchScheduler(cfg) + scheduler.SetTimesteps(30, 4096) + + // Golden values from Python diffusers (first 3, last 3 before terminal) + wantFirst := []float32{1.000000, 0.982251, 0.963889} + wantLast := []float32{0.142924, 0.083384, 0.020000} + + // Check first 3 + for i, want := range wantFirst { + got := scheduler.Sigmas[i] + if abs32(got-want) > 1e-4 { + t.Errorf("sigma[%d]: got %v, want %v", i, got, want) + } + } + + // Check last 3 (indices 27, 28, 29) + for i, want := range wantLast { + idx := 27 + i + got := scheduler.Sigmas[idx] + if abs32(got-want) > 1e-4 { + t.Errorf("sigma[%d]: got %v, want %v", idx, got, want) + } + } + + // Check terminal is 0 + if scheduler.Sigmas[30] != 0.0 { + t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30]) + } + + // Check length + if len(scheduler.Sigmas) != 31 { + t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas)) + } +} + +// TestSchedulerProperties tests mathematical invariants of the scheduler. +func TestSchedulerProperties(t *testing.T) { + cfg := DefaultSchedulerConfig() + scheduler := NewFlowMatchScheduler(cfg) + scheduler.SetTimesteps(30, 4096) + + // Property: sigmas monotonically decreasing + for i := 1; i < len(scheduler.Sigmas); i++ { + if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] { + t.Errorf("sigmas not monotonically decreasing at %d: %v > %v", + i, scheduler.Sigmas[i], scheduler.Sigmas[i-1]) + } + } + + // Property: first sigma should be ~1.0 (with time shift) + if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 { + t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0]) + } + + // Property: terminal sigma should be exactly 0 + if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 { + t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1]) + } + + // Property: last non-terminal sigma should be shift_terminal (0.02) + lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2] + if abs32(lastNonTerminal-0.02) > 1e-5 { + t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal) + } + + // Property: length = steps + 1 + if len(scheduler.Sigmas) != scheduler.NumSteps+1 { + t.Errorf("sigmas length should be steps+1: got %d, want %d", + len(scheduler.Sigmas), scheduler.NumSteps+1) + } +} + +// TestCalculateShift verifies the mu calculation against Python reference. +// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len) +func TestCalculateShift(t *testing.T) { + cases := []struct { + imgSeqLen int32 + want float32 + }{ + {256, 0.5}, // base case + {8192, 0.9}, // max case + {4096, 0.6935}, // middle case (rounded) + } + + for _, c := range cases { + got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9) + if abs32(got-c.want) > 0.001 { + t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want) + } + } +} + +// TestSchedulerStep verifies the Euler step formula. +func TestSchedulerStep(t *testing.T) { + cfg := DefaultSchedulerConfig() + scheduler := NewFlowMatchScheduler(cfg) + scheduler.SetTimesteps(30, 4096) + + // Verify dt calculation for first step + sigma0 := scheduler.Sigmas[0] + sigma1 := scheduler.Sigmas[1] + expectedDt := sigma1 - sigma0 + + // dt should be negative (sigmas decrease) + if expectedDt >= 0 { + t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1) + } +} + +func abs32(x float32) float32 { + return float32(math.Abs(float64(x))) +} diff --git a/x/imagegen/models/qwen_image/text_encoder_test.go b/x/imagegen/models/qwen_image/text_encoder_test.go new file mode 100644 index 000000000..7704513c8 --- /dev/null +++ b/x/imagegen/models/qwen_image/text_encoder_test.go @@ -0,0 +1,174 @@ +//go:build mlx + +package qwen_image + +import ( + "encoding/json" + "math" + "os" + "path/filepath" + "slices" + "testing" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// TinyTextEncoderConfig holds config for the tiny test text encoder +type TinyTextEncoderConfig struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + HeadDim int32 `json:"head_dim"` + MRoPESection []int32 `json:"mrope_section"` +} + +// loadTinyTextEncoder loads the tiny text encoder from testdata +func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) { + t.Helper() + + testdataDir := filepath.Join("testdata", "tiny_text_encoder") + + // Load config + configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json")) + if err != nil { + t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)") + } + + var tinyCfg TinyTextEncoderConfig + if err := json.Unmarshal(configData, &tinyCfg); err != nil { + t.Fatalf("Failed to parse config: %v", err) + } + + // Create encoder config (using Qwen25VLConfig) + cfg := &Qwen25VLConfig{ + HiddenSize: tinyCfg.HiddenSize, + NumHiddenLayers: tinyCfg.NumHiddenLayers, + IntermediateSize: tinyCfg.IntermediateSize, + NumAttentionHeads: tinyCfg.NumAttentionHeads, + NumKeyValueHeads: tinyCfg.NumKeyValueHeads, + VocabSize: tinyCfg.VocabSize, + RMSNormEps: tinyCfg.RMSNormEps, + RopeTheta: tinyCfg.RopeTheta, + HeadDim: tinyCfg.HeadDim, + MRoPESection: tinyCfg.MRoPESection, + } + + // Load weights + weights, err := safetensors.LoadModelWeights(testdataDir) + if err != nil { + t.Fatalf("Failed to load weights: %v", err) + } + + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + t.Fatalf("Failed to bulk load weights: %v", err) + } + + // Build encoder + embedding, err := weights.Get("model.embed_tokens.weight") + if err != nil { + t.Fatalf("Failed to get embedding: %v", err) + } + + blocks := make([]*VLTextBlock, cfg.NumHiddenLayers) + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + block, err := newVLTextBlock(weights, int(i), cfg) + if err != nil { + t.Fatalf("Failed to load block %d: %v", i, err) + } + blocks[i] = block + } + + finalNorm, err := weights.Get("model.norm.weight") + if err != nil { + t.Fatalf("Failed to get final norm: %v", err) + } + + encoder := &Qwen25VL{ + Config: cfg, + Embedding: embedding, + Blocks: blocks, + FinalNorm: finalNorm, + HasVision: false, // Text-only mode + } + + return encoder, &tinyCfg +} + +// TestTextEncoderForward verifies the text encoder forward pass with tiny weights. +func TestTextEncoderForward(t *testing.T) { + encoder, cfg := loadTinyTextEncoder(t) + + // Create test tokens (within vocab range) + tokens := []int32{1, 2, 3, 4, 5} + + // Forward pass using EncodeTextOnly + out := encoder.EncodeTextOnly(tokens) + mlx.Eval(out) + + // Verify output shape: [batch, seq_len, hidden_size] + wantShape := []int32{1, 5, cfg.HiddenSize} + if !slices.Equal(out.Shape(), wantShape) { + t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape) + } + + // Verify output is finite (not NaN or Inf) + data := out.Data() + for i, v := range data { + if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) { + t.Errorf("output[%d] is not finite: %v", i, v) + break + } + } +} + +// TestTextEncoderBatch tests batch processing. +func TestTextEncoderBatch(t *testing.T) { + encoder, cfg := loadTinyTextEncoder(t) + + // For batch test, we'll use EncodeTextOnly with a single sequence + // (EncodeTextOnly doesn't support batch, but we can verify single sequence works) + tokens := []int32{1, 2, 3} + + out := encoder.EncodeTextOnly(tokens) + mlx.Eval(out) + + wantShape := []int32{1, 3, cfg.HiddenSize} + if !slices.Equal(out.Shape(), wantShape) { + t.Errorf("shape: got %v, want %v", out.Shape(), wantShape) + } +} + +// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values. +func TestMRoPEComputation(t *testing.T) { + encoder, cfg := loadTinyTextEncoder(t) + + cossin := encoder.computeTextRoPE(10, 1) + mlx.Eval(cossin[0], cossin[1]) + + // Verify shapes: [3, B, L, head_dim] + wantShape := []int32{3, 1, 10, cfg.HeadDim} + if !slices.Equal(cossin[0].Shape(), wantShape) { + t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape) + } + if !slices.Equal(cossin[1].Shape(), wantShape) { + t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape) + } + + // Verify cos/sin values are in valid range [-1, 1] + cosData := cossin[0].Data() + sinData := cossin[1].Data() + for i := 0; i < min(100, len(cosData)); i++ { + if cosData[i] < -1.01 || cosData[i] > 1.01 { + t.Errorf("cos[%d] out of range: %v", i, cosData[i]) + } + if sinData[i] < -1.01 || sinData[i] > 1.01 { + t.Errorf("sin[%d] out of range: %v", i, sinData[i]) + } + } +} diff --git a/x/imagegen/models/qwen_image/transformer.go b/x/imagegen/models/qwen_image/transformer.go new file mode 100644 index 000000000..06e677619 --- /dev/null +++ b/x/imagegen/models/qwen_image/transformer.go @@ -0,0 +1,868 @@ +//go:build mlx + +package qwen_image + +import ( + "fmt" + "math" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// TransformerConfig holds Qwen-Image transformer configuration +type TransformerConfig struct { + HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128) + NHeads int32 `json:"num_attention_heads"` // 24 + HeadDim int32 `json:"attention_head_dim"` // 128 + NLayers int32 `json:"num_layers"` // 60 + InChannels int32 `json:"in_channels"` // 64 + OutChannels int32 `json:"out_channels"` // 16 + PatchSize int32 `json:"patch_size"` // 2 + JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim) + NormEps float32 `json:"norm_eps"` // 1e-6 + AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56] + GuidanceEmbeds bool `json:"guidance_embeds"` // false +} + +// defaultTransformerConfig returns config for Qwen-Image transformer +func defaultTransformerConfig() *TransformerConfig { + return &TransformerConfig{ + HiddenDim: 3072, // 24 * 128 + NHeads: 24, + HeadDim: 128, + NLayers: 60, + InChannels: 64, + OutChannels: 16, + PatchSize: 2, + JointAttentionDim: 3584, + NormEps: 1e-6, + AxesDimsRope: []int32{16, 56, 56}, + GuidanceEmbeds: false, + } +} + +// TimestepEmbedder creates timestep embeddings +type TimestepEmbedder struct { + Linear1Weight *mlx.Array // [256, hidden_dim] + Linear1Bias *mlx.Array + Linear2Weight *mlx.Array // [hidden_dim, hidden_dim] + Linear2Bias *mlx.Array +} + +// newTimestepEmbedder creates a timestep embedder from weights +func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) { + linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight") + if err != nil { + return nil, err + } + linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias") + if err != nil { + return nil, err + } + linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight") + if err != nil { + return nil, err + } + linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias") + if err != nil { + return nil, err + } + + return &TimestepEmbedder{ + Linear1Weight: mlx.Transpose(linear1Weight, 1, 0), + Linear1Bias: linear1Bias, + Linear2Weight: mlx.Transpose(linear2Weight, 1, 0), + Linear2Bias: linear2Bias, + }, nil +} + +// Forward computes timestep embeddings +// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally) +func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array { + half := int32(128) // embedding_dim / 2 + + // Sinusoidal embedding with flip_sin_to_cos=True, scale=1000 + freqs := make([]float32, half) + for i := int32(0); i < half; i++ { + freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half))) + } + freqsArr := mlx.NewArray(freqs, []int32{1, half}) + + tExpanded := mlx.ExpandDims(t, 1) + args := mlx.Mul(tExpanded, freqsArr) + args = mlx.MulScalar(args, 1000.0) // scale + + // [cos, sin] (flip_sin_to_cos=True) + sinArgs := mlx.Sin(args) + cosArgs := mlx.Cos(args) + embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256] + + // MLP: linear1 -> silu -> linear2 + h := mlx.Linear(embedding, te.Linear1Weight) + h = mlx.Add(h, te.Linear1Bias) + h = mlx.SiLU(h) + h = mlx.Linear(h, te.Linear2Weight) + h = mlx.Add(h, te.Linear2Bias) + + return h +} + +// JointAttention implements dual-stream joint attention +type JointAttention struct { + // Image projections + ToQ *mlx.Array + ToQB *mlx.Array + ToK *mlx.Array + ToKB *mlx.Array + ToV *mlx.Array + ToVB *mlx.Array + ToOut *mlx.Array + ToOutB *mlx.Array + NormQ *mlx.Array + NormK *mlx.Array + + // Text (added) projections + AddQProj *mlx.Array + AddQProjB *mlx.Array + AddKProj *mlx.Array + AddKProjB *mlx.Array + AddVProj *mlx.Array + AddVProjB *mlx.Array + ToAddOut *mlx.Array + ToAddOutB *mlx.Array + NormAddQ *mlx.Array + NormAddK *mlx.Array + + NHeads int32 + HeadDim int32 + Scale float32 +} + +// newJointAttention creates a joint attention layer +func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) { + toQ, _ := weights.Get(prefix + ".attn.to_q.weight") + toQB, _ := weights.Get(prefix + ".attn.to_q.bias") + toK, _ := weights.Get(prefix + ".attn.to_k.weight") + toKB, _ := weights.Get(prefix + ".attn.to_k.bias") + toV, _ := weights.Get(prefix + ".attn.to_v.weight") + toVB, _ := weights.Get(prefix + ".attn.to_v.bias") + toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight") + toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias") + normQ, _ := weights.Get(prefix + ".attn.norm_q.weight") + normK, _ := weights.Get(prefix + ".attn.norm_k.weight") + + addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight") + addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias") + addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight") + addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias") + addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight") + addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias") + toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight") + toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias") + normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight") + normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight") + + return &JointAttention{ + ToQ: mlx.Transpose(toQ, 1, 0), + ToQB: toQB, + ToK: mlx.Transpose(toK, 1, 0), + ToKB: toKB, + ToV: mlx.Transpose(toV, 1, 0), + ToVB: toVB, + ToOut: mlx.Transpose(toOut, 1, 0), + ToOutB: toOutB, + NormQ: normQ, + NormK: normK, + AddQProj: mlx.Transpose(addQProj, 1, 0), + AddQProjB: addQProjB, + AddKProj: mlx.Transpose(addKProj, 1, 0), + AddKProjB: addKProjB, + AddVProj: mlx.Transpose(addVProj, 1, 0), + AddVProjB: addVProjB, + ToAddOut: mlx.Transpose(toAddOut, 1, 0), + ToAddOutB: toAddOutB, + NormAddQ: normAddQ, + NormAddK: normAddK, + NHeads: cfg.NHeads, + HeadDim: cfg.HeadDim, + Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))), + }, nil +} + +// Forward computes joint attention +// img: [B, L_img, D], txt: [B, L_txt, D] +// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag +func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) { + imgShape := img.Shape() + B := imgShape[0] + Limg := imgShape[1] + D := imgShape[2] + + txtShape := txt.Shape() + Ltxt := txtShape[1] + + // === Image Q/K/V === + imgFlat := mlx.Reshape(img, B*Limg, D) + qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB) + kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB) + vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB) + + qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim) + kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim) + vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim) + + // QK norm (RMSNorm per head) + qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6) + kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6) + + // Apply RoPE + if imgFreqs != nil { + qImg = applyRoPE(qImg, imgFreqs) + kImg = applyRoPE(kImg, imgFreqs) + } + + // === Text Q/K/V === + txtFlat := mlx.Reshape(txt, B*Ltxt, D) + qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB) + kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB) + vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB) + + qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim) + kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim) + vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim) + + qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6) + kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6) + + if txtFreqs != nil { + qTxt = applyRoPE(qTxt, txtFreqs) + kTxt = applyRoPE(kTxt, txtFreqs) + } + + // Concatenate for joint attention: [txt, img] order + qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1) + kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1) + vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1) + + // Transpose to [B, nheads, L, head_dim] + qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3) + kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3) + vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3) + + // SDPA + outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false) + + // Transpose back and split + outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim] + outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D) + + outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D}) + outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D}) + + // Output projections + outImg = mlx.Reshape(outImg, B*Limg, D) + outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB) + outImg = mlx.Reshape(outImg, B, Limg, D) + + outTxt = mlx.Reshape(outTxt, B*Ltxt, D) + outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB) + outTxt = mlx.Reshape(outTxt, B, Ltxt, D) + + return outImg, outTxt +} + +// applyRoPE applies rotary embeddings using complex multiplication +// x: [B, L, nheads, head_dim] +// freqs: [L, head_dim] as complex (interleaved real/imag pairs) +func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + nheads := shape[2] + headDim := shape[3] + halfDim := headDim / 2 + + // Reshape x to pairs: [B, L, nheads, half, 2] + xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2) + + // freqs: [L, head_dim] -> [1, L, 1, half, 2] + freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2) + + // Extract real/imag parts + xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1}) + xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1}) + xReal = mlx.Squeeze(xReal, 4) + xImag = mlx.Squeeze(xImag, 4) + + freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1}) + freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1}) + freqReal = mlx.Squeeze(freqReal, 4) + freqImag = mlx.Squeeze(freqImag, 4) + + // Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag)) + outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal)) + + // Interleave back + outReal = mlx.ExpandDims(outReal, 4) + outImag = mlx.ExpandDims(outImag, 4) + out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4) + + return mlx.Reshape(out, B, L, nheads, headDim) +} + +// MLP implements GELU MLP (not GEGLU) +type MLP struct { + ProjWeight *mlx.Array + ProjBias *mlx.Array + OutWeight *mlx.Array + OutBias *mlx.Array +} + +// newMLP creates a GELU MLP +func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) { + projWeight, _ := weights.Get(prefix + ".net.0.proj.weight") + projBias, _ := weights.Get(prefix + ".net.0.proj.bias") + outWeight, _ := weights.Get(prefix + ".net.2.weight") + outBias, _ := weights.Get(prefix + ".net.2.bias") + + return &MLP{ + ProjWeight: mlx.Transpose(projWeight, 1, 0), + ProjBias: projBias, + OutWeight: mlx.Transpose(outWeight, 1, 0), + OutBias: outBias, + }, nil +} + +// Forward applies GELU MLP +func (m *MLP) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + xFlat := mlx.Reshape(x, B*L, D) + h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias) + h = geluApprox(h) + h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias) + return mlx.Reshape(h, B, L, m.OutBias.Dim(0)) +} + +// geluApprox implements approximate GELU +func geluApprox(x *mlx.Array) *mlx.Array { + sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi)) + x3 := mlx.Mul(mlx.Mul(x, x), x) + inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715)) + inner = mlx.MulScalar(inner, sqrt2OverPi) + return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0)) +} + +// TransformerBlock is a single dual-stream transformer block +type TransformerBlock struct { + Attention *JointAttention + ImgMLP *MLP + TxtMLP *MLP + + ImgModWeight *mlx.Array + ImgModBias *mlx.Array + TxtModWeight *mlx.Array + TxtModBias *mlx.Array + + HiddenDim int32 + NormEps float32 +} + +// newTransformerBlock creates a transformer block +func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) { + attn, err := newJointAttention(weights, prefix, cfg) + if err != nil { + return nil, err + } + + imgMLP, _ := newMLP(weights, prefix+".img_mlp") + txtMLP, _ := newMLP(weights, prefix+".txt_mlp") + + imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight") + imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias") + txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight") + txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias") + + return &TransformerBlock{ + Attention: attn, + ImgMLP: imgMLP, + TxtMLP: txtMLP, + ImgModWeight: mlx.Transpose(imgModWeight, 1, 0), + ImgModBias: imgModBias, + TxtModWeight: mlx.Transpose(txtModWeight, 1, 0), + TxtModBias: txtModBias, + HiddenDim: cfg.HiddenDim, + NormEps: cfg.NormEps, + }, nil +} + +// Forward applies the transformer block +func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) { + // Compute modulation: silu(temb) -> linear -> [B, 6*D] + siluT := mlx.SiLU(temb) + imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias) + txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias) + + // Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2 + imgModParts := splitMod6(imgMod, tb.HiddenDim) + txtModParts := splitMod6(txtMod, tb.HiddenDim) + + // Pre-attention: norm + modulate + imgNorm := layerNormNoAffine(img, tb.NormEps) + imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0]) + + txtNorm := layerNormNoAffine(txt, tb.NormEps) + txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0]) + + // Joint attention + attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs) + + // Residual with gate + img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg)) + txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt)) + + // Pre-MLP: norm + modulate + imgNorm2 := layerNormNoAffine(img, tb.NormEps) + imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3]) + + txtNorm2 := layerNormNoAffine(txt, tb.NormEps) + txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3]) + + // MLP + mlpImg := tb.ImgMLP.Forward(imgNorm2) + mlpTxt := tb.TxtMLP.Forward(txtNorm2) + + // Residual with gate + img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg)) + txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt)) + + return img, txt +} + +// splitMod6 splits modulation into 6 parts each [B, 1, D] +func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array { + shape := mod.Shape() + B := shape[0] + parts := make([]*mlx.Array, 6) + for i := int32(0); i < 6; i++ { + part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim}) + parts[i] = mlx.ExpandDims(part, 1) + } + return parts +} + +// layerNormNoAffine applies layer norm without learnable parameters +func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array { + ndim := x.Ndim() + lastAxis := ndim - 1 + mean := mlx.Mean(x, lastAxis, true) + xCentered := mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true) + return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps))) +} + +// Transformer is the full Qwen-Image transformer model +type Transformer struct { + Config *TransformerConfig + + ImgIn *mlx.Array + ImgInBias *mlx.Array + TxtIn *mlx.Array + TxtInBias *mlx.Array + TxtNorm *mlx.Array + + TEmbed *TimestepEmbedder + Layers []*TransformerBlock + + NormOutWeight *mlx.Array + NormOutBias *mlx.Array + ProjOut *mlx.Array + ProjOutBias *mlx.Array +} + +// Load loads the transformer from a directory +func (m *Transformer) Load(path string) error { + fmt.Println("Loading Qwen-Image transformer...") + + cfg := defaultTransformerConfig() + m.Config = cfg + + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + + // Bulk load all weights as bf16 + fmt.Print(" Loading weights as bf16... ") + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) + + fmt.Print(" Loading input projections... ") + imgIn, _ := weights.Get("img_in.weight") + imgInBias, _ := weights.Get("img_in.bias") + txtIn, _ := weights.Get("txt_in.weight") + txtInBias, _ := weights.Get("txt_in.bias") + txtNorm, _ := weights.Get("txt_norm.weight") + m.ImgIn = mlx.Transpose(imgIn, 1, 0) + m.ImgInBias = imgInBias + m.TxtIn = mlx.Transpose(txtIn, 1, 0) + m.TxtInBias = txtInBias + m.TxtNorm = txtNorm + fmt.Println("✓") + + fmt.Print(" Loading timestep embedder... ") + m.TEmbed, err = newTimestepEmbedder(weights) + if err != nil { + return fmt.Errorf("timestep embedder: %w", err) + } + fmt.Println("✓") + + m.Layers = make([]*TransformerBlock, cfg.NLayers) + for i := int32(0); i < cfg.NLayers; i++ { + fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers) + prefix := fmt.Sprintf("transformer_blocks.%d", i) + m.Layers[i], err = newTransformerBlock(weights, prefix, cfg) + if err != nil { + return fmt.Errorf("layer %d: %w", i, err) + } + } + fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers) + + fmt.Print(" Loading output layers... ") + normOutWeight, _ := weights.Get("norm_out.linear.weight") + normOutBias, _ := weights.Get("norm_out.linear.bias") + projOut, _ := weights.Get("proj_out.weight") + projOutBias, _ := weights.Get("proj_out.bias") + m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0) + m.NormOutBias = normOutBias + m.ProjOut = mlx.Transpose(projOut, 1, 0) + m.ProjOutBias = projOutBias + fmt.Println("✓") + + weights.ReleaseAll() + return nil +} + +// LoadFromPath is a convenience function to load transformer from path +func LoadTransformerFromPath(path string) (*Transformer, error) { + m := &Transformer{} + if err := m.Load(filepath.Join(path, "transformer")); err != nil { + return nil, err + } + return m, nil +} + +// Forward runs the transformer +// img: [B, L_img, in_channels] patchified latents +// txt: [B, L_txt, joint_attention_dim] text embeddings +// t: [B] timesteps (0-1) +// imgFreqs, txtFreqs: RoPE frequencies +func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array { + imgShape := img.Shape() + B := imgShape[0] + Limg := imgShape[1] + + txtShape := txt.Shape() + Ltxt := txtShape[1] + + // Timestep embedding + temb := tr.TEmbed.Forward(t) + + // Project image: [B, L, in_channels] -> [B, L, hidden_dim] + imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels) + imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias) + imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim) + + // Project text: RMSNorm then linear + txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim) + txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6) + txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias) + txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim) + + for _, layer := range tr.Layers { + imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs) + } + + // Final norm with modulation (AdaLayerNormContinuous) + // Python: scale, shift = torch.chunk(emb, 2, dim=1) + finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias) + modShape := finalMod.Shape() + halfDim := modShape[1] / 2 + scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1) + shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1) + + imgH = layerNormNoAffine(imgH, tr.Config.NormEps) + imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift) + + // Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels] + imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim) + out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias) + + outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels + return mlx.Reshape(out, B, Limg, outChannels) +} + +// ForwardWithCache runs the transformer with layer caching for speedup. +// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024): +// shallow layers change little between denoising steps, so we cache their +// outputs and reuse them on non-refresh steps. +// +// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers)) +// step: current denoising step (0-indexed) +// cacheInterval: refresh cache every N steps (e.g., 3) +// cacheLayers: number of shallow layers to cache (e.g., 15) +func (tr *Transformer) ForwardWithCache( + img, txt, t *mlx.Array, + imgFreqs, txtFreqs *mlx.Array, + stepCache *cache.StepCache, + step, cacheInterval, cacheLayers int, +) *mlx.Array { + imgShape := img.Shape() + B := imgShape[0] + Limg := imgShape[1] + + txtShape := txt.Shape() + Ltxt := txtShape[1] + + // Timestep embedding + temb := tr.TEmbed.Forward(t) + + // Project image: [B, L, in_channels] -> [B, L, hidden_dim] + imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels) + imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias) + imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim) + + // Project text: RMSNorm then linear + txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim) + txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6) + txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias) + txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim) + + // Check if we should refresh the cache + refreshCache := stepCache.ShouldRefresh(step, cacheInterval) + + for i, layer := range tr.Layers { + if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil { + // Use cached outputs for shallow layers + imgH = stepCache.Get(i) + txtH = stepCache.Get2(i) + } else { + // Compute layer + imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs) + // Cache shallow layers on refresh steps + if i < cacheLayers && refreshCache { + stepCache.Set(i, imgH) + stepCache.Set2(i, txtH) + } + } + } + + // Final norm with modulation (AdaLayerNormContinuous) + finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias) + modShape := finalMod.Shape() + halfDim := modShape[1] / 2 + scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1) + shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1) + + imgH = layerNormNoAffine(imgH, tr.Config.NormEps) + imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift) + + // Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels] + imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim) + out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias) + + outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels + return mlx.Reshape(out, B, Limg, outChannels) +} + +// RoPECache holds precomputed RoPE frequencies +type RoPECache struct { + ImgFreqs *mlx.Array // [L_img, head_dim] + TxtFreqs *mlx.Array // [L_txt, head_dim] +} + +// PrepareRoPE computes RoPE for image and text sequences +// This matches Python's QwenEmbedRope with scale_rope=True +func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache { + theta := float64(10000) + maxIdx := int32(4096) + + // Compute base frequencies for each axis dimension + freqsT := ComputeAxisFreqs(axesDims[0], theta) + freqsH := ComputeAxisFreqs(axesDims[1], theta) + freqsW := ComputeAxisFreqs(axesDims[2], theta) + + // Build frequency lookup tables + posFreqsT := MakeFreqTable(maxIdx, freqsT, false) + posFreqsH := MakeFreqTable(maxIdx, freqsH, false) + posFreqsW := MakeFreqTable(maxIdx, freqsW, false) + negFreqsH := MakeFreqTable(maxIdx, freqsH, true) + negFreqsW := MakeFreqTable(maxIdx, freqsW, true) + + // Image frequencies with scale_rope=True + imgLen := imgH * imgW + headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2 + imgFreqsData := make([]float32, imgLen*headDim) + + hHalf := imgH / 2 + wHalf := imgW / 2 + + idx := int32(0) + for y := int32(0); y < imgH; y++ { + for x := int32(0); x < imgW; x++ { + // Frame = 0 + for i := 0; i < len(freqsT)*2; i++ { + imgFreqsData[idx+int32(i)] = posFreqsT[0][i] + } + idx += int32(len(freqsT) * 2) + + // Height: scale_rope pattern + hNegCount := imgH - hHalf + if y < hNegCount { + negTableIdx := maxIdx - hNegCount + y + for i := 0; i < len(freqsH)*2; i++ { + imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i] + } + } else { + posIdx := y - hNegCount + for i := 0; i < len(freqsH)*2; i++ { + imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i] + } + } + idx += int32(len(freqsH) * 2) + + // Width: scale_rope pattern + wNegCount := imgW - wHalf + if x < wNegCount { + negTableIdx := maxIdx - wNegCount + x + for i := 0; i < len(freqsW)*2; i++ { + imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i] + } + } else { + posIdx := x - wNegCount + for i := 0; i < len(freqsW)*2; i++ { + imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i] + } + } + idx += int32(len(freqsW) * 2) + } + } + + imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim}) + imgFreqs = mlx.ToBFloat16(imgFreqs) + + // Text frequencies + maxVidIdx := max(hHalf, wHalf) + txtFreqsData := make([]float32, txtLen*headDim) + + idx = 0 + for t := int32(0); t < txtLen; t++ { + pos := maxVidIdx + t + for i := 0; i < len(freqsT)*2; i++ { + txtFreqsData[idx+int32(i)] = posFreqsT[pos][i] + } + idx += int32(len(freqsT) * 2) + for i := 0; i < len(freqsH)*2; i++ { + txtFreqsData[idx+int32(i)] = posFreqsH[pos][i] + } + idx += int32(len(freqsH) * 2) + for i := 0; i < len(freqsW)*2; i++ { + txtFreqsData[idx+int32(i)] = posFreqsW[pos][i] + } + idx += int32(len(freqsW) * 2) + } + + txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim}) + txtFreqs = mlx.ToBFloat16(txtFreqs) + + return &RoPECache{ + ImgFreqs: imgFreqs, + TxtFreqs: txtFreqs, + } +} + +// ComputeAxisFreqs computes RoPE base frequencies for a given dimension. +func ComputeAxisFreqs(dim int32, theta float64) []float64 { + halfDim := dim / 2 + freqs := make([]float64, halfDim) + for i := int32(0); i < halfDim; i++ { + freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim)) + } + return freqs +} + +// MakeFreqTable builds a table of cos/sin values for RoPE positions. +func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 { + table := make([][]float32, maxIdx) + for idx := int32(0); idx < maxIdx; idx++ { + var pos float64 + if negative { + pos = float64(-maxIdx + int32(idx)) + } else { + pos = float64(idx) + } + + row := make([]float32, len(baseFreqs)*2) + for i, f := range baseFreqs { + angle := pos * f + row[i*2] = float32(math.Cos(angle)) + row[i*2+1] = float32(math.Sin(angle)) + } + table[idx] = row + } + return table +} + +func max(a, b int32) int32 { + if a > b { + return a + } + return b +} + +// PackLatents converts [B, C, H, W] to [B, L, C*4] patches +func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array { + shape := latents.Shape() + B := shape[0] + C := shape[1] + H := shape[2] + W := shape[3] + + pH := H / patchSize + pW := W / patchSize + + // [B, C, H, W] -> [B, C, pH, 2, pW, 2] + x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize) + // -> [B, pH, pW, C, 2, 2] + x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5) + // -> [B, pH*pW, C*4] + return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize) +} + +// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE) +func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array { + shape := patches.Shape() + B := shape[0] + channels := shape[2] / (patchSize * patchSize) + + pH := H / patchSize + pW := W / patchSize + + // [B, L, C*4] -> [B, pH, pW, C, 2, 2] + x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize) + // -> [B, C, pH, 2, pW, 2] + x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5) + // -> [B, C, H, W] + x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize) + // Add temporal dimension for VAE: [B, C, 1, H, W] + return mlx.ExpandDims(x, 2) +} diff --git a/x/imagegen/models/qwen_image/transformer_test.go b/x/imagegen/models/qwen_image/transformer_test.go new file mode 100644 index 000000000..5eef53b1d --- /dev/null +++ b/x/imagegen/models/qwen_image/transformer_test.go @@ -0,0 +1,119 @@ +//go:build mlx + +package qwen_image + +import ( + "math" + "os" + "testing" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// TestTransformerConfig tests configuration invariants. +func TestTransformerConfig(t *testing.T) { + cfg := defaultTransformerConfig() + + // Property: hidden_dim = n_heads * head_dim + if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim { + t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d", + cfg.HiddenDim, cfg.NHeads, cfg.HeadDim) + } + + // Property: axes_dims_rope sums to head_dim + var ropeSum int32 + for _, d := range cfg.AxesDimsRope { + ropeSum += d + } + if ropeSum != cfg.HeadDim { + t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim) + } + + // Property: in_channels = out_channels * patch_size^2 + expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize + if cfg.InChannels != expectedIn { + t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn) + } +} + +// TestTransformerRoPE tests RoPE frequency computation produces valid values. +func TestTransformerRoPE(t *testing.T) { + cfg := defaultTransformerConfig() + + // Test with small image dimensions + imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches + txtLen := int32(5) + + ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope) + mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs) + + // Verify shapes: [seq_len, head_dim] + imgSeqLen := imgH * imgW + if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen { + t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen) + } + if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim { + t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim) + } + + if ropeCache.TxtFreqs.Shape()[0] != txtLen { + t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen) + } + + // Verify values are finite + imgData := ropeCache.ImgFreqs.Data() + for i := 0; i < min(100, len(imgData)); i++ { + if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) { + t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i]) + break + } + } +} + +// TestTransformerForward tests full forward pass (integration test). +// Skips if model weights are not available. +func TestTransformerForward(t *testing.T) { + weightsPath := "../../../weights/Qwen-Image-2512/transformer" + if _, err := os.Stat(weightsPath); os.IsNotExist(err) { + t.Skip("Skipping: model weights not found at " + weightsPath) + } + + transformer := &Transformer{} + if err := transformer.Load(weightsPath); err != nil { + t.Fatalf("Failed to load transformer: %v", err) + } + mlx.Keep(mlx.Collect(transformer)...) + cfg := transformer.Config + + // Small test inputs + batchSize := int32(1) + imgH, imgW := int32(4), int32(4) + imgSeqLen := imgH * imgW + txtSeqLen := int32(5) + + hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0) + encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0) + timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize}) + + ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope) + + // Forward pass + out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) + mlx.Eval(out) + + // Verify output shape: [batch, img_seq_len, in_channels] + wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels} + gotShape := out.Shape() + if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] { + t.Errorf("output shape: got %v, want %v", gotShape, wantShape) + } + + // Verify output is finite + outData := out.Data() + for i := 0; i < min(100, len(outData)); i++ { + if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) { + t.Errorf("output[%d] not finite: %v", i, outData[i]) + break + } + } +} diff --git a/x/imagegen/models/qwen_image/vae.go b/x/imagegen/models/qwen_image/vae.go new file mode 100644 index 000000000..e1c7f5255 --- /dev/null +++ b/x/imagegen/models/qwen_image/vae.go @@ -0,0 +1,854 @@ +//go:build mlx + +package qwen_image + +import ( + "fmt" + "math" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// VAEConfig holds Qwen-Image VAE configuration +type VAEConfig struct { + ZDim int32 `json:"z_dim"` // 16 + BaseDim int32 `json:"base_dim"` // 96 + DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4] + NumResBlocks int32 `json:"num_res_blocks"` // 2 + LatentsMean []float32 `json:"latents_mean"` // 16 values + LatentsStd []float32 `json:"latents_std"` // 16 values + TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true] +} + +// defaultVAEConfig returns config for Qwen-Image VAE +func defaultVAEConfig() *VAEConfig { + return &VAEConfig{ + ZDim: 16, + BaseDim: 96, + DimMult: []int32{1, 2, 4, 4}, + NumResBlocks: 2, + LatentsMean: []float32{ + -0.7571, -0.7089, -0.9113, 0.1075, + -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, + -0.1922, -0.9497, 0.2503, -0.2921, + }, + LatentsStd: []float32{ + 2.8184, 1.4541, 2.3275, 2.6558, + 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, + 1.6382, 1.1253, 2.8251, 1.916, + }, + TemperalDownsample: []bool{false, true, true}, + } +} + +// CausalConv3d is a causal 3D convolution (for temporal causality) +type CausalConv3d struct { + Weight *mlx.Array + Bias *mlx.Array + BiasReshaped *mlx.Array // [1, C, 1, 1, 1] + KernelT int32 +} + +// newCausalConv3d creates a 3D causal conv +func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) { + weight, err := weights.Get(prefix + ".weight") + if err != nil { + return nil, fmt.Errorf("weight not found: %s", prefix) + } + bias, _ := weights.Get(prefix + ".bias") + + kernelT := weight.Shape()[2] + outC := weight.Shape()[0] + + var biasReshaped *mlx.Array + if bias != nil { + biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1) + } + + return &CausalConv3d{ + Weight: weight, + Bias: bias, + BiasReshaped: biasReshaped, + KernelT: kernelT, + }, nil +} + +// Forward applies causal 3D convolution +// x: [B, T, H, W, C] (channels-last, MLX format) +func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array { + shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW] + kernelT := shape[2] + kernelH := shape[3] + kernelW := shape[4] + + // Causal temporal padding, same spatial padding + // Input is channels-last: [B, T, H, W, C] + padT := kernelT - 1 + padH := kernelH / 2 + padW := kernelW / 2 + + // Stage 1: Pad + { + x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW) + mlx.Eval(x) + } + + // Stage 2: Conv + bias + var out *mlx.Array + { + prev := x + weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1) + out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0) + if c.Bias != nil { + bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0)) + out = mlx.Add(out, bias) + } + prev.Free() + mlx.Eval(out) + } + + return out +} + +// RMSNorm3D applies RMS normalization over channels +// Works with channels-last [B, T, H, W, C] format +type RMSNorm3D struct { + Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting +} + +// newRMSNorm3D creates an RMS norm +func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) { + gamma, err := weights.Get(prefix + ".gamma") + if err != nil { + return nil, err + } + // Reshape for channels-last broadcasting: [1, 1, 1, 1, C] + gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0)) + return &RMSNorm3D{Gamma: gamma}, nil +} + +// Forward applies RMS norm to channels-last input [B, T, H, W, C] +func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array { + // RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma + normalized := mlx.RMSNormNoWeight(x, 1e-6) + return mlx.Mul(normalized, n.Gamma) +} + +// ResBlock is a residual block with RMS norm and causal convs +type ResBlock struct { + Norm1 *RMSNorm3D + Conv1 *CausalConv3d + Norm2 *RMSNorm3D + Conv2 *CausalConv3d + Shortcut *CausalConv3d +} + +// newResBlock creates a residual block +func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) { + norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim) + if err != nil { + return nil, err + } + conv1, err := newCausalConv3d(weights, prefix+".conv1") + if err != nil { + return nil, err + } + norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim) + if err != nil { + return nil, err + } + conv2, err := newCausalConv3d(weights, prefix+".conv2") + if err != nil { + return nil, err + } + + var shortcut *CausalConv3d + if inDim != outDim { + shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut") + if err != nil { + return nil, err + } + } + + return &ResBlock{ + Norm1: norm1, + Conv1: conv1, + Norm2: norm2, + Conv2: conv2, + Shortcut: shortcut, + }, nil +} + +// Forward applies the residual block +func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array { + // Use h as working variable, keep x intact for residual (caller will free x) + // Conv handles its own pools, so we just need pools for non-conv operations + var h *mlx.Array + + // Keep x so it survives Eval() cleanup - needed for residual connection + mlx.Keep(x) + + // Stage 1: norm1 + silu + { + h = r.Norm1.Forward(x) + h = silu3D(h) + mlx.Eval(h) + } + + // Stage 2: conv1 (handles its own pools) + { + prev := h + h = r.Conv1.Forward(h) + prev.Free() + } + + // Stage 3: norm2 + silu + { + prev := h + h = r.Norm2.Forward(h) + h = silu3D(h) + prev.Free() + mlx.Eval(h) + } + + // Stage 4: conv2 (handles its own pools) + { + prev := h + h = r.Conv2.Forward(h) + prev.Free() + } + + // Residual connection (shortcut handles its own pools if present) + if r.Shortcut != nil { + shortcut := r.Shortcut.Forward(x) + h = mlx.Add(h, shortcut) + mlx.Eval(h) + } else { + h = mlx.Add(h, x) + mlx.Eval(h) + } + + return h +} + +// AttentionBlock is a 2D attention block +type AttentionBlock struct { + Norm *RMSNorm3D + ToQKV *mlx.Array + ToQKVBias *mlx.Array + Proj *mlx.Array + ProjBias *mlx.Array + Dim int32 +} + +// newAttentionBlock creates an attention block +func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) { + norm, err := newRMSNorm3D(weights, prefix+".norm", dim) + if err != nil { + return nil, err + } + toQKV, _ := weights.Get(prefix + ".to_qkv.weight") + toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias") + proj, _ := weights.Get(prefix + ".proj.weight") + projBias, _ := weights.Get(prefix + ".proj.bias") + + return &AttentionBlock{ + Norm: norm, + ToQKV: toQKV, + ToQKVBias: toQKVBias, + Proj: proj, + ProjBias: projBias, + Dim: dim, + }, nil +} + +// Forward applies 2D attention +// Input: [B, T, H, W, C] (channels-last) +func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + + identity := x + + // Flatten to [B*T, 1, H, W, C] for norm + x = mlx.Reshape(x, B*T, 1, H, W, C) + x = a.Norm.Forward(x) + x = mlx.Reshape(x, B*T, H, W, C) + + // Flatten spatial to [B*T, H*W, C] + x = mlx.Reshape(x, B*T, H*W, C) + + // Linear to get Q, K, V: [B*T, H*W, 3*C] + // Weight is [outC, inC] or [outC, inC, 1, 1] + wShape := a.ToQKV.Shape() + var w *mlx.Array + if len(wShape) == 4 { + w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1]) + } else { + w = a.ToQKV + } + w = mlx.Transpose(w, 1, 0) // [inC, outC] + + qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C] + if a.ToQKVBias != nil { + qkv = mlx.Add(qkv, a.ToQKVBias) + } + qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C) + + q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C}) + k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C}) + v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C}) + + scale := float32(1.0 / math.Sqrt(float64(C))) + out := mlx.ScaledDotProductAttention(q, k, v, scale, false) + + // out: [B*T, 1, H*W, C] + out = mlx.Reshape(out, B*T, H*W, C) + + // Project back + pShape := a.Proj.Shape() + var p *mlx.Array + if len(pShape) == 4 { + p = mlx.Reshape(a.Proj, pShape[0], pShape[1]) + } else { + p = a.Proj + } + p = mlx.Transpose(p, 1, 0) // [inC, outC] + out = mlx.Linear(out, p) // [B*T, H*W, C] + if a.ProjBias != nil { + out = mlx.Add(out, a.ProjBias) + } + + out = mlx.Reshape(out, B, T, H, W, C) + return mlx.Add(out, identity) +} + +// UpBlock handles upsampling in decoder +type UpBlock struct { + ResBlocks []*ResBlock + Upsampler *Upsample +} + +// newUpBlock creates an up block +func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) { + resBlocks := make([]*ResBlock, numBlocks+1) + + currentDim := inDim + for i := int32(0); i <= numBlocks; i++ { + resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i) + block, err := newResBlock(weights, resPrefix, currentDim, outDim) + if err != nil { + return nil, err + } + resBlocks[i] = block + currentDim = outDim + } + + var upsampler *Upsample + if upsampleMode != "" { + upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode) + } + + return &UpBlock{ + ResBlocks: resBlocks, + Upsampler: upsampler, + }, nil +} + +// Forward applies up block with staged memory management +func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array { + // ResBlocks handle their own pools + for _, block := range u.ResBlocks { + prev := x + x = block.Forward(x) + prev.Free() + } + + // Upsampler handles its own pools + if u.Upsampler != nil { + prev := x + x = u.Upsampler.Forward(x) + prev.Free() + } + return x +} + +// Upsample handles spatial upsampling +type Upsample struct { + Conv *mlx.Array + Bias *mlx.Array + Mode string +} + +// newUpsample creates an upsampler +func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample { + conv, _ := weights.Get(prefix + ".resample.1.weight") + bias, _ := weights.Get(prefix + ".resample.1.bias") + return &Upsample{ + Conv: conv, + Bias: bias, + Mode: mode, + } +} + +// Forward applies upsampling to channels-last input [B, T, H, W, C] +// Uses staged pools to reduce peak memory during 2x upsampling +func (u *Upsample) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + outC := u.Conv.Shape()[0] + + // Stage 1: 2x nearest neighbor upsample + { + x = mlx.Reshape(x, B*T, H, W, C) + x = upsample2xChannelsLast(x) + mlx.Eval(x) + } + + // Stage 2: Conv + bias + { + prev := x + weight := mlx.Transpose(u.Conv, 0, 2, 3, 1) + x = conv2D3x3PaddedChannelsLast(x, weight) + if u.Bias != nil { + bias := mlx.Reshape(u.Bias, 1, 1, 1, outC) + x = mlx.Add(x, bias) + } + x = mlx.Reshape(x, B, T, H*2, W*2, outC) + prev.Free() + mlx.Eval(x) + } + + return x +} + +// MidBlock is the middle block of decoder +type MidBlock struct { + ResBlock1 *ResBlock + Attention *AttentionBlock + ResBlock2 *ResBlock +} + +// newMidBlock creates a mid block +func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) { + res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim) + if err != nil { + return nil, err + } + attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim) + if err != nil { + return nil, err + } + res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim) + if err != nil { + return nil, err + } + + return &MidBlock{ + ResBlock1: res1, + Attention: attn, + ResBlock2: res2, + }, nil +} + +// Forward applies mid block +func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array { + // Each component handles its own pools; we just free inputs + prev := x + x = m.ResBlock1.Forward(x) + prev.Free() + + prev = x + x = m.Attention.Forward(x) + prev.Free() + + prev = x + x = m.ResBlock2.Forward(x) + prev.Free() + + return x +} + +// VAEDecoder is the full VAE decoder +type VAEDecoder struct { + Config *VAEConfig + + PostQuantConv *CausalConv3d + ConvIn *CausalConv3d + MidBlock *MidBlock + UpBlocks []*UpBlock + NormOut *RMSNorm3D + ConvOut *CausalConv3d +} + +// Load loads the VAE decoder from a directory +func (m *VAEDecoder) Load(path string) error { + fmt.Println("Loading Qwen-Image VAE decoder...") + + cfg := defaultVAEConfig() + m.Config = cfg + + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + + // Bulk load all weights as bf16 + fmt.Print(" Loading weights as bf16... ") + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("failed to load weights: %w", err) + } + fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) + + fmt.Print(" Loading post_quant_conv... ") + postQuantConv, err := newCausalConv3d(weights, "post_quant_conv") + if err != nil { + return err + } + m.PostQuantConv = postQuantConv + fmt.Println("✓") + + fmt.Print(" Loading conv_in... ") + convIn, err := newCausalConv3d(weights, "decoder.conv_in") + if err != nil { + return err + } + m.ConvIn = convIn + fmt.Println("✓") + + // Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384) + fmt.Print(" Loading mid_block... ") + midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1] + midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim) + if err != nil { + return err + } + m.MidBlock = midBlock + fmt.Println("✓") + + // Up blocks (reversed dim_mult) + fmt.Print(" Loading up_blocks... ") + numUpBlocks := len(cfg.DimMult) + m.UpBlocks = make([]*UpBlock, numUpBlocks) + + dimsMult := make([]int32, numUpBlocks+1) + dimsMult[0] = cfg.DimMult[numUpBlocks-1] + for i := 0; i < numUpBlocks; i++ { + dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i] + } + + temporalUpsample := make([]bool, len(cfg.TemperalDownsample)) + for i := range cfg.TemperalDownsample { + temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i] + } + + for i := 0; i < numUpBlocks; i++ { + inDim := cfg.BaseDim * dimsMult[i] + outDim := cfg.BaseDim * dimsMult[i+1] + + if i > 0 { + inDim = inDim / 2 + } + + upsampleMode := "" + if i < numUpBlocks-1 { + if temporalUpsample[i] { + upsampleMode = "upsample3d" + } else { + upsampleMode = "upsample2d" + } + } + + prefix := fmt.Sprintf("decoder.up_blocks.%d", i) + upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode) + if err != nil { + return err + } + m.UpBlocks[i] = upBlock + } + fmt.Printf("✓ [%d blocks]\n", numUpBlocks) + + fmt.Print(" Loading output layers... ") + normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim) + if err != nil { + return err + } + m.NormOut = normOut + convOut, err := newCausalConv3d(weights, "decoder.conv_out") + if err != nil { + return err + } + m.ConvOut = convOut + fmt.Println("✓") + + weights.ReleaseAll() + return nil +} + +// LoadVAEDecoderFromPath is a convenience function to load VAE from path +func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) { + m := &VAEDecoder{} + if err := m.Load(filepath.Join(path, "vae")); err != nil { + return nil, err + } + return m, nil +} + +// Decode converts latents to image +// z: [B, C, T, H, W] normalized latents +// Uses staged pools to free intermediate arrays and reduce peak memory. +func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array { + var x *mlx.Array + + // Stage 1a: Denormalize and transpose + { + z = vae.Denormalize(z) + // Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C] + z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1)) + mlx.Eval(z) + } + + // Stage 1b: PostQuantConv (handles its own pools) + x = vae.PostQuantConv.Forward(z) + z.Free() + + // Stage 1c: ConvIn (handles its own pools) + { + prev := x + x = vae.ConvIn.Forward(x) + prev.Free() + } + + // Stage 2: Mid block (handles its own pools) + x = vae.MidBlock.Forward(x) + + // Stage 3: Up blocks (each handles its own pools) + for _, upBlock := range vae.UpBlocks { + x = upBlock.Forward(x) + } + + // Stage 4a: NormOut + silu + { + prev := x + x = vae.NormOut.Forward(x) + x = silu3D(x) + prev.Free() + mlx.Eval(x) + } + + // Stage 4b: ConvOut (handles its own pools) + { + prev := x + x = vae.ConvOut.Forward(x) + prev.Free() + } + + // Stage 4c: Post-processing + { + prev := x + // Clamp to [-1, 1] + x = mlx.ClipScalar(x, -1.0, 1.0, true, true) + // Convert back from channels-last to channels-first + x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3)) + prev.Free() + mlx.Eval(x) + } + + return x +} + +// Denormalize reverses the normalization applied during encoding +func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array { + shape := z.Shape() + C := shape[1] + + mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1}) + std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1}) + + mean = mlx.ToBFloat16(mean) + std = mlx.ToBFloat16(std) + + return mlx.Add(mlx.Mul(z, std), mean) +} + +// Helper functions + +func silu3D(x *mlx.Array) *mlx.Array { + return mlx.Mul(x, mlx.Sigmoid(x)) +} + +// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor +func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array { + if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 { + return x + } + // Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after] + return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0}) +} + +func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array { + if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 { + return x + } + return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter}) +} + +func conv2D1x1(x, weight *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[2] + W := shape[3] + + x = mlx.Transpose(x, 0, 2, 3, 1) + x = mlx.Reshape(x, B*H*W, shape[1]) + + wShape := weight.Shape() + var w *mlx.Array + if len(wShape) == 4 { + w = mlx.Reshape(weight, wShape[0], wShape[1]) + } else { + w = weight + } + w = mlx.Transpose(w, 1, 0) + + out := mlx.Linear(x, w) + outC := w.Dim(1) + out = mlx.Reshape(out, B, H, W, outC) + return mlx.Transpose(out, 0, 3, 1, 2) +} + +func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array { + x = pad2D(x, 1, 1, 1, 1) + return conv2D(x, weight, 1, 1) +} + +func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array { + x = mlx.Transpose(x, 0, 2, 3, 1) + w = mlx.Transpose(w, 0, 2, 3, 1) + + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + + wShape := w.Shape() + Cout := wShape[0] + kH := wShape[1] + kW := wShape[2] + + outH := (H-kH)/strideH + 1 + outW := (W-kW)/strideW + 1 + + patches := extractPatches2D(x, kH, kW, strideH, strideW) + wFlat := mlx.Reshape(w, Cout, -1) + patches = mlx.Reshape(patches, B*outH*outW, -1) + out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) + out = mlx.Reshape(out, B, outH, outW, Cout) + return mlx.Transpose(out, 0, 3, 1, 2) +} + +func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + C := shape[3] + + outH := (H-kH)/strideH + 1 + outW := (W-kW)/strideW + 1 + + patches := make([]*mlx.Array, outH*outW) + idx := 0 + for i := int32(0); i < outH; i++ { + for j := int32(0); j < outW; j++ { + startH := i * strideH + startW := j * strideW + patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C}) + patch = mlx.Reshape(patch, B, kH*kW*C) + patches[idx] = patch + idx++ + } + } + + for i := range patches { + patches[i] = mlx.ExpandDims(patches[i], 1) + } + stacked := mlx.Concatenate(patches, 1) + return mlx.Reshape(stacked, B, outH, outW, kH*kW*C) +} + +func upsample2x(x *mlx.Array) *mlx.Array { + shape := x.Shape() + H := shape[2] + W := shape[3] + + rowIdxData := make([]int32, H*2) + for i := int32(0); i < H; i++ { + rowIdxData[i*2] = i + rowIdxData[i*2+1] = i + } + rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2}) + + colIdxData := make([]int32, W*2) + for i := int32(0); i < W; i++ { + colIdxData[i*2] = i + colIdxData[i*2+1] = i + } + colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2}) + + x = mlx.Take(x, rowIdx, 2) + x = mlx.Take(x, colIdx, 3) + + return x +} + +// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x +func upsample2xChannelsLast(x *mlx.Array) *mlx.Array { + shape := x.Shape() + H := shape[1] + W := shape[2] + + // Create repeat indices for rows + rowIdxData := make([]int32, H*2) + for i := int32(0); i < H; i++ { + rowIdxData[i*2] = i + rowIdxData[i*2+1] = i + } + rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2}) + + // Create repeat indices for columns + colIdxData := make([]int32, W*2) + for i := int32(0); i < W; i++ { + colIdxData[i*2] = i + colIdxData[i*2+1] = i + } + colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2}) + + // Take along H (axis 1) then W (axis 2) + x = mlx.Take(x, rowIdx, 1) + x = mlx.Take(x, colIdx, 2) + + return x +} + +// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C] +// weight: [outC, kH, kW, inC] (MLX channels-last format) +func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array { + // Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side + x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0}) + // Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC] + // stride=1, padding=0 (we already padded manually) + return mlx.Conv2d(x, weight, 1, 0) +} diff --git a/x/imagegen/models/qwen_image/vae_test.go b/x/imagegen/models/qwen_image/vae_test.go new file mode 100644 index 000000000..f15a1134b --- /dev/null +++ b/x/imagegen/models/qwen_image/vae_test.go @@ -0,0 +1,114 @@ +//go:build mlx + +package qwen_image + +import ( + "math" + "os" + "testing" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// TestVAEConfig tests configuration invariants. +func TestVAEConfig(t *testing.T) { + cfg := defaultVAEConfig() + + // Property: latents_mean and latents_std have z_dim elements + if int32(len(cfg.LatentsMean)) != cfg.ZDim { + t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim) + } + if int32(len(cfg.LatentsStd)) != cfg.ZDim { + t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim) + } + + // Property: dim_mult defines 4 stages + if len(cfg.DimMult) != 4 { + t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult)) + } + + // Property: temperal_downsample has 3 elements (for 3 transitions) + if len(cfg.TemperalDownsample) != 3 { + t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample)) + } +} + +// TestVAELatentsNormalization tests the latent denormalization values. +func TestVAELatentsNormalization(t *testing.T) { + cfg := defaultVAEConfig() + + // Verify latents_std values are all positive + for i, std := range cfg.LatentsStd { + if std <= 0 { + t.Errorf("latents_std[%d] should be positive: %v", i, std) + } + } + + // Verify values are in reasonable range (from actual model) + for i, mean := range cfg.LatentsMean { + if math.Abs(float64(mean)) > 5 { + t.Errorf("latents_mean[%d] seems too large: %v", i, mean) + } + } + for i, std := range cfg.LatentsStd { + if std > 10 { + t.Errorf("latents_std[%d] seems too large: %v", i, std) + } + } +} + +// TestVAEDecoderForward tests full forward pass (integration test). +// Skips if model weights are not available. +func TestVAEDecoderForward(t *testing.T) { + weightsPath := "../../../weights/Qwen-Image-2512/vae" + if _, err := os.Stat(weightsPath); os.IsNotExist(err) { + t.Skip("Skipping: model weights not found at " + weightsPath) + } + + vae := &VAEDecoder{} + if err := vae.Load(weightsPath); err != nil { + t.Fatalf("Failed to load VAE decoder: %v", err) + } + mlx.Keep(mlx.Collect(vae)...) + + // Small test input: [B, C, T, H, W] + // After 4 upsampling stages (2x each), H/W multiply by 16 + batchSize := int32(1) + channels := int32(16) + frames := int32(1) + latentH := int32(4) + latentW := int32(4) + + latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0) + + // Decode + out := vae.Decode(latents) + mlx.Eval(out) + + // Verify output shape: [B, 3, T, H*16, W*16] + outShape := out.Shape() + if outShape[0] != batchSize { + t.Errorf("batch size: got %d, want %d", outShape[0], batchSize) + } + if outShape[1] != 3 { + t.Errorf("channels: got %d, want 3", outShape[1]) + } + if outShape[2] != frames { + t.Errorf("frames: got %d, want %d", outShape[2], frames) + } + expectedH := latentH * 16 // 4 stages of 2x upsampling + expectedW := latentW * 16 + if outShape[3] != expectedH || outShape[4] != expectedW { + t.Errorf("spatial dims: got [%d, %d], want [%d, %d]", + outShape[3], outShape[4], expectedH, expectedW) + } + + // Verify output is in valid range (should be clamped to [0, 1] by decode) + outData := out.Data() + for i := 0; i < min(100, len(outData)); i++ { + if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) { + t.Errorf("output[%d] not finite: %v", i, outData[i]) + break + } + } +} diff --git a/x/imagegen/models/qwen_image_edit/layers.go b/x/imagegen/models/qwen_image_edit/layers.go new file mode 100644 index 000000000..04c192077 --- /dev/null +++ b/x/imagegen/models/qwen_image_edit/layers.go @@ -0,0 +1,682 @@ +//go:build mlx + +package qwen_image_edit + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// CausalConv3d is a causal 3D convolution (for temporal causality) +type CausalConv3d struct { + Weight *mlx.Array + Bias *mlx.Array + BiasReshaped *mlx.Array // [1, C, 1, 1, 1] + KernelT int32 +} + +// newCausalConv3d creates a 3D causal conv +func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) { + weight, err := weights.Get(prefix + ".weight") + if err != nil { + return nil, fmt.Errorf("weight not found: %s", prefix) + } + bias, _ := weights.Get(prefix + ".bias") + + kernelT := weight.Shape()[2] + outC := weight.Shape()[0] + + var biasReshaped *mlx.Array + if bias != nil { + biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1) + } + + return &CausalConv3d{ + Weight: weight, + Bias: bias, + BiasReshaped: biasReshaped, + KernelT: kernelT, + }, nil +} + +// Forward applies causal 3D convolution (or 2D if weight is 4D) +// x: [B, T, H, W, C] (channels-last, MLX format) +func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array { + shape := c.Weight.Shape() + + // Handle both 5D (3D conv) and 4D (2D conv) weights + if len(shape) == 4 { + // 2D conv: [O, I, kH, kW] - need to apply per-frame + return c.forward2D(x) + } + + // 3D conv: [O, I, kT, kH, kW] + kernelT := shape[2] + kernelH := shape[3] + kernelW := shape[4] + + // Causal temporal padding, same spatial padding + padT := kernelT - 1 + padH := kernelH / 2 + padW := kernelW / 2 + + // Stage 1: Pad + { + x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW) + mlx.Eval(x) + } + + // Stage 2: Conv + bias + var out *mlx.Array + { + prev := x + weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1) + out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0) + if c.Bias != nil { + bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0)) + out = mlx.Add(out, bias) + } + prev.Free() + mlx.Eval(out) + } + + return out +} + +// forward2D applies 2D conv per-frame for [B, T, H, W, C] input +func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array { + xShape := x.Shape() + B := xShape[0] + T := xShape[1] + H := xShape[2] + W := xShape[3] + C := xShape[4] + + wShape := c.Weight.Shape() // [O, I, kH, kW] + kernelH := wShape[2] + kernelW := wShape[3] + outC := wShape[0] + + padH := kernelH / 2 + padW := kernelW / 2 + + // Reshape to [B*T, H, W, C] for 2D conv + x = mlx.Reshape(x, B*T, H, W, C) + + // Pad spatially + x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0}) + + // Apply 2D conv + weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I] + x = mlx.Conv2d(x, weight, 1, 0) + + if c.Bias != nil { + bias := mlx.Reshape(c.Bias, 1, 1, 1, outC) + x = mlx.Add(x, bias) + } + + // Get output spatial dims + outH := H + outW := W + + // Reshape back to [B, T, H, W, C] + x = mlx.Reshape(x, B, T, outH, outW, outC) + mlx.Eval(x) + + return x +} + +// RMSNorm3D applies RMS normalization over channels +type RMSNorm3D struct { + Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting +} + +// newRMSNorm3D creates an RMS norm +func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) { + gamma, err := weights.Get(prefix + ".gamma") + if err != nil { + return nil, err + } + gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0)) + return &RMSNorm3D{Gamma: gamma}, nil +} + +// Forward applies RMS norm to channels-last input [B, T, H, W, C] +func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array { + normalized := mlx.RMSNormNoWeight(x, 1e-6) + return mlx.Mul(normalized, n.Gamma) +} + +// ResBlock is a residual block with RMS norm and causal convs +type ResBlock struct { + Norm1 *RMSNorm3D + Conv1 *CausalConv3d + Norm2 *RMSNorm3D + Conv2 *CausalConv3d + Shortcut *CausalConv3d +} + +// newResBlock creates a residual block +func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) { + norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim) + if err != nil { + return nil, err + } + conv1, err := newCausalConv3d(weights, prefix+".conv1") + if err != nil { + return nil, err + } + norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim) + if err != nil { + return nil, err + } + conv2, err := newCausalConv3d(weights, prefix+".conv2") + if err != nil { + return nil, err + } + + var shortcut *CausalConv3d + if inDim != outDim { + shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut") + if err != nil { + return nil, err + } + } + + return &ResBlock{ + Norm1: norm1, + Conv1: conv1, + Norm2: norm2, + Conv2: conv2, + Shortcut: shortcut, + }, nil +} + +// Forward applies the residual block +func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array { + var h *mlx.Array + + mlx.Keep(x) + + // Stage 1: norm1 + silu + { + h = r.Norm1.Forward(x) + h = silu3D(h) + mlx.Eval(h) + } + + // Stage 2: conv1 + { + prev := h + h = r.Conv1.Forward(h) + prev.Free() + } + + // Stage 3: norm2 + silu + { + prev := h + h = r.Norm2.Forward(h) + h = silu3D(h) + prev.Free() + mlx.Eval(h) + } + + // Stage 4: conv2 + { + prev := h + h = r.Conv2.Forward(h) + prev.Free() + } + + // Residual connection + if r.Shortcut != nil { + shortcut := r.Shortcut.Forward(x) + h = mlx.Add(h, shortcut) + mlx.Eval(h) + } else { + h = mlx.Add(h, x) + mlx.Eval(h) + } + + return h +} + +// AttentionBlock is a 2D attention block +type AttentionBlock struct { + Norm *RMSNorm3D + ToQKV *mlx.Array + ToQKVBias *mlx.Array + Proj *mlx.Array + ProjBias *mlx.Array + Dim int32 +} + +// newAttentionBlock creates an attention block +func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) { + norm, err := newRMSNorm3D(weights, prefix+".norm", dim) + if err != nil { + return nil, err + } + toQKV, _ := weights.Get(prefix + ".to_qkv.weight") + toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias") + proj, _ := weights.Get(prefix + ".proj.weight") + projBias, _ := weights.Get(prefix + ".proj.bias") + + return &AttentionBlock{ + Norm: norm, + ToQKV: toQKV, + ToQKVBias: toQKVBias, + Proj: proj, + ProjBias: projBias, + Dim: dim, + }, nil +} + +// Forward applies 2D attention +// Input: [B, T, H, W, C] (channels-last) +func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + + identity := x + + // Flatten to [B*T, 1, H, W, C] for norm + x = mlx.Reshape(x, B*T, 1, H, W, C) + x = a.Norm.Forward(x) + x = mlx.Reshape(x, B*T, H, W, C) + + // Flatten spatial to [B*T, H*W, C] + x = mlx.Reshape(x, B*T, H*W, C) + + // Linear to get Q, K, V + wShape := a.ToQKV.Shape() + var w *mlx.Array + if len(wShape) == 4 { + w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1]) + } else { + w = a.ToQKV + } + w = mlx.Transpose(w, 1, 0) + + qkv := mlx.Linear(x, w) + if a.ToQKVBias != nil { + qkv = mlx.Add(qkv, a.ToQKVBias) + } + qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C) + + q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C}) + k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C}) + v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C}) + + scale := float32(1.0 / math.Sqrt(float64(C))) + out := mlx.ScaledDotProductAttention(q, k, v, scale, false) + + out = mlx.Reshape(out, B*T, H*W, C) + + // Project back + pShape := a.Proj.Shape() + var p *mlx.Array + if len(pShape) == 4 { + p = mlx.Reshape(a.Proj, pShape[0], pShape[1]) + } else { + p = a.Proj + } + p = mlx.Transpose(p, 1, 0) + out = mlx.Linear(out, p) + if a.ProjBias != nil { + out = mlx.Add(out, a.ProjBias) + } + + out = mlx.Reshape(out, B, T, H, W, C) + return mlx.Add(out, identity) +} + +// UpBlock handles upsampling in decoder +type UpBlock struct { + ResBlocks []*ResBlock + Upsampler *Upsample +} + +// newUpBlock creates an up block +func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) { + resBlocks := make([]*ResBlock, numBlocks+1) + + currentDim := inDim + for i := int32(0); i <= numBlocks; i++ { + resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i) + block, err := newResBlock(weights, resPrefix, currentDim, outDim) + if err != nil { + return nil, err + } + resBlocks[i] = block + currentDim = outDim + } + + var upsampler *Upsample + if upsampleMode != "" { + upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode) + } + + return &UpBlock{ + ResBlocks: resBlocks, + Upsampler: upsampler, + }, nil +} + +// Forward applies up block +func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array { + for _, block := range u.ResBlocks { + prev := x + x = block.Forward(x) + prev.Free() + } + + if u.Upsampler != nil { + prev := x + x = u.Upsampler.Forward(x) + prev.Free() + } + return x +} + +// Upsample handles spatial upsampling +type Upsample struct { + Conv *mlx.Array + Bias *mlx.Array + Mode string +} + +// newUpsample creates an upsampler +func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample { + conv, _ := weights.Get(prefix + ".resample.1.weight") + bias, _ := weights.Get(prefix + ".resample.1.bias") + return &Upsample{ + Conv: conv, + Bias: bias, + Mode: mode, + } +} + +// Forward applies upsampling to channels-last input [B, T, H, W, C] +func (u *Upsample) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + outC := u.Conv.Shape()[0] + + // Stage 1: 2x nearest neighbor upsample + { + x = mlx.Reshape(x, B*T, H, W, C) + x = upsample2xChannelsLast(x) + mlx.Eval(x) + } + + // Stage 2: Conv + bias + { + prev := x + weight := mlx.Transpose(u.Conv, 0, 2, 3, 1) + x = conv2D3x3PaddedChannelsLast(x, weight) + if u.Bias != nil { + bias := mlx.Reshape(u.Bias, 1, 1, 1, outC) + x = mlx.Add(x, bias) + } + x = mlx.Reshape(x, B, T, H*2, W*2, outC) + prev.Free() + mlx.Eval(x) + } + + return x +} + +// MidBlock is the middle block +type MidBlock struct { + ResBlock1 *ResBlock + Attention *AttentionBlock + ResBlock2 *ResBlock +} + +// newMidBlock creates a mid block +func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) { + res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim) + if err != nil { + return nil, err + } + attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim) + if err != nil { + return nil, err + } + res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim) + if err != nil { + return nil, err + } + + return &MidBlock{ + ResBlock1: res1, + Attention: attn, + ResBlock2: res2, + }, nil +} + +// Forward applies mid block +func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array { + prev := x + x = m.ResBlock1.Forward(x) + prev.Free() + + prev = x + x = m.Attention.Forward(x) + prev.Free() + + prev = x + x = m.ResBlock2.Forward(x) + prev.Free() + + return x +} + +// Helper functions + +func silu3D(x *mlx.Array) *mlx.Array { + return mlx.Mul(x, mlx.Sigmoid(x)) +} + +// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor +func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array { + if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 { + return x + } + return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0}) +} + +// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x +func upsample2xChannelsLast(x *mlx.Array) *mlx.Array { + shape := x.Shape() + H := shape[1] + W := shape[2] + + rowIdxData := make([]int32, H*2) + for i := int32(0); i < H; i++ { + rowIdxData[i*2] = i + rowIdxData[i*2+1] = i + } + rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2}) + + colIdxData := make([]int32, W*2) + for i := int32(0); i < W; i++ { + colIdxData[i*2] = i + colIdxData[i*2+1] = i + } + colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2}) + + x = mlx.Take(x, rowIdx, 1) + x = mlx.Take(x, colIdx, 2) + + return x +} + +// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C] +func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array { + x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0}) + return mlx.Conv2d(x, weight, 1, 0) +} + +// conv2DStrided applies conv with stride > 1 using manual patch extraction +// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I] +func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + + wShape := weight.Shape() + Cout := wShape[0] + kH := wShape[1] + kW := wShape[2] + + outH := (H - kH) / stride + 1 + outW := (W - kW) / stride + 1 + + patches := extractPatches2DStrided(x, kH, kW, stride) + wFlat := mlx.Reshape(weight, Cout, -1) + patches = mlx.Reshape(patches, B*outH*outW, -1) + out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) + return mlx.Reshape(out, B, outH, outW, Cout) +} + +// conv3DStrided applies 3D conv with strides using manual patch extraction +// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format) +// strideT, strideH, strideW are the strides for each dimension +// Patches are extracted in [C, T, H, W] order to match Python's preprocessing +func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + + wShape := weight.Shape() + Cout := wShape[0] + // I := wShape[1] + kT := wShape[2] + kH := wShape[3] + kW := wShape[4] + + // For temporal: if T < kT, we need to repeat frames temporally + // For single image with T=1 and kT=2, we duplicate the frame to T=kT + // Python Qwen2.5-VL duplicates the frame, not zero-pads + if T < kT { + // Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C] + x = mlx.Tile(x, []int32{1, kT, 1, 1, 1}) + T = kT + } + + outT := (T - kT) / strideT + 1 + outH := (H - kH) / strideH + 1 + outW := (W - kW) / strideW + 1 + + // Extract 3D patches in [C, T, H, W] order to match Python + patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW) + // patches shape: [B, outT, outH, outW, C*kT*kH*kW] + + // Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W] + wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW] + patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW) + out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0)) + return mlx.Reshape(out, B, outT, outH, outW, Cout) +} + +// extractPatches3DStrided extracts 3D patches with given strides +// Returns patches with values in [C, T, H, W] order to match Python's preprocessing +func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + + outT := (T - kT) / strideT + 1 + outH := (H - kH) / strideH + 1 + outW := (W - kW) / strideW + 1 + + numPatches := outT * outH * outW + patches := make([]*mlx.Array, numPatches) + idx := 0 + for t := int32(0); t < outT; t++ { + for i := int32(0); i < outH; i++ { + for j := int32(0); j < outW; j++ { + startT := t * strideT + startH := i * strideH + startW := j * strideW + // Extract patch: [B, kT, kH, kW, C] + patch := mlx.Slice(x, + []int32{0, startT, startH, startW, 0}, + []int32{B, startT + kT, startH + kH, startW + kW, C}) + // Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order + patch = mlx.Transpose(patch, 0, 4, 1, 2, 3) + // Flatten to [B, C*T*H*W] + patch = mlx.Reshape(patch, B, C*kT*kH*kW) + patches[idx] = patch + idx++ + } + } + } + + for i := range patches { + patches[i] = mlx.ExpandDims(patches[i], 1) + } + stacked := mlx.Concatenate(patches, 1) + return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW) +} + +// extractPatches2DStrided extracts patches with given stride +func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array { + shape := x.Shape() + B := shape[0] + H := shape[1] + W := shape[2] + C := shape[3] + + outH := (H - kH) / stride + 1 + outW := (W - kW) / stride + 1 + + patches := make([]*mlx.Array, outH*outW) + idx := 0 + for i := int32(0); i < outH; i++ { + for j := int32(0); j < outW; j++ { + startH := i * stride + startW := j * stride + patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C}) + patch = mlx.Reshape(patch, B, kH*kW*C) + patches[idx] = patch + idx++ + } + } + + for i := range patches { + patches[i] = mlx.ExpandDims(patches[i], 1) + } + stacked := mlx.Concatenate(patches, 1) + return mlx.Reshape(stacked, B, outH, outW, kH*kW*C) +} + +// layerNormNoAffine applies layer norm without learnable parameters +func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array { + ndim := x.Ndim() + lastAxis := ndim - 1 + mean := mlx.Mean(x, lastAxis, true) + xCentered := mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true) + return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps))) +} diff --git a/x/imagegen/models/qwen_image_edit/processor.go b/x/imagegen/models/qwen_image_edit/processor.go new file mode 100644 index 000000000..c80f5a3b1 --- /dev/null +++ b/x/imagegen/models/qwen_image_edit/processor.go @@ -0,0 +1,475 @@ +//go:build mlx + +package qwen_image_edit + +import ( + "fmt" + "image" + "image/color" + _ "image/jpeg" + _ "image/png" + "math" + "os" + + "github.com/ollama/ollama/x/imagegen/mlx" + "golang.org/x/image/draw" + _ "golang.org/x/image/webp" +) + +// loadImageFile loads an image from disk +func loadImageFile(path string) (image.Image, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open image: %w", err) + } + defer f.Close() + + img, _, err := image.Decode(f) + if err != nil { + return nil, fmt.Errorf("decode image: %w", err) + } + return img, nil +} + +// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range +func imageToFloat32Pixels(img image.Image, width, height int) []float32 { + pixels := make([]float32, width*height*3) + idx := 0 + for y := 0; y < height; y++ { + for x := 0; x < width; x++ { + r, g, b, _ := img.At(x, y).RGBA() + pixels[idx] = float32(r) / 65535.0 + pixels[idx+1] = float32(g) / 65535.0 + pixels[idx+2] = float32(b) / 65535.0 + idx += 3 + } + } + return pixels +} + +// normalizeImageNet applies ImageNet normalization to an image tensor +func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array { + mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3}) + std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3}) + return mlx.Div(mlx.Sub(arr, mean), std) +} + +// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16 +func prepareImageTensor(arr *mlx.Array) *mlx.Array { + // Transpose to [C, H, W] and make contiguous + arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1)) + // Add batch dimension [1, C, H, W] + arr = mlx.ExpandDims(arr, 0) + // Convert to bf16 + arr = mlx.ToBFloat16(arr) + mlx.Eval(arr) + return arr +} + +// clampFloat clamps a value to [0, 255] and returns uint8 +func clampFloat(v, weightSum float64) uint8 { + v /= weightSum + if v < 0 { + v = 0 + } + if v > 255 { + v = 255 + } + return uint8(math.Round(v)) +} + +// ImageDims holds dimensions for a preprocessed image +type ImageDims struct { + // Original image dimensions + OrigW, OrigH int32 + // Condition image dimensions (for vision encoder) + CondW, CondH int32 + // VAE image dimensions + VaeW, VaeH int32 + // Latent dimensions (VAE dims / vae_scale_factor) + LatentW, LatentH int32 + // Patch dimensions (latent dims / patch_size) + PatchW, PatchH int32 +} + +// ProcessorConfig holds image processor configuration +type ProcessorConfig struct { + // Condition image size (target pixel area for vision encoder input) + // Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456 + // Pipeline resizes image to this area before passing to encode_prompt + ConditionImageSize int32 + + // VAE image size (target pixel area) + // Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576 + VAEImageSize int32 + + // Image normalization (ImageNet stats for vision encoder) + ImageMean []float32 + ImageStd []float32 +} + +// defaultProcessorConfig returns default processor config +func defaultProcessorConfig() *ProcessorConfig { + return &ProcessorConfig{ + ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE + VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE + ImageMean: []float32{0.48145466, 0.4578275, 0.40821073}, + ImageStd: []float32{0.26862954, 0.26130258, 0.27577711}, + } +} + +// Processor handles image preprocessing for Qwen-Image-Edit +type Processor struct { + Config *ProcessorConfig +} + +// Load loads the processor config +func (p *Processor) Load(path string) error { + p.Config = defaultProcessorConfig() + return nil +} + +// LoadAndPreprocess loads an image and preprocesses it for both paths +// Returns: condImage (for vision encoder), vaeImage (for VAE encoding) +func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) { + img, err := loadImageFile(imagePath) + if err != nil { + return nil, nil, err + } + + bounds := img.Bounds() + origW := bounds.Dx() + origH := bounds.Dy() + ratio := float64(origW) / float64(origH) + + // Calculate dimensions for condition image (vision encoder) + // Python pipeline does TWO resizes: + // 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area) + // 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28 + intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32) + finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280) + + // Calculate dimensions for VAE image (1024x1024 area) + // Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32) + vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32) + + // Preprocess for condition (vision encoder) - two-step resize + condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH) + + // Preprocess for VAE ([-1, 1] range, 5D tensor) + vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH) + + return condImage, vaeImage, nil +} + +// preprocessImageLanczos does single-step Lanczos resize for vision encoder +// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default) +// Used by edit_plus pipeline for multi-image input +// Returns: [B, C, H, W] normalized tensor +func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array { + resized := resizeImageLanczos(img, int(width), int(height)) + pixels := imageToFloat32Pixels(resized, int(width), int(height)) + arr := mlx.NewArray(pixels, []int32{height, width, 3}) + arr = p.normalizeImageNet(arr) + return prepareImageTensor(arr) +} + +// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline +// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize) +// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize) +// Returns: [B, C, H, W] normalized tensor +func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array { + intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH)) + resized := resizeImageBicubic(intermediate, int(finalW), int(finalH)) + pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH)) + arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3}) + arr = p.normalizeImageNet(arr) + return prepareImageTensor(arr) +} + +// preprocessImage converts image to tensor for vision encoder +// Returns: [B, C, H, W] normalized tensor +func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array { + resized := resizeImageBicubic(img, int(width), int(height)) + pixels := imageToFloat32Pixels(resized, int(width), int(height)) + arr := mlx.NewArray(pixels, []int32{height, width, 3}) + if normalize { + arr = p.normalizeImageNet(arr) + } + return prepareImageTensor(arr) +} + +// preprocessImageForVAE converts image to tensor for VAE encoding +// Returns: [B, C, T, H, W] tensor in [-1, 1] range +func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array { + resized := resizeImageLanczos(img, int(width), int(height)) + pixels := imageToFloat32Pixels(resized, int(width), int(height)) + arr := mlx.NewArray(pixels, []int32{height, width, 3}) + + // Scale to [-1, 1]: arr * 2 - 1 + arr = mlx.MulScalar(arr, 2.0) + arr = mlx.AddScalar(arr, -1.0) + + // Transpose to [C, H, W] and make contiguous + arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1)) + + // Add batch and temporal dimensions [1, C, 1, H, W] + arr = mlx.ExpandDims(arr, 0) // [1, C, H, W] + arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W] + + arr = mlx.ToBFloat16(arr) + mlx.Eval(arr) + return arr +} + +// smartResize implements Python Qwen2VL processor's smart_resize logic +// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints +func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) { + // Round to factor + hBar := int32(math.Round(float64(height)/float64(factor))) * factor + wBar := int32(math.Round(float64(width)/float64(factor))) * factor + + // Ensure minimum factor size + if hBar < factor { + hBar = factor + } + if wBar < factor { + wBar = factor + } + + // Check pixel constraints + total := hBar * wBar + if total > maxPixels { + // Scale down + beta := math.Sqrt(float64(maxPixels) / float64(total)) + hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor + wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor + } else if total < minPixels { + // Scale up + beta := math.Sqrt(float64(minPixels) / float64(total)) + hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor + wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor + } + + return hBar, wBar +} + +// calculateDimensions calculates width and height for a target area while maintaining ratio +// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge) +func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) { + width := math.Sqrt(float64(targetArea) * ratio) + height := width / ratio + + m := float64(multiple) + width = math.Round(width/m) * m + height = math.Round(height/m) * m + + // Ensure minimum dimensions + if width < m { + width = m + } + if height < m { + height = m + } + + return int32(width), int32(height) +} + +// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS) +func resizeImageLanczos(img image.Image, width, height int) image.Image { + bounds := img.Bounds() + dst := image.NewRGBA(image.Rect(0, 0, width, height)) + + // Lanczos3 kernel (a=3) to match PIL.LANCZOS + lanczos3 := &draw.Kernel{ + Support: 3.0, + At: func(t float64) float64 { + if t == 0 { + return 1.0 + } + if t < 0 { + t = -t + } + if t >= 3.0 { + return 0.0 + } + // sinc(t) * sinc(t/3) + piT := math.Pi * t + return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3)) + }, + } + lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil) + + return dst +} + +// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC) +// Uses separable interpolation with PIL's coordinate mapping for exact match +func resizeImageBicubic(img image.Image, width, height int) image.Image { + bounds := img.Bounds() + srcW := bounds.Dx() + srcH := bounds.Dy() + + // Convert to RGBA if needed + var src *image.RGBA + if rgba, ok := img.(*image.RGBA); ok { + src = rgba + } else { + src = image.NewRGBA(bounds) + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + src.Set(x, y, img.At(x, y)) + } + } + } + + // Keys cubic with a=-0.5 (PIL BICUBIC) + cubic := func(x float64) float64 { + if x < 0 { + x = -x + } + if x < 1 { + return 1.5*x*x*x - 2.5*x*x + 1 + } + if x < 2 { + return -0.5*x*x*x + 2.5*x*x - 4*x + 2 + } + return 0 + } + + // Horizontal pass: srcW -> width, keep srcH rows + temp := image.NewRGBA(image.Rect(0, 0, width, srcH)) + for y := 0; y < srcH; y++ { + for dstX := 0; dstX < width; dstX++ { + // PIL coordinate mapping: center-to-center + srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5 + baseX := int(math.Floor(srcXf)) + + var sumR, sumG, sumB, sumA, weightSum float64 + for i := -1; i <= 2; i++ { + sx := baseX + i + if sx < 0 { + sx = 0 + } + if sx >= srcW { + sx = srcW - 1 + } + + w := cubic(math.Abs(srcXf - float64(baseX+i))) + c := src.RGBAAt(sx, y) + sumR += float64(c.R) * w + sumG += float64(c.G) * w + sumB += float64(c.B) * w + sumA += float64(c.A) * w + weightSum += w + } + + temp.SetRGBA(dstX, y, color.RGBA{ + clampFloat(sumR, weightSum), + clampFloat(sumG, weightSum), + clampFloat(sumB, weightSum), + clampFloat(sumA, weightSum), + }) + } + } + + // Vertical pass: srcH -> height + dst := image.NewRGBA(image.Rect(0, 0, width, height)) + for x := 0; x < width; x++ { + for dstY := 0; dstY < height; dstY++ { + srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5 + baseY := int(math.Floor(srcYf)) + + var sumR, sumG, sumB, sumA, weightSum float64 + for j := -1; j <= 2; j++ { + sy := baseY + j + if sy < 0 { + sy = 0 + } + if sy >= srcH { + sy = srcH - 1 + } + + w := cubic(math.Abs(srcYf - float64(baseY+j))) + c := temp.RGBAAt(x, sy) + sumR += float64(c.R) * w + sumG += float64(c.G) * w + sumB += float64(c.B) * w + sumA += float64(c.A) * w + weightSum += w + } + + dst.SetRGBA(x, dstY, color.RGBA{ + clampFloat(sumR, weightSum), + clampFloat(sumG, weightSum), + clampFloat(sumB, weightSum), + clampFloat(sumA, weightSum), + }) + } + } + + return dst +} + +// LoadAndPreprocessMultiple loads multiple images and preprocesses them +// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions) +func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) { + const vaeScaleFactor int32 = 8 + const patchSize int32 = 2 + + condImages := make([]*mlx.Array, len(imagePaths)) + vaeImages := make([]*mlx.Array, len(imagePaths)) + dims := make([]ImageDims, len(imagePaths)) + + for i, imagePath := range imagePaths { + img, err := loadImageFile(imagePath) + if err != nil { + return nil, nil, nil, fmt.Errorf("image %d: %w", i, err) + } + + bounds := img.Bounds() + origW := int32(bounds.Dx()) + origH := int32(bounds.Dy()) + ratio := float64(origW) / float64(origH) + + // Calculate dimensions for condition image (vision encoder) + // Python pipeline does TWO resizes: + // 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area) + // 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28 + intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32) + condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280) + + // Calculate dimensions for VAE image (1024x1024 area) + vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32) + + // Calculate derived dimensions + latentW := vaeW / vaeScaleFactor + latentH := vaeH / vaeScaleFactor + patchW := latentW / patchSize + patchH := latentH / patchSize + + dims[i] = ImageDims{ + OrigW: origW, + OrigH: origH, + CondW: condW, + CondH: condH, + VaeW: vaeW, + VaeH: vaeH, + LatentW: latentW, + LatentH: latentH, + PatchW: patchW, + PatchH: patchH, + } + + fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n", + i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH) + + // Preprocess for condition (vision encoder) - two-step resize to match Python pipeline + condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH) + + // Preprocess for VAE ([-1, 1] range, 5D tensor) + vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH) + } + + return condImages, vaeImages, dims, nil +} diff --git a/x/imagegen/models/qwen_image_edit/qwen_image_edit.go b/x/imagegen/models/qwen_image_edit/qwen_image_edit.go new file mode 100644 index 000000000..991205c96 --- /dev/null +++ b/x/imagegen/models/qwen_image_edit/qwen_image_edit.go @@ -0,0 +1,610 @@ +//go:build mlx + +// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing. +// It reuses components from qwen_image where possible. +package qwen_image_edit + +import ( + "context" + "fmt" + "path/filepath" + "time" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/models/qwen_image" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// GenerateConfig holds all options for image editing. +type GenerateConfig struct { + Prompt string + NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid) + CFGScale float32 // CFG enabled when > 1.0 (default: 4.0) + Width int32 // Output width (default: from input image) + Height int32 // Output height (default: from input image) + Steps int // Denoising steps (default: 50) + Seed int64 // Random seed + Progress ProgressFunc // Optional progress callback +} + +// ProgressFunc is called during generation with step progress. +type ProgressFunc func(step, totalSteps int) + +// Model represents a Qwen-Image-Edit diffusion model. +type Model struct { + ModelPath string + Tokenizer *tokenizer.Tokenizer + Processor *Processor // Image processor for vision encoder + TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image) + Transformer *qwen_image.Transformer // Reuse qwen_image transformer + VAE *VAE // Combined encoder + decoder +} + +// Load loads the Qwen-Image-Edit model from a directory. +func (m *Model) Load(modelPath string) error { + fmt.Println("Loading Qwen-Image-Edit model...") + start := time.Now() + + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + mlx.EnableCompile() + } + + m.ModelPath = modelPath + + // Load tokenizer from processor directory + fmt.Print(" Loading tokenizer... ") + processorPath := filepath.Join(modelPath, "processor") + tok, err := tokenizer.Load(processorPath) + if err != nil { + // Fallback to tokenizer directory + tokenizerPath := filepath.Join(modelPath, "tokenizer") + tok, err = tokenizer.Load(tokenizerPath) + if err != nil { + return fmt.Errorf("tokenizer: %w", err) + } + } + m.Tokenizer = tok + fmt.Println("✓") + + // Load processor (image preprocessing config) + fmt.Print(" Loading processor... ") + m.Processor = &Processor{} + if err := m.Processor.Load(processorPath); err != nil { + return fmt.Errorf("processor: %w", err) + } + fmt.Println("✓") + + // Load vision-language text encoder (Qwen2.5-VL from qwen_image package) + m.TextEncoder = &qwen_image.Qwen25VL{} + if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil { + return fmt.Errorf("text encoder: %w", err) + } + mlx.Eval(mlx.Collect(m.TextEncoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load transformer (reuse qwen_image) + m.Transformer = &qwen_image.Transformer{} + if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil { + return fmt.Errorf("transformer: %w", err) + } + mlx.Eval(mlx.Collect(m.Transformer)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load VAE (encoder + decoder) + m.VAE = &VAE{} + if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil { + return fmt.Errorf("VAE: %w", err) + } + mlx.Eval(mlx.Collect(m.VAE)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + mem := mlx.MetalGetActiveMemory() + peak := mlx.MetalGetPeakMemory() + fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n", + time.Since(start).Seconds(), + float64(mem)/(1024*1024*1024), + float64(peak)/(1024*1024*1024)) + + return nil +} + +// Edit edits an image based on a text prompt. +// inputImagePath: path to input image +// prompt: text description of desired edit +func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + }) +} + +// EditFromConfig edits images using the unified config struct. +// Accepts one or more input images. +func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) { + if len(inputImagePaths) == 0 { + return nil, fmt.Errorf("no input images provided") + } + + start := time.Now() + result, err := m.edit(inputImagePaths, cfg) + if err != nil { + return nil, err + } + + if cfg.NegativePrompt != "" { + fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n", + len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps) + } else { + fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n", + len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps) + } + return result, nil +} + +// EditImage implements model.ImageEditModel interface. +func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.Edit(inputImagePath, prompt, width, height, steps, seed) +} + +// EditMultiImage edits using multiple source images. +// This matches diffusers' QwenImageEditPlusPipeline behavior. +func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) { + return m.EditFromConfig(inputImagePaths, cfg) +} + +// edit is the internal editing pipeline that handles one or more images. +func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) { + // Apply defaults + if cfg.Steps <= 0 { + cfg.Steps = 50 + } + if cfg.CFGScale <= 0 { + cfg.CFGScale = 4.0 + } + + // Load and preprocess all input images + fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths)) + condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths) + if err != nil { + return nil, fmt.Errorf("preprocess images: %w", err) + } + for _, img := range condImages { + mlx.Keep(img) + } + for _, img := range vaeImages { + mlx.Keep(img) + } + mlx.Eval(append(condImages, vaeImages...)...) + + useCFG := cfg.NegativePrompt != "" + tcfg := m.Transformer.Config + vaeScaleFactor := int32(8) + + // Output dimensions - if not specified, use first input image dimensions + if cfg.Width <= 0 { + cfg.Width = inputDims[0].VaeW + } + if cfg.Height <= 0 { + cfg.Height = inputDims[0].VaeH + } + + // Output (noise) latent dimensions + outLatentH := cfg.Height / vaeScaleFactor + outLatentW := cfg.Width / vaeScaleFactor + outPH := outLatentH / tcfg.PatchSize + outPW := outLatentW / tcfg.PatchSize + noiseSeqLen := outPH * outPW + imgSeqLen := noiseSeqLen + + // Encode prompt with all images for conditioning + posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages) + if err != nil { + return nil, fmt.Errorf("encoding prompt: %w", err) + } + mlx.Keep(posEmb) + mlx.Eval(posEmb) + + var negEmb *mlx.Array + if useCFG { + negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages) + if err != nil { + return nil, fmt.Errorf("encoding negative prompt: %w", err) + } + mlx.Keep(negEmb) + mlx.Eval(negEmb) + } + + // Pad sequences to same length for CFG + txtLen := posEmb.Shape()[1] + if useCFG { + negLen := negEmb.Shape()[1] + if negLen > txtLen { + txtLen = negLen + } + if posEmb.Shape()[1] < txtLen { + posEmb = padSequence(posEmb, txtLen) + } + if negEmb.Shape()[1] < txtLen { + negEmb = padSequence(negEmb, txtLen) + } + mlx.Keep(posEmb, negEmb) + mlx.Eval(posEmb, negEmb) + } + + // Encode all input images to latents and concatenate + fmt.Println("Encoding images to latents...") + allImageLatentsPacked := make([]*mlx.Array, len(vaeImages)) + for i, vaeImage := range vaeImages { + imageLatents := m.VAE.Encode(vaeImage) + imageLatents = m.VAE.Normalize(imageLatents) + imageLatents2D := mlx.Squeeze(imageLatents, 2) + packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize) + mlx.Keep(packed) + mlx.Eval(packed) + allImageLatentsPacked[i] = packed + } + + imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1) + mlx.Keep(imageLatentsPacked) + mlx.Eval(imageLatentsPacked) + + // Scheduler + scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig()) + scheduler.SetTimesteps(cfg.Steps, noiseSeqLen) + + // Init noise latents in packed format + packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize + packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed) + latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize) + mlx.Eval(latents) + + // RoPE cache + ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope) + mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs) + mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs) + + // Denoising loop + fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps) + for i := 0; i < cfg.Steps; i++ { + stepStart := time.Now() + if cfg.Progress != nil { + cfg.Progress(i+1, cfg.Steps) + } + + t := scheduler.Timesteps[i] + timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1})) + mlx.Eval(timestep) + + latents2D := mlx.Squeeze(latents, 2) + patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize) + latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1) + + var output *mlx.Array + if useCFG { + posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) + negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) + + posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]}) + negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]}) + + output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale) + } else { + output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs) + output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]}) + } + + noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize) + oldLatents := latents + latents = scheduler.Step(noisePred, latents, i) + mlx.Eval(latents) + oldLatents.Free() + + fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds()) + } + + // Free denoising temporaries + posEmb.Free() + if negEmb != nil { + negEmb.Free() + } + ropeCache.ImgFreqs.Free() + ropeCache.TxtFreqs.Free() + imageLatentsPacked.Free() + + // Decode latents + decoded := m.decodeAndPostprocess(latents) + latents.Free() + + fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + return decoded, nil +} + +// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling. +// This prevents CFG from inflating magnitude too much. +func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array { + // Upcast to float32 for precision + posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32) + negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32) + + // CFG: pred = neg + scale * (pos - neg) + diff := mlx.Sub(posF32, negF32) + scaledDiff := mlx.MulScalar(diff, scale) + combPred := mlx.Add(negF32, scaledDiff) + + // Norm rescaling: rescale combined prediction to match conditional norm + condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true)) + combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true)) + output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm)) + + mlx.Eval(output) + return mlx.ToBFloat16(output) +} + +// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1]. +func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array { + latents = m.VAE.Denormalize(latents) + decoded := m.VAE.Decode(latents) + + // Post-process: squeeze temporal dim and rescale to [0, 1] + decoded = mlx.Squeeze(decoded, 2) + decoded = mlx.AddScalar(decoded, 1.0) + decoded = mlx.DivScalar(decoded, 2.0) + decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true) + mlx.Eval(decoded) + return decoded +} + +// padSequence pads a sequence tensor to the target length with zeros +func padSequence(x *mlx.Array, targetLen int32) *mlx.Array { + shape := x.Shape() + currentLen := shape[1] + if currentLen >= targetLen { + return x + } + padLen := targetLen - currentLen + // Pad on sequence dimension (axis 1) + return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0}) +} + +// LoadPersistent is an alias for backward compatibility. +func LoadPersistent(modelPath string) (*Model, error) { + m := &Model{} + if err := m.Load(modelPath); err != nil { + return nil, err + } + return m, nil +} + +// PrepareRoPEMultiImage computes RoPE with interpolation for image editing. +// Handles single or multiple input images with different resolutions. +// +// Parameters: +// - outPH, outPW: output patch dimensions (noise latent resolution) +// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...] +// - txtLen: text sequence length +// - axesDims: RoPE axis dimensions [16, 56, 56] +// +// Returns RoPE cache where: +// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions +// - First outPH*outPW positions are for noise latents (standard RoPE at output res) +// - Following positions are for each input image (interpolated from output res) +func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache { + theta := float64(10000) + maxIdx := int32(4096) + + // Compute base frequencies for each axis dimension + freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta) + freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta) + freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta) + + // Build frequency lookup tables + posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false) + posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false) + posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false) + negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image + negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true) + negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true) + + headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2 + + // Helper to compute RoPE for a single position at output resolution with scale_rope + computePosFreqs := func(framePos, y, x int32) []float32 { + row := make([]float32, headDim) + idx := 0 + + // Frame position + for i := 0; i < len(freqsT)*2; i++ { + row[idx+i] = posFreqsT[framePos][i] + } + idx += len(freqsT) * 2 + + // Height with scale_rope centering (using OUTPUT dimensions) + outHHalf := outPH / 2 + hNegCount := outPH - outHHalf + if y < hNegCount { + negTableIdx := maxIdx - hNegCount + y + for i := 0; i < len(freqsH)*2; i++ { + row[idx+i] = negFreqsH[negTableIdx][i] + } + } else { + posIdx := y - hNegCount + for i := 0; i < len(freqsH)*2; i++ { + row[idx+i] = posFreqsH[posIdx][i] + } + } + idx += len(freqsH) * 2 + + // Width with scale_rope centering (using OUTPUT dimensions) + outWHalf := outPW / 2 + wNegCount := outPW - outWHalf + if x < wNegCount { + negTableIdx := maxIdx - wNegCount + x + for i := 0; i < len(freqsW)*2; i++ { + row[idx+i] = negFreqsW[negTableIdx][i] + } + } else { + posIdx := x - wNegCount + for i := 0; i < len(freqsW)*2; i++ { + row[idx+i] = posFreqsW[posIdx][i] + } + } + + return row + } + + // Helper to compute RoPE for frame -1 (used for last condition image) + // This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:] + computeNegFrameFreqs := func(y, x int32) []float32 { + row := make([]float32, headDim) + idx := 0 + + // Frame -1: use last row of negative frame frequencies + negFrameIdx := maxIdx - 1 + for i := 0; i < len(freqsT)*2; i++ { + row[idx+i] = negFreqsT[negFrameIdx][i] + } + idx += len(freqsT) * 2 + + // Height with scale_rope centering (using OUTPUT dimensions) + outHHalf := outPH / 2 + hNegCount := outPH - outHHalf + if y < hNegCount { + negTableIdx := maxIdx - hNegCount + y + for i := 0; i < len(freqsH)*2; i++ { + row[idx+i] = negFreqsH[negTableIdx][i] + } + } else { + posIdx := y - hNegCount + for i := 0; i < len(freqsH)*2; i++ { + row[idx+i] = posFreqsH[posIdx][i] + } + } + idx += len(freqsH) * 2 + + // Width with scale_rope centering (using OUTPUT dimensions) + outWHalf := outPW / 2 + wNegCount := outPW - outWHalf + if x < wNegCount { + negTableIdx := maxIdx - wNegCount + x + for i := 0; i < len(freqsW)*2; i++ { + row[idx+i] = negFreqsW[negTableIdx][i] + } + } else { + posIdx := x - wNegCount + for i := 0; i < len(freqsW)*2; i++ { + row[idx+i] = posFreqsW[posIdx][i] + } + } + + return row + } + + // Total image sequence length: noise + all input images + noiseSeqLen := outPH * outPW + totalImgLen := noiseSeqLen + for _, dims := range inputDims { + totalImgLen += dims.PatchH * dims.PatchW + } + + imgFreqsData := make([]float32, totalImgLen*headDim) + idx := int32(0) + + // Segment 0: Noise latents - standard RoPE at output resolution (frame 0) + for y := int32(0); y < outPH; y++ { + for x := int32(0); x < outPW; x++ { + row := computePosFreqs(0, y, x) + copy(imgFreqsData[idx:], row) + idx += headDim + } + } + + // Segments 1..N: Edit image latents - INTERPOLATED RoPE + // For single image: use frame 1 (matches original PrepareRoPEInterpolated) + // For multiple images: Python uses frame -1 for the LAST condition image + // (_compute_condition_freqs), positive indices for others. + numImages := len(inputDims) + lastImgIdx := numImages - 1 + for imgIdx, dims := range inputDims { + inPH := dims.PatchH + inPW := dims.PatchW + + // Determine frame index for this image + // Single image case: use frame 1 (like original PrepareRoPEInterpolated) + // Multi-image case: last image uses frame -1, others use frame 1, 2, etc. + useNegFrame := numImages > 1 && imgIdx == lastImgIdx + + // Map each input position to an output position using linear interpolation + for y := int32(0); y < inPH; y++ { + for x := int32(0); x < inPW; x++ { + // Interpolate: map input (y, x) to output grid position + // This is the key fix from DiffSynth's forward_sampling + var yOut, xOut int32 + if inPH == 1 { + yOut = 0 + } else { + // Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1) + yOut = y * (outPH - 1) / (inPH - 1) + } + if inPW == 1 { + xOut = 0 + } else { + xOut = x * (outPW - 1) / (inPW - 1) + } + + var row []float32 + if useNegFrame { + // Last image in multi-image uses frame -1 + row = computeNegFrameFreqs(yOut, xOut) + } else { + // Single image uses frame 1, multi-image uses frame 1, 2, etc. + frameIdx := int32(imgIdx + 1) + row = computePosFreqs(frameIdx, yOut, xOut) + } + copy(imgFreqsData[idx:], row) + idx += headDim + } + } + } + + imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim}) + imgFreqs = mlx.ToBFloat16(imgFreqs) + + // Text frequencies - start after max video index + maxVidIdx := max(outPH/2, outPW/2) + + txtFreqsData := make([]float32, txtLen*headDim) + idx = 0 + for t := int32(0); t < txtLen; t++ { + pos := maxVidIdx + t + for i := 0; i < len(freqsT)*2; i++ { + txtFreqsData[idx+int32(i)] = posFreqsT[pos][i] + } + idx += int32(len(freqsT) * 2) + for i := 0; i < len(freqsH)*2; i++ { + txtFreqsData[idx+int32(i)] = posFreqsH[pos][i] + } + idx += int32(len(freqsH) * 2) + for i := 0; i < len(freqsW)*2; i++ { + txtFreqsData[idx+int32(i)] = posFreqsW[pos][i] + } + idx += int32(len(freqsW) * 2) + } + + txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim}) + txtFreqs = mlx.ToBFloat16(txtFreqs) + + return &qwen_image.RoPECache{ + ImgFreqs: imgFreqs, + TxtFreqs: txtFreqs, + } +} diff --git a/x/imagegen/models/qwen_image_edit/rope_test.go b/x/imagegen/models/qwen_image_edit/rope_test.go new file mode 100644 index 000000000..7da4eaa3c --- /dev/null +++ b/x/imagegen/models/qwen_image_edit/rope_test.go @@ -0,0 +1,227 @@ +//go:build mlx + +package qwen_image_edit + +import ( + "math" + "testing" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/models/qwen_image" +) + +// TestComputeAxisFreqs verifies frequency computation matches Python reference +func TestComputeAxisFreqs(t *testing.T) { + theta := float64(10000) + + // Expected values from Python: + // freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim)) + expectedFreqsT := []float64{ + 1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684, + 0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017, + } + + expectedFreqsH_first4 := []float64{ + 1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494, + } + + expectedFreqsH_last4 := []float64{ + 0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437, + } + + // Test temporal frequencies (dim=16) + freqsT := qwen_image.ComputeAxisFreqs(16, theta) + if len(freqsT) != 8 { + t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT)) + } + for i, expected := range expectedFreqsT { + if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 { + t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff) + } + } + + // Test height/width frequencies (dim=56) + freqsH := qwen_image.ComputeAxisFreqs(56, theta) + if len(freqsH) != 28 { + t.Fatalf("expected 28 height frequencies, got %d", len(freqsH)) + } + for i, expected := range expectedFreqsH_first4 { + if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 { + t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff) + } + } + for i, expected := range expectedFreqsH_last4 { + idx := 24 + i // last 4 of 28 + if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 { + t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff) + } + } +} + +// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions +func TestMakeFreqTable(t *testing.T) { + theta := float64(10000) + freqsT := qwen_image.ComputeAxisFreqs(16, theta) + maxIdx := int32(4096) + + // Test positive table + posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false) + + // Position 0 should give cos=1, sin=0 for all frequencies + for i := 0; i < len(freqsT)*2; i += 2 { + if posTable[0][i] != 1.0 { + t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i]) + } + if posTable[0][i+1] != 0.0 { + t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1]) + } + } + + // Position 1, first frequency (1.0): angle = 1*1 = 1 + // cos(1) = 0.5403, sin(1) = 0.8415 + if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 { + t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0]) + } + if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 { + t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1]) + } + + // Test negative table + negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true) + + // negTable[4095] corresponds to position -1 + // cos(-1) = cos(1), sin(-1) = -sin(1) + if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 { + t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0]) + } + if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 { + t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1]) + } + + // negTable[4094] corresponds to position -2 + // cos(-2) = cos(2), sin(-2) = -sin(2) + cos2 := math.Cos(2.0) + sin2 := math.Sin(2.0) + if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 { + t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0]) + } + if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 { + t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1]) + } +} + +// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case +func TestPrepareRoPE_QwenImage(t *testing.T) { + if !mlx.GPUIsAvailable() { + t.Skip("GPU not available") + } + + mlx.SetDefaultDeviceCPU() + + // 4x4 patch grid, single image + imgH, imgW := int32(4), int32(4) + txtLen := int32(5) + axesDims := []int32{16, 56, 56} + + cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims) + mlx.Eval(cache.ImgFreqs, cache.TxtFreqs) + + // Check shapes + imgShape := cache.ImgFreqs.Shape() + if imgShape[0] != 16 { // 4*4 patches + t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0]) + } + + // For single image (frame=0), all temporal values should be cos=1, sin=0 + imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32) + mlx.Eval(imgFreqsCPU) + imgData := imgFreqsCPU.Data() + + // Check first 16 values of patch 0 (temporal cos/sin pairs) + for i := 0; i < 16; i += 2 { + cosVal := imgData[i] + sinVal := imgData[i+1] + if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 { + t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal) + } + if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 { + t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal) + } + } + + cache.ImgFreqs.Free() + cache.TxtFreqs.Free() +} + +// TestScaleRopePositions verifies the centered position calculation for scale_rope=True +func TestScaleRopePositions(t *testing.T) { + // For a 4x4 grid with scale_rope=True: + // hHalf = 2, wHalf = 2 + // hNegCount = 4 - 2 = 2 (positions 0,1 are negative) + // wNegCount = 4 - 2 = 2 (positions 0,1 are negative) + // + // Height positions: + // y=0: -(4-2) + 0 = -2 + // y=1: -(4-2) + 1 = -1 + // y=2: 2 - 2 = 0 + // y=3: 3 - 2 = 1 + // + // Same for width + + pH, pW := int32(4), int32(4) + hHalf := pH / 2 + wHalf := pW / 2 + hNegCount := pH - hHalf + wNegCount := pW - wHalf + + expectedH := []int32{-2, -1, 0, 1} + expectedW := []int32{-2, -1, 0, 1} + + for y := int32(0); y < pH; y++ { + var hPos int32 + if y < hNegCount { + hPos = -(pH - hHalf) + y + } else { + hPos = y - hNegCount + } + if hPos != expectedH[y] { + t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos) + } + } + + for x := int32(0); x < pW; x++ { + var wPos int32 + if x < wNegCount { + wPos = -(pW - wHalf) + x + } else { + wPos = x - wNegCount + } + if wPos != expectedW[x] { + t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos) + } + } +} + +// TestRoPEHeadDimensions verifies the head dimension breakdown +func TestRoPEHeadDimensions(t *testing.T) { + // axes_dims_rope = [16, 56, 56] + // Each dimension uses half the values for frequencies + // So we get: 8 + 28 + 28 = 64 frequency values + // Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position + + axesDims := []int32{16, 56, 56} + expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2) + expectedHeadDim := expectedFreqs * 2 + + if expectedFreqs != 64 { + t.Errorf("expected 64 frequency values, got %d", expectedFreqs) + } + if expectedHeadDim != 128 { + t.Errorf("expected head_dim=128, got %d", expectedHeadDim) + } + + // This should match the transformer's attention head dimension + // hidden_size = 3072, num_heads = 24 + // head_dim = 3072 / 24 = 128 +} + diff --git a/x/imagegen/models/qwen_image_edit/vae.go b/x/imagegen/models/qwen_image_edit/vae.go new file mode 100644 index 000000000..3dbe7ef3c --- /dev/null +++ b/x/imagegen/models/qwen_image_edit/vae.go @@ -0,0 +1,642 @@ +//go:build mlx + +package qwen_image_edit + +import ( + "fmt" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// VAEConfig holds Qwen-Image VAE configuration +type VAEConfig struct { + ZDim int32 `json:"z_dim"` // 16 + BaseDim int32 `json:"base_dim"` // 96 + DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4] + NumResBlocks int32 `json:"num_res_blocks"` // 2 + LatentsMean []float32 `json:"latents_mean"` // 16 values + LatentsStd []float32 `json:"latents_std"` // 16 values + TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true] +} + +// defaultVAEConfig returns config for Qwen-Image VAE +func defaultVAEConfig() *VAEConfig { + return &VAEConfig{ + ZDim: 16, + BaseDim: 96, + DimMult: []int32{1, 2, 4, 4}, + NumResBlocks: 2, + LatentsMean: []float32{ + -0.7571, -0.7089, -0.9113, 0.1075, + -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, + -0.1922, -0.9497, 0.2503, -0.2921, + }, + LatentsStd: []float32{ + 2.8184, 1.4541, 2.3275, 2.6558, + 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, + 1.6382, 1.1253, 2.8251, 1.916, + }, + TemperalDownsample: []bool{false, true, true}, + } +} + +// VAE is the full VAE with encoder and decoder +type VAE struct { + Config *VAEConfig + Encoder *VAEEncoder + Decoder *VAEDecoder +} + +// Load loads the VAE from a directory +func (m *VAE) Load(path string) error { + fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...") + + cfg := defaultVAEConfig() + m.Config = cfg + + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + + // Load weights as f32 for quality (matches Python default behavior) + // VAE decoder precision is critical for final image quality + fmt.Print(" Loading weights as f32... ") + if err := weights.Load(mlx.DtypeFloat32); err != nil { + return fmt.Errorf("failed to load weights: %w", err) + } + fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) + + // Load encoder + fmt.Print(" Loading encoder... ") + m.Encoder = &VAEEncoder{} + if err := m.Encoder.loadFromWeights(weights, cfg); err != nil { + return fmt.Errorf("encoder: %w", err) + } + fmt.Println("✓") + + // Load decoder + fmt.Print(" Loading decoder... ") + m.Decoder = &VAEDecoder{} + if err := m.Decoder.loadFromWeights(weights, cfg); err != nil { + return fmt.Errorf("decoder: %w", err) + } + fmt.Println("✓") + + weights.ReleaseAll() + return nil +} + +// Encode encodes an image to latents +// x: [B, C, T, H, W] image tensor in [-1, 1] range +// Returns: [B, C, T, H/8, W/8] latents (unnormalized) +func (m *VAE) Encode(x *mlx.Array) *mlx.Array { + return m.Encoder.Encode(x) +} + +// Decode decodes latents to image +// z: [B, C, T, H, W] latents (denormalized) +// Returns: [B, C, T, H*8, W*8] image in [-1, 1] +func (m *VAE) Decode(z *mlx.Array) *mlx.Array { + return m.Decoder.Decode(z) +} + +// Normalize applies latent normalization +// Input z should be f32 (from VAE encoder), output is f32 for transformer +func (m *VAE) Normalize(z *mlx.Array) *mlx.Array { + shape := z.Shape() + C := shape[1] + + mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1}) + std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1}) + + // Mean/std are f32, will match z dtype through broadcasting + return mlx.Div(mlx.Sub(z, mean), std) +} + +// Denormalize reverses latent normalization +// Input z is bf16 (from transformer), output converted to f32 for VAE decoder +func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array { + shape := z.Shape() + C := shape[1] + + // Convert latents to f32 for VAE decoder quality + z = mlx.AsType(z, mlx.DtypeFloat32) + + mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1}) + std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1}) + + return mlx.Add(mlx.Mul(z, std), mean) +} + +// VAEEncoder is the encoder part of the VAE +// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers: +// - Blocks 0,1: ResBlocks (base_dim) +// - Block 2: Downsample +// - Blocks 3,4: ResBlocks (base_dim*2) +// - Block 5: Downsample + temporal +// - Blocks 6,7: ResBlocks (base_dim*4) +// - Block 8: Downsample + temporal +// - Blocks 9,10: ResBlocks (base_dim*4) +type VAEEncoder struct { + Config *VAEConfig + + ConvIn *CausalConv3d + Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers + MidBlock *MidBlock + NormOut *RMSNorm3D + ConvOut *CausalConv3d + QuantConv *CausalConv3d +} + +// EncoderBlock is either a ResBlock or a Downsample +type EncoderBlock interface { + Forward(x *mlx.Array) *mlx.Array + IsDownsample() bool +} + +// EncoderResBlock wraps ResBlock +type EncoderResBlock struct { + *ResBlock +} + +func (b *EncoderResBlock) IsDownsample() bool { return false } + +// EncoderDownsample is a downsample layer +type EncoderDownsample struct { + Resample *CausalConv3d + TimeConv *CausalConv3d // Optional temporal downsample +} + +func (d *EncoderDownsample) IsDownsample() bool { return true } + +func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array { + // Spatial downsample with stride 2 + // WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2) + x = d.forwardSpatialDownsample(x) + + // NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode + // with feat_cache. For single-frame encoding (T=1), time_conv is skipped. + // The Python forward checks: if feat_cache is not None ... then use time_conv + // Since we don't support streaming, we skip time_conv entirely. + return x +} + +// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling +func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array { + xShape := x.Shape() + B := xShape[0] + T := xShape[1] + H := xShape[2] + W := xShape[3] + C := xShape[4] + + wShape := d.Resample.Weight.Shape() + outC := wShape[0] + + // Reshape to [B*T, H, W, C] for 2D conv + x = mlx.Reshape(x, B*T, H, W, C) + + // Asymmetric padding: pad right and bottom by 1 (WAN VAE style) + // ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1) + x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W + + // Apply 2D conv with stride 2 + weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I] + x = conv2DStrided(x, weight, 2) + + if d.Resample.Bias != nil { + bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC) + x = mlx.Add(x, bias) + } + + // Output dims after stride 2: (H+1)/2, (W+1)/2 + outH := (H + 1) / 2 + outW := (W + 1) / 2 + + // Reshape back to [B, T, H', W', C] + x = mlx.Reshape(x, B, T, outH, outW, outC) + mlx.Eval(x) + + return x +} + +// loadFromWeights loads the encoder from pre-loaded weights +func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error { + e.Config = cfg + + // Conv in + convIn, err := newCausalConv3d(weights, "encoder.conv_in") + if err != nil { + return err + } + e.ConvIn = convIn + + // Encoder uses flat block structure: + // dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true] + // Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res + // That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res + e.Blocks = make([]EncoderBlock, 0, 11) + + // Track dimensions + dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4} + blockIdx := 0 + + for stage := 0; stage < len(cfg.DimMult); stage++ { + inDim := cfg.BaseDim + if stage > 0 { + inDim = dims[stage-1] + } + outDim := dims[stage] + + // ResBlocks for this stage (num_res_blocks per stage) + for r := int32(0); r < cfg.NumResBlocks; r++ { + prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx) + currentInDim := inDim + if r > 0 { + currentInDim = outDim + } + block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim) + if err != nil { + return fmt.Errorf("encoder res block %d: %w", blockIdx, err) + } + e.Blocks = append(e.Blocks, block) + blockIdx++ + } + + // Downsample after each stage except the last + if stage < len(cfg.DimMult)-1 { + prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx) + down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage]) + if err != nil { + return fmt.Errorf("encoder downsample %d: %w", blockIdx, err) + } + e.Blocks = append(e.Blocks, down) + blockIdx++ + } + } + + // Mid block + midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1] + midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim) + if err != nil { + return err + } + e.MidBlock = midBlock + + // Norm out + normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim) + if err != nil { + return err + } + e.NormOut = normOut + + // Conv out + convOut, err := newCausalConv3d(weights, "encoder.conv_out") + if err != nil { + return err + } + e.ConvOut = convOut + + // Quant conv + quantConv, err := newCausalConv3d(weights, "quant_conv") + if err != nil { + return err + } + e.QuantConv = quantConv + + return nil +} + +// newEncoderResBlock creates a ResBlock for the encoder (flat structure) +func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) { + block, err := newResBlock(weights, prefix, inDim, outDim) + if err != nil { + return nil, err + } + return &EncoderResBlock{block}, nil +} + +// newEncoderDownsample creates a downsample layer for the encoder +func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) { + resample, err := newCausalConv3d(weights, prefix+".resample.1") + if err != nil { + return nil, err + } + + var timeConv *CausalConv3d + if temporal { + timeConv, _ = newCausalConv3d(weights, prefix+".time_conv") + } + + return &EncoderDownsample{ + Resample: resample, + TimeConv: timeConv, + }, nil +} + +// Encode encodes an image to latents +// x: [B, C, T, H, W] image tensor (channels-first) +// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode +func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array { + // Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C] + x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1)) + mlx.Eval(x) + + // Conv in + x = e.ConvIn.Forward(x) + + // Encoder blocks (mix of ResBlocks and Downsamplers) + for _, block := range e.Blocks { + prev := x + x = block.Forward(x) + prev.Free() + } + + // Mid block + x = e.MidBlock.Forward(x) + + // Norm + silu + { + prev := x + x = e.NormOut.Forward(x) + x = silu3D(x) + prev.Free() + mlx.Eval(x) + } + + // Conv out + { + prev := x + x = e.ConvOut.Forward(x) + prev.Free() + } + + // Quant conv + { + prev := x + x = e.QuantConv.Forward(x) + prev.Free() + } + + // Get mode from distribution (first half of channels = mean) + // Output is [B, T, H, W, 2*latent_C], we take first latent_C channels + shape := x.Shape() + latentC := shape[4] / 2 + x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC}) + + // Convert back to channels-first [N, C, T, H, W] + x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3)) + mlx.Eval(x) + + return x +} + +// VAEDecoder is the decoder part of the VAE +type VAEDecoder struct { + Config *VAEConfig + + PostQuantConv *CausalConv3d + ConvIn *CausalConv3d + MidBlock *MidBlock + UpBlocks []*UpBlock + NormOut *RMSNorm3D + ConvOut *CausalConv3d +} + +// loadFromWeights loads the decoder from pre-loaded weights +func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error { + d.Config = cfg + + postQuantConv, err := newCausalConv3d(weights, "post_quant_conv") + if err != nil { + return err + } + d.PostQuantConv = postQuantConv + + convIn, err := newCausalConv3d(weights, "decoder.conv_in") + if err != nil { + return err + } + d.ConvIn = convIn + + // Mid block + midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1] + midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim) + if err != nil { + return err + } + d.MidBlock = midBlock + + // Up blocks (reversed dim_mult) + numUpBlocks := len(cfg.DimMult) + d.UpBlocks = make([]*UpBlock, numUpBlocks) + + dimsMult := make([]int32, numUpBlocks+1) + dimsMult[0] = cfg.DimMult[numUpBlocks-1] + for i := 0; i < numUpBlocks; i++ { + dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i] + } + + temporalUpsample := make([]bool, len(cfg.TemperalDownsample)) + for i := range cfg.TemperalDownsample { + temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i] + } + + for i := 0; i < numUpBlocks; i++ { + inDim := cfg.BaseDim * dimsMult[i] + outDim := cfg.BaseDim * dimsMult[i+1] + + if i > 0 { + inDim = inDim / 2 + } + + upsampleMode := "" + if i < numUpBlocks-1 { + if temporalUpsample[i] { + upsampleMode = "upsample3d" + } else { + upsampleMode = "upsample2d" + } + } + + prefix := fmt.Sprintf("decoder.up_blocks.%d", i) + upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode) + if err != nil { + return err + } + d.UpBlocks[i] = upBlock + } + + normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim) + if err != nil { + return err + } + d.NormOut = normOut + + convOut, err := newCausalConv3d(weights, "decoder.conv_out") + if err != nil { + return err + } + d.ConvOut = convOut + + return nil +} + +// Decode converts latents to image +// z: [B, C, T, H, W] denormalized latents +func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array { + var x *mlx.Array + + // Convert from channels-first to channels-last + { + z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1)) + mlx.Eval(z) + } + + // PostQuantConv + x = d.PostQuantConv.Forward(z) + z.Free() + + // ConvIn + { + prev := x + x = d.ConvIn.Forward(x) + prev.Free() + } + + // Mid block + x = d.MidBlock.Forward(x) + + // Up blocks + for _, upBlock := range d.UpBlocks { + x = upBlock.Forward(x) + } + + // NormOut + silu + { + prev := x + x = d.NormOut.Forward(x) + x = silu3D(x) + prev.Free() + mlx.Eval(x) + } + + // ConvOut + { + prev := x + x = d.ConvOut.Forward(x) + prev.Free() + } + + // Post-processing: clamp and convert back to channels-first + { + prev := x + x = mlx.ClipScalar(x, -1.0, 1.0, true, true) + x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3)) + prev.Free() + mlx.Eval(x) + } + + return x +} + +// DownBlock handles downsampling in encoder +type DownBlock struct { + ResBlocks []*ResBlock + Downsampler *Downsample +} + +// newDownBlock creates a down block +func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) { + resBlocks := make([]*ResBlock, numBlocks+1) + + currentDim := inDim + for i := int32(0); i <= numBlocks; i++ { + resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i) + block, err := newResBlock(weights, resPrefix, currentDim, outDim) + if err != nil { + return nil, err + } + resBlocks[i] = block + currentDim = outDim + } + + var downsampler *Downsample + if downsampleMode != "" { + downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode) + } + + return &DownBlock{ + ResBlocks: resBlocks, + Downsampler: downsampler, + }, nil +} + +// Forward applies down block +func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array { + for _, block := range d.ResBlocks { + prev := x + x = block.Forward(x) + prev.Free() + } + + if d.Downsampler != nil { + prev := x + x = d.Downsampler.Forward(x) + prev.Free() + } + return x +} + +// Downsample handles spatial downsampling +type Downsample struct { + Conv *mlx.Array + Bias *mlx.Array + Mode string +} + +// newDownsample creates a downsampler +func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample { + conv, _ := weights.Get(prefix + ".resample.1.weight") + bias, _ := weights.Get(prefix + ".resample.1.bias") + return &Downsample{ + Conv: conv, + Bias: bias, + Mode: mode, + } +} + +// Forward applies downsampling to channels-last input [B, T, H, W, C] +func (d *Downsample) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + T := shape[1] + H := shape[2] + W := shape[3] + C := shape[4] + outC := d.Conv.Shape()[0] + + // Reshape to [B*T, H, W, C] for 2D conv + x = mlx.Reshape(x, B*T, H, W, C) + + // Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding + // For 3x3 stride 2: pad 1 on all sides + x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0}) + + // Conv with stride 2 using manual strided patching + weight := mlx.Transpose(d.Conv, 0, 2, 3, 1) + x = conv2DStrided(x, weight, 2) + if d.Bias != nil { + bias := mlx.Reshape(d.Bias, 1, 1, 1, outC) + x = mlx.Add(x, bias) + } + + x = mlx.Reshape(x, B, T, H/2, W/2, outC) + mlx.Eval(x) + + return x +} diff --git a/x/imagegen/models/zimage/scheduler.go b/x/imagegen/models/zimage/scheduler.go new file mode 100644 index 000000000..3c55fc62f --- /dev/null +++ b/x/imagegen/models/zimage/scheduler.go @@ -0,0 +1,148 @@ +//go:build mlx + +package zimage + +import ( + "math" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// FlowMatchSchedulerConfig holds scheduler configuration +type FlowMatchSchedulerConfig struct { + NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000 + Shift float32 `json:"shift"` // 3.0 + UseDynamicShifting bool `json:"use_dynamic_shifting"` // false +} + +// DefaultFlowMatchSchedulerConfig returns default config +func DefaultFlowMatchSchedulerConfig() *FlowMatchSchedulerConfig { + return &FlowMatchSchedulerConfig{ + NumTrainTimesteps: 1000, + Shift: 3.0, + UseDynamicShifting: true, // Z-Image-Turbo uses dynamic shifting + } +} + +// FlowMatchEulerScheduler implements the Flow Match Euler discrete scheduler +// This is used in Z-Image-Turbo for fast sampling +type FlowMatchEulerScheduler struct { + Config *FlowMatchSchedulerConfig + Timesteps []float32 // Discretized timesteps + Sigmas []float32 // Noise levels at each timestep + NumSteps int // Number of inference steps +} + +// NewFlowMatchEulerScheduler creates a new scheduler +func NewFlowMatchEulerScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchEulerScheduler { + return &FlowMatchEulerScheduler{ + Config: cfg, + } +} + +// SetTimesteps sets up the scheduler for the given number of inference steps +func (s *FlowMatchEulerScheduler) SetTimesteps(numSteps int) { + s.SetTimestepsWithMu(numSteps, 0) +} + +// SetTimestepsWithMu sets up the scheduler with dynamic mu shift +func (s *FlowMatchEulerScheduler) SetTimestepsWithMu(numSteps int, mu float32) { + s.NumSteps = numSteps + + // Create evenly spaced timesteps from 1.0 to 0.0 (flow matching goes t=1 to t=0) + // Match Python: np.linspace(1.0, 0.0, num_inference_steps + 1) + s.Timesteps = make([]float32, numSteps+1) + s.Sigmas = make([]float32, numSteps+1) + + for i := 0; i <= numSteps; i++ { + t := 1.0 - float32(i)/float32(numSteps) + + // Apply time shift if using dynamic shifting + if s.Config.UseDynamicShifting && mu != 0 { + t = s.timeShift(mu, t) + } + + s.Timesteps[i] = t + s.Sigmas[i] = t + } +} + +// timeShift applies the dynamic time shift (match Python) +func (s *FlowMatchEulerScheduler) timeShift(mu float32, t float32) float32 { + if t <= 0 { + return 0 + } + // exp(mu) / (exp(mu) + (1/t - 1)) + expMu := float32(math.Exp(float64(mu))) + return expMu / (expMu + (1.0/t - 1.0)) +} + +// Step performs one denoising step +// modelOutput: predicted velocity/noise from the model +// timestepIdx: current timestep index +// sample: current noisy sample +// Returns: denoised sample for next step +func (s *FlowMatchEulerScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array { + // Get current and next sigma + sigma := s.Sigmas[timestepIdx] + sigmaNext := s.Sigmas[timestepIdx+1] + + // Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t + // where v_t is the velocity predicted by the model + dt := sigmaNext - sigma // This is negative (going from noise to clean) + + // x_next = x + dt * velocity + scaledOutput := mlx.MulScalar(modelOutput, dt) + return mlx.Add(sample, scaledOutput) +} + +// ScaleSample scales the sample for model input (identity for flow matching) +func (s *FlowMatchEulerScheduler) ScaleSample(sample *mlx.Array, timestepIdx int) *mlx.Array { + // Flow matching doesn't need scaling + return sample +} + +// GetTimestep returns the timestep value at the given index +func (s *FlowMatchEulerScheduler) GetTimestep(idx int) float32 { + if idx < len(s.Timesteps) { + return s.Timesteps[idx] + } + return 0.0 +} + +// GetTimesteps returns all timesteps (implements Scheduler interface) +func (s *FlowMatchEulerScheduler) GetTimesteps() []float32 { + return s.Timesteps +} + +// AddNoise adds noise to clean samples for a given timestep +// Used for img2img or inpainting +func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array { + // In flow matching: x_t = (1-t) * x_0 + t * noise + t := s.Timesteps[timestepIdx] + oneMinusT := 1.0 - t + + scaledClean := mlx.MulScalar(cleanSample, oneMinusT) + scaledNoise := mlx.MulScalar(noise, t) + + return mlx.Add(scaledClean, scaledNoise) +} + +// InitNoise creates initial noise for sampling +func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array { + return RandomNormal(shape, seed) +} + +// RandomNormal creates a random normal tensor using MLX +func RandomNormal(shape []int32, seed int64) *mlx.Array { + return mlx.RandomNormal(shape, uint64(seed)) +} + +// GetLatentShape returns the latent shape for a given image size +func GetLatentShape(batchSize, height, width, latentChannels int32, patchSize int32) []int32 { + // Latent is 8x smaller than image (VAE downscale) + latentH := height / 8 + latentW := width / 8 + + return []int32{batchSize, latentChannels, latentH, latentW} +} diff --git a/x/imagegen/models/zimage/text_encoder.go b/x/imagegen/models/zimage/text_encoder.go new file mode 100644 index 000000000..2f2cc897c --- /dev/null +++ b/x/imagegen/models/zimage/text_encoder.go @@ -0,0 +1,296 @@ +//go:build mlx + +package zimage + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// Qwen3Config holds Qwen3 text encoder configuration +type Qwen3Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + HeadDim int32 `json:"head_dim"` +} + +// loadQwen3Config loads text encoder config from a JSON file +func loadQwen3Config(path string) (*Qwen3Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config: %w", err) + } + var cfg Qwen3Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + return &cfg, nil +} + +// Qwen3Attention implements Qwen3 attention with QK norms +type Qwen3Attention struct { + QProj *nn.Linear `weight:"q_proj"` + KProj *nn.Linear `weight:"k_proj"` + VProj *nn.Linear `weight:"v_proj"` + OProj *nn.Linear `weight:"o_proj"` + QNorm *nn.RMSNorm `weight:"q_norm"` + KNorm *nn.RMSNorm `weight:"k_norm"` + // Computed fields + NHeads int32 + NKVHeads int32 + HeadDim int32 + Scale float32 + RopeTheta float32 +} + +// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder +func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + H := shape[2] + D := shape[3] + half := D / 2 + + freqsArr := make([]float32, half) + logTheta := float32(math.Log(float64(theta))) + for i := int32(0); i < half; i++ { + freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half)))) + } + freqs := mlx.NewArray(freqsArr, []int32{half}) + + posArr := make([]float32, seqLen) + for i := int32(0); i < seqLen; i++ { + posArr[i] = float32(i) + } + pos := mlx.NewArray(posArr, []int32{seqLen}) + + posExpanded := mlx.Reshape(pos, seqLen, 1) + freqsExpanded := mlx.Reshape(freqs, 1, half) + args := mlx.Mul(posExpanded, freqsExpanded) + + cosVals := mlx.Cos(args) + sinVals := mlx.Sin(args) + cosVals = mlx.Reshape(cosVals, seqLen, 1, half) + sinVals = mlx.Reshape(sinVals, seqLen, 1, half) + + x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) + x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) + + part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals)) + part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals)) + + return mlx.Concatenate([]*mlx.Array{part1, part2}, 3) +} + +// Forward computes attention with causal masking +func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + + q := attn.QProj.Forward(x) + k := attn.KProj.Forward(x) + v := attn.VProj.Forward(x) + + q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) + k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim) + v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim) + + // QK norm uses 1e-6 hardcoded (Qwen3 specific) + q = attn.QNorm.Forward(q, 1e-6) + k = attn.KNorm.Forward(k, 1e-6) + + q = applyRoPEQwen3(q, L, attn.RopeTheta) + k = applyRoPEQwen3(k, L, attn.RopeTheta) + + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + if attn.NKVHeads < attn.NHeads { + repeats := attn.NHeads / attn.NKVHeads + k = repeatKV(k, repeats) + v = repeatKV(v, repeats) + } + + out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true) + + out = mlx.Transpose(out, 0, 2, 1, 3) + out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim) + + out = attn.OProj.Forward(out) + + return out +} + +// repeatKV repeats key/value heads for GQA +func repeatKV(x *mlx.Array, repeats int32) *mlx.Array { + if repeats == 1 { + return x + } + shape := x.Shape() + x = mlx.ExpandDims(x, 2) + x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1}) + return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3]) +} + +// Qwen3MLP implements Qwen3 SwiGLU MLP +type Qwen3MLP struct { + GateProj *nn.Linear `weight:"gate_proj"` + UpProj *nn.Linear `weight:"up_proj"` + DownProj *nn.Linear `weight:"down_proj"` +} + +// Forward applies the MLP +func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array { + gate := m.GateProj.Forward(x) + gate = mlx.SiLU(gate) + up := m.UpProj.Forward(x) + h := mlx.Mul(gate, up) + return m.DownProj.Forward(h) +} + +// Qwen3Block represents a single Qwen3 transformer block +type Qwen3Block struct { + Attention *Qwen3Attention `weight:"self_attn"` + MLP *Qwen3MLP `weight:"mlp"` + InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"` + PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"` +} + +// Forward applies the Qwen3 block +func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array { + h := qb.InputLayerNorm.Forward(x, eps) + attnOut := qb.Attention.Forward(h) + x = mlx.Add(x, attnOut) + + h = qb.PostAttnLayerNorm.Forward(x, eps) + mlpOut := qb.MLP.Forward(h) + x = mlx.Add(x, mlpOut) + + return x +} + +// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image +type Qwen3TextEncoder struct { + EmbedTokens *nn.Embedding `weight:"model.embed_tokens"` + Layers []*Qwen3Block `weight:"model.layers"` + FinalNorm *nn.RMSNorm `weight:"model.norm"` + *Qwen3Config +} + +// Load loads the Qwen3 text encoder from a directory +func (m *Qwen3TextEncoder) Load(path string) error { + fmt.Println("Loading Qwen3 text encoder...") + + // Load config + cfg, err := loadQwen3Config(filepath.Join(path, "config.json")) + if err != nil { + return fmt.Errorf("config: %w", err) + } + m.Qwen3Config = cfg + + // Pre-allocate layers slice + m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers) + + // Load weights + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + + fmt.Print(" Loading weights via struct tags... ") + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + fmt.Println("✓") + + // Initialize computed fields + m.FinalNorm.Eps = cfg.RMSNormEps + for _, block := range m.Layers { + // Attention + block.Attention.NHeads = cfg.NumAttentionHeads + block.Attention.NKVHeads = cfg.NumKeyValueHeads + block.Attention.HeadDim = cfg.HeadDim + block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + block.Attention.RopeTheta = cfg.RopeTheta + block.Attention.QNorm.Eps = cfg.RMSNormEps + block.Attention.KNorm.Eps = cfg.RMSNormEps + // Block norms + block.InputLayerNorm.Eps = cfg.RMSNormEps + block.PostAttnLayerNorm.Eps = cfg.RMSNormEps + } + + weights.ReleaseAll() + return nil +} + +// Forward encodes text tokens +func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array { + h := te.EmbedTokens.Forward(tokens) + eps := te.RMSNormEps + + for _, layer := range te.Layers { + h = layer.Forward(h, eps) + } + + // Apply final RMS norm + h = te.FinalNorm.Forward(h, eps) + + return h +} + +// ApplyChatTemplate wraps prompt in Qwen3 chat format +func ApplyChatTemplate(prompt string) string { + return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n" +} + +// EncodePrompt encodes a text prompt using the tokenizer and encoder +func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) { + formattedPrompt := ApplyChatTemplate(prompt) + + tokens := tok.Encode(formattedPrompt, false) + + if len(tokens) > maxLen { + tokens = tokens[:maxLen] + } + + maskData := make([]float32, maxLen) + for i := 0; i < len(tokens); i++ { + maskData[i] = 1.0 + } + + // Get PAD token (different from EOS for Qwen3) + padToken := tok.PAD() + if padToken < 0 { + padToken = tok.EOS() // fallback + } + + paddedTokens := make([]int32, maxLen) + copy(paddedTokens, tokens) + for i := len(tokens); i < maxLen; i++ { + paddedTokens[i] = padToken + } + + tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)}) + maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)}) + + embeddings := te.Forward(tokensArr) + + return embeddings, maskArr +} diff --git a/x/imagegen/models/zimage/transformer.go b/x/imagegen/models/zimage/transformer.go new file mode 100644 index 000000000..63314be00 --- /dev/null +++ b/x/imagegen/models/zimage/transformer.go @@ -0,0 +1,692 @@ +//go:build mlx + +// Package zimage implements the Z-Image diffusion transformer model. +package zimage + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/nn" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// TransformerConfig holds Z-Image transformer configuration +type TransformerConfig struct { + Dim int32 `json:"dim"` + NHeads int32 `json:"n_heads"` + NKVHeads int32 `json:"n_kv_heads"` + NLayers int32 `json:"n_layers"` + NRefinerLayers int32 `json:"n_refiner_layers"` + InChannels int32 `json:"in_channels"` + PatchSize int32 `json:"-"` // Computed from AllPatchSize + CapFeatDim int32 `json:"cap_feat_dim"` + NormEps float32 `json:"norm_eps"` + RopeTheta float32 `json:"rope_theta"` + TScale float32 `json:"t_scale"` + QKNorm bool `json:"qk_norm"` + AxesDims []int32 `json:"axes_dims"` + AxesLens []int32 `json:"axes_lens"` + AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element +} + +// TimestepEmbedder creates sinusoidal timestep embeddings +// Output dimension is 256 (fixed), used for AdaLN modulation +type TimestepEmbedder struct { + Linear1 *nn.Linear `weight:"mlp.0"` + Linear2 *nn.Linear `weight:"mlp.2"` + FreqEmbedSize int32 // 256 (computed) +} + +// Forward computes timestep embeddings -> [B, 256] +func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array { + // t: [B] timesteps + + // Create sinusoidal embedding + half := te.FreqEmbedSize / 2 + + // freqs = exp(-log(10000) * arange(half) / half) + freqs := make([]float32, half) + for i := int32(0); i < half; i++ { + freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half))) + } + freqsArr := mlx.NewArray(freqs, []int32{1, half}) + + // t[:, None] * freqs[None, :] -> [B, half] + tExpanded := mlx.ExpandDims(t, 1) // [B, 1] + args := mlx.Mul(tExpanded, freqsArr) + + // embedding = [cos(args), sin(args)] -> [B, 256] + cosArgs := mlx.Cos(args) + sinArgs := mlx.Sin(args) + embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) + + // MLP: linear1 -> silu -> linear2 + h := te.Linear1.Forward(embedding) + h = mlx.SiLU(h) + h = te.Linear2.Forward(h) + + return h +} + +// XEmbedder embeds image patches to model dimension +type XEmbedder struct { + Linear *nn.Linear `weight:"2-1"` +} + +// Forward embeds patchified image latents +func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array { + // x: [B, L, in_channels * 4] -> [B, L, dim] + return xe.Linear.Forward(x) +} + +// CapEmbedder projects caption features to model dimension +type CapEmbedder struct { + Norm *nn.RMSNorm `weight:"0"` + Linear *nn.Linear `weight:"1"` + PadToken *mlx.Array // loaded separately at root level +} + +// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim] +func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array { + // RMSNorm on last axis (uses 1e-6) + h := ce.Norm.Forward(capFeats, 1e-6) + // Linear projection + return ce.Linear.Forward(h) +} + +// FeedForward implements SwiGLU FFN +type FeedForward struct { + W1 *nn.Linear `weight:"w1"` // gate projection + W2 *nn.Linear `weight:"w2"` // down projection + W3 *nn.Linear `weight:"w3"` // up projection + OutDim int32 // computed from W2 +} + +// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2 +func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + // Reshape for matmul + x = mlx.Reshape(x, B*L, D) + gate := ff.W1.Forward(x) + gate = mlx.SiLU(gate) + up := ff.W3.Forward(x) + h := mlx.Mul(gate, up) + out := ff.W2.Forward(h) + + return mlx.Reshape(out, B, L, ff.OutDim) +} + +// Attention implements multi-head attention with QK norm +type Attention struct { + ToQ *nn.Linear `weight:"to_q"` + ToK *nn.Linear `weight:"to_k"` + ToV *nn.Linear `weight:"to_v"` + ToOut *nn.Linear `weight:"to_out.0"` + NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm + NormK *mlx.Array `weight:"norm_k.weight"` + // Computed fields + NHeads int32 + HeadDim int32 + Dim int32 + Scale float32 +} + +// Forward computes attention +func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + + // Project Q, K, V + xFlat := mlx.Reshape(x, B*L, D) + q := attn.ToQ.Forward(xFlat) + k := attn.ToK.Forward(xFlat) + v := attn.ToV.Forward(xFlat) + + // Reshape to [B, L, nheads, head_dim] + q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim) + k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim) + v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim) + + // QK norm + q = mlx.RMSNorm(q, attn.NormQ, 1e-5) + k = mlx.RMSNorm(k, attn.NormK, 1e-5) + + // Apply RoPE if provided + if cos != nil && sin != nil { + q = applyRoPE3D(q, cos, sin) + k = applyRoPE3D(k, cos, sin) + } + + // Transpose to [B, nheads, L, head_dim] + q = mlx.Transpose(q, 0, 2, 1, 3) + k = mlx.Transpose(k, 0, 2, 1, 3) + v = mlx.Transpose(v, 0, 2, 1, 3) + + // SDPA + out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false) + + // Transpose back and reshape + out = mlx.Transpose(out, 0, 2, 1, 3) + out = mlx.Reshape(out, B*L, attn.Dim) + out = attn.ToOut.Forward(out) + + return mlx.Reshape(out, B, L, attn.Dim) +} + +// applyRoPE3D applies 3-axis rotary position embeddings +// x: [B, L, nheads, head_dim] +// cos, sin: [B, L, 1, head_dim/2] +func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + L := shape[1] + nheads := shape[2] + headDim := shape[3] + half := headDim / 2 + + // Create even/odd index arrays + evenIdx := make([]int32, half) + oddIdx := make([]int32, half) + for i := int32(0); i < half; i++ { + evenIdx[i] = i * 2 + oddIdx[i] = i*2 + 1 + } + evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half}) + oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half}) + + // Extract x1 (even indices) and x2 (odd indices) along last axis + x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half] + x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half] + + // Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos] + r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin)) + r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos)) + + // Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...] + r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1] + r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1] + stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2] + return mlx.Reshape(stacked, B, L, nheads, headDim) +} + +// TransformerBlock is a single transformer block with optional AdaLN modulation +type TransformerBlock struct { + Attention *Attention `weight:"attention"` + FeedForward *FeedForward `weight:"feed_forward"` + AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"` + AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"` + FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"` + FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"` + AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation + // Computed fields + HasModulation bool + Dim int32 +} + +// Forward applies the transformer block +func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array { + if tb.AdaLN != nil && adaln != nil { + // Compute modulation: [B, 256] -> [B, 4*dim] + chunks := tb.AdaLN.Forward(adaln) + + // Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp + chunkShape := chunks.Shape() + chunkDim := chunkShape[1] / 4 + + scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim}) + gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2}) + scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3}) + gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4}) + + // Expand for broadcasting: [B, 1, dim] + scaleMSA = mlx.ExpandDims(scaleMSA, 1) + gateMSA = mlx.ExpandDims(gateMSA, 1) + scaleMLP = mlx.ExpandDims(scaleMLP, 1) + gateMLP = mlx.ExpandDims(gateMLP, 1) + + // Attention with modulation + normX := tb.AttentionNorm1.Forward(x, eps) + normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0)) + attnOut := tb.Attention.Forward(normX, cos, sin) + attnOut = tb.AttentionNorm2.Forward(attnOut, eps) + x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut)) + + // FFN with modulation + normFFN := tb.FFNNorm1.Forward(x, eps) + normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0)) + ffnOut := tb.FeedForward.Forward(normFFN) + ffnOut = tb.FFNNorm2.Forward(ffnOut, eps) + x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut)) + } else { + // No modulation (context refiner) + attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin) + x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps)) + + ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps)) + x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps)) + } + + return x +} + +// FinalLayer outputs the denoised patches +type FinalLayer struct { + AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim] + Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels] + OutDim int32 // computed from Output +} + +// Forward computes final output +func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array { + // c: [B, 256] -> scale: [B, dim] + scale := mlx.SiLU(c) + scale = fl.AdaLN.Forward(scale) + scale = mlx.ExpandDims(scale, 1) // [B, 1, dim] + + // LayerNorm (affine=False) then scale + x = layerNormNoAffine(x, 1e-6) + x = mlx.Mul(x, mlx.AddScalar(scale, 1.0)) + + // Output projection + shape := x.Shape() + B := shape[0] + L := shape[1] + D := shape[2] + x = mlx.Reshape(x, B*L, D) + x = fl.Output.Forward(x) + + return mlx.Reshape(x, B, L, fl.OutDim) +} + +// layerNormNoAffine applies layer norm without learnable parameters +func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array { + ndim := x.Ndim() + lastAxis := ndim - 1 + + mean := mlx.Mean(x, lastAxis, true) + xCentered := mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true) + return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps))) +} + +// Transformer is the full Z-Image DiT model +type Transformer struct { + TEmbed *TimestepEmbedder `weight:"t_embedder"` + XEmbed *XEmbedder `weight:"all_x_embedder"` + CapEmbed *CapEmbedder `weight:"cap_embedder"` + NoiseRefiners []*TransformerBlock `weight:"noise_refiner"` + ContextRefiners []*TransformerBlock `weight:"context_refiner"` + Layers []*TransformerBlock `weight:"layers"` + FinalLayer *FinalLayer `weight:"all_final_layer.2-1"` + XPadToken *mlx.Array `weight:"x_pad_token"` + CapPadToken *mlx.Array `weight:"cap_pad_token"` + *TransformerConfig +} + +// Load loads the Z-Image transformer from a directory +func (m *Transformer) Load(path string) error { + fmt.Println("Loading Z-Image transformer...") + + // Load config + cfg, err := loadTransformerConfig(filepath.Join(path, "config.json")) + if err != nil { + return fmt.Errorf("config: %w", err) + } + m.TransformerConfig = cfg + + // Pre-allocate slices for loader + m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers) + m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers) + m.Layers = make([]*TransformerBlock, cfg.NLayers) + + // Load weights + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + + fmt.Print(" Loading weights as bf16... ") + if err := weights.Load(mlx.DtypeBFloat16); err != nil { + return fmt.Errorf("load weights: %w", err) + } + fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024)) + + fmt.Print(" Loading weights via struct tags... ") + if err := safetensors.LoadModule(m, weights, ""); err != nil { + return fmt.Errorf("load module: %w", err) + } + fmt.Println("✓") + + // Initialize computed fields + m.TEmbed.FreqEmbedSize = 256 + m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0] + m.CapEmbed.Norm.Eps = 1e-6 + + for _, block := range m.NoiseRefiners { + initTransformerBlock(block, cfg) + } + for _, block := range m.ContextRefiners { + initTransformerBlock(block, cfg) + } + for _, block := range m.Layers { + initTransformerBlock(block, cfg) + } + + weights.ReleaseAll() + return nil +} + +// loadTransformerConfig loads transformer config from a JSON file +func loadTransformerConfig(path string) (*TransformerConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config: %w", err) + } + var cfg TransformerConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + // Extract PatchSize from array + if len(cfg.AllPatchSize) > 0 { + cfg.PatchSize = cfg.AllPatchSize[0] + } + return &cfg, nil +} + +// initTransformerBlock sets computed fields on a transformer block +func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) { + block.Dim = cfg.Dim + block.HasModulation = block.AdaLN != nil + + // Init attention computed fields + attn := block.Attention + attn.NHeads = cfg.NHeads + attn.HeadDim = cfg.Dim / cfg.NHeads + attn.Dim = cfg.Dim + attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim))) + + // Init feedforward OutDim + block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0] + + // Set eps on all RMSNorm layers + block.AttentionNorm1.Eps = cfg.NormEps + block.AttentionNorm2.Eps = cfg.NormEps + block.FFNNorm1.Eps = cfg.NormEps + block.FFNNorm2.Eps = cfg.NormEps +} + +// RoPECache holds precomputed RoPE values +type RoPECache struct { + ImgCos *mlx.Array + ImgSin *mlx.Array + CapCos *mlx.Array + CapSin *mlx.Array + UnifiedCos *mlx.Array + UnifiedSin *mlx.Array + ImgLen int32 + CapLen int32 +} + +// PrepareRoPECache precomputes RoPE values for the given image and caption lengths. +// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize). +func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache { + imgLen := hTok * wTok + + // Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0) + imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0) + imgPos = mlx.ToBFloat16(imgPos) + // Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0) + capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0) + capPos = mlx.ToBFloat16(capPos) + + // Compute RoPE from UNIFIED positions + unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1) + unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims) + + // Slice RoPE for image and caption parts + imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64}) + imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64}) + capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64}) + capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64}) + + return &RoPECache{ + ImgCos: imgCos, + ImgSin: imgSin, + CapCos: capCos, + CapSin: capSin, + UnifiedCos: unifiedCos, + UnifiedSin: unifiedSin, + ImgLen: imgLen, + CapLen: capLen, + } +} + +// Forward runs the Z-Image transformer with precomputed RoPE +func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array { + imgLen := rope.ImgLen + + // Timestep embedding -> [B, 256] + temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale)) + + // Embed image patches -> [B, L_img, dim] + x = m.XEmbed.Forward(x) + + // Embed caption features -> [B, L_cap, dim] + capEmb := m.CapEmbed.Forward(capFeats) + + eps := m.NormEps + + // Noise refiner: refine image patches with modulation + for _, refiner := range m.NoiseRefiners { + x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps) + } + + // Context refiner: refine caption (no modulation) + for _, refiner := range m.ContextRefiners { + capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps) + } + + // Concatenate image and caption for joint attention + unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1) + + // Main transformer layers use full unified RoPE + for _, layer := range m.Layers { + unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps) + } + + // Extract image tokens only + unifiedShape := unified.Shape() + B := unifiedShape[0] + imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]}) + + // Final layer + return m.FinalLayer.Forward(imgOut, temb) +} + +// ForwardWithCache runs the transformer with layer caching for faster inference. +// On refresh steps (step % cacheInterval == 0), all layers are computed and cached. +// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs. +func (m *Transformer) ForwardWithCache( + x *mlx.Array, + t *mlx.Array, + capFeats *mlx.Array, + rope *RoPECache, + stepCache *cache.StepCache, + step int, + cacheInterval int, +) *mlx.Array { + imgLen := rope.ImgLen + cacheLayers := stepCache.NumLayers() + eps := m.NormEps + + // Timestep embedding -> [B, 256] + temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale)) + + // Embed image patches -> [B, L_img, dim] + x = m.XEmbed.Forward(x) + + // Context refiners: compute once on step 0, reuse forever + // (caption embedding doesn't depend on timestep or latents) + var capEmb *mlx.Array + if stepCache.GetConstant() != nil { + capEmb = stepCache.GetConstant() + } else { + capEmb = m.CapEmbed.Forward(capFeats) + for _, refiner := range m.ContextRefiners { + capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps) + } + stepCache.SetConstant(capEmb) + } + + // Noise refiners: always compute (depend on x which changes each step) + for _, refiner := range m.NoiseRefiners { + x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps) + } + + // Concatenate image and caption for joint attention + unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1) + + // Determine if this is a cache refresh step + refreshCache := stepCache.ShouldRefresh(step, cacheInterval) + + // Main transformer layers with caching + for i, layer := range m.Layers { + if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil { + // Use cached output for shallow layers + unified = stepCache.Get(i) + } else { + // Compute layer + unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps) + // Cache shallow layer outputs on refresh steps + if i < cacheLayers && refreshCache { + stepCache.Set(i, unified) + } + } + } + + // Extract image tokens only + unifiedShape := unified.Shape() + B := unifiedShape[0] + imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]}) + + // Final layer + return m.FinalLayer.Forward(imgOut, temb) +} + +// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3] +func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array { + // Create meshgrid and stack + total := d0 * d1 * d2 + coords := make([]float32, total*3) + + idx := 0 + for i := int32(0); i < d0; i++ { + for j := int32(0); j < d1; j++ { + for k := int32(0); k < d2; k++ { + coords[idx*3+0] = float32(s0 + i) + coords[idx*3+1] = float32(s1 + j) + coords[idx*3+2] = float32(s2 + k) + idx++ + } + } + } + + return mlx.NewArray(coords, []int32{1, total, 3}) +} + +// prepareRoPE3D computes cos/sin for 3-axis RoPE +// positions: [B, L, 3] with (h, w, t) coordinates +// axesDims: [32, 48, 48] - dimensions for each axis +// Returns: cos, sin each [B, L, 1, head_dim/2] +func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) { + // Compute frequencies for each axis + // dims = [32, 48, 48], so halves = [16, 24, 24] + ropeTheta := float32(256.0) + + freqs := make([]*mlx.Array, 3) + for axis := 0; axis < 3; axis++ { + half := axesDims[axis] / 2 + f := make([]float32, half) + for i := int32(0); i < half; i++ { + f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half))) + } + freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half}) + } + + // Extract position coordinates + shape := positions.Shape() + B := shape[0] + L := shape[1] + + // positions[:, :, 0] -> h positions + posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1}) + posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2}) + posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3}) + + // Compute args: pos * freqs for each axis + posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1] + posW = mlx.ExpandDims(posW, 3) + posT = mlx.ExpandDims(posT, 3) + + argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16] + argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24] + argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24] + + // Concatenate: [B, L, 1, 16+24+24=64] + args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3) + + // Compute cos and sin + return mlx.Cos(args), mlx.Sin(args) +} + +// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2] +// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4) +func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array { + shape := latents.Shape() + C := shape[1] + H := shape[2] + W := shape[3] + + pH := H / patchSize // H_tok + pW := W / patchSize // W_tok + + // Match Python exactly: reshape treating B=1 as part of contiguous data + // [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2] + x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize) + + // Python: transpose(1, 2, 3, 5, 4, 6, 0) + // [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C] + x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0) + + // [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4] + return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize) +} + +// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W] +// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W) +func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array { + pH := H / patchSize + pW := W / patchSize + + // [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C] + x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C) + + // Python: transpose(6, 0, 1, 2, 4, 3, 5) + // [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2] + x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5) + + // [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W] + return mlx.Reshape(x, 1, C, H, W) +} diff --git a/x/imagegen/models/zimage/vae.go b/x/imagegen/models/zimage/vae.go new file mode 100644 index 000000000..fda0f63c1 --- /dev/null +++ b/x/imagegen/models/zimage/vae.go @@ -0,0 +1,652 @@ +//go:build mlx + +package zimage + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/safetensors" +) + +// VAEConfig holds VAE decoder configuration +type VAEConfig struct { + InChannels int32 `json:"in_channels"` + OutChannels int32 `json:"out_channels"` + LatentChannels int32 `json:"latent_channels"` + BlockOutChannels []int32 `json:"block_out_channels"` + LayersPerBlock int32 `json:"layers_per_block"` + NormNumGroups int32 `json:"norm_num_groups"` + ScalingFactor float32 `json:"scaling_factor"` + ShiftFactor float32 `json:"shift_factor"` +} + +// loadVAEConfig loads VAE config from a JSON file +func loadVAEConfig(path string) (*VAEConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config: %w", err) + } + var cfg VAEConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + return &cfg, nil +} + +// GroupNormLayer implements group normalization +type GroupNormLayer struct { + Weight *mlx.Array + Bias *mlx.Array + NumGroups int32 + Eps float32 +} + +// NewGroupNorm creates a group norm layer +func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer { + return &GroupNormLayer{ + Weight: weight, + Bias: bias, + NumGroups: numGroups, + Eps: 1e-5, + } +} + +// Forward applies group normalization +func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array { + // x: [B, C, H, W] + shape := x.Shape() + B := shape[0] + C := shape[1] + H := shape[2] + W := shape[3] + + // Reshape to [B, groups, C/groups, H, W] + groupSize := C / gn.NumGroups + x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W) + + // Compute mean and variance per group + mean := mlx.Mean(x, 2, true) + mean = mlx.Mean(mean, 3, true) + mean = mlx.Mean(mean, 4, true) + + xCentered := mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Square(xCentered), 2, true) + variance = mlx.Mean(variance, 3, true) + variance = mlx.Mean(variance, 4, true) + + // Normalize + xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps))) + + // Reshape back to [B, C, H, W] + xNorm = mlx.Reshape(xNorm, B, C, H, W) + + // Scale and shift (weight and bias are [C]) + if gn.Weight != nil { + weight := mlx.Reshape(gn.Weight, 1, C, 1, 1) + xNorm = mlx.Mul(xNorm, weight) + } + if gn.Bias != nil { + bias := mlx.Reshape(gn.Bias, 1, C, 1, 1) + xNorm = mlx.Add(xNorm, bias) + } + + return xNorm +} + +// Conv2D represents a 2D convolution layer +// MLX uses NHWC format, but we store weights in OHWI format for MLX conv +type Conv2D struct { + Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX) + Bias *mlx.Array // [out_channels] + Stride int32 + Padding int32 +} + +// NewConv2D creates a Conv2D layer +// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch) +// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX) +func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D { + // Transpose weight from OIHW to OHWI + // [O, I, H, W] -> [O, H, W, I] + weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1) + return &Conv2D{ + Weight: weightOHWI, + Bias: bias, + Stride: stride, + Padding: padding, + } +} + +// Forward applies convolution +// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW +func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array { + // x: [N, C, H, W] -> [N, H, W, C] + xNHWC := mlx.Transpose(x, 0, 2, 3, 1) + + // Conv in NHWC format + outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding) + + // Convert back to NCHW: [N, H, W, C] -> [N, C, H, W] + out := mlx.Transpose(outNHWC, 0, 3, 1, 2) + + if conv.Bias != nil { + bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1) + out = mlx.Add(out, bias) + } + return out +} + +// ResnetBlock2D implements a ResNet block for VAE +type ResnetBlock2D struct { + Norm1 *GroupNormLayer + Conv1 *Conv2D + Norm2 *GroupNormLayer + Conv2 *Conv2D + ConvShortcut *Conv2D // nil if in_channels == out_channels +} + +// NewResnetBlock2D creates a ResNet block +func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) { + norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight") + if err != nil { + return nil, err + } + norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias") + if err != nil { + return nil, err + } + + conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight") + if err != nil { + return nil, err + } + conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias") + if err != nil { + return nil, err + } + + norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight") + if err != nil { + return nil, err + } + norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias") + if err != nil { + return nil, err + } + + conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight") + if err != nil { + return nil, err + } + conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias") + if err != nil { + return nil, err + } + + block := &ResnetBlock2D{ + Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups), + Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1), + Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups), + Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1), + } + + if weights.HasTensor(prefix + ".conv_shortcut.weight") { + shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight") + if err != nil { + return nil, err + } + shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias") + if err != nil { + return nil, err + } + block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0) + } + + return block, nil +} + +// Forward applies the ResNet block with staged evaluation +func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array { + var h *mlx.Array + + // Stage 1: norm1 + { + h = rb.Norm1.Forward(x) + mlx.Eval(h) + } + + // Stage 2: silu + conv1 + { + prev := h + h = mlx.SiLU(h) + h = rb.Conv1.Forward(h) + prev.Free() + mlx.Eval(h) + } + + // Stage 3: norm2 + { + prev := h + h = rb.Norm2.Forward(h) + prev.Free() + mlx.Eval(h) + } + + // Stage 4: silu + conv2 + { + prev := h + h = mlx.SiLU(h) + h = rb.Conv2.Forward(h) + prev.Free() + mlx.Eval(h) + } + + // Residual connection + { + prev := h + if rb.ConvShortcut != nil { + shortcut := rb.ConvShortcut.Forward(x) + h = mlx.Add(h, shortcut) + } else { + h = mlx.Add(h, x) + } + prev.Free() + mlx.Eval(h) + } + + return h +} + +// VAEAttentionBlock implements self-attention for VAE +type VAEAttentionBlock struct { + GroupNorm *GroupNormLayer + ToQWeight *mlx.Array + ToQBias *mlx.Array + ToKWeight *mlx.Array + ToKBias *mlx.Array + ToVWeight *mlx.Array + ToVBias *mlx.Array + ToOutWeight *mlx.Array + ToOutBias *mlx.Array + NumHeads int32 +} + +// NewVAEAttentionBlock creates an attention block +func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) { + normWeight, err := weights.GetTensor(prefix + ".group_norm.weight") + if err != nil { + return nil, err + } + normBias, err := weights.GetTensor(prefix + ".group_norm.bias") + if err != nil { + return nil, err + } + + toQWeight, err := weights.GetTensor(prefix + ".to_q.weight") + if err != nil { + return nil, err + } + toQBias, err := weights.GetTensor(prefix + ".to_q.bias") + if err != nil { + return nil, err + } + + toKWeight, err := weights.GetTensor(prefix + ".to_k.weight") + if err != nil { + return nil, err + } + toKBias, err := weights.GetTensor(prefix + ".to_k.bias") + if err != nil { + return nil, err + } + + toVWeight, err := weights.GetTensor(prefix + ".to_v.weight") + if err != nil { + return nil, err + } + toVBias, err := weights.GetTensor(prefix + ".to_v.bias") + if err != nil { + return nil, err + } + + toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight") + if err != nil { + return nil, err + } + toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias") + if err != nil { + return nil, err + } + + return &VAEAttentionBlock{ + GroupNorm: NewGroupNorm(normWeight, normBias, numGroups), + ToQWeight: mlx.Transpose(toQWeight, 1, 0), + ToQBias: toQBias, + ToKWeight: mlx.Transpose(toKWeight, 1, 0), + ToKBias: toKBias, + ToVWeight: mlx.Transpose(toVWeight, 1, 0), + ToVBias: toVBias, + ToOutWeight: mlx.Transpose(toOutWeight, 1, 0), + ToOutBias: toOutBias, + NumHeads: 1, + }, nil +} + +// Forward applies attention with staged evaluation +func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array { + residual := x + shape := x.Shape() + B := shape[0] + C := shape[1] + H := shape[2] + W := shape[3] + + var h *mlx.Array + + // Stage 1: GroupNorm + reshape + { + h = ab.GroupNorm.Forward(x) + h = mlx.Transpose(h, 0, 2, 3, 1) + h = mlx.Reshape(h, B, H*W, C) + mlx.Eval(h) + } + + var out *mlx.Array + + // Stage 2: Q, K, V projections + attention + { + q := mlx.Linear(h, ab.ToQWeight) + q = mlx.Add(q, ab.ToQBias) + k := mlx.Linear(h, ab.ToKWeight) + k = mlx.Add(k, ab.ToKBias) + v := mlx.Linear(h, ab.ToVWeight) + v = mlx.Add(v, ab.ToVBias) + h.Free() + + q = mlx.ExpandDims(q, 1) + k = mlx.ExpandDims(k, 1) + v = mlx.ExpandDims(v, 1) + + scale := float32(1.0 / math.Sqrt(float64(C))) + out = mlx.ScaledDotProductAttention(q, k, v, scale, false) + out = mlx.Squeeze(out, 1) + mlx.Eval(out) + } + + // Stage 3: Output projection + reshape + residual + { + prev := out + out = mlx.Linear(out, ab.ToOutWeight) + out = mlx.Add(out, ab.ToOutBias) + out = mlx.Reshape(out, B, H, W, C) + out = mlx.Transpose(out, 0, 3, 1, 2) + out = mlx.Add(out, residual) + prev.Free() + mlx.Eval(out) + } + + return out +} + +// UpDecoderBlock2D implements an upsampling decoder block +type UpDecoderBlock2D struct { + ResnetBlocks []*ResnetBlock2D + Upsample *Conv2D +} + +// NewUpDecoderBlock2D creates an up decoder block +func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) { + resnets := make([]*ResnetBlock2D, numLayers) + for i := int32(0); i < numLayers; i++ { + resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i) + resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups) + if err != nil { + return nil, err + } + resnets[i] = resnet + } + + var upsample *Conv2D + if hasUpsample { + upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight") + if err != nil { + return nil, err + } + upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias") + if err != nil { + return nil, err + } + upsample = NewConv2D(upWeight, upBias, 1, 1) + } + + return &UpDecoderBlock2D{ + ResnetBlocks: resnets, + Upsample: upsample, + }, nil +} + +// Forward applies the up decoder block with staged evaluation to reduce peak memory +func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array { + for _, resnet := range ub.ResnetBlocks { + prev := x + x = resnet.Forward(x) // ResNet handles its own pools + prev.Free() + } + + if ub.Upsample != nil { + // Stage 1: Upsample2x (nearest neighbor) + { + prev := x + x = Upsample2x(x) + prev.Free() + mlx.Eval(x) + } + + // Stage 2: Upsample conv + { + prev := x + x = ub.Upsample.Forward(x) + prev.Free() + mlx.Eval(x) + } + } + + return x +} + +// VAEMidBlock is the middle block with attention +type VAEMidBlock struct { + Resnet1 *ResnetBlock2D + Attention *VAEAttentionBlock + Resnet2 *ResnetBlock2D +} + +// NewVAEMidBlock creates the mid block +func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) { + resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups) + if err != nil { + return nil, err + } + + attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups) + if err != nil { + return nil, err + } + + resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups) + if err != nil { + return nil, err + } + + return &VAEMidBlock{ + Resnet1: resnet1, + Attention: attention, + Resnet2: resnet2, + }, nil +} + +// Forward applies the mid block with staged evaluation +func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array { + prev := x + x = mb.Resnet1.Forward(x) // ResNet handles its own pools + prev.Free() + + // Attention handles its own pools + prev = x + x = mb.Attention.Forward(x) + prev.Free() + + prev = x + x = mb.Resnet2.Forward(x) // ResNet handles its own pools + prev.Free() + + return x +} + +// VAEDecoder is the full VAE decoder +type VAEDecoder struct { + Config *VAEConfig + ConvIn *Conv2D + MidBlock *VAEMidBlock + UpBlocks []*UpDecoderBlock2D + ConvNormOut *GroupNormLayer + ConvOut *Conv2D +} + +// Load loads the VAE decoder from a directory +func (m *VAEDecoder) Load(path string) error { + fmt.Println("Loading VAE decoder...") + + // Load config + cfg, err := loadVAEConfig(filepath.Join(path, "config.json")) + if err != nil { + return fmt.Errorf("config: %w", err) + } + m.Config = cfg + + // Load weights + weights, err := safetensors.LoadModelWeights(path) + if err != nil { + return fmt.Errorf("weights: %w", err) + } + + // Load conv_in + fmt.Print(" Loading conv_in... ") + convInWeight, err := weights.GetTensor("decoder.conv_in.weight") + if err != nil { + return err + } + convInBias, err := weights.GetTensor("decoder.conv_in.bias") + if err != nil { + return err + } + m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1) + fmt.Println("✓") + + // Load mid block + fmt.Print(" Loading mid block... ") + m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups) + if err != nil { + return err + } + fmt.Println("✓") + + // Load up blocks + fmt.Print(" Loading up blocks... ") + numBlocks := len(cfg.BlockOutChannels) + m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks) + for i := 0; i < numBlocks; i++ { + prefix := fmt.Sprintf("decoder.up_blocks.%d", i) + hasUpsample := i < numBlocks-1 + m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample) + if err != nil { + return err + } + } + fmt.Printf("✓ [%d blocks]\n", numBlocks) + + // Load conv_norm_out + fmt.Print(" Loading conv_norm_out... ") + normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight") + if err != nil { + return err + } + normBias, err := weights.GetTensor("decoder.conv_norm_out.bias") + if err != nil { + return err + } + m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups) + fmt.Println("✓") + + // Load conv_out + fmt.Print(" Loading conv_out... ") + convOutWeight, err := weights.GetTensor("decoder.conv_out.weight") + if err != nil { + return err + } + convOutBias, err := weights.GetTensor("decoder.conv_out.bias") + if err != nil { + return err + } + m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1) + fmt.Println("✓") + + weights.ReleaseAll() + return nil +} + +// Decode decodes latents to images. +// Uses staged pools to free intermediate arrays and reduce peak memory. +func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array { + var h *mlx.Array + { + z := mlx.DivScalar(latents, vae.Config.ScalingFactor) + z = mlx.AddScalar(z, vae.Config.ShiftFactor) + h = vae.ConvIn.Forward(z) + mlx.Eval(h) + } + + h = vae.MidBlock.Forward(h) + + for _, upBlock := range vae.UpBlocks { + h = upBlock.Forward(h) + } + + { + prev := h + h = vae.ConvNormOut.Forward(h) + h = mlx.SiLU(h) + h = vae.ConvOut.Forward(h) + // VAE outputs [-1, 1], convert to [0, 1] + h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5) + h = mlx.ClipScalar(h, 0.0, 1.0, true, true) + prev.Free() + mlx.Eval(h) + } + + return h +} + +// Upsample2x performs 2x nearest neighbor upsampling using broadcast. +// x: [B, C, H, W] -> [B, C, H*2, W*2] +func Upsample2x(x *mlx.Array) *mlx.Array { + shape := x.Shape() + B := shape[0] + C := shape[1] + H := shape[2] + W := shape[3] + + // [B, C, H, W] -> [B, C, H, 1, W, 1] + x = mlx.Reshape(x, B, C, H, 1, W, 1) + // Broadcast to [B, C, H, 2, W, 2] + x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2}) + // Reshape to [B, C, H*2, W*2] + x = mlx.Reshape(x, B, C, H*2, W*2) + + return x +} diff --git a/x/imagegen/models/zimage/zimage.go b/x/imagegen/models/zimage/zimage.go new file mode 100644 index 000000000..44b1fbdc4 --- /dev/null +++ b/x/imagegen/models/zimage/zimage.go @@ -0,0 +1,363 @@ +//go:build mlx + +// Package zimage implements the Z-Image diffusion transformer model. +package zimage + +import ( + "context" + "fmt" + "path/filepath" + "time" + + "github.com/ollama/ollama/x/imagegen/cache" + "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/imagegen/tokenizer" +) + +// GenerateConfig holds all options for image generation. +type GenerateConfig struct { + Prompt string + NegativePrompt string // Empty = no CFG + CFGScale float32 // Only used if NegativePrompt is set (default: 4.0) + Width int32 // Image width (default: 1024) + Height int32 // Image height (default: 1024) + Steps int // Denoising steps (default: 9 for turbo) + Seed int64 // Random seed + Progress ProgressFunc // Optional progress callback + CapturePath string // GPU capture path (debug) + + // Layer caching options (speedup via shallow layer reuse) + LayerCache bool // Enable layer caching (default: false) + CacheInterval int // Refresh cache every N steps (default: 3) + CacheLayers int // Number of shallow layers to cache (default: 15) +} + +// ProgressFunc is called during generation with step progress. +type ProgressFunc func(step, totalSteps int) + +// Model represents a Z-Image diffusion model. +type Model struct { + ModelPath string + Tokenizer *tokenizer.Tokenizer + TextEncoder *Qwen3TextEncoder + Transformer *Transformer + VAEDecoder *VAEDecoder +} + +// Load loads the Z-Image model from a directory. +func (m *Model) Load(modelPath string) error { + fmt.Println("Loading Z-Image model...") + start := time.Now() + + if mlx.GPUIsAvailable() { + mlx.SetDefaultDeviceGPU() + mlx.EnableCompile() + } + + m.ModelPath = modelPath + + // Load tokenizer + fmt.Print(" Loading tokenizer... ") + tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json") + tok, err := tokenizer.Load(tokenizerPath) + if err != nil { + return fmt.Errorf("tokenizer: %w", err) + } + m.Tokenizer = tok + fmt.Println("✓") + + // Load text encoder + m.TextEncoder = &Qwen3TextEncoder{} + if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil { + return fmt.Errorf("text encoder: %w", err) + } + mlx.Eval(mlx.Collect(m.TextEncoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load transformer + m.Transformer = &Transformer{} + if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil { + return fmt.Errorf("transformer: %w", err) + } + mlx.Eval(mlx.Collect(m.Transformer)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + // Load VAE decoder + m.VAEDecoder = &VAEDecoder{} + if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil { + return fmt.Errorf("VAE decoder: %w", err) + } + mlx.Eval(mlx.Collect(m.VAEDecoder)...) + fmt.Printf(" (%.1f GB, peak %.1f GB)\n", + float64(mlx.MetalGetActiveMemory())/(1024*1024*1024), + float64(mlx.MetalGetPeakMemory())/(1024*1024*1024)) + + mem := mlx.MetalGetActiveMemory() + fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024)) + + return nil +} + +// Generate creates an image from a prompt. +func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.GenerateFromConfig(&GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + }) +} + +// GenerateWithProgress creates an image with progress callback. +func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) { + return m.GenerateFromConfig(&GenerateConfig{ + Prompt: prompt, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + Progress: progress, + }) +} + +// GenerateWithCFG creates an image with classifier-free guidance. +func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) { + return m.GenerateFromConfig(&GenerateConfig{ + Prompt: prompt, + NegativePrompt: negativePrompt, + CFGScale: cfgScale, + Width: width, + Height: height, + Steps: steps, + Seed: seed, + Progress: progress, + }) +} + +// GenerateFromConfig generates an image using the unified config struct. +func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) { + start := time.Now() + result, err := m.generate(cfg) + if err != nil { + return nil, err + } + if cfg.NegativePrompt != "" { + fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps) + } else { + fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps) + } + return result, nil +} + +// GenerateImage implements model.ImageModel interface. +func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) { + return m.Generate(prompt, width, height, steps, seed) +} + +// generate is the internal denoising pipeline. +func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) { + // Apply defaults + if cfg.Width <= 0 { + cfg.Width = 1024 + } + if cfg.Height <= 0 { + cfg.Height = 1024 + } + if cfg.Steps <= 0 { + cfg.Steps = 9 // Turbo default + } + if cfg.CFGScale <= 0 { + cfg.CFGScale = 4.0 + } + if cfg.LayerCache { + if cfg.CacheInterval <= 0 { + cfg.CacheInterval = 3 + } + if cfg.CacheLayers <= 0 { + cfg.CacheLayers = 15 // Half of 30 layers + } + } + + useCFG := cfg.NegativePrompt != "" + tcfg := m.Transformer.TransformerConfig + latentH := cfg.Height / 8 + latentW := cfg.Width / 8 + hTok := latentH / tcfg.PatchSize + wTok := latentW / tcfg.PatchSize + + // Text encoding with padding to multiple of 32 + var posEmb, negEmb *mlx.Array + { + posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512) + if useCFG { + negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512) + } + + // Pad both to same length (multiple of 32) + maxLen := posEmb.Shape()[1] + if useCFG && negEmb.Shape()[1] > maxLen { + maxLen = negEmb.Shape()[1] + } + if pad := (32 - (maxLen % 32)) % 32; pad > 0 { + maxLen += pad + } + + posEmb = padToLength(posEmb, maxLen) + if useCFG { + negEmb = padToLength(negEmb, maxLen) + mlx.Keep(posEmb, negEmb) + mlx.Eval(posEmb, negEmb) + } else { + mlx.Keep(posEmb) + mlx.Eval(posEmb) + } + } + + // Scheduler + scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig()) + scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok)) + + // Init latents [B, C, H, W] + var latents *mlx.Array + { + latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed) + mlx.Eval(latents) + } + + // RoPE cache + var ropeCache *RoPECache + { + ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1]) + mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin, + ropeCache.UnifiedCos, ropeCache.UnifiedSin) + mlx.Eval(ropeCache.UnifiedCos) + } + + // Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style) + var stepCache *cache.StepCache + if cfg.LayerCache { + stepCache = cache.NewStepCache(cfg.CacheLayers) + fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n", + cfg.CacheLayers, cfg.CacheInterval) + } + + // Denoising loop + for i := 0; i < cfg.Steps; i++ { + stepStart := time.Now() + if cfg.Progress != nil { + cfg.Progress(i+1, cfg.Steps) + } + + // GPU capture on step 2 if requested + if cfg.CapturePath != "" && i == 1 { + mlx.MetalStartCapture(cfg.CapturePath) + } + + tCurr := scheduler.Timesteps[i] + timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1})) + + patches := PatchifyLatents(latents, tcfg.PatchSize) + + var output *mlx.Array + if stepCache != nil { + // Use layer caching for faster inference + if useCFG { + posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache, + stepCache, i, cfg.CacheInterval) + // Note: CFG with layer cache shares the cache between pos/neg + // This is approximate but fast - neg prompt uses same cached shallow layers + negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache, + stepCache, i, cfg.CacheInterval) + diff := mlx.Sub(posOutput, negOutput) + scaledDiff := mlx.MulScalar(diff, cfg.CFGScale) + output = mlx.Add(negOutput, scaledDiff) + } else { + output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache, + stepCache, i, cfg.CacheInterval) + } + } else { + // Standard forward without caching + if useCFG { + posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache) + negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache) + diff := mlx.Sub(posOutput, negOutput) + scaledDiff := mlx.MulScalar(diff, cfg.CFGScale) + output = mlx.Add(negOutput, scaledDiff) + } else { + output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache) + } + } + + noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels) + noisePred = mlx.Neg(noisePred) + oldLatents := latents + latents = scheduler.Step(noisePred, latents, i) + + // Keep latents and any cached arrays + if stepCache != nil { + mlx.Keep(stepCache.Arrays()...) + } + mlx.Eval(latents) + oldLatents.Free() + + if cfg.CapturePath != "" && i == 1 { + mlx.MetalStopCapture() + } + + activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024) + peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024) + fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", + i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem) + } + + // Free denoising temporaries before VAE decode + posEmb.Free() + if negEmb != nil { + negEmb.Free() + } + ropeCache.ImgCos.Free() + ropeCache.ImgSin.Free() + ropeCache.CapCos.Free() + ropeCache.CapSin.Free() + ropeCache.UnifiedCos.Free() + ropeCache.UnifiedSin.Free() + if stepCache != nil { + stepCache.Free() + } + + // VAE decode + decoded := m.VAEDecoder.Decode(latents) + latents.Free() + + return decoded, nil +} + +// padToLength pads a sequence tensor to the target length by repeating the last token. +func padToLength(x *mlx.Array, targetLen int32) *mlx.Array { + shape := x.Shape() + currentLen := shape[1] + if currentLen >= targetLen { + return x + } + padLen := targetLen - currentLen + lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]}) + padding := mlx.Tile(lastToken, []int32{1, padLen, 1}) + return mlx.Concatenate([]*mlx.Array{x, padding}, 1) +} + +// CalculateShift computes the mu shift value for dynamic scheduling +func CalculateShift(imgSeqLen int32) float32 { + baseSeqLen := float32(256) + maxSeqLen := float32(4096) + baseShift := float32(0.5) + maxShift := float32(1.15) + + m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen) + b := baseShift - m*baseSeqLen + return float32(imgSeqLen)*m + b +} diff --git a/x/imagegen/nn/nn.go b/x/imagegen/nn/nn.go new file mode 100644 index 000000000..c61e59939 --- /dev/null +++ b/x/imagegen/nn/nn.go @@ -0,0 +1,203 @@ +//go:build mlx + +// Package nn provides neural network layer types. +package nn + +import "github.com/ollama/ollama/x/imagegen/mlx" + +// Layer is the interface for neural network layers with a Forward method. +type Layer interface { + Forward(x *mlx.Array) *mlx.Array +} + +// Linear applies an affine transformation: y = x @ W.T + b +// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention. +type Linear struct { + Weight *mlx.Array `weight:"weight"` // [out_features, in_features] + Bias *mlx.Array `weight:"bias,optional"` // [out_features] or nil +} + +// NewLinear creates a linear layer. +// Weight should be [out_features, in_features]. +func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear { + return &Linear{Weight: weight, Bias: bias} +} + +// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights. +// Quantizes the weight immediately and evaluates to break lazy dependencies. +func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear { + qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode) + // Eval immediately so bf16 weight can be freed + mlx.Eval(qw, scales, qbiases) + return &QuantizedLinear{ + Weight: qw, + Scales: scales, + QBiases: qbiases, + Bias: bias, + GroupSize: groupSize, + Bits: bits, + Mode: mode, + } +} + +// Forward applies the linear transformation: x @ W.T + bias +func (l *Linear) Forward(x *mlx.Array) *mlx.Array { + w := mlx.Transpose(l.Weight, 1, 0) + if l.Bias != nil { + return mlx.AddMM(l.Bias, x, w, 1.0, 1.0) + } + return mlx.Linear(x, w) +} + +// ToQuantized converts this Linear to a QuantizedLinear. +func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear { + qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode) + return &QuantizedLinear{ + Weight: qw, + Scales: scales, + QBiases: qbiases, + Bias: l.Bias, + GroupSize: groupSize, + Bits: bits, + Mode: mode, + } +} + +// QuantizedLinear applies an affine transformation using quantized weights. +// Equivalent to mlx.nn.QuantizedLinear. +type QuantizedLinear struct { + Weight *mlx.Array // Quantized weight data + Scales *mlx.Array // Scale factors for dequantization + QBiases *mlx.Array // Quantization biases (NOT layer bias) + Bias *mlx.Array // Layer bias [output_dims] or nil + GroupSize int + Bits int + Mode string +} + +// Forward applies the quantized linear transformation. +func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array { + out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode) + if ql.Bias != nil { + out = mlx.Add(out, ql.Bias) + } + return out +} + +// RMSNorm represents an RMS normalization layer. +type RMSNorm struct { + Weight *mlx.Array `weight:"weight"` + Eps float32 // optional: used if Forward called with eps=0 +} + +// NewRMSNorm creates an RMSNorm layer (for models not using weight loader). +func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm { + return &RMSNorm{Weight: weight, Eps: eps} +} + +// Forward applies RMS normalization. If eps=0, uses stored Eps. +func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array { + if eps == 0 { + eps = rn.Eps + } + return mlx.RMSNorm(x, rn.Weight, eps) +} + +// Embedding represents an embedding layer. +type Embedding struct { + Weight *mlx.Array `weight:"weight"` +} + +// NewEmbedding creates an embedding layer. +func NewEmbedding(weight *mlx.Array) *Embedding { + return &Embedding{Weight: weight} +} + +// Forward looks up embeddings by indices. +func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array { + return mlx.Take(e.Weight, indices, 0) +} + +// RepeatKV repeats K/V tensors for grouped query attention +// x: [B, num_kv_heads, S, head_dim] -> [B, num_heads, S, head_dim] +func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array { + if repeatFactor == 1 { + return x + } + shape := x.Shape() + // [B, num_kv_heads, S, head_dim] -> [B, num_kv_heads, 1, S, head_dim] + x = mlx.ExpandDims(x, 2) + // Repeat along the new axis + reps := []int32{1, 1, repeatFactor, 1, 1} + x = mlx.Tile(x, reps) + // Reshape: [B, num_kv_heads, repeat, S, head_dim] -> [B, num_kv_heads * repeat, S, head_dim] + return mlx.Reshape(x, shape[0], shape[1]*repeatFactor, shape[2], shape[3]) +} + +// ApplyCausalMask applies causal (lower triangular) mask to attention scores +func ApplyCausalMask(scores *mlx.Array) *mlx.Array { + // scores: [B, num_heads, S, S] + shape := scores.Shape() + seqLen := shape[2] + + // Create causal mask: 1 for positions to keep, 0 for positions to mask + mask := mlx.Tri(seqLen, seqLen, 0) + + // Where mask is 0, set score to -inf + negInf := mlx.NewScalarArray(float32(-1e9)) + + // Broadcast mask to match scores shape + mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, S, S] + + // Use where: if mask > 0, keep scores, else -inf + return mlx.Where(mask, scores, negInf) +} + +// ApplyCausalMaskWithOffset applies causal mask for cached attention +// scores: [B, num_heads, queryLen, keyLen] where keyLen = cacheLen + queryLen +// offset: the starting position of the new queries (i.e., cache length) +func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array { + if offset == 0 { + return ApplyCausalMask(scores) + } + + shape := scores.Shape() + queryLen := shape[2] + keyLen := shape[3] + + // For cached attention, new queries can attend to all cached keys plus + // new keys up to and including their position. + mask := mlx.Tri(queryLen, keyLen, int(offset)) + + negInf := mlx.NewScalarArray(float32(-1e9)) + mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, queryLen, keyLen] + + return mlx.Where(mask, scores, negInf) +} + +// LayerNorm represents a standard layer normalization layer (with bias). +type LayerNorm struct { + Weight *mlx.Array `weight:"weight"` + Bias *mlx.Array `weight:"bias"` + Eps float32 +} + +// Forward applies layer normalization: (x - mean) / sqrt(var + eps) * weight + bias +func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array { + eps := ln.Eps + if eps == 0 { + eps = 1e-5 + } + // Compute mean and variance along last dimension + mean := mlx.Mean(x, -1, true) + centered := mlx.Sub(x, mean) + variance := mlx.Mean(mlx.Mul(centered, centered), -1, true) + normalized := mlx.Mul(centered, mlx.RSqrt(mlx.AddScalar(variance, eps))) + + // Scale and shift + out := mlx.Mul(normalized, ln.Weight) + if ln.Bias != nil { + out = mlx.Add(out, ln.Bias) + } + return out +} diff --git a/x/imagegen/nn/nn_test.go b/x/imagegen/nn/nn_test.go new file mode 100644 index 000000000..2f8c04762 --- /dev/null +++ b/x/imagegen/nn/nn_test.go @@ -0,0 +1,356 @@ +//go:build mlx + +package nn + +import ( + "math" + "testing" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly. +func TestLinearNoBias(t *testing.T) { + // Weight: [out=2, in=3] -> transposed at forward time + weight := mlx.NewArrayFloat32([]float32{ + 1, 2, 3, // row 0 + 4, 5, 6, // row 1 + }, []int32{2, 3}) + mlx.Eval(weight) + + linear := NewLinear(weight, nil) + + // Input: [1, 3] + x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3}) + mlx.Eval(x) + + out := linear.Forward(x) + mlx.Eval(out) + + // Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15] + data := out.Data() + if len(data) != 2 || data[0] != 6 || data[1] != 15 { + t.Errorf("expected [6, 15], got %v", data) + } +} + +// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly. +func TestLinearWithBias(t *testing.T) { + weight := mlx.NewArrayFloat32([]float32{ + 1, 2, 3, + 4, 5, 6, + }, []int32{2, 3}) + bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2}) + mlx.Eval(weight, bias) + + linear := NewLinear(weight, bias) + + x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3}) + mlx.Eval(x) + + out := linear.Forward(x) + mlx.Eval(out) + + // Expected: [6, 15] + [10, 20] = [16, 35] + data := out.Data() + if len(data) != 2 || data[0] != 16 || data[1] != 35 { + t.Errorf("expected [16, 35], got %v", data) + } +} + +// TestLinearBatched verifies Linear works with batched input. +func TestLinearBatched(t *testing.T) { + weight := mlx.NewArrayFloat32([]float32{ + 1, 0, + 0, 1, + }, []int32{2, 2}) // Identity + mlx.Eval(weight) + + linear := NewLinear(weight, nil) + + // Batch of 3 inputs + x := mlx.NewArrayFloat32([]float32{ + 1, 2, + 3, 4, + 5, 6, + }, []int32{3, 2}) + mlx.Eval(x) + + out := linear.Forward(x) + mlx.Eval(out) + + // Identity should return same values + data := out.Data() + expected := []float32{1, 2, 3, 4, 5, 6} + for i, v := range expected { + if data[i] != v { + t.Errorf("at %d: expected %f, got %f", i, v, data[i]) + } + } +} + +// TestRMSNorm verifies RMSNorm computation. +func TestRMSNorm(t *testing.T) { + weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4}) + mlx.Eval(weight) + + norm := NewRMSNorm(weight, 1e-5) + + // Input with known RMS + x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4}) + mlx.Eval(x) + + out := norm.Forward(x, 0) // eps=0 uses stored Eps + mlx.Eval(out) + + // RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1] + data := out.Data() + for i, v := range data { + if math.Abs(float64(v-1.0)) > 1e-4 { + t.Errorf("at %d: expected ~1.0, got %f", i, v) + } + } +} + +// TestRMSNormWithScale verifies RMSNorm applies weight scaling. +func TestRMSNormWithScale(t *testing.T) { + weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4}) + mlx.Eval(weight) + + norm := NewRMSNorm(weight, 1e-5) + + x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4}) + mlx.Eval(x) + + out := norm.Forward(x, 0) // eps=0 uses stored Eps + mlx.Eval(out) + + // Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2] + data := out.Data() + for i, v := range data { + if math.Abs(float64(v-2.0)) > 1e-4 { + t.Errorf("at %d: expected ~2.0, got %f", i, v) + } + } +} + +// TestEmbedding verifies embedding lookup. +func TestEmbedding(t *testing.T) { + // Embedding table: 4 tokens, dim 3 + weight := mlx.NewArrayFloat32([]float32{ + 0, 0, 0, // token 0 + 1, 1, 1, // token 1 + 2, 2, 2, // token 2 + 3, 3, 3, // token 3 + }, []int32{4, 3}) + mlx.Eval(weight) + + emb := NewEmbedding(weight) + + // Look up tokens [1, 3, 0] + indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3}) + mlx.Eval(indices) + + out := emb.Forward(indices) + mlx.Eval(out) + + data := out.Data() + expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0} + for i, v := range expected { + if data[i] != v { + t.Errorf("at %d: expected %f, got %f", i, v, data[i]) + } + } +} + +// TestRepeatKV verifies K/V repetition for GQA. +func TestRepeatKV(t *testing.T) { + // [B=1, num_kv_heads=2, S=2, head_dim=2] + x := mlx.NewArrayFloat32([]float32{ + // head 0 + 1, 2, // pos 0 + 3, 4, // pos 1 + // head 1 + 5, 6, // pos 0 + 7, 8, // pos 1 + }, []int32{1, 2, 2, 2}) + mlx.Eval(x) + + // Repeat factor 2: 2 kv heads -> 4 heads + out := RepeatKV(x, 2) + mlx.Eval(out) + + shape := out.Shape() + if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 { + t.Errorf("expected shape [1,4,2,2], got %v", shape) + } + + data := out.Data() + // After repeat: head0, head0, head1, head1 + expected := []float32{ + 1, 2, 3, 4, // head 0 (original) + 1, 2, 3, 4, // head 0 (repeat) + 5, 6, 7, 8, // head 1 (original) + 5, 6, 7, 8, // head 1 (repeat) + } + for i, v := range expected { + if data[i] != v { + t.Errorf("at %d: expected %f, got %f", i, v, data[i]) + } + } +} + +// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged. +func TestRepeatKVNoOp(t *testing.T) { + x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2}) + mlx.Eval(x) + + out := RepeatKV(x, 1) + // Should return same pointer + if out != x { + t.Error("RepeatKV with factor 1 should return input unchanged") + } +} + +// TestApplyCausalMask verifies causal masking. +func TestApplyCausalMask(t *testing.T) { + // [B=1, heads=1, S=3, S=3] - all ones + scores := mlx.Ones(1, 1, 3, 3) + mlx.Eval(scores) + + out := ApplyCausalMask(scores) + mlx.Eval(out) + + data := out.Data() + // Lower triangular should be 1, upper should be -1e9 + // Row 0: [1, -inf, -inf] + // Row 1: [1, 1, -inf] + // Row 2: [1, 1, 1] + if data[0] != 1 || data[1] >= 0 || data[2] >= 0 { + t.Errorf("row 0 wrong: %v", data[0:3]) + } + if data[3] != 1 || data[4] != 1 || data[5] >= 0 { + t.Errorf("row 1 wrong: %v", data[3:6]) + } + if data[6] != 1 || data[7] != 1 || data[8] != 1 { + t.Errorf("row 2 wrong: %v", data[6:9]) + } +} + +// TestApplyCausalMaskWithOffset verifies causal masking with cache offset. +func TestApplyCausalMaskWithOffset(t *testing.T) { + // Simulating: cache has 2 tokens, adding 1 new query + // scores: [B=1, heads=1, queryLen=1, keyLen=3] + scores := mlx.Ones(1, 1, 1, 3) + mlx.Eval(scores) + + out := ApplyCausalMaskWithOffset(scores, 2) + mlx.Eval(out) + + data := out.Data() + // With offset=2, query at position 2 can attend to all 3 positions + if data[0] != 1 || data[1] != 1 || data[2] != 1 { + t.Errorf("expected [1, 1, 1], got %v", data) + } +} + +// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal. +func TestApplyCausalMaskWithOffsetZero(t *testing.T) { + scores := mlx.Ones(1, 1, 2, 2) + mlx.Eval(scores) + + out := ApplyCausalMaskWithOffset(scores, 0) + mlx.Eval(out) + + data := out.Data() + // Standard causal: [1, -inf], [1, 1] + if data[0] != 1 || data[1] >= 0 { + t.Errorf("row 0 wrong: %v", data[0:2]) + } + if data[2] != 1 || data[3] != 1 { + t.Errorf("row 1 wrong: %v", data[2:4]) + } +} + +// BenchmarkLinearSmall benchmarks small Linear forward pass. +func BenchmarkLinearSmall(b *testing.B) { + weight := mlx.RandomNormal([]int32{256, 256}, 42) + mlx.Eval(weight) + + linear := NewLinear(weight, nil) + + x := mlx.RandomNormal([]int32{1, 256}, 43) + mlx.Eval(x) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out := linear.Forward(x) + mlx.Eval(out) + } +} + +// BenchmarkLinearLarge benchmarks larger Linear forward pass. +func BenchmarkLinearLarge(b *testing.B) { + weight := mlx.RandomNormal([]int32{4096, 4096}, 42) + mlx.Eval(weight) + + linear := NewLinear(weight, nil) + + x := mlx.RandomNormal([]int32{1, 4096}, 43) + mlx.Eval(x) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out := linear.Forward(x) + mlx.Eval(out) + } +} + +// BenchmarkRMSNorm benchmarks RMSNorm forward pass. +func BenchmarkRMSNorm(b *testing.B) { + weight := mlx.Ones(4096) + mlx.Eval(weight) + + norm := NewRMSNorm(weight, 1e-5) + + x := mlx.RandomNormal([]int32{1, 4096}, 42) + mlx.Eval(x) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out := norm.Forward(x, 0) + mlx.Eval(out) + } +} + +// BenchmarkEmbedding benchmarks embedding lookup. +func BenchmarkEmbedding(b *testing.B) { + // Typical vocab size + weight := mlx.RandomNormal([]int32{32000, 4096}, 42) + mlx.Eval(weight) + + emb := NewEmbedding(weight) + + // Single token lookup + indices := mlx.NewArrayInt32([]int32{1000}, []int32{1}) + mlx.Eval(indices) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out := emb.Forward(indices) + mlx.Eval(out) + } +} + +// BenchmarkRepeatKV benchmarks K/V repetition. +func BenchmarkRepeatKV(b *testing.B) { + // Typical GQA setup: 8 kv heads -> 32 heads + x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42) + mlx.Eval(x) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out := RepeatKV(x, 4) + mlx.Eval(out) + } +} diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go new file mode 100644 index 000000000..cbf2d2848 --- /dev/null +++ b/x/imagegen/safetensors/loader.go @@ -0,0 +1,170 @@ +//go:build mlx + +package safetensors + +import ( + "fmt" + "reflect" + "strings" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// LoadModule loads weights into a struct using reflection and struct tags. +// +// Struct tags use the format: `weight:"path[,optional]"` +// - path: the weight name suffix (appended to prefix) +// - optional: if present, missing weights don't cause errors +// - "-": skip this field entirely +// - no tag on struct pointer: recurse with current prefix +// - no tag on *mlx.Array: skip (computed fields don't need loading) +// +// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes. +// The slice must be pre-allocated to the correct length. +// +// Example: +// +// type Attention struct { +// QProj *nn.Linear `weight:"self_attn.q_proj"` +// KProj *nn.Linear `weight:"self_attn.k_proj"` +// Cache *mlx.Array // no tag = skipped (computed field) +// } +// +// err := LoadModule(&attn, weights, "model.layers.0") +func LoadModule(dst any, weights *ModelWeights, prefix string) error { + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.IsNil() { + return fmt.Errorf("LoadModule: dst must be a non-nil pointer") + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind()) + } + + var errs []string + loadStruct(v, weights, prefix, &errs, false) + + if len(errs) > 0 { + return fmt.Errorf("LoadModule: missing weights:\n %s", strings.Join(errs, "\n ")) + } + return nil +} + +// loadStruct recursively loads weights into a struct value. +func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) { + t := v.Type() + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldVal := v.Field(i) + + // Skip unexported fields + if !fieldVal.CanSet() { + continue + } + + // Parse tag + tag, hasTag := field.Tag.Lookup("weight") + if tag == "-" { + continue + } + + // Parse tag options + optional := parentOptional + weightPath := tag + if idx := strings.Index(tag, ","); idx != -1 { + weightPath = tag[:idx] + if strings.Contains(tag[idx+1:], "optional") { + optional = true + } + } + + // Build full path + fullPath := joinPath(prefix, weightPath) + + // For struct pointers without a tag, recurse with current prefix + if !hasTag && fieldVal.Kind() == reflect.Ptr { + elemType := fieldVal.Type().Elem() + if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) { + if fieldVal.IsNil() { + fieldVal.Set(reflect.New(elemType)) + } + loadStruct(fieldVal.Elem(), weights, prefix, errs, optional) + continue + } + } + + // Handle by kind + switch fieldVal.Kind() { + case reflect.Ptr: + elemType := fieldVal.Type().Elem() + + // *mlx.Array - load directly (but skip if no tag - computed fields) + if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) { + if !hasTag { + continue // no tag on *mlx.Array = computed field, skip + } + arr, err := weights.GetTensor(fullPath) + if err != nil { + if !optional { + *errs = append(*errs, fullPath) + } + continue + } + fieldVal.Set(reflect.ValueOf(arr)) + continue + } + + // Pointer to struct - allocate and recurse + if elemType.Kind() == reflect.Struct { + if optional && !hasWeightsWithPrefix(weights, fullPath) { + continue + } + if fieldVal.IsNil() { + fieldVal.Set(reflect.New(elemType)) + } + loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional) + } + + case reflect.Slice: + elemType := fieldVal.Type().Elem() + if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct { + loadSlice(fieldVal, weights, fullPath, errs) + } + } + } +} + +// hasWeightsWithPrefix checks if any weights exist with the given prefix. +func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool { + for _, name := range weights.ListTensors() { + if strings.HasPrefix(name, prefix+".") || name == prefix { + return true + } + } + return false +} + +// loadSlice loads weights into each element of a slice of struct pointers. +func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) { + elemStructType := v.Type().Elem().Elem() + + for i := 0; i < v.Len(); i++ { + elem := v.Index(i) + if elem.IsNil() { + elem.Set(reflect.New(elemStructType)) + } + loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false) + } +} + +// joinPath joins path segments with dots, handling empty segments. +func joinPath(prefix, suffix string) string { + if prefix == "" { + return suffix + } + if suffix == "" { + return prefix + } + return prefix + "." + suffix +} diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go new file mode 100644 index 000000000..e2dfb0960 --- /dev/null +++ b/x/imagegen/safetensors/safetensors.go @@ -0,0 +1,280 @@ +//go:build mlx + +package safetensors + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +// SafetensorHeader represents the JSON header of a safetensors file +type SafetensorHeader map[string]TensorInfo + +// TensorInfo contains metadata about a tensor +type TensorInfo struct { + Dtype string `json:"dtype"` + Shape []int32 `json:"shape"` + DataOffsets [2]int `json:"data_offsets"` +} + +// parseSafetensorHeader reads only the JSON header from a safetensors file. +func parseSafetensorHeader(path string) (SafetensorHeader, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer f.Close() + + var headerSize uint64 + if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { + return nil, fmt.Errorf("failed to read header size: %w", err) + } + + headerBytes := make([]byte, headerSize) + if _, err := f.Read(headerBytes); err != nil { + return nil, fmt.Errorf("failed to read header: %w", err) + } + + var header SafetensorHeader + if err := json.Unmarshal(headerBytes, &header); err != nil { + return nil, fmt.Errorf("failed to parse header: %w", err) + } + + delete(header, "__metadata__") + return header, nil +} + +// dtypeFromString converts safetensors dtype string to mlx.Dtype +func dtypeFromString(s string) mlx.Dtype { + switch strings.ToUpper(s) { + case "F32", "FLOAT32": + return mlx.DtypeFloat32 + case "F16", "FLOAT16": + return mlx.DtypeFloat16 + case "BF16", "BFLOAT16": + return mlx.DtypeBFloat16 + case "I32", "INT32": + return mlx.DtypeInt32 + case "I64", "INT64": + return mlx.DtypeInt64 + case "U8", "UINT8": + return mlx.DtypeUint8 + default: + return mlx.DtypeFloat32 + } +} + +// ModelWeights manages weights from multiple safetensor files. +type ModelWeights struct { + dir string // Model directory + tensorFiles map[string]string // tensor name -> file path + tensorInfo map[string]TensorInfo // tensor name -> metadata + nativeCache map[string]*mlx.SafetensorsFile // file path -> loaded native handle + cache map[string]*mlx.Array // tensor name -> array (after Load) +} + +// LoadModelWeights scans safetensor files and builds a tensor index. +// This only reads JSON headers, not tensor data. +func LoadModelWeights(dir string) (*ModelWeights, error) { + mw := &ModelWeights{ + dir: dir, + tensorFiles: make(map[string]string), + tensorInfo: make(map[string]TensorInfo), + nativeCache: make(map[string]*mlx.SafetensorsFile), + } + + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("failed to read directory: %w", err) + } + + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), ".safetensors") { + path := filepath.Join(dir, entry.Name()) + + header, err := parseSafetensorHeader(path) + if err != nil { + return nil, fmt.Errorf("failed to parse %s: %w", entry.Name(), err) + } + + for name, info := range header { + mw.tensorFiles[name] = path + mw.tensorInfo[name] = info + } + } + } + + if len(mw.tensorFiles) == 0 { + return nil, fmt.Errorf("no safetensor files found in %s", dir) + } + + return mw, nil +} + +// Load loads all tensors into cache with the specified dtype. +// If dtype is 0, tensors are loaded in their original dtype. +// Automatically uses streaming (memory-efficient) when dtype conversion is needed, +// or native loading when tensors are already in the target dtype. +func (mw *ModelWeights) Load(dtype mlx.Dtype) error { + if dtype == 0 { + return mw.loadNative() + } + + // Check if any tensor needs conversion + needsConversion := false + for name := range mw.tensorFiles { + info := mw.tensorInfo[name] + if dtypeFromString(info.Dtype) != dtype { + needsConversion = true + break + } + } + + if needsConversion { + return mw.loadStreaming(dtype) + } + return mw.loadNative() +} + +// loadNative loads all tensors using the native memory-mapped loader. +func (mw *ModelWeights) loadNative() error { + mw.cache = make(map[string]*mlx.Array) + + fileToTensors := make(map[string][]string) + for name, path := range mw.tensorFiles { + fileToTensors[path] = append(fileToTensors[path], name) + } + + for path, names := range fileToTensors { + native, err := mlx.LoadSafetensorsNative(path) + if err != nil { + return fmt.Errorf("failed to load %s: %w", path, err) + } + + for _, name := range names { + arr := native.Get(name) + if arr == nil { + native.Free() + return fmt.Errorf("tensor %q not found in %s", name, path) + } + mw.cache[name] = arr + } + + mw.nativeCache[path] = native + } + + return nil +} + +// loadStreaming loads tensors with dtype conversion. +// Uses the same pattern as Python: replace each entry in the map after conversion, +// so the original tensor loses its reference and can be freed. +func (mw *ModelWeights) loadStreaming(dtype mlx.Dtype) error { + mw.cache = make(map[string]*mlx.Array) + + fileToTensors := make(map[string][]string) + for name, path := range mw.tensorFiles { + fileToTensors[path] = append(fileToTensors[path], name) + } + + for path, names := range fileToTensors { + native, err := mlx.LoadSafetensorsNative(path) + if err != nil { + return fmt.Errorf("failed to load %s: %w", path, err) + } + + for _, name := range names { + src := native.Get(name) + if src == nil { + native.Free() + return fmt.Errorf("tensor %q not found in %s", name, path) + } + + dst := mlx.AsType(src, dtype) + mlx.Eval(dst) + native.Set(name, dst) + mw.cache[name] = dst + } + + native.Free() + } + + return nil +} + +// Get returns a tensor from cache. Call Load() first. +func (mw *ModelWeights) Get(name string) (*mlx.Array, error) { + if mw.cache == nil { + return nil, fmt.Errorf("cache not initialized: call Load() first") + } + arr, ok := mw.cache[name] + if !ok { + return nil, fmt.Errorf("tensor %q not found in cache", name) + } + return arr, nil +} + +// GetTensor loads a tensor using the native loader without caching. +// For bulk loading, use Load() + Get() instead. +func (mw *ModelWeights) GetTensor(name string) (*mlx.Array, error) { + if mw.cache != nil { + if arr, ok := mw.cache[name]; ok { + return arr, nil + } + } + + path, ok := mw.tensorFiles[name] + if !ok { + return nil, fmt.Errorf("tensor %q not found", name) + } + + native, ok := mw.nativeCache[path] + if !ok { + var err error + native, err = mlx.LoadSafetensorsNative(path) + if err != nil { + return nil, fmt.Errorf("failed to load %s: %w", path, err) + } + mw.nativeCache[path] = native + } + + return native.Get(name), nil +} + +// GetTensorInfo returns metadata about a tensor without loading it. +func (mw *ModelWeights) GetTensorInfo(name string) (TensorInfo, bool) { + info, ok := mw.tensorInfo[name] + return info, ok +} + +// ListTensors returns all tensor names. +func (mw *ModelWeights) ListTensors() []string { + names := make([]string, 0, len(mw.tensorFiles)) + for name := range mw.tensorFiles { + names = append(names, name) + } + sort.Strings(names) + return names +} + +// HasTensor checks if a tensor exists. +func (mw *ModelWeights) HasTensor(name string) bool { + _, ok := mw.tensorFiles[name] + return ok +} + +// ReleaseAll releases all cached native file handles. +func (mw *ModelWeights) ReleaseAll() { + for path, native := range mw.nativeCache { + native.Free() + delete(mw.nativeCache, path) + } +} + diff --git a/x/imagegen/safetensors/safetensors_test.go b/x/imagegen/safetensors/safetensors_test.go new file mode 100644 index 000000000..f00268751 --- /dev/null +++ b/x/imagegen/safetensors/safetensors_test.go @@ -0,0 +1,167 @@ +//go:build mlx + +package safetensors + +import ( + "os" + "path/filepath" + "testing" + + "github.com/ollama/ollama/x/imagegen/mlx" +) + +func TestLoadModelWeights(t *testing.T) { + // Skip if no model available + modelDir := "../weights/gpt-oss-20b" + if _, err := os.Stat(modelDir); os.IsNotExist(err) { + t.Skip("model weights not available") + } + + mw, err := LoadModelWeights(modelDir) + if err != nil { + t.Fatalf("LoadModelWeights: %v", err) + } + defer mw.ReleaseAll() + + // Check we found tensors + tensors := mw.ListTensors() + if len(tensors) == 0 { + t.Fatal("no tensors found") + } + t.Logf("found %d tensors", len(tensors)) + + // Check HasTensor + if !mw.HasTensor(tensors[0]) { + t.Errorf("HasTensor(%q) = false", tensors[0]) + } + if mw.HasTensor("nonexistent.weight") { + t.Error("HasTensor returned true for nonexistent tensor") + } +} + +func TestGetTensor(t *testing.T) { + modelDir := "../weights/gpt-oss-20b" + if _, err := os.Stat(modelDir); os.IsNotExist(err) { + t.Skip("model weights not available") + } + + mw, err := LoadModelWeights(modelDir) + if err != nil { + t.Fatalf("LoadModelWeights: %v", err) + } + defer mw.ReleaseAll() + + tensors := mw.ListTensors() + if len(tensors) == 0 { + t.Skip("no tensors") + } + + // Load first tensor + arr, err := mw.GetTensor(tensors[0]) + if err != nil { + t.Fatalf("GetTensor(%q): %v", tensors[0], err) + } + + // Verify it has a shape + shape := arr.Shape() + if len(shape) == 0 { + t.Error("tensor has no shape") + } + t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype()) +} + +func TestLoadWithDtype(t *testing.T) { + modelDir := "../weights/gpt-oss-20b" + if _, err := os.Stat(modelDir); os.IsNotExist(err) { + t.Skip("model weights not available") + } + + mw, err := LoadModelWeights(modelDir) + if err != nil { + t.Fatalf("LoadModelWeights: %v", err) + } + defer mw.ReleaseAll() + + // Load all tensors as bfloat16 + if err := mw.Load(mlx.DtypeBFloat16); err != nil { + t.Fatalf("Load: %v", err) + } + + // Get a tensor from cache + tensors := mw.ListTensors() + arr, err := mw.Get(tensors[0]) + if err != nil { + t.Fatalf("Get: %v", err) + } + + // Verify dtype (unless it was already bf16) + t.Logf("%s: dtype=%v", tensors[0], arr.Dtype()) +} + +func TestLookupTensor(t *testing.T) { + modelDir := "../weights/gpt-oss-20b" + if _, err := os.Stat(modelDir); os.IsNotExist(err) { + t.Skip("model weights not available") + } + + mw, err := LoadModelWeights(modelDir) + if err != nil { + t.Fatalf("LoadModelWeights: %v", err) + } + defer mw.ReleaseAll() + + // HasTensor returns false for nonexistent + if mw.HasTensor("nonexistent") { + t.Error("HasTensor should return false for nonexistent") + } + + // HasTensor returns true for existing tensor + tensors := mw.ListTensors() + if !mw.HasTensor(tensors[0]) { + t.Error("HasTensor should return true for existing tensor") + } +} + +func TestParseSafetensorHeader(t *testing.T) { + modelDir := "../weights/gpt-oss-20b" + if _, err := os.Stat(modelDir); os.IsNotExist(err) { + t.Skip("model weights not available") + } + + // Find a safetensors file + entries, err := os.ReadDir(modelDir) + if err != nil { + t.Fatal(err) + } + + var stFile string + for _, e := range entries { + if filepath.Ext(e.Name()) == ".safetensors" { + stFile = filepath.Join(modelDir, e.Name()) + break + } + } + if stFile == "" { + t.Skip("no safetensors file found") + } + + header, err := parseSafetensorHeader(stFile) + if err != nil { + t.Fatalf("parseSafetensorHeader: %v", err) + } + + if len(header) == 0 { + t.Error("header is empty") + } + + // Check a tensor has valid info + for name, info := range header { + if info.Dtype == "" { + t.Errorf("%s: empty dtype", name) + } + if len(info.Shape) == 0 { + t.Errorf("%s: empty shape", name) + } + break // just check one + } +} diff --git a/x/imagegen/tokenizer/README.md b/x/imagegen/tokenizer/README.md new file mode 100644 index 000000000..03de51510 --- /dev/null +++ b/x/imagegen/tokenizer/README.md @@ -0,0 +1,85 @@ +# Tokenizer + +Tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms. The goal of this package is to see if a pure Go tokenizer can be fast and correct. It primarily supports the `imagegen` models however it (or parts of it) could be considered to replace Ollama's tokenizer in the `model` package. + +## Features + +- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding +- **SentencePiece** - Gemma style with `▁` space handling +- **WordPiece** - BERT style with `##` continuation tokens +- **Parallel encoding** - Automatic parallelization for inputs >4KB +- **HuggingFace compatible** - Loads `tokenizer.json` directly + +## Usage + +```go +import "github.com/ollama/ollama/x/imagegen/tokenizer" + +// Load from HuggingFace model directory +tok, err := tokenizer.Load("./weights/Llama-3.2-1B") +if err != nil { + log.Fatal(err) +} + +// Encode text to token IDs +ids := tok.Encode("Hello, world!", false) // false = don't add BOS + +// Decode back to text +text := tok.Decode(ids) + +// Check special tokens +if tok.IsEOS(ids[len(ids)-1]) { + // End of sequence +} +``` + +## Performance + +Benchmarks on Apple M3 Max: + +| Input Size | Encode | Decode | Tokens | +|------------|--------|--------|--------| +| 1 KB | 14.5 MB/s | 267 MB/s | 231 | +| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 | +| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 | +| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 | + +Comparison with other implementations (10 MB input): + +| Implementation | Encode Speed | Notes | +|----------------|--------------|-------| +| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB | +| tiktoken (Rust) | ~17 MB/s | Highly optimized regex | +| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking | + +## Performance Opportunities + +Potential optimizations not yet implemented: + +| Optimization | Expected Gain | Complexity | +|--------------|---------------|------------| +| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium | +| Custom regex engine (like tiktoken) | 1.5-2x | High | +| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium | +| Assembly BPE merge loop | 1.2-1.5x | High | +| Memoization for repeated substrings | Variable | Low | + +Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine. + +## Not Yet Implemented + +| Feature | Used By | Notes | +|---------|---------|-------| +| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) | +| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. | +| Custom pretokenizers | Model-specific | Beyond standard patterns | + +Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters. + +## Files + +| File | Description | +|------|-------------| +| `tokenizer.go` | Main implementation (~1000 lines) | +| `tokenizer_test.go` | Tests and benchmarks | +| `testdata/` | Mini tokenizer for unit tests | diff --git a/x/imagegen/tokenizer/testdata/mini_llama.json b/x/imagegen/tokenizer/testdata/mini_llama.json new file mode 100644 index 000000000..05c609767 --- /dev/null +++ b/x/imagegen/tokenizer/testdata/mini_llama.json @@ -0,0 +1 @@ +{"model": {"type": "BPE", "vocab": {"!": 0, "\"": 1, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "fé": 59958, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "r": 81, "q": 80, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "¡": 94, "¢": 95, "£": 96, "¤": 97, "¥": 98, "¦": 99, "§": 100, "¨": 101, "World": 10343, "©": 102, "ª": 103, "«": 104, "¬": 105, "®": 106, "world": 14957, "¯": 107, "°": 108, "±": 109, "²": 110, "³": 111, "´": 112, "µ": 113, "¶": 114, "·": 115, "¸": 116, "¹": 117, "º": 118, "»": 119, "¼": 120, "½": 121, "¾": 122, "¿": 123, "À": 124, "Á": 125, "Â": 126, "Ã": 127, "Ä": 128, "Å": 129, "Æ": 130, "Ç": 131, "È": 132, "É": 133, "Ê": 134, "Ë": 135, "Ì": 136, "Í": 137, "Î": 138, "Ï": 139, "Ð": 140, "Ñ": 141, "Ò": 142, "Ó": 143, "Ô": 144, "Õ": 145, "Ö": 146, "×": 147, "Ø": 148, "Ù": 149, "Ú": 150, "Û": 151, "Ü": 152, "Ý": 153, "Þ": 154, "ß": 155, "à": 156, "á": 157, "â": 158, "ã": 159, "ä": 160, "å": 161, "æ": 162, "ç": 163, "è": 164, "é": 165, "ê": 166, "ë": 167, "ì": 168, "Ġhello": 24748, "í": 169, "î": 170, "ï": 171, "ð": 172, "ñ": 173, "Hello": 9906, "ò": 174, "ó": 175, "ô": 176, "õ": 177, "ö": 178, "Ġ{}": 4792, "÷": 179, "ø": 180, "ù": 181, "ú": 182, "û": 183, "ü": 184, "ý": 185, "þ": 186, "ÿ": 187, "Ā": 188, "ā": 189, "Ă": 190, "ă": 191, "Ċ": 198, "Ą": 192, "ą": 193, "Ć": 194, "ć": 195, "Ĉ": 196, "ĉ": 197, "ċ": 199, "Č": 200, "č": 201, "Ď": 202, "ď": 203, "Đ": 204, "đ": 205, "Ē": 206, "ē": 207, "Ĕ": 208, "ĕ": 209, "Ė": 210, "ė": 211, "Ę": 212, "ę": 213, "Ġ": 220, "Ě": 214, "ě": 215, "Ĝ": 216, "ĝ": 217, "Ğ": 218, "ğ": 219, "ġ": 221, "Ģ": 222, "ģ": 223, "Ĥ": 224, "ĥ": 225, "Ħ": 226, "ħ": 227, "Ĩ": 228, "ĩ": 229, "Ī": 230, "ī": 231, "Ĭ": 232, "ĭ": 233, "Į": 234, "į": 235, "İ": 236, "ı": 237, "IJ": 238, "ij": 239, "Ĵ": 240, "ĵ": 241, "Ķ": 242, "ķ": 243, "ĸ": 244, "Ĺ": 245, "ĺ": 246, "Ļ": 247, "ļ": 248, "Ľ": 249, "ĠĠ": 256, "ľ": 250, "Ŀ": 251, "ŀ": 252, "Ł": 253, "rer": 38149, "ĠĠĠ": 262, "ł": 254, "Ń": 255, "'m": 2846, "'re": 2351, "can": 4919, "func": 2900, "()": 368, "Ġworld": 1917, "Ġmain": 1925, "00": 410, "123": 4513, "000": 931, "ca": 936, "'t": 956, "é": 978, "hello": 15339, "Ġw": 289, "orld": 1410, "Ġwor": 4191, "ld": 509, "main": 3902, "Ġm": 296, "ain": 467, "Ġma": 7643, "in": 258, "Ġmai": 17154, "re": 265, "'r": 97670, "unc": 1371, "fun": 12158, "fu": 33721, "nc": 1031, "ma": 1764, "mai": 77585, "wor": 50810, "or": 269, "Ġwo": 24670, "23": 1419, "12": 717, "{}": 6390, "Ġ{": 314, "an": 276, "ello": 4896, "Hel": 33813, "lo": 385, "Hell": 81394, "un": 359, "hel": 50222, "hell": 57195, "ai": 2192, "wo": 1146, "Ġh": 305, "Ġhel": 11591, "Ġhell": 15123, "el": 301, "He": 1548, "er": 261, "he": 383, "ell": 616, "ll": 657}, "merges": ["Ġ Ġ", "Ġ ĠĠ", "ĠĠ Ġ", "( )", "0 0", "0 00", "00 0", "c a", "' t", "à ©", "Ġ world", "Ġw orld", "Ġwor ld", "Ġ main", "Ġm ain", "Ġma in", "Ġmai n", "' re", "'r e", "' m", "f unc", "fun c", "fu nc", "m ain", "ma in", "mai n", "Ġ wor", "Ġw or", "Ġwo r", "1 23", "12 3", "Ġ {}", "Ġ{ }", "c an", "ca n", "{ }", "Ġ ma", "Ġm a", "H ello", "Hel lo", "Hell o", "W orld", "f un", "fu n", "w orld", "wor ld", "h ello", "hel lo", "hell o", "Ġ mai", "Ġm ai", "Ġma i", "Ġ wo", "Ġw o", "Ġ hello", "Ġh ello", "Ġhel lo", "Ġhell o", "f u", "H el", "He l", "r er", "re r", "h el", "he l", "w or", "wo r", "h ell", "he ll", "hel l", "f é", "m ai", "ma i", "H ell", "He ll", "Hel l", "' r"]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}, "behavior": "Isolated", "invert": false}, {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": false}]}, "decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true}, "added_tokens": [{"id": 128000, "content": "<|begin_of_text|>", "special": true}, {"id": 128001, "content": "<|end_of_text|>", "special": true}]} \ No newline at end of file diff --git a/x/imagegen/tokenizer/tokenizer.go b/x/imagegen/tokenizer/tokenizer.go new file mode 100644 index 000000000..8628f273b --- /dev/null +++ b/x/imagegen/tokenizer/tokenizer.go @@ -0,0 +1,1013 @@ +//go:build mlx + +// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models +// +// Based on standard BPE algorithm (Sennrich et al. 2015) with: +// - GPT-2 byte-level encoding (OpenAI tiktoken) +// - HuggingFace tokenizer.json pretokenizer patterns +// - SentencePiece ▁-style space handling + +package tokenizer + +import ( + "encoding/json" + "fmt" + "os" + "regexp" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +// TokenizerType identifies the tokenization algorithm +type TokenizerType int + +const ( + TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE + TokenizerSentencePiece // SentencePiece with ▁ for spaces + TokenizerWordPiece // BERT style with ## continuations +) + +// Vocabulary holds the tokenizer vocabulary and merges +type Vocabulary struct { + Values []string + Reverse map[string]int32 + Merges map[string]int + + BOS int32 + EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has and ) + PAD int32 // Padding token (often <|endoftext|> or ) + AddBOS bool + AddEOS bool + + // Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found) + byteTokens [256]int32 +} + +// Tokenizer handles BPE, SentencePiece, and WordPiece tokenization +type Tokenizer struct { + vocab *Vocabulary + pretokenizer *regexp.Regexp + specialTokens map[string]int32 // Special tokens for direct lookup + typ TokenizerType // Algorithm type + unkToken int32 // [UNK] token ID for WordPiece fallback +} + +// Precomputed GPT-2 byte-level encoding table +// Maps byte values to their encoded rune equivalents +var byteToRune [256]rune + +func init() { + for b := 0; b < 256; b++ { + r := rune(b) + switch { + case r == 0x00ad: + r = 0x0143 + case r <= 0x0020: + r = r + 0x0100 + case r >= 0x007f && r <= 0x00a0: + r = r + 0x00a2 + } + byteToRune[b] = r + } +} + +// loadSpecialTokenConfig loads special token configuration from HuggingFace companion files. +// +// Loading priority for EOS tokens (can be single int or []int): +// 1. generation_config.json - eos_token_id (preferred, matches HuggingFace generation) +// 2. config.json - eos_token_id (model config fallback) +// 3. tokenizer_config.json - eos_token string + add_bos/add_eos flags +// 4. special_tokens_map.json - final fallback +func loadSpecialTokenConfig(dir string, t *Tokenizer) { + // Helper to parse eos_token_id which can be int or []int + parseTokenIDs := func(v interface{}) []int32 { + switch val := v.(type) { + case float64: + return []int32{int32(val)} + case []interface{}: + ids := make([]int32, 0, len(val)) + for _, id := range val { + if f, ok := id.(float64); ok { + ids = append(ids, int32(f)) + } + } + return ids + } + return nil + } + + // Priority 1: generation_config.json (eos_token_id can be int or []int) + if data, err := os.ReadFile(dir + "generation_config.json"); err == nil { + var config struct { + EOSTokenID interface{} `json:"eos_token_id"` + BOSTokenID interface{} `json:"bos_token_id"` + } + if err := json.Unmarshal(data, &config); err == nil { + if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 { + t.vocab.EOS = ids + } + if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 { + t.vocab.BOS = ids[0] + } + } + } + + // Priority 2: config.json (model config, same format) + if len(t.vocab.EOS) == 0 || t.vocab.BOS < 0 { + if data, err := os.ReadFile(dir + "config.json"); err == nil { + var config struct { + EOSTokenID interface{} `json:"eos_token_id"` + BOSTokenID interface{} `json:"bos_token_id"` + } + if err := json.Unmarshal(data, &config); err == nil { + if len(t.vocab.EOS) == 0 { + if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 { + t.vocab.EOS = ids + } + } + if t.vocab.BOS < 0 { + if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 { + t.vocab.BOS = ids[0] + } + } + } + } + } + + // Priority 3: tokenizer_config.json (token strings + add_bos/add_eos flags) + if data, err := os.ReadFile(dir + "tokenizer_config.json"); err == nil { + var config struct { + BOSToken interface{} `json:"bos_token"` + EOSToken interface{} `json:"eos_token"` + PADToken interface{} `json:"pad_token"` + AddBOSToken *bool `json:"add_bos_token"` + AddEOSToken *bool `json:"add_eos_token"` + } + if err := json.Unmarshal(data, &config); err == nil { + if t.vocab.BOS < 0 { + if bosStr := extractTokenString(config.BOSToken); bosStr != "" { + if id, ok := t.specialTokens[bosStr]; ok { + t.vocab.BOS = id + } + } + } + if len(t.vocab.EOS) == 0 { + if eosStr := extractTokenString(config.EOSToken); eosStr != "" { + if id, ok := t.specialTokens[eosStr]; ok { + t.vocab.EOS = []int32{id} + } + } + } + if t.vocab.PAD < 0 { + if padStr := extractTokenString(config.PADToken); padStr != "" { + if id, ok := t.specialTokens[padStr]; ok { + t.vocab.PAD = id + } + } + } + if config.AddBOSToken != nil { + t.vocab.AddBOS = *config.AddBOSToken + } + if config.AddEOSToken != nil { + t.vocab.AddEOS = *config.AddEOSToken + } + } + } + + // Priority 4: special_tokens_map.json (final fallback) + if data, err := os.ReadFile(dir + "special_tokens_map.json"); err == nil { + var tokensMap map[string]interface{} + if err := json.Unmarshal(data, &tokensMap); err == nil { + if t.vocab.BOS < 0 { + if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" { + if id, ok := t.specialTokens[bosStr]; ok { + t.vocab.BOS = id + } + } + } + if len(t.vocab.EOS) == 0 { + if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" { + if id, ok := t.specialTokens[eosStr]; ok { + t.vocab.EOS = []int32{id} + } + } + } + if t.vocab.PAD < 0 { + if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" { + if id, ok := t.specialTokens[padStr]; ok { + t.vocab.PAD = id + } + } + } + } + } +} + +// extractTokenString extracts the token string from various formats used in HuggingFace configs. +// Tokens can be represented as: +// - string: "token" +// - object: {"content": "token", ...} +func extractTokenString(v interface{}) string { + if v == nil { + return "" + } + // Direct string + if s, ok := v.(string); ok { + return s + } + // Object with content field + if m, ok := v.(map[string]interface{}); ok { + if content, ok := m["content"].(string); ok { + return content + } + } + return "" +} + +// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be +// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features: +// - (?!\S) negative lookahead - RE2 doesn't support this +// - (?i:...) inline case-insensitive groups - RE2 doesn't support this +// +// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex(). +// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word). +// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior. +func rewritePatternForRE2(pattern string) string { + // Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex() + pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`) + + // Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style) + // IMPORTANT: Must be done before the non-optional version to avoid partial replacement + pattern = strings.ReplaceAll(pattern, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`) + + // Expand case-insensitive contraction pattern to explicit alternations + // (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD] + pattern = strings.ReplaceAll(pattern, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)`, + `(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`) + + return pattern +} + +// Load loads a tokenizer from a path which can be: +// - A tokenizer.json file +// - A directory containing tokenizer.json or vocab.json + merges.txt +func Load(path string) (*Tokenizer, error) { + // Check if path is a directory + if info, err := os.Stat(path); err == nil && info.IsDir() { + dir := strings.TrimSuffix(path, "/") + "/" + // Try tokenizer.json first + if data, err := os.ReadFile(dir + "tokenizer.json"); err == nil { + return loadFromTokenizerJSON(data, dir) + } + // Fall back to vocab.json + merges.txt + return LoadVocabMerges(path) + } + + // It's a file - read it directly + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read tokenizer: %w", err) + } + + // Get directory for loading companion files + dir := "" + if idx := strings.LastIndex(path, "/"); idx >= 0 { + dir = path[:idx+1] + } + return loadFromTokenizerJSON(data, dir) +} + +// loadFromTokenizerJSON parses a tokenizer.json file +func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) { + + var raw struct { + Model struct { + Type string `json:"type"` // "BPE" or "WordPiece" + Vocab map[string]int32 `json:"vocab"` + Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only) + } `json:"model"` + PreTokenizer json.RawMessage `json:"pre_tokenizer"` + Decoder json.RawMessage `json:"decoder"` + AddedTokens []struct { + ID int32 `json:"id"` + Content string `json:"content"` + Special bool `json:"special"` + } `json:"added_tokens"` + } + + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("failed to parse tokenizer: %w", err) + } + + // Parse merges - can be []string (Llama) or [][]string (GPT-OSS) + // WordPiece models don't have merges + var mergesStrings []string + if raw.Model.Type != "WordPiece" && raw.Model.Merges != nil { + var mergesArrays [][]string + if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil { + // Try array of arrays format + if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil { + return nil, fmt.Errorf("failed to parse merges: %w", err) + } + // Convert [][]string to []string + mergesStrings = make([]string, len(mergesArrays)) + for i, pair := range mergesArrays { + mergesStrings[i] = pair[0] + " " + pair[1] + } + } + } + + // Build tokenizer + t := &Tokenizer{ + vocab: &Vocabulary{ + Values: make([]string, len(raw.Model.Vocab)), + Reverse: raw.Model.Vocab, + Merges: make(map[string]int, len(mergesStrings)), + BOS: -1, + PAD: -1, + }, + specialTokens: make(map[string]int32), + } + + // Build values array + for token, id := range raw.Model.Vocab { + if int(id) >= len(t.vocab.Values) { + newValues := make([]string, id+1) + copy(newValues, t.vocab.Values) + t.vocab.Values = newValues + } + t.vocab.Values[id] = token + } + + // Build merges map + for i, merge := range mergesStrings { + t.vocab.Merges[merge] = i + } + + // Add special tokens to vocabulary + for _, tok := range raw.AddedTokens { + if int(tok.ID) >= len(t.vocab.Values) { + newValues := make([]string, tok.ID+1) + copy(newValues, t.vocab.Values) + t.vocab.Values = newValues + } + t.vocab.Values[tok.ID] = tok.Content + if tok.Special { + t.specialTokens[tok.Content] = tok.ID + } + } + + // Load special token configuration from companion files + loadSpecialTokenConfig(dir, t) + + // Precompute byte token IDs for <0xNN> fallback + initByteTokens(t) + + // Determine tokenizer type + switch { + case raw.Model.Type == "WordPiece": + t.typ = TokenizerWordPiece + case detectSentencePiece(raw.Decoder): + t.typ = TokenizerSentencePiece + default: + t.typ = TokenizerBPE + } + + // Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer) + if t.typ == TokenizerBPE { + pattern := extractPretokenizer(raw.PreTokenizer) + if pattern == "" { + pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+` + } + re, err := regexp.Compile(rewritePatternForRE2(pattern)) + if err != nil { + return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err) + } + t.pretokenizer = re + } + + return t, nil +} + +// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces) +// vs GPT-2 byte-level encoding +func detectSentencePiece(data json.RawMessage) bool { + if data == nil { + return false + } + + // Check for Sequence decoder with Replace step (SentencePiece style) + var seq struct { + Type string `json:"type"` + Decoders []struct { + Type string `json:"type"` + Pattern struct { + String string `json:"String"` + } `json:"pattern"` + } `json:"decoders"` + } + if err := json.Unmarshal(data, &seq); err == nil { + if seq.Type == "Sequence" { + for _, dec := range seq.Decoders { + // Look for Replace decoder that converts ▁ to space + if dec.Type == "Replace" && dec.Pattern.String == "▁" { + return true + } + } + } + } + + // Check for direct ByteLevel decoder (GPT-2 style) + var simple struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &simple); err == nil { + if simple.Type == "ByteLevel" { + return false + } + } + + return false +} + +// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding +func initByteTokens(t *Tokenizer) { + for i := range t.vocab.byteTokens { + t.vocab.byteTokens[i] = -1 + } + for b := 0; b < 256; b++ { + token := fmt.Sprintf("<0x%02X>", b) + if id, ok := t.vocab.Reverse[token]; ok { + t.vocab.byteTokens[b] = id + } + } +} + +// extractPretokenizer extracts the regex pattern from the pre_tokenizer config +func extractPretokenizer(data json.RawMessage) string { + if data == nil { + return "" + } + + // Try to parse as a single Split pretokenizer + var single struct { + Type string `json:"type"` + Pattern struct { + Regex string `json:"Regex"` + } `json:"pattern"` + } + if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" { + return single.Pattern.Regex + } + + // Try to parse as Sequence of pretokenizers - use first Split pattern + var seq struct { + Type string `json:"type"` + Pretokenizers []struct { + Type string `json:"type"` + Pattern struct { + Regex string `json:"Regex"` + } `json:"pattern"` + } `json:"pretokenizers"` + } + if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" { + for _, pt := range seq.Pretokenizers { + if pt.Type == "Split" && pt.Pattern.Regex != "" { + return pt.Pattern.Regex + } + } + } + + return "" +} + +// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines) +func isNonNewlineWhitespace(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r == '\n' || r == '\r' { + return false + } + if !unicode.IsSpace(r) { + return false + } + } + return true +} + +// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements +func (t *Tokenizer) splitBySpecialTokens(s string) []string { + if len(t.specialTokens) == 0 { + return []string{s} + } + + // Sort special tokens by length (longest first) to match greedily + tokens := make([]string, 0, len(t.specialTokens)) + for tok := range t.specialTokens { + tokens = append(tokens, tok) + } + sort.Slice(tokens, func(i, j int) bool { + return len(tokens[i]) > len(tokens[j]) + }) + + var result []string + remaining := s + + for len(remaining) > 0 { + found := false + for _, tok := range tokens { + if strings.HasPrefix(remaining, tok) { + result = append(result, tok) + remaining = remaining[len(tok):] + found = true + break + } + } + if !found { + // Find next special token position + nextPos := len(remaining) + for _, tok := range tokens { + if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos { + nextPos = idx + } + } + if nextPos > 0 { + result = append(result, remaining[:nextPos]) + } + remaining = remaining[nextPos:] + } + } + + return result +} + +// Encode tokenizes text to token IDs. Parallelizes for large inputs (>10KB). +func (t *Tokenizer) Encode(s string, addBOS bool) []int32 { + // First: split by special tokens + parts := t.splitBySpecialTokens(s) + + // Second: collect all pretokenizer chunks + type chunk struct { + text string + isSpecial bool + } + var allChunks []chunk + + if t.pretokenizer != nil { + re := t.pretokenizer + for _, part := range parts { + if _, ok := t.specialTokens[part]; ok { + allChunks = append(allChunks, chunk{part, true}) + continue + } + + // Split by pretokenizer regex + type match struct{ start, end int } + var matches []match + offset := 0 + for offset < len(part) { + loc := re.FindStringIndex(part[offset:]) + if loc == nil { + break + } + matches = append(matches, match{offset + loc[0], offset + loc[1]}) + offset += loc[1] + } + + // Apply whitespace boundary fix for Python regex compatibility + for i := 0; i < len(matches)-1; i++ { + m := part[matches[i].start:matches[i].end] + next := part[matches[i+1].start:matches[i+1].end] + + if isNonNewlineWhitespace(m) && len(next) > 0 { + firstRune, _ := utf8.DecodeRuneInString(next) + if unicode.IsLetter(firstRune) { + lastSpaceStart := matches[i].end + for j := matches[i].end; j > matches[i].start; { + r, size := utf8.DecodeLastRuneInString(part[matches[i].start:j]) + if unicode.IsSpace(r) { + lastSpaceStart = j - size + break + } + j -= size + } + if lastSpaceStart > matches[i].start { + matches[i].end = lastSpaceStart + matches[i+1].start = lastSpaceStart + } else { + matches[i+1].start = matches[i].start + matches[i].end = matches[i].start + } + } + } + } + + for _, m := range matches { + if m.end > m.start { + allChunks = append(allChunks, chunk{part[m.start:m.end], false}) + } + } + } + } else { + // No pretokenizer - treat each part as a single chunk + for _, part := range parts { + if _, ok := t.specialTokens[part]; ok { + allChunks = append(allChunks, chunk{part, true}) + } else { + allChunks = append(allChunks, chunk{part, false}) + } + } + } + + // Encode chunks - parallel for large inputs (>4KB), sequential otherwise + var ids []int32 + if len(s) < 4096 { + for _, c := range allChunks { + if c.isSpecial { + if id, ok := t.specialTokens[c.text]; ok { + ids = append(ids, id) + } + } else { + ids = t.encodeChunkInto(c.text, ids) + } + } + } else { + numWorkers := runtime.GOMAXPROCS(0) + if numWorkers > len(allChunks) { + numWorkers = len(allChunks) + } + + chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers + results := make([][]int32, numWorkers) + var wg sync.WaitGroup + + for i := 0; i < numWorkers; i++ { + start := i * chunksPer + end := start + chunksPer + if end > len(allChunks) { + end = len(allChunks) + } + if start >= end { + continue + } + + wg.Add(1) + go func(i int, chunks []chunk) { + defer wg.Done() + var r []int32 + for _, c := range chunks { + if c.isSpecial { + if id, ok := t.specialTokens[c.text]; ok { + r = append(r, id) + } + } else { + r = t.encodeChunkInto(c.text, r) + } + } + results[i] = r + }(i, allChunks[start:end]) + } + wg.Wait() + + for _, r := range results { + ids = append(ids, r...) + } + } + + if addBOS && t.vocab.BOS >= 0 { + ids = append([]int32{t.vocab.BOS}, ids...) + } + return ids +} + +// encodeChunkInto appends encoded tokens to ids and returns the extended slice +// Uses BPE merge algorithm when merges are available, otherwise longest-match +func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 { + if t.typ == TokenizerWordPiece { + return t.encodeWordPieceInto(s, ids) + } + + if s == "" { + return ids + } + + // Apply encoding transformation + // SentencePiece: replace space with ▁ + // BPE: convert bytes using precomputed table (GPT-2 byte-level encoding) + var encoded string + if t.typ == TokenizerSentencePiece { + encoded = strings.ReplaceAll(s, " ", "▁") + } else { + var sb strings.Builder + sb.Grow(len(s) * 2) + for i := 0; i < len(s); i++ { + sb.WriteRune(byteToRune[s[i]]) + } + encoded = sb.String() + } + + // Fast path: check if entire chunk is a single token + if id, ok := t.vocab.Reverse[encoded]; ok { + return append(ids, id) + } + + return t.encodeBPEMerge(encoded, ids) +} + +// encodeBPEMerge encodes using BPE merge algorithm. +// Repeatedly merges the pair with lowest rank until no more merges possible. +// Works correctly with empty merges (falls back to individual rune/byte encoding). +func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 { + // Start with individual runes as parts + runes := []rune(encoded) + parts := make([]string, len(runes)) + for i, r := range runes { + parts[i] = string(r) + } + + // Repeatedly merge lowest-rank pair + for len(parts) > 1 { + minRank := int(0x7FFFFFFF) + minIdx := -1 + + for i := 0; i < len(parts)-1; i++ { + // Merge key format: "token1 token2" (space-separated) + mergeKey := parts[i] + " " + parts[i+1] + if rank, ok := t.vocab.Merges[mergeKey]; ok { + if rank < minRank { + minRank = rank + minIdx = i + } + } + } + + if minIdx < 0 { + break // No more merges possible + } + + // Merge the pair + parts[minIdx] = parts[minIdx] + parts[minIdx+1] + parts = append(parts[:minIdx+1], parts[minIdx+2:]...) + } + + // Convert parts to token IDs + for _, part := range parts { + if id, ok := t.vocab.Reverse[part]; ok { + ids = append(ids, id) + } else { + // Byte fallback for unknown parts + for _, b := range []byte(part) { + if id := t.vocab.byteTokens[b]; id >= 0 { + ids = append(ids, id) + } + } + } + } + + return ids +} + +// encodeWordPieceInto appends WordPiece tokens to ids and returns extended slice +// Uses greedy longest-match with ## prefix for continuation tokens +func (t *Tokenizer) encodeWordPieceInto(s string, ids []int32) []int32 { + if s == "" { + return ids + } + + // Check if entire string is in vocabulary (common case) + if id, ok := t.vocab.Reverse[s]; ok { + return append(ids, id) + } + + runes := []rune(s) + start := 0 + + for start < len(runes) { + end := len(runes) + found := false + + // Greedy longest-match + for end > start { + substr := string(runes[start:end]) + if start > 0 { + // Continuation token: prefix with ## + substr = "##" + substr + } + + if id, ok := t.vocab.Reverse[substr]; ok { + ids = append(ids, id) + found = true + start = end + break + } + end-- + } + + if !found { + // No match found - use [UNK] token or skip + if t.unkToken >= 0 { + ids = append(ids, t.unkToken) + } + start++ + } + } + + return ids +} + +// Decode converts token IDs back to text +func (t *Tokenizer) Decode(ids []int32) string { + var sb strings.Builder + + for _, id := range ids { + if int(id) >= len(t.vocab.Values) { + continue + } + + token := t.vocab.Values[id] + + switch t.typ { + case TokenizerWordPiece: + // WordPiece style: strip ## prefix from continuation tokens + if strings.HasPrefix(token, "##") { + sb.WriteString(token[2:]) + } else { + sb.WriteString(token) + } + case TokenizerSentencePiece: + // SentencePiece style: replace ▁ with space, decode byte tokens + token = strings.ReplaceAll(token, "▁", " ") + // Handle byte fallback tokens like <0x0D> + if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' { + if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil { + sb.WriteByte(byte(v)) + continue + } + } + sb.WriteString(token) + default: + // GPT-2 BPE style: decode byte-level encoding + for _, r := range token { + switch { + case r == 0x0100: + // NULL byte (0x00 encoded as 0x0100) + sb.WriteByte(0) + continue + case r == 0x0143: + r = 0x00ad + case r > 0x0100 && r <= 0x0120: + r = r - 0x0100 + case r > 0x0120 && r <= 0x0142: + r = r - 0x00a2 + } + + // Write as byte, not UTF-8 encoded rune + sb.WriteByte(byte(r)) + } + } + } + + return sb.String() +} + +// VocabSize returns the vocabulary size +func (t *Tokenizer) VocabSize() int { + return len(t.vocab.Values) +} + +// BOS returns the beginning of sequence token ID +func (t *Tokenizer) BOS() int32 { + return t.vocab.BOS +} + +// EOS returns the first end of sequence token ID (for backwards compatibility) +func (t *Tokenizer) EOS() int32 { + if len(t.vocab.EOS) > 0 { + return t.vocab.EOS[0] + } + return -1 +} + +// EOSTokens returns all end of sequence token IDs +func (t *Tokenizer) EOSTokens() []int32 { + return t.vocab.EOS +} + +// PAD returns the padding token ID, or -1 if not set +func (t *Tokenizer) PAD() int32 { + return t.vocab.PAD +} + +// IsEOS returns true if the token ID is an end of sequence token +func (t *Tokenizer) IsEOS(id int32) bool { + for _, eos := range t.vocab.EOS { + if id == eos { + return true + } + } + return false +} + +// GetSpecialToken returns the token ID for a special token string +func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) { + id, ok := t.specialTokens[name] + return id, ok +} + +// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style) +func LoadVocabMerges(dir string) (*Tokenizer, error) { + vocabPath := dir + "/vocab.json" + mergesPath := dir + "/merges.txt" + addedTokensPath := dir + "/added_tokens.json" + + // Load vocab + vocabData, err := os.ReadFile(vocabPath) + if err != nil { + return nil, fmt.Errorf("failed to read vocab.json: %w", err) + } + + vocabMap := make(map[string]int32) + if err := json.Unmarshal(vocabData, &vocabMap); err != nil { + return nil, fmt.Errorf("failed to parse vocab.json: %w", err) + } + + // Load merges + mergesData, err := os.ReadFile(mergesPath) + if err != nil { + return nil, fmt.Errorf("failed to read merges.txt: %w", err) + } + + mergesLines := strings.Split(string(mergesData), "\n") + var mergesStrings []string + for _, line := range mergesLines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + mergesStrings = append(mergesStrings, line) + } + + // Build tokenizer + t := &Tokenizer{ + vocab: &Vocabulary{ + Values: make([]string, len(vocabMap)), + Reverse: vocabMap, + Merges: make(map[string]int, len(mergesStrings)), + BOS: -1, + PAD: -1, + }, + specialTokens: make(map[string]int32), + } + + // Load added tokens if exists + if addedData, err := os.ReadFile(addedTokensPath); err == nil { + addedMap := make(map[string]int32) + if err := json.Unmarshal(addedData, &addedMap); err == nil { + for token, id := range addedMap { + vocabMap[token] = id + t.specialTokens[token] = id + } + } + } + + // Build values array + for token, id := range vocabMap { + if int(id) >= len(t.vocab.Values) { + newValues := make([]string, id+1) + copy(newValues, t.vocab.Values) + t.vocab.Values = newValues + } + t.vocab.Values[id] = token + } + + // Build merges map + for i, merge := range mergesStrings { + t.vocab.Merges[merge] = i + } + + // Load special token configuration from companion files + loadSpecialTokenConfig(dir+"/", t) + + // Precompute byte token IDs for <0xNN> fallback + initByteTokens(t) + + // GPT-2/tiktoken pretokenizer pattern + pattern := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+` + re, err := regexp.Compile(rewritePatternForRE2(pattern)) + if err != nil { + return nil, fmt.Errorf("failed to compile pretokenizer regex: %w", err) + } + t.pretokenizer = re + + return t, nil +} diff --git a/x/imagegen/tokenizer/tokenizer_test.go b/x/imagegen/tokenizer/tokenizer_test.go new file mode 100644 index 000000000..2ac79ab1e --- /dev/null +++ b/x/imagegen/tokenizer/tokenizer_test.go @@ -0,0 +1,785 @@ +//go:build mlx + +package tokenizer + +import ( + "bytes" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "sync" + "testing" +) + +// TestPatternCompilation validates that HuggingFace pretokenizer patterns +// can be rewritten for Go's RE2 regexp engine and compiled successfully. +func TestPatternCompilation(t *testing.T) { + patterns := []struct { + name string + pattern string + }{ + {"llama3", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`}, + {"qwen2", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`}, + {"gpt4o", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`}, + {"gpt2", `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}, + {"deepseek_cjk", `[一-龥\x{3040}-ゟ゠-ヿ]+`}, + } + + for _, p := range patterns { + t.Run(p.name, func(t *testing.T) { + rewritten := rewritePatternForRE2(p.pattern) + if _, err := regexp.Compile(rewritten); err != nil { + t.Errorf("failed to compile pattern: %v\noriginal: %s\nrewritten: %s", err, p.pattern, rewritten) + } + }) + } +} + +// TestRoundtrip verifies the fundamental property: encode(text) -> decode -> text +// This is the key invariant from tiktoken's test suite. +func TestRoundtrip(t *testing.T) { + tok, err := Load("testdata/mini_llama.json") + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + // Test cases covering key edge cases from tiktoken + inputs := []string{ + // Empty and simple + "", + "a", + "hello", + "hello world", + + // Whitespace edge cases + " ", + " ", + " ", + " hello", + "hello ", + " hello ", + "hello world", + "hello world", + "\t", + "\n", + "\r\n", + "hello\nworld", + "hello\n\nworld", + + // Contractions + "don't", + "I'm", + "we'll", + "they're", + "it's", + "DON'T", // uppercase + + // Numbers + "123", + "1234567890", + "3.14159", + "$100", + "50%", + + // Unicode + "こんにちは", // Japanese + "你好", // Chinese + "مرحبا", // Arabic (RTL) + "🎉", // Emoji + "Hello 世界", // Mixed + "café", // Accented + "naïve", // Diaeresis + "Ω≈ç√∫", // Math symbols + + // Code + "func main() {}", + "if (x == 0) { return; }", + "import \"fmt\"", + "x := 42", + "// comment", + "/* block */", + + // Repetition (tiktoken specifically tests this) + "aaaa", + "aaaaaaaaaaaa", + strings.Repeat("a", 100), + strings.Repeat("hello ", 50), + + // Punctuation + "...", + "!!!", + "???", + "hello, world!", + "(parentheses)", + "[brackets]", + "{braces}", + + // Mixed complexity + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "func TestRoundtrip(t *testing.T) { t.Run(\"test\", func(t *testing.T) {}) }", + } + + for _, input := range inputs { + name := input + if len(name) > 30 { + name = name[:30] + "..." + } + if name == "" { + name = "" + } + name = strings.ReplaceAll(name, "\n", "\\n") + name = strings.ReplaceAll(name, "\t", "\\t") + + t.Run(name, func(t *testing.T) { + tokens := tok.Encode(input, false) + decoded := tok.Decode(tokens) + if decoded != input { + t.Errorf("roundtrip failed:\n input: %q\n tokens: %v\n decoded: %q", input, tokens, decoded) + } + }) + } +} + +// TestSpecialTokens verifies that special tokens are handled correctly +func TestSpecialTokens(t *testing.T) { + tok, err := Load("testdata/mini_llama.json") + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + // Special tokens should be preserved through encode/decode + t.Run("bos_preserved", func(t *testing.T) { + if tok.BOS() < 0 { + t.Skip("no BOS token") + } + tokens := tok.Encode("hello", true) + if len(tokens) == 0 || tokens[0] != tok.BOS() { + t.Errorf("BOS not prepended: got %v, want first token to be %d", tokens, tok.BOS()) + } + }) + + t.Run("special_token_split", func(t *testing.T) { + // If we have special tokens, verify they're split correctly + for tokenStr, tokenID := range tok.specialTokens { + input := "before" + tokenStr + "after" + tokens := tok.Encode(input, false) + + found := false + for _, id := range tokens { + if id == tokenID { + found = true + break + } + } + if !found { + t.Errorf("special token %q (id=%d) not found in encoding of %q: %v", + tokenStr, tokenID, input, tokens) + } + } + }) +} + +// TestConcurrency verifies thread-safe encoding +func TestConcurrency(t *testing.T) { + tok, err := Load("testdata/mini_llama.json") + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + input := "The quick brown fox jumps over the lazy dog." + expected := tok.Encode(input, false) + + var wg sync.WaitGroup + errors := make(chan error, 100) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + got := tok.Encode(input, false) + if len(got) != len(expected) { + errors <- nil // just signal error + return + } + for j := range got { + if got[j] != expected[j] { + errors <- nil + return + } + } + }() + } + + wg.Wait() + close(errors) + + if len(errors) > 0 { + t.Errorf("concurrent encoding produced inconsistent results") + } +} + +// TestIntegration runs against real model directories, comparing with Python transformers. +// Skips if model weights are not available. +func TestIntegration(t *testing.T) { + models := []string{ + "../weights/Llama-3.2-1B", + "../weights/gemma-3-1b-it", + "../weights/gpt-oss-20b", + } + + // Test inputs covering various edge cases + inputs := []string{ + "Hello, world!", + "The quick brown fox jumps over the lazy dog.", + "こんにちは世界", + "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)", + "1234567890", + " spaces ", + "don't won't can't", + } + + for _, modelPath := range models { + modelName := filepath.Base(modelPath) + + t.Run(modelName, func(t *testing.T) { + tokenizerPath := filepath.Join(modelPath, "tokenizer.json") + if _, err := os.Stat(tokenizerPath); err != nil { + t.Skipf("skipping: %s not found", tokenizerPath) + } + + tok, err := Load(tokenizerPath) + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + for _, input := range inputs { + t.Run(truncate(input, 20), func(t *testing.T) { + // Test roundtrip + tokens := tok.Encode(input, false) + decoded := tok.Decode(tokens) + if decoded != input { + t.Errorf("roundtrip failed:\n input: %q\n decoded: %q", input, decoded) + } + + // Compare with Python if available + if pythonTokens, err := pythonEncode(modelPath, input); err == nil { + if !equalInt32Slice(tokens, pythonTokens) { + t.Errorf("mismatch with Python:\n go: %v\n python: %v", tokens, pythonTokens) + } + } + }) + } + }) + } +} + +// pythonEncode calls Python transformers to encode text, for comparison +func pythonEncode(modelPath, text string) ([]int32, error) { + script := ` +import sys, json +from transformers import AutoTokenizer +tok = AutoTokenizer.from_pretrained(sys.argv[1]) +tokens = tok.encode(sys.argv[2], add_special_tokens=False) +print(json.dumps(tokens)) +` + cmd := exec.Command("python3", "-c", script, modelPath, text) + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = nil + + if err := cmd.Run(); err != nil { + return nil, err + } + + // Parse JSON array + var tokens []int32 + output := strings.TrimSpace(out.String()) + if output == "" || output == "[]" { + return []int32{}, nil + } + + // Simple parsing for [1, 2, 3] format + output = strings.Trim(output, "[]") + if output == "" { + return []int32{}, nil + } + + for _, s := range strings.Split(output, ",") { + s = strings.TrimSpace(s) + var v int32 + if _, err := parseIntSimple(s, &v); err == nil { + tokens = append(tokens, v) + } + } + + return tokens, nil +} + +func parseIntSimple(s string, v *int32) (bool, error) { + var n int64 + for _, c := range s { + if c >= '0' && c <= '9' { + n = n*10 + int64(c-'0') + } + } + *v = int32(n) + return true, nil +} + +func equalInt32Slice(a, b []int32) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +// TestBPEPretokenizer verifies BPE pretokenizer splits text correctly +// using the GPT-2 style regex pattern (no dependency on tokenizer files) +func TestBPEPretokenizer(t *testing.T) { + pattern := `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+` + re := regexp.MustCompile(rewritePatternForRE2(pattern)) + + tests := []struct { + input string + expected []string + }{ + {"Hello", []string{"Hello"}}, + {"Hello world", []string{"Hello", " world"}}, + {"Hello, world!", []string{"Hello", ",", " world", "!"}}, + {"don't", []string{"don", "'t"}}, + {"I'm", []string{"I", "'m"}}, + {"123", []string{"123"}}, + {"12345", []string{"12345"}}, // GPT-2 pattern matches any digit sequence + {"a b", []string{"a", " ", " b"}}, // whitespace boundary: last space prepends to word + {" ", []string{" "}}, // pure whitespace stays together + {"\n\n", []string{"\n\n"}}, // newlines stay together + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + // Get regex matches + matches := re.FindAllStringIndex(tt.input, -1) + var chunks []string + for _, m := range matches { + chunks = append(chunks, tt.input[m[0]:m[1]]) + } + + // Apply whitespace boundary fix (same logic as Encode) + for i := 0; i < len(chunks)-1; i++ { + if isNonNewlineWhitespace(chunks[i]) && len(chunks[i+1]) > 0 { + r, _ := []rune(chunks[i+1])[0], 0 + if r >= 'A' && r <= 'z' { // simplified letter check + // Move last space to next chunk + if len(chunks[i]) > 0 { + lastSpace := chunks[i][len(chunks[i])-1:] + chunks[i] = chunks[i][:len(chunks[i])-1] + chunks[i+1] = lastSpace + chunks[i+1] + } + } + } + } + + // Filter empty chunks + var result []string + for _, c := range chunks { + if c != "" { + result = append(result, c) + } + } + + if len(result) != len(tt.expected) { + t.Errorf("got %v, want %v", result, tt.expected) + return + } + for i := range result { + if result[i] != tt.expected[i] { + t.Errorf("chunk %d: got %q, want %q", i, result[i], tt.expected[i]) + } + } + }) + } +} + +// TestSentencePiecePretokenizer verifies SentencePiece doesn't use pretokenizer +// and correctly replaces spaces with ▁ (no dependency on tokenizer files) +func TestSentencePiecePretokenizer(t *testing.T) { + // SentencePiece has no pretokenizer - whole text is one chunk + // Spaces are replaced with ▁ during encoding + + tests := []struct { + input string + expected string // after space replacement + }{ + {"Hello", "Hello"}, + {"Hello world", "Hello▁world"}, + {"Hello, world!", "Hello,▁world!"}, + {" spaces ", "▁▁▁spaces▁▁▁"}, + {" Hello", "▁Hello"}, + {"Hello ", "Hello▁"}, + {"a b c", "a▁b▁c"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + // SentencePiece encoding: replace space with ▁ + result := strings.ReplaceAll(tt.input, " ", "▁") + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} + +// TestWordPiecePretokenizer verifies WordPiece (BERT) pretokenizer splits correctly +// BertPreTokenizer splits on whitespace and punctuation +func TestWordPiecePretokenizer(t *testing.T) { + // BertPreTokenizer behavior: split on whitespace and punctuation + // Whitespace is stripped, punctuation becomes separate tokens + + tests := []struct { + input string + expected []string + }{ + {"Hello", []string{"Hello"}}, + {"Hello world", []string{"Hello", "world"}}, // whitespace stripped + {"Hello, world!", []string{"Hello", ",", "world", "!"}}, // punct separate + {"don't", []string{"don", "'", "t"}}, // apostrophe separate (unlike BPE) + {" spaces ", []string{"spaces"}}, // whitespace stripped + {"Hello.World", []string{"Hello", ".", "World"}}, // punct splits + {"test@email.com", []string{"test", "@", "email", ".", "com"}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := splitBertStyle(tt.input) + if len(result) != len(tt.expected) { + t.Errorf("got %v, want %v", result, tt.expected) + return + } + for i := range result { + if result[i] != tt.expected[i] { + t.Errorf("token %d: got %q, want %q", i, result[i], tt.expected[i]) + } + } + }) + } +} + +// splitBertStyle mimics BertPreTokenizer: split on whitespace and punctuation +func splitBertStyle(s string) []string { + var result []string + var current strings.Builder + + for _, r := range s { + if r == ' ' || r == '\t' || r == '\n' || r == '\r' { + // Whitespace: flush current token, don't add whitespace + if current.Len() > 0 { + result = append(result, current.String()) + current.Reset() + } + } else if isPunct(r) { + // Punctuation: flush current, add punct as separate token + if current.Len() > 0 { + result = append(result, current.String()) + current.Reset() + } + result = append(result, string(r)) + } else { + current.WriteRune(r) + } + } + if current.Len() > 0 { + result = append(result, current.String()) + } + return result +} + +func isPunct(r rune) bool { + // Common ASCII punctuation + return (r >= '!' && r <= '/') || (r >= ':' && r <= '@') || + (r >= '[' && r <= '`') || (r >= '{' && r <= '~') +} + +// TestRepeatedDigits verifies correct tokenization of repeated digit sequences. +// Llama-style tokenizers split digits in groups of 1-3 due to the \p{N}{1,3} pattern. +func TestRepeatedDigits(t *testing.T) { + tok, err := Load("./testdata/mini_llama.json") + if err != nil { + t.Skipf("mini_llama.json not available: %v", err) + } + + // Pattern: 1 digit, 2 digits, 3 digits, then repeats + // "0" -> [single], "00" -> [double], "000" -> [triple] + // "0000" -> [triple, single], etc. + tests := []struct { + input string + count int // expected token count + }{ + {"0", 1}, + {"00", 1}, + {"000", 1}, + {"0000", 2}, // 3 + 1 + {"00000", 2}, // 3 + 2 + {"000000", 2}, // 3 + 3 + {"0000000", 3}, + {"00000000", 3}, + {"000000000", 3}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + ids := tok.Encode(tt.input, false) + if len(ids) != tt.count { + t.Errorf("Encode(%q) = %d tokens, want %d", tt.input, len(ids), tt.count) + } + // Verify roundtrip + decoded := tok.Decode(ids) + if decoded != tt.input { + t.Errorf("Decode(Encode(%q)) = %q", tt.input, decoded) + } + }) + } +} + +// TestNullByte verifies that null bytes roundtrip correctly +func TestNullByte(t *testing.T) { + tok, err := Load("./testdata/mini_llama.json") + if err != nil { + t.Skipf("mini_llama.json not available: %v", err) + } + + ids := tok.Encode("\x00", false) + decoded := tok.Decode(ids) + if decoded != "\x00" { + t.Errorf("null byte roundtrip failed: got %q, want %q", decoded, "\x00") + } +} + +// TestTokenizerTypeDetection verifies correct detection of tokenizer types +func TestTokenizerTypeDetection(t *testing.T) { + tests := []struct { + name string + decoder string + expected TokenizerType + }{ + { + name: "ByteLevel decoder (BPE)", + decoder: `{"type": "ByteLevel"}`, + expected: TokenizerBPE, + }, + { + name: "Sequence with Replace ▁ (SentencePiece)", + decoder: `{ + "type": "Sequence", + "decoders": [ + {"type": "Replace", "pattern": {"String": "▁"}, "content": " "} + ] + }`, + expected: TokenizerSentencePiece, + }, + { + name: "null decoder (BPE default)", + decoder: `null`, + expected: TokenizerBPE, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isSPM := detectSentencePiece([]byte(tt.decoder)) + var got TokenizerType + if isSPM { + got = TokenizerSentencePiece + } else { + got = TokenizerBPE + } + if got != tt.expected { + t.Errorf("got %v, want %v", got, tt.expected) + } + }) + } +} + +// TestPADTokenDefault verifies PAD() returns -1 when not configured +func TestPADTokenDefault(t *testing.T) { + tok, err := Load("testdata/mini_llama.json") + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + // mini_llama.json has no PAD token configured, should return -1 + if got := tok.PAD(); got != -1 { + t.Errorf("PAD() = %d, want -1 (not configured)", got) + } +} + +// TestPADTokenFromConfig verifies PAD token is loaded from tokenizer_config.json +func TestPADTokenFromConfig(t *testing.T) { + // Create temp directory with tokenizer files + dir := t.TempDir() + + // Write minimal tokenizer.json + tokenizerJSON := `{ + "model": { + "type": "BPE", + "vocab": {"<|endoftext|>": 0, "hello": 1, "world": 2}, + "merges": [] + }, + "added_tokens": [ + {"id": 0, "content": "<|endoftext|>", "special": true} + ] + }` + if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil { + t.Fatalf("failed to write tokenizer.json: %v", err) + } + + // Write tokenizer_config.json with pad_token + configJSON := `{ + "pad_token": "<|endoftext|>" + }` + if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write tokenizer_config.json: %v", err) + } + + tok, err := Load(dir) + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + if got := tok.PAD(); got != 0 { + t.Errorf("PAD() = %d, want 0 (<|endoftext|>)", got) + } +} + +// TestPADTokenFromSpecialTokensMap verifies PAD falls back to special_tokens_map.json +func TestPADTokenFromSpecialTokensMap(t *testing.T) { + dir := t.TempDir() + + // Write minimal tokenizer.json + tokenizerJSON := `{ + "model": { + "type": "BPE", + "vocab": {"": 0, "hello": 1, "world": 2}, + "merges": [] + }, + "added_tokens": [ + {"id": 0, "content": "", "special": true} + ] + }` + if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil { + t.Fatalf("failed to write tokenizer.json: %v", err) + } + + // Write special_tokens_map.json with pad_token + mapJSON := `{ + "pad_token": "" + }` + if err := os.WriteFile(filepath.Join(dir, "special_tokens_map.json"), []byte(mapJSON), 0o644); err != nil { + t.Fatalf("failed to write special_tokens_map.json: %v", err) + } + + tok, err := Load(dir) + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + if got := tok.PAD(); got != 0 { + t.Errorf("PAD() = %d, want 0 ()", got) + } +} + +// TestPADTokenWithContentObject verifies PAD token works with {"content": "..."} format +func TestPADTokenWithContentObject(t *testing.T) { + dir := t.TempDir() + + // Write minimal tokenizer.json + tokenizerJSON := `{ + "model": { + "type": "BPE", + "vocab": {"[PAD]": 0, "hello": 1}, + "merges": [] + }, + "added_tokens": [ + {"id": 0, "content": "[PAD]", "special": true} + ] + }` + if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil { + t.Fatalf("failed to write tokenizer.json: %v", err) + } + + // Write tokenizer_config.json with pad_token as object (HuggingFace format) + configJSON := `{ + "pad_token": {"content": "[PAD]", "lstrip": false, "normalized": false} + }` + if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil { + t.Fatalf("failed to write tokenizer_config.json: %v", err) + } + + tok, err := Load(dir) + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + if got := tok.PAD(); got != 0 { + t.Errorf("PAD() = %d, want 0 ([PAD])", got) + } +} + +// Benchmarks + +func BenchmarkEncode(b *testing.B) { + tok, err := Load("testdata/mini_llama.json") + if err != nil { + b.Fatalf("failed to load tokenizer: %v", err) + } + + inputs := []struct { + name string + text string + }{ + {"short", "Hello, world!"}, + {"medium", "The quick brown fox jumps over the lazy dog. " + strings.Repeat("This is a test. ", 10)}, + {"long", strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)}, + } + + for _, input := range inputs { + b.Run(input.name, func(b *testing.B) { + b.SetBytes(int64(len(input.text))) + for i := 0; i < b.N; i++ { + tok.Encode(input.text, false) + } + }) + } +} + +func BenchmarkDecode(b *testing.B) { + tok, err := Load("testdata/mini_llama.json") + if err != nil { + b.Fatalf("failed to load tokenizer: %v", err) + } + + text := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100) + tokens := tok.Encode(text, false) + + b.SetBytes(int64(len(text))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + tok.Decode(tokens) + } +} diff --git a/x/kvcache/cache.go b/x/kvcache/cache.go new file mode 100644 index 000000000..f0627584a --- /dev/null +++ b/x/kvcache/cache.go @@ -0,0 +1,77 @@ +package kvcache + +import ( + "errors" + + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/model/input" +) + +var ( + ErrKvCacheFull = errors.New("could not find a kv cache slot") + ErrNotSupported = errors.New("model does not support operation") +) + +type Cache interface { + // ** used by model implementations ** + + // SetLayer sets the active layer of the cache + SetLayer(layer int) + + // Get returns the history of key and value tensors plus a mask + // + // The shape of the tensors is documented in the specific + // cache implementation used. + Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) + + // Put stores a batch of key and value in the cache + // + // The shape of the tensors is documented in the specific + // cache implementation used. + Put(ctx ml.Context, key, value ml.Tensor) + + // SetConfig controls optimizations (mostly backend-specific) that may transform + // the output of the cache to work better with specific kernels. If not called, + // the backend settings will be used. This works well when calling Attention. + // + // The config can be overridden by models, especially if they require vanilla + // output when implementing their own version of attention. To do this, pass + // an empty ml.CacheConfig. + // + // Most models will not need to use this. + SetConfig(ml.CacheConfig) + + // ** cache management ** + + // Init sets up runtime parameters. + // backend: Used to allocate cache data storage and execute management operations (such as defrag) + // dtype: The data type for storing cache entries + // maxSequences: The maximum number of sequences stored in the cache - across all batches + // capacity: The number of cache entries to store, per sequence + // maxBatch: The maximum number of tokens that can occur in a single batch + Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) + + // Close closes the cache and frees resources associated with it + Close() + + // StartForward is called before the start of the model's forward pass. + // For each token in the coming batch, there must be a corresponding + // entry in positions and seqs. reserve is to preallocate memory + // without actually storing data in the cache. + StartForward(ctx ml.Context, batch input.Batch, reserve bool) error + + // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq + CopyPrefix(srcSeq, dstSeq int, len int32) + + // CanResume returns true if the cache can continue with the next token at + // the given position and sequence. Assumes that the caller has already + // verified the contents of the cache. + CanResume(seq int, pos int32) bool + + // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set + // endIndex to math.MaxInt32 to remove everything starting at beginIndex. + // + // If an error occurs, the entire context for the sequence should be + // removed by calling Remove(seq, 0, math.MaxInt32) + Remove(seq int, beginIndex, endIndex int32) error +} diff --git a/x/kvcache/causal.go b/x/kvcache/causal.go new file mode 100644 index 000000000..967fed674 --- /dev/null +++ b/x/kvcache/causal.go @@ -0,0 +1,797 @@ +package kvcache + +// import ( +// "errors" +// "fmt" +// "log/slog" +// "math" +// "slices" + +// "github.com/ollama/ollama/ml" +// "github.com/ollama/ollama/model/input" +// ) + +// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) + +// // Causal cache stores K and V tensors according to their position in the +// // sequence. Returns the history and a mask for attending to past tokens +// // +// // The tensors are of shape embed dim, kv heads, batch size +// // The mask is of shape history size, batch size +// type Causal struct { +// DType ml.DType + +// // swaWindowSize is the number of tokens that will be included in the mask +// // during attention operations. swaMemorySize is the number of tokens that +// // will be retained in memory for partial prefix caching. Set to math.MaxInt32 +// // for unlimited or if sliding window attention is not being used. +// swaWindowSize int32 +// swaMemorySize int32 + +// chunkSize int32 + +// opts CausalOptions + +// // maxBatch is the largest batch that we might receive +// maxBatch int + +// // config controls mostly backend-specific optimizations +// config *ml.CacheConfig + +// // ** current forward pass ** + +// // size of the current batch +// curBatchSize int + +// // locations for data storage for this batch +// curLoc ml.Tensor + +// // mask of the cache as used by this batch +// curMask ml.Tensor + +// // the active layer for Get and Put +// curLayer int + +// // locations in the cache that are needed for this batch +// curCellRange cellRange + +// // curSequences is the sequences corresponding to this pass's entries in the cache +// curSequences []int + +// // curPositions is the positions corresponding to this pass's entries in the cache +// curPositions []int32 + +// // ** cache metadata ** + +// // for each possible location in the cache, stores the position and set of sequences +// // that reference the data there +// cells []cacheCell + +// // maps from sequence to the range of locations where it is stored in the cache +// cellRanges map[int]cellRange + +// // ** cache data storage ** + +// shiftFn shiftFn +// backend ml.Backend +// ctxs map[int]ml.Context +// keys, values map[int]ml.Tensor + +// kHeadDims, vHeadDims, numKVHeads map[int]int +// } + +// type cacheCell struct { +// pos int32 +// sequences []int +// } + +// type cellRange struct { +// min int +// max int +// } + +// func NewCausalCache(shift shiftFn) *Causal { +// return &Causal{ +// shiftFn: shift, +// ctxs: make(map[int]ml.Context), +// keys: make(map[int]ml.Tensor), +// values: make(map[int]ml.Tensor), +// kHeadDims: make(map[int]int), +// vHeadDims: make(map[int]int), +// numKVHeads: make(map[int]int), +// } +// } + +// func NewSWACache(windowSize int32, shift shiftFn) *Causal { +// return &Causal{ +// swaWindowSize: windowSize, +// shiftFn: shift, +// ctxs: make(map[int]ml.Context), +// keys: make(map[int]ml.Tensor), +// values: make(map[int]ml.Tensor), +// kHeadDims: make(map[int]int), +// vHeadDims: make(map[int]int), +// numKVHeads: make(map[int]int), +// } +// } + +// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal { +// return &Causal{ +// swaWindowSize: windowSize, +// swaMemorySize: memorySize, +// shiftFn: shift, +// ctxs: make(map[int]ml.Context), +// keys: make(map[int]ml.Tensor), +// values: make(map[int]ml.Tensor), +// kHeadDims: make(map[int]int), +// vHeadDims: make(map[int]int), +// numKVHeads: make(map[int]int), +// } +// } + +// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal { +// return &Causal{ +// chunkSize: chunkSize, +// shiftFn: shift, +// ctxs: make(map[int]ml.Context), +// keys: make(map[int]ml.Tensor), +// values: make(map[int]ml.Tensor), +// kHeadDims: make(map[int]int), +// vHeadDims: make(map[int]int), +// numKVHeads: make(map[int]int), +// } +// } + +// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { +// if c.config == nil { +// var config ml.CacheConfig +// if cc, ok := backend.(ml.BackendCacheConfig); ok { +// config = cc.CacheConfig() +// } +// c.config = &config +// } + +// if c.config.CachePadding == 0 { +// c.config.CachePadding = 1 +// } + +// if c.config.MaskBatchPadding == 0 { +// c.config.MaskBatchPadding = 1 +// } + +// // TODO what types do we handle here? +// // if c.config.MaskDType == ml.DTypeOther { +// // c.config.MaskDType = ml.DTypeFloat32 +// // } + +// if c.swaWindowSize == 0 { +// c.swaWindowSize = math.MaxInt32 +// } +// if c.swaMemorySize == 0 { +// c.swaMemorySize = c.swaWindowSize +// } +// // We will allocate space in the cache for the stop token, which won't be part of a follow on +// // sequence, so allocate an extra token of storage to ensure that we can jump back without +// // causing a cache break. As an optimization, only do this when we have parallel sequences +// // because the extra token will live in the batch buffer and won't get overwritten if we +// // only have a single sequence. +// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 { +// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1) +// } +// if int(c.swaMemorySize) >= capacity { +// c.swaMemorySize = math.MaxInt32 +// } + +// if c.swaMemorySize < c.swaWindowSize { +// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize)) +// } + +// var cacheSize int +// if c.swaMemorySize == math.MaxInt32 { +// cacheSize = maxSequences * capacity +// } else { +// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch +// } +// cacheSize = roundUp(cacheSize, c.config.CachePadding) +// c.cells = make([]cacheCell, cacheSize) + +// c.DType = dtype +// c.cellRanges = make(map[int]cellRange) +// c.backend = backend +// c.maxBatch = maxBatch +// } + +// func (c *Causal) SetConfig(config ml.CacheConfig) { +// if c.config != nil { +// panic("config cannot be changed after being previously set, either by the model or backend") +// } + +// c.config = &config +// } + +// func (c *Causal) Close() { +// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs)) +// for _, ctx := range c.ctxs { +// ctx.Close() +// } +// } + +// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { +// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch) +// // panic("XXX Causal.StartForward") +// c.curBatchSize = len(batch.Positions) +// c.curSequences = batch.Sequences +// c.curPositions = batch.Positions +// c.opts.Except = nil + +// var locs []int32 +// if !reserve { +// c.updateSlidingWindow() + +// var err error +// locs, err = c.findLocs() +// if err != nil { +// return err +// } +// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs)) + +// for i, pos := range batch.Positions { +// seq := batch.Sequences[i] +// loc := int(locs[i]) + +// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}} + +// seqRange, ok := c.cellRanges[seq] +// if !ok { +// seqRange = newRange() +// } + +// seqRange.min = min(seqRange.min, loc) +// c.curCellRange.min = min(c.curCellRange.min, loc) + +// seqRange.max = max(seqRange.max, loc) +// c.curCellRange.max = max(c.curCellRange.max, loc) + +// c.cellRanges[seq] = seqRange +// } +// } else { +// // If we are reserving memory, don't update any of the cache metadata but set the size +// // to the worst case. +// locs = make([]int32, c.curBatchSize) +// for i := range locs { +// locs[i] = int32(i) +// } +// c.curCellRange.min = 0 +// c.curCellRange.max = len(c.cells) - 1 +// } + +// // XXX Building up the locs for what's already processed (if any) +// dummyLocs := []int{} +// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) +// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 + +// for i := range c.curBatchSize { +// enabled := !slices.Contains(c.opts.Except, i) +// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { +// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || +// (enabled && c.cells[j].pos > c.curPositions[i]) || +// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || +// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize { +// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) +// } else { +// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i { +// dummyLocs = append(dummyLocs, i) +// } +// } +// } +// } +// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs) + +// slog.Info("XXX Causal.StartForward", "locs", locs) +// c.curLoc = ctx.Input().FromInts(locs, len(locs)) +// c.curMask = c.buildMask(ctx) + +// return nil +// } + +// func newRange() cellRange { +// return cellRange{ +// min: math.MaxInt, +// max: 0, +// } +// } + +// // Returns a slice of locations where each token in the batch should be stored +// func (c *Causal) findLocs() ([]int32, error) { +// loc := make([]int32, 0, c.curBatchSize) + +// for i := range c.cells { +// if len(c.cells[i].sequences) == 0 { +// loc = append(loc, int32(i)) +// if len(loc) >= c.curBatchSize { +// return loc, nil +// } +// } +// } + +// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize) +// } + +// func (c *Causal) updateSlidingWindow() { +// c.curCellRange = newRange() + +// if c.swaMemorySize == math.MaxInt32 { +// for _, seq := range c.curSequences { +// if seqRange, ok := c.cellRanges[seq]; ok { +// c.curCellRange.min = min(c.curCellRange.min, seqRange.min) +// c.curCellRange.max = max(c.curCellRange.max, seqRange.max) +// } +// } + +// return +// } + +// type lowestPosition struct { +// pos int32 +// curBatch bool +// } + +// // create a map of unique sequences to the lowest position in that sequence +// lowestPos := make(map[int]lowestPosition) +// for i := range c.curPositions { +// seq := c.curSequences[i] + +// lowest, ok := lowestPos[seq] +// if !ok { +// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true} +// } else if c.curPositions[i] < lowest.pos { +// lowest.pos = c.curPositions[i] +// } + +// lowestPos[seq] = lowest +// } + +// // for any sequences are not part of this batch, clean up any tokens +// // that are no longer needed after the processing of the previous +// // batch +// for seq, seqRange := range c.cellRanges { +// if _, ok := lowestPos[seq]; !ok { +// var last int32 +// for i := seqRange.min; i <= seqRange.max; i++ { +// if slices.Contains(c.cells[i].sequences, seq) { +// last = max(last, c.cells[i].pos) +// } +// } + +// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false} +// } +// } + +// // delete any entries that are beyond the window of the oldest position in the sequence +// for seq, lowest := range lowestPos { +// oldRange, ok := c.cellRanges[seq] +// if !ok { +// continue +// } + +// newRange := newRange() + +// for i := oldRange.min; i <= oldRange.max; i++ { +// if slices.Contains(c.cells[i].sequences, seq) { +// if c.cells[i].pos < lowest.pos-c.swaMemorySize { +// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) +// } else { +// newRange.min = min(newRange.min, i) +// newRange.max = max(newRange.max, i) +// } +// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize { +// c.curCellRange.min = min(c.curCellRange.min, i) +// c.curCellRange.max = max(c.curCellRange.max, i) +// } +// } +// } + +// c.cellRanges[seq] = newRange +// } +// } + +// func roundDown(length, pad int) int { +// return (length / pad) * pad +// } + +// func roundUp(length, pad int) int { +// return ((length + pad - 1) / pad) * pad +// } + +// // Builds a mask of history x batch indicating whether for each token in the batch the +// // token in the history should apply. This is based on both the sequence and causality (the +// // position of the history is not ahead of the token in the batch). +// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { +// // Align and pad the two dimensions as required by the backend +// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) + +// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) +// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 + +// length := c.curCellRange.max - c.curCellRange.min + 1 + +// mask := make([]float32, batchSize*length) + +// for i := range c.curBatchSize { +// enabled := !slices.Contains(c.opts.Except, i) +// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { +// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || +// (enabled && c.cells[j].pos > c.curPositions[i]) || +// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || +// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize { +// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) +// } +// } +// } + +// // Mask out any padding tokens we added. For padding that we added to the cache history, this +// // has already been masked out because the sequence doesn't match. +// for i := c.curBatchSize * length; i < len(mask); i++ { +// mask[i] = float32(math.Inf(-1)) +// } + +// maskTensor := ctx.Input().FromFloats(mask, batchSize, length) + +// // if c.config.MaskDType != ml.DTypeFloat32 { +// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType) +// // } + +// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length}) + +// return maskTensor +// } + +// func (c *Causal) SetLayer(layer int) { +// c.curLayer = layer +// } + +// type CausalOptions struct { +// // Enabled controls whether the causal mask is generated for a particular index in a batch +// Except []int +// } + +// // SetCausal disables causal mask generation for a particular range of indicies in +// // the current batch for subsequent calls to Get. The state resets for the next forward pass. +// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { +// if !slices.Equal(c.opts.Except, opts.Except) { +// c.opts = opts +// if ctx != nil { +// c.curMask = c.buildMask(ctx) +// } +// } +// } + +// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { +// key := c.keys[c.curLayer] +// value := c.values[c.curLayer] + +// kHeadDim := c.kHeadDims[c.curLayer] +// vHeadDim := c.vHeadDims[c.curLayer] +// numKVHeads := c.numKVHeads[c.curLayer] +// // rowSize := numKVHeads * c.curBatchSize +// // cachedSize := c.curMask.Dim(1) +// cachedSize := c.curLoc.Dim(0) +// // kCellSize := kHeadDim * numKVHeads +// // vCellSize := vHeadDim * numKVHeads + +// slog.Info("XXX Causal.Get full cache", "key", key) +// slog.Info("XXX Causal.Get full cache", "value", value) +// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc) +// slog.Info("XXX Causal.Get", "curMask", c.curMask) +// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim) +// // panic("XXX") + +// // fmt.Fprintln(os.Stderr, key.ToString()) +// // panic("full cache value") + +// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask +// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim) +// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min) + +// // slog.Info("XXX Causal.Get after AsStrided", "key", key) +// // panic("XXX") + +// // if c.config.PermutedV { +// // panic("permuted") +// // // TODO not converted +// // vHeadDim := value.Dim(1) +// // elemSize := value.Stride(2) + +// // value = value.AsStrided(ctx, +// // []int{numKVHeads, vHeadDim, cachedSize}, +// // []int{value.Stride(0), value.Stride(1)}, +// // elemSize*c.curCellRange.min, +// // ) +// // } else { +// // vHeadDim := c.vHeadDims[c.curLayer] +// // rowSize := value.Stride(2) +// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize) +// // panic("XXX") + +// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask +// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim) +// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min) + +// // slog.Info("XXX Causal.Get after AsStrided", "value", value) +// // panic("XXX") + +// // } + +// // // TODO The mask changes from X,X to 1,X, and with the Row-order change +// // // the 1 becomes trailing and messes up later operations +// // // This isn't the right solution, but works around it... +// // if c.curMask.Dim(1) == 1 { +// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3) +// // } +// // fmt.Fprintln(os.Stderr, key.ToString()) +// // fmt.Fprintln(os.Stderr, value.ToString()) +// // panic("XXX") +// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape()) + +// return key, value, c.curMask +// } + +// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { +// kHeadDim := key.Dim(3) +// vHeadDim := value.Dim(3) +// numKVHeads := key.Dim(1) +// batchSize := key.Dim(2) +// kCellSize := kHeadDim * numKVHeads +// vCellSize := vHeadDim * numKVHeads + +// // slog.Info("XXX Causal.Put", "key", key, "value", value) +// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize) +// // panic("XXX") + +// if c.curBatchSize != batchSize { +// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) +// } + +// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend) +// if _, ok := c.ctxs[c.curLayer]; !ok { +// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer) +// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) +// } + +// if _, ok := c.keys[c.curLayer]; !ok { +// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize}) + +// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize) +// c.kHeadDims[c.curLayer] = kHeadDim +// c.vHeadDims[c.curLayer] = vHeadDim +// c.numKVHeads[c.curLayer] = numKVHeads +// } + +// if _, ok := c.values[c.curLayer]; !ok { +// // if c.config.PermutedV { +// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells)) +// // } else { +// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize) +// // } +// } + +// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed + +// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache) +// // panic("XXX") +// // curLoc := 0 // TODO c.curLoc is now a tensor +// // kSize := numKVHeads * kHeadDim +// // vSize := numKVHeads * vHeadDim +// // start := []int{int(curLoc), 0} +// // kStop := []int{int(curLoc + batchSize), int(kSize)} +// // vStop := []int{int(curLoc + batchSize), int(vSize)} +// // strides := []int{1, 1} + +// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache) +// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key) + +// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides) + +// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides)) +// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0})) +// // fmt.Fprintln(os.Stderr, keyCache.ToString()) +// // panic("input value") + +// // fmt.Fprintln(os.Stderr, t.ToString()) +// // panic("XXX") + +// // if c.config.PermutedV { +// // panic("permuted") +// // // TODO not adjusted +// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize) +// // value = value.Transpose(ctx, 2, 0, 1, 3) + +// // valueCache := c.values[c.curLayer] +// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads) + +// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides)) +// // } else { +// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed +// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache) +// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value) +// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides) + +// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0})) +// // } +// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString()) +// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString()) +// // panic("XXX") + +// } + +// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { +// seqRange := newRange() + +// for i := range c.cells { +// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end +// if slices.Contains(c.cells[i].sequences, dstSeq) { +// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq }) +// } + +// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len { +// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq) +// if i < seqRange.min { +// seqRange.min = i +// } +// if i > seqRange.max { +// seqRange.max = i +// } +// } +// } + +// c.cellRanges[dstSeq] = seqRange +// } + +// func (c *Causal) CanResume(seq int, pos int32) bool { +// if c.swaMemorySize == math.MaxInt32 { +// return true +// } + +// seqRange, ok := c.cellRanges[seq] +// if !ok { +// return false +// } + +// // for sliding window, check that the window of the new sequence is contained in +// // the window of what we are storing +// var first int32 = math.MaxInt32 +// var last int32 = -1 +// for i := seqRange.min; i <= seqRange.max; i++ { +// if slices.Contains(c.cells[i].sequences, seq) { +// first = min(first, c.cells[i].pos) +// last = max(last, c.cells[i].pos) +// } +// } + +// if last == -1 { +// return false +// } + +// posWindowStart := max(0, pos-c.swaWindowSize) +// return posWindowStart >= first && pos <= last+1 +// } + +// func (c *Causal) shift(seq int, beginIndex, offset int32) error { +// if c.shiftFn == nil { +// return ErrNotSupported +// } + +// seqRange := c.cellRanges[seq] + +// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { +// size := min(seqRange.max-start+1, c.maxBatch) +// offsets := make([]int32, size) + +// var batchFirst, batchLast int + +// batchFirst = -1 +// for i := range offsets { +// cell := c.cells[start+i] + +// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { +// offsets[i] = offset +// if batchFirst < 0 { +// batchFirst = i +// } +// batchLast = i +// } +// } + +// if batchFirst < 0 { +// continue +// } + +// offsets = offsets[batchFirst : batchLast+1] + +// slog.Info("XXX Causal.shift creating new temporary context") +// ctx := c.backend.NewContext() +// kShift := ctx.Input().FromInts(offsets, len(offsets)) + +// for i, key := range c.keys { +// if key == nil { +// continue +// } + +// kHeadDim := key.Dim(2) +// numKVHeads := key.Dim(1) +// rowSize := key.Stride(0) + +// key = key.AsStrided(ctx, +// []int{len(offsets), numKVHeads, kHeadDim}, +// []int{key.Stride(0), key.Stride(1)}, +// rowSize*(start+batchFirst), +// ) + +// roped, err := c.shiftFn(ctx, i, key, kShift) +// if err != nil { +// ctx.Close() +// return err +// } + +// ctx.Forward(roped.Copy(ctx, key)) +// } + +// ctx.Compute() +// ctx.Close() +// } + +// return nil +// } + +// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { +// // TODO(jessegross): We should check to see if removing the middle of the sequence will +// // cause the sliding window to encompass tokens that we no longer have. If so, then we +// // should return an error, which will trigger the runner to evaluate the full history and +// // rebuild the window. However, if we have multimodal inputs in our history, this reuse +// // results in use after free, so we don't do it for now. + +// var offset int32 +// if endIndex != math.MaxInt32 { +// offset = beginIndex - endIndex +// } + +// seqRange := newRange() + +// for i := range c.cells { +// if slices.Contains(c.cells[i].sequences, seq) { +// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex { +// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) +// } else { +// if c.cells[i].pos >= endIndex { +// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { +// return errors.New("shifting cells shared by multiple sequences not supported") +// } + +// c.cells[i].pos += offset +// } +// if i < seqRange.min { +// seqRange.min = i +// } +// if i > seqRange.max { +// seqRange.max = i +// } +// } +// } +// } + +// if seqRange == newRange() { +// delete(c.cellRanges, seq) +// return nil +// } + +// c.cellRanges[seq] = seqRange + +// if endIndex != math.MaxInt32 { +// err := c.shift(seq, endIndex+offset, offset) +// if err != nil { +// return err +// } +// } + +// return nil +// } diff --git a/x/kvcache/causal_test.go b/x/kvcache/causal_test.go new file mode 100644 index 000000000..d7ac430b1 --- /dev/null +++ b/x/kvcache/causal_test.go @@ -0,0 +1,973 @@ +package kvcache + +// import ( +// "fmt" +// "math" +// "slices" +// "testing" + +// "github.com/ollama/ollama/ml" +// "github.com/ollama/ollama/model/input" +// ) + +// type testCase struct { +// name string +// in []float32 +// inShape []int +// seqs []int +// pos []int32 +// expected []float32 +// expectedShape []int +// expectedMask []float32 +// } + +// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) { +// t.Helper() +// for _, permuted := range []bool{false, true} { +// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) { +// fn(t, &testBackend{permutedV: permuted}) +// }) +// } +// } + +// func TestStore(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewCausalCache(nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// tests := []testCase{ +// { +// name: "FirstBatch", +// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, +// inShape: []int{2, 3, 4}, +// seqs: []int{0, 0, 0, 0}, +// pos: []int32{0, 1, 2, 3}, +// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, +// expectedShape: []int{2, 3, 4}, +// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, +// }, +// { +// name: "SecondBatch", +// in: []float32{115, 215, 125, 225, 135, 235}, +// inShape: []int{2, 3, 1}, +// seqs: []int{0}, +// pos: []int32{4}, +// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, +// expectedShape: []int{2, 3, 5}, +// expectedMask: []float32{0, 0, 0, 0, 0}, +// }, +// } + +// testCache(t, backend, cache, tests) +// }) +// } + +// func TestSWA(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewSWACache(1, nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// x := float32(math.Inf(-1)) + +// tests := []testCase{ +// { +// name: "FirstBatch", +// in: []float32{1, 2, 3, 4}, +// inShape: []int{1, 1, 4}, +// seqs: []int{0, 0, 0, 0}, +// pos: []int32{0, 1, 2, 3}, +// expected: []float32{1, 2, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{ +// 0, x, x, x, +// 0, 0, x, x, +// x, 0, 0, x, +// x, x, 0, 0, +// }, +// }, +// { +// name: "SecondBatch", +// in: []float32{5, 6}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 0}, +// pos: []int32{4, 5}, +// expected: []float32{5, 6, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{ +// 0, x, x, 0, +// 0, 0, x, x, +// }, +// }, +// } + +// testCache(t, backend, cache, tests) +// }) +// } + +// func TestSWASeparateBatches(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewSWACache(1, nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 2, 16, 2) + +// x := float32(math.Inf(-1)) + +// tests := []testCase{ +// { +// name: "First seq 0", +// in: []float32{1, 2}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 0}, +// pos: []int32{0, 1}, +// expected: []float32{1, 2}, +// expectedShape: []int{1, 1, 2}, +// expectedMask: []float32{ +// 0, x, +// 0, 0, +// }, +// }, +// { +// name: "Second seq 0", +// in: []float32{3, 4}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 0}, +// pos: []int32{2, 3}, +// expected: []float32{2, 3, 4}, +// expectedShape: []int{1, 1, 3}, +// expectedMask: []float32{ +// 0, 0, x, +// x, 0, 0, +// }, +// }, +// { +// name: "First seq 1", +// in: []float32{5, 6}, +// inShape: []int{1, 1, 2}, +// seqs: []int{1, 1}, +// pos: []int32{0, 1}, +// expected: []float32{5, 6}, +// expectedShape: []int{1, 1, 2}, +// expectedMask: []float32{ +// 0, x, +// 0, 0, +// }, +// }, +// { +// name: "Second seq 1", +// in: []float32{7, 8}, +// inShape: []int{1, 1, 2}, +// seqs: []int{1, 1}, +// pos: []int32{2, 3}, +// expected: []float32{6, 3, 4, 7, 8}, +// expectedShape: []int{1, 1, 5}, +// expectedMask: []float32{ +// 0, x, x, 0, x, +// x, x, x, 0, 0, +// }, +// }, +// { +// name: "Third seq 0", +// in: []float32{9, 10}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 0}, +// pos: []int32{4, 5}, +// expected: []float32{9, 10, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{ +// 0, x, x, 0, +// 0, 0, x, x, +// }, +// }, +// } + +// testCache(t, backend, cache, tests) +// }) +// } + +// func TestSWAMem(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewSWAMemCache(1, 3, nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// x := float32(math.Inf(-1)) + +// tests := []testCase{ +// { +// name: "FirstBatch", +// in: []float32{1, 2, 3, 4}, +// inShape: []int{1, 1, 4}, +// seqs: []int{0, 0, 0, 0}, +// pos: []int32{0, 1, 2, 3}, +// expected: []float32{1, 2, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{ +// 0, x, x, x, +// 0, 0, x, x, +// x, 0, 0, x, +// x, x, 0, 0, +// }, +// }, +// { +// name: "SecondBatch", +// in: []float32{5, 6}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 0}, +// pos: []int32{4, 5}, +// expected: []float32{5, 2, 3, 4, 6}, +// expectedShape: []int{1, 1, 5}, +// expectedMask: []float32{ +// 0, x, x, 0, x, +// 0, x, x, x, 0, +// }, +// }, +// } + +// testCache(t, backend, cache, tests) +// }) +// } + +// func TestChunkedAttention(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewChunkedAttentionCache(2, nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// x := float32(math.Inf(-1)) + +// testCache( +// t, backend, cache, +// []testCase{ +// { +// name: "FirstBatch", +// in: []float32{1, 2, 3, 4}, +// inShape: []int{1, 1, 4}, +// seqs: []int{0, 0, 0, 0}, +// pos: []int32{0, 1, 2, 3}, +// expected: []float32{1, 2, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{ +// 0, x, x, x, +// 0, 0, x, x, +// x, x, 0, x, +// x, x, 0, 0, +// }, +// }, +// { +// name: "SecondBatch", +// in: []float32{5, 6, 7}, +// inShape: []int{1, 1, 3}, +// seqs: []int{0, 0, 0}, +// pos: []int32{4, 5, 6}, +// expected: []float32{1, 2, 3, 4, 5, 6, 7}, +// expectedShape: []int{1, 1, 7}, +// expectedMask: []float32{ +// x, x, x, x, 0, x, x, +// x, x, x, x, 0, 0, x, +// x, x, x, x, x, x, 0, +// }, +// }, +// { +// name: "ThirdBatch", +// in: []float32{8, 9}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 0}, +// pos: []int32{7, 8}, +// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, +// expectedShape: []int{1, 1, 9}, +// expectedMask: []float32{ +// x, x, x, x, x, x, 0, 0, x, +// x, x, x, x, x, x, x, x, 0, +// }, +// }, +// }, +// ) +// }) +// } + +// func TestSequences(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewCausalCache(nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// tests := []testCase{ +// { +// name: "FirstBatch", +// in: []float32{1, 2, 3, 4}, +// inShape: []int{1, 1, 4}, +// seqs: []int{0, 0, 1, 1}, +// pos: []int32{0, 1, 0, 1}, +// expected: []float32{1, 2, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, +// }, +// { +// name: "SecondBatch", +// in: []float32{5, 6}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 1}, +// pos: []int32{2, 2}, +// expected: []float32{1, 2, 3, 4, 5, 6}, +// expectedShape: []int{1, 1, 6}, +// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, +// }, +// } + +// testCache(t, backend, cache, tests) +// }) +// } + +// func TestRemove(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { +// return key.Add(ctx, shift), nil +// }) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// x := float32(math.Inf(-1)) + +// tests := []testCase{ +// { +// name: "FirstBatch", +// in: []float32{1, 2, 3, 4}, +// inShape: []int{1, 1, 4}, +// seqs: []int{0, 0, 1, 1}, +// pos: []int32{0, 1, 0, 1}, +// expected: []float32{1, 2, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{ +// 0, x, x, x, +// 0, 0, x, x, +// x, x, 0, x, +// x, x, 0, 0, +// }, +// }, +// } + +// testCache(t, backend, cache, tests) + +// err := cache.Remove(0, 1, math.MaxInt32) +// if err != nil { +// panic(err) +// } + +// tests = []testCase{ +// { +// name: "RemoveEnd", +// in: []float32{5, 6}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 1}, +// pos: []int32{1, 2}, +// expected: []float32{1, 5, 3, 4, 6}, +// expectedShape: []int{1, 1, 5}, +// expectedMask: []float32{ +// 0, 0, x, x, x, +// x, x, 0, 0, 0, +// }, +// }, +// } + +// testCache(t, backend, cache, tests) + +// err = cache.Remove(0, 0, 1) +// if err != nil { +// panic(err) +// } + +// tests = []testCase{ +// { +// name: "RemoveMiddle", +// in: []float32{7, 8}, +// inShape: []int{1, 1, 2}, +// seqs: []int{0, 0}, +// pos: []int32{1, 2}, +// expected: []float32{7, 4, 3, 4, 6, 8}, +// expectedShape: []int{1, 1, 6}, +// expectedMask: []float32{ +// 0, 0, x, x, x, x, +// 0, 0, x, x, x, 0, +// }, +// }, +// } + +// testCache(t, backend, cache, tests) +// }) +// } + +// func TestCopy(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// tests := []testCase{ +// { +// name: "FirstBatch", +// in: []float32{1, 2, 3, 4}, +// inShape: []int{1, 1, 4}, +// seqs: []int{0, 0, 0, 0}, +// pos: []int32{0, 1, 2, 3}, +// expected: []float32{1, 2, 3, 4}, +// expectedShape: []int{1, 1, 4}, +// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, +// }, +// } + +// testCache(t, backend, cache, tests) + +// cache.CopyPrefix(0, 1, 2) + +// tests = []testCase{ +// { +// name: "Copy", +// in: []float32{5, 6}, +// inShape: []int{1, 1, 2}, +// seqs: []int{1, 1}, +// pos: []int32{3, 4}, +// expected: []float32{1, 2, 3, 4, 5, 6}, +// expectedShape: []int{1, 1, 6}, +// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, +// }, +// } + +// testCache(t, backend, cache, tests) +// }) +// } + +// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { +// for _, test := range tests { +// t.Run(test.name, func(t *testing.T) { +// context := backend.NewContext() +// defer context.Close() + +// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) +// if err != nil { +// panic(err) +// } + +// cache.SetLayer(0) +// tensor := context.FromFloats(test.in, test.inShape...) +// cache.Put(context, tensor, tensor) + +// out, _, mask := cache.Get(context) + +// context.Forward(out, mask).Compute(out, mask) + +// if !slices.Equal(out.Floats(), test.expected) { +// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected) +// } + +// if !slices.Equal(out.Shape(), test.expectedShape) { +// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape) +// } + +// if !slices.Equal(mask.Floats(), test.expectedMask) { +// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask) +// } +// }) +// } +// } + +// func TestCanResume(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// windowSize := int32(4) +// cache := NewSWACache(windowSize, nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// context := backend.NewContext() +// defer context.Close() + +// err := cache.StartForward(context, input.Batch{ +// Positions: []int32{0, 1, 2, 3, 4}, +// Sequences: []int{0, 0, 0, 0, 0}, +// }, false) +// if err != nil { +// t.Fatalf("StartForward failed: %v", err) +// } + +// cache.SetLayer(0) +// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) +// cache.Put(context, tensor, tensor) + +// // with window size 4, nothing has slid out of the window yet +// if !cache.CanResume(0, 0) { +// t.Errorf("CanResume(0, 0) = false, want true (within window)") +// } +// if !cache.CanResume(0, 1) { +// t.Errorf("CanResume(0, 1) = false, want true (within window)") +// } +// if !cache.CanResume(0, 2) { +// t.Errorf("CanResume(0, 2) = false, want true (within window)") +// } +// if !cache.CanResume(0, 3) { +// t.Errorf("CanResume(0, 3) = false, want true (latest position)") +// } +// if !cache.CanResume(0, 4) { +// t.Errorf("CanResume(0, 4) = false, want true (latest position)") +// } + +// // shift window by adding position 5 +// err = cache.StartForward(context, input.Batch{ +// Positions: []int32{5}, +// Sequences: []int{0}, +// }, false) +// if err != nil { +// t.Fatalf("StartForward failed: %v", err) +// } + +// cache.SetLayer(0) +// tensor = context.FromFloats([]float32{6}, 1, 1, 1) +// cache.Put(context, tensor, tensor) + +// // only the latest position has overlapping windows +// if cache.CanResume(0, 0) { +// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") +// } +// if cache.CanResume(0, 1) { +// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") +// } +// if cache.CanResume(0, 2) { +// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") +// } +// if cache.CanResume(0, 3) { +// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") +// } +// if cache.CanResume(0, 4) { +// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") +// } +// if !cache.CanResume(0, 5) { +// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") +// } +// }) +// } + +// func TestCanResumeSWAMem(t *testing.T) { +// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { +// windowSize := int32(4) +// memSize := int32(5) +// cache := NewSWAMemCache(windowSize, memSize, nil) +// defer cache.Close() + +// cache.Init(backend, ml.DTypeF16, 1, 16, 16) + +// context := backend.NewContext() +// defer context.Close() + +// err := cache.StartForward(context, input.Batch{ +// Positions: []int32{0, 1, 2, 3, 4, 5, 6}, +// Sequences: []int{0, 0, 0, 0, 0, 0, 0}, +// }, false) +// if err != nil { +// t.Fatalf("StartForward failed: %v", err) +// } + +// cache.SetLayer(0) +// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) +// cache.Put(context, tensor, tensor) + +// // shift window by adding position 7 +// err = cache.StartForward(context, input.Batch{ +// Positions: []int32{7}, +// Sequences: []int{0}, +// }, false) +// if err != nil { +// t.Fatalf("StartForward failed: %v", err) +// } + +// cache.SetLayer(0) +// tensor = context.FromFloats([]float32{8}, 1, 1, 1) +// cache.Put(context, tensor, tensor) + +// // only the latest position has overlapping windows +// if cache.CanResume(0, 0) { +// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") +// } +// if cache.CanResume(0, 1) { +// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") +// } +// if cache.CanResume(0, 2) { +// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") +// } +// if cache.CanResume(0, 3) { +// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") +// } +// if cache.CanResume(0, 4) { +// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") +// } +// if cache.CanResume(0, 5) { +// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") +// } +// if !cache.CanResume(0, 6) { +// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") +// } +// if !cache.CanResume(0, 7) { +// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") +// } +// }) +// } + +// type testBackend struct { +// ml.Backend +// permutedV bool +// } + +// func (b *testBackend) NewContext() ml.Context { +// return &testContext{} +// } + +// func (b *testBackend) NewContextSize(int) ml.Context { +// return &testContext{} +// } + +// func (b *testBackend) CacheConfig() ml.CacheConfig { +// return ml.CacheConfig{PermutedV: b.permutedV} +// } + +// type testContext struct { +// ml.Context +// } + +// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { +// total := 0 + +// if len(shape) > 0 { +// total = 1 +// for _, s := range shape { +// total *= s +// } +// } + +// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} +// } + +// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +// return c.Empty(dtype, shape...) +// } + +// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor { +// t := c.Empty(ml.DTypeF32, shape...).(*testTensor) + +// copy(t.data, s) + +// return t +// } + +// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor { +// f := make([]float32, len(s)) +// for i := range f { +// f[i] = float32(s[i]) +// } + +// out := c.FromFloats(f, shape...) +// out.(*testTensor).dtype = ml.DTypeI32 + +// return out +// } + +// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { +// s := make([]float32, 0, int((stop-start)/step)) +// for i := start; i < stop; i += step { +// s = append(s, i) +// } + +// out := c.FromFloats(s, len(s)) +// out.(*testTensor).dtype = dtype +// return out +// } + +// func (c *testContext) Input() ml.Context { return c } +// func (c *testContext) Layer(int) ml.Context { return c } + +// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } + +// func (c *testContext) Compute(...ml.Tensor) {} + +// func (c *testContext) Reserve() {} + +// func (c *testContext) MaxGraphNodes() int { +// return 10 +// } + +// func (c *testContext) Close() {} + +// type testTensor struct { +// ml.Tensor + +// dtype ml.DType +// elementSize int +// data []float32 +// shape []int +// } + +// func (t *testTensor) Dim(n int) int { +// return t.shape[n] +// } + +// func (t *testTensor) Stride(n int) int { +// stride := t.elementSize +// for i := range n { +// stride *= t.shape[i] +// } + +// return stride +// } + +// func (t *testTensor) Shape() []int { +// return t.shape +// } + +// func (t *testTensor) DType() ml.DType { +// return t.dtype +// } + +// func (t *testTensor) Floats() []float32 { +// out := make([]float32, len(t.data)) +// copy(out, t.data) +// return out +// } + +// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor { +// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) +// for i := range out.data { +// out.data[i] = -t.data[i] +// } +// return out +// } + +// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { +// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) + +// for i := range out.data { +// out.data[i] = t.data[i] + t2.(*testTensor).data[i] +// } + +// return out +// } + +// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { +// return &testTensor{ +// dtype: t.dtype, +// elementSize: t.elementSize, +// data: t.data, +// shape: shape, +// } +// } + +// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { +// offset /= t.elementSize + +// var s []int + +// switch len(shape) { +// case 1: +// s = []int{shape[0]} +// case 3: +// s = []int{shape[0], shape[2]} +// case 5: +// s = []int{shape[0], shape[2], shape[4]} +// default: +// panic("unsupported number of dimensions") +// } + +// context := &testContext{} + +// view := context.Empty(t.dtype, s...).(*testTensor) +// view.data = t.data[offset : offset+len(view.data)] + +// return view +// } + +// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor { +// if len(t.shape) > 4 || len(order) > 4 { +// panic("permute only supports up to 4 dimensions") +// } + +// if len(order) != len(t.shape) && len(order) != 4 { +// panic("invalid number of dimensions for permute") +// } + +// // ggml_permute expects 4 axes, so fill in any missing dimensions. +// orderFull := append(make([]int, 0, 4), order...) +// for len(orderFull) < 4 { +// orderFull = append(orderFull, len(orderFull)) +// } + +// seen := [4]bool{} + +// shape4 := [4]int{1, 1, 1, 1} +// for i := 0; i < len(t.shape) && i < 4; i++ { +// shape4[i] = t.shape[i] +// } + +// newShape4 := [4]int{1, 1, 1, 1} +// for axis := range 4 { +// dst := orderFull[axis] +// if dst < 0 || dst >= 4 { +// panic("invalid axis for permute") +// } +// if seen[dst] { +// panic("duplicate axis for permute") +// } +// seen[dst] = true +// newShape4[dst] = shape4[axis] +// } + +// total := len(t.data) +// newData := make([]float32, total) + +// if total > 0 { +// oldDims := shape4 +// newDims := newShape4 + +// oldStride := [4]int{1, 1, 1, 1} +// newStride := [4]int{1, 1, 1, 1} +// for i := 1; i < 4; i++ { +// oldStride[i] = oldStride[i-1] * oldDims[i-1] +// newStride[i] = newStride[i-1] * newDims[i-1] +// } + +// var coords [4]int +// var newCoords [4]int + +// for idx := range total { +// remainder := idx +// for axis := range 4 { +// dim := oldDims[axis] +// if dim == 0 { +// coords[axis] = 0 +// continue +// } +// coords[axis] = remainder % dim +// remainder /= dim +// } + +// for axis := range 4 { +// newCoords[orderFull[axis]] = coords[axis] +// } + +// newIndex := 0 +// for axis := range 4 { +// if newDims[axis] == 0 { +// continue +// } +// newIndex += newCoords[axis] * newStride[axis] +// } + +// newData[newIndex] = t.data[idx] +// } +// } + +// numDims := 4 +// for numDims > 1 && newShape4[numDims-1] <= 1 { +// numDims-- +// } + +// newShape := make([]int, numDims) +// copy(newShape, newShape4[:numDims]) + +// return &testTensor{ +// dtype: t.dtype, +// elementSize: t.elementSize, +// data: newData, +// shape: newShape, +// } +// } + +// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { +// dst := t +// srcTensor := src.(*testTensor) +// idxTensor := idxs.(*testTensor) + +// shapeTo4D := func(shape []int) [4]int { +// out := [4]int{1, 1, 1, 1} +// for i := 0; i < len(shape) && i < 4; i++ { +// out[i] = shape[i] +// } +// return out +// } + +// computeStrides := func(shape [4]int) [4]int { +// out := [4]int{1, 1, 1, 1} +// for i := 1; i < 4; i++ { +// out[i] = out[i-1] * shape[i-1] +// } +// return out +// } + +// dstShape4D := shapeTo4D(dst.shape) +// srcShape4D := shapeTo4D(srcTensor.shape) +// idxShape4D := shapeTo4D(idxTensor.shape) + +// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] { +// panic("SetRows requires matching tensor shapes") +// } + +// if srcShape4D[1] != idxShape4D[0] { +// panic("SetRows rows/index mismatch") +// } + +// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 { +// panic("SetRows cannot broadcast indices") +// } + +// if idxShape4D[3] != 1 { +// panic("SetRows expects 1D or 2D index tensors") +// } + +// dstStride := computeStrides(dstShape4D) +// srcStride := computeStrides(srcShape4D) +// idxStride := computeStrides(idxShape4D) + +// numColumns := srcShape4D[0] +// numRows := srcShape4D[1] + +// for dim3Index := range dstShape4D[3] { +// for dim2Index := range dstShape4D[2] { +// idxDim2 := 0 +// idxDim3 := 0 +// if idxShape4D[1] > 0 { +// idxDim2 = dim2Index % idxShape4D[1] +// } +// if idxShape4D[2] > 0 { +// idxDim3 = dim3Index % idxShape4D[2] +// } + +// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1] +// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2] +// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2] + +// for row := range numRows { +// idx := int(idxTensor.data[idxBase+row*idxStride[0]]) +// if idx < 0 || idx >= dstShape4D[1] { +// panic("SetRows index out of range") +// } + +// srcOffset := srcBase + row*srcStride[1] +// dstOffset := dstBase + idx*dstStride[1] + +// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns]) +// } +// } +// } + +// return dst +// } + +// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { +// copy(t2.(*testTensor).data, t.data) +// return nil +// } diff --git a/x/kvcache/encoder.go b/x/kvcache/encoder.go new file mode 100644 index 000000000..19a3839ce --- /dev/null +++ b/x/kvcache/encoder.go @@ -0,0 +1,156 @@ +package kvcache + +// import ( +// "fmt" + +// "github.com/ollama/ollama/ml" +// "github.com/ollama/ollama/model/input" +// ) + +// // Encoder cache stores K and V tensors that are position independent +// // +// // The tensors can be of any shape and will be returned as they were stored +// // The mask is currently always nil +// // +// // Not currently safe for multiple sequences +// type EncoderCache struct { +// // config controls mostly backend-specific optimizations +// config *ml.CacheConfig + +// // ** current forward pass ** + +// // the active layer for Get and Put +// curLayer int + +// // if something is stored during this pass, this +// // will be the position (but there is no guarantee +// // anything will be stored) +// curPos int32 + +// // curReserve indicates that this forward pass is only for +// // memory reservation and we should not update our metadata +// // based on it. +// curReserve bool + +// // ** cache metadata ** + +// // was something stored in the cache? +// encoderCached bool + +// // position of the cached data +// encoderPos int32 + +// // ** cache data storage ** +// backend ml.Backend +// ctxs map[int]ml.Context +// keys, values map[int]ml.Tensor +// } + +// func NewEncoderCache() *EncoderCache { +// return &EncoderCache{ +// ctxs: make(map[int]ml.Context), +// keys: make(map[int]ml.Tensor), +// values: make(map[int]ml.Tensor), +// } +// } + +// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { +// if c.config == nil { +// var config ml.CacheConfig +// if cc, ok := backend.(ml.BackendCacheConfig); ok { +// config = cc.CacheConfig() +// } +// c.config = &config +// } + +// if maxSequences > 1 { +// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences)) +// } + +// if c.config.CachePadding != 0 && c.config.CachePadding != 1 { +// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) +// } + +// c.backend = backend +// } + +// func (c *EncoderCache) SetConfig(config ml.CacheConfig) { +// if c.config != nil { +// panic("config cannot be changed after being previously set, either by the model or backend") +// } + +// c.config = &config +// } + +// func (c *EncoderCache) Close() { +// for _, ctx := range c.ctxs { +// ctx.Close() +// } +// } + +// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { +// // We work with the most recent image +// if len(batch.Multimodal) > 0 { +// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index] +// } + +// c.curReserve = reserve + +// return nil +// } + +// func (c *EncoderCache) SetLayer(layer int) { +// c.curLayer = layer +// } + +// func (c *EncoderCache) EncoderCached() bool { +// return c.encoderCached +// } + +// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { +// return c.keys[c.curLayer], c.values[c.curLayer], nil +// } + +// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { +// if !c.curReserve { +// c.encoderPos = c.curPos +// c.encoderCached = true +// } + +// if c.config.PermutedV { +// value = value.Transpose(ctx, 1, 2, 0, 3) +// } + +// if _, ok := c.ctxs[c.curLayer]; !ok { +// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) +// } + +// if _, ok := c.keys[c.curLayer]; !ok { +// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...) +// } + +// if _, ok := c.values[c.curLayer]; !ok { +// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...) +// } + +// ctx.Forward( +// key.Copy(ctx, c.keys[c.curLayer]), +// value.Copy(ctx, c.values[c.curLayer]), +// ) +// } + +// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { +// panic("encoder cache does not support multiple sequences") +// } + +// func (c *EncoderCache) CanResume(seq int, pos int32) bool { +// return true +// } + +// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { +// if c.encoderPos >= beginIndex && c.encoderPos < endIndex { +// c.encoderCached = false +// } + +// return nil +// } diff --git a/x/kvcache/mlx.go b/x/kvcache/mlx.go new file mode 100644 index 000000000..fa3865104 --- /dev/null +++ b/x/kvcache/mlx.go @@ -0,0 +1,144 @@ +//go:build mlx + +package kvcache + +import ( + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/model/input" +) + +// Causal cache stores K and V tensors according to their position in the +// sequence. Returns the history and a mask for attending to past tokens +type MLXCausal struct { + DType ml.DType + + // locations for data storage for this batch + curLocPut ml.Tensor + + // locations for data storage for this batch + curLocGet ml.Tensor + + // the active layer for Get and Put + curLayer int + + capacity int + + offset int + + backend ml.Backend + ctxs map[int]ml.Context + keys, values map[int]ml.Tensor + + // TODO is this needed per layer, or will it always be consistent? + kHeadDims, vHeadDims, numKVHeads map[int]int +} + +func NewMLXCausalCache() *MLXCausal { + return &MLXCausal{ + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + kHeadDims: make(map[int]int), + vHeadDims: make(map[int]int), + numKVHeads: make(map[int]int), + } +} + +func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { + c.DType = dtype + c.capacity = capacity + c.backend = backend +} + +func (c *MLXCausal) SetConfig(config ml.CacheConfig) {} + +func (c *MLXCausal) SetLayer(layer int) { + c.curLayer = layer +} + +func (c *MLXCausal) Close() { + // slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs)) + for _, ctx := range c.ctxs { + ctx.Close() + } +} + +func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { + locsPut := make([]int32, len(batch.Positions)) + for i := c.offset; i < len(batch.Positions); i++ { + locsPut[i-c.offset] = int32(i) + } + c.offset += len(batch.Positions) + locsGet := make([]int32, c.offset) + for i := range c.offset { + locsGet[i] = int32(i) + } + c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet)) + c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut)) + // slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet) + + return nil +} +func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) { + kHeadDim := key.Dim(3) + vHeadDim := value.Dim(3) + numKVHeads := key.Dim(1) + batchSize := key.Dim(2) + kCellSize := kHeadDim * numKVHeads + vCellSize := vHeadDim * numKVHeads + // slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize) + + if _, ok := c.ctxs[c.curLayer]; !ok { + // slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer) + c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) + } + + if _, ok := c.keys[c.curLayer]; !ok { + // slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize}) + c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize) + c.kHeadDims[c.curLayer] = kHeadDim + c.vHeadDims[c.curLayer] = vHeadDim + c.numKVHeads[c.curLayer] = numKVHeads + } + key = key.Reshape(ctx, batchSize, 1, kCellSize) + + // slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer]) + // slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut) + // slog.Info("XXX MLXCausal.Put ", "key", key) + ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0})) + value = value.Reshape(ctx, batchSize, 1, vCellSize) + ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0})) + +} + +func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + key := c.keys[c.curLayer] + value := c.values[c.curLayer] + + kHeadDim := c.kHeadDims[c.curLayer] + vHeadDim := c.vHeadDims[c.curLayer] + numKVHeads := c.numKVHeads[c.curLayer] + // rowSize := numKVHeads * c.curBatchSize + // cachedSize := c.curMask.Dim(1) + cachedSize := c.curLocGet.Dim(0) + // kCellSize := kHeadDim * numKVHeads + // vCellSize := vHeadDim * numKVHeads + // slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim}) + + key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim) + value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim) + return key, value, nil +} + +func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) { + panic("not implemented") +} + +func (c *MLXCausal) CanResume(seq int, pos int32) bool { + panic("not implemented") +} + +func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error { + panic("not implemented") +} diff --git a/x/kvcache/wrapper.go b/x/kvcache/wrapper.go new file mode 100644 index 000000000..69e07dc96 --- /dev/null +++ b/x/kvcache/wrapper.go @@ -0,0 +1,110 @@ +package kvcache + +// import ( +// "math" + +// "github.com/ollama/ollama/ml" +// "github.com/ollama/ollama/model/input" +// ) + +// // Wrapper cache is a container for multiple types of caches, +// // such as for the encoding and decoding portions of a model. +// type WrapperCache struct { +// // caches we are wrapping +// caches []Cache + +// // cache to be used for this layer +// curType int +// } + +// func NewWrapperCache(caches ...Cache) *WrapperCache { +// return &WrapperCache{ +// caches: caches, +// } +// } + +// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { +// for _, cache := range c.caches { +// cache.Init(backend, dtype, maxSequences, capacity, maxBatch) +// } +// } + +// func (c *WrapperCache) SetConfig(config ml.CacheConfig) { +// for _, cache := range c.caches { +// cache.SetConfig(config) +// } +// } + +// func (c *WrapperCache) Close() { +// for _, cache := range c.caches { +// cache.Close() +// } +// } + +// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { +// for i, cache := range c.caches { +// err := cache.StartForward(ctx, batch, reserve) +// if err != nil { +// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail +// for j := i - 1; j >= 0; j-- { +// for k := range batch.Positions { +// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32) +// } +// } +// return err +// } +// } + +// c.curType = 0 +// return nil +// } + +// func (c *WrapperCache) SetLayer(layer int) { +// for _, cache := range c.caches { +// cache.SetLayer(layer) +// } +// } + +// func (c *WrapperCache) SetLayerType(layerType int) { +// c.curType = layerType +// } + +// func (c *WrapperCache) UnderlyingCache() Cache { +// return c.caches[c.curType] +// } + +// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { +// return c.caches[c.curType].Get(ctx) +// } + +// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) { +// c.caches[c.curType].Put(ctx, key, value) +// } + +// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) { +// for _, cache := range c.caches { +// cache.CopyPrefix(srcSeq, dstSeq, len) +// } +// } + +// func (c *WrapperCache) CanResume(seq int, pos int32) bool { +// for _, cache := range c.caches { +// if !cache.CanResume(seq, pos) { +// return false +// } +// } + +// return true +// } + +// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { +// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail +// for _, cache := range c.caches { +// err := cache.Remove(seq, beginIndex, endIndex) +// if err != nil { +// return err +// } +// } + +// return nil +// } diff --git a/x/ml/backend.go b/x/ml/backend.go new file mode 100644 index 000000000..31ff3541e --- /dev/null +++ b/x/ml/backend.go @@ -0,0 +1,433 @@ +package ml + +import ( + "fmt" + "log/slog" + "os" + + "github.com/ollama/ollama/fs" +) + +type Backend interface { + // Close frees all memory associated with this backend + // Close() + + // Load(ctx context.Context, progress func(float32)) error + + // BackendMemory returns the memory allocations that were made for this model + // BackendMemory() BackendMemory + + Config() fs.Config + Get(name string) Tensor + NewContext() Context + // NewContextSize(size int) Context + + // Enumerate the devices available for inference via this backend + // BackendDevices() []DeviceInfo +} + +// BackendCacheConfig should be implemented by backends that need special output +// from the cache to meet specific requirements. It is frequently implemented in +// conjunction with ScaledDotProductAttention. +type BackendCacheConfig interface { + CacheConfig() CacheConfig +} + +// CacheConfig controls optimizations (mostly backend-specific) that may transform +// the output the cache to work better with specific kernels. +type CacheConfig struct { + // CachePadding specifies the multiple for the number of tokens of cache history + // that will be returned from cache Get for k, v and mask. The capacity of the + // cache itself will also be increased to a multiple of this size if needed. + CachePadding int + + // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put + // and return the permuted version via Get. This uses the cache copy operation + // to avoid a Contiguous call on the permuted tensor. + PermutedV bool + + // MaskDType specifies the data type for generating the mask. If unset it will + // default to DTypeF32. + MaskDType DType + + // MaskBatchPadding specifies the multiple for the batch size dimension in the mask. + // Any position that does not correspond to an actual token will be filled with -Inf. + MaskBatchPadding int +} + +// BackendParams controls how the backend loads and executes models +type BackendParams struct { + // AllocMemory causes the backend to allocate memory for the model. If + // false, this is only being used for discovering the required amount of + // memory and cannot load the model for running. + AllocMemory bool + + // NumThreads sets the number of threads to use if running on the CPU + NumThreads int + + // GPULayers is the set of layers to offload to GPUs + GPULayers GPULayersList + + // FlashAttention indicates that we should use a fused flash attention kernel + FlashAttention bool +} + +var backends = make(map[string]func(string, BackendParams) (Backend, error)) + +func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { + if _, ok := backends[name]; ok { + panic("backend: backend already registered") + } + + backends[name] = f +} + +func NewBackend(modelPath string, params BackendParams) (Backend, error) { + be := os.Getenv("OLLAMA_BACKEND") + if be == "" { + be = "mlx" + slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override") + } + slog.Info("Loading new engine", "backend", be) + if backend, ok := backends[be]; ok { + return backend(modelPath, params) + } + + return nil, fmt.Errorf("unsupported backend") +} + +type Context interface { + Empty(dtype DType, shape ...int) Tensor + Zeros(dtype DType, shape ...int) Tensor + // FromBytes(dtype DType, s []byte, shape ...int) Tensor + FromFloats(s []float32, shape ...int) Tensor + FromInts(s []int32, shape ...int) Tensor + RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor + + // Arange creates a 1D tensor with values within an interval (start, stop] increased by step. + Arange(start, stop, step float32, dtype DType) Tensor + + Forward(...Tensor) Context + + // SetBatchSize provides a hint on the batch size to optimize processing + // Uses heuristics if not set + // SetBatchSize(int) + + Compute(...Tensor) + // ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun + + // Reserve is analogous to Compute but rather than executing a + // graph, simply preallocates memory. Typically called with a + // worst case graph to ensure all resources are available for + // for future inference. + // Reserve() + + // MaxGraphNodes() int + Close() + + // Input returns a context appropriate for creating tensors that are + // inputs to the model (which includes things like output locations) + Input() Context + + // Layer returns a context appropriate for creating intermediate tensors + Layer(int) Context + + // Load a tensor from "filename" safetensors file, and compare with the input tensor + // Returns error if the shape is inconsistent, or similarity measures are below 99% + CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error +} + +type RoPEOptions struct { + Base *float32 + Freqs Tensor +} + +func WithRoPEBase(base float32) func(*RoPEOptions) { + return func(opts *RoPEOptions) { + opts.Base = &base + } +} + +func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) { + return func(opts *RoPEOptions) { + opts.Freqs = freqs + } +} + +type Tensor interface { + ToString() string + RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor + ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor + TakeAxes(ctx Context, indicies Tensor, axes int) Tensor + // TakeAxes(ctx Context, axes int, indicies ...int) Tensor + + Dim(n int) int + Stride(n int) int + + Shape() []int + DType() DType + // Cast(ctx Context, dtype DType) Tensor + + // Bytes() []byte + Floats() []float32 + Ints() []int32 + + // FromBytes([]byte) + // FromFloats([]float32) + // FromInts([]int32) + + Add(ctx Context, t2 Tensor) Tensor + Sub(ctx Context, t2 Tensor) Tensor + // Mul(ctx Context, t2 Tensor) Tensor + // Div(ctx Context, t2 Tensor) Tensor + + Max(ctx Context, axes []int, keepDims bool) Tensor + Min(ctx Context, axes []int, keepDims bool) Tensor + + Matmul(ctx Context, a2 Tensor) Tensor + // Mulmat(ctx Context, t2 Tensor) Tensor + // MulmatFullPrec(ctx Context, t2 Tensor) Tensor + // MulmatID(ctx Context, t2, ids Tensor) Tensor + // AddID(ctx Context, t2, ids Tensor) Tensor + + Softmax(ctx Context) Tensor + L2Norm(ctx Context, eps float32) Tensor + LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor + RMSNorm(ctx Context, weight Tensor, eps float32) Tensor + Scale(ctx Context, s float64) Tensor + // SumRows(ctx Context) Tensor + + AvgPool2D(ctx Context, k, s int, p float32) Tensor + Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor + Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor + + // IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor + + // Sin(ctx Context) Tensor + // Cos(ctx Context) Tensor + // Tanh(ctx Context) Tensor + GELU(ctx Context, up ...Tensor) Tensor + // QuickGELU(ctx Context, up ...Tensor) Tensor + // SILU(ctx Context, up ...Tensor) Tensor + // RELU(ctx Context, up ...Tensor) Tensor + // Sigmoid(ctx Context) Tensor + + // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit] + // SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor + + Reshape(ctx Context, shape ...int) Tensor + AsStrided(ctx Context, shape, strides []int, offset int) Tensor + Transpose(ctx Context, shape ...int) Tensor + Contiguous(ctx Context, allowColMajor bool) Tensor + + // Pad(ctx Context, shape ...int) Tensor + + // Stack(ctx Context, dim int, s ...Tensor) Tensor + + // Repeat repeats the tensor n times along dimension dim + // Repeat(ctx Context, dim, n int) Tensor + // Concat(ctx Context, t2 Tensor, dim int) Tensor + // Rows(ctx Context, t2 Tensor) Tensor + + // TODO these probably aren't actually needed - false starts on trying to wire up cache + // SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor + // SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor + // PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor + + Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor + + Copy(ctx Context, t2 Tensor) Tensor + // Duplicate(ctx Context) Tensor + + // Slice(ctx Context, dim, low, high, step int) Tensor + // Chunk(ctx Context, dim int, size int) []Tensor + // ChunkSections(ctx Context, dim int, sections ...int) []Tensor + + // TopK(ctx Context, k int) Tensor + // Argsort(ctx Context) Tensor + // Mean(ctx Context) Tensor + // Variance(ctx Context) Tensor + // Stddev(ctx Context) Tensor + // Sqr(ctx Context) Tensor + // Sqrt(ctx Context) Tensor + + // Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor +} + +// ScaledDotProductAttention implements a fused attention +// operation equivalent to following code on a tensor named +// query: +// +// query = query.Permute(ctx, 0, 2, 1, 3) +// key = key.Permute(ctx, 0, 2, 1, 3) +// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) +// +// kq := key.MulmatFullPrec(ctx, query) +// +// kq = kq.Scale(ctx, scale) +// +// if mask != nil { +// kq = kq.Add(ctx, mask) +// } +// +// kq = kq.Softmax(ctx) +// +// kqv := value.Mulmat(ctx, kq) +// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +// type ScaledDotProductAttention interface { +// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor +// } + +// type number interface { +// ~int | ~int8 | ~int16 | ~int32 | ~int64 | +// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | +// ~float32 | ~float64 | +// ~complex64 | ~complex128 +// } + +// func mul[T number](s ...T) T { +// p := T(1) +// for _, v := range s { +// p *= v +// } + +// return p +// } + +// type DumpOptions func(*dumpOptions) + +// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64. +// func DumpWithPrecision(n int) DumpOptions { +// return func(opts *dumpOptions) { +// opts.Precision = n +// } +// } + +// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements +// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the +// // beginning and end of each dimension will be printed. +// func DumpWithThreshold(n int) DumpOptions { +// return func(opts *dumpOptions) { +// opts.Threshold = n +// } +// } + +// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension. +// func DumpWithEdgeItems(n int) DumpOptions { +// return func(opts *dumpOptions) { +// opts.EdgeItems = n +// } +// } + +// type dumpOptions struct { +// Precision, Threshold, EdgeItems int +// } + +// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string { +// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3} +// for _, optsFunc := range optsFuncs { +// optsFunc(&opts) +// } + +// if mul(t.Shape()...) <= opts.Threshold { +// opts.EdgeItems = math.MaxInt +// } + +// switch t.DType() { +// case DTypeFloat32: +// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string { +// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32) +// }) +// case DTypeFloat16: // TODO other types... +// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...) +// f32 = t.Copy(ctx, f32) +// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string { +// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32) +// }) +// case DTypeInt32: +// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string { +// return strconv.FormatInt(int64(i), 10) +// }) +// default: +// return "" +// } +// } + +// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string { +// if t.Bytes() == nil { +// ctx.Compute(t) +// } + +// s := make(S, mul(t.Shape()...)) +// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil { +// panic(err) +// } + +// shape := t.Shape() +// slices.Reverse(shape) + +// var sb strings.Builder +// var f func([]int, int) +// f = func(dims []int, stride int) { +// prefix := strings.Repeat(" ", len(shape)-len(dims)+1) +// sb.WriteString("[") +// defer func() { sb.WriteString("]") }() +// for i := 0; i < dims[0]; i++ { +// if i >= items && i < dims[0]-items { +// sb.WriteString("..., ") +// // skip to next printable element +// skip := dims[0] - 2*items +// if len(dims) > 1 { +// stride += mul(append(dims[1:], skip)...) +// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix) +// } +// i += skip - 1 +// } else if len(dims) > 1 { +// f(dims[1:], stride) +// stride += mul(dims[1:]...) +// if i < dims[0]-1 { +// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) +// } +// } else { +// text := fn(s[stride+i]) +// if len(text) > 0 && text[0] != '-' { +// sb.WriteString(" ") +// } + +// sb.WriteString(text) +// if i < dims[0]-1 { +// sb.WriteString(", ") +// } +// } +// } +// } +// f(shape, 0) + +// return sb.String() +// } + +type DType int + +const ( + DTypeBool DType = iota + DTypeUint8 + DTypeUint16 + DTypeUint32 + DTypeUint64 + DTypeInt8 + DTypeInt16 + DTypeInt32 + DTypeInt64 + DTypeFloat16 + DTypeFloat32 + DTypeFloat64 + DTypeBfloat16 + DTypeComplex64 +) + +type SamplingMode int + +const ( + SamplingModeNearest SamplingMode = iota + SamplingModeBilinear +) diff --git a/x/ml/backend/backend.go b/x/ml/backend/backend.go new file mode 100644 index 000000000..b9dd4a13b --- /dev/null +++ b/x/ml/backend/backend.go @@ -0,0 +1,3 @@ +package backend + +// _ "github.com/ollama/ollama/x/ml/backend/mlx" diff --git a/x/ml/backend/mlx/CMakeLists.txt b/x/ml/backend/mlx/CMakeLists.txt new file mode 100644 index 000000000..e71a6567a --- /dev/null +++ b/x/ml/backend/mlx/CMakeLists.txt @@ -0,0 +1,57 @@ +include(FetchContent) + +set(MLX_C_BUILD_EXAMPLES OFF) + +set(MLX_BUILD_GGUF OFF) +set(MLX_BUILD_SAFETENSORS ON) + +function(set_target_output_directory _target) + if(TARGET ${_target}) + set_target_properties(${_target} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR} + LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR} + ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR} + ) + endif() +endfunction() + +# Check for Metal support (macOS only) +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + execute_process( + COMMAND + zsh "-c" + "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" + OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) + + if(NOT MLX_METAL_VERSION) + message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF") + set(MLX_BUILD_METAL OFF) + endif() +else() + # On Linux, disable Metal backend + message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF") + set(MLX_BUILD_METAL OFF) +endif() + +# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set +if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES) + set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}) + message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}") +endif() + +# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available +if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER) + set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE) + message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}") +elseif(MLX_CUDA_ARCHITECTURES) + message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled") +endif() + +FetchContent_Declare( + mlx-c + GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" + GIT_TAG v0.4.1) +FetchContent_MakeAvailable(mlx-c) + +set_target_output_directory(mlx) +set_target_output_directory(mlxc) diff --git a/x/ml/backend/mlx/mlx.go b/x/ml/backend/mlx/mlx.go new file mode 100644 index 000000000..1b647685e --- /dev/null +++ b/x/ml/backend/mlx/mlx.go @@ -0,0 +1,1278 @@ +//go:build mlx + +package mlx + +/* +#cgo CPPFLAGS: -I${SRCDIR}/../../../../build/_deps/mlx-c-src +#cgo LDFLAGS: -L${SRCDIR}/../../../../build/lib/ollama/ -lmlxc -lmlx +#cgo LDFLAGS: -framework Accelerate +#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../../build/lib/ollama/ +#include +#include "mlx/c/mlx.h" +static inline size_t stride(const mlx_array a, int i) {return mlx_array_strides(a)[i];} + +extern void goStackTrace(); +static void error_handler(const char *msg, void* data) { + fprintf(stderr, "MLX error: %s\n", msg); + goStackTrace(); + exit(-1); // TODO adjust so this can become a return code on the current thread instead of exit +} +static void set_error_handler() {mlx_set_error_handler(&error_handler, NULL, NULL);} +static void* mlx_array_data_float16_asvoid(const mlx_array a) {return (void*)mlx_array_data_float16(a);} +typedef const char cchar_t; +*/ +import "C" + +import ( + "encoding/json" + "fmt" + "log/slog" + "math" + "os" + "path/filepath" + "reflect" + "runtime" + "runtime/debug" + "sync" + "unsafe" + + "github.com/ollama/ollama/convert" + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/x/ml" + "github.com/x448/float16" +) + +func init() { + ml.RegisterBackend("mlx", New) + C.set_error_handler() +} + +//export goStackTrace +func goStackTrace() { + debug.PrintStack() +} + +type SafetensorsIndexMetadata struct { + TotalSize uint64 `json:"total_size"` +} +type SafetensorsIndex struct { + Metadata SafetensorsIndexMetadata `json:"metadata"` + WeightMap map[string]string `json:"weight_map"` +} + +type Backend struct { + meta fs.Config + tensors map[string]*Array +} + +func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { + // TODO assumes modelPath is actually a directory for now... + kv, tokenizer, err := convert.LoadModelMetadata(os.DirFS(modelPath)) + if err != nil { + return nil, fmt.Errorf("unable to load model: %w", err) + } + + b := &Backend{ + meta: kv.KV(tokenizer), + } + + err = b.LoadSafeTensors(modelPath) + if err != nil { + return nil, fmt.Errorf("safetensors load failed: %w", err) + } + return b, nil +} + +func (b *Backend) LoadSafeTensors(dir string) error { + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("failed to stat dir: %w", err) + } + // other variations to try? + stFilename := filepath.Join(dir, "model.safetensors.index.json") + if _, err := os.Stat(stFilename); err != nil { + return fmt.Errorf("failed to stat %s: %w", stFilename, err) + } + + fp, err := os.Open(stFilename) + if err != nil { + return fmt.Errorf("failed to open safetensor index: %s: %w", stFilename, err) + } + decoder := json.NewDecoder(fp) + var index SafetensorsIndex + if err := decoder.Decode(&index); err != nil { + return fmt.Errorf("decode error: %s: %w", stFilename, err) + } + slog.Info("XXX parsed metadata", "size", index.Metadata.TotalSize, "weights", len(index.WeightMap)) + filenames := map[string]struct{}{} + for _, filename := range index.WeightMap { + filenames[filename] = struct{}{} + } + stream := C.mlx_default_cpu_stream_new() + + b.tensors = map[string]*Array{} + + for filename := range filenames { + filepath := filepath.Join(dir, filename) + if _, err := os.Stat(filepath); err != nil { + return fmt.Errorf("failed to stat %s: %w", filepath, err) + } + slog.Info("Loading tensors from", "filename", filename) + cFilename := C.CString(filepath) + defer C.free(unsafe.Pointer(cFilename)) + data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it? + metadata := C.mlx_map_string_to_string_new() + defer C.mlx_map_string_to_array_free(data) + defer C.mlx_map_string_to_string_free(metadata) + + if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 { + // TODO with the current error handling, this will never happen + return fmt.Errorf("load failed") + } + + it := C.mlx_map_string_to_array_iterator_new(data) + // defer C.mlx_array_free(shaped) + // TODO confusing how memory management works with this... + for { + var key *C.cchar_t + var value C.mlx_array + if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 { + break + } + k := C.GoString((*C.char)(key)) + b.tensors[k] = &Array{ + name: k, + a: value, + } + // slog.Info("XXX read", "tensor", b.tensors[k], "type", b.tensors[k].TypeString()) + } + } + + return nil +} + +func (b *Backend) Get(name string) ml.Tensor { + var t ml.Tensor + var ok bool + if t, ok = b.tensors[name]; !ok { + // slog.Warn("unable to locate", "tensor", name) + return nil + } + // slog.Info("Fetching", "tensor", name, "type", b.tensors[name].TypeString()) + return t +} + +func (b *Backend) NewContext() ml.Context { + // slog.Info("MLX.NewContext") + return &Context{ + stream: C.mlx_default_gpu_stream_new(), + } +} + +func (b *Backend) Config() fs.Config { + return b.meta +} + +type Context struct { + stream C.mlx_stream + + mu sync.Mutex + arrays []C.mlx_array // TODO should we do some bookkeeping to ensure none of these Arrays are still lingering? +} + +func (c *Context) Close() { + // C.mlx_synchronize(c.stream) // ??? + C.mlx_stream_free(c.stream) + + c.mu.Lock() + defer c.mu.Unlock() + for _, a := range c.arrays { + slog.Info("XXX freeing", "array", a) + C.mlx_array_free(a) + } +} + +func (c *Context) Compute(tensors ...ml.Tensor) { + // TODO - for the zero tensor case this feels like it might not be correct... + needSync := true + sync := func() { + if needSync { + C.mlx_synchronize(c.stream) + needSync = false + } + } + + vec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vec) + for _, t := range tensors { + C.mlx_vector_array_append_value(vec, t.(*Array).a) + t.(*Array).sync = sync + } + C.mlx_async_eval(vec) +} + +func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { + vec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vec) + needSync := true + sync := func() { + if needSync { + C.mlx_synchronize(c.stream) + needSync = false + } + } + + for _, t := range tensors { + t.(*Array).sync = sync + C.mlx_vector_array_append_value(vec, t.(*Array).a) + } + C.mlx_async_eval(vec) + return c +} + +func (c *Context) Input() ml.Context { + return c +} + +// func (c *Context) Output() ml.Context { +// return c +// } + +func (c *Context) Layer(_ int) ml.Context { + return c +} + +func (c *Context) RandomNormal(shape []int, dtype ml.DType, loc, scale float32, key ml.Tensor) ml.Tensor { + var r C.mlx_array + var k C.mlx_array + if key != nil { + k = key.(*Array).a + } + sh := make([]C.int, len(shape)) + for i := range shape { + sh[i] = C.int(shape[i]) + } + C.mlx_random_normal( + &r, + &sh[0], + C.size_t(len(shape)), + C.mlx_dtype(dtype), + C.float(loc), + C.float(scale), + k, + c.stream, + ) + return newArray(c, r) +} + +func (c *Context) CompareWith(filepath string, tensors map[string]ml.Tensor, abortOnError bool) (err error) { + minCosine := float32(0.96) // TODO too low... + fileTensors := map[string]*Array{} + defer func() { + if err != nil { + for k, v := range tensors { + fmt.Fprintln(os.Stderr, "input tensor "+k+"\n"+v.ToString()) + if fv, ok := fileTensors[k]; ok { + fmt.Fprintln(os.Stderr, " file tensor "+k+"\n"+fv.ToString()) + } else { + fmt.Fprintln(os.Stderr, " file tensor "+k+" missing!\n") + } + } + } + if abortOnError { + if err != nil { + panic(fmt.Sprintf("%s", err)) + } + } + }() + if _, err = os.Stat(filepath); err != nil { + filepath += ".safetensors" + if _, err = os.Stat(filepath); err != nil { + err = fmt.Errorf("failed to stat %s: %w", filepath, err) + return + } + err = nil + } + // slog.Info("Loading tensors from", "filename", filepath) + cFilename := C.CString(filepath) + defer C.free(unsafe.Pointer(cFilename)) + data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it? + metadata := C.mlx_map_string_to_string_new() + defer C.mlx_map_string_to_array_free(data) + defer C.mlx_map_string_to_string_free(metadata) + + stream := C.mlx_default_cpu_stream_new() + + if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 { + // TODO with the current error handling, this will never happen + err = fmt.Errorf("load failed") + return + } + + it := C.mlx_map_string_to_array_iterator_new(data) + allTensors := []ml.Tensor{} + for _, t := range tensors { + allTensors = append(allTensors, t) + } + + for { + var key *C.cchar_t + var value C.mlx_array + defer C.mlx_array_free(value) + if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 { + break + } + k := C.GoString((*C.char)(key)) + var r C.mlx_array + defer C.mlx_array_free(r) + C.mlx_astype( + &r, + value, + C.MLX_FLOAT32, + stream, + ) + + fileTensors[k] = &Array{ + name: k, + a: r, + } + // slog.Info("XXX read", "tensor", t, "type", t.TypeString()) + allTensors = append(allTensors, fileTensors[k]) + } + c.Forward(allTensors...) + for k, t := range tensors { + a, ok := fileTensors[k] + if !ok { + err = fmt.Errorf("tensor named %s not found in file", k) + return + } + if !reflect.DeepEqual(a.Shape(), t.Shape()) { + err = fmt.Errorf("mismatched shapes: file: %v vs. input %v", a.Shape(), t.Shape()) + return + } + // slog.Info("XXX shapes match", "shape", t.Shape()) + // TODO handle int types... + tDType := t.DType() + if tDType != ml.DTypeFloat16 && tDType != ml.DTypeFloat32 { + var r C.mlx_array + defer C.mlx_array_free(r) + C.mlx_astype( + &r, + t.(*Array).a, + C.MLX_FLOAT32, + stream, + ) + t = &Array{ + a: r, + } + c.Forward(t) + } + + af := a.Floats() + tf := t.Floats() + cos := cosineSimilarity(af, tf) + diff := a.Sub(c, t) + min := diff.Min(c, nil, true) + max := diff.Max(c, nil, true) + c.Forward(min, max) + minf := min.Floats() + maxf := max.Floats() + if cos < minCosine { + err = fmt.Errorf("%s shapes match, but not similar enough: %v min_difference=%v max_difference=%v", k, cos, minf, maxf) + return + } + + slog.Info("XXX tensors are similar", k, cos, "shape", t.Shape(), "min_difference", minf, "max_difference", maxf) + } + err = nil + + return +} + +func dotProduct[V float32 | float64](v1, v2 []V) V { + var result V = 0 + if len(v1) != len(v2) { + return result + } + + for i := 0; i < len(v1); i++ { + result += v1[i] * v2[i] + } + return result +} + +func magnitude[V float32 | float64](v []V) V { + var result V = 0 + for _, val := range v { + result += val * val + } + return V(math.Sqrt(float64(result))) +} + +func cosineSimilarity[V float32 | float64](v1, v2 []V) V { + mag1 := magnitude(v1) + mag2 := magnitude(v2) + + if mag1 == 0 || mag2 == 0 { + return 0 + } + + return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2)) +} + +func euclideanDistance[V float32 | float64](v1, v2 []V) V { + if len(v1) != len(v2) { + return V(math.Inf(1)) + } + + var sum V = 0 + for i := 0; i < len(v1); i++ { + diff := v1[i] - v2[i] + sum += diff * diff + } + + return V(math.Sqrt(float64(sum))) +} + +func manhattanDistance[V float32 | float64](v1, v2 []V) V { + if len(v1) != len(v2) { + return V(math.Inf(1)) + } + + var sum V = 0 + for i := 0; i < len(v1); i++ { + sum += V(math.Abs(float64(v1[i] - v2[i]))) + } + + return sum +} + +type Array struct { + name string + a C.mlx_array + c *Context + + sync func() +} + +func newArray(ctx *Context, a C.mlx_array) *Array { + // TODO measure impact and if this slows things down, make it conditional on some debugging flag at load time + var name string + _, f, l, ok := runtime.Caller(2) + if ok { + name = fmt.Sprintf("%s:%d", f, l) + } + + t := &Array{ + name: name, + a: a, + c: ctx, + } + // DEBUG memory allocation problems... + // slog.Info("XXX Allocated", "array", t, "a", a) + ctx.mu.Lock() + defer ctx.mu.Unlock() + ctx.arrays = append(ctx.arrays, a) + return t +} + +// FromFloats implements ml.Context. +func (c *Context) FromFloats(s []float32, shape ...int) ml.Tensor { + u16s := make([]float16.Float16, len(s)) + for i := range u16s { + u16s[i] = float16.Fromfloat32(s[i]) + } + cshape := make([]C.int, len(shape)) + for i, dim := range shape { + cshape[i] = C.int(dim) + } + return newArray(c, + C.mlx_array_new_data( + unsafe.Pointer(&u16s[0]), + &cshape[0], + C.int(len(cshape)), + C.MLX_FLOAT16, + ), + ) +} + +func (a *Array) Floats() []float32 { + if a.sync != nil { + a.sync() + } + l := (int)(C.mlx_array_size(a.a)) + + switch C.mlx_array_dtype(a.a) { + case C.MLX_BFLOAT16: + panic("bfloat16 not yet implemented") + case C.MLX_FLOAT16: + data := C.mlx_array_data_float16_asvoid(a.a) + if data == nil { + panic("nil data, wasn't eval'd") + } + u16s := unsafe.Slice((*uint16)(data), l) + f32s := make([]float32, len(u16s)) + for i := range u16s { + f32s[i] = float16.Frombits(u16s[i]).Float32() + } + return f32s + case C.MLX_FLOAT32: + data := C.mlx_array_data_float32(a.a) + if data == nil { + panic("nil data, wasn't eval'd") + } + f32s := unsafe.Slice((*float32)(data), l) + return f32s + default: + panic(fmt.Sprintf("unsupported dtype for Floats: %d", C.mlx_array_dtype(a.a))) + } +} + +// FromInts implements ml.Context. +func (c *Context) FromInts(s []int32, shape ...int) ml.Tensor { + cshape := make([]C.int, len(shape)) + for i, dim := range shape { + cshape[i] = C.int(dim) + } + return newArray(c, + C.mlx_array_new_data( + unsafe.Pointer(&s[0]), + &cshape[0], + C.int(len(cshape)), + C.MLX_INT32, + ), + ) +} + +func (a *Array) Ints() []int32 { + if a.sync != nil { + a.sync() + } + l := (int)(C.mlx_array_size(a.a)) + + switch C.mlx_array_dtype(a.a) { + case C.MLX_INT32: + data := C.mlx_array_data_int32(a.a) + if data == nil { + panic("nil data, wasn't eval'd") + } + i32s := unsafe.Slice((*int32)(data), l) + return i32s + + // TODO other types via conversion? + default: + panic(fmt.Sprintf("unsupported dtype for Ints: %d", C.mlx_array_dtype(a.a))) + } +} + +func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + sh := make([]C.int, len(shape)) + for i, s := range shape { + sh[i] = (C.int)(s) + } + + var r C.mlx_array + C.mlx_zeros( + &r, + &sh[0], + (C.size_t)(len(sh)), + C.mlx_dtype(dtype), + c.stream, + ) + return newArray(c, r) +} + +func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { + // TODO more efficient impl? + return c.Zeros(dtype, shape...) +} + +func (a *Array) DType() ml.DType { + return (ml.DType)(C.mlx_array_dtype(a.a)) +} + +func (a *Array) Dim(n int) int { + return int(C.mlx_array_dim(a.a, C.int(n))) +} + +func (a *Array) Stride(n int) int { + return (int)(C.stride(a.a, (C.int)(n))) +} + +func (c *Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { + var r C.mlx_array + C.mlx_arange( + &r, + C.double(start), + C.double(stop), + C.double(step), + (C.mlx_dtype)(dtype), + c.stream, + ) + + return newArray(c, r) +} + +// Scale implements ml.Tensor. +func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor { + scale := C.mlx_array_new_float(C.float(s)) + var r C.mlx_array + C.mlx_multiply( + &r, + a.a, + scale, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) Softmax(ctx ml.Context) ml.Tensor { + var r C.mlx_array + C.mlx_softmax( + &r, + a.a, + false, // TODO - precise? + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) SliceUpdate(ctx ml.Context, update ml.Tensor, start, stop, strides []int) ml.Tensor { + cStart := make([]C.int, len(start)) + for i := range start { + cStart[i] = C.int(start[i]) + } + cStop := make([]C.int, len(stop)) + for i := range stop { + cStop[i] = C.int(stop[i]) + } + cStrides := make([]C.int, len(strides)) + for i := range strides { + cStrides[i] = C.int(strides[i]) + } + var r C.mlx_array + C.mlx_slice_update( + &r, + a.a, + update.(*Array).a, + (*C.int)(unsafe.Pointer(&cStart[0])), + C.size_t(len(cStart)), + (*C.int)(unsafe.Pointer(&cStop[0])), + C.size_t(len(cStop)), + (*C.int)(unsafe.Pointer(&cStrides[0])), + C.size_t(len(cStrides)), + ctx.(*Context).stream, + ) + // Release the old array and replace with the new one to ensure the same underlying buffer is used + a.c.mu.Lock() + defer a.c.mu.Unlock() + for i := range a.c.arrays { + if a.c.arrays[i] == a.a { + C.mlx_array_free(a.a) + a.a = r + a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...) + return a + } + } + panic("unable to locate array in context") +} + +func (a *Array) SliceUpdateDynamic(ctx ml.Context, update, start ml.Tensor, axes []int) ml.Tensor { + cAxes := make([]C.int, len(axes)) + for i := range axes { + cAxes[i] = C.int(axes[i]) + } + + var r C.mlx_array + C.mlx_slice_update_dynamic( + &r, + a.a, + update.(*Array).a, + start.(*Array).a, + (*C.int)(unsafe.Pointer(&cAxes[0])), + C.size_t(len(cAxes)), + ctx.(*Context).stream, + ) + // Release the old array and replace with the new one to ensure the same underlying buffer is used + a.c.mu.Lock() + defer a.c.mu.Unlock() + for i := range a.c.arrays { + if a.c.arrays[i] == a.a { + C.mlx_array_free(a.a) + a.a = r + a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...) + return a + } + } + panic("unable to locate array in context") + +} + +func (a *Array) PutAlongAxis(ctx ml.Context, indicies, values ml.Tensor, axis int) ml.Tensor { + var r C.mlx_array + C.mlx_put_along_axis( + &r, + a.a, + indicies.(*Array).a, + values.(*Array).a, + C.int(axis), + ctx.(*Context).stream, + ) + // Release the old array and replace with the new one to ensure the same underlying buffer is used + a.c.mu.Lock() + defer a.c.mu.Unlock() + for i := range a.c.arrays { + if a.c.arrays[i] == a.a { + C.mlx_array_free(a.a) + a.a = r + a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...) + return a + } + } + panic("unable to locate array in context") +} + +func (a *Array) Scatter(ctx ml.Context, indicies []ml.Tensor, updates ml.Tensor, axes []int) ml.Tensor { + + cAxes := make([]C.int, len(axes)) + for i := range axes { + cAxes[i] = C.int(axes[i]) + } + var cAxes0 *C.int + if len(cAxes) > 0 { + cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0])) + } + indiciesVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(indiciesVec) + for _, ind := range indicies { + C.mlx_vector_array_append_value(indiciesVec, ind.(*Array).a) + } + + var r C.mlx_array + C.mlx_scatter( + &r, + a.a, + indiciesVec, + updates.(*Array).a, + cAxes0, + C.size_t(len(cAxes)), + ctx.(*Context).stream, + ) + // Release the old array and replace with the new one to ensure the same underlying buffer is used + a.c.mu.Lock() + defer a.c.mu.Unlock() + for i := range a.c.arrays { + if a.c.arrays[i] == a.a { + C.mlx_array_free(a.a) + a.a = r + a.c.arrays[i] = r + return a + } + } + panic("unable to locate array in context") + +} + +func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + C.mlx_copy( + &a2.(*Array).a, + a.a, + ctx.(*Context).stream, + ) + // TODO - view? + return newArray(ctx.(*Context), a2.(*Array).a) +} + +func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + var r C.mlx_array + C.mlx_add( + &r, + a.a, + a2.(*Array).a, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) Sub(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + var r C.mlx_array + C.mlx_subtract( + &r, + a.a, + a2.(*Array).a, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) Max(ctx ml.Context, axes []int, keepDims bool) ml.Tensor { + var r C.mlx_array + cAxes := make([]C.int, len(axes)) + for i := range axes { + cAxes[i] = C.int(axes[i]) + } + var cAxes0 *C.int + if len(cAxes) > 0 { + cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0])) + C.mlx_max_axes( + &r, + a.a, + cAxes0, + C.size_t(len(cAxes)), + C._Bool(keepDims), + ctx.(*Context).stream, + ) + } else { + C.mlx_max( + &r, + a.a, + C._Bool(keepDims), + ctx.(*Context).stream, + ) + + } + + return newArray(ctx.(*Context), r) +} + +func (a *Array) Min(ctx ml.Context, axes []int, keepDims bool) ml.Tensor { + var r C.mlx_array + cAxes := make([]C.int, len(axes)) + for i := range axes { + cAxes[i] = C.int(axes[i]) + } + var cAxes0 *C.int + if len(cAxes) > 0 { + cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0])) + C.mlx_min_axes( + &r, + a.a, + cAxes0, + C.size_t(len(cAxes)), + C._Bool(keepDims), + ctx.(*Context).stream, + ) + } else { + C.mlx_min( + &r, + a.a, + C._Bool(keepDims), + ctx.(*Context).stream, + ) + } + + return newArray(ctx.(*Context), r) +} + +func (a *Array) Matmul(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + var r C.mlx_array + C.mlx_matmul( + &r, + a.a, + a2.(*Array).a, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { + // slog.Info("MLX.RMSNorm", "a", a, "w", w) + var r C.mlx_array + C.mlx_fast_rms_norm( + &r, + a.a, + w.(*Array).a, + C.float(eps), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { + var r C.mlx_array + C.mlx_fast_layer_norm( + &r, + a.a, + w.(*Array).a, + b.(*Array).a, + C.float(eps), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) L2Norm(ctx ml.Context, eps float32) ml.Tensor { + // TODO implement + panic("NOT YET IMPLEMENTED") +} + +func (t Array) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { + panic("NOT YET IMPLEMENTED") +} + +// RoPE implements Rotary Positional Encoding +// +// dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. +// traditional (bool) – If set to True choose the traditional implementation which rotates consecutive dimensions. +// scale (float) – The scale used to scale the positions. +// offset (int) – The position offset to start at. TODO MLX-C does not yet expose Offset as an Array +// WithBase (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of base and freqs must be None. +// WithFreqs (array, optional) – Optional frequencies to use with RoPE. If set, the base parameter must be None. Default: None. +func (a *Array) RoPE(ctx ml.Context, dims int, traditional bool, scale float32, offset int, options ...func(*ml.RoPEOptions)) ml.Tensor { + opts := ml.RoPEOptions{} + + // Apply any provided options + for _, option := range options { + option(&opts) + } + var r C.mlx_array + var base C.mlx_optional_float + var freqs C.mlx_array + + if opts.Base != nil { + base.value = C.float(*opts.Base) + base.has_value = true + } + if opts.Freqs != nil { + freqs = opts.Freqs.(*Array).a + } + C.mlx_fast_rope( + &r, + a.a, + C.int(dims), + C._Bool(traditional), + base, + C.float(scale), + C.int(offset), + freqs, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V. +// +// Supports: +// - Multi-Head Attention +// - Grouped Query Attention +// - Multi-Query Attention +// +// Note: +// - The softmax operation is performed in float32 regardless of the input precision. +// - For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q. +// +// In the following the dimensions are given by: +// - B: The batch size. +// - N_q: The number of query heads. +// - N_kv: The number of key and value heads. +// - T_q: The number of queries per example. +// - T_kv: The number of keys and values per example. +// - D: The per-head dimension. +// +// Parameters: +// - [subject array] queries (array) – Queries with shape [B, N_q, T_q, D]. +// - keys (array) – with shape [B, N_kv, T_kv, D]. +// - values (array) – with shape [B, N_kv, T_kv, D]. +// - scale (float) – Scale for queries (typically 1.0 / sqrt(q.shape(-1)). +// - mask (str or array, optional) – The mask to apply to the query-key scores. +// The mask can be an array or a string indicating the mask type. The only supported string type is "causal". +// If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and +// must be broadcast-compatible with the shape [B, N, T_q, T_kv]. If an additive mask is given its type must +// promote to the promoted type of q, k, and v. +// - sinks (array, optional) – An optional array of attention sinks. Default: None. + +func (queries *Array) ScaledDotProductAttention(ctx ml.Context, keys, values ml.Tensor, scale float64, maskMode string, mask ml.Tensor, sinks ml.Tensor) ml.Tensor { + var r C.mlx_array + var s C.mlx_array + if sinks != nil { + s = sinks.(*Array).a + } + maskModeC := C.CString(maskMode) + defer C.free(unsafe.Pointer(maskModeC)) + var maskArr C.mlx_array + if mask != nil { + maskArr = mask.(*Array).a + } + + C.mlx_fast_scaled_dot_product_attention( + &r, + queries.a, + keys.(*Array).a, + values.(*Array).a, + C.float(scale), + maskModeC, + maskArr, + s, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) TakeAxes(ctx ml.Context, indicies ml.Tensor, axes int) ml.Tensor { + var r C.mlx_array + + C.mlx_take_axis(&r, a.a, indicies.(*Array).a, C.int(axes), ctx.(*Context).stream) + return newArray(ctx.(*Context), r) + +} + +// TODO not sure if we'll want this variation taking raw ints instead of a tensor... +// func (a *Array) TakeAxes(ctx ml.Context, axes int, indicies ...int) ml.Tensor { +// var i C.mlx_array +// var r C.mlx_array + +// if indicies != nil { +// shape := []C.int{C.int(len(indicies))} +// cindicies := make([]int32, len(indicies)) +// for i, v := range indicies { +// cindicies[i] = int32(v) +// } +// i = C.mlx_array_new_data( +// unsafe.Pointer(&cindicies[0]), +// &shape[0], +// C.int(len(shape)), +// C.MLX_INT32, +// ) +// } +// C.mlx_take_axis(&r, a.a, i, C.int(axes), ctx.(*Context).stream) +// return newArray(ctx.(*Context), r) + +// } + +func (a *Array) GELU(ctx ml.Context, up ...ml.Tensor) ml.Tensor { + // TODO precise vs fast, and compile + // x * mx.sigmoid(1.702 * x) + u16s := []float16.Float16{float16.Fromfloat32(1.702)} + cshape := []C.int{1} + f := C.mlx_array_new_data(unsafe.Pointer(&u16s[0]), &cshape[0], 1, C.MLX_FLOAT16) + defer C.mlx_array_free(f) + var r1, r2, r3 C.mlx_array + C.mlx_multiply(&r1, a.a, f, ctx.(*Context).stream) + defer C.mlx_array_free(r1) + C.mlx_sigmoid(&r2, r1, ctx.(*Context).stream) + defer C.mlx_array_free(r2) + C.mlx_multiply(&r3, a.a, r2, ctx.(*Context).stream) + + if len(up) > 0 { + var r4 C.mlx_array + defer C.mlx_array_free(r3) + C.mlx_multiply(&r4, r3, up[0].(*Array).a, ctx.(*Context).stream) + return newArray(ctx.(*Context), r4) + } + + return newArray(ctx.(*Context), r3) +} + +// Create a view into the array with the given shape and strides. +// +// The resulting array will always be as if the provided array was row +// contiguous regardless of the provided arrays storage order and current +// strides. +// +// Note that this function should be used with caution as it changes the shape +// and strides of the array directly. This can lead to the resulting array +// pointing to invalid memory locations which can result into crashes. +// +// Parameters: +// - shape (list(int), optional) – The shape of the resulting array. If None it defaults to a.shape(). +// - strides (list(int), optional) – The strides of the resulting array. If None it defaults to the +// reverse exclusive cumulative product of a.shape(). +// - offset (int) – Skip that many elements from the beginning of the input array. +func (a *Array) AsStrided(ctx ml.Context, shape, strides []int, offset int) ml.Tensor { + var r C.mlx_array + sh := make([]C.int, len(shape)) + st := make([]C.int64_t, len(strides)) + var sh0 *C.int + var st0 *C.int64_t + for i, s := range shape { + sh[i] = C.int(s) + } + for i, s := range strides { + st[i] = C.int64_t(s) + } + if len(sh) > 0 { + sh0 = (*C.int)(unsafe.Pointer(&sh[0])) + } + if len(st) > 0 { + st0 = (*C.int64_t)(unsafe.Pointer(&st[0])) + } + + C.mlx_as_strided( + &r, + a.a, + sh0, + C.size_t(len(sh)), + st0, + C.size_t(len(st)), + C.size_t(offset), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) + +} + +func (a *Array) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + cshape := make([]C.int, len(shape)) + for i, dim := range shape { + cshape[i] = C.int(dim) + } + var r C.mlx_array + C.mlx_reshape(&r, a.a, &cshape[0], C.size_t(len(cshape)), ctx.(*Context).stream) + return newArray(ctx.(*Context), r) +} + +func (a *Array) Transpose(ctx ml.Context, shape ...int) ml.Tensor { + ndim := min(C.mlx_array_ndim(a.a), C.size_t(len(shape))) + var r C.mlx_array + sh := make([]C.int, ndim) + for i := range ndim { + sh[i] = (C.int)(shape[i]) + if int(sh[i]) >= int(ndim) { + slog.Error("Permute error", "tensor", a, "shape", shape) + panic("invalid pemute call") + } + } + if len(sh) > 0 { + C.mlx_transpose_axes( + &r, + a.a, + &sh[0], + ndim, + ctx.(*Context).stream, + ) + } else { + C.mlx_transpose( + &r, + a.a, + ctx.(*Context).stream, + ) + } + return newArray(ctx.(*Context), r) +} + +func (a *Array) Contiguous(ctx ml.Context, allowColMajor bool) ml.Tensor { + var r C.mlx_array + C.mlx_contiguous( + &r, + a.a, + (C._Bool)(allowColMajor), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Conv2D implements ml.Tensor. +// GGML API +// input: [N, IC, IH, IW] +// weight: [OC,IC, KH, KW] +// result: [N, OC, OH, OW] +// +// MLX: +// input: (N, KH, KW, C_in) +// weight: (C_out, IH, IW, C_in) +// result: XXX + +func (input *Array) Conv2D(ctx ml.Context, weight ml.Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) ml.Tensor { + var r C.mlx_array + C.mlx_conv2d( + &r, + input.a, + weight.(*Array).a, + C.int(stride0), + C.int(stride1), + C.int(padding0), + C.int(padding1), + C.int(dilation0), + C.int(dilation1), + C.int(groups), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (input *Array) Conv3D(ctx ml.Context, weight ml.Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) ml.Tensor { + var r C.mlx_array + C.mlx_conv3d( + &r, + input.a, + weight.(*Array).a, + C.int(stride0), + C.int(stride1), + C.int(stride2), + C.int(padding0), + C.int(padding1), + C.int(padding2), + C.int(dilation0), + C.int(dilation1), + C.int(dilation2), + C.int(groups), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +func (a *Array) ToString() string { + str := C.mlx_string_new() + C.mlx_array_tostring(&str, a.a) + s := C.mlx_string_data(str) + defer C.mlx_string_free(str) + return C.GoString(s) +} + +func (a *Array) LogValue() slog.Value { + + dims := int(C.mlx_array_ndim(a.a)) + strides := make([]int, dims) + for i := range strides { + strides[i] = int(C.stride(a.a, (C.int)(i))) + } + + return slog.GroupValue( + slog.String("name", a.name), + slog.String("type", a.TypeString()), + slog.Any("shape", a.Shape()), + slog.Any("strides", strides), + // slog.String("values", C.GoString(s)), + ) +} + +func (a *Array) Shape() []int { + shape := make([]int, C.mlx_array_ndim(a.a)) + for i := range shape { + shape[i] = int(C.mlx_array_dim(a.a, C.int(i))) + } + + return shape +} + +func (a *Array) TypeString() string { + switch C.mlx_array_dtype(a.a) { + case C.MLX_BOOL: + return "bool" + case C.MLX_UINT8: + return "uint8" + case C.MLX_UINT16: + return "uint16" + case C.MLX_UINT32: + return "uint32" + case C.MLX_UINT64: + return "uint64" + case C.MLX_INT8: + return "int8" + case C.MLX_INT16: + return "int16" + case C.MLX_INT32: + return "int32" + case C.MLX_INT64: + return "int64" + case C.MLX_FLOAT16: + return "float16" + case C.MLX_FLOAT32: + return "float32" + case C.MLX_BFLOAT16: + return "bfloat16" + case C.MLX_COMPLEX64: + return "complex64" + default: + return "unknown" + } +} diff --git a/x/ml/backend/mlx/mlx_test.go b/x/ml/backend/mlx/mlx_test.go new file mode 100644 index 000000000..7699c1524 --- /dev/null +++ b/x/ml/backend/mlx/mlx_test.go @@ -0,0 +1,314 @@ +//go:build mlx + +package mlx + +import ( + "log/slog" + "os" + "reflect" + "strings" + "testing" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/runner/common" + "github.com/ollama/ollama/sample" + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/model" + "github.com/ollama/ollama/x/model/input" + _ "github.com/ollama/ollama/x/model/models/gemma3" +) + +func init() { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + slog.SetDefault(logger) +} + +func TestLoadModel(t *testing.T) { + dir := "/Users/daniel/Models/gemma-3-4b-it/" + b := &Backend{} + err := b.LoadSafeTensors(dir) + if err != nil { + t.Fatalf("load failed: %s", err) + } +} + +func TestFromInts(t *testing.T) { + b := &Backend{} + c := b.NewContext() + defer c.Close() + data := []int32{1, 2, 3, 4, 5, 6} + a := c.FromInts(data, 2, 3) + slog.Info("", "array", a) + t.Log(a.ToString()) + if !reflect.DeepEqual(a.Shape(), []int{2, 3}) { + t.Fatalf("incorrect shape: %v", a.Shape()) + } +} + +func TestFromFloats(t *testing.T) { + b := &Backend{} + c := b.NewContext() + defer c.Close() + data := []float32{1, 2, 3, 4, 5, 6} + a := c.FromFloats(data, 2, 3) + slog.Info("", "array", a) + t.Log(a.ToString()) + if !reflect.DeepEqual(a.Shape(), []int{2, 3}) { + t.Fatalf("incorrect shape: %v", a.Shape()) + } + res := a.Floats() + if !reflect.DeepEqual(res, data) { + t.Fatalf("incorrect results: %v", res) + } +} + +func TestAdd(t *testing.T) { + b := &Backend{} + c := b.NewContext() + defer c.Close() + t1 := c.Arange(0, 24, 1, ml.DTypeFloat16) + t2 := c.Arange(0, 24, 1, ml.DTypeFloat16) + exp := c.Arange(0, 48, 2, ml.DTypeFloat16) + t3 := t1.Add(c, t2) + c.Compute(t3, exp) + t3f := t3.Floats() + if !reflect.DeepEqual(t3f, exp.Floats()) { + t.Fatalf("incorrect result: %v", t3f) + } +} + +func TestReshapeTranspose(t *testing.T) { + b := &Backend{} + c := b.NewContext() + defer c.Close() + t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false) + c.Compute(t1) + t1f := t1.Floats() + exp := []float32{ + 0, 4, 8, + 1, 5, 9, + 2, 6, 10, + 3, 7, 11, + 12, 16, 20, + 13, 17, 21, + 14, 18, 22, + 15, 19, 23, + } + if !reflect.DeepEqual(t1f, exp) { + t.Fatalf("incorrect results: %v", t1f) + } +} + +func prod(vals ...int) int { + r := 1 + for _, v := range vals { + r *= v + } + return r +} +func TestMatmul(t *testing.T) { + // TODO create scenarios... + b := &Backend{} + c := b.NewContext() + defer c.Close() + s1 := []int{1, 3, 2, 4} + t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...) + s2 := []int{4, 2} + t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...) + t3 := t1.Matmul(c, t2) + exp := []float32{ + 28, 34, + 76, 98, + + 124, 162, + 172, 226, + + 220, 290, + 268, 354, + } + c.Compute(t3) + t3f := t3.Floats() + if !reflect.DeepEqual(t3f, exp) { + t.Fatalf("incorrect result: %v", t3f) + } +} + +func TestRows(t *testing.T) { + b := &Backend{} + c := b.NewContext() + defer c.Close() + t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3) + outputs := c.Zeros(ml.DTypeInt32, 1) + t2 := t1.TakeAxes(c, outputs, 1) + c.Forward(t1, t2).Compute(t1, t2) + t.Log(t1.ToString()) + t.Log(t2.ToString()) + f := t2.Floats() + t.Logf("Result: %v", f) +} + +func TestCaching(t *testing.T) { + // Validate the caching algorithm + b := &Backend{} + c := b.NewContext() + defer c.Close() + batchSize := 3 + headDim := 4 + numKVHeads := 2 + // Make cache twice the size of one test batch + cells := batchSize * 2 + cellSize := numKVHeads * headDim + shape := []int{1, numKVHeads, batchSize, headDim} + stop := float32(1) + for _, x := range shape { + stop *= float32(x) + } + // Create the cache + cache := c.Zeros(ml.DTypeFloat16, cells, cellSize) + t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize}) + + // Input tensor + t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...) + t.Logf("Initial Data shape%v\n"+t1.ToString(), shape) + + // Reshape to copy into the cache + /* + From MLX python/src/indexing.cpp mlx_scatter_args_array + // The update shape must broadcast with indices.shape + [1] + src.shape[1:] + auto up_shape = indices.shape(); + up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end()); + up = broadcast_to(up, up_shape); + up_shape.insert(up_shape.begin() + indices.ndim(), 1); + up = reshape(up, up_shape); + */ + numRows := 3 + up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly + t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim}) + + // Simulate cells 1,3,5 are available + indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)} + t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows}) + axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape + cache.Scatter(c, indicies, up, axis) + + c.Forward(cache) + // Cache should contain the data now + t.Log("Cache after put\n" + cache.ToString()) + + // Retrieve cache content and verify it matches + out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...) + t.Logf("Output shape%v\n"+out.ToString(), out.Shape()) + + t1f := t1.Floats() + outf := out.Floats() + if !reflect.DeepEqual(t1f, outf) { + t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf) + } +} + +func TestGemma3(t *testing.T) { + // Why is the sky blue + inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368} + limit := 50 + + // TODO generalize this + dir := "/Users/daniel/Models/gemma-3-4b-it/" + + m, err := model.New(dir, ml.BackendParams{}) + if err != nil { + t.Fatalf("unable to load model: %s", err) + } + b := m.Backend() + ctx := b.NewContext() + defer ctx.Close() + + batch := input.Batch{ + Inputs: ctx.FromInts(inputs[:], 1, len(inputs)), + Positions: make([]int32, len(inputs)), + Sequences: make([]int, len(inputs)), + Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1), + Offset: 0, + } + for i := range len(inputs) { + batch.Positions[i] = int32(i) + } + offset := len(inputs) + + cache := m.Config().Cache + if cache != nil { + numSlots := 1 + batchSize := 512 + numCtx := 4096 + + // Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working + // cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64}) + cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0}) + + cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize) + err := cache.StartForward(ctx, batch, false) + if err != nil { + t.Fatalf("failed cache.StartForward: %s", err) + } + } + opts := api.DefaultOptions() + var grammar *sample.GrammarSampler + sampler := sample.NewSampler( + opts.Temperature, + opts.TopK, + opts.TopP, + opts.MinP, + opts.Seed, + grammar, + ) + + t.Log("Starting Forward pass loop") + pendingResponses := []string{} + for { + out, err := m.Forward(ctx, batch) + if err != nil { + t.Fatalf("failed forward pass: %s", err) + } + ctx.Forward(out) + outputs := out.Floats() + t.Logf("finished forward pass! length:%d", len(outputs)) + // sample a token + logits := outputs + token, err := sampler.Sample(logits) + if err != nil { + t.Fatalf("unable to sample token: %s", err) + } + t.Logf("Sampled token: %v", token) + if m.(model.TextProcessor).Is(token, model.SpecialEOS) { + t.Log("hit EOS") + break + } + piece, err := m.(model.TextProcessor).Decode([]int32{token}) + if err != nil { + t.Fatalf("unable to decode token: %s", err) + } + + pendingResponses = append(pendingResponses, piece) + sequence := strings.Join(pendingResponses, "") + if ok, stop := common.FindStop(sequence, opts.Stop); ok { + t.Logf("hit stop token: %v", stop) + break + } + t.Logf("RESULTS: %s", sequence) + batch = input.Batch{ + Inputs: ctx.FromInts([]int32{token}, 1, 1), + Positions: make([]int32, 1), + Sequences: make([]int, 1), + Outputs: ctx.FromInts([]int32{0}, 1), + Offset: offset, + } + offset++ + batch.Positions[0] = 0 + err = cache.StartForward(ctx, batch, false) + if err != nil { + t.Fatalf("failed cache.StartForward: %s", err) + } + if offset > limit { + break + } + } +} diff --git a/x/ml/backend/mlx/quant.go b/x/ml/backend/mlx/quant.go new file mode 100644 index 000000000..724f43253 --- /dev/null +++ b/x/ml/backend/mlx/quant.go @@ -0,0 +1,335 @@ +//go:build mlx + +package mlx + +/* +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/ops.h" + +// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp + +void unpack_32_4(uint8_t* data, int8_t* dst) { + memset(dst, 0, 16); + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes. + if (j % 2 != 0) { + x <<= 4; + } + dst[j / 2] += x; + } + // Last 16 weights are in the higher bits + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j + 2] >> 4); + if (j % 2 != 0) { + x <<= 4; + } + dst[8 + j / 2] += x; + } +} + +// Extracts (weight, scales, biases) from Q4_0 tensors. +// Data layout is: |16 bit scale|32 x 4bit weights|. +void extract_q4_0_data( + uint8_t* data, + mlx_array* weights_arr, + mlx_array* scales_arr, + mlx_array* biases_arr) { + const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights + uint8_t* weights = mlx_array_data_uint8(*weights_arr); + float16_t* scales = mlx_array_data_float16(*scales_arr); + float16_t* biases = mlx_array_data_float16(*biases_arr); + for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { + scales[i] = *((float16_t*)data); + biases[i] = -8 * scales[i]; + unpack_32_4(data, weights); + weights += 16; + data += bytes_per_block; + } +} + +// Extracts (weight, scales, biases) from Q4_1 tensors. +// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|. +void extract_q4_1_data( + uint8_t* data, + mlx_array* weights_arr, + mlx_array* scales_arr, + mlx_array* biases_arr) { + const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights + uint8_t* weights = mlx_array_data_uint8(*weights_arr); + float16_t* scales = mlx_array_data_float16(*scales_arr); + float16_t* biases = mlx_array_data_float16(*biases_arr); + for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { + scales[i] = *((float16_t*)data); + biases[i] = *((float16_t*)(data) + 1); + unpack_32_4(data, weights); + weights += 16; + data += bytes_per_block; + } +} + +// Extracts (weight, scales, biases) from Q8_0 tensors. +// Data layout is: |16 bit scale|32 x 8bit weights|. +void extract_q8_0_data( + uint8_t* data, + mlx_array* weights_arr, + mlx_array* scales_arr, + mlx_array* biases_arr) { + const uint64_t weights_per_block = 32; + const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights + uint8_t* weights = mlx_array_data_uint8(*weights_arr); + float16_t* scales = mlx_array_data_float16(*scales_arr); + float16_t* biases = mlx_array_data_float16(*biases_arr); + for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { + uint8_t* block_data = data + i * bytes_per_block; + scales[i] = *((float16_t*)block_data); + biases[i] = -128 * scales[i]; + for (int64_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. + // Original data is in int8_t, so we add a bias of -128 and invert the + // first bit. + x ^= 1 << 7; + weights[i * weights_per_block + j] = x; + } + } +} + +// Drived from ggml-quants.c + +#define QK_K 256 + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + uint16_t d; // super-block scale +} block_q6_K; + +void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) { + const int64_t nb = k / QK_K; + block_q6_K *x = (block_q6_K *)vx; + float16_t* y = (float16_t *)vy; + + for (int i = 0; i < nb; i++) { + float16_t d = 0.0; + memcpy(&d, &x[i].d, sizeof(d)); + + const uint8_t * restrict ql = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict sc = x[i].scales; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +#define K_SCALE_SIZE 12 +#define GGML_COMMON_AGGR_U +#define GGML_COMMON_AGGR_S + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +typedef struct { + union { + struct { + uint16_t d; // super-block scale for quantized scales + uint16_t dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR_S; + uint16_t dm; + } GGML_COMMON_AGGR_U; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; + +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) { + block_q4_K *x = (block_q4_K *)vx; + float16_t* y = (float16_t *)vy; + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + const uint8_t * q = x[i].qs; + float16_t d = 0.0; + memcpy(&d, &x[i].d, sizeof(d)); + float16_t min = 0.0; + memcpy(&min, &x[i].dmin, sizeof(d)); + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float16_t d1 = d * sc; const float16_t m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float16_t d2 = d * sc; const float16_t m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } + } +} + + + +*/ +import "C" + +import ( + "fmt" + "unsafe" + + "github.com/x448/float16" +) + +func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) { + shape := append([]C.int{}, final_shape...) + var weights_per_byte C.int + if dtype == 2 || dtype == 3 { + weights_per_byte = 2 + } else if dtype == 8 { + weights_per_byte = 1 + } else { + return r, fmt.Errorf("unsupported tensor type %d", dtype) + } + + weights_per_block := C.int(32) + if shape[len(shape)-1]%weights_per_block != 0 { + return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1]) + } + + weights_shape := append([]C.int{}, shape...) + weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4) + w_nbytes := C.int(unsafe.Sizeof(uint32(0))) + for i := range weights_shape { + w_nbytes *= weights_shape[i] + } + w_data := make([]byte, w_nbytes) + cbytes := C.CBytes(w_data) + defer C.free(cbytes) + weights := C.mlx_array_new_data( + cbytes, + &weights_shape[0], + C.int(len(weights_shape)), + C.MLX_UINT32, + ) + + // For scales and bias + shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block + sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0))) + for i := range shape { + sb_nbytes *= shape[i] + } + + s_data := make([]byte, sb_nbytes) + cbytes = C.CBytes(s_data) + defer C.free(cbytes) + scales := C.mlx_array_new_data( + cbytes, + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT16, + ) + b_data := make([]byte, sb_nbytes) + cbytes = C.CBytes(b_data) + defer C.free(cbytes) + biases := C.mlx_array_new_data( + cbytes, + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT16, + ) + var bits C.int + switch dtype { + case 2: + C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases) + bits = 4 + case 3: + C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases) + bits = 4 + case 8: + C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases) + bits = 8 + } + groupSize := C.mlx_optional_int{value: 32, has_value: true} + bitsOpt := C.mlx_optional_int{value: bits, has_value: true} + var dtypeOpt C.mlx_optional_dtype // has_value defaults to false + C.mlx_dequantize( + &r, + weights, + scales, + biases, + groupSize, + bitsOpt, + nil, // TODO mode + dtypeOpt, + stream, + ) + C.mlx_array_free(weights) + C.mlx_array_free(scales) + C.mlx_array_free(biases) + + return r, nil +} + +func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) { + size := 1 + for _, d := range shape { + size *= int(d) + } + fdata := make([]float16.Float16, size) + switch dtype { + case 14: + C.dequant_row_q6_K( + data, + unsafe.Pointer(&fdata[0]), + C.int(size), + ) + + case 12: + C.dequant_row_q4_K( + data, + unsafe.Pointer(&fdata[0]), + C.int(size), + ) + default: + return r, fmt.Errorf("unsupported K quant") + } + + r = C.mlx_array_new_data( + unsafe.Pointer(&fdata[0]), + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT16, + ) + return r, nil +} diff --git a/x/ml/device.go b/x/ml/device.go new file mode 100644 index 000000000..f892b512d --- /dev/null +++ b/x/ml/device.go @@ -0,0 +1,643 @@ +package ml + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "hash/maphash" + "io" + "log/slog" + "math" + "net/http" + "runtime" + "slices" + "sort" + "strconv" + "strings" + "time" + + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/logutil" +) + +// GPULayers is a set of layers to be allocated on a single GPU +type GPULayers struct { + DeviceID + + // Layers is a set of layer indicies to load + Layers []int +} + +// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty. +func (g GPULayers) FirstLayer() int { + if len(g.Layers) == 0 { + return math.MaxInt + } + + first := g.Layers[0] + for i := 1; i < len(g.Layers); i++ { + if g.Layers[i] < first { + first = g.Layers[i] + } + } + + return first +} + +func (g GPULayers) String() string { + if len(g.Layers) == 0 { + return "" + } + + slices.Sort(g.Layers) + + contiguous := true + base := g.Layers[0] + for i := range g.Layers { + if g.Layers[i] != base+i { + contiguous = false + break + } + } + + if contiguous { + return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1]) + } else { + return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers) + } +} + +// GPULayersList is a set of layer allocations across multiple GPUs +type GPULayersList []GPULayers + +func (l GPULayersList) Len() int { return len(l) } +func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } + +// Sort by the ordering of the layers offloaded +func (l GPULayersList) Less(i, j int) bool { + li := l[i].FirstLayer() + lj := l[j].FirstLayer() + + return li < lj +} + +func (l GPULayersList) String() string { + if l.Sum() > 0 { + return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) + } else { + return fmt.Sprintf("%v", []GPULayers(l)) + } +} + +// Sum is the total number of layers assigned across all GPUs +func (l GPULayersList) Sum() int { + var sum int + + for _, g := range l { + sum += len(g.Layers) + } + + return sum +} + +var h maphash.Hash + +// Hash is an identifier of this layer assignment +func (l GPULayersList) Hash() uint64 { + h.Reset() + for _, g := range l { + if len(g.Layers) > 0 { + h.WriteString(g.ID + g.Library) + for _, l := range g.Layers { + binary.Write(&h, binary.NativeEndian, int64(l)) + } + } + } + + return h.Sum64() +} + +// ErrNoMem is returned when panicing due to insufficient memory. It includes +// the attempted memory allocation. +type ErrNoMem struct { + BackendMemory +} + +func (e ErrNoMem) Error() string { + return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) +} + +// Minimal unique device identification +type DeviceID struct { + // ID is an identifier for the device for matching with system + // management libraries. The ID is only unique for other devices + // using the same Library. + // This ID represents a "post filtered" view of the enumerated devices + // if the ID is numeric + ID string `json:"id"` + + // Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.) + Library string `json:"backend,omitempty"` +} + +// DeviceMemory provides a breakdown of the memory needed +// per device, such as a CPU or GPU. +type DeviceMemory struct { + DeviceID + + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string + + // Weights is the per-layer memory needed for the model weights. + Weights []uint64 + + // Cache is the per-layer memory needed for the KV cache. + Cache []uint64 + + // Graph is the size of the compute graph. It is not per-layer. + Graph uint64 +} + +func sumMemory(mem []uint64) uint64 { + var sum uint64 + + for _, m := range mem { + sum += m + } + + return sum +} + +// Size returns the total size of the memory required by this device +func (m DeviceMemory) Size() uint64 { + return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph +} + +func memoryPresent(mem []uint64) bool { + return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 }) +} + +func (m DeviceMemory) LogValue() slog.Value { + var attrs []slog.Attr + if memoryPresent(m.Weights) { + attrs = append(attrs, slog.Any("Weights", m.Weights)) + } + + if memoryPresent(m.Cache) { + attrs = append(attrs, slog.Any("Cache", m.Cache)) + } + + if m.Graph != 0 { + attrs = append(attrs, slog.Any("Graph", m.Graph)) + } + + if len(attrs) > 0 && m.ID != "" { + attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...) + } + + return slog.GroupValue(attrs...) +} + +// BackendMemory provides the amount of memory required to load the model +// per device based on the BackendParams. In some cases, not all required +// allocations will be known at this point. However, the size of the most recent +// allocation is guaranteed to be provided so that if it failed, the caller can +// accommodate that to make forward progress. +type BackendMemory struct { + // InputWeights are always located on the CPU and cannot be moved + InputWeights uint64 + + // CPU model components are located in system memory. This does not + // include unified memory allocated through the GPU. + CPU DeviceMemory + + // GPU model components are located on one or more GPUs. + GPUs []DeviceMemory +} + +func (m BackendMemory) LogValue() slog.Value { + var attrs []slog.Attr + if m.InputWeights != 0 { + attrs = append(attrs, slog.Any("InputWeights", m.InputWeights)) + } + + attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU)) + for _, g := range m.GPUs { + attrs = append(attrs, slog.Any(g.Name, g)) + } + + return slog.GroupValue(attrs...) +} + +// Log prints a high level summary of the memory +func (m BackendMemory) Log(level slog.Level) { + var total uint64 + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := sumMemory(m.CPU.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := gpu.Graph; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.CPU.Graph; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + if total > 0 { + slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total)) + } +} + +type DeviceInfo struct { + DeviceID + + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string `json:"name"` + + // Description is the longer user-friendly identification of the device + Description string `json:"description"` + + // FilterID is populated with the unfiltered device ID if a numeric ID is used + // so the device can be included. + FilterID string `json:"filter_id,omitempty"` + + // Integrated is set true for integrated GPUs, false for Discrete GPUs + Integrated bool `json:"integration,omitempty"` + + // PCIID is the bus, device and domain ID of the device for deduplication + // when discovered by multiple backends + PCIID string `json:"pci_id,omitempty"` + + // TotalMemory is the total amount of memory the device can use for loading models + TotalMemory uint64 `json:"total_memory"` + + // FreeMemory is the amount of memory currently available on the device for loading models + FreeMemory uint64 `json:"free_memory,omitempty"` + + // ComputeMajor is the major version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMajor int + + // ComputeMinor is the minor version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMinor int + + // Driver Information + DriverMajor int `json:"driver_major,omitempty"` + DriverMinor int `json:"driver_minor,omitempty"` + + // Where backends were loaded from + LibraryPath []string +} + +type SystemInfo struct { + // ThreadCount is the optimal number of threads to use for inference + ThreadCount int `json:"threads,omitempty"` + + // TotalMemory is the total amount of system memory + TotalMemory uint64 `json:"total_memory,omitempty"` + + // FreeMemory is the amount of memory currently available on the system for loading models + FreeMemory uint64 `json:"free_memory,omitempty"` + + // FreeSwap is the amount of system swap space reported as available + FreeSwap uint64 `json:"free_swap,omitempty"` +} + +func (d DeviceInfo) Compute() string { + // AMD gfx is encoded into the major minor in hex form + if strings.EqualFold(d.Library, "ROCm") { + return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor) + } + return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor) +} + +func (d DeviceInfo) Driver() string { + return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor) +} + +// MinimumMemory reports the amount of memory that should be set aside +// on the device for overhead (e.g. VRAM consumed by context structures independent +// of model allocations) +func (d DeviceInfo) MinimumMemory() uint64 { + if d.Library == "Metal" { + return 512 * format.MebiByte + } + return 457 * format.MebiByte +} + +// Sort by Free Space. +// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first +type ByFreeMemory []DeviceInfo + +func (a ByFreeMemory) Len() int { return len(a) } +func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByFreeMemory) Less(i, j int) bool { + if a[i].Integrated && !a[j].Integrated { + return true + } else if !a[i].Integrated && a[j].Integrated { + return false + } + return a[i].FreeMemory < a[j].FreeMemory +} + +// ByPerformance groups devices by similar speed +func ByPerformance(l []DeviceInfo) [][]DeviceInfo { + resp := [][]DeviceInfo{} + scores := []bool{} + for _, info := range l { + found := false + requested := info.Integrated + for i, score := range scores { + if score == requested { + resp[i] = append(resp[i], info) + found = true + break + } + } + if !found { + scores = append(scores, requested) + resp = append(resp, []DeviceInfo{info}) + } + } + return resp +} + +func ByLibrary(l []DeviceInfo) [][]DeviceInfo { + resp := [][]DeviceInfo{} + libs := []string{} + for _, info := range l { + found := false + requested := info.Library + for i, lib := range libs { + if lib == requested { + resp[i] = append(resp[i], info) + found = true + break + } + } + if !found { + libs = append(libs, requested) + resp = append(resp, []DeviceInfo{info}) + } + } + return resp +} + +func LibraryPaths(l []DeviceInfo) []string { + gpuLibs := []string{LibOllamaPath} + for _, gpu := range l { + for _, dir := range gpu.LibraryPath { + needed := true + for _, existing := range gpuLibs { + if dir == existing { + needed = false + break + } + } + if needed { + gpuLibs = append(gpuLibs, dir) + } + } + } + return gpuLibs +} + +type DeviceComparison int + +const ( + UniqueDevice DeviceComparison = iota + SameBackendDevice // The device is the same, and the library/backend is the same + DuplicateDevice // The same physical device but different library/backend (overlapping device) +) + +func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison { + if a.PCIID != b.PCIID { + return UniqueDevice + } + // If PCIID is empty, we have to use ID + library for uniqueness + if a.PCIID == "" && a.DeviceID != b.DeviceID { + return UniqueDevice + } + if a.Library == b.Library { + return SameBackendDevice + } + return DuplicateDevice +} + +// For a SameBackendDevice, return true if b is better than a +// e.g. newer GPU library version +func (a DeviceInfo) IsBetter(b DeviceInfo) bool { + aLib := a.LibraryPath[len(a.LibraryPath)-1] + bLib := b.LibraryPath[len(b.LibraryPath)-1] + if aLib == bLib { + return false + } + aLibSplit := strings.SplitN(aLib, "_", 2) + bLibSplit := strings.SplitN(bLib, "_", 2) + if len(aLibSplit) < 2 || len(bLibSplit) < 2 { + return false + } + if aLibSplit[0] != bLibSplit[0] { + slog.Debug("unexpected libraries", "a", aLib, "b", bLib) + return false + } + if aLibSplit[1] == bLibSplit[1] { + return false + } + cmp := []string{aLibSplit[1], bLibSplit[1]} + sort.Sort(sort.Reverse(sort.StringSlice(cmp))) + return cmp[0] == bLibSplit[1] +} + +// For each GPU, check if it does NOT support flash attention +func FlashAttentionSupported(l []DeviceInfo) bool { + for _, gpu := range l { + supportsFA := gpu.Library == "cpu" || + gpu.Name == "Metal" || gpu.Library == "Metal" || + (gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || + gpu.Library == "ROCm" || + gpu.Library == "Vulkan" + + if !supportsFA { + return false + } + } + return true +} + +// Given the list of GPUs this instantiation is targeted for, +// figure out the visible devices environment variables +// Set mustFilter true to enable filtering of CUDA devices +func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string { + if len(l) == 0 { + return nil + } + env := map[string]string{} + for _, d := range l { + d.updateVisibleDevicesEnv(env, mustFilter) + } + return env +} + +// NeedsInitValidation returns true if the device in question has the potential +// to crash at inference time and requires deeper validation before we include +// it in the supported devices list. +func (d DeviceInfo) NeedsInitValidation() bool { + // ROCm: rocblas will crash on unsupported devices. + // CUDA: verify CC is supported by the version of the library + return d.Library == "ROCm" || d.Library == "CUDA" +} + +// Set the init validation environment variable +func (d DeviceInfo) AddInitValidation(env map[string]string) { + env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs +} + +// PreferredLibrary returns true if this library is preferred over the other input +// library +// Used to filter out Vulkan in favor of CUDA or ROCm +func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool { + // TODO in the future if we find Vulkan is better than ROCm on some devices + // that implementation can live here. + + if d.Library == "CUDA" || d.Library == "ROCm" { + return true + } + return false +} + +func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) { + var envVar string + switch d.Library { + case "ROCm": + // ROCm must be filtered as it can crash the runner on unsupported devices + envVar = "ROCR_VISIBLE_DEVICES" + if runtime.GOOS != "linux" { + envVar = "HIP_VISIBLE_DEVICES" + } + case "CUDA": + if !mustFilter { + // By default we try to avoid filtering CUDA devices because ROCm also + // looks at the CUDA env var, and gets confused in mixed vendor environments. + return + } + envVar = "CUDA_VISIBLE_DEVICES" + default: + // Vulkan is not filtered via env var, but via scheduling decisions + return + } + v, existing := env[envVar] + if existing { + v = v + "," + } + if d.FilterID != "" { + v = v + d.FilterID + } else { + v = v + d.ID + } + env[envVar] = v +} + +type BaseRunner interface { + // GetPort returns the localhost port number the runner is running on + GetPort() int + + // HasExited indicates if the runner is no longer running. This can be used during + // bootstrap to detect if a given filtered device is incompatible and triggered an assert + HasExited() bool +} + +type RunnerDiscovery interface { + BaseRunner + + // GetDeviceInfos will perform a query of the underlying device libraries + // for device identification and free VRAM information + // During bootstrap scenarios, this routine may take seconds to complete + GetDeviceInfos(ctx context.Context) []DeviceInfo +} + +type FilteredRunnerDiscovery interface { + RunnerDiscovery + + // GetActiveDeviceIDs returns the filtered set of devices actively in + // use by this runner for running models. If the runner is a bootstrap runner, no devices + // will be active yet so no device IDs are returned. + // This routine will not query the underlying device and will return immediately + GetActiveDeviceIDs() []DeviceID +} + +func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) { + var moreDevices []DeviceInfo + port := runner.GetPort() + tick := time.Tick(10 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("failed to finish discovery before timeout") + case <-tick: + r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + r.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(r) + if err != nil { + // slog.Warn("failed to send request", "error", err) + if runner.HasExited() { + return nil, fmt.Errorf("runner crashed") + } + continue + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // old runner, fall back to bootstrapping model + return nil, fmt.Errorf("llamarunner free vram reporting not supported") + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + slog.Warn("failed to read response", "error", err) + continue + } + if resp.StatusCode != 200 { + logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body) + return nil, fmt.Errorf("runner error: %s", string(body)) + } + + if err := json.Unmarshal(body, &moreDevices); err != nil { + slog.Warn("unmarshal encode response", "error", err) + continue + } + return moreDevices, nil + } + } +} diff --git a/x/ml/nn/attention.go b/x/ml/nn/attention.go new file mode 100644 index 000000000..c4a16a302 --- /dev/null +++ b/x/ml/nn/attention.go @@ -0,0 +1,103 @@ +package nn + +import ( + "fmt" + + "github.com/ollama/ollama/x/kvcache" + "github.com/ollama/ollama/x/ml" +) + +// Attention implements scaled dot-product attention for transformer models: +// Attention(Q, K, V) = softmax(QK^T/√d_k)V +// +// Parameters: +// - ctx: Context for tensor operations +// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] +// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only +// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only +// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension +// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value +// +// Returns: +// +// Attention output with shape [d_v, heads, seq_len_q] +func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache) +} + +func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache) +} + +func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + ctx.Forward(query) + + if key != nil && value != nil { + if query.Dim(0) != key.Dim(0) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + } + + if key.Dim(1) != value.Dim(1) { + panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) + } + + if key.Dim(2) != value.Dim(2) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + } + + ctx.Forward(key, value) + if cache != nil { + cache.Put(ctx, key, value) + } + } else if cache == nil { + panic("key & value tensors must be provided if cache is nil") + } + + // ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true) + // panic("after cache get") // + // 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844] + // 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534] + // 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942] + + // var mask ml.Tensor + if cache != nil { + key, value, _ = cache.Get(ctx) + } + // ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true) + // panic("after cache get") // + // 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844] + // 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25] + // 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5] + + // Only use the fast SDPA implementation if we have a cache, since that's what + // will do any expected backend-specific transformations for us + + if cache != nil { + // TODO what to do with vmla? + // return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks) + return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks) + + // TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999% + } else { + panic("else case not supported") + // TODO transpose shapes are wrong + // key = key.Transpose(ctx, 0, 2, 1, 3) + // value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false) + + // kq := query.Matmul(ctx, key) + + // kq = kq.Scale(ctx, scale) + // if mask != nil { + // kq = kq.Add(ctx, mask) + // } + // kq = kq.Softmax(ctx) + + // kqv := kq.Matmul(ctx, value) + + // if vmla != nil { + // kqv = kqv.Matmul(ctx, vmla) + // } + + // return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false) + } +} diff --git a/x/ml/nn/convolution.go b/x/ml/nn/convolution.go new file mode 100644 index 000000000..7c4b5a520 --- /dev/null +++ b/x/ml/nn/convolution.go @@ -0,0 +1,30 @@ +package nn + +import "github.com/ollama/ollama/x/ml" + +type Conv2D struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1) + if m.Bias != nil { + // Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch) + t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1)) + } + return t +} + +type Conv3D struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor { + t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g) + if m.Bias != nil { + t = t.Add(ctx, m.Bias) + } + return t +} diff --git a/x/ml/nn/embedding.go b/x/ml/nn/embedding.go new file mode 100644 index 000000000..b00aa2ff1 --- /dev/null +++ b/x/ml/nn/embedding.go @@ -0,0 +1,11 @@ +package nn + +import "github.com/ollama/ollama/x/ml" + +type Embedding struct { + Weight ml.Tensor `gguf:"weight"` +} + +func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor { + return m.Weight.TakeAxes(ctx, hiddenState, 0) +} diff --git a/x/ml/nn/linear.go b/x/ml/nn/linear.go new file mode 100644 index 000000000..6d108e095 --- /dev/null +++ b/x/ml/nn/linear.go @@ -0,0 +1,32 @@ +package nn + +import "github.com/ollama/ollama/x/ml" + +type Linear struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor { + t = t.Matmul(ctx, m.Weight.Transpose(ctx)) + if m.Bias != nil { + t = t.Add(ctx, m.Bias) + } + + return t +} + +type LinearBatch struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor { + panic("not yet ported") + // t = m.Weight.MulmatID(ctx, t, indices) + // if m.Bias != nil { + // t = t.AddID(ctx, m.Bias, indices) + // } + + // return t +} diff --git a/x/ml/nn/normalization.go b/x/ml/nn/normalization.go new file mode 100644 index 000000000..621245ab4 --- /dev/null +++ b/x/ml/nn/normalization.go @@ -0,0 +1,29 @@ +package nn + +import ( + "github.com/ollama/ollama/x/ml" +) + +type LayerNorm struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor { + return t.LayerNorm(ctx, m.Weight, m.Bias, eps) +} + +type RMSNorm struct { + Weight ml.Tensor `gguf:"weight"` +} + +func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor { + // slog.Info("RMSNorm", "eps", eps) + // fmt.Fprintln(os.Stderr, t.ToString()) + // fmt.Fprintln(os.Stderr, m.Weight.ToString()) + + // TODO this is probably model specific, not generalized... + w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1)) + + return t.RMSNorm(ctx, w, eps) +} diff --git a/x/ml/nn/pooling/pooling.go b/x/ml/nn/pooling/pooling.go new file mode 100644 index 000000000..2dae6dc43 --- /dev/null +++ b/x/ml/nn/pooling/pooling.go @@ -0,0 +1,41 @@ +package pooling + +import ( + "github.com/ollama/ollama/x/ml" +) + +type Type uint32 + +const ( + TypeNone Type = iota + TypeMean + TypeCLS + TypeLast +) + +func (t Type) String() string { + switch t { + case TypeMean: + return "Mean" + case TypeCLS: + return "CLS" + case TypeLast: + return "Last" + default: + return "Unknown" + } +} + +func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + switch t { + // case TypeMean: + // hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx) + // return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) + // case TypeCLS: + // return hiddenStates.Slice(ctx, 1, 0, 1, 1) + // case TypeLast: + // return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1) + default: + panic("unknown pooling type") + } +} diff --git a/x/ml/nn/rope/rope.go b/x/ml/nn/rope/rope.go new file mode 100644 index 000000000..e868aa614 --- /dev/null +++ b/x/ml/nn/rope/rope.go @@ -0,0 +1,72 @@ +package rope + +import "github.com/ollama/ollama/x/ml" + +// Options contains optional parameters for RoPE function +type Options struct { + Type int + Factors ml.Tensor + + // YaRN options + YaRN struct { + OriginalContextLength int + ExtrapolationFactor, + AttentionFactor, + BetaFast, + BetaSlow float32 + } + + // MRoPE options + MRoPE struct { + Sections []int + } +} + +// WithTypeNeoX sets RoPE type to NeoX +func WithTypeNeoX() func(*Options) { + return func(opts *Options) { + opts.Type = 2 + } +} + +// WithFactors sets custom rope factors +func WithFactors(factors ml.Tensor) func(*Options) { + return func(opts *Options) { + if factors != nil { + opts.Factors = factors + } + } +} + +// WithOriginalContextLength sets a custom context length +func WithOriginalContextLength(n int) func(*Options) { + return func(opts *Options) { + opts.YaRN.OriginalContextLength = n + } +} + +func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) { + return func(opts *Options) { + opts.YaRN.ExtrapolationFactor = extrapolationFactor + } +} + +func WithAttentionFactor(attentionFactor float32) func(*Options) { + return func(opts *Options) { + opts.YaRN.AttentionFactor = attentionFactor + } +} + +func WithMRoPE(sections []int) func(*Options) { + return func(opts *Options) { + opts.Type |= 1 << 3 + opts.MRoPE.Sections = sections + } +} + +func WithInterleaveMRoPE(sections []int) func(*Options) { + return func(opts *Options) { + opts.Type |= 1<<3 | 1<<5 + opts.MRoPE.Sections = sections + } +} diff --git a/x/ml/path.go b/x/ml/path.go new file mode 100644 index 000000000..ac93af403 --- /dev/null +++ b/x/ml/path.go @@ -0,0 +1,56 @@ +package ml + +import ( + "os" + "path/filepath" + "runtime" +) + +// LibPath is a path to lookup dynamic libraries +// in development it's usually 'build/lib/ollama' +// in distribution builds it's 'lib/ollama' on Windows +// '../lib/ollama' on Linux and the executable's directory on macOS +// note: distribution builds, additional GPU-specific libraries are +// found in subdirectories of the returned path, such as +// 'cuda_v12', 'rocm', etc. +var LibOllamaPath string = func() string { + exe, err := os.Executable() + if err != nil { + return "" + } + + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + + var libPath string + switch runtime.GOOS { + case "windows": + libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama") + case "linux": + libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama") + case "darwin": + libPath = filepath.Dir(exe) + } + + cwd, err := os.Getwd() + if err != nil { + return "" + } + + paths := []string{ + libPath, + + // build paths for development + filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"), + filepath.Join(cwd, "build", "lib", "ollama"), + } + + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + return p + } + } + + return filepath.Dir(exe) +}() diff --git a/x/model/bytepairencoding.go b/x/model/bytepairencoding.go new file mode 100644 index 000000000..acb58743b --- /dev/null +++ b/x/model/bytepairencoding.go @@ -0,0 +1,282 @@ +package model + +import ( + "cmp" + "fmt" + "iter" + "log/slog" + "slices" + "strings" + + "github.com/dlclark/regexp2" + heap "github.com/emirpasic/gods/v2/trees/binaryheap" + "github.com/ollama/ollama/logutil" +) + +type BytePairEncoding struct { + vocab *Vocabulary + regexps []*regexp2.Regexp +} + +var _ TextProcessor = (*BytePairEncoding)(nil) + +func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding { + if len(pretokenizers) == 0 { + // set default byte-level pretokenizer if none provided, e.g. + // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44 + pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`} + } + + return BytePairEncoding{ + vocab: vocab, + regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) { + for _, p := range pretokenizers { + if !yield(regexp2.MustCompile(p, regexp2.RE2)) { + return + } + } + }), + } +} + +func (bpe BytePairEncoding) Vocabulary() *Vocabulary { + return bpe.vocab +} + +func (bpe BytePairEncoding) Is(id int32, special Special) bool { + return bpe.vocab.Is(id, special) +} + +func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { + parts := []string{s} + for _, re := range bpe.regexps { + parts = slices.Collect(func(yield func(string) bool) { + for _, part := range parts { + r := []rune(part) + var offset int + for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) { + if offset-m.Index != 0 { + if !yield(string(r[:m.Index])) { + return + } + } + + if !yield(m.String()) { + return + } + + offset = m.Index + m.Length + } + + if offset < len(r) { + if !yield(string(r[offset:])) { + return + } + } + } + }) + } + + return slices.Values(parts) +} + +// fragment is a string fragment and their corresponding token IDs +type fragment struct { + value string + ids []int32 +} + +// pair is a pair of runes and its rank +type pair struct { + a, b int + rank int + value string +} + +type merge struct { + p, n int + runes []rune +} + +func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { + fragments := []fragment{{value: s}} + for _, special := range bpe.vocab.SpecialVocabulary() { + // TODO: process special tokens concurrently + id := bpe.vocab.Encode(special) + for i := 0; i < len(fragments); i++ { + frag := fragments[i] + if len(frag.ids) > 0 { + continue + } + + var middle []fragment + switch i := strings.Index(frag.value, special); { + case i < 0: + middle = append(middle, frag) + case i > 0: + middle = append(middle, fragment{value: frag.value[:i]}) + fallthrough + default: + middle = append(middle, fragment{value: special, ids: []int32{id}}) + if rest := frag.value[i+len(special):]; rest != "" { + middle = append(middle, fragment{value: rest}) + } + } + + fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) + } + } + + var ids []int32 + for _, frag := range fragments { + if len(frag.ids) > 0 { + ids = append(ids, frag.ids...) + continue + } + + for split := range bpe.split(frag.value) { + // TODO: process splits concurrently + var sb strings.Builder + for _, b := range []byte(split) { + r := rune(b) + switch { + case r == 0x00ad: + r = 0x0143 + case r <= 0x0020: + r = r + 0x0100 + case r >= 0x007f && r <= 0x00a0: + r = r + 0x00a2 + } + + sb.WriteRune(r) + } + + // short circuit if the fragment is in the vocabulary + if id := bpe.vocab.Encode(sb.String()); id >= 0 { + ids = append(ids, id) + continue + } + + runes := []rune(sb.String()) + merges := make([]merge, len(runes)) + for r := range runes { + merges[r] = merge{ + p: r - 1, + n: r + 1, + runes: []rune{runes[r]}, + } + } + + pairwise := func(a, b int) *pair { + if a < 0 || b >= len(runes) { + return nil + } + + left, right := string(merges[a].runes), string(merges[b].runes) + rank := bpe.vocab.Merge(left, right) + if rank < 0 { + return nil + } + + return &pair{ + a: a, + b: b, + rank: rank, + value: left + right, + } + } + + pairs := heap.NewWith(func(i, j *pair) int { + return cmp.Compare(i.rank, j.rank) + }) + + for i := range len(runes) - 1 { + if pair := pairwise(i, i+1); pair != nil { + pairs.Push(pair) + } + } + + for !pairs.Empty() { + pair, _ := pairs.Pop() + + left, right := merges[pair.a], merges[pair.b] + if len(left.runes) == 0 || len(right.runes) == 0 || + string(left.runes)+string(right.runes) != pair.value { + continue + } + + if id := bpe.vocab.Encode(pair.value); id < 0 { + continue + } + + merges[pair.a].runes = append(left.runes, right.runes...) + merges[pair.b].runes = nil + + merges[pair.a].n = right.n + if right.n < len(merges) { + merges[right.n].p = pair.a + } + + if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { + pairs.Push(pair) + } + + if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { + pairs.Push(pair) + } + } + + for _, merge := range merges { + if len(merge.runes) > 0 { + // TODO: handle the edge case where the rune isn't in the vocabulary + if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { + ids = append(ids, id) + } + } + } + } + } + + if addSpecial { + ids = bpe.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +type lazyIdsString struct { + ids []int32 +} + +func (l lazyIdsString) LogValue() slog.Value { + return slog.AnyValue(fmt.Sprint(l.ids)) +} + +func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { + var sb strings.Builder + for _, id := range ids { + for _, r := range bpe.vocab.Decode(id) { + switch { + case r == 0x0100: + // this produces 0x00 aka NULL + continue + case r == 0x0143: + r = 0x00ad + case r > 0x0100 && r <= 0x0120: + r = r - 0x0100 + case r > 0x0120 && r <= 0x0142: + r = r - 0x00a2 + } + + // NOTE: not using WriteRune here because it writes the UTF-8 + // encoding of the rune which is _not_ what we want + if err := sb.WriteByte(byte(r)); err != nil { + return "", err + } + } + } + + logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) + return sb.String(), nil +} diff --git a/x/model/bytepairencoding_test.go b/x/model/bytepairencoding_test.go new file mode 100644 index 000000000..2a7041284 --- /dev/null +++ b/x/model/bytepairencoding_test.go @@ -0,0 +1,322 @@ +package model + +import ( + "bufio" + "encoding/json" + "math" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func llama(t testing.TB) BytePairEncoding { + t.Helper() + + f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + vocab := make(map[string]int32) + if err := json.NewDecoder(f).Decode(&vocab); err != nil { + t.Fatal(err) + } + + types := make([]int32, len(vocab)) + tokens := make([]string, len(vocab)) + for token, id := range vocab { + tokens[id] = token + types[id] = 1 + } + + for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} { + if _, ok := vocab[token]; !ok { + tokens = append(tokens, token) //nolint:makezero + types = append(types, 3) //nolint:makezero + vocab[token] = int32(len(vocab)) + } + } + + f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + merges := make([]string, 0, 50000) + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + if !strings.HasPrefix(scanner.Text(), "#") { + merges = append(merges, scanner.Text()) + } + } + + return NewBytePairEncoding( + &Vocabulary{ + Values: tokens, + Types: types, + Merges: merges, + }, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + ) +} + +func TestLlama(t *testing.T) { + tokenizer := llama(t) + + t.Run("simple", func(t *testing.T) { + t.Parallel() + + ids, err := tokenizer.Encode("hello world", true) + if err != nil { + t.Error(err) + } + + if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + + s, err := tokenizer.Decode([]int32{15339, 1917}) + if err != nil { + t.Fatal(err) + } + + if s != "hello world" { + t.Errorf("got %q, want hello world", s) + } + + ids, err = tokenizer.Encode("hello <|end_of_text|>", true) + if err != nil { + t.Error(err) + } + + if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + }) + + t.Run("simple repeated", func(t *testing.T) { + t.Parallel() + + cases := map[string][]int32{ + strings.Repeat("0", 1): {15}, + strings.Repeat("0", 2): {410}, + strings.Repeat("0", 3): {931}, + strings.Repeat("0", 4): {931, 15}, + strings.Repeat("0", 5): {931, 410}, + strings.Repeat("0", 6): {931, 931}, + strings.Repeat("0", 7): {931, 931, 15}, + strings.Repeat("0", 8): {931, 931, 410}, + strings.Repeat("0", 9): {931, 931, 931}, + strings.Repeat("0", 10): {931, 931, 931, 15}, + strings.Repeat("0", 11): {931, 931, 931, 410}, + strings.Repeat("0", 12): {931, 931, 931, 931}, + strings.Repeat("0", 13): {931, 931, 931, 931, 15}, + strings.Repeat("0", 14): {931, 931, 931, 931, 410}, + strings.Repeat("0", 15): {931, 931, 931, 931, 931}, + strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15}, + strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410}, + } + + for s, want := range cases { + ids, err := tokenizer.Encode(s, true) + if err != nil { + t.Error(err) + } + + if diff := cmp.Diff(want, ids); diff != "" { + t.Errorf("%q no match (-theirs +ours):\n%s", s, diff) + } + } + }) + + t.Run("basic roundtrip", func(t *testing.T) { + t.Parallel() + + cases := []string{ + "hello", + "hello ", + "hello ", + " hello", + " hello ", + " hello ", + "hello world", + "请考试我的软件!12345", + } + + for _, want := range cases { + ids, err := tokenizer.Encode(want, true) + if err != nil { + t.Error(err) + } + + if got, err := tokenizer.Decode(ids); err != nil { + t.Fatal(err) + } else if got != want { + t.Errorf("got %q, want %q", got, want) + } + } + }) + + t.Run("special", func(t *testing.T) { + t.Parallel() + + cases := map[string][]int32{ + "<|begin_of_text|>A B!": {128000, 32, 426, 0}, + "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0}, + "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0}, + "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001}, + } + + for s, want := range cases { + ids, err := tokenizer.Encode(s, true) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(want, ids); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + } + }) + + t.Run("split", func(t *testing.T) { + t.Parallel() + + cases := map[string][]string{ + "Hello World!": {"Hello", " World", "!"}, + "I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"}, + "In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"}, + "Hello!! ...world": {"Hello", "!!", " ...", "world"}, + "Hello World": {"Hello", " ", " World"}, + "Hello\nWorld": {"Hello", "\n", "World"}, + "Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"}, + } + + for s, want := range cases { + got := slices.Collect(tokenizer.split(s)) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + } + }) + + t.Run("roundtriping 0x00-0xFF", func(t *testing.T) { + t.Parallel() + + for b := 0x00; b <= 0xFF; b++ { + input := string(rune(b)) + ids, err := tokenizer.Encode(input, false) + if err != nil { + t.Errorf("failed to encode rune 0x%02X: %v", b, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("failed to decode rune 0x%02X: %v", b, err) + continue + } + + if b == 0x00 { + if len(decoded) != 0 { + t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids) + } + continue + } + + if decoded != input { + t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input) + } + } + }) +} + +func BenchmarkBytePairEncoding(b *testing.B) { + tokenizer := llama(b) + bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt")) + if err != nil { + b.Fatal(err) + } + + for i := range 8 { + n := min(int(math.Pow10(i)), len(bts)) + bts := bts[:n] + b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + _, err := tokenizer.Encode(string(bts), true) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { + ids, err := tokenizer.Encode(string(bts), true) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for b.Loop() { + _, err := tokenizer.Decode(ids) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("split"+strconv.Itoa(n), func(b *testing.B) { + b.ResetTimer() + for b.Loop() { + slices.Collect(tokenizer.split(string(bts))) + } + }) + } +} + +func TestSplit(t *testing.T) { + cases := []struct { + name string + patterns, + want []string + }{ + { + name: "default", + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"}, + }, + { + name: "unicode", + patterns: []string{ + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"}, + }, + { + name: "individual digits", + patterns: []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizer := NewBytePairEncoding(nil, tt.patterns...) + if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + }) + } +} diff --git a/x/model/input/input.go b/x/model/input/input.go new file mode 100644 index 000000000..05857e20a --- /dev/null +++ b/x/model/input/input.go @@ -0,0 +1,76 @@ +package input + +import "github.com/ollama/ollama/x/ml" + +// Multimodal is a multimodal embedding or a component of one. +// For example, it could be a row of an image that can be processed +// independently. +type Multimodal struct { + // Tensor is the embedding data. Implementations may chose what to + // store here or it may be nil if not needed. However, any ml.Tensor + // objects must be stored here and not in Data. + Tensor ml.Tensor + + // Data is implementation-specific opaque data, such as metadata on how + // to layout Tensor. It may be nil if not needed. It may also store larger + // objects such as complete images if they are to be processed later. + Data any +} + +// Input represents one token in the input stream +type Input struct { + // Token is a single element of text. + Token int32 + + // Multimodal is represents a non-text element such as an + // image (or part of one if the image can be processed in pieces). + // It may be used either together with Token or on its own. + Multimodal []Multimodal + + // MultimodalHash is a unique representation of the data + // stored in Multimodal, used for caching and comparing + // equality. + MultimodalHash uint64 + + // SameBatch forces the following number of tokens to be processed + // in a single batch, breaking and extending batches as needed. + // Useful for things like images that must be processed in one + // shot. + SameBatch int +} + +// MultimodalIndex is a multimodal element (such as an image) +// together with an index into the slice of Inputs with the +// corresponding token. Note that the index is not the same +// as the position - to find that use the index with the +// Positions slice. +type MultimodalIndex struct { + Index int + Multimodal []Multimodal +} + +// Batch contains the inputs for a model forward pass +type Batch struct { + // Inputs is the input tokens, including placeholders for multimodal inputs. + Inputs ml.Tensor + + // Outputs are the set of indicies into Inputs for which output data should + // be returned. + Outputs ml.Tensor + + // TODO maybe not the optimal way to handle this + // Offset of final tensor in the final batch + Offset int + + // Positions is the position for each Input, relative to its sequence. Equal + // in length to Inputs. + Positions []int32 + + // Sequences is the sequence for each Input. Equal in length to Inputs. + Sequences []int + + // Multimodal is a set of multimodal embeddings previously created by + // EncodeMultimodal, along with an index into Inputs. Unused for text-only + // models or for batches without multimodal elements. + Multimodal []MultimodalIndex +} diff --git a/x/model/model.go b/x/model/model.go new file mode 100644 index 000000000..60c3d1487 --- /dev/null +++ b/x/model/model.go @@ -0,0 +1,333 @@ +package model + +import ( + "errors" + "fmt" + _ "image/jpeg" + _ "image/png" + "log/slog" + "os" + "reflect" + "strconv" + "strings" + + _ "golang.org/x/image/bmp" + _ "golang.org/x/image/tiff" + _ "golang.org/x/image/webp" + + "github.com/ollama/ollama/fs" + fsggml "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/x/kvcache" + "github.com/ollama/ollama/x/ml" + _ "github.com/ollama/ollama/x/ml/backend" + "github.com/ollama/ollama/x/ml/nn/pooling" + "github.com/ollama/ollama/x/model/input" +) + +var ( + ErrNoVisionModel = errors.New("this model is missing data required for image input") + ErrUnsupportedModel = errors.New("model not supported") + ErrUnsupportedTokenizer = errors.New("tokenizer not supported") +) + +// Model implements a specific model architecture, defining the forward pass and any model-specific configuration +type Model interface { + Forward(ml.Context, input.Batch) (ml.Tensor, error) + + Backend() ml.Backend + Config() config +} + +// MultimodalProcessor must be implemented by multimodal models. +type MultimodalProcessor interface { + // EncodeMultimodal processes a single input (such as an image) and + // generates an output (typically an embedding) that can be used by the model. + // + // The return value is one or more tensors, each with optional model-specific + // opaque metadata. Typically, the tensors might be views into an embedding + // with each view representing a chunk of data that can be processed independently + // in different batches. + // + // The result may be cached by the runner. + EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error) + + // PostTokenize is called after tokenization to allow the model to edit the + // input stream to correctly arrange multimodal elements. + // + // The input is a slice of tokens with the results of EncodeMultimodal interleaved + // in the order that the user provided them. Each element of the slice will be + // either a single token or single multimodal object. + // + // The model must ensure that inputs are stored according to how they will be + // processed and stored in the cache. For example, Llava-style models should insert + // placeholder tokens equal to the feature size of the corresponding image with + // the image itself attached to and split across these tokens. When Forward is called + // a partial subset of these tokens may be submitted according to the batch size. + // + // This function is also responsible for updating MultimodalHash for any Multimodal + // that is modified to ensure that there is a unique hash value that accurately + // represents the contents. + PostTokenize([]*input.Input) ([]*input.Input, error) +} + +// Base implements the common fields and methods for all models +type Base struct { + b ml.Backend + config +} + +type config struct { + Cache kvcache.Cache +} + +// Backend returns the underlying backend that will run the model +func (m *Base) Backend() ml.Backend { + return m.b +} + +func (m *Base) Config() config { + return m.config +} + +var models = make(map[string]func(fs.Config) (Model, error)) + +// Register registers a model constructor for the given architecture +func Register(name string, f func(fs.Config) (Model, error)) { + if _, ok := models[name]; ok { + panic("model: model already registered") + } + + models[name] = f +} + +// New initializes a new model instance with the provided configuration based on the metadata in the model file +func New(modelPath string, params ml.BackendParams) (Model, error) { + b, err := ml.NewBackend(modelPath, params) + if err != nil { + return nil, err + } + + m, err := modelForArch(b.Config()) + if err != nil { + return nil, err + } + + base := Base{b: b, config: m.Config()} + v := reflect.ValueOf(m) + v.Elem().Set(populateFields(base, v.Elem())) + return m, nil +} + +func NewTextProcessor(s string) (TextProcessor, error) { + r, err := os.Open(s) + if err != nil { + return nil, err + } + defer r.Close() + + meta, err := fsggml.Decode(r, -1) + if err != nil { + return nil, err + } + + m, err := modelForArch(meta.KV()) + if err != nil { + return nil, err + } + + tp, ok := m.(TextProcessor) + if !ok { + return nil, ErrUnsupportedTokenizer + } + return tp, nil +} + +func modelForArch(c fs.Config) (Model, error) { + arch := c.Architecture() + if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone { + arch = arch + "_embed" + } + + f, ok := models[arch] + if !ok { + return nil, ErrUnsupportedModel + } + + return f(c) +} + +func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { + t := v.Type() + + if t.Kind() == reflect.Struct { + allNil := true + for i := range t.NumField() { + tt := t.Field(i).Type + vv := v.Field(i) + if !vv.CanSet() { + continue + } + + // make a copy + tagsCopy := tags + if tag := t.Field(i).Tag.Get("gguf"); tag != "" { + tagsCopy = append(tagsCopy, parseTag(tag)) + } + + if tt == reflect.TypeOf((*Base)(nil)).Elem() { + vv.Set(reflect.ValueOf(base)) + } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { + var fn func([]Tag, string, string) [][]string + fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { + if len(tags) > 0 { + var names []string + if tags[0].name != "" { + for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) { + names = append(names, prefix+n+suffix) + } + } + childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix) + if len(names) == 0 { + // current tag has no name, use child names only + fullNames = append(fullNames, childNames...) + } else if len(childNames) == 0 { + // current tag has names but no children, create branches for each name + for _, name := range names { + fullNames = append(fullNames, []string{name}) + } + } else { + // merge each name with each child + for _, name := range names { + for _, childName := range childNames { + fullNames = append(fullNames, append([]string{name}, childName...)) + } + } + } + } + + return fullNames + } + + names := fn(tagsCopy, "", "") + for _, name := range names { + if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { + logutil.Trace("found tensor", "", tensor) + vv.Set(reflect.ValueOf(tensor)) + break + } + } + } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { + setPointer(base, vv, tagsCopy) + } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { + for i := range vv.Len() { + vvv := vv.Index(i) + if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { + setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})) + } else { + vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...)) + } + } + } + + if !canNil(tt) || !vv.IsNil() { + allNil = false + } + } + + if allNil { + return reflect.Zero(t) + } + } + + return v +} + +func setPointer(base Base, v reflect.Value, tags []Tag) { + vv := v + if v.Kind() == reflect.Interface { + if v.IsNil() { + return + } + + vv = vv.Elem() + } + + vv = reflect.Indirect(vv) + if v.IsNil() { + vv = reflect.New(v.Type().Elem()).Elem() + } + + if f := populateFields(base, vv, tags...); f.CanAddr() { + v.Set(f.Addr()) + } +} + +type Tag struct { + name, + // prefix and suffix are applied to child tags + prefix, + suffix string + alternatives []string +} + +func parseTag(s string) (tag Tag) { + parts := strings.Split(s, ",") + if len(parts) > 0 { + tag.name = parts[0] + + for _, part := range parts[1:] { + if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" { + // elevate alternative to primary if no primary given + tag.name = value + slog.Warn("gguf tag has alt: but no primary name", "tag", s) + } else if ok { + tag.alternatives = append(tag.alternatives, value) + } + if value, ok := strings.CutPrefix(part, "pre:"); ok { + tag.prefix = value + } + if value, ok := strings.CutPrefix(part, "suf:"); ok { + tag.suffix = value + } + } + } + + return +} + +func canNil(t reflect.Type) bool { + return t.Kind() == reflect.Chan || + t.Kind() == reflect.Func || + t.Kind() == reflect.Interface || + t.Kind() == reflect.Map || + t.Kind() == reflect.Pointer || + t.Kind() == reflect.Slice +} + +func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { + if len(batch.Positions) != len(batch.Sequences) { + return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) + } + + if len(batch.Positions) < 1 { + return nil, errors.New("batch size cannot be less than 1") + } + + cache := m.Config().Cache + if cache != nil { + err := cache.StartForward(ctx, batch, false) + if err != nil { + return nil, err + } + } + + t, err := m.Forward(ctx, batch) + if err != nil { + return nil, err + } + + ctx.Forward(t) + + return t, nil +} diff --git a/x/model/models/gemma3/embed.go b/x/model/models/gemma3/embed.go new file mode 100644 index 000000000..229cbcb50 --- /dev/null +++ b/x/model/models/gemma3/embed.go @@ -0,0 +1,58 @@ +//go:build mlx + +package gemma3 + +import ( + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/ml/nn" + "github.com/ollama/ollama/x/ml/nn/pooling" + "github.com/ollama/ollama/x/model" + "github.com/ollama/ollama/x/model/input" +) + +type embedModel struct { + model.Base + model.SentencePiece + + *TextModel + poolingType pooling.Type + + Dense [2]*nn.Linear `gguf:"dense"` +} + +func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) + for _, dense := range m.Dense { + hiddenStates = dense.Forward(ctx, hiddenStates) + } + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + return hiddenStates, nil +} + +func newEmbedModel(c fs.Config) (model.Model, error) { + m := &embedModel{ + SentencePiece: model.NewSentencePiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), + int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), + }, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + TextModel: newTextModel(c), + poolingType: pooling.Type(c.Uint("pooling_type", 0)), + } + + return m, nil +} diff --git a/x/model/models/gemma3/model.go b/x/model/models/gemma3/model.go new file mode 100644 index 000000000..23f78f207 --- /dev/null +++ b/x/model/models/gemma3/model.go @@ -0,0 +1,157 @@ +//go:build mlx + +package gemma3 + +import ( + "bytes" + "image" + "math" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/x/kvcache" + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/ml/nn" + "github.com/ollama/ollama/x/model" + "github.com/ollama/ollama/x/model/input" +) + +type Model struct { + model.Base + model.SentencePiece + + *VisionModel `gguf:"vision_tower.vision_model"` + *TextModel `gguf:"language_model.model"` + + *MultiModalProjector `gguf:"multi_modal_projector"` + + ImageProcessor +} + +var _ model.MultimodalProcessor = (*Model)(nil) + +type MultiModalProjector struct { + SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"` + InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight + + tokensPerImage int +} + +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor { + l := visionOutputs.Dim(0) + + visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) + patchesPerImage := imageSize / patchSize + visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l) + + kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage))) + visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0) + visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l) + visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) + visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps) + + // TODO: inputProjection must be transposed since they're incompatible with visionOutputs + visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)) + return visionOutputs +} + +func New(c fs.Config) (model.Model, error) { + // slog.Info("XXX Config", "c", c) + m := Model{ + SentencePiece: model.NewSentencePiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), + int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), + }, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + ImageProcessor: newImageProcessor(c), + VisionModel: newVisionModel(c), + TextModel: newTextModel(c), + MultiModalProjector: &MultiModalProjector{ + tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)), + }, + } + + // slidingWindowLen := int32(c.Uint("attention.sliding_window")) + // m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) + + // TODO need to implement sliding window... + m.Cache = kvcache.NewMLXCausalCache() + + return &m, nil +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + image, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + f32s, err := m.ImageProcessor.ProcessImage(image) + if err != nil { + return nil, err + } + + pixelValues := ctx.Input().FromFloats(f32s, + m.ImageProcessor.imageSize, + m.ImageProcessor.imageSize, + m.ImageProcessor.numChannels, + ) + + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) + return []input.Multimodal{{Tensor: visionOutputs}}, nil +} + +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input + + for _, inp := range inputs { + if len(inp.Multimodal) == 0 { + result = append(result, inp) + } else { + inputMultimodal := inp.Multimodal[0].Tensor + + result = append(result, + &input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + &input.Input{Token: 255999}, // """ + &input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + ) + + // add image token placeholders + result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) + + result = append(result, + &input.Input{Token: 256000}, // + &input.Input{Token: 108}, // "\n\n" + ) + } + } + + return result, nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func init() { + model.Register("gemma3", New) + model.Register("gemma3_embed", newEmbedModel) +} diff --git a/x/model/models/gemma3/model_text.go b/x/model/models/gemma3/model_text.go new file mode 100644 index 000000000..d7686542a --- /dev/null +++ b/x/model/models/gemma3/model_text.go @@ -0,0 +1,211 @@ +//go:build mlx + +package gemma3 + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/x/kvcache" + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/ml/nn" + "github.com/ollama/ollama/x/model/input" +) + +type TextConfig struct { + hiddenSize, numHeads, numKVHeads int + attnKeyLen int + eps, ropeScale float32 + ropeLocalBase, ropeGlobalBase float32 + largeModelScaling bool +} + +type TextModel struct { + TokenEmbedding *nn.Embedding `gguf:"embed_tokens"` + Layers []TextLayer `gguf:"layers"` + OutputNorm *nn.RMSNorm `gguf:"norm"` + Output *nn.Linear `gguf:"embed_tokens"` + + *TextConfig +} + +const ( + gemmaGlobalCacheCount = 6 + gemma27BLayerCount = 62 +) + +// const ( +// cacheTypeSWA = iota +// cacheTypeCausal +// ) + +func newTextModel(c fs.Config) *TextModel { + numBlocks := int(c.Uint("block_count")) + + m := TextModel{ + Layers: make([]TextLayer, numBlocks), + TextConfig: &TextConfig{ + hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size + numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8 + numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above + attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python + ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python + ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python + ropeScale: 1, // 1 - default is 1, implied in python code + // vocabSize: vocabSize, // 262144 + // attnValLen: int(c.Uint("attention.value_length", 256)), //256 + // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights + // (8 instead of 1) + // ropeScale: c.Float("rope.scaling.factor", 1.0), + }, + } + if numBlocks == gemma27BLayerCount { + m.largeModelScaling = true + } + + return &m +} + +type TextSelfAttention struct { + Query *nn.Linear `gguf:"q_proj"` + QueryNorm *nn.RMSNorm `gguf:"q_norm"` + Key *nn.Linear `gguf:"k_proj"` + KeyNorm *nn.RMSNorm `gguf:"k_norm"` + Value *nn.Linear `gguf:"v_proj"` + Output *nn.Linear `gguf:"o_proj"` +} + +func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor { + B := hiddenState.Dim(0) + L := hiddenState.Dim(1) + ropeBase := opts.ropeLocalBase + if (layer+1)%gemmaGlobalCacheCount == 0 { + ropeBase = opts.ropeGlobalBase + } + + q := sa.Query.Forward(ctx, hiddenState) + k := sa.Key.Forward(ctx, hiddenState) + v := sa.Value.Forward(ctx, hiddenState) + q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3) + k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3) + v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false) + q = sa.QueryNorm.Forward(ctx, q, opts.eps) + k = sa.KeyNorm.Forward(ctx, k, opts.eps) + traditional := false + q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase)) + k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase)) + + // TODO - this is wrong somehow so commenting out + // if opts.largeModelScaling { + // q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) + // } else { + // q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) + // } + + scaleFactor := math.Pow(256, -0.5) + + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1) + return sa.Output.Forward(ctx, kqv) +} + +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + // ropeBase := m.TextConfig.ropeLocalBase + // if (layer+1)%gemmaGlobalCacheCount == 0 { + // ropeBase = m.TextConfig.ropeGlobalBase + // } + // q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase)) + panic("not yet implemented") + // return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil +} + +type TextMLP struct { + Up *nn.Linear `gguf:"up_proj"` + Down *nn.Linear `gguf:"down_proj"` + Gate *nn.Linear `gguf:"gate_proj"` +} + +func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type TextLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"` + SelfAttention *TextSelfAttention `gguf:"self_attn"` + PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"` + MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"` + MLP *TextMLP `gguf:"mlp"` + PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"` +} + +func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor { + residual := hiddenState + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts) + hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.TakeAxes(ctx, outputs, 1) + residual = residual.TakeAxes(ctx, outputs, 1) + } + + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely... + hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps) + return hiddenState.Add(ctx, residual) +} + +func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) + hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) + + // set image embeddings + // var except []int + // for _, image := range batch.Multimodal { + // visionOutputs := image.Multimodal[0].Tensor + // ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx, + // []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)}, + // []int{image.Index * hiddenState.Stride(1)}, 0))) + + // for i := range visionOutputs.Dim(1) { + // except = append(except, image.Index+i) + // } + // } + + for i, layer := range m.Layers { + // gemma alternates between the sliding window (local) and causal (global) + // kv cache every 6 layers + if cache != nil { + // cacheType := cacheTypeSWA + // if (i+1)%gemmaGlobalCacheCount == 0 { + // cacheType = cacheTypeCausal + // } + cache.SetLayer(i) + + // TODO this needs to come back + // wc := cache.(*kvcache.WrapperCache) + // wc.SetLayerType(cacheType) + + // if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok { + // causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) + // } + } + + var offset int + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + offset = batch.Offset + lastLayerOutputs = batch.Outputs + } + + hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig) + } + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return hiddenState +} diff --git a/x/model/models/gemma3/model_vision.go b/x/model/models/gemma3/model_vision.go new file mode 100644 index 000000000..bffb3cb58 --- /dev/null +++ b/x/model/models/gemma3/model_vision.go @@ -0,0 +1,121 @@ +//go:build mlx + +package gemma3 + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/x/ml" + "github.com/ollama/ollama/x/ml/nn" +) + +var batchSize int = 1 + +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"self_attn.q_proj"` + Key *nn.Linear `gguf:"self_attn.k_proj"` + Value *nn.Linear `gguf:"self_attn.v_proj"` + Output *nn.Linear `gguf:"self_attn.out_proj"` +} + +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { + headDim := opts.hiddenSize / opts.numHeads + + query := sa.Query.Forward(ctx, hiddenState) + key := sa.Key.Forward(ctx, hiddenState) + value := sa.Value.Forward(ctx, hiddenState) + + query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) + key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) + value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + + hiddenState = sa.Output.Forward(ctx, attention) + return hiddenState +} + +type VisionMLP struct { + FC1 *nn.Linear `gguf:"fc1"` + FC2 *nn.Linear `gguf:"fc2"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { + hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx) + hiddenState = mlp.FC2.Forward(ctx, hiddenState) + return hiddenState +} + +type VisionEncoderLayer struct { + LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"` + SelfAttention *VisionSelfAttention + + LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"` + MLP *VisionMLP `gguf:"mlp"` +} + +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { + residual := hiddenState + + // self attention + hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps) + hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts) + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + // feed forward + hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps) + hiddenState = e.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +type VisionModelOptions struct { + hiddenSize, numHeads int + imageSize, patchSize int + eps float32 +} + +type VisionModel struct { + PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"` + PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"` + PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"` + + Layers []VisionEncoderLayer `gguf:"encoder.layers"` + + *VisionModelOptions +} + +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { + numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize) + + hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) + hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) + hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) + + positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32) + hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs)) + + for _, layer := range m.Layers { + hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions) + } + + hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps) + return hiddenState +} + +func newVisionModel(c fs.Config) *VisionModel { + return &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), + VisionModelOptions: &VisionModelOptions{ + hiddenSize: int(c.Uint("vision.embedding_length")), + numHeads: int(c.Uint("vision.attention.head_count")), + + imageSize: int(c.Uint("vision.image_size")), + patchSize: int(c.Uint("vision.patch_size")), + + eps: c.Float("vision.attention.layer_norm_epsilon"), + }, + } +} diff --git a/x/model/models/gemma3/process_image.go b/x/model/models/gemma3/process_image.go new file mode 100644 index 000000000..09d0727d0 --- /dev/null +++ b/x/model/models/gemma3/process_image.go @@ -0,0 +1,60 @@ +//go:build mlx + +package gemma3 + +import ( + "image" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +type ImageProcessor struct { + imageSize, patchSize, numChannels int +} + +func newImageProcessor(c fs.Config) ImageProcessor { + return ImageProcessor{ + imageSize: int(c.Uint("vision.image_size")), + patchSize: int(c.Uint("vision.patch_size")), + numChannels: int(c.Uint("vision.num_channels")), + } +} + +func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 { + var pixelVals, rVals, gVals, bVals []float32 + + bounds := img.Bounds() + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + c := img.At(x, y) + r, g, b, _ := c.RGBA() + rVal := float32(r>>8) / 255.0 + gVal := float32(g>>8) / 255.0 + bVal := float32(b>>8) / 255.0 + + rVal = (rVal - mean[0]) / std[0] + gVal = (gVal - mean[1]) / std[1] + bVal = (bVal - mean[2]) / std[2] + + rVals = append(rVals, rVal) + gVals = append(gVals, gVal) + bVals = append(bVals, bVal) + } + } + + pixelVals = append(pixelVals, rVals...) + pixelVals = append(pixelVals, gVals...) + pixelVals = append(pixelVals, bVals...) + + return pixelVals +} + +func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { + outputSize := image.Point{p.imageSize, p.imageSize} + newImage := imageproc.Composite(img) + newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) + + data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD) + return data, nil +} diff --git a/x/model/models/models.go b/x/model/models/models.go new file mode 100644 index 000000000..a2542707f --- /dev/null +++ b/x/model/models/models.go @@ -0,0 +1,3 @@ +package models + +// _ "github.com/ollama/ollama/x/model/models/gemma3" diff --git a/x/model/sentencepiece.go b/x/model/sentencepiece.go new file mode 100644 index 000000000..2c178ec0c --- /dev/null +++ b/x/model/sentencepiece.go @@ -0,0 +1,249 @@ +package model + +import ( + "container/heap" + "fmt" + "log/slog" + "strconv" + "strings" + + "github.com/ollama/ollama/logutil" +) + +const spmWhitespaceSep = "▁" + +type SentencePiece struct { + maxTokenLen int + vocab *Vocabulary +} + +var _ TextProcessor = (*SentencePiece)(nil) + +func (spm SentencePiece) Vocabulary() *Vocabulary { + return spm.vocab +} + +func NewSentencePiece(vocab *Vocabulary) SentencePiece { + logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) + + counter := map[int]int{} + var maxTokenLen int + for cnt := range vocab.Types { + switch vocab.Types[cnt] { + case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED: + maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt])) + fallthrough + default: + counter[int(vocab.Types[cnt])] += 1 + } + } + + logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], + "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], + "max token len", maxTokenLen) + + return SentencePiece{ + maxTokenLen: maxTokenLen, + vocab: vocab, + } +} + +func (spm SentencePiece) Is(id int32, special Special) bool { + return spm.vocab.Is(id, special) +} + +func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { + fragments := []fragment{{value: s}} + for _, special := range spm.vocab.SpecialVocabulary() { + id := spm.vocab.Encode(special) + for i := 0; i < len(fragments); i++ { + frag := fragments[i] + if len(frag.ids) > 0 { + continue + } + + var middle []fragment + switch i := strings.Index(frag.value, special); { + case i < 0: + middle = append(middle, frag) + case i > 0: + middle = append(middle, fragment{value: frag.value[:i]}) + fallthrough + default: + middle = append(middle, fragment{value: special, ids: []int32{id}}) + if rest := frag.value[i+len(special):]; rest != "" { + middle = append(middle, fragment{value: rest}) + } + } + + fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) + } + } + + var ids []int32 + for _, frag := range fragments { + if len(frag.ids) > 0 { + ids = append(ids, frag.ids...) + continue + } + + text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep) + + if id := spm.vocab.Encode(text); id >= 0 { + ids = append(ids, id) + continue + } + + q := &queue{} + heap.Init(q) + + runes := []rune(text) + merges := make([]merge, len(runes)) + for r := range runes { + merges[r] = merge{ + p: r - 1, + n: r + 1, + runes: []rune{runes[r]}, + } + } + + pairwise := func(a, b int) *candidate { + if a < 0 || b >= len(runes) { + return nil + } + + left, right := string(merges[a].runes), string(merges[b].runes) + if id := spm.vocab.Encode(left + right); id >= 0 { + return &candidate{ + a: a, + b: b, + score: spm.vocab.Scores[id], + size: len(left) + len(right), + } + } + + return nil + } + + for i := range len(runes) - 1 { + if pair := pairwise(i, i+1); pair != nil { + heap.Push(q, pair) + } + } + + for q.Len() > 0 { + pair := heap.Pop(q).(*candidate) + left, right := merges[pair.a], merges[pair.b] + + if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size { + continue + } + + merges[pair.a].runes = append(left.runes, right.runes...) + merges[pair.b].runes = nil + merges[pair.a].n = right.n + if right.n < len(merges) { + merges[right.n].p = pair.a + } + + if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { + heap.Push(q, pair) + } + + if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { + heap.Push(q, pair) + } + } + + for _, merge := range merges { + if token := string(merge.runes); token != "" { + id := spm.vocab.Encode(token) + + if id >= 0 { + ids = append(ids, id) + continue + } + + // Fallback to byte tokenization + var result []int32 + for _, b := range []byte(token) { + byteToken := fmt.Sprintf("<0x%02X>", b) + unknownID := spm.vocab.Encode(byteToken) + if unknownID >= 0 { + result = append(result, unknownID) + } else { + slog.Debug("unknown byte token", "byte", b, "token", byteToken) + } + } + + ids = append(ids, result...) + } + } + } + + if addSpecial { + ids = spm.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +type candidate struct { + a, b int + score float32 + size int +} + +type queue []*candidate + +func (q queue) Len() int { return len(q) } + +func (q queue) Less(i, j int) bool { + return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a) +} + +func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] } + +func (q *queue) Push(x interface{}) { + item := x.(*candidate) + *q = append(*q, item) +} + +func (q *queue) Pop() interface{} { + old := *q + n := len(old) + item := old[n-1] + *q = old[0 : n-1] + return item +} + +func (spm SentencePiece) Decode(ids []int32) (string, error) { + var sb strings.Builder + for _, id := range ids { + data := spm.vocab.Decode(id) + data = strings.ReplaceAll(data, spmWhitespaceSep, " ") + + // For tokenizers that use byte tokens like "<0xEA>" + // convert them to the partial unicode character + // so they are buffered correctly by the runner instead + // of being sent back to the api as "<0xEA>" + if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") { + byteVal, err := strconv.ParseUint(data[1:5], 0, 8) + if err != nil { + return "", fmt.Errorf("failed to parse hex byte: %v", err) + } + + if err := sb.WriteByte(byte(byteVal)); err != nil { + return "", err + } + } else { + if _, err := sb.WriteString(data); err != nil { + return "", err + } + } + } + + logutil.Trace("decoded", "ids", ids, "string", sb.String()) + return sb.String(), nil +} diff --git a/x/model/sentencepiece_test.go b/x/model/sentencepiece_test.go new file mode 100644 index 000000000..7ab158af7 --- /dev/null +++ b/x/model/sentencepiece_test.go @@ -0,0 +1,172 @@ +package model + +import ( + "log/slog" + "os" + "path/filepath" + "slices" + "testing" + + "google.golang.org/protobuf/proto" + + "github.com/ollama/ollama/convert/sentencepiece" +) + +func loadSentencePieceVocab(t *testing.T) SentencePiece { + t.Helper() + + bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model")) + if err != nil { + t.Fatal(err) + } + + var spm sentencepiece.ModelProto + if err := proto.Unmarshal(bts, &spm); err != nil { + t.Fatal(err) + } + + var v Vocabulary + + for _, piece := range spm.GetPieces() { + v.Values = append(v.Values, piece.GetPiece()) + v.Scores = append(v.Scores, piece.GetScore()) + switch t := piece.GetType(); t { + case sentencepiece.ModelProto_SentencePiece_UNKNOWN, + sentencepiece.ModelProto_SentencePiece_CONTROL, + sentencepiece.ModelProto_SentencePiece_UNUSED, + sentencepiece.ModelProto_SentencePiece_BYTE: + v.Types = append(v.Types, int32(t)) + default: + tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) + // todo parse the special tokens file + // - this will roundtrip correctly but the and + // tokens aren't processed + v.Types = append(v.Types, tt) + } + } + + return NewSentencePiece(&v) +} + +func TestSentencePieceEncode(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + slog.SetDefault(logger) + + tokenizer := loadSentencePieceVocab(t) + + t.Run("basic roundtrip", func(t *testing.T) { + t.Parallel() + + cases := []string{ + "hello", + "hello ", + "hello ", + " hello", + " hello ", + " hello ", + "hello world", + "请考试我的软件!12345", + "你好", + "Hello 你好 world!", + "Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?", + "Multilingual: 你好 こんにちは Привет Hola مرحبا", + "Numbers and symbols: 123456789 +- */", + "Special tokens: text ", + "Code snippets: func main() { fmt.Println(\"Hello World\") }", + "Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.", + } + + for _, want := range cases { + ids, err := tokenizer.Encode(want, true) + if err != nil { + t.Fatal(err) + } + + if got, err := tokenizer.Decode(ids); err != nil { + t.Fatal(err) + } else if got != want { + t.Errorf("got %q, want %q [%#v]", got, want, ids) + } + } + }) + + t.Run("special tokens", func(t *testing.T) { + type candidate struct { + token string + ids []int32 + } + + cases := []candidate{ + {"", []int32{2}}, + {"", []int32{1}}, + } + + for _, want := range cases { + ids, err := tokenizer.Encode(want.token, true) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(ids, want.ids) { + t.Errorf("got %#v, want %#v", ids, want.ids) + } + } + }) +} + +func TestSentencePieceDecodeByteTokens(t *testing.T) { + vocab := &Vocabulary{ + Values: []string{ + "normal", + "<0xEA>", + "<0x41>", + "<0xC3>", + "<0xA3>", + }, + Types: []int32{ + TOKEN_TYPE_NORMAL, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + }, + Scores: []float32{0, 0, 0, 0, 0}, + } + + spm := NewSentencePiece(vocab) + + tests := []struct { + name string + ids []int32 + expected string + }{ + { + name: "single byte token", + ids: []int32{1}, + expected: "\xea", + }, + { + name: "ASCII byte token", + ids: []int32{2}, + expected: "A", + }, + { + name: "multiple byte tokens forming UTF-8 character", + ids: []int32{3, 4}, + expected: "ã", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := spm.Decode(tt.ids) + if err != nil { + t.Errorf("failed to decode token IDs %v: %v", tt.ids, err) + } + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/x/model/textprocessor.go b/x/model/textprocessor.go new file mode 100644 index 000000000..4a36f2352 --- /dev/null +++ b/x/model/textprocessor.go @@ -0,0 +1,17 @@ +package model + +const ( + TOKEN_TYPE_NORMAL = iota + 1 + TOKEN_TYPE_UNKNOWN + TOKEN_TYPE_CONTROL + TOKEN_TYPE_USER_DEFINED + TOKEN_TYPE_UNUSED + TOKEN_TYPE_BYTE +) + +type TextProcessor interface { + Encode(s string, addSpecial bool) ([]int32, error) + Decode([]int32) (string, error) + Is(int32, Special) bool + Vocabulary() *Vocabulary +} diff --git a/x/model/vocabulary.go b/x/model/vocabulary.go new file mode 100644 index 000000000..d977c4957 --- /dev/null +++ b/x/model/vocabulary.go @@ -0,0 +1,112 @@ +package model + +import ( + "log/slog" + "slices" + "sync" +) + +type Special int32 + +const ( + SpecialBOS Special = iota + SpecialEOS +) + +type Vocabulary struct { + Values []string + Types []int32 + Scores []float32 + Merges []string + + BOS, EOS []int32 + AddBOS, AddEOS bool + + specialOnce sync.Once + special []string + + valuesOnce sync.Once + values map[string]int32 + + mergeOnce sync.Once + merge map[string]int32 +} + +func (v *Vocabulary) Is(id int32, special Special) bool { + switch special { + case SpecialBOS: + return slices.Contains(v.BOS, id) + case SpecialEOS: + return slices.Contains(v.EOS, id) + default: + return false + } +} + +func (v *Vocabulary) addSpecials(ids []int32) []int32 { + if v.AddBOS && len(v.BOS) > 0 { + if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) { + slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) + } + + slog.Debug("adding bos token to prompt", "id", v.BOS[0]) + ids = append([]int32{v.BOS[0]}, ids...) + } + + if v.AddEOS && len(v.EOS) > 0 { + if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) { + slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) + } + + slog.Debug("adding eos token to prompt", "id", v.EOS[0]) + ids = append(ids, v.EOS[0]) + } + + return ids +} + +func (v *Vocabulary) Encode(s string) int32 { + v.valuesOnce.Do(func() { + v.values = make(map[string]int32, len(v.Values)) + for i, value := range v.Values { + v.values[value] = int32(i) + } + }) + + if id, ok := v.values[s]; ok { + return id + } + + return -1 +} + +func (v *Vocabulary) Decode(id int32) string { + return v.Values[id] +} + +func (v *Vocabulary) SpecialVocabulary() []string { + v.specialOnce.Do(func() { + for i := range v.Values { + if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED { + v.special = append(v.special, v.Values[i]) + } + } + }) + + return v.special +} + +func (v *Vocabulary) Merge(left, right string) int { + v.mergeOnce.Do(func() { + v.merge = make(map[string]int32, len(v.Merges)) + for i, merge := range v.Merges { + v.merge[merge] = int32(i) + } + }) + + if id, ok := v.merge[left+" "+right]; ok { + return int(id) + } + + return -1 +} diff --git a/x/model/vocabulary_test.go b/x/model/vocabulary_test.go new file mode 100644 index 000000000..ccfc39e69 --- /dev/null +++ b/x/model/vocabulary_test.go @@ -0,0 +1,107 @@ +package model + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestSpecialVocabulary(t *testing.T) { + vocab := &Vocabulary{ + Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"}, + Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL}, + } + + specialVocab := vocab.SpecialVocabulary() + + if len(specialVocab) != 4 { + t.Errorf("expected 4 special tokens, got %d", len(specialVocab)) + } +} + +func TestAddSpecialVocabulary(t *testing.T) { + cases := []struct { + name string + vocab *Vocabulary + input []int32 + want []int32 + }{ + { + name: "add bos", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: false, + }, + input: []int32{2, 3, 4}, + want: []int32{0, 2, 3, 4}, + }, + { + // TODO(mxyng): this is to match previous behaviour + name: "add bos when already present", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: false, + }, + input: []int32{0, 2, 3, 4}, + want: []int32{0, 0, 2, 3, 4}, + }, + { + name: "add eos", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: false, + AddEOS: true, + }, + input: []int32{2, 3, 4}, + want: []int32{2, 3, 4, 1}, + }, + { + // TODO(mxyng): this is to match previous behaviour + name: "add eos when already present", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: false, + AddEOS: true, + }, + input: []int32{2, 3, 4, 1}, + want: []int32{2, 3, 4, 1, 1}, + }, + { + name: "add both", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: true, + }, + input: []int32{2, 3, 4}, + want: []int32{0, 2, 3, 4, 1}, + }, + { + name: "add bos to empty inputs", + vocab: &Vocabulary{ + BOS: []int32{0}, + EOS: []int32{1}, + AddBOS: true, + AddEOS: false, + }, + input: []int32{}, + want: []int32{0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := tt.vocab.addSpecials(tt.input) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("no match (-want +got):\n%s", diff) + } + }) + } +} diff --git a/x/model/wordpiece.go b/x/model/wordpiece.go new file mode 100644 index 000000000..e552bce0d --- /dev/null +++ b/x/model/wordpiece.go @@ -0,0 +1,171 @@ +package model + +import ( + "fmt" + "iter" + "strings" + "unicode" + + "github.com/ollama/ollama/logutil" +) + +type WordPiece struct { + vocab *Vocabulary + lowercase bool +} + +// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries. +// this differs from original word piece which uses "##" to indicate subwords. +const ggmlPrefix = "▁" + +var wordPieceReplacer = strings.NewReplacer( + " .", ".", + " ?", "?", + " !", "!", + " ,", ",", + " ' ", "'", + " n't", "n't", + " 'm", "'m", + " do not", " don't", + " 's", "'s", + " 've", "'ve", + " 're", "'re", +) + +// Decode implements TextProcessor. +func (wpm WordPiece) Decode(ids []int32) (string, error) { + var sb strings.Builder + for i, id := range ids { + if id < 0 || int(id) >= len(wpm.vocab.Values) { + return "", fmt.Errorf("invalid token id: %d", id) + } + + var separator string + piece := wpm.vocab.Values[id] + if i > 0 && + (strings.HasPrefix(piece, ggmlPrefix) || + (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) { + separator = " " + } + + sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix))) + } + + return sb.String(), nil +} + +// words splits a string into words, treating CJK characters as separate words. +// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models. +func (wpm WordPiece) words(s string) iter.Seq[string] { + return func(yield func(string) bool) { + runes := make([]rune, 0, len(s)*3) + for _, r := range s { + switch { + case r >= 0x4E00 && r <= 0x9FFF, + r >= 0x3400 && r <= 0x4DBF, + r >= 0x20000 && r <= 0x2A6DF, + r >= 0x2A700 && r <= 0x2B73F, + r >= 0x2B740 && r <= 0x2B81F, + r >= 0x2B820 && r <= 0x2CEAF, + r >= 0xF900 && r <= 0xFAFF, + r >= 0x2F800 && r <= 0x2FA1F: + runes = append(runes, ' ', r, ' ') + default: + runes = append(runes, r) + } + } + + for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) { + // split on but keep punctuation + var start int + for start < len(w) { + end := strings.IndexFunc(w[start:], unicode.IsPunct) + if end < 0 { + end = len(w) - start + } else if end == 0 { + end = 1 + } + + if !yield(w[start : start+end]) { + return + } + + start += end + } + } + } +} + +// Encode implements TextProcessor. +func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { + var ids []int32 + + // TODO: use [UNK] from config + unk := wpm.vocab.Encode("[UNK]") + for word := range wpm.words(s) { + var start int + var pieces []int32 + for start < len(word) { + end := len(word) + + var piece int32 + for start < end { + subword := word[start:end] + if start == 0 { + subword = ggmlPrefix + subword + } + + if wpm.lowercase { + subword = strings.ToLower(subword) + } + piece = wpm.vocab.Encode(subword) + if piece >= 0 { + break + } + + end-- + } + + if piece < 0 { + // Unknown token + pieces = pieces[:0] + break + } + + pieces = append(pieces, piece) + start = end + } + + if len(pieces) > 0 { + ids = append(ids, pieces...) + } else { + ids = append(ids, unk) + } + } + + if addSpecial { + ids = wpm.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +// Is implements TextProcessor. +func (wpm WordPiece) Is(id int32, special Special) bool { + return wpm.vocab.Is(id, special) +} + +// Vocabulary implements TextProcessor. +func (wpm WordPiece) Vocabulary() *Vocabulary { + return wpm.vocab +} + +var _ TextProcessor = (*WordPiece)(nil) + +func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece { + return WordPiece{ + vocab: vocab, + lowercase: lowercase, + } +} diff --git a/x/model/wordpiece_test.go b/x/model/wordpiece_test.go new file mode 100644 index 000000000..c03bb17a7 --- /dev/null +++ b/x/model/wordpiece_test.go @@ -0,0 +1,53 @@ +package model + +import ( + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWordPiece(t *testing.T) { + wpm := NewWordPiece( + &Vocabulary{ + Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"}, + AddBOS: true, + AddEOS: true, + BOS: []int32{1}, + EOS: []int32{2}, + }, + true, // lowercase + ) + + ids, err := wpm.Encode("Hello world!", true) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" { + t.Errorf("unexpected ids (-want +got):\n%s", diff) + } + + words, err := wpm.Decode(ids) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} + +func TestWordPieceWords(t *testing.T) { + var wpm WordPiece + + basic := slices.Collect(wpm.words("Hey friend! How are you?!?")) + if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } + + chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika")) + if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +}