mirror of
https://github.com/patriceckhart/zot.git
synced 2026-06-26 13:26:33 +02:00
Add temperature option
This commit is contained in:
parent
798174c22c
commit
85a3c3b73e
8 changed files with 87 additions and 25 deletions
|
|
@ -3,6 +3,7 @@ package agent
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/tui"
|
||||
|
|
@ -37,6 +38,7 @@ type Args struct {
|
|||
SystemPrompt string
|
||||
AppendSystemPrompt []string
|
||||
Reasoning string
|
||||
Temperature *float32
|
||||
|
||||
Continue bool
|
||||
Resume bool
|
||||
|
|
@ -201,6 +203,17 @@ func ParseArgs(in []string) (Args, error) {
|
|||
default:
|
||||
return a, fmt.Errorf("--reasoning must be off|minimum|low|medium|high|maximum")
|
||||
}
|
||||
case "--temperature":
|
||||
v, err := want(&i, arg)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
f, err := strconv.ParseFloat(v, 32)
|
||||
if err != nil || f < 0 || f > 2 {
|
||||
return a, fmt.Errorf("--temperature must be a number between 0 and 2")
|
||||
}
|
||||
t := float32(f)
|
||||
a.Temperature = &t
|
||||
case "--session":
|
||||
v, err := want(&i, arg)
|
||||
if err != nil {
|
||||
|
|
@ -361,6 +374,7 @@ func PrintHelp(version string) {
|
|||
row{"--api-key KEY", "api key for this run (env / auth.json fallback)"},
|
||||
row{"--base-url URL", "override provider api base url"},
|
||||
row{"--reasoning off|minimum|low|medium|high|maximum", "set thinking level on supported models"},
|
||||
row{"--temperature N", "sampling temperature, 0 to 2 (omit for provider default)"},
|
||||
)
|
||||
section("prompt and session flags",
|
||||
row{"--system-prompt TEXT", "replace the default system prompt"},
|
||||
|
|
|
|||
19
packages/agent/args_test.go
Normal file
19
packages/agent/args_test.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseArgsTemperatureAllowsZero(t *testing.T) {
|
||||
args, err := ParseArgs([]string{"--temperature", "0"})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseArgs returned %v", err)
|
||||
}
|
||||
if args.Temperature == nil || *args.Temperature != 0 {
|
||||
t.Fatalf("Temperature = %v; want 0", args.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseArgsTemperatureRejectsOutOfRange(t *testing.T) {
|
||||
if _, err := ParseArgs([]string{"--temperature", "2.1"}); err == nil {
|
||||
t.Fatal("ParseArgs accepted out-of-range temperature")
|
||||
}
|
||||
}
|
||||
|
|
@ -17,14 +17,15 @@ import (
|
|||
|
||||
// Resolved is the effective configuration after merging CLI, config, defaults.
|
||||
type Resolved struct {
|
||||
Provider string
|
||||
Model string
|
||||
Credential string // api key or oauth access token
|
||||
AuthMethod string // "apikey" | "oauth" | "" (no credential yet)
|
||||
AccountID string // ChatGPT account id (for openai oauth), "" otherwise
|
||||
BaseURL string
|
||||
CWD string
|
||||
Reasoning string
|
||||
Provider string
|
||||
Model string
|
||||
Credential string // api key or oauth access token
|
||||
AuthMethod string // "apikey" | "oauth" | "" (no credential yet)
|
||||
AccountID string // ChatGPT account id (for openai oauth), "" otherwise
|
||||
BaseURL string
|
||||
CWD string
|
||||
Reasoning string
|
||||
Temperature *float32
|
||||
|
||||
ToolRegistry core.Registry
|
||||
ToolSummary []ToolSummary
|
||||
|
|
@ -492,6 +493,10 @@ func Resolve(args Args, requireCred bool) (Resolved, error) {
|
|||
})
|
||||
|
||||
reasoning := provider.NormalizeReasoning(firstNonEmpty(args.Reasoning, cfg.Reasoning))
|
||||
temperature := args.Temperature
|
||||
if temperature == nil {
|
||||
temperature = cfg.Temperature
|
||||
}
|
||||
|
||||
max := args.MaxSteps // 0 = unlimited
|
||||
|
||||
|
|
@ -504,6 +509,7 @@ func Resolve(args Args, requireCred bool) (Resolved, error) {
|
|||
BaseURL: args.BaseURL,
|
||||
CWD: args.CWD,
|
||||
Reasoning: reasoning,
|
||||
Temperature: temperature,
|
||||
ToolRegistry: reg,
|
||||
ToolSummary: summaries,
|
||||
SystemPrompt: sys,
|
||||
|
|
@ -778,6 +784,7 @@ func (r Resolved) NewAgent() *core.Agent {
|
|||
a.MaxSteps = r.MaxSteps
|
||||
a.MaxTokens = r.MaxOutput
|
||||
a.Reasoning = r.Reasoning
|
||||
a.Temperature = r.Temperature
|
||||
return a
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,11 @@ import (
|
|||
|
||||
// Config is the persisted user configuration.
|
||||
type Config struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
Theme string `json:"theme"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
Theme string `json:"theme"`
|
||||
|
||||
// InlineImagesEnabled controls whether zot draws screenshots inline
|
||||
// when the terminal supports an image protocol. nil/missing means
|
||||
|
|
|
|||
|
|
@ -61,6 +61,9 @@ type Config struct {
|
|||
// ("low", "medium", "high"). Empty = no reasoning.
|
||||
Reasoning string
|
||||
|
||||
// Temperature sets the sampling temperature. Nil = provider default.
|
||||
Temperature *float32
|
||||
|
||||
// MaxSteps caps the agent loop iterations per Prompt call.
|
||||
// 0 uses the default (50).
|
||||
MaxSteps int
|
||||
|
|
@ -112,6 +115,7 @@ func New(cfg Config) (*Runtime, error) {
|
|||
SystemPrompt: cfg.SystemPrompt,
|
||||
AppendSystemPrompt: cfg.AppendSystemPrompt,
|
||||
Reasoning: cfg.Reasoning,
|
||||
Temperature: cfg.Temperature,
|
||||
MaxSteps: cfg.MaxSteps,
|
||||
Tools: cfg.Tools,
|
||||
NoTools: cfg.NoTools,
|
||||
|
|
|
|||
|
|
@ -15,12 +15,13 @@ import (
|
|||
// Agent is a stateful conversation bound to a provider client, a model,
|
||||
// and a set of tools.
|
||||
type Agent struct {
|
||||
Client provider.Client
|
||||
Model string
|
||||
System string
|
||||
Tools Registry
|
||||
MaxSteps int
|
||||
Reasoning string
|
||||
Client provider.Client
|
||||
Model string
|
||||
System string
|
||||
Tools Registry
|
||||
MaxSteps int
|
||||
Reasoning string
|
||||
Temperature *float32
|
||||
|
||||
// MaxTokens caps the model's output tokens per turn. Zero leaves
|
||||
// the field unset on the provider request, letting each provider
|
||||
|
|
@ -520,10 +521,11 @@ func (a *Agent) oneTurn(ctx context.Context, sink func(AgentEvent)) (provider.St
|
|||
// next in-process request is rejected by providers like Anthropic
|
||||
// with "tool_use ids were found without tool_result blocks". The
|
||||
// repair is pure and a no-op on already-valid transcripts.
|
||||
Messages: repairToolUseResultPairs(a.Messages()),
|
||||
Tools: a.Tools.Specs(),
|
||||
Reasoning: a.Reasoning,
|
||||
MaxTokens: a.MaxTokens,
|
||||
Messages: repairToolUseResultPairs(a.Messages()),
|
||||
Tools: a.Tools.Specs(),
|
||||
Reasoning: a.Reasoning,
|
||||
MaxTokens: a.MaxTokens,
|
||||
Temperature: a.Temperature,
|
||||
}
|
||||
stream, err := a.Client.Stream(ctx, req)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -144,3 +144,17 @@ func TestAgentPropagatesMaxTokens(t *testing.T) {
|
|||
t.Fatalf("request MaxTokens = %d; want 64000 (Agent.MaxTokens not propagated)", client.lastReq.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentPropagatesTemperature(t *testing.T) {
|
||||
client := &captureClient{}
|
||||
a := NewAgent(client, "fake-model", "system", Registry{})
|
||||
temp := float32(0)
|
||||
a.Temperature = &temp
|
||||
|
||||
if err := a.Prompt(context.Background(), "hello", nil, nil); err != nil {
|
||||
t.Fatalf("Prompt returned %v", err)
|
||||
}
|
||||
if client.lastReq.Temperature == nil || *client.lastReq.Temperature != temp {
|
||||
t.Fatalf("request Temperature = %v; want %v", client.lastReq.Temperature, temp)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,9 +47,10 @@ func (a *Agent) Compact(ctx context.Context, keepTail int, sink func(delta strin
|
|||
prompt := "<conversation>\n" + transcript + "\n</conversation>\n\n" + compactionPrompt
|
||||
|
||||
req := provider.Request{
|
||||
Model: a.Model,
|
||||
System: summarizationSystem,
|
||||
MaxTokens: 4096,
|
||||
Model: a.Model,
|
||||
System: summarizationSystem,
|
||||
MaxTokens: 4096,
|
||||
Temperature: a.Temperature,
|
||||
Messages: []provider.Message{
|
||||
{
|
||||
Role: provider.RoleUser,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue