mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
336 lines
9.2 KiB
Go
336 lines
9.2 KiB
Go
//go:build mlx
|
|
|
|
package mlx
|
|
|
|
/*
|
|
#include <stdio.h>
|
|
#include <string.h>
|
|
|
|
#include "mlx/c/array.h"
|
|
#include "mlx/c/ops.h"
|
|
|
|
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
|
|
|
|
void unpack_32_4(uint8_t* data, int8_t* dst) {
|
|
memset(dst, 0, 16);
|
|
for (int j = 0; j < 16; ++j) {
|
|
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
|
|
if (j % 2 != 0) {
|
|
x <<= 4;
|
|
}
|
|
dst[j / 2] += x;
|
|
}
|
|
// Last 16 weights are in the higher bits
|
|
for (int j = 0; j < 16; ++j) {
|
|
uint8_t x = (data[j + 2] >> 4);
|
|
if (j % 2 != 0) {
|
|
x <<= 4;
|
|
}
|
|
dst[8 + j / 2] += x;
|
|
}
|
|
}
|
|
|
|
// Extracts (weight, scales, biases) from Q4_0 tensors.
|
|
// Data layout is: |16 bit scale|32 x 4bit weights|.
|
|
void extract_q4_0_data(
|
|
uint8_t* data,
|
|
mlx_array* weights_arr,
|
|
mlx_array* scales_arr,
|
|
mlx_array* biases_arr) {
|
|
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
|
|
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
|
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
|
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
|
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
|
scales[i] = *((float16_t*)data);
|
|
biases[i] = -8 * scales[i];
|
|
unpack_32_4(data, weights);
|
|
weights += 16;
|
|
data += bytes_per_block;
|
|
}
|
|
}
|
|
|
|
// Extracts (weight, scales, biases) from Q4_1 tensors.
|
|
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
|
|
void extract_q4_1_data(
|
|
uint8_t* data,
|
|
mlx_array* weights_arr,
|
|
mlx_array* scales_arr,
|
|
mlx_array* biases_arr) {
|
|
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
|
|
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
|
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
|
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
|
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
|
scales[i] = *((float16_t*)data);
|
|
biases[i] = *((float16_t*)(data) + 1);
|
|
unpack_32_4(data, weights);
|
|
weights += 16;
|
|
data += bytes_per_block;
|
|
}
|
|
}
|
|
|
|
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
|
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
|
void extract_q8_0_data(
|
|
uint8_t* data,
|
|
mlx_array* weights_arr,
|
|
mlx_array* scales_arr,
|
|
mlx_array* biases_arr) {
|
|
const uint64_t weights_per_block = 32;
|
|
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
|
|
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
|
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
|
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
|
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
|
uint8_t* block_data = data + i * bytes_per_block;
|
|
scales[i] = *((float16_t*)block_data);
|
|
biases[i] = -128 * scales[i];
|
|
for (int64_t j = 0; j < weights_per_block; ++j) {
|
|
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
|
|
// Original data is in int8_t, so we add a bias of -128 and invert the
|
|
// first bit.
|
|
x ^= 1 << 7;
|
|
weights[i * weights_per_block + j] = x;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Drived from ggml-quants.c
|
|
|
|
#define QK_K 256
|
|
|
|
// 6-bit quantization
|
|
// weight is represented as x = a * q
|
|
// 16 blocks of 16 elements each
|
|
// Effectively 6.5625 bits per weight
|
|
typedef struct {
|
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
|
uint16_t d; // super-block scale
|
|
} block_q6_K;
|
|
|
|
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
|
|
const int64_t nb = k / QK_K;
|
|
block_q6_K *x = (block_q6_K *)vx;
|
|
float16_t* y = (float16_t *)vy;
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
float16_t d = 0.0;
|
|
memcpy(&d, &x[i].d, sizeof(d));
|
|
|
|
const uint8_t * restrict ql = x[i].ql;
|
|
const uint8_t * restrict qh = x[i].qh;
|
|
const int8_t * restrict sc = x[i].scales;
|
|
|
|
for (int n = 0; n < QK_K; n += 128) {
|
|
for (int l = 0; l < 32; ++l) {
|
|
int is = l/16;
|
|
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
|
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
|
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
|
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
|
y[l + 0] = d * sc[is + 0] * q1;
|
|
y[l + 32] = d * sc[is + 2] * q2;
|
|
y[l + 64] = d * sc[is + 4] * q3;
|
|
y[l + 96] = d * sc[is + 6] * q4;
|
|
}
|
|
y += 128;
|
|
ql += 64;
|
|
qh += 32;
|
|
sc += 8;
|
|
}
|
|
}
|
|
}
|
|
|
|
#define K_SCALE_SIZE 12
|
|
#define GGML_COMMON_AGGR_U
|
|
#define GGML_COMMON_AGGR_S
|
|
|
|
// 4-bit quantization
|
|
// 8 blocks of 32 elements each
|
|
// weight is represented as x = a * q + b
|
|
// Effectively 4.5 bits per weight
|
|
typedef struct {
|
|
union {
|
|
struct {
|
|
uint16_t d; // super-block scale for quantized scales
|
|
uint16_t dmin; // super-block scale for quantized mins
|
|
} GGML_COMMON_AGGR_S;
|
|
uint16_t dm;
|
|
} GGML_COMMON_AGGR_U;
|
|
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
|
} block_q4_K;
|
|
|
|
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
|
if (j < 4) {
|
|
*d = q[j] & 63; *m = q[j + 4] & 63;
|
|
} else {
|
|
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
|
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
|
}
|
|
}
|
|
|
|
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
|
|
block_q4_K *x = (block_q4_K *)vx;
|
|
float16_t* y = (float16_t *)vy;
|
|
const int nb = k / QK_K;
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
const uint8_t * q = x[i].qs;
|
|
float16_t d = 0.0;
|
|
memcpy(&d, &x[i].d, sizeof(d));
|
|
float16_t min = 0.0;
|
|
memcpy(&min, &x[i].dmin, sizeof(d));
|
|
|
|
int is = 0;
|
|
uint8_t sc, m;
|
|
for (int j = 0; j < QK_K; j += 64) {
|
|
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
|
const float16_t d1 = d * sc; const float16_t m1 = min * m;
|
|
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
|
const float16_t d2 = d * sc; const float16_t m2 = min * m;
|
|
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
|
q += 32; is += 2;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
*/
|
|
import "C"
|
|
|
|
import (
|
|
"fmt"
|
|
"unsafe"
|
|
|
|
"github.com/x448/float16"
|
|
)
|
|
|
|
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
|
shape := append([]C.int{}, final_shape...)
|
|
var weights_per_byte C.int
|
|
if dtype == 2 || dtype == 3 {
|
|
weights_per_byte = 2
|
|
} else if dtype == 8 {
|
|
weights_per_byte = 1
|
|
} else {
|
|
return r, fmt.Errorf("unsupported tensor type %d", dtype)
|
|
}
|
|
|
|
weights_per_block := C.int(32)
|
|
if shape[len(shape)-1]%weights_per_block != 0 {
|
|
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
|
|
}
|
|
|
|
weights_shape := append([]C.int{}, shape...)
|
|
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
|
|
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
|
|
for i := range weights_shape {
|
|
w_nbytes *= weights_shape[i]
|
|
}
|
|
w_data := make([]byte, w_nbytes)
|
|
cbytes := C.CBytes(w_data)
|
|
defer C.free(cbytes)
|
|
weights := C.mlx_array_new_data(
|
|
cbytes,
|
|
&weights_shape[0],
|
|
C.int(len(weights_shape)),
|
|
C.MLX_UINT32,
|
|
)
|
|
|
|
// For scales and bias
|
|
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
|
|
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
|
|
for i := range shape {
|
|
sb_nbytes *= shape[i]
|
|
}
|
|
|
|
s_data := make([]byte, sb_nbytes)
|
|
cbytes = C.CBytes(s_data)
|
|
defer C.free(cbytes)
|
|
scales := C.mlx_array_new_data(
|
|
cbytes,
|
|
&shape[0],
|
|
C.int(len(shape)),
|
|
C.MLX_FLOAT16,
|
|
)
|
|
b_data := make([]byte, sb_nbytes)
|
|
cbytes = C.CBytes(b_data)
|
|
defer C.free(cbytes)
|
|
biases := C.mlx_array_new_data(
|
|
cbytes,
|
|
&shape[0],
|
|
C.int(len(shape)),
|
|
C.MLX_FLOAT16,
|
|
)
|
|
var bits C.int
|
|
switch dtype {
|
|
case 2:
|
|
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
|
bits = 4
|
|
case 3:
|
|
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
|
bits = 4
|
|
case 8:
|
|
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
|
bits = 8
|
|
}
|
|
groupSize := C.mlx_optional_int{value: 32, has_value: true}
|
|
bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
|
|
var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
|
|
C.mlx_dequantize(
|
|
&r,
|
|
weights,
|
|
scales,
|
|
biases,
|
|
groupSize,
|
|
bitsOpt,
|
|
nil, // TODO mode
|
|
dtypeOpt,
|
|
stream,
|
|
)
|
|
C.mlx_array_free(weights)
|
|
C.mlx_array_free(scales)
|
|
C.mlx_array_free(biases)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
|
size := 1
|
|
for _, d := range shape {
|
|
size *= int(d)
|
|
}
|
|
fdata := make([]float16.Float16, size)
|
|
switch dtype {
|
|
case 14:
|
|
C.dequant_row_q6_K(
|
|
data,
|
|
unsafe.Pointer(&fdata[0]),
|
|
C.int(size),
|
|
)
|
|
|
|
case 12:
|
|
C.dequant_row_q4_K(
|
|
data,
|
|
unsafe.Pointer(&fdata[0]),
|
|
C.int(size),
|
|
)
|
|
default:
|
|
return r, fmt.Errorf("unsupported K quant")
|
|
}
|
|
|
|
r = C.mlx_array_new_data(
|
|
unsafe.Pointer(&fdata[0]),
|
|
&shape[0],
|
|
C.int(len(shape)),
|
|
C.MLX_FLOAT16,
|
|
)
|
|
return r, nil
|
|
}
|