mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
* WIP - MLX backend with gemma3 * MLX: add cmake and go tag build toggles To build the new MLX backend code: cmake --preset MLX cmake --build --preset MLX --parallel cmake --install build --component MLX go build -tags mlx . Note: the main.go entrypoint for the MLX engine will change in a follow up commit. * add experimental image generation runtime * add experimental image generation runtime * MLX: wire up cuda build for linux * MLX: get dependencies correct and dedup This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13 directory. * fix relative link bug in dedup * Add darwin build and readme * add go build tag for mlx dependent code and wire up build_darwin.sh * lint cleanup * macos: build mlx for x86 This will be CPU only. * cuda build instructions and fix drift from mlx bump * stale comment * Delete agent helper doc * Clean up readme.md * Revise README for tokenizer clarity and details Updated README to clarify tokenizer functionality and removed correctness section. --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
798 lines
25 KiB
Go
798 lines
25 KiB
Go
package kvcache
|
|
|
|
// import (
|
|
// "errors"
|
|
// "fmt"
|
|
// "log/slog"
|
|
// "math"
|
|
// "slices"
|
|
|
|
// "github.com/ollama/ollama/ml"
|
|
// "github.com/ollama/ollama/model/input"
|
|
// )
|
|
|
|
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
|
|
|
// // Causal cache stores K and V tensors according to their position in the
|
|
// // sequence. Returns the history and a mask for attending to past tokens
|
|
// //
|
|
// // The tensors are of shape embed dim, kv heads, batch size
|
|
// // The mask is of shape history size, batch size
|
|
// type Causal struct {
|
|
// DType ml.DType
|
|
|
|
// // swaWindowSize is the number of tokens that will be included in the mask
|
|
// // during attention operations. swaMemorySize is the number of tokens that
|
|
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
|
// // for unlimited or if sliding window attention is not being used.
|
|
// swaWindowSize int32
|
|
// swaMemorySize int32
|
|
|
|
// chunkSize int32
|
|
|
|
// opts CausalOptions
|
|
|
|
// // maxBatch is the largest batch that we might receive
|
|
// maxBatch int
|
|
|
|
// // config controls mostly backend-specific optimizations
|
|
// config *ml.CacheConfig
|
|
|
|
// // ** current forward pass **
|
|
|
|
// // size of the current batch
|
|
// curBatchSize int
|
|
|
|
// // locations for data storage for this batch
|
|
// curLoc ml.Tensor
|
|
|
|
// // mask of the cache as used by this batch
|
|
// curMask ml.Tensor
|
|
|
|
// // the active layer for Get and Put
|
|
// curLayer int
|
|
|
|
// // locations in the cache that are needed for this batch
|
|
// curCellRange cellRange
|
|
|
|
// // curSequences is the sequences corresponding to this pass's entries in the cache
|
|
// curSequences []int
|
|
|
|
// // curPositions is the positions corresponding to this pass's entries in the cache
|
|
// curPositions []int32
|
|
|
|
// // ** cache metadata **
|
|
|
|
// // for each possible location in the cache, stores the position and set of sequences
|
|
// // that reference the data there
|
|
// cells []cacheCell
|
|
|
|
// // maps from sequence to the range of locations where it is stored in the cache
|
|
// cellRanges map[int]cellRange
|
|
|
|
// // ** cache data storage **
|
|
|
|
// shiftFn shiftFn
|
|
// backend ml.Backend
|
|
// ctxs map[int]ml.Context
|
|
// keys, values map[int]ml.Tensor
|
|
|
|
// kHeadDims, vHeadDims, numKVHeads map[int]int
|
|
// }
|
|
|
|
// type cacheCell struct {
|
|
// pos int32
|
|
// sequences []int
|
|
// }
|
|
|
|
// type cellRange struct {
|
|
// min int
|
|
// max int
|
|
// }
|
|
|
|
// func NewCausalCache(shift shiftFn) *Causal {
|
|
// return &Causal{
|
|
// shiftFn: shift,
|
|
// ctxs: make(map[int]ml.Context),
|
|
// keys: make(map[int]ml.Tensor),
|
|
// values: make(map[int]ml.Tensor),
|
|
// kHeadDims: make(map[int]int),
|
|
// vHeadDims: make(map[int]int),
|
|
// numKVHeads: make(map[int]int),
|
|
// }
|
|
// }
|
|
|
|
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
// return &Causal{
|
|
// swaWindowSize: windowSize,
|
|
// shiftFn: shift,
|
|
// ctxs: make(map[int]ml.Context),
|
|
// keys: make(map[int]ml.Tensor),
|
|
// values: make(map[int]ml.Tensor),
|
|
// kHeadDims: make(map[int]int),
|
|
// vHeadDims: make(map[int]int),
|
|
// numKVHeads: make(map[int]int),
|
|
// }
|
|
// }
|
|
|
|
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
|
// return &Causal{
|
|
// swaWindowSize: windowSize,
|
|
// swaMemorySize: memorySize,
|
|
// shiftFn: shift,
|
|
// ctxs: make(map[int]ml.Context),
|
|
// keys: make(map[int]ml.Tensor),
|
|
// values: make(map[int]ml.Tensor),
|
|
// kHeadDims: make(map[int]int),
|
|
// vHeadDims: make(map[int]int),
|
|
// numKVHeads: make(map[int]int),
|
|
// }
|
|
// }
|
|
|
|
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
|
// return &Causal{
|
|
// chunkSize: chunkSize,
|
|
// shiftFn: shift,
|
|
// ctxs: make(map[int]ml.Context),
|
|
// keys: make(map[int]ml.Tensor),
|
|
// values: make(map[int]ml.Tensor),
|
|
// kHeadDims: make(map[int]int),
|
|
// vHeadDims: make(map[int]int),
|
|
// numKVHeads: make(map[int]int),
|
|
// }
|
|
// }
|
|
|
|
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
|
// if c.config == nil {
|
|
// var config ml.CacheConfig
|
|
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
|
// config = cc.CacheConfig()
|
|
// }
|
|
// c.config = &config
|
|
// }
|
|
|
|
// if c.config.CachePadding == 0 {
|
|
// c.config.CachePadding = 1
|
|
// }
|
|
|
|
// if c.config.MaskBatchPadding == 0 {
|
|
// c.config.MaskBatchPadding = 1
|
|
// }
|
|
|
|
// // TODO what types do we handle here?
|
|
// // if c.config.MaskDType == ml.DTypeOther {
|
|
// // c.config.MaskDType = ml.DTypeFloat32
|
|
// // }
|
|
|
|
// if c.swaWindowSize == 0 {
|
|
// c.swaWindowSize = math.MaxInt32
|
|
// }
|
|
// if c.swaMemorySize == 0 {
|
|
// c.swaMemorySize = c.swaWindowSize
|
|
// }
|
|
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
|
|
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
|
|
// // causing a cache break. As an optimization, only do this when we have parallel sequences
|
|
// // because the extra token will live in the batch buffer and won't get overwritten if we
|
|
// // only have a single sequence.
|
|
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
|
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
|
// }
|
|
// if int(c.swaMemorySize) >= capacity {
|
|
// c.swaMemorySize = math.MaxInt32
|
|
// }
|
|
|
|
// if c.swaMemorySize < c.swaWindowSize {
|
|
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
|
// }
|
|
|
|
// var cacheSize int
|
|
// if c.swaMemorySize == math.MaxInt32 {
|
|
// cacheSize = maxSequences * capacity
|
|
// } else {
|
|
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
|
// }
|
|
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
|
// c.cells = make([]cacheCell, cacheSize)
|
|
|
|
// c.DType = dtype
|
|
// c.cellRanges = make(map[int]cellRange)
|
|
// c.backend = backend
|
|
// c.maxBatch = maxBatch
|
|
// }
|
|
|
|
// func (c *Causal) SetConfig(config ml.CacheConfig) {
|
|
// if c.config != nil {
|
|
// panic("config cannot be changed after being previously set, either by the model or backend")
|
|
// }
|
|
|
|
// c.config = &config
|
|
// }
|
|
|
|
// func (c *Causal) Close() {
|
|
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
|
// for _, ctx := range c.ctxs {
|
|
// ctx.Close()
|
|
// }
|
|
// }
|
|
|
|
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
|
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
|
|
// // panic("XXX Causal.StartForward")
|
|
// c.curBatchSize = len(batch.Positions)
|
|
// c.curSequences = batch.Sequences
|
|
// c.curPositions = batch.Positions
|
|
// c.opts.Except = nil
|
|
|
|
// var locs []int32
|
|
// if !reserve {
|
|
// c.updateSlidingWindow()
|
|
|
|
// var err error
|
|
// locs, err = c.findLocs()
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
|
|
|
|
// for i, pos := range batch.Positions {
|
|
// seq := batch.Sequences[i]
|
|
// loc := int(locs[i])
|
|
|
|
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
|
|
|
// seqRange, ok := c.cellRanges[seq]
|
|
// if !ok {
|
|
// seqRange = newRange()
|
|
// }
|
|
|
|
// seqRange.min = min(seqRange.min, loc)
|
|
// c.curCellRange.min = min(c.curCellRange.min, loc)
|
|
|
|
// seqRange.max = max(seqRange.max, loc)
|
|
// c.curCellRange.max = max(c.curCellRange.max, loc)
|
|
|
|
// c.cellRanges[seq] = seqRange
|
|
// }
|
|
// } else {
|
|
// // If we are reserving memory, don't update any of the cache metadata but set the size
|
|
// // to the worst case.
|
|
// locs = make([]int32, c.curBatchSize)
|
|
// for i := range locs {
|
|
// locs[i] = int32(i)
|
|
// }
|
|
// c.curCellRange.min = 0
|
|
// c.curCellRange.max = len(c.cells) - 1
|
|
// }
|
|
|
|
// // XXX Building up the locs for what's already processed (if any)
|
|
// dummyLocs := []int{}
|
|
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
|
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
|
|
|
// for i := range c.curBatchSize {
|
|
// enabled := !slices.Contains(c.opts.Except, i)
|
|
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
|
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
|
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
|
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
|
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
// } else {
|
|
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
|
|
// dummyLocs = append(dummyLocs, i)
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
|
|
|
|
// slog.Info("XXX Causal.StartForward", "locs", locs)
|
|
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
|
// c.curMask = c.buildMask(ctx)
|
|
|
|
// return nil
|
|
// }
|
|
|
|
// func newRange() cellRange {
|
|
// return cellRange{
|
|
// min: math.MaxInt,
|
|
// max: 0,
|
|
// }
|
|
// }
|
|
|
|
// // Returns a slice of locations where each token in the batch should be stored
|
|
// func (c *Causal) findLocs() ([]int32, error) {
|
|
// loc := make([]int32, 0, c.curBatchSize)
|
|
|
|
// for i := range c.cells {
|
|
// if len(c.cells[i].sequences) == 0 {
|
|
// loc = append(loc, int32(i))
|
|
// if len(loc) >= c.curBatchSize {
|
|
// return loc, nil
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
|
// }
|
|
|
|
// func (c *Causal) updateSlidingWindow() {
|
|
// c.curCellRange = newRange()
|
|
|
|
// if c.swaMemorySize == math.MaxInt32 {
|
|
// for _, seq := range c.curSequences {
|
|
// if seqRange, ok := c.cellRanges[seq]; ok {
|
|
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
|
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
|
// }
|
|
// }
|
|
|
|
// return
|
|
// }
|
|
|
|
// type lowestPosition struct {
|
|
// pos int32
|
|
// curBatch bool
|
|
// }
|
|
|
|
// // create a map of unique sequences to the lowest position in that sequence
|
|
// lowestPos := make(map[int]lowestPosition)
|
|
// for i := range c.curPositions {
|
|
// seq := c.curSequences[i]
|
|
|
|
// lowest, ok := lowestPos[seq]
|
|
// if !ok {
|
|
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
|
// } else if c.curPositions[i] < lowest.pos {
|
|
// lowest.pos = c.curPositions[i]
|
|
// }
|
|
|
|
// lowestPos[seq] = lowest
|
|
// }
|
|
|
|
// // for any sequences are not part of this batch, clean up any tokens
|
|
// // that are no longer needed after the processing of the previous
|
|
// // batch
|
|
// for seq, seqRange := range c.cellRanges {
|
|
// if _, ok := lowestPos[seq]; !ok {
|
|
// var last int32
|
|
// for i := seqRange.min; i <= seqRange.max; i++ {
|
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
|
// last = max(last, c.cells[i].pos)
|
|
// }
|
|
// }
|
|
|
|
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
|
// }
|
|
// }
|
|
|
|
// // delete any entries that are beyond the window of the oldest position in the sequence
|
|
// for seq, lowest := range lowestPos {
|
|
// oldRange, ok := c.cellRanges[seq]
|
|
// if !ok {
|
|
// continue
|
|
// }
|
|
|
|
// newRange := newRange()
|
|
|
|
// for i := oldRange.min; i <= oldRange.max; i++ {
|
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
|
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
|
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
|
// } else {
|
|
// newRange.min = min(newRange.min, i)
|
|
// newRange.max = max(newRange.max, i)
|
|
// }
|
|
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
|
// c.curCellRange.min = min(c.curCellRange.min, i)
|
|
// c.curCellRange.max = max(c.curCellRange.max, i)
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// c.cellRanges[seq] = newRange
|
|
// }
|
|
// }
|
|
|
|
// func roundDown(length, pad int) int {
|
|
// return (length / pad) * pad
|
|
// }
|
|
|
|
// func roundUp(length, pad int) int {
|
|
// return ((length + pad - 1) / pad) * pad
|
|
// }
|
|
|
|
// // Builds a mask of history x batch indicating whether for each token in the batch the
|
|
// // token in the history should apply. This is based on both the sequence and causality (the
|
|
// // position of the history is not ahead of the token in the batch).
|
|
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|
// // Align and pad the two dimensions as required by the backend
|
|
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
|
|
|
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
|
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
|
|
|
// length := c.curCellRange.max - c.curCellRange.min + 1
|
|
|
|
// mask := make([]float32, batchSize*length)
|
|
|
|
// for i := range c.curBatchSize {
|
|
// enabled := !slices.Contains(c.opts.Except, i)
|
|
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
|
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
|
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
|
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
|
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
|
|
// // has already been masked out because the sequence doesn't match.
|
|
// for i := c.curBatchSize * length; i < len(mask); i++ {
|
|
// mask[i] = float32(math.Inf(-1))
|
|
// }
|
|
|
|
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
|
|
|
|
// // if c.config.MaskDType != ml.DTypeFloat32 {
|
|
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
|
// // }
|
|
|
|
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
|
|
|
|
// return maskTensor
|
|
// }
|
|
|
|
// func (c *Causal) SetLayer(layer int) {
|
|
// c.curLayer = layer
|
|
// }
|
|
|
|
// type CausalOptions struct {
|
|
// // Enabled controls whether the causal mask is generated for a particular index in a batch
|
|
// Except []int
|
|
// }
|
|
|
|
// // SetCausal disables causal mask generation for a particular range of indicies in
|
|
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
|
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
|
// if !slices.Equal(c.opts.Except, opts.Except) {
|
|
// c.opts = opts
|
|
// if ctx != nil {
|
|
// c.curMask = c.buildMask(ctx)
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
// key := c.keys[c.curLayer]
|
|
// value := c.values[c.curLayer]
|
|
|
|
// kHeadDim := c.kHeadDims[c.curLayer]
|
|
// vHeadDim := c.vHeadDims[c.curLayer]
|
|
// numKVHeads := c.numKVHeads[c.curLayer]
|
|
// // rowSize := numKVHeads * c.curBatchSize
|
|
// // cachedSize := c.curMask.Dim(1)
|
|
// cachedSize := c.curLoc.Dim(0)
|
|
// // kCellSize := kHeadDim * numKVHeads
|
|
// // vCellSize := vHeadDim * numKVHeads
|
|
|
|
// slog.Info("XXX Causal.Get full cache", "key", key)
|
|
// slog.Info("XXX Causal.Get full cache", "value", value)
|
|
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
|
|
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
|
|
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
|
|
// // panic("XXX")
|
|
|
|
// // fmt.Fprintln(os.Stderr, key.ToString())
|
|
// // panic("full cache value")
|
|
|
|
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
|
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
|
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
|
|
|
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
|
|
// // panic("XXX")
|
|
|
|
// // if c.config.PermutedV {
|
|
// // panic("permuted")
|
|
// // // TODO not converted
|
|
// // vHeadDim := value.Dim(1)
|
|
// // elemSize := value.Stride(2)
|
|
|
|
// // value = value.AsStrided(ctx,
|
|
// // []int{numKVHeads, vHeadDim, cachedSize},
|
|
// // []int{value.Stride(0), value.Stride(1)},
|
|
// // elemSize*c.curCellRange.min,
|
|
// // )
|
|
// // } else {
|
|
// // vHeadDim := c.vHeadDims[c.curLayer]
|
|
// // rowSize := value.Stride(2)
|
|
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
|
|
// // panic("XXX")
|
|
|
|
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
|
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
|
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
|
|
|
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
|
|
// // panic("XXX")
|
|
|
|
// // }
|
|
|
|
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
|
|
// // // the 1 becomes trailing and messes up later operations
|
|
// // // This isn't the right solution, but works around it...
|
|
// // if c.curMask.Dim(1) == 1 {
|
|
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
|
|
// // }
|
|
// // fmt.Fprintln(os.Stderr, key.ToString())
|
|
// // fmt.Fprintln(os.Stderr, value.ToString())
|
|
// // panic("XXX")
|
|
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
|
|
|
|
// return key, value, c.curMask
|
|
// }
|
|
|
|
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
// kHeadDim := key.Dim(3)
|
|
// vHeadDim := value.Dim(3)
|
|
// numKVHeads := key.Dim(1)
|
|
// batchSize := key.Dim(2)
|
|
// kCellSize := kHeadDim * numKVHeads
|
|
// vCellSize := vHeadDim * numKVHeads
|
|
|
|
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
|
|
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
|
|
// // panic("XXX")
|
|
|
|
// if c.curBatchSize != batchSize {
|
|
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
|
// }
|
|
|
|
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
|
|
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
|
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
|
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
|
// }
|
|
|
|
// if _, ok := c.keys[c.curLayer]; !ok {
|
|
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
|
|
|
|
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
|
|
// c.kHeadDims[c.curLayer] = kHeadDim
|
|
// c.vHeadDims[c.curLayer] = vHeadDim
|
|
// c.numKVHeads[c.curLayer] = numKVHeads
|
|
// }
|
|
|
|
// if _, ok := c.values[c.curLayer]; !ok {
|
|
// // if c.config.PermutedV {
|
|
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
|
|
// // } else {
|
|
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
|
|
// // }
|
|
// }
|
|
|
|
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
|
|
|
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
|
|
// // panic("XXX")
|
|
// // curLoc := 0 // TODO c.curLoc is now a tensor
|
|
// // kSize := numKVHeads * kHeadDim
|
|
// // vSize := numKVHeads * vHeadDim
|
|
// // start := []int{int(curLoc), 0}
|
|
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
|
|
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
|
|
// // strides := []int{1, 1}
|
|
|
|
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
|
|
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
|
|
|
|
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
|
|
|
|
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
|
|
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
|
|
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
|
|
// // panic("input value")
|
|
|
|
// // fmt.Fprintln(os.Stderr, t.ToString())
|
|
// // panic("XXX")
|
|
|
|
// // if c.config.PermutedV {
|
|
// // panic("permuted")
|
|
// // // TODO not adjusted
|
|
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
|
// // value = value.Transpose(ctx, 2, 0, 1, 3)
|
|
|
|
// // valueCache := c.values[c.curLayer]
|
|
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
|
|
|
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
|
|
// // } else {
|
|
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
|
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
|
|
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
|
|
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
|
|
|
|
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
|
|
// // }
|
|
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
|
|
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
|
|
// // panic("XXX")
|
|
|
|
// }
|
|
|
|
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|
// seqRange := newRange()
|
|
|
|
// for i := range c.cells {
|
|
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
|
// if slices.Contains(c.cells[i].sequences, dstSeq) {
|
|
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
|
// }
|
|
|
|
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
|
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
|
// if i < seqRange.min {
|
|
// seqRange.min = i
|
|
// }
|
|
// if i > seqRange.max {
|
|
// seqRange.max = i
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// c.cellRanges[dstSeq] = seqRange
|
|
// }
|
|
|
|
// func (c *Causal) CanResume(seq int, pos int32) bool {
|
|
// if c.swaMemorySize == math.MaxInt32 {
|
|
// return true
|
|
// }
|
|
|
|
// seqRange, ok := c.cellRanges[seq]
|
|
// if !ok {
|
|
// return false
|
|
// }
|
|
|
|
// // for sliding window, check that the window of the new sequence is contained in
|
|
// // the window of what we are storing
|
|
// var first int32 = math.MaxInt32
|
|
// var last int32 = -1
|
|
// for i := seqRange.min; i <= seqRange.max; i++ {
|
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
|
// first = min(first, c.cells[i].pos)
|
|
// last = max(last, c.cells[i].pos)
|
|
// }
|
|
// }
|
|
|
|
// if last == -1 {
|
|
// return false
|
|
// }
|
|
|
|
// posWindowStart := max(0, pos-c.swaWindowSize)
|
|
// return posWindowStart >= first && pos <= last+1
|
|
// }
|
|
|
|
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|
// if c.shiftFn == nil {
|
|
// return ErrNotSupported
|
|
// }
|
|
|
|
// seqRange := c.cellRanges[seq]
|
|
|
|
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
|
// size := min(seqRange.max-start+1, c.maxBatch)
|
|
// offsets := make([]int32, size)
|
|
|
|
// var batchFirst, batchLast int
|
|
|
|
// batchFirst = -1
|
|
// for i := range offsets {
|
|
// cell := c.cells[start+i]
|
|
|
|
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
|
// offsets[i] = offset
|
|
// if batchFirst < 0 {
|
|
// batchFirst = i
|
|
// }
|
|
// batchLast = i
|
|
// }
|
|
// }
|
|
|
|
// if batchFirst < 0 {
|
|
// continue
|
|
// }
|
|
|
|
// offsets = offsets[batchFirst : batchLast+1]
|
|
|
|
// slog.Info("XXX Causal.shift creating new temporary context")
|
|
// ctx := c.backend.NewContext()
|
|
// kShift := ctx.Input().FromInts(offsets, len(offsets))
|
|
|
|
// for i, key := range c.keys {
|
|
// if key == nil {
|
|
// continue
|
|
// }
|
|
|
|
// kHeadDim := key.Dim(2)
|
|
// numKVHeads := key.Dim(1)
|
|
// rowSize := key.Stride(0)
|
|
|
|
// key = key.AsStrided(ctx,
|
|
// []int{len(offsets), numKVHeads, kHeadDim},
|
|
// []int{key.Stride(0), key.Stride(1)},
|
|
// rowSize*(start+batchFirst),
|
|
// )
|
|
|
|
// roped, err := c.shiftFn(ctx, i, key, kShift)
|
|
// if err != nil {
|
|
// ctx.Close()
|
|
// return err
|
|
// }
|
|
|
|
// ctx.Forward(roped.Copy(ctx, key))
|
|
// }
|
|
|
|
// ctx.Compute()
|
|
// ctx.Close()
|
|
// }
|
|
|
|
// return nil
|
|
// }
|
|
|
|
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
|
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
|
|
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
|
|
// // should return an error, which will trigger the runner to evaluate the full history and
|
|
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
|
// // results in use after free, so we don't do it for now.
|
|
|
|
// var offset int32
|
|
// if endIndex != math.MaxInt32 {
|
|
// offset = beginIndex - endIndex
|
|
// }
|
|
|
|
// seqRange := newRange()
|
|
|
|
// for i := range c.cells {
|
|
// if slices.Contains(c.cells[i].sequences, seq) {
|
|
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
|
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
|
// } else {
|
|
// if c.cells[i].pos >= endIndex {
|
|
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
|
// return errors.New("shifting cells shared by multiple sequences not supported")
|
|
// }
|
|
|
|
// c.cells[i].pos += offset
|
|
// }
|
|
// if i < seqRange.min {
|
|
// seqRange.min = i
|
|
// }
|
|
// if i > seqRange.max {
|
|
// seqRange.max = i
|
|
// }
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// if seqRange == newRange() {
|
|
// delete(c.cellRanges, seq)
|
|
// return nil
|
|
// }
|
|
|
|
// c.cellRanges[seq] = seqRange
|
|
|
|
// if endIndex != math.MaxInt32 {
|
|
// err := c.shift(seq, endIndex+offset, offset)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// }
|
|
|
|
// return nil
|
|
// }
|