colibri/crates/colibri-daemon/src/cost.rs
Sam & Claude 34929a6a53
Some checks failed
CI / rust (pull_request) Has been cancelled
CI / markdown (pull_request) Has been cancelled
fix(headroom): harden sidecar protocol and timeout (Sam & Codex)
Keep the Python sidecar connection open for multiple newline-delimited requests, add daemon-side request timeout/fallback tests, and document the opt-in Headroom sidecar contract.\n\nChecks: ./scripts/check-format.sh; cargo fmt --check; python3 -m py_compile scripts/headroom-sidecar.py; git diff --check; cargo test -p colibri-daemon cost -- --nocapture; cargo test -p colibri-daemon session:: -- --nocapture; cargo test -p colibri-daemon --all-targets; cargo check -p colibri-daemon; manual sidecar two-request smoke using a headroom-capable Python env.
2026-06-14 01:30:45 +02:00

477 lines
15 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Cost discipline — cache-first prompt assembly with three cost modes.
//!
//! Phase 5 (Reasonix ideas): deterministic 3-region prompt assembler,
//! cost modes (fast/smart/max), visible escalation, and large tool
//! result compaction.
//!
//! Region model (from session.rs):
//! 1. Immutable system prefix — byte-stable for DeepSeek cache hits
//! 2. Appendable conversation log — turns accumulate until compaction
//! 3. Volatile scratch — discarded per-turn, never persisted
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use tokio::time::timeout;
use tracing::{debug, info, warn};
// ---------------------------------------------------------------------------
// Cost mode
// ---------------------------------------------------------------------------
/// Cost mode governs how aggressively the daemon compacts context.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CostMode {
/// Minimal context window, aggressive compaction. Cheapest per-turn.
Fast,
/// Balanced thresholds — the default. Standard compaction.
#[default]
Smart,
/// Full context, no compaction, large tool results preserved. Most expensive.
Max,
}
impl CostMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Fast => "fast",
Self::Smart => "smart",
Self::Max => "max",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"fast" => Some(Self::Fast),
"smart" => Some(Self::Smart),
"max" => Some(Self::Max),
_ => None,
}
}
/// Session max bytes under this mode.
pub fn session_max_bytes(&self) -> u64 {
match self {
Self::Fast => 500_000, // ~12K tokens worth
Self::Smart => 2_000_000, // ~50K tokens
Self::Max => 8_000_000, // ~200K tokens
}
}
/// Max uncompacted turns before compaction triggers.
pub fn max_uncompacted_turns(&self) -> usize {
match self {
Self::Fast => 5,
Self::Smart => 20,
Self::Max => 100,
}
}
/// Whether to compact large tool results.
pub fn compact_tool_results(&self) -> bool {
match self {
Self::Fast => true,
Self::Smart => true,
Self::Max => false,
}
}
/// Max bytes for a single tool result before compaction kicks in.
pub fn tool_result_max_bytes(&self) -> u64 {
match self {
Self::Fast => 4_000,
Self::Smart => 16_000,
Self::Max => u64::MAX,
}
}
}
// ---------------------------------------------------------------------------
// Escalation
// ---------------------------------------------------------------------------
/// Escalation path: Fast → Smart → Max.
/// Each step logs visibly so operators can see cost increases.
pub fn escalate(current: CostMode) -> CostMode {
let next = match current {
CostMode::Fast => CostMode::Smart,
CostMode::Smart => CostMode::Max,
CostMode::Max => CostMode::Max,
};
if next != current {
info!(
from = current.as_str(),
to = next.as_str(),
"cost mode escalated"
);
}
next
}
/// What triggered an escalation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EscalationTrigger {
/// Session byte count exceeded current mode budget.
BudgetExceeded {
current_bytes: u64,
budget_bytes: u64,
},
/// Compaction didn't free enough space.
CompactionInsufficient { freed_bytes: u64, needed_bytes: u64 },
}
/// Auto-escalate if the trigger warrants it, returning the new mode.
/// Returns `None` if no escalation is needed (already at Max, or trigger
/// doesn't cross a threshold).
pub fn auto_escalate(current: CostMode, trigger: &EscalationTrigger) -> Option<CostMode> {
if current == CostMode::Max {
return None; // already at ceiling
}
let should_escalate = match trigger {
EscalationTrigger::BudgetExceeded {
current_bytes,
budget_bytes,
} => current_bytes > budget_bytes,
EscalationTrigger::CompactionInsufficient {
freed_bytes,
needed_bytes,
} => freed_bytes < needed_bytes,
};
if should_escalate {
let next = escalate(current);
if next != current {
return Some(next);
}
}
None
}
// ---------------------------------------------------------------------------
// Tool result compaction
// ---------------------------------------------------------------------------
/// Compact an oversized tool result into a summary.
///
/// If `raw` exceeds `max_bytes`, return a truncated + annotated version.
/// Otherwise return `None` (no compaction needed).
pub fn compact_tool_result(raw: &str, max_bytes: u64, tool_name: &str) -> Option<String> {
if raw.len() as u64 <= max_bytes {
return None;
}
let floor = raw.floor_char_boundary(max_bytes as usize);
let truncated = &raw[..floor];
Some(format!(
"[{tool_name} output truncated: {} bytes → {} bytes]\n{truncated}\n[... {} more bytes omitted]",
raw.len(),
floor,
raw.len() - floor,
))
}
// ---------------------------------------------------------------------------
// Headroom compression sidecar
// ---------------------------------------------------------------------------
const HEADROOM_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
/// A long-lived connection to the headroom compression sidecar process.
///
/// The sidecar is a Python script (`scripts/headroom-sidecar.py`) that
/// listens on a Unix domain socket. Each request is a single JSON line;
/// each response is a single JSON line. The connection is kept open and
/// reused across tool results.
pub struct HeadroomSidecar {
stream: UnixStream,
socket_path: std::path::PathBuf,
}
impl HeadroomSidecar {
/// Connect to the headroom sidecar at the given socket path.
pub async fn connect(socket_path: &std::path::Path) -> std::io::Result<Self> {
let stream = UnixStream::connect(socket_path).await?;
info!(
socket = %socket_path.display(),
"connected to headroom sidecar"
);
Ok(Self {
stream,
socket_path: socket_path.to_path_buf(),
})
}
/// Compress a tool result through the sidecar.
///
/// Returns the compressed content on success, or `None` if compression
/// failed, timed out, or produced no savings (graceful degradation).
pub async fn compress(&mut self, raw: &str, tool_name: &str) -> Option<String> {
self.compress_with_timeout(raw, tool_name, HEADROOM_REQUEST_TIMEOUT)
.await
}
async fn compress_with_timeout(
&mut self,
raw: &str,
tool_name: &str,
timeout_duration: Duration,
) -> Option<String> {
match timeout(timeout_duration, self.compress_once(raw, tool_name)).await {
Ok(result) => result,
Err(_) => {
warn!(
tool = %tool_name,
timeout_ms = timeout_duration.as_millis(),
"headroom sidecar request timed out"
);
None
}
}
}
async fn compress_once(&mut self, raw: &str, tool_name: &str) -> Option<String> {
let request = serde_json::json!({
"id": tool_name,
"raw": raw,
"role": "tool",
});
let payload = serde_json::to_string(&request).unwrap_or_default() + "\n";
match self.stream.write_all(payload.as_bytes()).await {
Ok(()) => {}
Err(e) => {
warn!(error = %e, "headroom sidecar write failed");
return None;
}
}
let mut reader = BufReader::new(&mut self.stream);
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => {
warn!("headroom sidecar closed connection");
return None;
}
Ok(_) => {}
Err(e) => {
warn!(error = %e, "headroom sidecar read failed");
return None;
}
}
let response: serde_json::Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, raw_line = %line, "headroom sidecar returned invalid JSON");
return None;
}
};
if response.get("error").is_some() {
warn!(error = %response["error"], "headroom sidecar error");
return None;
}
let tokens_before = response["tokens_before"].as_u64().unwrap_or(0);
let tokens_after = response["tokens_after"].as_u64().unwrap_or(0);
let compressed = response["compressed"].as_str().unwrap_or(raw);
if tokens_after >= tokens_before {
debug!(
tool = %tool_name,
tokens_before,
tokens_after,
"headroom: no savings, returning original"
);
return None;
}
info!(
tool = %tool_name,
tokens_before,
tokens_after,
saved = tokens_before.saturating_sub(tokens_after),
pct = 100 - (tokens_after * 100 / tokens_before.max(1)),
"headroom compression applied"
);
Some(compressed.to_string())
}
}
impl Drop for HeadroomSidecar {
fn drop(&mut self) {
info!(
socket = %self.socket_path.display(),
"headroom sidecar disconnected"
);
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cost_mode_defaults() {
assert_eq!(CostMode::default(), CostMode::Smart);
}
#[test]
fn test_fast_thresholds() {
assert_eq!(CostMode::Fast.session_max_bytes(), 500_000);
assert_eq!(CostMode::Fast.max_uncompacted_turns(), 5);
assert!(CostMode::Fast.compact_tool_results());
}
#[test]
fn test_max_thresholds() {
assert_eq!(CostMode::Max.session_max_bytes(), 8_000_000);
assert_eq!(CostMode::Max.max_uncompacted_turns(), 100);
assert!(!CostMode::Max.compact_tool_results());
}
#[test]
fn test_escalation_fast_to_smart() {
assert_eq!(escalate(CostMode::Fast), CostMode::Smart);
}
#[test]
fn test_escalation_smart_to_max() {
assert_eq!(escalate(CostMode::Smart), CostMode::Max);
}
#[test]
fn test_escalation_max_ceiling() {
assert_eq!(escalate(CostMode::Max), CostMode::Max);
}
#[test]
fn test_parse_cost_mode() {
assert_eq!(CostMode::parse("fast"), Some(CostMode::Fast));
assert_eq!(CostMode::parse("smart"), Some(CostMode::Smart));
assert_eq!(CostMode::parse("max"), Some(CostMode::Max));
assert_eq!(CostMode::parse("unknown"), None);
}
#[test]
fn test_tool_result_compaction_needed() {
let big = "x".repeat(10_000);
let result = compact_tool_result(&big, 4_000, "test_tool");
assert!(result.is_some());
let compacted = result.unwrap();
assert!(compacted.contains("truncated"));
assert!(compacted.contains("10000 bytes"));
}
#[test]
fn test_tool_result_compaction_not_needed() {
let small = "ok";
assert!(compact_tool_result(small, 4_000, "test_tool").is_none());
}
#[test]
fn test_tool_result_compaction_multibyte_no_panic() {
let big = "äöü日本語".repeat(2_000);
let result = compact_tool_result(&big, 50, "unicode_tool");
assert!(result.is_some());
let compacted = result.unwrap();
assert!(compacted.contains("truncated"));
assert!(compacted.is_char_boundary(compacted.len()));
}
#[test]
fn test_tool_result_compaction_slovenian_no_panic() {
// š, č, ž are 2-byte UTF-8 — same family as German umlauts.
// "Cene že še češnje je" = 21 bytes (7 ASCII + 6×2 multibyte).
let big = "Cene že še češnje je".repeat(2_000);
let result = compact_tool_result(&big, 50, "cene");
assert!(result.is_some());
let compacted = result.unwrap();
assert!(compacted.contains("truncated"));
assert!(compacted.is_char_boundary(compacted.len()));
}
fn test_socket_path(name: &str) -> std::path::PathBuf {
let dir =
std::env::temp_dir().join(format!("colibri-headroom-{name}-{}", uuid::Uuid::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
dir.join("headroom.sock")
}
#[tokio::test]
async fn headroom_sidecar_reuses_one_connection_for_multiple_requests() {
let socket = test_socket_path("reuse");
let listener = tokio::net::UnixListener::bind(&socket).unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut reader = BufReader::new(stream);
for i in 0..2 {
let mut line = String::new();
assert!(reader.read_line(&mut line).await.unwrap() > 0);
let request: serde_json::Value = serde_json::from_str(&line).unwrap();
assert_eq!(request["role"], "tool");
let response = serde_json::json!({
"id": request["id"],
"compressed": format!("compressed-{i}"),
"tokens_before": 100,
"tokens_after": 10,
});
reader
.get_mut()
.write_all((response.to_string() + "\n").as_bytes())
.await
.unwrap();
}
});
let mut sidecar = HeadroomSidecar::connect(&socket).await.unwrap();
assert_eq!(
sidecar.compress("raw one", "tool-a").await.as_deref(),
Some("compressed-0")
);
assert_eq!(
sidecar.compress("raw two", "tool-b").await.as_deref(),
Some("compressed-1")
);
server.await.unwrap();
let _ = std::fs::remove_file(&socket);
let _ = std::fs::remove_dir(socket.parent().unwrap());
}
#[tokio::test]
async fn headroom_sidecar_timeout_degrades_to_none() {
let socket = test_socket_path("timeout");
let listener = tokio::net::UnixListener::bind(&socket).unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut reader = BufReader::new(stream);
let mut line = String::new();
assert!(reader.read_line(&mut line).await.unwrap() > 0);
tokio::time::sleep(Duration::from_millis(200)).await;
});
let mut sidecar = HeadroomSidecar::connect(&socket).await.unwrap();
assert_eq!(
sidecar
.compress_with_timeout("raw", "slow-tool", Duration::from_millis(25))
.await,
None
);
server.await.unwrap();
let _ = std::fs::remove_file(&socket);
let _ = std::fs::remove_dir(socket.parent().unwrap());
}
}