mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
108 lines
2.3 KiB
Go
108 lines
2.3 KiB
Go
//go:build mlx
|
|
|
|
package imagegen
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"image"
|
|
"image/png"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
// SaveImage saves an MLX array as a PNG image file.
|
|
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
|
func SaveImage(arr *mlx.Array, path string) error {
|
|
img, err := ArrayToImage(arr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if filepath.Ext(path) != ".png" {
|
|
path = path + ".png"
|
|
}
|
|
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
return png.Encode(f, img)
|
|
}
|
|
|
|
// EncodeImageBase64 encodes an MLX array as a base64-encoded PNG.
|
|
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
|
func EncodeImageBase64(arr *mlx.Array) (string, error) {
|
|
img, err := ArrayToImage(arr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := png.Encode(&buf, img); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
|
}
|
|
|
|
// ArrayToImage converts an MLX array to a Go image.RGBA.
|
|
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
|
func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
|
shape := arr.Shape()
|
|
if len(shape) != 4 {
|
|
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
|
}
|
|
|
|
// Transform to [H, W, C] for image conversion
|
|
img := mlx.Squeeze(arr, 0)
|
|
img = mlx.Transpose(img, 1, 2, 0)
|
|
img = mlx.Contiguous(img)
|
|
mlx.Eval(img)
|
|
|
|
imgShape := img.Shape()
|
|
H := int(imgShape[0])
|
|
W := int(imgShape[1])
|
|
C := int(imgShape[2])
|
|
|
|
if C != 3 {
|
|
img.Free()
|
|
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
|
}
|
|
|
|
// Copy to CPU and free GPU memory
|
|
data := img.Data()
|
|
img.Free()
|
|
|
|
// Write directly to Pix slice (faster than SetRGBA)
|
|
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
|
pix := goImg.Pix
|
|
for y := 0; y < H; y++ {
|
|
for x := 0; x < W; x++ {
|
|
srcIdx := (y*W + x) * C
|
|
dstIdx := (y*W + x) * 4
|
|
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
|
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
|
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
|
pix[dstIdx+3] = 255
|
|
}
|
|
}
|
|
|
|
return goImg, nil
|
|
}
|
|
|
|
func clampF(v, min, max float32) float32 {
|
|
if v < min {
|
|
return min
|
|
}
|
|
if v > max {
|
|
return max
|
|
}
|
|
return v
|
|
}
|