From 84fd98ea741d3b83a200353dd40f8f4e5219a79a Mon Sep 17 00:00:00 2001 From: patriceckhart Date: Fri, 5 Jun 2026 16:05:38 +0200 Subject: [PATCH] Normalize Bedrock tool results --- packages/provider/amazon_bedrock.go | 60 ++++++++++++++++++++- packages/provider/amazon_bedrock_test.go | 66 ++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/packages/provider/amazon_bedrock.go b/packages/provider/amazon_bedrock.go index f5653f0..58832d3 100644 --- a/packages/provider/amazon_bedrock.go +++ b/packages/provider/amazon_bedrock.go @@ -216,6 +216,61 @@ type bedrockRequest struct { } `json:"toolConfig,omitempty"` } +func normalizeBedrockToolResults(msgs []Message) []Message { + resultByID := map[string]ToolResultBlock{} + for _, m := range msgs { + for _, c := range m.Content { + if tr, ok := c.(ToolResultBlock); ok { + if _, exists := resultByID[tr.CallID]; !exists { + resultByID[tr.CallID] = tr + } + } + } + } + + out := make([]Message, 0, len(msgs)) + for _, m := range msgs { + copy := m + copy.Content = nil + var toolCalls []ToolCallBlock + for _, c := range m.Content { + switch v := c.(type) { + case ToolResultBlock: + // Bedrock requires toolResult blocks immediately after the + // assistant toolUse they answer. Reinsert them from the + // assistant pass below instead of preserving their original + // location, which may be separated by user text in active + // sessions. + continue + case ToolCallBlock: + toolCalls = append(toolCalls, v) + } + copy.Content = append(copy.Content, c) + } + if len(copy.Content) > 0 { + out = append(out, copy) + } + if m.Role != RoleAssistant || len(toolCalls) == 0 { + continue + } + + results := make([]Content, 0, len(toolCalls)) + for _, tc := range toolCalls { + if tr, ok := resultByID[tc.ID]; ok { + results = append(results, tr) + continue + } + results = append(results, ToolResultBlock{ + CallID: tc.ID, + IsError: true, + Content: []Content{TextBlock{Text: "tool call did not complete before the next user message"}}, + }) + } + out = append(out, Message{Role: RoleTool, Content: results, Time: m.Time}) + } + return out +} + func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) { out := &bedrockRequest{} if req.System != "" { @@ -226,7 +281,7 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) { if out.InferenceConfig.MaxTokens == 0 { out.InferenceConfig.MaxTokens = 4096 } - for _, m := range req.Messages { + for _, m := range normalizeBedrockToolResults(req.Messages) { role := string(m.Role) if role == "tool" { role = "user" @@ -267,6 +322,9 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) { }) } } + if len(bm.Content) == 0 { + continue + } out.Messages = append(out.Messages, bm) } if len(req.Tools) > 0 { diff --git a/packages/provider/amazon_bedrock_test.go b/packages/provider/amazon_bedrock_test.go index b78c017..f8a4c18 100644 --- a/packages/provider/amazon_bedrock_test.go +++ b/packages/provider/amazon_bedrock_test.go @@ -2,6 +2,7 @@ package provider import ( "bytes" + "encoding/json" "net/http" "strings" "testing" @@ -82,6 +83,71 @@ func TestReadAWSCredentialsFile(t *testing.T) { } } +func TestNormalizeBedrockToolResultsMovesResultsAdjacentToToolUse(t *testing.T) { + msgs := []Message{ + {Role: RoleAssistant, Content: []Content{ToolCallBlock{ID: "tool-1", Name: "edit", Arguments: json.RawMessage(`{"path":"a"}`)}}}, + {Role: RoleUser, Content: []Content{TextBlock{Text: "what's wrong"}}}, + {Role: RoleTool, Content: []Content{ToolResultBlock{CallID: "tool-1", Content: []Content{TextBlock{Text: "edited"}}}}}, + } + + out := normalizeBedrockToolResults(msgs) + if len(out) != 3 { + t.Fatalf("got %d messages, want 3: %+v", len(out), out) + } + if out[0].Role != RoleAssistant { + t.Fatalf("first role = %s, want assistant", out[0].Role) + } + if out[1].Role != RoleTool { + t.Fatalf("second role = %s, want tool", out[1].Role) + } + tr, ok := out[1].Content[0].(ToolResultBlock) + if !ok { + t.Fatalf("second message content = %T, want ToolResultBlock", out[1].Content[0]) + } + if tr.CallID != "tool-1" || tr.IsError { + t.Fatalf("unexpected tool result: %+v", tr) + } + if out[2].Role != RoleUser { + t.Fatalf("third role = %s, want user", out[2].Role) + } +} + +func TestNormalizeBedrockToolResultsInjectsMissingResult(t *testing.T) { + msgs := []Message{ + {Role: RoleAssistant, Content: []Content{ToolCallBlock{ID: "tool-1", Name: "edit", Arguments: json.RawMessage(`{}`)}}}, + {Role: RoleUser, Content: []Content{TextBlock{Text: "what's wrong"}}}, + } + + out := normalizeBedrockToolResults(msgs) + if len(out) != 3 { + t.Fatalf("got %d messages, want 3: %+v", len(out), out) + } + tr, ok := out[1].Content[0].(ToolResultBlock) + if !ok { + t.Fatalf("second message content = %T, want ToolResultBlock", out[1].Content[0]) + } + if tr.CallID != "tool-1" || !tr.IsError { + t.Fatalf("unexpected synthetic result: %+v", tr) + } +} + +func TestBedrockBuildRequestSkipsEmptyToolMessages(t *testing.T) { + client := &bedrockClient{} + req, err := client.buildRequest(Request{Messages: []Message{ + {Role: RoleTool, Content: []Content{ToolResultBlock{CallID: "missing", Content: []Content{TextBlock{Text: "orphan"}}}}}, + {Role: RoleUser, Content: []Content{TextBlock{Text: "hello"}}}, + }}) + if err != nil { + t.Fatal(err) + } + if len(req.Messages) != 1 { + t.Fatalf("got %d bedrock messages, want 1: %+v", len(req.Messages), req.Messages) + } + if req.Messages[0].Role != "user" { + t.Fatalf("role = %s, want user", req.Messages[0].Role) + } +} + func TestBedrockEventPayloadHelpers(t *testing.T) { wrapped := []byte(`{"contentBlockDelta":{"contentBlockIndex":0,"delta":{"text":"Hello"}}}`) if got := bedrockEventTypeFromPayload(wrapped); got != "contentBlockDelta" {