colibri/crates/colibri-daemon/src/cost.rs

478 lines
15 KiB
Rust
Raw Normal View History

//! Cost discipline — cache-first prompt assembly with three cost modes.
//!
//! Phase 5 (Reasonix ideas): deterministic 3-region prompt assembler,
//! cost modes (fast/smart/max), visible escalation, and large tool
//! result compaction.
//!
//! Region model (from session.rs):
//! 1. Immutable system prefix — byte-stable for DeepSeek cache hits
//! 2. Appendable conversation log — turns accumulate until compaction
//! 3. Volatile scratch — discarded per-turn, never persisted
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use tokio::time::timeout;
use tracing::{debug, info, warn};
// ---------------------------------------------------------------------------
// Cost mode
// ---------------------------------------------------------------------------
/// Cost mode governs how aggressively the daemon compacts context.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CostMode {
/// Minimal context window, aggressive compaction. Cheapest per-turn.
Fast,
/// Balanced thresholds — the default. Standard compaction.
#[default]
Smart,
/// Full context, no compaction, large tool results preserved. Most expensive.
Max,
}
impl CostMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Fast => "fast",
Self::Smart => "smart",
Self::Max => "max",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"fast" => Some(Self::Fast),
"smart" => Some(Self::Smart),
"max" => Some(Self::Max),
_ => None,
}
}
/// Session max bytes under this mode.
pub fn session_max_bytes(&self) -> u64 {
match self {
Self::Fast => 500_000, // ~12K tokens worth
Self::Smart => 2_000_000, // ~50K tokens
Self::Max => 8_000_000, // ~200K tokens
}
}
/// Max uncompacted turns before compaction triggers.
pub fn max_uncompacted_turns(&self) -> usize {
match self {
Self::Fast => 5,
Self::Smart => 20,
Self::Max => 100,
}
}
/// Whether to compact large tool results.
pub fn compact_tool_results(&self) -> bool {
match self {
Self::Fast => true,
Self::Smart => true,
Self::Max => false,
}
}
/// Max bytes for a single tool result before compaction kicks in.
pub fn tool_result_max_bytes(&self) -> u64 {
match self {
Self::Fast => 4_000,
Self::Smart => 16_000,
Self::Max => u64::MAX,
}
}
}
// ---------------------------------------------------------------------------
// Escalation
// ---------------------------------------------------------------------------
/// Escalation path: Fast → Smart → Max.
/// Each step logs visibly so operators can see cost increases.
pub fn escalate(current: CostMode) -> CostMode {
let next = match current {
CostMode::Fast => CostMode::Smart,
CostMode::Smart => CostMode::Max,
CostMode::Max => CostMode::Max,
};
if next != current {
info!(
from = current.as_str(),
to = next.as_str(),
"cost mode escalated"
);
}
next
}
/// What triggered an escalation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EscalationTrigger {
/// Session byte count exceeded current mode budget.
BudgetExceeded {
current_bytes: u64,
budget_bytes: u64,
},
/// Compaction didn't free enough space.
CompactionInsufficient { freed_bytes: u64, needed_bytes: u64 },
}
/// Auto-escalate if the trigger warrants it, returning the new mode.
/// Returns `None` if no escalation is needed (already at Max, or trigger
/// doesn't cross a threshold).
pub fn auto_escalate(current: CostMode, trigger: &EscalationTrigger) -> Option<CostMode> {
if current == CostMode::Max {
return None; // already at ceiling
}
let should_escalate = match trigger {
EscalationTrigger::BudgetExceeded {
current_bytes,
budget_bytes,
} => current_bytes > budget_bytes,
EscalationTrigger::CompactionInsufficient {
freed_bytes,
needed_bytes,
} => freed_bytes < needed_bytes,
};
if should_escalate {
let next = escalate(current);
if next != current {
return Some(next);
}
}
None
}
// ---------------------------------------------------------------------------
// Tool result compaction
// ---------------------------------------------------------------------------
/// Compact an oversized tool result into a summary.
///
/// If `raw` exceeds `max_bytes`, return a truncated + annotated version.
/// Otherwise return `None` (no compaction needed).
pub fn compact_tool_result(raw: &str, max_bytes: u64, tool_name: &str) -> Option<String> {
if raw.len() as u64 <= max_bytes {
return None;
}
let floor = raw.floor_char_boundary(max_bytes as usize);
let truncated = &raw[..floor];
Some(format!(
"[{tool_name} output truncated: {} bytes → {} bytes]\n{truncated}\n[... {} more bytes omitted]",
raw.len(),
floor,
raw.len() - floor,
))
}
// ---------------------------------------------------------------------------
// Headroom compression sidecar
// ---------------------------------------------------------------------------
const HEADROOM_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
/// A long-lived connection to the headroom compression sidecar process.
///
/// The sidecar is a Python script (`scripts/headroom-sidecar.py`) that
/// listens on a Unix domain socket. Each request is a single JSON line;
/// each response is a single JSON line. The connection is kept open and
/// reused across tool results.
pub struct HeadroomSidecar {
stream: UnixStream,
socket_path: std::path::PathBuf,
}
impl HeadroomSidecar {
/// Connect to the headroom sidecar at the given socket path.
pub async fn connect(socket_path: &std::path::Path) -> std::io::Result<Self> {
let stream = UnixStream::connect(socket_path).await?;
info!(
socket = %socket_path.display(),
"connected to headroom sidecar"
);
Ok(Self {
stream,
socket_path: socket_path.to_path_buf(),
})
}
/// Compress a tool result through the sidecar.
///
/// Returns the compressed content on success, or `None` if compression
/// failed, timed out, or produced no savings (graceful degradation).
pub async fn compress(&mut self, raw: &str, tool_name: &str) -> Option<String> {
self.compress_with_timeout(raw, tool_name, HEADROOM_REQUEST_TIMEOUT)
.await
}
async fn compress_with_timeout(
&mut self,
raw: &str,
tool_name: &str,
timeout_duration: Duration,
) -> Option<String> {
match timeout(timeout_duration, self.compress_once(raw, tool_name)).await {
Ok(result) => result,
Err(_) => {
warn!(
tool = %tool_name,
timeout_ms = timeout_duration.as_millis(),
"headroom sidecar request timed out"
);
None
}
}
}
async fn compress_once(&mut self, raw: &str, tool_name: &str) -> Option<String> {
let request = serde_json::json!({
"id": tool_name,
"raw": raw,
"role": "tool",
});
let payload = serde_json::to_string(&request).unwrap_or_default() + "\n";
match self.stream.write_all(payload.as_bytes()).await {
Ok(()) => {}
Err(e) => {
warn!(error = %e, "headroom sidecar write failed");
return None;
}
}
let mut reader = BufReader::new(&mut self.stream);
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => {
warn!("headroom sidecar closed connection");
return None;
}
Ok(_) => {}
Err(e) => {
warn!(error = %e, "headroom sidecar read failed");
return None;
}
}
let response: serde_json::Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, raw_line = %line, "headroom sidecar returned invalid JSON");
return None;
}
};
if response.get("error").is_some() {
warn!(error = %response["error"], "headroom sidecar error");
return None;
}
let tokens_before = response["tokens_before"].as_u64().unwrap_or(0);
let tokens_after = response["tokens_after"].as_u64().unwrap_or(0);
let compressed = response["compressed"].as_str().unwrap_or(raw);
if tokens_after >= tokens_before {
debug!(
tool = %tool_name,
tokens_before,
tokens_after,
"headroom: no savings, returning original"
);
return None;
}
info!(
tool = %tool_name,
tokens_before,
tokens_after,
saved = tokens_before.saturating_sub(tokens_after),
pct = 100 - (tokens_after * 100 / tokens_before.max(1)),
"headroom compression applied"
);
Some(compressed.to_string())
}
}
impl Drop for HeadroomSidecar {
fn drop(&mut self) {
info!(
socket = %self.socket_path.display(),
"headroom sidecar disconnected"
);
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cost_mode_defaults() {
assert_eq!(CostMode::default(), CostMode::Smart);
}
#[test]
fn test_fast_thresholds() {
assert_eq!(CostMode::Fast.session_max_bytes(), 500_000);
assert_eq!(CostMode::Fast.max_uncompacted_turns(), 5);
assert!(CostMode::Fast.compact_tool_results());
}
#[test]
fn test_max_thresholds() {
assert_eq!(CostMode::Max.session_max_bytes(), 8_000_000);
assert_eq!(CostMode::Max.max_uncompacted_turns(), 100);
assert!(!CostMode::Max.compact_tool_results());
}
#[test]
fn test_escalation_fast_to_smart() {
assert_eq!(escalate(CostMode::Fast), CostMode::Smart);
}
#[test]
fn test_escalation_smart_to_max() {
assert_eq!(escalate(CostMode::Smart), CostMode::Max);
}
#[test]
fn test_escalation_max_ceiling() {
assert_eq!(escalate(CostMode::Max), CostMode::Max);
}
#[test]
fn test_parse_cost_mode() {
assert_eq!(CostMode::parse("fast"), Some(CostMode::Fast));
assert_eq!(CostMode::parse("smart"), Some(CostMode::Smart));
assert_eq!(CostMode::parse("max"), Some(CostMode::Max));
assert_eq!(CostMode::parse("unknown"), None);
}
#[test]
fn test_tool_result_compaction_needed() {
let big = "x".repeat(10_000);
let result = compact_tool_result(&big, 4_000, "test_tool");
assert!(result.is_some());
let compacted = result.unwrap();
assert!(compacted.contains("truncated"));
assert!(compacted.contains("10000 bytes"));
}
#[test]
fn test_tool_result_compaction_not_needed() {
let small = "ok";
assert!(compact_tool_result(small, 4_000, "test_tool").is_none());
}
#[test]
fn test_tool_result_compaction_multibyte_no_panic() {
let big = "äöü日本語".repeat(2_000);
let result = compact_tool_result(&big, 50, "unicode_tool");
assert!(result.is_some());
let compacted = result.unwrap();
assert!(compacted.contains("truncated"));
assert!(compacted.is_char_boundary(compacted.len()));
}
#[test]
fn test_tool_result_compaction_slovenian_no_panic() {
// š, č, ž are 2-byte UTF-8 — same family as German umlauts.
// "Cene že še češnje je" = 21 bytes (7 ASCII + 6×2 multibyte).
let big = "Cene že še češnje je".repeat(2_000);
let result = compact_tool_result(&big, 50, "cene");
assert!(result.is_some());
let compacted = result.unwrap();
assert!(compacted.contains("truncated"));
assert!(compacted.is_char_boundary(compacted.len()));
}
fn test_socket_path(name: &str) -> std::path::PathBuf {
let dir =
std::env::temp_dir().join(format!("colibri-headroom-{name}-{}", uuid::Uuid::new_v4()));
std::fs::create_dir_all(&dir).unwrap();
dir.join("headroom.sock")
}
#[tokio::test]
async fn headroom_sidecar_reuses_one_connection_for_multiple_requests() {
let socket = test_socket_path("reuse");
let listener = tokio::net::UnixListener::bind(&socket).unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut reader = BufReader::new(stream);
for i in 0..2 {
let mut line = String::new();
assert!(reader.read_line(&mut line).await.unwrap() > 0);
let request: serde_json::Value = serde_json::from_str(&line).unwrap();
assert_eq!(request["role"], "tool");
let response = serde_json::json!({
"id": request["id"],
"compressed": format!("compressed-{i}"),
"tokens_before": 100,
"tokens_after": 10,
});
reader
.get_mut()
.write_all((response.to_string() + "\n").as_bytes())
.await
.unwrap();
}
});
let mut sidecar = HeadroomSidecar::connect(&socket).await.unwrap();
assert_eq!(
sidecar.compress("raw one", "tool-a").await.as_deref(),
Some("compressed-0")
);
assert_eq!(
sidecar.compress("raw two", "tool-b").await.as_deref(),
Some("compressed-1")
);
server.await.unwrap();
let _ = std::fs::remove_file(&socket);
let _ = std::fs::remove_dir(socket.parent().unwrap());
}
#[tokio::test]
async fn headroom_sidecar_timeout_degrades_to_none() {
let socket = test_socket_path("timeout");
let listener = tokio::net::UnixListener::bind(&socket).unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut reader = BufReader::new(stream);
let mut line = String::new();
assert!(reader.read_line(&mut line).await.unwrap() > 0);
tokio::time::sleep(Duration::from_millis(200)).await;
});
let mut sidecar = HeadroomSidecar::connect(&socket).await.unwrap();
assert_eq!(
sidecar
.compress_with_timeout("raw", "slow-tool", Duration::from_millis(25))
.await,
None
);
server.await.unwrap();
let _ = std::fs::remove_file(&socket);
let _ = std::fs::remove_dir(socket.parent().unwrap());
}
}