mirror of
https://github.com/patriceckhart/zot.git
synced 2026-06-26 13:26:33 +02:00
Normalize Bedrock tool results
This commit is contained in:
parent
7a7bf0b52c
commit
84fd98ea74
2 changed files with 125 additions and 1 deletions
|
|
@ -216,6 +216,61 @@ type bedrockRequest struct {
|
|||
} `json:"toolConfig,omitempty"`
|
||||
}
|
||||
|
||||
func normalizeBedrockToolResults(msgs []Message) []Message {
|
||||
resultByID := map[string]ToolResultBlock{}
|
||||
for _, m := range msgs {
|
||||
for _, c := range m.Content {
|
||||
if tr, ok := c.(ToolResultBlock); ok {
|
||||
if _, exists := resultByID[tr.CallID]; !exists {
|
||||
resultByID[tr.CallID] = tr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]Message, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
copy := m
|
||||
copy.Content = nil
|
||||
var toolCalls []ToolCallBlock
|
||||
for _, c := range m.Content {
|
||||
switch v := c.(type) {
|
||||
case ToolResultBlock:
|
||||
// Bedrock requires toolResult blocks immediately after the
|
||||
// assistant toolUse they answer. Reinsert them from the
|
||||
// assistant pass below instead of preserving their original
|
||||
// location, which may be separated by user text in active
|
||||
// sessions.
|
||||
continue
|
||||
case ToolCallBlock:
|
||||
toolCalls = append(toolCalls, v)
|
||||
}
|
||||
copy.Content = append(copy.Content, c)
|
||||
}
|
||||
if len(copy.Content) > 0 {
|
||||
out = append(out, copy)
|
||||
}
|
||||
if m.Role != RoleAssistant || len(toolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
results := make([]Content, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
if tr, ok := resultByID[tc.ID]; ok {
|
||||
results = append(results, tr)
|
||||
continue
|
||||
}
|
||||
results = append(results, ToolResultBlock{
|
||||
CallID: tc.ID,
|
||||
IsError: true,
|
||||
Content: []Content{TextBlock{Text: "tool call did not complete before the next user message"}},
|
||||
})
|
||||
}
|
||||
out = append(out, Message{Role: RoleTool, Content: results, Time: m.Time})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
|
||||
out := &bedrockRequest{}
|
||||
if req.System != "" {
|
||||
|
|
@ -226,7 +281,7 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
|
|||
if out.InferenceConfig.MaxTokens == 0 {
|
||||
out.InferenceConfig.MaxTokens = 4096
|
||||
}
|
||||
for _, m := range req.Messages {
|
||||
for _, m := range normalizeBedrockToolResults(req.Messages) {
|
||||
role := string(m.Role)
|
||||
if role == "tool" {
|
||||
role = "user"
|
||||
|
|
@ -267,6 +322,9 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
|
|||
})
|
||||
}
|
||||
}
|
||||
if len(bm.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
out.Messages = append(out.Messages, bm)
|
||||
}
|
||||
if len(req.Tools) > 0 {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package provider
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -82,6 +83,71 @@ func TestReadAWSCredentialsFile(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNormalizeBedrockToolResultsMovesResultsAdjacentToToolUse(t *testing.T) {
|
||||
msgs := []Message{
|
||||
{Role: RoleAssistant, Content: []Content{ToolCallBlock{ID: "tool-1", Name: "edit", Arguments: json.RawMessage(`{"path":"a"}`)}}},
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "what's wrong"}}},
|
||||
{Role: RoleTool, Content: []Content{ToolResultBlock{CallID: "tool-1", Content: []Content{TextBlock{Text: "edited"}}}}},
|
||||
}
|
||||
|
||||
out := normalizeBedrockToolResults(msgs)
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("got %d messages, want 3: %+v", len(out), out)
|
||||
}
|
||||
if out[0].Role != RoleAssistant {
|
||||
t.Fatalf("first role = %s, want assistant", out[0].Role)
|
||||
}
|
||||
if out[1].Role != RoleTool {
|
||||
t.Fatalf("second role = %s, want tool", out[1].Role)
|
||||
}
|
||||
tr, ok := out[1].Content[0].(ToolResultBlock)
|
||||
if !ok {
|
||||
t.Fatalf("second message content = %T, want ToolResultBlock", out[1].Content[0])
|
||||
}
|
||||
if tr.CallID != "tool-1" || tr.IsError {
|
||||
t.Fatalf("unexpected tool result: %+v", tr)
|
||||
}
|
||||
if out[2].Role != RoleUser {
|
||||
t.Fatalf("third role = %s, want user", out[2].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeBedrockToolResultsInjectsMissingResult(t *testing.T) {
|
||||
msgs := []Message{
|
||||
{Role: RoleAssistant, Content: []Content{ToolCallBlock{ID: "tool-1", Name: "edit", Arguments: json.RawMessage(`{}`)}}},
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "what's wrong"}}},
|
||||
}
|
||||
|
||||
out := normalizeBedrockToolResults(msgs)
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("got %d messages, want 3: %+v", len(out), out)
|
||||
}
|
||||
tr, ok := out[1].Content[0].(ToolResultBlock)
|
||||
if !ok {
|
||||
t.Fatalf("second message content = %T, want ToolResultBlock", out[1].Content[0])
|
||||
}
|
||||
if tr.CallID != "tool-1" || !tr.IsError {
|
||||
t.Fatalf("unexpected synthetic result: %+v", tr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockBuildRequestSkipsEmptyToolMessages(t *testing.T) {
|
||||
client := &bedrockClient{}
|
||||
req, err := client.buildRequest(Request{Messages: []Message{
|
||||
{Role: RoleTool, Content: []Content{ToolResultBlock{CallID: "missing", Content: []Content{TextBlock{Text: "orphan"}}}}},
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "hello"}}},
|
||||
}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(req.Messages) != 1 {
|
||||
t.Fatalf("got %d bedrock messages, want 1: %+v", len(req.Messages), req.Messages)
|
||||
}
|
||||
if req.Messages[0].Role != "user" {
|
||||
t.Fatalf("role = %s, want user", req.Messages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockEventPayloadHelpers(t *testing.T) {
|
||||
wrapped := []byte(`{"contentBlockDelta":{"contentBlockIndex":0,"delta":{"text":"Hello"}}}`)
|
||||
if got := bedrockEventTypeFromPayload(wrapped); got != "contentBlockDelta" {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue