mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
321 lines
7.1 KiB
Go
321 lines
7.1 KiB
Go
package transfer
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
var (
|
|
errStalled = errors.New("download stalled")
|
|
errSlow = errors.New("download too slow")
|
|
)
|
|
|
|
type downloader struct {
|
|
client *http.Client
|
|
baseURL string
|
|
destDir string
|
|
repository string // Repository path for blob URLs (e.g., "library/model")
|
|
token *string
|
|
getToken func(context.Context, AuthChallenge) (string, error)
|
|
userAgent string
|
|
stallTimeout time.Duration
|
|
progress *progressTracker
|
|
speeds *speedTracker
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func download(ctx context.Context, opts DownloadOptions) error {
|
|
if len(opts.Blobs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Filter existing
|
|
var blobs []Blob
|
|
var total int64
|
|
for _, b := range opts.Blobs {
|
|
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
|
|
if opts.Logger != nil {
|
|
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
|
|
}
|
|
continue
|
|
}
|
|
blobs = append(blobs, b)
|
|
total += b.Size
|
|
}
|
|
if len(blobs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
token := opts.Token
|
|
d := &downloader{
|
|
client: cmp.Or(opts.Client, defaultClient),
|
|
baseURL: opts.BaseURL,
|
|
destDir: opts.DestDir,
|
|
repository: cmp.Or(opts.Repository, "library/_"),
|
|
token: &token,
|
|
getToken: opts.GetToken,
|
|
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
|
|
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
|
|
progress: newProgressTracker(total, opts.Progress),
|
|
speeds: &speedTracker{},
|
|
logger: opts.Logger,
|
|
}
|
|
|
|
concurrency := cmp.Or(opts.Concurrency, DefaultDownloadConcurrency)
|
|
sem := semaphore.NewWeighted(int64(concurrency))
|
|
|
|
g, ctx := errgroup.WithContext(ctx)
|
|
for _, blob := range blobs {
|
|
g.Go(func() error {
|
|
if err := sem.Acquire(ctx, 1); err != nil {
|
|
return err
|
|
}
|
|
defer sem.Release(1)
|
|
return d.download(ctx, blob)
|
|
})
|
|
}
|
|
return g.Wait()
|
|
}
|
|
|
|
func (d *downloader) download(ctx context.Context, blob Blob) error {
|
|
var lastErr error
|
|
var slowRetries int
|
|
attempt := 0
|
|
|
|
for attempt < maxRetries {
|
|
if attempt > 0 {
|
|
if err := backoff(ctx, attempt, time.Second<<uint(attempt-1)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
start := time.Now()
|
|
n, err := d.downloadOnce(ctx, blob)
|
|
if err == nil {
|
|
if s := time.Since(start).Seconds(); s > 0 {
|
|
d.speeds.record(float64(blob.Size) / s)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
d.progress.add(-n) // rollback
|
|
|
|
switch {
|
|
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
|
return err
|
|
case errors.Is(err, errStalled):
|
|
// Don't count stall retries against limit
|
|
case errors.Is(err, errSlow):
|
|
if slowRetries++; slowRetries >= 3 {
|
|
attempt++ // Only count after 3 slow retries
|
|
}
|
|
default:
|
|
attempt++
|
|
}
|
|
lastErr = err
|
|
}
|
|
return fmt.Errorf("%w: %v", errMaxRetriesExceeded, lastErr)
|
|
}
|
|
|
|
func (d *downloader) downloadOnce(ctx context.Context, blob Blob) (int64, error) {
|
|
if d.logger != nil {
|
|
d.logger.Debug("downloading blob", "digest", blob.Digest, "size", blob.Size)
|
|
}
|
|
|
|
baseURL, _ := url.Parse(d.baseURL)
|
|
u, err := d.resolve(ctx, fmt.Sprintf("%s/v2/%s/blobs/%s", d.baseURL, d.repository, blob.Digest))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
|
req.Header.Set("User-Agent", d.userAgent)
|
|
// Add auth only for same-host (not CDN)
|
|
if u.Host == baseURL.Host && *d.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+*d.token)
|
|
}
|
|
|
|
resp, err := d.client.Do(req)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return 0, fmt.Errorf("status %d", resp.StatusCode)
|
|
}
|
|
|
|
return d.save(ctx, blob, resp.Body)
|
|
}
|
|
|
|
func (d *downloader) save(ctx context.Context, blob Blob, r io.Reader) (int64, error) {
|
|
dest := filepath.Join(d.destDir, digestToPath(blob.Digest))
|
|
tmp := dest + ".tmp"
|
|
os.MkdirAll(filepath.Dir(dest), 0o755)
|
|
|
|
f, err := os.Create(tmp)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer f.Close()
|
|
setSparse(f)
|
|
|
|
h := sha256.New()
|
|
n, err := d.copy(ctx, f, r, h)
|
|
if err != nil {
|
|
os.Remove(tmp)
|
|
return n, err
|
|
}
|
|
f.Close()
|
|
|
|
if got := fmt.Sprintf("sha256:%x", h.Sum(nil)); got != blob.Digest {
|
|
os.Remove(tmp)
|
|
return n, fmt.Errorf("digest mismatch")
|
|
}
|
|
if n != blob.Size {
|
|
os.Remove(tmp)
|
|
return n, fmt.Errorf("size mismatch")
|
|
}
|
|
return n, os.Rename(tmp, dest)
|
|
}
|
|
|
|
func (d *downloader) copy(ctx context.Context, dst io.Writer, src io.Reader, h io.Writer) (int64, error) {
|
|
var n int64
|
|
var lastRead atomic.Int64
|
|
lastRead.Store(time.Now().UnixNano())
|
|
start := time.Now()
|
|
|
|
ctx, cancel := context.WithCancelCause(ctx)
|
|
defer cancel(nil)
|
|
|
|
go func() {
|
|
tick := time.NewTicker(time.Second)
|
|
defer tick.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-tick.C:
|
|
if time.Since(time.Unix(0, lastRead.Load())) > d.stallTimeout {
|
|
cancel(errStalled)
|
|
return
|
|
}
|
|
if e := time.Since(start); e > 5*time.Second {
|
|
if m := d.speeds.median(); m > 0 && float64(n)/e.Seconds() < m*0.1 {
|
|
cancel(errSlow)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
if err := ctx.Err(); err != nil {
|
|
if c := context.Cause(ctx); c != nil {
|
|
return n, c
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
nr, err := src.Read(buf)
|
|
if nr > 0 {
|
|
lastRead.Store(time.Now().UnixNano())
|
|
dst.Write(buf[:nr])
|
|
h.Write(buf[:nr])
|
|
d.progress.add(int64(nr))
|
|
n += int64(nr)
|
|
}
|
|
if err == io.EOF {
|
|
return n, nil
|
|
}
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *downloader) resolve(ctx context.Context, rawURL string) (*url.URL, error) {
|
|
u, _ := url.Parse(rawURL)
|
|
for range 10 {
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
|
req.Header.Set("User-Agent", d.userAgent)
|
|
if *d.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+*d.token)
|
|
}
|
|
|
|
resp, err := d.client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp.Body.Close()
|
|
|
|
switch resp.StatusCode {
|
|
case http.StatusOK:
|
|
return u, nil
|
|
case http.StatusUnauthorized:
|
|
if d.getToken == nil {
|
|
return nil, fmt.Errorf("unauthorized")
|
|
}
|
|
ch := parseAuthChallenge(resp.Header.Get("WWW-Authenticate"))
|
|
if *d.token, err = d.getToken(ctx, ch); err != nil {
|
|
return nil, err
|
|
}
|
|
case http.StatusTemporaryRedirect, http.StatusFound, http.StatusMovedPermanently:
|
|
loc, _ := resp.Location()
|
|
if loc.Host != u.Host {
|
|
return loc, nil
|
|
}
|
|
u = loc
|
|
default:
|
|
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("too many redirects")
|
|
}
|
|
|
|
type speedTracker struct {
|
|
mu sync.Mutex
|
|
speeds []float64
|
|
}
|
|
|
|
func (s *speedTracker) record(v float64) {
|
|
s.mu.Lock()
|
|
s.speeds = append(s.speeds, v)
|
|
if len(s.speeds) > 30 {
|
|
s.speeds = s.speeds[1:]
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *speedTracker) median() float64 {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if len(s.speeds) < 5 {
|
|
return 0
|
|
}
|
|
sorted := make([]float64, len(s.speeds))
|
|
copy(sorted, s.speeds)
|
|
slices.Sort(sorted)
|
|
return sorted[len(sorted)/2]
|
|
}
|
|
|
|
const defaultStallTimeout = 10 * time.Second
|