diff --git a/Dockerfile b/Dockerfile index af8a37e68..21de0afb3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -161,10 +161,6 @@ ARG GOFLAGS="'-ldflags=-w -s'" ENV CGO_ENABLED=1 ARG CGO_CFLAGS ARG CGO_CXXFLAGS -# TODO wire up the actual MLX engine here instead of building the main binary... -RUN mkdir -p dist/bin -RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine - FROM base AS build WORKDIR /go/src/github.com/ollama/ollama diff --git a/cmd/cmd.go b/cmd/cmd.go index 34a773375..187191be3 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -46,6 +46,8 @@ import ( "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" xcmd "github.com/ollama/ollama/x/cmd" + "github.com/ollama/ollama/x/imagegen" + imagegenclient "github.com/ollama/ollama/x/imagegen/client" ) const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" @@ -96,6 +98,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { filename, err := getModelfileName(cmd) if os.IsNotExist(err) { if filename == "" { + // No Modelfile found - check if current directory is an image gen model + if imagegen.IsTensorModelDir(".") { + return imagegenclient.CreateModel(args[0], ".", p) + } reader = strings.NewReader("FROM .\n") } else { return errModelfileNotFound @@ -457,6 +463,15 @@ func RunHandler(cmd *cobra.Command, args []string) error { } name := args[0] + + // Check if this is a known image generation model (skip Show/Pull) + if imagegen.HasTensorLayers(name) { + if opts.Prompt == "" && !interactive { + return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"") + } + return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive) + } + info, err := func() (*api.ShowResponse, error) { showReq := &api.ShowRequest{Name: name} info, err := client.Show(cmd.Context(), showReq) @@ -822,6 +837,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { } func ShowHandler(cmd *cobra.Command, args []string) error { + // Check if this is an image generation model + if imagegen.HasTensorLayers(args[0]) { + return imagegen.Show(args[0], os.Stdout) + } + client, err := api.ClientFromEnvironment() if err != nil { return err @@ -1767,6 +1787,9 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools") runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)") + // Image generation flags (width, height, steps, seed, etc.) + imagegen.RegisterFlags(runCmd) + stopCmd := &cobra.Command{ Use: "stop MODEL", Short: "Stop a running model", diff --git a/progress/stepbar.go b/progress/stepbar.go new file mode 100644 index 000000000..facbded78 --- /dev/null +++ b/progress/stepbar.go @@ -0,0 +1,33 @@ +package progress + +import ( + "fmt" + "strings" +) + +// StepBar displays step-based progress (e.g., for image generation steps). +type StepBar struct { + message string + current int + total int +} + +func NewStepBar(message string, total int) *StepBar { + return &StepBar{message: message, total: total} +} + +func (s *StepBar) Set(current int) { + s.current = current +} + +func (s *StepBar) String() string { + percent := float64(s.current) / float64(s.total) * 100 + barWidth := s.total + empty := barWidth - s.current + + // "Generating 0% ▕ ▏ 0/9" + return fmt.Sprintf("%s %3.0f%% ▕%s%s▏ %d/%d", + s.message, percent, + strings.Repeat("█", s.current), strings.Repeat(" ", empty), + s.current, s.total) +} diff --git a/runner/runner.go b/runner/runner.go index 500fdd72e..543410798 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -3,6 +3,7 @@ package runner import ( "github.com/ollama/ollama/runner/llamarunner" "github.com/ollama/ollama/runner/ollamarunner" + imagerunner "github.com/ollama/ollama/x/imagegen/runner" ) func Execute(args []string) error { @@ -11,12 +12,19 @@ func Execute(args []string) error { } var newRunner bool - if args[0] == "--ollama-engine" { + var imageRunner bool + if len(args) > 0 && args[0] == "--ollama-engine" { args = args[1:] newRunner = true } + if len(args) > 0 && args[0] == "--image-engine" { + args = args[1:] + imageRunner = true + } - if newRunner { + if imageRunner { + return imagerunner.Execute(args) + } else if newRunner { return ollamarunner.Execute(args) } else { return llamarunner.Execute(args) diff --git a/server/images.go b/server/images.go index 951f7ac6e..97ed3a3b9 100644 --- a/server/images.go +++ b/server/images.go @@ -30,6 +30,7 @@ import ( "github.com/ollama/ollama/thinking" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" + "github.com/ollama/ollama/x/imagegen/transfer" ) var ( @@ -73,6 +74,11 @@ type Model struct { func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} + // Check for image generation model via config capabilities + if slices.Contains(m.Config.Capabilities, "image") { + return []model.Capability{model.CapabilityImageGeneration} + } + // Check for completion capability if m.ModelPath != "" { f, err := gguf.Open(m.ModelPath) @@ -555,6 +561,24 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu layers = append(layers, manifest.Config) } + // Use fast transfer for models with tensor layers (many small blobs) + if hasTensorLayers(layers) { + // Read raw manifest JSON to preserve tensor metadata fields + manifestPath, err := mp.GetManifestPath() + if err != nil { + return err + } + manifestJSON, err := os.ReadFile(manifestPath) + if err != nil { + return err + } + if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil { + return err + } + fn(api.ProgressResponse{Status: "success"}) + return nil + } + for _, layer := range layers { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { slog.Info(fmt.Sprintf("error uploading blob: %v", err)) @@ -620,6 +644,15 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu layers = append(layers, manifest.Config) } + // Use fast transfer for models with tensor layers (many small blobs) + if hasTensorLayers(layers) { + if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil { + return err + } + fn(api.ProgressResponse{Status: "success"}) + return nil + } + skipVerify := make(map[string]bool) for _, layer := range layers { cacheHit, err := downloadBlob(ctx, downloadOpts{ @@ -634,7 +667,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu skipVerify[layer.Digest] = cacheHit delete(deleteMap, layer.Digest) } - delete(deleteMap, manifest.Config.Digest) fn(api.ProgressResponse{Status: "verifying sha256 digest"}) for _, layer := range layers { @@ -643,13 +675,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu } if err := verifyBlob(layer.Digest); err != nil { if errors.Is(err, errDigestMismatch) { - // something went wrong, delete the blob fp, err := GetBlobsPath(layer.Digest) if err != nil { return err } if err := os.Remove(fp); err != nil { - // log this, but return the original error slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err)) } } @@ -657,6 +687,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu } } + for _, layer := range layers { + delete(deleteMap, layer.Digest) + } + delete(deleteMap, manifest.Config.Digest) + fn(api.ProgressResponse{Status: "writing manifest"}) manifestJSON, err := json.Marshal(manifest) @@ -690,6 +725,148 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu return nil } +// hasTensorLayers checks if any layer has tensor media type. +func hasTensorLayers(layers []Layer) bool { + for _, layer := range layers { + if layer.MediaType == MediaTypeImageTensor { + return true + } + } + return false +} + +// pullWithTransfer uses the simplified x/transfer package for downloading blobs. +func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error { + blobs := make([]transfer.Blob, len(layers)) + for i, layer := range layers { + blobs[i] = transfer.Blob{ + Digest: layer.Digest, + Size: layer.Size, + } + } + + destDir, err := GetBlobsPath("") + if err != nil { + return err + } + + base := mp.BaseURL() + if base.Scheme != "http" && regOpts != nil && regOpts.Insecure { + base.Scheme = "http" + } + baseURL := base.String() + + var totalSize int64 + for _, blob := range blobs { + totalSize += blob.Size + } + + progress := func(completed, total int64) { + fn(api.ProgressResponse{ + Status: "pulling model", + Digest: "sha256:model", + Total: total, + Completed: completed, + }) + } + + getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) { + return getAuthorizationToken(ctx, registryChallenge{ + Realm: challenge.Realm, + Service: challenge.Service, + Scope: challenge.Scope, + }) + } + + if err := transfer.Download(ctx, transfer.DownloadOptions{ + Blobs: blobs, + BaseURL: baseURL, + DestDir: destDir, + Repository: mp.GetNamespaceRepository(), + Progress: progress, + Token: regOpts.Token, + GetToken: getToken, + Logger: slog.Default(), + }); err != nil { + return err + } + + // Write manifest + fn(api.ProgressResponse{Status: "writing manifest"}) + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } + + fp, err := mp.GetManifestPath() + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil { + return err + } + + return os.WriteFile(fp, manifestJSON, 0o644) +} + +// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest. +func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error { + blobs := make([]transfer.Blob, len(layers)) + for i, layer := range layers { + blobs[i] = transfer.Blob{ + Digest: layer.Digest, + Size: layer.Size, + From: layer.From, + } + } + + srcDir, err := GetBlobsPath("") + if err != nil { + return err + } + + base := mp.BaseURL() + if base.Scheme != "http" && regOpts != nil && regOpts.Insecure { + base.Scheme = "http" + } + baseURL := base.String() + + var totalSize int64 + for _, blob := range blobs { + totalSize += blob.Size + } + + progress := func(completed, total int64) { + fn(api.ProgressResponse{ + Status: "pushing model", + Digest: "sha256:model", + Total: total, + Completed: completed, + }) + } + + getToken := func(ctx context.Context, challenge transfer.AuthChallenge) (string, error) { + return getAuthorizationToken(ctx, registryChallenge{ + Realm: challenge.Realm, + Service: challenge.Service, + Scope: challenge.Scope, + }) + } + + return transfer.Upload(ctx, transfer.UploadOptions{ + Blobs: blobs, + BaseURL: baseURL, + SrcDir: srcDir, + Progress: progress, + Token: regOpts.Token, + GetToken: getToken, + Logger: slog.Default(), + Manifest: manifestJSON, + ManifestRef: mp.Tag, + Repository: mp.GetNamespaceRepository(), + }) +} + func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) { requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) diff --git a/server/images_test.go b/server/images_test.go index a2fba8d98..156914a07 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -47,6 +47,15 @@ func TestModelCapabilities(t *testing.T) { model Model expectedCaps []model.Capability }{ + { + name: "model with image generation capability via config", + model: Model{ + Config: model.ConfigV2{ + Capabilities: []string{"image"}, + }, + }, + expectedCaps: []model.Capability{model.CapabilityImageGeneration}, + }, { name: "model with completion capability", model: Model{ diff --git a/server/layer.go b/server/layer.go index f1fbabea0..4baabe35c 100644 --- a/server/layer.go +++ b/server/layer.go @@ -13,9 +13,14 @@ type Layer struct { Digest string `json:"digest"` Size int64 `json:"size"` From string `json:"from,omitempty"` + Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight" status string } +const ( + MediaTypeImageTensor = "application/vnd.ollama.image.tensor" +) + func NewLayer(r io.Reader, mediatype string) (Layer, error) { blobs, err := GetBlobsPath("") if err != nil { diff --git a/server/routes.go b/server/routes.go index 8e199bada..c58a3db51 100644 --- a/server/routes.go +++ b/server/routes.go @@ -50,6 +50,8 @@ import ( "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" + "github.com/ollama/ollama/x/imagegen" + imagegenapi "github.com/ollama/ollama/x/imagegen/api" ) const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" @@ -162,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return runner.llama, model, &opts, nil } +// ScheduleImageGenRunner schedules an image generation model runner. +// This implements the imagegenapi.RunnerScheduler interface. +func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) { + m := &Model{ + Name: modelName, + ShortName: modelName, + ModelPath: modelName, // For image gen, ModelPath is just the model name + Config: model.ConfigV2{ + Capabilities: []string{"image"}, + }, + } + + runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive) + var runner *runnerRef + select { + case runner = <-runnerCh: + case err := <-errCh: + return nil, err + } + + return runner.llama, nil +} + func signinURL() (string, error) { pubKey, err := auth.GetPublicKey() if err != nil { @@ -189,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + // Check if this is a known image generation model + if imagegen.ResolveModelName(req.Model) != "" { + imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse) + return + } + name := model.ParseName(req.Model) if !name.IsValid() { // Ideally this is "invalid model name" but we're keeping with @@ -1547,6 +1578,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { // Inference (Anthropic compatibility) r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) + // Experimental image generation support + imagegenapi.RegisterRoutes(r, s) + if rc != nil { // wrap old with new rs := ®istry.Local{ diff --git a/server/sched.go b/server/sched.go index c5bc6692d..df4fb2a2b 100644 --- a/server/sched.go +++ b/server/sched.go @@ -21,6 +21,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/types/model" + "github.com/ollama/ollama/x/imagegen" ) type LlmRequest struct { @@ -194,6 +195,14 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus)) } + // Check for image generation model before attempting GGML load + if slices.Contains(pending.model.Config.Capabilities, "image") { + if s.loadImageGen(pending) { + break + } + continue + } + // Load model for fitting logutil.Trace("loading model metadata", "model", pending.model.ModelPath) ggml, err := llm.LoadModel(pending.model.ModelPath, 1024) @@ -543,6 +552,48 @@ iGPUScan: return false } +// loadImageGen loads an image generation model. +func (s *Scheduler) loadImageGen(req *LlmRequest) bool { + // Use model name for imagegen (it resolves manifests by name, not file path) + modelName := req.model.ShortName + server, err := imagegen.NewServer(modelName) + if err != nil { + req.errCh <- err + return true + } + + sessionDuration := envconfig.KeepAlive() + if req.sessionDuration != nil { + sessionDuration = req.sessionDuration.Duration + } + + runner := &runnerRef{ + model: req.model, + modelPath: req.model.ModelPath, + llama: server, + Options: &req.opts, + loading: false, + sessionDuration: sessionDuration, + refCount: 1, + } + + s.loadedMu.Lock() + s.loaded[req.model.ModelPath] = runner + s.loadedMu.Unlock() + + // Set up expiration timer + runner.refMu.Lock() + if sessionDuration > 0 { + runner.expireTimer = time.AfterFunc(sessionDuration, func() { + s.expiredCh <- runner + }) + } + runner.refMu.Unlock() + + req.useLoadedRunner(runner, s.finishedReqCh) + return true +} + func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) { if len(allGpus) == 0 { return diff --git a/server/sched_test.go b/server/sched_test.go index 480aafa4e..292b1635c 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -6,6 +6,7 @@ import ( "errors" "log/slog" "os" + "slices" "testing" "time" @@ -16,6 +17,7 @@ import ( "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/types/model" ) func TestMain(m *testing.M) { @@ -804,3 +806,61 @@ func (s *mockLlm) GetPort() int { return - func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil } func (s *mockLlm) HasExited() bool { return false } func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil } + +// TestImageGenCapabilityDetection verifies that models with "image" capability +// are correctly identified and routed differently from language models. +func TestImageGenCapabilityDetection(t *testing.T) { + // Model with image capability should be detected + imageModel := &Model{ + Config: model.ConfigV2{ + Capabilities: []string{"image"}, + }, + } + require.True(t, slices.Contains(imageModel.Config.Capabilities, "image")) + + // Model without image capability should not be detected + langModel := &Model{ + Config: model.ConfigV2{ + Capabilities: []string{"completion"}, + }, + } + require.False(t, slices.Contains(langModel.Config.Capabilities, "image")) + + // Empty capabilities should not match + emptyModel := &Model{} + require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image")) +} + +// TestImageGenRunnerCanBeEvicted verifies that an image generation model +// loaded in the scheduler can be evicted by a language model request. +func TestImageGenRunnerCanBeEvicted(t *testing.T) { + ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond) + defer done() + + s := InitScheduler(ctx) + s.getGpuFn = getGpuFn + s.getSystemInfoFn = getSystemInfoFn + + // Simulate an image gen runner already loaded + imageGenRunner := &runnerRef{ + model: &Model{Name: "z-image", ModelPath: "/fake/image/model"}, + modelPath: "/fake/image/model", + llama: &mockLlm{vramSize: 21 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{}}, + sessionDuration: 5 * time.Millisecond, + refCount: 0, // idle + } + + s.loadedMu.Lock() + s.loaded["/fake/image/model"] = imageGenRunner + s.loadedMu.Unlock() + + // Verify the image gen runner is loaded + s.loadedMu.Lock() + require.Len(t, s.loaded, 1) + s.loadedMu.Unlock() + + // findRunnerToUnload should find the idle image gen runner + runner := s.findRunnerToUnload() + require.NotNil(t, runner) + require.Equal(t, "/fake/image/model", runner.modelPath) +} diff --git a/types/model/capability.go b/types/model/capability.go index cde23cee7..62e1abd8b 100644 --- a/types/model/capability.go +++ b/types/model/capability.go @@ -3,12 +3,13 @@ package model type Capability string const ( - CapabilityCompletion = Capability("completion") - CapabilityTools = Capability("tools") - CapabilityInsert = Capability("insert") - CapabilityVision = Capability("vision") - CapabilityEmbedding = Capability("embedding") - CapabilityThinking = Capability("thinking") + CapabilityCompletion = Capability("completion") + CapabilityTools = Capability("tools") + CapabilityInsert = Capability("insert") + CapabilityVision = Capability("vision") + CapabilityEmbedding = Capability("embedding") + CapabilityThinking = Capability("thinking") + CapabilityImageGeneration = Capability("image") ) func (c Capability) String() string { diff --git a/x/imagegen/README.md b/x/imagegen/README.md index e68f295b8..38abfc427 100644 --- a/x/imagegen/README.md +++ b/x/imagegen/README.md @@ -1,61 +1,236 @@ -# imagegen +# Image Generation in Ollama (Experimental) -This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner. -in `CMakeLists.txt` and rebuild. +Generate images from text prompts using local AI models. -### 1. Download a Model - -Download Llama 3.1 8B (or any compatible model) in safetensors format: +## Quick Start ```bash -mkdir -p ./weights - -# Example using huggingface-cli -hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B -hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b +# Run with a prompt +ollama run z-image "a sunset over mountains" +Generating: step 30/30 +Image saved to: /tmp/ollama-image-1704067200.png ``` -### 2. Run Inference +On macOS, the generated image will automatically open in Preview. + +## Supported Models + +| Model | VRAM Required | Notes | +|-------|---------------|-------| +| z-image | ~12GB | Based on Flux architecture | + +## CLI Usage ```bash -# Build -go build ./cmd/engine +# Generate an image +ollama run z-image "a cat playing piano" -# Text generation -./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250 +# Check if model is running +ollama ps -# Qwen-Image 2512 (text-to-image) -./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \ - -width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png - -# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50 -./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \ - -input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \ - -steps 8 -seed 42 -output edited.png +# Stop the model +ollama stop z-image ``` -## Memory Management +## API -MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly. +### OpenAI-Compatible Endpoint -Instead, arrays are automatically tracked and freed on `Eval()`: - -```go -// All arrays are automatically tracked when created -x := mlx.Add(a, b) -y := mlx.Matmul(x, w) - -// Eval frees non-kept arrays, evaluates outputs (auto-kept) -mlx.Eval(y) - -// After copying to CPU, free the array -data := y.Data() -y.Free() +```bash +POST /v1/images/generations ``` -Key points: +**Request:** +```json +{ + "model": "z-image", + "prompt": "a sunset over mountains", + "size": "1024x1024", + "response_format": "b64_json" +} +``` -- All created arrays are automatically tracked -- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept) -- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches) -- Call `.Free()` when done with an array +**Response:** +```json +{ + "created": 1704067200, + "data": [ + { + "b64_json": "iVBORw0KGgo..." + } + ] +} +``` + +### Example: cURL + +```bash +curl http://localhost:11434/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "model": "z-image", + "prompt": "a white cat", + "size": "1024x1024" + }' +``` + +### Example: Save to File + +```bash +curl -s http://localhost:11434/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "model": "z-image", + "prompt": "a white cat", + "size": "1024x1024" + }' | jq -r '.data[0].b64_json' | base64 -d > image.png +``` + +### Streaming Progress + +Enable streaming to receive progress updates via SSE: + +```bash +curl http://localhost:11434/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{"model": "z-image", "prompt": "a sunset", "stream": true}' +``` + +Events: +``` +event: progress +data: {"step": 1, "total": 30} + +event: progress +data: {"step": 2, "total": 30} +... + +event: done +data: {"created": 1704067200, "data": [{"b64_json": "..."}]} +``` + +## Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| model | string | required | Model name | +| prompt | string | required | Text description of image | +| size | string | "1024x1024" | Image dimensions (WxH) | +| n | int | 1 | Number of images (currently only 1 supported) | +| response_format | string | "b64_json" | "b64_json" or "url" | +| stream | bool | false | Enable progress streaming | + +## Requirements + +- macOS with Apple Silicon (M1/M2/M3/M4) +- CUDA: tested on CUDA 12 Blackwell, more testing coming soon +- Sufficient VRAM (see model table above) +- Ollama built with MLX support + +## Limitations + +- macOS only (uses MLX backend) +- Single image per request +- Fixed step count (30 steps) +- Modelfiles not yet supported (use `ollama create` from model directory) + +--- + +# Tensor Model Storage Format + +Tensor models store each tensor as a separate blob with metadata in the manifest. This enables faster downloads (parallel fetching) and deduplication (shared tensors are stored once). + +## Manifest Structure + +The manifest follows the standard ollama format with tensor-specific layer metadata: + +```json +{ + "schemaVersion": 2, + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "config": { "digest": "sha256:...", "size": 1234 }, + "layers": [ + { + "mediaType": "application/vnd.ollama.image.tensor", + "digest": "sha256:25b36eed...", + "size": 49807448, + "name": "text_encoder/model.layers.0.mlp.down_proj.weight", + "dtype": "BF16", + "shape": [2560, 9728] + }, + { + "mediaType": "application/vnd.ollama.image.json", + "digest": "sha256:abc123...", + "size": 512, + "name": "text_encoder/config.json" + } + ] +} +``` + +Each tensor layer includes: +- `name`: Path-style tensor name (e.g., `text_encoder/model.layers.0.mlp.down_proj.weight`) +- `dtype`: Data type (BF16, F32, etc.) +- `shape`: Tensor dimensions + +Config layers use the same path-style naming (e.g., `tokenizer/tokenizer.json`). + +## Blob Format + +Each tensor blob is a minimal safetensors file: + +``` +[8 bytes: header size (uint64 LE)] +[~80 bytes: JSON header, padded to 8-byte alignment] +[N bytes: raw tensor data] +``` + +Header contains a single tensor named `"data"`: + +```json +{"data":{"dtype":"BF16","shape":[2560,9728],"data_offsets":[0,49807360]}} +``` + +## Why Include the Header? + +The ~88 byte safetensors header enables MLX's native `mlx_load_safetensors` function, which: + +1. **Uses mmap** - Maps file directly into memory, no copies +2. **Zero-copy to GPU** - MLX reads directly from mapped pages +3. **No custom code** - Standard MLX API, battle-tested + +Without the header, we'd need custom C++ code to create MLX arrays from raw mmap'd data. MLX's public API doesn't expose this - it always copies when creating arrays from external pointers. + +The overhead is negligible: 88 bytes per tensor = ~100KB total for a 13GB model (0.0007%). + +## Why Per-Tensor Blobs? + +**Deduplication**: Blobs are content-addressed by SHA256. If two models share identical tensors (same weights, dtype, shape), they share the same blob file. + +Example: Model A and Model B both use the same text encoder. The text encoder's 400 tensors are stored once, referenced by both manifests. + +``` +~/.ollama/models/ + blobs/ + sha256-25b36eed... <- shared by both models + sha256-abc123... + manifests/ + library/model-a/latest <- references sha256-25b36eed + library/model-b/latest <- references sha256-25b36eed +``` + +## Import Flow + +``` +cd ./weights/Z-Image-Turbo +ollama create z-image + +1. Scan component directories (text_encoder/, transformer/, vae/) +2. For each .safetensors file: + - Extract individual tensors + - Wrap each in minimal safetensors format (88B header + data) + - Write to blob store (SHA256 content-addressed) + - Add layer entry to manifest with path-style name +3. Copy config files (*.json) as config layers +4. Write manifest +``` diff --git a/x/imagegen/api/handler.go b/x/imagegen/api/handler.go new file mode 100644 index 000000000..f66ed6d85 --- /dev/null +++ b/x/imagegen/api/handler.go @@ -0,0 +1,235 @@ +package api + +import ( + "encoding/base64" + "fmt" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/x/imagegen" +) + +// RunnerScheduler is the interface for scheduling a model runner. +// This is implemented by server.Server to avoid circular imports. +type RunnerScheduler interface { + ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) +} + +// RegisterRoutes registers the image generation API routes. +func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) { + r.POST("/v1/images/generations", func(c *gin.Context) { + ImageGenerationHandler(c, scheduler) + }) +} + +// ImageGenerationHandler handles OpenAI-compatible image generation requests. +func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) { + var req ImageGenerationRequest + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + + // Validate required fields + if req.Model == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}}) + return + } + if req.Prompt == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}}) + return + } + + // Apply defaults + if req.N == 0 { + req.N = 1 + } + if req.Size == "" { + req.Size = "1024x1024" + } + if req.ResponseFormat == "" { + req.ResponseFormat = "b64_json" + } + + // Verify model exists + if imagegen.ResolveModelName(req.Model) == "" { + c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}}) + return + } + + // Parse size + width, height := parseSize(req.Size) + + // Build options - we repurpose NumCtx/NumGPU for width/height + opts := api.Options{} + opts.NumCtx = int(width) + opts.NumGPU = int(height) + + // Schedule runner + runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil) + if err != nil { + status := http.StatusInternalServerError + if strings.Contains(err.Error(), "not found") { + status = http.StatusNotFound + } + c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + + // Build completion request + completionReq := llm.CompletionRequest{ + Prompt: req.Prompt, + Options: &opts, + } + + if req.Stream { + handleStreamingResponse(c, runner, completionReq, req.ResponseFormat) + } else { + handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat) + } +} + +func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + var imagePath string + err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) { + if resp.Done { + imagePath = extractPath(resp.Content) + } else { + progress := parseProgress(resp.Content) + if progress.Total > 0 { + c.SSEvent("progress", progress) + c.Writer.Flush() + } + } + }) + if err != nil { + c.SSEvent("error", gin.H{"error": err.Error()}) + return + } + + c.SSEvent("done", buildResponse(imagePath, format)) +} + +func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) { + var imagePath string + err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) { + if resp.Done { + imagePath = extractPath(resp.Content) + } + }) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}}) + return + } + + c.JSON(http.StatusOK, buildResponse(imagePath, format)) +} + +func parseSize(size string) (int32, int32) { + parts := strings.Split(size, "x") + if len(parts) != 2 { + return 1024, 1024 + } + w, _ := strconv.Atoi(parts[0]) + h, _ := strconv.Atoi(parts[1]) + if w == 0 { + w = 1024 + } + if h == 0 { + h = 1024 + } + return int32(w), int32(h) +} + +func extractPath(content string) string { + if idx := strings.Index(content, "Image saved to: "); idx >= 0 { + return strings.TrimSpace(content[idx+16:]) + } + return "" +} + +func parseProgress(content string) ImageProgressEvent { + var step, total int + fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total) + return ImageProgressEvent{Step: step, Total: total} +} + +func buildResponse(imagePath, format string) ImageGenerationResponse { + resp := ImageGenerationResponse{ + Created: time.Now().Unix(), + Data: make([]ImageData, 1), + } + + if imagePath == "" { + return resp + } + + if format == "url" { + resp.Data[0].URL = "file://" + imagePath + } else { + data, err := os.ReadFile(imagePath) + if err == nil { + resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data) + } + } + + return resp +} + +// HandleGenerateRequest handles Ollama /api/generate requests for image gen models. +// This allows routes.go to delegate image generation with minimal code. +func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) { + opts := api.Options{} + + // Schedule runner + runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Build completion request + completionReq := llm.CompletionRequest{ + Prompt: prompt, + Options: &opts, + } + + // Stream responses via channel + ch := make(chan any) + go func() { + defer close(ch) + err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) { + ch <- GenerateResponse{ + Model: modelName, + CreatedAt: time.Now().UTC(), + Response: resp.Content, + Done: resp.Done, + } + }) + if err != nil { + // Log error but don't block - channel is already being consumed + _ = err + } + }() + + streamFn(c, ch) +} + +// GenerateResponse matches api.GenerateResponse structure for streaming. +type GenerateResponse struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` +} diff --git a/x/imagegen/api/types.go b/x/imagegen/api/types.go new file mode 100644 index 000000000..c0e67d12f --- /dev/null +++ b/x/imagegen/api/types.go @@ -0,0 +1,31 @@ +// Package api provides OpenAI-compatible image generation API types. +package api + +// ImageGenerationRequest is an OpenAI-compatible image generation request. +type ImageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +// ImageGenerationResponse is an OpenAI-compatible image generation response. +type ImageGenerationResponse struct { + Created int64 `json:"created"` + Data []ImageData `json:"data"` +} + +// ImageData contains the generated image data. +type ImageData struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + +// ImageProgressEvent is sent during streaming to indicate generation progress. +type ImageProgressEvent struct { + Step int `json:"step"` + Total int `json:"total"` +} diff --git a/x/imagegen/cli.go b/x/imagegen/cli.go new file mode 100644 index 000000000..1268f449a --- /dev/null +++ b/x/imagegen/cli.go @@ -0,0 +1,539 @@ +// cli.go provides CLI commands for image generation models. +// +// TODO (jmorganca): Integrate these commands into cmd/cmd.go when stable. +// Currently these are separate to keep experimental code isolated. + +package imagegen + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" + "time" + + "github.com/spf13/cobra" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/readline" +) + +// ImageGenOptions holds options for image generation. +// These can be set via environment variables or interactive commands. +type ImageGenOptions struct { + Width int + Height int + Steps int + Seed int + NegativePrompt string +} + +// DefaultOptions returns the default image generation options. +func DefaultOptions() ImageGenOptions { + return ImageGenOptions{ + Width: 1024, + Height: 1024, + Steps: 9, + Seed: 0, // 0 means random + } +} + +// Show displays information about an image generation model. +func Show(modelName string, w io.Writer) error { + manifest, err := LoadManifest(modelName) + if err != nil { + return fmt.Errorf("failed to load manifest: %w", err) + } + + // Count total size + var totalSize int64 + for _, layer := range manifest.Manifest.Layers { + if layer.MediaType == "application/vnd.ollama.image.tensor" { + totalSize += layer.Size + } + } + + // Read model_index.json for architecture + var architecture string + if data, err := manifest.ReadConfig("model_index.json"); err == nil { + var index struct { + Architecture string `json:"architecture"` + } + if json.Unmarshal(data, &index) == nil { + architecture = index.Architecture + } + } + + // Estimate parameter count from total size (assuming BF16 = 2 bytes per param) + paramCount := totalSize / 2 + paramStr := formatParamCount(paramCount) + + // Print Model info + fmt.Fprintln(w, " Model") + if architecture != "" { + fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture) + } + fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr) + fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16") + fmt.Fprintln(w) + + // Print Capabilities + fmt.Fprintln(w, " Capabilities") + fmt.Fprintf(w, " %s\n", "image") + fmt.Fprintln(w) + + return nil +} + +// formatParamCount formats parameter count as human-readable string. +func formatParamCount(count int64) string { + if count >= 1_000_000_000 { + return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000) + } + if count >= 1_000_000 { + return fmt.Sprintf("%.1fM", float64(count)/1_000_000) + } + return fmt.Sprintf("%d", count) +} + +// RegisterFlags adds image generation flags to the given command. +// Flags are hidden since they only apply to image generation models. +func RegisterFlags(cmd *cobra.Command) { + cmd.Flags().Int("width", 1024, "Image width") + cmd.Flags().Int("height", 1024, "Image height") + cmd.Flags().Int("steps", 9, "Denoising steps") + cmd.Flags().Int("seed", 0, "Random seed (0 for random)") + cmd.Flags().String("negative", "", "Negative prompt") + cmd.Flags().MarkHidden("width") + cmd.Flags().MarkHidden("height") + cmd.Flags().MarkHidden("steps") + cmd.Flags().MarkHidden("seed") + cmd.Flags().MarkHidden("negative") +} + +// RunCLI handles the CLI for image generation models. +// Returns true if it handled the request, false if the caller should continue with normal flow. +// Supports flags: --width, --height, --steps, --seed, --negative +func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error { + // Verify it's a valid image gen model + if ResolveModelName(name) == "" { + return fmt.Errorf("unknown image generation model: %s", name) + } + + // Get options from flags (with env var defaults) + opts := DefaultOptions() + if cmd != nil && cmd.Flags() != nil { + if v, err := cmd.Flags().GetInt("width"); err == nil && v > 0 { + opts.Width = v + } + if v, err := cmd.Flags().GetInt("height"); err == nil && v > 0 { + opts.Height = v + } + if v, err := cmd.Flags().GetInt("steps"); err == nil && v > 0 { + opts.Steps = v + } + if v, err := cmd.Flags().GetInt("seed"); err == nil && v != 0 { + opts.Seed = v + } + if v, err := cmd.Flags().GetString("negative"); err == nil && v != "" { + opts.NegativePrompt = v + } + } + + if interactive { + return runInteractive(cmd, name, keepAlive, opts) + } + + // One-shot generation + return generateImageWithOptions(cmd, name, prompt, keepAlive, opts) +} + +// generateImageWithOptions generates an image with the given options. +func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keepAlive *api.Duration, opts ImageGenOptions) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + // Build request with image gen options encoded in Options fields + // NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed + req := &api.GenerateRequest{ + Model: modelName, + Prompt: prompt, + Options: map[string]any{ + "num_ctx": opts.Width, + "num_gpu": opts.Height, + "num_predict": opts.Steps, + "seed": opts.Seed, + }, + } + if keepAlive != nil { + req.KeepAlive = keepAlive + } + + // Show loading spinner until generation starts + p := progress.NewProgress(os.Stderr) + spinner := progress.NewSpinner("") + p.Add("", spinner) + + var stepBar *progress.StepBar + var imagePath string + + err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { + content := resp.Response + + // Handle progress updates - parse step info and switch to step bar + if strings.HasPrefix(content, "\rGenerating:") { + var step, total int + fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total) + if stepBar == nil && total > 0 { + spinner.Stop() + stepBar = progress.NewStepBar("Generating", total) + p.Add("", stepBar) + } + if stepBar != nil { + stepBar.Set(step) + } + return nil + } + + // Handle final response with image path + if resp.Done && strings.Contains(content, "Image saved to:") { + if idx := strings.Index(content, "Image saved to: "); idx >= 0 { + imagePath = strings.TrimSpace(content[idx+16:]) + } + } + + return nil + }) + + p.Stop() + if err != nil { + return err + } + + if imagePath != "" { + displayImageInTerminal(imagePath) + fmt.Printf("Image saved to: %s\n", imagePath) + } + + return nil +} + +// runInteractive runs an interactive REPL for image generation. +func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duration, opts ImageGenOptions) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + scanner, err := readline.New(readline.Prompt{ + Prompt: ">>> ", + Placeholder: "Describe an image to generate (/help for commands)", + }) + if err != nil { + return err + } + + if envconfig.NoHistory() { + scanner.HistoryDisable() + } + + for { + line, err := scanner.Readline() + switch { + case errors.Is(err, io.EOF): + fmt.Println() + return nil + case errors.Is(err, readline.ErrInterrupt): + if line == "" { + fmt.Println("\nUse Ctrl + d or /bye to exit.") + } + continue + case err != nil: + return err + } + + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Handle commands + switch { + case strings.HasPrefix(line, "/bye"): + return nil + case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"): + printInteractiveHelp(opts) + continue + case strings.HasPrefix(line, "/set "): + if err := handleSetCommand(line[5:], &opts); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } + continue + case strings.HasPrefix(line, "/show"): + printCurrentSettings(opts) + continue + case strings.HasPrefix(line, "/"): + fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line) + continue + } + + // Generate image with current options + req := &api.GenerateRequest{ + Model: modelName, + Prompt: line, + Options: map[string]any{ + "num_ctx": opts.Width, + "num_gpu": opts.Height, + "num_predict": opts.Steps, + "seed": opts.Seed, + }, + } + if keepAlive != nil { + req.KeepAlive = keepAlive + } + + // Show loading spinner until generation starts + p := progress.NewProgress(os.Stderr) + spinner := progress.NewSpinner("") + p.Add("", spinner) + + var stepBar *progress.StepBar + var imagePath string + + err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { + content := resp.Response + + // Handle progress updates - parse step info and switch to step bar + if strings.HasPrefix(content, "\rGenerating:") { + var step, total int + fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total) + if stepBar == nil && total > 0 { + spinner.Stop() + stepBar = progress.NewStepBar("Generating", total) + p.Add("", stepBar) + } + if stepBar != nil { + stepBar.Set(step) + } + return nil + } + + // Handle final response with image path + if resp.Done && strings.Contains(content, "Image saved to:") { + if idx := strings.Index(content, "Image saved to: "); idx >= 0 { + imagePath = strings.TrimSpace(content[idx+16:]) + } + } + + return nil + }) + + p.Stop() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + continue + } + + // Copy image to current directory with descriptive name + if imagePath != "" { + // Create filename from prompt (sanitized) + safeName := sanitizeFilename(line) + if len(safeName) > 50 { + safeName = safeName[:50] + } + timestamp := time.Now().Format("20060102-150405") + newName := fmt.Sprintf("%s-%s.png", safeName, timestamp) + + // Copy file to CWD + if err := copyFile(imagePath, newName); err != nil { + fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err) + displayImageInTerminal(imagePath) + fmt.Printf("Image saved to: %s\n", imagePath) + } else { + displayImageInTerminal(newName) + fmt.Printf("Image saved to: %s\n", newName) + } + } + + fmt.Println() + } +} + +// sanitizeFilename removes characters that aren't safe for filenames. +func sanitizeFilename(s string) string { + s = strings.ToLower(s) + s = strings.ReplaceAll(s, " ", "-") + // Remove any character that's not alphanumeric or hyphen + var result strings.Builder + for _, r := range s { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' { + result.WriteRune(r) + } + } + return result.String() +} + +// copyFile copies a file from src to dst. +func copyFile(src, dst string) error { + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + destFile, err := os.Create(dst) + if err != nil { + return err + } + defer destFile.Close() + + _, err = io.Copy(destFile, sourceFile) + return err +} + +// printInteractiveHelp prints help for interactive mode commands. +func printInteractiveHelp(opts ImageGenOptions) { + fmt.Fprintln(os.Stderr, "Commands:") + fmt.Fprintln(os.Stderr, " /set width Set image width (current:", opts.Width, ")") + fmt.Fprintln(os.Stderr, " /set height Set image height (current:", opts.Height, ")") + fmt.Fprintln(os.Stderr, " /set steps Set denoising steps (current:", opts.Steps, ")") + fmt.Fprintln(os.Stderr, " /set seed Set random seed (current:", opts.Seed, ", 0=random)") + fmt.Fprintln(os.Stderr, " /set negative Set negative prompt") + fmt.Fprintln(os.Stderr, " /show Show current settings") + fmt.Fprintln(os.Stderr, " /bye Exit") + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "Or type a prompt to generate an image.") + fmt.Fprintln(os.Stderr) +} + +// printCurrentSettings prints the current image generation settings. +func printCurrentSettings(opts ImageGenOptions) { + fmt.Fprintf(os.Stderr, "Current settings:\n") + fmt.Fprintf(os.Stderr, " width: %d\n", opts.Width) + fmt.Fprintf(os.Stderr, " height: %d\n", opts.Height) + fmt.Fprintf(os.Stderr, " steps: %d\n", opts.Steps) + fmt.Fprintf(os.Stderr, " seed: %d (0=random)\n", opts.Seed) + if opts.NegativePrompt != "" { + fmt.Fprintf(os.Stderr, " negative: %s\n", opts.NegativePrompt) + } + fmt.Fprintln(os.Stderr) +} + +// handleSetCommand handles /set commands to change options. +func handleSetCommand(args string, opts *ImageGenOptions) error { + parts := strings.SplitN(args, " ", 2) + if len(parts) < 2 { + return fmt.Errorf("usage: /set