From 85a3c3b73ec7be46abfcadc1b696d349eb9f1e58 Mon Sep 17 00:00:00 2001 From: patriceckhart Date: Sun, 14 Jun 2026 11:42:31 +0200 Subject: [PATCH] Add temperature option --- packages/agent/args.go | 14 ++++++++++++++ packages/agent/args_test.go | 19 +++++++++++++++++++ packages/agent/build.go | 23 +++++++++++++++-------- packages/agent/config.go | 9 +++++---- packages/agent/sdk/sdk.go | 4 ++++ packages/core/agent.go | 22 ++++++++++++---------- packages/core/agent_retry_test.go | 14 ++++++++++++++ packages/core/compact.go | 7 ++++--- 8 files changed, 87 insertions(+), 25 deletions(-) create mode 100644 packages/agent/args_test.go diff --git a/packages/agent/args.go b/packages/agent/args.go index b82ce58..6ec93b9 100644 --- a/packages/agent/args.go +++ b/packages/agent/args.go @@ -3,6 +3,7 @@ package agent import ( "fmt" "os" + "strconv" "strings" "github.com/patriceckhart/zot/packages/tui" @@ -37,6 +38,7 @@ type Args struct { SystemPrompt string AppendSystemPrompt []string Reasoning string + Temperature *float32 Continue bool Resume bool @@ -201,6 +203,17 @@ func ParseArgs(in []string) (Args, error) { default: return a, fmt.Errorf("--reasoning must be off|minimum|low|medium|high|maximum") } + case "--temperature": + v, err := want(&i, arg) + if err != nil { + return a, err + } + f, err := strconv.ParseFloat(v, 32) + if err != nil || f < 0 || f > 2 { + return a, fmt.Errorf("--temperature must be a number between 0 and 2") + } + t := float32(f) + a.Temperature = &t case "--session": v, err := want(&i, arg) if err != nil { @@ -361,6 +374,7 @@ func PrintHelp(version string) { row{"--api-key KEY", "api key for this run (env / auth.json fallback)"}, row{"--base-url URL", "override provider api base url"}, row{"--reasoning off|minimum|low|medium|high|maximum", "set thinking level on supported models"}, + row{"--temperature N", "sampling temperature, 0 to 2 (omit for provider default)"}, ) section("prompt and session flags", row{"--system-prompt TEXT", "replace the default system prompt"}, diff --git a/packages/agent/args_test.go b/packages/agent/args_test.go new file mode 100644 index 0000000..45ad356 --- /dev/null +++ b/packages/agent/args_test.go @@ -0,0 +1,19 @@ +package agent + +import "testing" + +func TestParseArgsTemperatureAllowsZero(t *testing.T) { + args, err := ParseArgs([]string{"--temperature", "0"}) + if err != nil { + t.Fatalf("ParseArgs returned %v", err) + } + if args.Temperature == nil || *args.Temperature != 0 { + t.Fatalf("Temperature = %v; want 0", args.Temperature) + } +} + +func TestParseArgsTemperatureRejectsOutOfRange(t *testing.T) { + if _, err := ParseArgs([]string{"--temperature", "2.1"}); err == nil { + t.Fatal("ParseArgs accepted out-of-range temperature") + } +} diff --git a/packages/agent/build.go b/packages/agent/build.go index 76aac8f..7b54b42 100644 --- a/packages/agent/build.go +++ b/packages/agent/build.go @@ -17,14 +17,15 @@ import ( // Resolved is the effective configuration after merging CLI, config, defaults. type Resolved struct { - Provider string - Model string - Credential string // api key or oauth access token - AuthMethod string // "apikey" | "oauth" | "" (no credential yet) - AccountID string // ChatGPT account id (for openai oauth), "" otherwise - BaseURL string - CWD string - Reasoning string + Provider string + Model string + Credential string // api key or oauth access token + AuthMethod string // "apikey" | "oauth" | "" (no credential yet) + AccountID string // ChatGPT account id (for openai oauth), "" otherwise + BaseURL string + CWD string + Reasoning string + Temperature *float32 ToolRegistry core.Registry ToolSummary []ToolSummary @@ -492,6 +493,10 @@ func Resolve(args Args, requireCred bool) (Resolved, error) { }) reasoning := provider.NormalizeReasoning(firstNonEmpty(args.Reasoning, cfg.Reasoning)) + temperature := args.Temperature + if temperature == nil { + temperature = cfg.Temperature + } max := args.MaxSteps // 0 = unlimited @@ -504,6 +509,7 @@ func Resolve(args Args, requireCred bool) (Resolved, error) { BaseURL: args.BaseURL, CWD: args.CWD, Reasoning: reasoning, + Temperature: temperature, ToolRegistry: reg, ToolSummary: summaries, SystemPrompt: sys, @@ -778,6 +784,7 @@ func (r Resolved) NewAgent() *core.Agent { a.MaxSteps = r.MaxSteps a.MaxTokens = r.MaxOutput a.Reasoning = r.Reasoning + a.Temperature = r.Temperature return a } diff --git a/packages/agent/config.go b/packages/agent/config.go index 4e1a359..ed6ac9f 100644 --- a/packages/agent/config.go +++ b/packages/agent/config.go @@ -16,10 +16,11 @@ import ( // Config is the persisted user configuration. type Config struct { - Provider string `json:"provider"` - Model string `json:"model"` - Reasoning string `json:"reasoning"` - Theme string `json:"theme"` + Provider string `json:"provider"` + Model string `json:"model"` + Reasoning string `json:"reasoning"` + Temperature *float32 `json:"temperature,omitempty"` + Theme string `json:"theme"` // InlineImagesEnabled controls whether zot draws screenshots inline // when the terminal supports an image protocol. nil/missing means diff --git a/packages/agent/sdk/sdk.go b/packages/agent/sdk/sdk.go index ca45c65..ea1bfb9 100644 --- a/packages/agent/sdk/sdk.go +++ b/packages/agent/sdk/sdk.go @@ -61,6 +61,9 @@ type Config struct { // ("low", "medium", "high"). Empty = no reasoning. Reasoning string + // Temperature sets the sampling temperature. Nil = provider default. + Temperature *float32 + // MaxSteps caps the agent loop iterations per Prompt call. // 0 uses the default (50). MaxSteps int @@ -112,6 +115,7 @@ func New(cfg Config) (*Runtime, error) { SystemPrompt: cfg.SystemPrompt, AppendSystemPrompt: cfg.AppendSystemPrompt, Reasoning: cfg.Reasoning, + Temperature: cfg.Temperature, MaxSteps: cfg.MaxSteps, Tools: cfg.Tools, NoTools: cfg.NoTools, diff --git a/packages/core/agent.go b/packages/core/agent.go index cf8731c..50acfa9 100644 --- a/packages/core/agent.go +++ b/packages/core/agent.go @@ -15,12 +15,13 @@ import ( // Agent is a stateful conversation bound to a provider client, a model, // and a set of tools. type Agent struct { - Client provider.Client - Model string - System string - Tools Registry - MaxSteps int - Reasoning string + Client provider.Client + Model string + System string + Tools Registry + MaxSteps int + Reasoning string + Temperature *float32 // MaxTokens caps the model's output tokens per turn. Zero leaves // the field unset on the provider request, letting each provider @@ -520,10 +521,11 @@ func (a *Agent) oneTurn(ctx context.Context, sink func(AgentEvent)) (provider.St // next in-process request is rejected by providers like Anthropic // with "tool_use ids were found without tool_result blocks". The // repair is pure and a no-op on already-valid transcripts. - Messages: repairToolUseResultPairs(a.Messages()), - Tools: a.Tools.Specs(), - Reasoning: a.Reasoning, - MaxTokens: a.MaxTokens, + Messages: repairToolUseResultPairs(a.Messages()), + Tools: a.Tools.Specs(), + Reasoning: a.Reasoning, + MaxTokens: a.MaxTokens, + Temperature: a.Temperature, } stream, err := a.Client.Stream(ctx, req) if err != nil { diff --git a/packages/core/agent_retry_test.go b/packages/core/agent_retry_test.go index d234817..2d6b4bd 100644 --- a/packages/core/agent_retry_test.go +++ b/packages/core/agent_retry_test.go @@ -144,3 +144,17 @@ func TestAgentPropagatesMaxTokens(t *testing.T) { t.Fatalf("request MaxTokens = %d; want 64000 (Agent.MaxTokens not propagated)", client.lastReq.MaxTokens) } } + +func TestAgentPropagatesTemperature(t *testing.T) { + client := &captureClient{} + a := NewAgent(client, "fake-model", "system", Registry{}) + temp := float32(0) + a.Temperature = &temp + + if err := a.Prompt(context.Background(), "hello", nil, nil); err != nil { + t.Fatalf("Prompt returned %v", err) + } + if client.lastReq.Temperature == nil || *client.lastReq.Temperature != temp { + t.Fatalf("request Temperature = %v; want %v", client.lastReq.Temperature, temp) + } +} diff --git a/packages/core/compact.go b/packages/core/compact.go index f73c2e1..a6234c6 100644 --- a/packages/core/compact.go +++ b/packages/core/compact.go @@ -47,9 +47,10 @@ func (a *Agent) Compact(ctx context.Context, keepTail int, sink func(delta strin prompt := "\n" + transcript + "\n\n\n" + compactionPrompt req := provider.Request{ - Model: a.Model, - System: summarizationSystem, - MaxTokens: 4096, + Model: a.Model, + System: summarizationSystem, + MaxTokens: 4096, + Temperature: a.Temperature, Messages: []provider.Message{ { Role: provider.RoleUser,