diff --git a/packages/provider/amazon_bedrock.go b/packages/provider/amazon_bedrock.go index 58832d3..a3baaf6 100644 --- a/packages/provider/amazon_bedrock.go +++ b/packages/provider/amazon_bedrock.go @@ -194,6 +194,13 @@ type bedrockMessage struct { Content []map[string]interface{} `json:"content"` } +// bedrockCachePoint is the Converse API cache checkpoint content block. +// Appending this to a system or message content array tells Bedrock to +// create a cache boundary at that position in the prompt prefix. +var bedrockCachePoint = map[string]interface{}{ + "cachePoint": map[string]interface{}{"type": "default"}, +} + type bedrockToolSpec struct { ToolSpec struct { Name string `json:"name"` @@ -205,8 +212,8 @@ type bedrockToolSpec struct { } type bedrockRequest struct { - Messages []bedrockMessage `json:"messages"` - System []map[string]string `json:"system,omitempty"` + Messages []bedrockMessage `json:"messages"` + System []map[string]interface{} `json:"system,omitempty"` InferenceConfig struct { MaxTokens int `json:"maxTokens,omitempty"` Temperature *float32 `json:"temperature,omitempty"` @@ -271,10 +278,44 @@ func normalizeBedrockToolResults(msgs []Message) []Message { return out } +// bedrockModelSupportsCaching reports whether the resolved model ID +// supports explicit prompt caching via cachePoint markers on Bedrock. +// We use PriceCacheWrite > 0 as a proxy: every Bedrock-hosted Claude +// model with a write price in the catalog supports cachePoint markers. +// Nova models use automatic caching and don't need explicit markers. +func bedrockModelSupportsCaching(modelID string) bool { + // Strip geo prefix (us./eu./apac./au./global.) before catalog lookup. + for _, p := range bedrockGeoPrefixes { + if strings.HasPrefix(modelID, p+".") { + modelID = modelID[len(p)+1:] + break + } + } + if m, err := FindModel("amazon-bedrock", modelID); err == nil { + return m.PriceCacheWrite > 0 + } + // Unknown model: enable for Anthropic Claude families — cachePoint is + // silently ignored by the API if the model doesn't support it. + return strings.HasPrefix(modelID, "anthropic.claude-") +} + func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) { out := &bedrockRequest{} + + // Resolve the model ID as it will appear on the wire so the caching + // check operates on the same ID used for FindModel. + resolvedModel := resolveBedrockInferenceProfileID(req.Model, c.region) + caching := bedrockModelSupportsCaching(resolvedModel) + if req.System != "" { - out.System = []map[string]string{{"text": req.System}} + sysBlock := map[string]interface{}{"text": req.System} + if caching { + // Append cachePoint after the system text so the stable system + // prompt is cached as the first breakpoint. + out.System = []map[string]interface{}{sysBlock, bedrockCachePoint} + } else { + out.System = []map[string]interface{}{sysBlock} + } } out.InferenceConfig.Temperature = req.Temperature out.InferenceConfig.MaxTokens = req.MaxTokens @@ -340,9 +381,29 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) { } out.ToolConfig = &tc } + + if caching { + // Tag the last user message with a cachePoint. This extends the + // cached prefix to cover the full conversation history up to the + // current turn, so the next turn reads that history cheaply. + bedrockTagLastUserCache(out.Messages) + } + return out, nil } +// bedrockTagLastUserCache appends a cachePoint block to the last user +// message in the Bedrock message list. It is the Bedrock equivalent of +// Anthropic's cache_control:{type:"ephemeral"} on the last user message. +func bedrockTagLastUserCache(msgs []bedrockMessage) { + for i := len(msgs) - 1; i >= 0; i-- { + if msgs[i].Role == "user" { + msgs[i].Content = append(msgs[i].Content, bedrockCachePoint) + return + } + } +} + // resolveBedrockInferenceProfileID maps a bare foundation-model ID to // its region-matched cross-region inference-profile ID. // diff --git a/packages/provider/amazon_bedrock_test.go b/packages/provider/amazon_bedrock_test.go index f8a4c18..2c0038f 100644 --- a/packages/provider/amazon_bedrock_test.go +++ b/packages/provider/amazon_bedrock_test.go @@ -131,6 +131,232 @@ func TestNormalizeBedrockToolResultsInjectsMissingResult(t *testing.T) { } } +func TestBedrockModelSupportsCaching(t *testing.T) { + cases := []struct { + model string + want bool + }{ + // Bare Claude IDs (as they come from catalog_builtin) + {"anthropic.claude-sonnet-4-5-20250929-v1:0", true}, + {"anthropic.claude-haiku-4-5-20251001-v1:0", true}, + {"anthropic.claude-opus-4-5-20251101-v1:0", true}, + // Geo-prefixed (resolved form that arrives at bedrockModelSupportsCaching) + {"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true}, + {"eu.anthropic.claude-haiku-4-5-20251001-v1:0", true}, + {"global.anthropic.claude-opus-4-5-20251101-v1:0", true}, + // Nova models have PriceCacheWrite==0 — they use automatic caching + {"amazon.nova-pro-v1:0", false}, + {"amazon.nova-lite-v1:0", false}, + {"amazon.nova-micro-v1:0", false}, + // DeepSeek — no cache write price + {"deepseek.r1-v1:0", false}, + // Unknown model with claude prefix + {"anthropic.claude-future-v99:0", true}, + // Unknown non-claude model + {"some.unknown-model-v1:0", false}, + } + for _, c := range cases { + got := bedrockModelSupportsCaching(c.model) + if got != c.want { + t.Errorf("bedrockModelSupportsCaching(%q) = %v, want %v", c.model, got, c.want) + } + } +} + +func TestBedrockBuildRequestCachingClaudeModel(t *testing.T) { + // A Claude model (PriceCacheWrite > 0) should get cachePoint markers + // in the system array and on the last user message. + client := &bedrockClient{region: "us-east-1"} + req, err := client.buildRequest(Request{ + Model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + System: "You are a helpful assistant.", + Messages: []Message{ + {Role: RoleUser, Content: []Content{TextBlock{Text: "hello"}}}, + }, + }) + if err != nil { + t.Fatal(err) + } + + // system should be: [{text: ...}, {cachePoint: {type: default}}] + if len(req.System) != 2 { + t.Fatalf("system len = %d, want 2 (text + cachePoint)", len(req.System)) + } + if _, ok := req.System[0]["text"]; !ok { + t.Errorf("system[0] missing text key: %v", req.System[0]) + } + if _, ok := req.System[1]["cachePoint"]; !ok { + t.Errorf("system[1] missing cachePoint key: %v", req.System[1]) + } + + // last user message should end with a cachePoint block + if len(req.Messages) == 0 { + t.Fatal("no messages") + } + lastMsg := req.Messages[len(req.Messages)-1] + if lastMsg.Role != "user" { + t.Fatalf("last message role = %q, want user", lastMsg.Role) + } + lastBlock := lastMsg.Content[len(lastMsg.Content)-1] + if _, ok := lastBlock["cachePoint"]; !ok { + t.Errorf("last user message final block missing cachePoint: %v", lastBlock) + } +} + +func TestBedrockBuildRequestNoCachingNovaModel(t *testing.T) { + // A Nova model (PriceCacheWrite == 0) should NOT get cachePoint markers. + client := &bedrockClient{region: "us-east-1"} + req, err := client.buildRequest(Request{ + Model: "amazon.nova-pro-v1:0", + System: "You are helpful.", + Messages: []Message{ + {Role: RoleUser, Content: []Content{TextBlock{Text: "hi"}}}, + }, + }) + if err != nil { + t.Fatal(err) + } + + // system should be plain: [{text: ...}] with no cachePoint + if len(req.System) != 1 { + t.Fatalf("system len = %d, want 1 (text only)", len(req.System)) + } + if _, ok := req.System[0]["cachePoint"]; ok { + t.Errorf("Nova system unexpectedly contains cachePoint") + } + + // last user message should NOT end with a cachePoint block + if len(req.Messages) == 0 { + t.Fatal("no messages") + } + lastMsg := req.Messages[len(req.Messages)-1] + for _, block := range lastMsg.Content { + if _, ok := block["cachePoint"]; ok { + t.Errorf("Nova user message unexpectedly contains cachePoint: %v", block) + } + } +} + +func TestBedrockBuildRequestCachingMultiTurn(t *testing.T) { + // In a multi-turn conversation the cachePoint should be on the LAST + // user message only, not on earlier ones. + client := &bedrockClient{region: "us-east-1"} + req, err := client.buildRequest(Request{ + Model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + Messages: []Message{ + {Role: RoleUser, Content: []Content{TextBlock{Text: "turn 1"}}}, + {Role: RoleAssistant, Content: []Content{TextBlock{Text: "response 1"}}}, + {Role: RoleUser, Content: []Content{TextBlock{Text: "turn 2"}}}, + }, + }) + if err != nil { + t.Fatal(err) + } + + // Find all user messages and verify only the last has a cachePoint. + var userMsgs []bedrockMessage + for _, m := range req.Messages { + if m.Role == "user" { + userMsgs = append(userMsgs, m) + } + } + if len(userMsgs) != 2 { + t.Fatalf("expected 2 user messages, got %d", len(userMsgs)) + } + + // First user message: no cachePoint + for _, block := range userMsgs[0].Content { + if _, ok := block["cachePoint"]; ok { + t.Errorf("first user message should not have cachePoint: %v", block) + } + } + + // Last user message: ends with cachePoint + lastContent := userMsgs[len(userMsgs)-1].Content + lastBlock := lastContent[len(lastContent)-1] + if _, ok := lastBlock["cachePoint"]; !ok { + t.Errorf("last user message should end with cachePoint, got: %v", lastBlock) + } +} + +func TestBedrockBuildRequestCachingNoSystemNoCrash(t *testing.T) { + // No system prompt: system array should be nil, not a bare cachePoint. + client := &bedrockClient{region: "us-east-1"} + req, err := client.buildRequest(Request{ + Model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + Messages: []Message{{Role: RoleUser, Content: []Content{TextBlock{Text: "hi"}}}}, + }) + if err != nil { + t.Fatal(err) + } + if len(req.System) != 0 { + t.Errorf("empty system prompt should produce nil system, got: %v", req.System) + } + // Message cachePoint still present + lastBlock := req.Messages[0].Content[len(req.Messages[0].Content)-1] + if _, ok := lastBlock["cachePoint"]; !ok { + t.Errorf("user message should still have cachePoint when system is empty") + } +} + +func TestBedrockTagLastUserCache(t *testing.T) { + msgs := []bedrockMessage{ + {Role: "user", Content: []map[string]interface{}{{"text": "hi"}}}, + {Role: "assistant", Content: []map[string]interface{}{{"text": "hello"}}}, + {Role: "user", Content: []map[string]interface{}{{"text": "followup"}}}, + } + bedrockTagLastUserCache(msgs) + + // Only the last user message (index 2) should have a cachePoint appended. + last := msgs[2].Content + if _, ok := last[len(last)-1]["cachePoint"]; !ok { + t.Errorf("last user message should end with cachePoint") + } + // The first user message should be untouched. + if len(msgs[0].Content) != 1 { + t.Errorf("first user message content len = %d, want 1", len(msgs[0].Content)) + } + if _, ok := msgs[0].Content[0]["cachePoint"]; ok { + t.Errorf("first user message should not have cachePoint") + } +} + +func TestBedrockTagLastUserCacheEmpty(t *testing.T) { + // Should not panic on empty or assistant-only history. + bedrockTagLastUserCache(nil) + bedrockTagLastUserCache([]bedrockMessage{}) + msgs := []bedrockMessage{ + {Role: "assistant", Content: []map[string]interface{}{{"text": "hi"}}}, + } + bedrockTagLastUserCache(msgs) // should not panic, no user message to tag +} + +func TestBedrockBuildRequestCachingWireJSON(t *testing.T) { + // Verify the JSON shape Bedrock actually receives has the right keys. + client := &bedrockClient{region: "us-east-1"} + breq, err := client.buildRequest(Request{ + Model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + System: "be helpful", + Messages: []Message{ + {Role: RoleUser, Content: []Content{TextBlock{Text: "hello"}}}, + }, + }) + if err != nil { + t.Fatal(err) + } + b, err := json.Marshal(breq) + if err != nil { + t.Fatal(err) + } + s := string(b) + if !strings.Contains(s, `"cachePoint"`) { + t.Errorf("serialised request missing cachePoint key: %s", s) + } + if !strings.Contains(s, `"type":"default"`) { + t.Errorf("serialised request missing cachePoint type:default: %s", s) + } +} + func TestBedrockBuildRequestSkipsEmptyToolMessages(t *testing.T) { client := &bedrockClient{} req, err := client.buildRequest(Request{Messages: []Message{