Normalize Bedrock tool results
Some checks failed
ci / test (macos-latest) (push) Has been cancelled
ci / test (ubuntu-latest) (push) Has been cancelled
ci / test (windows-latest) (push) Has been cancelled

This commit is contained in:
patriceckhart 2026-06-05 16:05:38 +02:00
parent 7a7bf0b52c
commit 84fd98ea74
2 changed files with 125 additions and 1 deletions

View file

@ -216,6 +216,61 @@ type bedrockRequest struct {
} `json:"toolConfig,omitempty"`
}
func normalizeBedrockToolResults(msgs []Message) []Message {
resultByID := map[string]ToolResultBlock{}
for _, m := range msgs {
for _, c := range m.Content {
if tr, ok := c.(ToolResultBlock); ok {
if _, exists := resultByID[tr.CallID]; !exists {
resultByID[tr.CallID] = tr
}
}
}
}
out := make([]Message, 0, len(msgs))
for _, m := range msgs {
copy := m
copy.Content = nil
var toolCalls []ToolCallBlock
for _, c := range m.Content {
switch v := c.(type) {
case ToolResultBlock:
// Bedrock requires toolResult blocks immediately after the
// assistant toolUse they answer. Reinsert them from the
// assistant pass below instead of preserving their original
// location, which may be separated by user text in active
// sessions.
continue
case ToolCallBlock:
toolCalls = append(toolCalls, v)
}
copy.Content = append(copy.Content, c)
}
if len(copy.Content) > 0 {
out = append(out, copy)
}
if m.Role != RoleAssistant || len(toolCalls) == 0 {
continue
}
results := make([]Content, 0, len(toolCalls))
for _, tc := range toolCalls {
if tr, ok := resultByID[tc.ID]; ok {
results = append(results, tr)
continue
}
results = append(results, ToolResultBlock{
CallID: tc.ID,
IsError: true,
Content: []Content{TextBlock{Text: "tool call did not complete before the next user message"}},
})
}
out = append(out, Message{Role: RoleTool, Content: results, Time: m.Time})
}
return out
}
func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
out := &bedrockRequest{}
if req.System != "" {
@ -226,7 +281,7 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
if out.InferenceConfig.MaxTokens == 0 {
out.InferenceConfig.MaxTokens = 4096
}
for _, m := range req.Messages {
for _, m := range normalizeBedrockToolResults(req.Messages) {
role := string(m.Role)
if role == "tool" {
role = "user"
@ -267,6 +322,9 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
})
}
}
if len(bm.Content) == 0 {
continue
}
out.Messages = append(out.Messages, bm)
}
if len(req.Tools) > 0 {

View file

@ -2,6 +2,7 @@ package provider
import (
"bytes"
"encoding/json"
"net/http"
"strings"
"testing"
@ -82,6 +83,71 @@ func TestReadAWSCredentialsFile(t *testing.T) {
}
}
func TestNormalizeBedrockToolResultsMovesResultsAdjacentToToolUse(t *testing.T) {
msgs := []Message{
{Role: RoleAssistant, Content: []Content{ToolCallBlock{ID: "tool-1", Name: "edit", Arguments: json.RawMessage(`{"path":"a"}`)}}},
{Role: RoleUser, Content: []Content{TextBlock{Text: "what's wrong"}}},
{Role: RoleTool, Content: []Content{ToolResultBlock{CallID: "tool-1", Content: []Content{TextBlock{Text: "edited"}}}}},
}
out := normalizeBedrockToolResults(msgs)
if len(out) != 3 {
t.Fatalf("got %d messages, want 3: %+v", len(out), out)
}
if out[0].Role != RoleAssistant {
t.Fatalf("first role = %s, want assistant", out[0].Role)
}
if out[1].Role != RoleTool {
t.Fatalf("second role = %s, want tool", out[1].Role)
}
tr, ok := out[1].Content[0].(ToolResultBlock)
if !ok {
t.Fatalf("second message content = %T, want ToolResultBlock", out[1].Content[0])
}
if tr.CallID != "tool-1" || tr.IsError {
t.Fatalf("unexpected tool result: %+v", tr)
}
if out[2].Role != RoleUser {
t.Fatalf("third role = %s, want user", out[2].Role)
}
}
func TestNormalizeBedrockToolResultsInjectsMissingResult(t *testing.T) {
msgs := []Message{
{Role: RoleAssistant, Content: []Content{ToolCallBlock{ID: "tool-1", Name: "edit", Arguments: json.RawMessage(`{}`)}}},
{Role: RoleUser, Content: []Content{TextBlock{Text: "what's wrong"}}},
}
out := normalizeBedrockToolResults(msgs)
if len(out) != 3 {
t.Fatalf("got %d messages, want 3: %+v", len(out), out)
}
tr, ok := out[1].Content[0].(ToolResultBlock)
if !ok {
t.Fatalf("second message content = %T, want ToolResultBlock", out[1].Content[0])
}
if tr.CallID != "tool-1" || !tr.IsError {
t.Fatalf("unexpected synthetic result: %+v", tr)
}
}
func TestBedrockBuildRequestSkipsEmptyToolMessages(t *testing.T) {
client := &bedrockClient{}
req, err := client.buildRequest(Request{Messages: []Message{
{Role: RoleTool, Content: []Content{ToolResultBlock{CallID: "missing", Content: []Content{TextBlock{Text: "orphan"}}}}},
{Role: RoleUser, Content: []Content{TextBlock{Text: "hello"}}},
}})
if err != nil {
t.Fatal(err)
}
if len(req.Messages) != 1 {
t.Fatalf("got %d bedrock messages, want 1: %+v", len(req.Messages), req.Messages)
}
if req.Messages[0].Role != "user" {
t.Fatalf("role = %s, want user", req.Messages[0].Role)
}
}
func TestBedrockEventPayloadHelpers(t *testing.T) {
wrapped := []byte(`{"contentBlockDelta":{"contentBlockIndex":0,"delta":{"text":"Hello"}}}`)
if got := bedrockEventTypeFromPayload(wrapped); got != "contentBlockDelta" {