mirror of
https://github.com/patriceckhart/zot.git
synced 2026-06-26 21:36:31 +02:00
provider/bedrock: add prompt caching via cachePoint markers
Place Bedrock Converse API cachePoint blocks at the system prompt
boundary and after the last user message on every request to Claude
models (those with PriceCacheWrite > 0 in the catalog).
This mirrors the existing Anthropic provider strategy (cache_control:
ephemeral on system, tools, and last user message) using Bedrock's
equivalent syntax: a {"cachePoint":{"type":"default"}} content block
appended to the relevant arrays.
Changes:
- bedrockRequest.System widened from []map[string]string to
[]map[string]interface{} to accommodate mixed text/cachePoint blocks
- bedrockCachePoint: shared sentinel content block var
- bedrockModelSupportsCaching: gates on PriceCacheWrite > 0; strips
geo prefixes before catalog lookup; falls back to anthropic.claude-
prefix check for unknown models (cachePoint is silently ignored by
the API if unsupported)
- buildRequest: resolves model ID before caching check; injects
cachePoint into system array and calls bedrockTagLastUserCache
- bedrockTagLastUserCache: appends cachePoint to last user message
Nova models (PriceCacheWrite == 0) are excluded — they use Bedrock's
automatic caching and don't need explicit markers.
Tests: 8 new cases covering model detection, Claude vs Nova presence/
absence, multi-turn last-message targeting, no-system safety,
nil/empty panic safety, and JSON wire shape.
This commit is contained in:
parent
a7ef8c22a1
commit
cc03a4c18a
2 changed files with 290 additions and 3 deletions
|
|
@ -194,6 +194,13 @@ type bedrockMessage struct {
|
|||
Content []map[string]interface{} `json:"content"`
|
||||
}
|
||||
|
||||
// bedrockCachePoint is the Converse API cache checkpoint content block.
|
||||
// Appending this to a system or message content array tells Bedrock to
|
||||
// create a cache boundary at that position in the prompt prefix.
|
||||
var bedrockCachePoint = map[string]interface{}{
|
||||
"cachePoint": map[string]interface{}{"type": "default"},
|
||||
}
|
||||
|
||||
type bedrockToolSpec struct {
|
||||
ToolSpec struct {
|
||||
Name string `json:"name"`
|
||||
|
|
@ -205,8 +212,8 @@ type bedrockToolSpec struct {
|
|||
}
|
||||
|
||||
type bedrockRequest struct {
|
||||
Messages []bedrockMessage `json:"messages"`
|
||||
System []map[string]string `json:"system,omitempty"`
|
||||
Messages []bedrockMessage `json:"messages"`
|
||||
System []map[string]interface{} `json:"system,omitempty"`
|
||||
InferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
|
|
@ -271,10 +278,44 @@ func normalizeBedrockToolResults(msgs []Message) []Message {
|
|||
return out
|
||||
}
|
||||
|
||||
// bedrockModelSupportsCaching reports whether the resolved model ID
|
||||
// supports explicit prompt caching via cachePoint markers on Bedrock.
|
||||
// We use PriceCacheWrite > 0 as a proxy: every Bedrock-hosted Claude
|
||||
// model with a write price in the catalog supports cachePoint markers.
|
||||
// Nova models use automatic caching and don't need explicit markers.
|
||||
func bedrockModelSupportsCaching(modelID string) bool {
|
||||
// Strip geo prefix (us./eu./apac./au./global.) before catalog lookup.
|
||||
for _, p := range bedrockGeoPrefixes {
|
||||
if strings.HasPrefix(modelID, p+".") {
|
||||
modelID = modelID[len(p)+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
if m, err := FindModel("amazon-bedrock", modelID); err == nil {
|
||||
return m.PriceCacheWrite > 0
|
||||
}
|
||||
// Unknown model: enable for Anthropic Claude families — cachePoint is
|
||||
// silently ignored by the API if the model doesn't support it.
|
||||
return strings.HasPrefix(modelID, "anthropic.claude-")
|
||||
}
|
||||
|
||||
func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
|
||||
out := &bedrockRequest{}
|
||||
|
||||
// Resolve the model ID as it will appear on the wire so the caching
|
||||
// check operates on the same ID used for FindModel.
|
||||
resolvedModel := resolveBedrockInferenceProfileID(req.Model, c.region)
|
||||
caching := bedrockModelSupportsCaching(resolvedModel)
|
||||
|
||||
if req.System != "" {
|
||||
out.System = []map[string]string{{"text": req.System}}
|
||||
sysBlock := map[string]interface{}{"text": req.System}
|
||||
if caching {
|
||||
// Append cachePoint after the system text so the stable system
|
||||
// prompt is cached as the first breakpoint.
|
||||
out.System = []map[string]interface{}{sysBlock, bedrockCachePoint}
|
||||
} else {
|
||||
out.System = []map[string]interface{}{sysBlock}
|
||||
}
|
||||
}
|
||||
out.InferenceConfig.Temperature = req.Temperature
|
||||
out.InferenceConfig.MaxTokens = req.MaxTokens
|
||||
|
|
@ -340,9 +381,29 @@ func (c *bedrockClient) buildRequest(req Request) (*bedrockRequest, error) {
|
|||
}
|
||||
out.ToolConfig = &tc
|
||||
}
|
||||
|
||||
if caching {
|
||||
// Tag the last user message with a cachePoint. This extends the
|
||||
// cached prefix to cover the full conversation history up to the
|
||||
// current turn, so the next turn reads that history cheaply.
|
||||
bedrockTagLastUserCache(out.Messages)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// bedrockTagLastUserCache appends a cachePoint block to the last user
|
||||
// message in the Bedrock message list. It is the Bedrock equivalent of
|
||||
// Anthropic's cache_control:{type:"ephemeral"} on the last user message.
|
||||
func bedrockTagLastUserCache(msgs []bedrockMessage) {
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
if msgs[i].Role == "user" {
|
||||
msgs[i].Content = append(msgs[i].Content, bedrockCachePoint)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveBedrockInferenceProfileID maps a bare foundation-model ID to
|
||||
// its region-matched cross-region inference-profile ID.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -131,6 +131,232 @@ func TestNormalizeBedrockToolResultsInjectsMissingResult(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBedrockModelSupportsCaching(t *testing.T) {
|
||||
cases := []struct {
|
||||
model string
|
||||
want bool
|
||||
}{
|
||||
// Bare Claude IDs (as they come from catalog_builtin)
|
||||
{"anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||
{"anthropic.claude-haiku-4-5-20251001-v1:0", true},
|
||||
{"anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||
// Geo-prefixed (resolved form that arrives at bedrockModelSupportsCaching)
|
||||
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||
{"eu.anthropic.claude-haiku-4-5-20251001-v1:0", true},
|
||||
{"global.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||
// Nova models have PriceCacheWrite==0 — they use automatic caching
|
||||
{"amazon.nova-pro-v1:0", false},
|
||||
{"amazon.nova-lite-v1:0", false},
|
||||
{"amazon.nova-micro-v1:0", false},
|
||||
// DeepSeek — no cache write price
|
||||
{"deepseek.r1-v1:0", false},
|
||||
// Unknown model with claude prefix
|
||||
{"anthropic.claude-future-v99:0", true},
|
||||
// Unknown non-claude model
|
||||
{"some.unknown-model-v1:0", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := bedrockModelSupportsCaching(c.model)
|
||||
if got != c.want {
|
||||
t.Errorf("bedrockModelSupportsCaching(%q) = %v, want %v", c.model, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockBuildRequestCachingClaudeModel(t *testing.T) {
|
||||
// A Claude model (PriceCacheWrite > 0) should get cachePoint markers
|
||||
// in the system array and on the last user message.
|
||||
client := &bedrockClient{region: "us-east-1"}
|
||||
req, err := client.buildRequest(Request{
|
||||
Model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []Message{
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "hello"}}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// system should be: [{text: ...}, {cachePoint: {type: default}}]
|
||||
if len(req.System) != 2 {
|
||||
t.Fatalf("system len = %d, want 2 (text + cachePoint)", len(req.System))
|
||||
}
|
||||
if _, ok := req.System[0]["text"]; !ok {
|
||||
t.Errorf("system[0] missing text key: %v", req.System[0])
|
||||
}
|
||||
if _, ok := req.System[1]["cachePoint"]; !ok {
|
||||
t.Errorf("system[1] missing cachePoint key: %v", req.System[1])
|
||||
}
|
||||
|
||||
// last user message should end with a cachePoint block
|
||||
if len(req.Messages) == 0 {
|
||||
t.Fatal("no messages")
|
||||
}
|
||||
lastMsg := req.Messages[len(req.Messages)-1]
|
||||
if lastMsg.Role != "user" {
|
||||
t.Fatalf("last message role = %q, want user", lastMsg.Role)
|
||||
}
|
||||
lastBlock := lastMsg.Content[len(lastMsg.Content)-1]
|
||||
if _, ok := lastBlock["cachePoint"]; !ok {
|
||||
t.Errorf("last user message final block missing cachePoint: %v", lastBlock)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockBuildRequestNoCachingNovaModel(t *testing.T) {
|
||||
// A Nova model (PriceCacheWrite == 0) should NOT get cachePoint markers.
|
||||
client := &bedrockClient{region: "us-east-1"}
|
||||
req, err := client.buildRequest(Request{
|
||||
Model: "amazon.nova-pro-v1:0",
|
||||
System: "You are helpful.",
|
||||
Messages: []Message{
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "hi"}}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// system should be plain: [{text: ...}] with no cachePoint
|
||||
if len(req.System) != 1 {
|
||||
t.Fatalf("system len = %d, want 1 (text only)", len(req.System))
|
||||
}
|
||||
if _, ok := req.System[0]["cachePoint"]; ok {
|
||||
t.Errorf("Nova system unexpectedly contains cachePoint")
|
||||
}
|
||||
|
||||
// last user message should NOT end with a cachePoint block
|
||||
if len(req.Messages) == 0 {
|
||||
t.Fatal("no messages")
|
||||
}
|
||||
lastMsg := req.Messages[len(req.Messages)-1]
|
||||
for _, block := range lastMsg.Content {
|
||||
if _, ok := block["cachePoint"]; ok {
|
||||
t.Errorf("Nova user message unexpectedly contains cachePoint: %v", block)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockBuildRequestCachingMultiTurn(t *testing.T) {
|
||||
// In a multi-turn conversation the cachePoint should be on the LAST
|
||||
// user message only, not on earlier ones.
|
||||
client := &bedrockClient{region: "us-east-1"}
|
||||
req, err := client.buildRequest(Request{
|
||||
Model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
Messages: []Message{
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "turn 1"}}},
|
||||
{Role: RoleAssistant, Content: []Content{TextBlock{Text: "response 1"}}},
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "turn 2"}}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Find all user messages and verify only the last has a cachePoint.
|
||||
var userMsgs []bedrockMessage
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == "user" {
|
||||
userMsgs = append(userMsgs, m)
|
||||
}
|
||||
}
|
||||
if len(userMsgs) != 2 {
|
||||
t.Fatalf("expected 2 user messages, got %d", len(userMsgs))
|
||||
}
|
||||
|
||||
// First user message: no cachePoint
|
||||
for _, block := range userMsgs[0].Content {
|
||||
if _, ok := block["cachePoint"]; ok {
|
||||
t.Errorf("first user message should not have cachePoint: %v", block)
|
||||
}
|
||||
}
|
||||
|
||||
// Last user message: ends with cachePoint
|
||||
lastContent := userMsgs[len(userMsgs)-1].Content
|
||||
lastBlock := lastContent[len(lastContent)-1]
|
||||
if _, ok := lastBlock["cachePoint"]; !ok {
|
||||
t.Errorf("last user message should end with cachePoint, got: %v", lastBlock)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockBuildRequestCachingNoSystemNoCrash(t *testing.T) {
|
||||
// No system prompt: system array should be nil, not a bare cachePoint.
|
||||
client := &bedrockClient{region: "us-east-1"}
|
||||
req, err := client.buildRequest(Request{
|
||||
Model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
Messages: []Message{{Role: RoleUser, Content: []Content{TextBlock{Text: "hi"}}}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(req.System) != 0 {
|
||||
t.Errorf("empty system prompt should produce nil system, got: %v", req.System)
|
||||
}
|
||||
// Message cachePoint still present
|
||||
lastBlock := req.Messages[0].Content[len(req.Messages[0].Content)-1]
|
||||
if _, ok := lastBlock["cachePoint"]; !ok {
|
||||
t.Errorf("user message should still have cachePoint when system is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockTagLastUserCache(t *testing.T) {
|
||||
msgs := []bedrockMessage{
|
||||
{Role: "user", Content: []map[string]interface{}{{"text": "hi"}}},
|
||||
{Role: "assistant", Content: []map[string]interface{}{{"text": "hello"}}},
|
||||
{Role: "user", Content: []map[string]interface{}{{"text": "followup"}}},
|
||||
}
|
||||
bedrockTagLastUserCache(msgs)
|
||||
|
||||
// Only the last user message (index 2) should have a cachePoint appended.
|
||||
last := msgs[2].Content
|
||||
if _, ok := last[len(last)-1]["cachePoint"]; !ok {
|
||||
t.Errorf("last user message should end with cachePoint")
|
||||
}
|
||||
// The first user message should be untouched.
|
||||
if len(msgs[0].Content) != 1 {
|
||||
t.Errorf("first user message content len = %d, want 1", len(msgs[0].Content))
|
||||
}
|
||||
if _, ok := msgs[0].Content[0]["cachePoint"]; ok {
|
||||
t.Errorf("first user message should not have cachePoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockTagLastUserCacheEmpty(t *testing.T) {
|
||||
// Should not panic on empty or assistant-only history.
|
||||
bedrockTagLastUserCache(nil)
|
||||
bedrockTagLastUserCache([]bedrockMessage{})
|
||||
msgs := []bedrockMessage{
|
||||
{Role: "assistant", Content: []map[string]interface{}{{"text": "hi"}}},
|
||||
}
|
||||
bedrockTagLastUserCache(msgs) // should not panic, no user message to tag
|
||||
}
|
||||
|
||||
func TestBedrockBuildRequestCachingWireJSON(t *testing.T) {
|
||||
// Verify the JSON shape Bedrock actually receives has the right keys.
|
||||
client := &bedrockClient{region: "us-east-1"}
|
||||
breq, err := client.buildRequest(Request{
|
||||
Model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
System: "be helpful",
|
||||
Messages: []Message{
|
||||
{Role: RoleUser, Content: []Content{TextBlock{Text: "hello"}}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, err := json.Marshal(breq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := string(b)
|
||||
if !strings.Contains(s, `"cachePoint"`) {
|
||||
t.Errorf("serialised request missing cachePoint key: %s", s)
|
||||
}
|
||||
if !strings.Contains(s, `"type":"default"`) {
|
||||
t.Errorf("serialised request missing cachePoint type:default: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBedrockBuildRequestSkipsEmptyToolMessages(t *testing.T) {
|
||||
client := &bedrockClient{}
|
||||
req, err := client.buildRequest(Request{Messages: []Message{
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue