mirror of
https://github.com/ollama/ollama.git
synced 2026-01-12 00:06:57 +08:00
This reverts commit 6a62b894c7.
This commit is contained in:
@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
}
|
||||
|
||||
if truncate && ctxLen > opts.NumCtx {
|
||||
if ctxLen > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
|
||||
@@ -27,18 +27,16 @@ func TestChatPrompt(t *testing.T) {
|
||||
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model Model
|
||||
limit int
|
||||
truncate bool
|
||||
msgs []api.Message
|
||||
name string
|
||||
model Model
|
||||
limit int
|
||||
msgs []api.Message
|
||||
expect
|
||||
}{
|
||||
{
|
||||
name: "messages",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
name: "messages",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -49,10 +47,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages",
|
||||
model: visionModel,
|
||||
limit: 1,
|
||||
truncate: true,
|
||||
name: "truncate messages",
|
||||
model: visionModel,
|
||||
limit: 1,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -63,10 +60,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages with image",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
name: "truncate messages with image",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -80,10 +76,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages with images",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
name: "truncate messages with images",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -97,10 +92,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "messages with images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
name: "messages with images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -115,10 +109,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message with image tag",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
name: "message with image tag",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -133,10 +126,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "messages with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
name: "messages with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
@@ -153,10 +145,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate message with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 1024,
|
||||
truncate: true,
|
||||
name: "truncate message with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 1024,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
@@ -172,10 +163,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message with system prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
name: "message with system prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are the Test Who Lived."},
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
@@ -187,10 +177,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "out of order system",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
name: "out of order system",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -202,10 +191,9 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple images same prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
name: "multiple images same prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
|
||||
},
|
||||
@@ -214,20 +202,6 @@ func TestChatPrompt(t *testing.T) {
|
||||
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no truncate with limit exceeded",
|
||||
model: visionModel,
|
||||
limit: 10,
|
||||
truncate: false,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@@ -235,7 +209,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think})
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
@@ -468,12 +468,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@@ -535,7 +533,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- err
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -549,11 +547,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
sbThinking.WriteString(t.Thinking)
|
||||
sbContent.WriteString(t.Response)
|
||||
r = t
|
||||
case api.StatusError:
|
||||
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage})
|
||||
return
|
||||
case error:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()})
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
@@ -1618,18 +1618,6 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
return false
|
||||
}
|
||||
|
||||
if statusError, ok := val.(api.StatusError); ok {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.AbortWithStatusJSON(statusError.StatusCode, gin.H{"error": statusError.ErrorMessage})
|
||||
return false
|
||||
}
|
||||
|
||||
if err, ok := val.(error); ok {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return false
|
||||
}
|
||||
|
||||
bts, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
|
||||
@@ -1947,8 +1935,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1997,12 +1984,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
@@ -2075,7 +2060,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- err
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -2093,11 +2078,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls = append(toolCalls, t.Message.ToolCalls...)
|
||||
}
|
||||
case api.StatusError:
|
||||
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage})
|
||||
return
|
||||
case error:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()})
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
|
||||
@@ -594,58 +594,6 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error non-streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Status: "Service Unavailable",
|
||||
ErrorMessage: "model is overloaded",
|
||||
}
|
||||
}
|
||||
|
||||
stream := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status 503, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Status: "Too Many Requests",
|
||||
ErrorMessage: "rate limit exceeded",
|
||||
}
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected status 429, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
@@ -1020,55 +968,6 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error non-streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Status: "Service Unavailable",
|
||||
ErrorMessage: "model is overloaded",
|
||||
}
|
||||
}
|
||||
|
||||
streamRequest := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello!",
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status 503, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Status: "Too Many Requests",
|
||||
ErrorMessage: "rate limit exceeded",
|
||||
}
|
||||
}
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello!",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected status 429, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user