mirror of
https://github.com/patriceckhart/zot.git
synced 2026-06-26 21:36:31 +02:00
Merge 1ae05194e7 into b325477870
This commit is contained in:
commit
a436d36b3d
11 changed files with 888 additions and 517 deletions
|
|
@ -10,13 +10,14 @@ import (
|
|||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/agent/modes/bot"
|
||||
"github.com/patriceckhart/zot/packages/agent/modes/telegram"
|
||||
"github.com/patriceckhart/zot/packages/core"
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
)
|
||||
|
||||
// detachChild configures cmd to run in its own process group so tty
|
||||
|
|
@ -356,24 +357,35 @@ func botRun(rawTail []string, version string) error {
|
|||
s, _, serr := openOrCreateSessionForBot(args, resolved, agent, version)
|
||||
if serr == nil {
|
||||
sess = s
|
||||
agent.OnMessageAppended = func(msg provider.Message) {
|
||||
_ = sess.AppendMessage(msg)
|
||||
}
|
||||
agent.OnUsage = func(u provider.Usage) {
|
||||
_ = sess.AppendUsage(u, u)
|
||||
}
|
||||
agent.OnTranscriptCompacted = func(msgs []provider.Message) {
|
||||
_ = sess.AppendCompaction(msgs)
|
||||
}
|
||||
defer sess.Close()
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, "session:", serr)
|
||||
}
|
||||
}
|
||||
|
||||
var b *telegram.Bot
|
||||
b = &telegram.Bot{
|
||||
Client: telegram.NewClient(cfg.BotToken),
|
||||
Agent: agent,
|
||||
Config: cfg,
|
||||
// Construct the Telegram adapter and generic runner.
|
||||
adapter := telegram.NewAdapter(
|
||||
telegram.NewClient(cfg.BotToken),
|
||||
&cfg,
|
||||
func(c telegram.Config) error {
|
||||
return telegram.SaveConfig(ZotHome(), c)
|
||||
},
|
||||
)
|
||||
var runner *bot.Runner
|
||||
runner = bot.NewRunner(adapter, agent, bot.Config{
|
||||
ZotHome: ZotHome(),
|
||||
Provider: resolved.Provider,
|
||||
AuthMethod: resolved.AuthMethod,
|
||||
CWD: args.CWD,
|
||||
Save: func(c telegram.Config) error {
|
||||
return telegram.SaveConfig(ZotHome(), c)
|
||||
},
|
||||
RefreshCreds: func() error {
|
||||
// Re-run the same resolver the tui uses so we pick up
|
||||
// refreshed oauth tokens, re-logins, and model switches.
|
||||
|
|
@ -385,12 +397,10 @@ func botRun(rawTail []string, version string) error {
|
|||
}
|
||||
agent.Client = next.NewClient()
|
||||
agent.Model = next.Model
|
||||
b.Provider = next.Provider
|
||||
b.AuthMethod = next.AuthMethod
|
||||
b.CWD = next.CWD
|
||||
runner.UpdateRuntimeConfig(next.Provider, next.AuthMethod, next.CWD)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
// Record our pid so `zot telegram-bot status` / `zot telegram-bot stop` can find us,
|
||||
// regardless of whether we were started directly or via `bot start`.
|
||||
|
|
@ -407,7 +417,7 @@ func botRun(rawTail []string, version string) error {
|
|||
cancel()
|
||||
}()
|
||||
defer cancel()
|
||||
return b.Run(ctx)
|
||||
return runner.Run(ctx)
|
||||
}
|
||||
|
||||
// openOrCreateSessionForBot reuses the same logic as interactive mode
|
||||
|
|
@ -444,6 +454,3 @@ func maskToken(tok string) string {
|
|||
}
|
||||
return tok[:i+1] + body[:3] + "..." + body[len(body)-3:]
|
||||
}
|
||||
|
||||
// _ compile-time hint so the strconv import stays if we later add numeric parsing.
|
||||
var _ = strconv.Itoa
|
||||
|
|
|
|||
51
packages/agent/modes/bot/adapter.go
Normal file
51
packages/agent/modes/bot/adapter.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
// Package bot provides a protocol-agnostic runner for long-running bot
|
||||
// modes. Concrete transports (Telegram, Discord, …) implement the
|
||||
// BotAdapter interface; the Runner handles turn queueing, agent
|
||||
// prompting, command dispatch, and credential refresh.
|
||||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
)
|
||||
|
||||
// InboundMessage is a protocol-normalised message from a user.
|
||||
type InboundMessage struct {
|
||||
ChannelID string // opaque; adapter owns encoding (e.g. fmt.Sprintf("%d", chatID))
|
||||
MessageID string // optional reply anchor
|
||||
Text string
|
||||
Images []provider.ImageBlock
|
||||
}
|
||||
|
||||
// Command is a built-in bot command that bypasses the agent.
|
||||
type Command int
|
||||
|
||||
const (
|
||||
CmdStart Command = iota // first-time pairing / welcome
|
||||
CmdHelp // usage information
|
||||
CmdStatus // agent/provider state
|
||||
CmdStop // cancel the active turn
|
||||
)
|
||||
|
||||
// BotAdapter is the transport layer a concrete protocol must implement.
|
||||
// The Runner calls these methods; it never touches protocol types directly.
|
||||
type BotAdapter interface {
|
||||
// Run drives inbound polling; calls handler for normal messages and
|
||||
// commandHandler for built-in commands. Blocks until ctx is done.
|
||||
Run(ctx context.Context,
|
||||
handler func(InboundMessage),
|
||||
commandHandler func(Command, InboundMessage),
|
||||
) error
|
||||
|
||||
// Send delivers a reply. The adapter chunks to protocol limits.
|
||||
Send(ctx context.Context, channelID, text string) error
|
||||
|
||||
// IndicateWorking fires a "typing…" signal; returns a stop func.
|
||||
// Return a no-op if the protocol doesn't support it.
|
||||
IndicateWorking(ctx context.Context, channelID string) (stop func())
|
||||
|
||||
// StatusText appends protocol-specific info to /status replies
|
||||
// (e.g. "@botname"). Return "" if there is nothing to add.
|
||||
StatusText() string
|
||||
}
|
||||
11
packages/agent/modes/bot/commands.go
Normal file
11
packages/agent/modes/bot/commands.go
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package bot
|
||||
|
||||
import "strings"
|
||||
|
||||
// IsStopCommand reports whether text should abort the active turn.
|
||||
// Users often type plain "stop" rather than bot-style "/stop"; keep
|
||||
// this intentionally narrow so normal prompts like "stop doing X"
|
||||
// still go to the agent.
|
||||
func IsStopCommand(text string) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(text), "stop")
|
||||
}
|
||||
254
packages/agent/modes/bot/runner.go
Normal file
254
packages/agent/modes/bot/runner.go
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/core"
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
)
|
||||
|
||||
// stderr is a tiny hook so tests can redirect bot logging.
|
||||
var stderr = func() io.Writer { return os.Stderr }
|
||||
|
||||
// Config holds runner-level settings that are protocol-independent.
|
||||
type Config struct {
|
||||
ZotHome string
|
||||
Provider string
|
||||
AuthMethod string
|
||||
CWD string
|
||||
RefreshCreds func() error
|
||||
}
|
||||
|
||||
// queuedTurn is an inbound message waiting to become a prompt.
|
||||
type queuedTurn struct {
|
||||
channelID string
|
||||
messageID string
|
||||
prompt string
|
||||
images []provider.ImageBlock
|
||||
}
|
||||
|
||||
// Runner is the protocol-agnostic bot engine. It owns the turn queue,
|
||||
// dispatches prompts to the agent, and streams replies back through
|
||||
// the BotAdapter.
|
||||
type Runner struct {
|
||||
agent *core.Agent
|
||||
adapter BotAdapter
|
||||
cfg Config
|
||||
|
||||
mu sync.Mutex
|
||||
busy bool
|
||||
activeCtx context.CancelFunc
|
||||
queue []queuedTurn
|
||||
lastCtxInput int
|
||||
runCtx context.Context // set at Run entry; used by goroutines
|
||||
}
|
||||
|
||||
// NewRunner creates a Runner wired to the given adapter and agent.
|
||||
func NewRunner(adapter BotAdapter, agent *core.Agent, cfg Config) *Runner {
|
||||
return &Runner{
|
||||
agent: agent,
|
||||
adapter: adapter,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateRuntimeConfig updates provider/auth/cwd at runtime (e.g. after
|
||||
// credential refresh). This is thread-safe.
|
||||
func (r *Runner) UpdateRuntimeConfig(provider, authMethod, cwd string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.cfg.Provider = provider
|
||||
r.cfg.AuthMethod = authMethod
|
||||
r.cfg.CWD = cwd
|
||||
}
|
||||
|
||||
// Run starts the adapter's polling loop and blocks until ctx cancels.
|
||||
func (r *Runner) Run(ctx context.Context) error {
|
||||
r.mu.Lock()
|
||||
r.runCtx = ctx
|
||||
r.mu.Unlock()
|
||||
|
||||
return r.adapter.Run(ctx, r.handleMessage, r.handleCommand)
|
||||
}
|
||||
|
||||
// handleMessage is called by the adapter for every normal inbound message.
|
||||
func (r *Runner) handleMessage(msg InboundMessage) {
|
||||
if msg.Text == "" && len(msg.Images) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.queue = append(r.queue, queuedTurn{
|
||||
channelID: msg.ChannelID,
|
||||
messageID: msg.MessageID,
|
||||
prompt: msg.Text,
|
||||
images: msg.Images,
|
||||
})
|
||||
idle := !r.busy
|
||||
if idle {
|
||||
r.busy = true
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
if idle {
|
||||
go r.drainQueue()
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommand is called by the adapter for built-in commands.
|
||||
func (r *Runner) handleCommand(cmd Command, msg InboundMessage) {
|
||||
switch cmd {
|
||||
case CmdStart, CmdHelp:
|
||||
_ = r.adapter.Send(context.Background(), msg.ChannelID,
|
||||
"send me any message and i'll forward it to zot. attach an image and i'll pass it to the model. commands: /status, /stop, or plain stop.")
|
||||
case CmdStatus:
|
||||
r.sendStatus(msg.ChannelID)
|
||||
case CmdStop:
|
||||
r.cancelActiveTurn(msg.ChannelID, msg.MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
// drainQueue runs queued turns one at a time until the queue is empty.
|
||||
func (r *Runner) drainQueue() {
|
||||
r.mu.Lock()
|
||||
parent := r.runCtx
|
||||
r.mu.Unlock()
|
||||
|
||||
for {
|
||||
r.mu.Lock()
|
||||
if len(r.queue) == 0 {
|
||||
r.busy = false
|
||||
r.activeCtx = nil
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t := r.queue[0]
|
||||
r.queue = r.queue[1:]
|
||||
turnCtx, cancel := context.WithCancel(parent)
|
||||
r.activeCtx = cancel
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.cfg.RefreshCreds != nil {
|
||||
if err := r.cfg.RefreshCreds(); err != nil {
|
||||
fmt.Fprintln(stderr(), "bot: refresh creds:", err)
|
||||
}
|
||||
}
|
||||
r.runTurn(turnCtx, t)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// runTurn sends the queued prompt to the agent and streams the reply.
|
||||
func (r *Runner) runTurn(ctx context.Context, t queuedTurn) {
|
||||
stopWorking := r.adapter.IndicateWorking(ctx, t.channelID)
|
||||
defer stopWorking()
|
||||
|
||||
var replyBuilder strings.Builder
|
||||
var lastAssistantText string
|
||||
var turnErr error
|
||||
|
||||
sink := func(ev core.AgentEvent) {
|
||||
switch e := ev.(type) {
|
||||
case core.EvTextDelta:
|
||||
replyBuilder.WriteString(e.Delta)
|
||||
case core.EvUsage:
|
||||
r.mu.Lock()
|
||||
if e.Usage.InputTokens > 0 {
|
||||
r.lastCtxInput = e.Usage.InputTokens + e.Usage.CacheReadTokens + e.Usage.CacheWriteTokens
|
||||
}
|
||||
r.mu.Unlock()
|
||||
case core.EvAssistantMessage:
|
||||
var sb strings.Builder
|
||||
for _, c := range e.Message.Content {
|
||||
if tb, ok := c.(provider.TextBlock); ok {
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString(tb.Text)
|
||||
}
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
lastAssistantText = sb.String()
|
||||
}
|
||||
replyBuilder.Reset()
|
||||
case core.EvTurnEnd:
|
||||
if e.Err != nil {
|
||||
turnErr = e.Err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.agent.Prompt(ctx, t.prompt, t.images, sink); err != nil {
|
||||
turnErr = err
|
||||
}
|
||||
|
||||
reply := strings.TrimSpace(lastAssistantText)
|
||||
if reply == "" {
|
||||
reply = strings.TrimSpace(replyBuilder.String())
|
||||
}
|
||||
if turnErr != nil && ctx.Err() == nil {
|
||||
reply = "error: " + turnErr.Error()
|
||||
}
|
||||
if reply == "" {
|
||||
reply = "(no reply)"
|
||||
}
|
||||
|
||||
// Adapter.Send is responsible for chunking to protocol limits.
|
||||
if err := r.adapter.Send(context.Background(), t.channelID, reply); err != nil {
|
||||
fmt.Fprintln(stderr(), "bot: send reply:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// cancelActiveTurn aborts the currently running turn, if any.
|
||||
func (r *Runner) cancelActiveTurn(channelID, messageID string) {
|
||||
r.mu.Lock()
|
||||
cancel := r.activeCtx
|
||||
r.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
_ = r.adapter.Send(context.Background(), channelID, "cancelled the current turn.")
|
||||
} else {
|
||||
_ = r.adapter.Send(context.Background(), channelID, "nothing running.")
|
||||
}
|
||||
}
|
||||
|
||||
// sendStatus describes agent state to the user.
|
||||
func (r *Runner) sendStatus(channelID string) {
|
||||
r.mu.Lock()
|
||||
busy := r.busy
|
||||
queued := len(r.queue)
|
||||
ctxUsed := r.lastCtxInput
|
||||
providerName := r.cfg.Provider
|
||||
authMethod := r.cfg.AuthMethod
|
||||
cwd := r.cfg.CWD
|
||||
r.mu.Unlock()
|
||||
|
||||
model := r.agent.Model
|
||||
ctxMax := 0
|
||||
if m, err := provider.FindModel(providerName, model); err == nil {
|
||||
ctxMax = m.ContextWindow
|
||||
}
|
||||
|
||||
status := FormatStatus(StatusSnapshot{
|
||||
Provider: providerName,
|
||||
Model: model,
|
||||
CWD: cwd,
|
||||
Usage: r.agent.Cost(),
|
||||
Subscription: authMethod == "oauth",
|
||||
ContextUsed: ctxUsed,
|
||||
ContextMax: ctxMax,
|
||||
Busy: busy,
|
||||
Queued: queued,
|
||||
})
|
||||
|
||||
if extra := r.adapter.StatusText(); extra != "" {
|
||||
status += "\n" + extra
|
||||
}
|
||||
|
||||
_ = r.adapter.Send(context.Background(), channelID, status)
|
||||
}
|
||||
103
packages/agent/modes/bot/runner_test.go
Normal file
103
packages/agent/modes/bot/runner_test.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/core"
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
)
|
||||
|
||||
type testAdapter struct{}
|
||||
|
||||
func (testAdapter) Run(context.Context, func(InboundMessage), func(Command, InboundMessage)) error {
|
||||
return nil
|
||||
}
|
||||
func (testAdapter) Send(context.Context, string, string) error { return nil }
|
||||
func (testAdapter) IndicateWorking(context.Context, string) func() { return func() {} }
|
||||
func (testAdapter) StatusText() string { return "" }
|
||||
|
||||
type blockingClient struct {
|
||||
started chan struct{}
|
||||
release chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
active int
|
||||
maxActive int
|
||||
}
|
||||
|
||||
func (c *blockingClient) Name() string { return "test" }
|
||||
|
||||
func (c *blockingClient) Stream(ctx context.Context, req provider.Request) (<-chan provider.Event, error) {
|
||||
c.mu.Lock()
|
||||
c.active++
|
||||
if c.active > c.maxActive {
|
||||
c.maxActive = c.active
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case c.started <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
out := make(chan provider.Event, 1)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
c.active--
|
||||
c.mu.Unlock()
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
out <- provider.EventDone{Stop: provider.StopAborted, Err: ctx.Err()}
|
||||
case <-c.release:
|
||||
out <- provider.EventDone{Stop: provider.StopEnd, Message: provider.Message{
|
||||
Role: provider.RoleAssistant,
|
||||
Content: []provider.Content{provider.TextBlock{Text: "ok"}},
|
||||
}}
|
||||
}
|
||||
}()
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *blockingClient) max() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.maxActive
|
||||
}
|
||||
|
||||
func TestHandleMessageClaimsDrainSlotBeforeSpawningDrainer(t *testing.T) {
|
||||
client := &blockingClient{started: make(chan struct{}, 2), release: make(chan struct{})}
|
||||
r := NewRunner(testAdapter{}, core.NewAgent(client, "test-model", "", nil), Config{})
|
||||
r.runCtx = context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.handleMessage(InboundMessage{ChannelID: "c", Text: "one"})
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.handleMessage(InboundMessage{ChannelID: "c", Text: "two"})
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-client.started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("first turn did not start")
|
||||
}
|
||||
|
||||
// Give any accidentally spawned second drainer a chance to enter Stream.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if got := client.max(); got != 1 {
|
||||
t.Fatalf("concurrent provider streams = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(client.release)
|
||||
}
|
||||
120
packages/agent/modes/bot/status.go
Normal file
120
packages/agent/modes/bot/status.go
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
)
|
||||
|
||||
// StatusSnapshot is the small cross-host state bundle rendered for
|
||||
// /status replies.
|
||||
type StatusSnapshot struct {
|
||||
Provider string
|
||||
Model string
|
||||
CWD string
|
||||
Usage provider.Usage
|
||||
Subscription bool
|
||||
ContextUsed int
|
||||
ContextMax int
|
||||
Busy bool
|
||||
Queued int
|
||||
}
|
||||
|
||||
// FormatStatus renders the same compact model/usage/cost/context
|
||||
// information shown in the TUI status bar, plus the current directory.
|
||||
func FormatStatus(s StatusSnapshot) string {
|
||||
providerName := strings.TrimSpace(s.Provider)
|
||||
model := strings.TrimSpace(s.Model)
|
||||
if providerName == "" {
|
||||
providerName = "unknown"
|
||||
}
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
|
||||
var stats []string
|
||||
if s.Usage.InputTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("↑%s", formatTokens(s.Usage.InputTokens)))
|
||||
}
|
||||
if s.Usage.OutputTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("↓%s", formatTokens(s.Usage.OutputTokens)))
|
||||
}
|
||||
if s.Usage.CacheReadTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("R%s", formatTokens(s.Usage.CacheReadTokens)))
|
||||
}
|
||||
if s.Usage.CacheWriteTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("W%s", formatTokens(s.Usage.CacheWriteTokens)))
|
||||
}
|
||||
if s.Usage.CostUSD > 0 || s.Subscription {
|
||||
cost := fmt.Sprintf("$%.3f", s.Usage.CostUSD)
|
||||
if s.Subscription {
|
||||
cost += " (sub)"
|
||||
}
|
||||
stats = append(stats, cost)
|
||||
}
|
||||
if ctx := contextUsage(s.ContextUsed, s.ContextMax); ctx != "" {
|
||||
stats = append(stats, ctx)
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("(%s) %s", providerName, model)
|
||||
if len(stats) > 0 {
|
||||
line += " " + strings.Join(stats, " ")
|
||||
}
|
||||
|
||||
state := "idle"
|
||||
if s.Busy {
|
||||
state = "working"
|
||||
}
|
||||
lines := []string{line, "state: " + state}
|
||||
if s.Queued > 0 {
|
||||
lines = append(lines, fmt.Sprintf("queued: %d", s.Queued))
|
||||
}
|
||||
if cwd := shortenHome(strings.TrimSpace(s.CWD)); cwd != "" {
|
||||
lines = append(lines, "cwd: "+cwd)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func contextUsage(used, max int) string {
|
||||
if max <= 0 {
|
||||
if used <= 0 {
|
||||
return ""
|
||||
}
|
||||
return formatTokens(used)
|
||||
}
|
||||
pct := float64(used) / float64(max) * 100
|
||||
return fmt.Sprintf("%.1f%%/%s", pct, formatTokens(max))
|
||||
}
|
||||
|
||||
func formatTokens(n int) string {
|
||||
switch {
|
||||
case n < 0:
|
||||
return "0"
|
||||
case n < 1000:
|
||||
return fmt.Sprintf("%d", n)
|
||||
case n < 10000:
|
||||
return fmt.Sprintf("%.1fk", float64(n)/1000)
|
||||
case n < 1_000_000:
|
||||
return fmt.Sprintf("%dk", (n+500)/1000)
|
||||
case n < 10_000_000:
|
||||
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
||||
default:
|
||||
return fmt.Sprintf("%dM", (n+500_000)/1_000_000)
|
||||
}
|
||||
}
|
||||
|
||||
func shortenHome(path string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || home == "" {
|
||||
return path
|
||||
}
|
||||
if path == home {
|
||||
return "~"
|
||||
}
|
||||
if strings.HasPrefix(path, home+string(os.PathSeparator)) {
|
||||
return "~" + strings.TrimPrefix(path, home)
|
||||
}
|
||||
return path
|
||||
}
|
||||
250
packages/agent/modes/telegram/adapter.go
Normal file
250
packages/agent/modes/telegram/adapter.go
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/agent/modes/bot"
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
)
|
||||
|
||||
// Adapter implements bot.BotAdapter for Telegram.
|
||||
type Adapter struct {
|
||||
Client *Client
|
||||
Cfg *Config // pointer so Run can mutate and persist
|
||||
Save func(Config) error
|
||||
}
|
||||
|
||||
// NewAdapter creates a Telegram adapter.
|
||||
func NewAdapter(client *Client, cfg *Config, save func(Config) error) *Adapter {
|
||||
return &Adapter{Client: client, Cfg: cfg, Save: save}
|
||||
}
|
||||
|
||||
// Run drives the Telegram long-polling loop. It performs initial
|
||||
// GetMe, handles pairing, and dispatches inbound messages to the
|
||||
// generic handler / commandHandler callbacks.
|
||||
func (a *Adapter) Run(ctx context.Context,
|
||||
handler func(bot.InboundMessage),
|
||||
commandHandler func(bot.Command, bot.InboundMessage),
|
||||
) error {
|
||||
if a.Cfg.BotToken == "" {
|
||||
return fmt.Errorf("no bot token configured; run `zot bot setup` first")
|
||||
}
|
||||
me, err := a.Client.GetMe(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getMe: %w", err)
|
||||
}
|
||||
// Keep the stored username/id in sync with the actual bot.
|
||||
if a.Cfg.BotID != me.ID || a.Cfg.BotUsername != me.Username {
|
||||
a.Cfg.BotID = me.ID
|
||||
a.Cfg.BotUsername = me.Username
|
||||
_ = a.Save(*a.Cfg)
|
||||
}
|
||||
|
||||
fmt.Printf("telegram bridge online as @%s (id=%d)\n", me.Username, me.ID)
|
||||
if a.Cfg.AllowedUserID == 0 {
|
||||
fmt.Println("no user paired yet — send /start to the bot from Telegram to claim it")
|
||||
} else {
|
||||
fmt.Printf("paired with telegram user id %d\n", a.Cfg.AllowedUserID)
|
||||
}
|
||||
|
||||
return a.pollLoop(ctx, handler, commandHandler)
|
||||
}
|
||||
|
||||
// pollLoop long-polls Telegram for updates and dispatches them.
|
||||
func (a *Adapter) pollLoop(ctx context.Context,
|
||||
handler func(bot.InboundMessage),
|
||||
commandHandler func(bot.Command, bot.InboundMessage),
|
||||
) error {
|
||||
backoff := time.Second
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
updates, err := a.Client.GetUpdates(ctx, a.Cfg.LastUpdateID+1, 30)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
fmt.Fprintln(stderr(), "telegram: getUpdates error:", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
if backoff < 30*time.Second {
|
||||
backoff *= 2
|
||||
}
|
||||
continue
|
||||
}
|
||||
backoff = time.Second
|
||||
for _, u := range updates {
|
||||
a.handleUpdate(ctx, u, handler, commandHandler)
|
||||
a.Cfg.LastUpdateID = u.UpdateID
|
||||
_ = a.Save(*a.Cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdate processes a single Telegram update. Telegram-specific
|
||||
// concerns (pairing, user filtering, image download) live here; the
|
||||
// generic callbacks are called for normal messages and commands.
|
||||
func (a *Adapter) handleUpdate(ctx context.Context, u Update,
|
||||
handler func(bot.InboundMessage),
|
||||
commandHandler func(bot.Command, bot.InboundMessage),
|
||||
) {
|
||||
msg := u.Message
|
||||
if msg == nil {
|
||||
msg = u.Edited
|
||||
}
|
||||
if msg == nil || msg.From == nil || msg.From.IsBot {
|
||||
return
|
||||
}
|
||||
if msg.Chat.Type != "private" {
|
||||
return
|
||||
}
|
||||
|
||||
chanID := fmt.Sprintf("%d", msg.Chat.ID)
|
||||
msgID := fmt.Sprintf("%d", msg.MessageID)
|
||||
|
||||
// Pairing: first user who sends /start claims the bridge.
|
||||
text := strings.TrimSpace(msg.Text)
|
||||
if a.Cfg.AllowedUserID == 0 {
|
||||
if strings.HasPrefix(text, "/start") {
|
||||
a.Cfg.AllowedUserID = msg.From.ID
|
||||
_ = a.Save(*a.Cfg)
|
||||
_ = a.Client.SendMessage(ctx, msg.Chat.ID,
|
||||
fmt.Sprintf("paired with @%s. send any message and i'll forward it to zot.", msg.From.Username),
|
||||
msg.MessageID)
|
||||
return
|
||||
}
|
||||
_ = a.Client.SendMessage(ctx, msg.Chat.ID,
|
||||
"this bot isn't paired yet. send /start to claim it.",
|
||||
msg.MessageID)
|
||||
return
|
||||
}
|
||||
|
||||
// Enforce allowed user.
|
||||
if msg.From.ID != a.Cfg.AllowedUserID {
|
||||
_ = a.Client.SendMessage(ctx, msg.Chat.ID,
|
||||
"this bot is paired with a different user.",
|
||||
msg.MessageID)
|
||||
return
|
||||
}
|
||||
|
||||
inbound := bot.InboundMessage{
|
||||
ChannelID: chanID,
|
||||
MessageID: msgID,
|
||||
}
|
||||
|
||||
// Built-in commands that bypass the agent.
|
||||
switch text {
|
||||
case "/start":
|
||||
commandHandler(bot.CmdStart, inbound)
|
||||
return
|
||||
case "/help":
|
||||
commandHandler(bot.CmdHelp, inbound)
|
||||
return
|
||||
case "/status":
|
||||
commandHandler(bot.CmdStatus, inbound)
|
||||
return
|
||||
case "/stop":
|
||||
commandHandler(bot.CmdStop, inbound)
|
||||
return
|
||||
}
|
||||
if bot.IsStopCommand(text) {
|
||||
commandHandler(bot.CmdStop, inbound)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the prompt: combine text + caption; download image attachments.
|
||||
prompt := strings.TrimSpace(msg.Text)
|
||||
if msg.Caption != "" {
|
||||
if prompt != "" {
|
||||
prompt += "\n"
|
||||
}
|
||||
prompt += msg.Caption
|
||||
}
|
||||
|
||||
var images []provider.ImageBlock
|
||||
if len(msg.Photo) > 0 {
|
||||
largest := msg.Photo[len(msg.Photo)-1]
|
||||
if data, mime, err := a.download(ctx, largest.FileID, ""); err == nil {
|
||||
images = append(images, provider.ImageBlock{MimeType: mime, Data: data})
|
||||
} else {
|
||||
fmt.Fprintln(stderr(), "telegram: download photo:", err)
|
||||
}
|
||||
}
|
||||
if msg.Document != nil && isImageMIME(msg.Document.MimeType) {
|
||||
if data, mime, err := a.download(ctx, msg.Document.FileID, msg.Document.MimeType); err == nil {
|
||||
images = append(images, provider.ImageBlock{MimeType: mime, Data: data})
|
||||
}
|
||||
}
|
||||
|
||||
inbound.Text = prompt
|
||||
inbound.Images = images
|
||||
handler(inbound)
|
||||
}
|
||||
|
||||
// Send delivers a reply to a Telegram chat. channelID is parsed back
|
||||
// to int64. Messages are chunked to 4000 runes (Telegram limit 4096).
|
||||
func (a *Adapter) Send(ctx context.Context, channelID, text string) error {
|
||||
chatID, err := strconv.ParseInt(channelID, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid channelID %q: %w", channelID, err)
|
||||
}
|
||||
for _, chunk := range chunkMessage(text, 4000) {
|
||||
if err := a.Client.SendMessage(ctx, chatID, chunk, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IndicateWorking keeps Telegram's "typing..." indicator alive until
|
||||
// the returned stop function is called.
|
||||
func (a *Adapter) IndicateWorking(ctx context.Context, channelID string) (stop func()) {
|
||||
chatID, err := strconv.ParseInt(channelID, 10, 64)
|
||||
if err != nil {
|
||||
return func() {}
|
||||
}
|
||||
tctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
for {
|
||||
_ = a.Client.SendChatAction(tctx, chatID, "typing")
|
||||
select {
|
||||
case <-tctx.Done():
|
||||
return
|
||||
case <-time.After(4 * time.Second):
|
||||
}
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
// StatusText returns the bot's @username for inclusion in /status.
|
||||
func (a *Adapter) StatusText() string {
|
||||
if a.Cfg.BotUsername != "" {
|
||||
return "@" + a.Cfg.BotUsername
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// download fetches a file from Telegram and returns bytes + mime.
|
||||
func (a *Adapter) download(ctx context.Context, fileID, mime string) ([]byte, string, error) {
|
||||
f, err := a.Client.GetFile(ctx, fileID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
data, err := a.Client.DownloadFile(ctx, f.FilePath)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if mime == "" {
|
||||
mime = guessImageMIME(f.FilePath)
|
||||
}
|
||||
return data, mime, nil
|
||||
}
|
||||
|
|
@ -1,408 +1,68 @@
|
|||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/core"
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Bot owns the Telegram polling loop and dispatches inbound DMs to
|
||||
// the agent. It is a long-running goroutine; Run blocks until ctx
|
||||
// cancels.
|
||||
type Bot struct {
|
||||
Client *Client
|
||||
Agent *core.Agent
|
||||
Config Config
|
||||
ZotHome string
|
||||
Provider string
|
||||
AuthMethod string
|
||||
CWD string
|
||||
// Save persists cfg to bot.json. Called whenever the bot pairs
|
||||
// with a new allowed user or advances LastUpdateID.
|
||||
Save func(Config) error
|
||||
// RefreshCreds is called before every turn to pick up newly
|
||||
// refreshed OAuth tokens. Optional; when nil, the bot uses the
|
||||
// credential it was built with. Implementations typically call
|
||||
// agent.ResolveCredentialFull which auto-refreshes expired tokens.
|
||||
RefreshCreds func() error
|
||||
|
||||
mu sync.Mutex
|
||||
busy bool
|
||||
activeCtx context.CancelFunc
|
||||
queue []queuedTurn
|
||||
lastCtxInput int
|
||||
}
|
||||
|
||||
// queuedTurn is an inbound DM waiting to become a prompt.
|
||||
type queuedTurn struct {
|
||||
chatID int64
|
||||
messageID int
|
||||
prompt string
|
||||
images []provider.ImageBlock
|
||||
}
|
||||
|
||||
// Run drives the bot. Returns when ctx is cancelled or GetMe fails.
|
||||
func (b *Bot) Run(ctx context.Context) error {
|
||||
if b.Config.BotToken == "" {
|
||||
return fmt.Errorf("no bot token configured; run `zot bot setup` first")
|
||||
}
|
||||
me, err := b.Client.GetMe(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getMe: %w", err)
|
||||
}
|
||||
// Keep the stored username/id in sync with the actual bot.
|
||||
if b.Config.BotID != me.ID || b.Config.BotUsername != me.Username {
|
||||
b.Config.BotID = me.ID
|
||||
b.Config.BotUsername = me.Username
|
||||
_ = b.Save(b.Config)
|
||||
}
|
||||
|
||||
fmt.Printf("telegram bridge online as @%s (id=%d)\n", me.Username, me.ID)
|
||||
if b.Config.AllowedUserID == 0 {
|
||||
fmt.Println("no user paired yet — send /start to the bot from Telegram to claim it")
|
||||
} else {
|
||||
fmt.Printf("paired with telegram user id %d\n", b.Config.AllowedUserID)
|
||||
}
|
||||
|
||||
return b.pollLoop(ctx)
|
||||
}
|
||||
|
||||
// pollLoop long-polls Telegram for updates and dispatches them.
|
||||
func (b *Bot) pollLoop(ctx context.Context) error {
|
||||
backoff := time.Second
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
updates, err := b.Client.GetUpdates(ctx, b.Config.LastUpdateID+1, 30)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
fmt.Fprintln(stderr(), "telegram: getUpdates error:", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
if backoff < 30*time.Second {
|
||||
backoff *= 2
|
||||
}
|
||||
continue
|
||||
}
|
||||
backoff = time.Second
|
||||
for _, u := range updates {
|
||||
if err := b.handleUpdate(ctx, u); err != nil {
|
||||
fmt.Fprintln(stderr(), "telegram: handleUpdate:", err)
|
||||
}
|
||||
b.Config.LastUpdateID = u.UpdateID
|
||||
_ = b.Save(b.Config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdate processes a single Telegram update.
|
||||
func (b *Bot) handleUpdate(ctx context.Context, u Update) error {
|
||||
msg := u.Message
|
||||
if msg == nil {
|
||||
msg = u.Edited
|
||||
}
|
||||
if msg == nil || msg.From == nil || msg.From.IsBot {
|
||||
return nil
|
||||
}
|
||||
if msg.Chat.Type != "private" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pairing: first user who sends /start claims the bridge.
|
||||
text := strings.TrimSpace(msg.Text)
|
||||
if b.Config.AllowedUserID == 0 {
|
||||
if strings.HasPrefix(text, "/start") {
|
||||
b.Config.AllowedUserID = msg.From.ID
|
||||
_ = b.Save(b.Config)
|
||||
_ = b.Client.SendMessage(ctx, msg.Chat.ID,
|
||||
fmt.Sprintf("paired with @%s. send any message and i'll forward it to zot.", msg.From.Username),
|
||||
msg.MessageID)
|
||||
return nil
|
||||
}
|
||||
_ = b.Client.SendMessage(ctx, msg.Chat.ID,
|
||||
"this bot isn't paired yet. send /start to claim it.",
|
||||
msg.MessageID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enforce allowed user.
|
||||
if msg.From.ID != b.Config.AllowedUserID {
|
||||
_ = b.Client.SendMessage(ctx, msg.Chat.ID,
|
||||
"this bot is paired with a different user.",
|
||||
msg.MessageID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Built-in commands that bypass the agent.
|
||||
switch text {
|
||||
case "/start", "/help":
|
||||
_ = b.Client.SendMessage(ctx, msg.Chat.ID,
|
||||
"send me any message and i'll forward it to zot. attach an image and i'll pass it to the model. commands: /status, /stop, or plain stop.",
|
||||
msg.MessageID)
|
||||
return nil
|
||||
case "/status":
|
||||
return b.sendStatus(ctx, msg.Chat.ID, msg.MessageID)
|
||||
case "/stop":
|
||||
b.cancelActiveTurn(ctx, msg.Chat.ID, msg.MessageID)
|
||||
return nil
|
||||
}
|
||||
if isStopCommand(text) {
|
||||
b.cancelActiveTurn(ctx, msg.Chat.ID, msg.MessageID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build the prompt: combine text + caption; download image attachments.
|
||||
prompt := strings.TrimSpace(msg.Text)
|
||||
if msg.Caption != "" {
|
||||
if prompt != "" {
|
||||
prompt += "\n"
|
||||
}
|
||||
prompt += msg.Caption
|
||||
}
|
||||
|
||||
var images []provider.ImageBlock
|
||||
if len(msg.Photo) > 0 {
|
||||
// Photos arrive in multiple sizes; take the largest (last in the slice).
|
||||
largest := msg.Photo[len(msg.Photo)-1]
|
||||
if data, mime, err := b.download(ctx, largest.FileID, ""); err == nil {
|
||||
images = append(images, provider.ImageBlock{MimeType: mime, Data: data})
|
||||
} else {
|
||||
fmt.Fprintln(stderr(), "telegram: download photo:", err)
|
||||
}
|
||||
}
|
||||
if msg.Document != nil && isImageMIME(msg.Document.MimeType) {
|
||||
if data, mime, err := b.download(ctx, msg.Document.FileID, msg.Document.MimeType); err == nil {
|
||||
images = append(images, provider.ImageBlock{MimeType: mime, Data: data})
|
||||
}
|
||||
}
|
||||
|
||||
if prompt == "" && len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
b.queue = append(b.queue, queuedTurn{
|
||||
chatID: msg.Chat.ID,
|
||||
messageID: msg.MessageID,
|
||||
prompt: prompt,
|
||||
images: images,
|
||||
})
|
||||
idle := !b.busy
|
||||
b.mu.Unlock()
|
||||
|
||||
if idle {
|
||||
go b.drainQueue(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// drainQueue runs queued turns one at a time until the queue is empty.
|
||||
func (b *Bot) drainQueue(parent context.Context) {
|
||||
for {
|
||||
b.mu.Lock()
|
||||
if len(b.queue) == 0 {
|
||||
b.busy = false
|
||||
b.activeCtx = nil
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t := b.queue[0]
|
||||
b.queue = b.queue[1:]
|
||||
b.busy = true
|
||||
turnCtx, cancel := context.WithCancel(parent)
|
||||
b.activeCtx = cancel
|
||||
b.mu.Unlock()
|
||||
|
||||
if b.RefreshCreds != nil {
|
||||
if err := b.RefreshCreds(); err != nil {
|
||||
fmt.Fprintln(stderr(), "telegram: refresh creds:", err)
|
||||
}
|
||||
}
|
||||
b.runTurn(turnCtx, t)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// runTurn sends the queued prompt to the agent and streams the reply.
|
||||
func (b *Bot) runTurn(ctx context.Context, t queuedTurn) {
|
||||
stopTyping := b.startTyping(ctx, t.chatID)
|
||||
defer stopTyping()
|
||||
|
||||
var replyBuilder strings.Builder
|
||||
var lastAssistantText string
|
||||
var turnErr error
|
||||
|
||||
sink := func(ev core.AgentEvent) {
|
||||
switch e := ev.(type) {
|
||||
case core.EvTextDelta:
|
||||
replyBuilder.WriteString(e.Delta)
|
||||
case core.EvUsage:
|
||||
b.mu.Lock()
|
||||
if e.Usage.InputTokens > 0 {
|
||||
b.lastCtxInput = e.Usage.InputTokens + e.Usage.CacheReadTokens + e.Usage.CacheWriteTokens
|
||||
}
|
||||
b.mu.Unlock()
|
||||
case core.EvAssistantMessage:
|
||||
var sb strings.Builder
|
||||
for _, c := range e.Message.Content {
|
||||
if tb, ok := c.(provider.TextBlock); ok {
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString(tb.Text)
|
||||
}
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
lastAssistantText = sb.String()
|
||||
}
|
||||
replyBuilder.Reset()
|
||||
case core.EvTurnEnd:
|
||||
if e.Err != nil {
|
||||
turnErr = e.Err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := b.Agent.Prompt(ctx, t.prompt, t.images, sink); err != nil {
|
||||
turnErr = err
|
||||
}
|
||||
|
||||
reply := strings.TrimSpace(lastAssistantText)
|
||||
if reply == "" {
|
||||
reply = strings.TrimSpace(replyBuilder.String())
|
||||
}
|
||||
if turnErr != nil && ctx.Err() == nil {
|
||||
reply = "error: " + turnErr.Error()
|
||||
}
|
||||
if reply == "" {
|
||||
reply = "(no reply)"
|
||||
}
|
||||
// Telegram caps messages at 4096 chars. Chunk to be safe.
|
||||
for _, chunk := range chunkMessage(reply, 4000) {
|
||||
if err := b.Client.SendMessage(context.Background(), t.chatID, chunk, 0); err != nil {
|
||||
fmt.Fprintln(stderr(), "telegram: sendMessage:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startTyping keeps Telegram's "typing..." indicator alive until the
|
||||
// returned stop function is called.
|
||||
func (b *Bot) startTyping(ctx context.Context, chatID int64) func() {
|
||||
tctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
for {
|
||||
_ = b.Client.SendChatAction(tctx, chatID, "typing")
|
||||
select {
|
||||
case <-tctx.Done():
|
||||
return
|
||||
case <-time.After(4 * time.Second):
|
||||
}
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (b *Bot) cancelActiveTurn(ctx context.Context, chatID int64, replyTo int) {
|
||||
b.mu.Lock()
|
||||
cancel := b.activeCtx
|
||||
b.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
_ = b.Client.SendMessage(ctx, chatID, "cancelled the current turn.", replyTo)
|
||||
} else {
|
||||
_ = b.Client.SendMessage(ctx, chatID, "nothing running.", replyTo)
|
||||
}
|
||||
}
|
||||
|
||||
// sendStatus describes agent state to the Telegram user.
|
||||
func (b *Bot) sendStatus(ctx context.Context, chatID int64, replyTo int) error {
|
||||
b.mu.Lock()
|
||||
busy := b.busy
|
||||
queued := len(b.queue)
|
||||
ctxUsed := b.lastCtxInput
|
||||
providerName := b.Provider
|
||||
authMethod := b.AuthMethod
|
||||
cwd := b.CWD
|
||||
b.mu.Unlock()
|
||||
|
||||
model := b.Agent.Model
|
||||
ctxMax := 0
|
||||
if m, err := provider.FindModel(providerName, model); err == nil {
|
||||
ctxMax = m.ContextWindow
|
||||
}
|
||||
return b.Client.SendMessage(ctx, chatID, FormatStatus(StatusSnapshot{
|
||||
Provider: providerName,
|
||||
Model: model,
|
||||
CWD: cwd,
|
||||
Usage: b.Agent.Cost(),
|
||||
Subscription: authMethod == "oauth",
|
||||
ContextUsed: ctxUsed,
|
||||
ContextMax: ctxMax,
|
||||
Busy: busy,
|
||||
Queued: queued,
|
||||
}), replyTo)
|
||||
}
|
||||
|
||||
// download fetches a file from Telegram and returns bytes + mime.
|
||||
func (b *Bot) download(ctx context.Context, fileID, mime string) ([]byte, string, error) {
|
||||
f, err := b.Client.GetFile(ctx, fileID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
data, err := b.Client.DownloadFile(ctx, f.FilePath)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if mime == "" {
|
||||
mime = guessImageMIME(f.FilePath)
|
||||
}
|
||||
return data, mime, nil
|
||||
}
|
||||
|
||||
// chunkMessage splits s into chunks no larger than limit runes, on line
|
||||
// boundaries when possible.
|
||||
func chunkMessage(s string, limit int) []string {
|
||||
if len(s) <= limit {
|
||||
if limit <= 0 || utf8.RuneCountInString(s) <= limit {
|
||||
return []string{s}
|
||||
}
|
||||
var out []string
|
||||
lines := strings.Split(s, "\n")
|
||||
var cur strings.Builder
|
||||
curRunes := 0
|
||||
for _, l := range lines {
|
||||
if cur.Len()+len(l)+1 > limit && cur.Len() > 0 {
|
||||
lineRunes := utf8.RuneCountInString(l)
|
||||
sepRunes := 0
|
||||
if curRunes > 0 {
|
||||
sepRunes = 1
|
||||
}
|
||||
if curRunes+sepRunes+lineRunes > limit && curRunes > 0 {
|
||||
out = append(out, cur.String())
|
||||
cur.Reset()
|
||||
curRunes = 0
|
||||
sepRunes = 0
|
||||
}
|
||||
if len(l) > limit {
|
||||
// Line itself too long; hard-split.
|
||||
for len(l) > limit {
|
||||
out = append(out, l[:limit])
|
||||
l = l[limit:]
|
||||
if lineRunes > limit {
|
||||
// Line itself too long; hard-split on rune boundaries.
|
||||
for lineRunes > limit {
|
||||
i := byteIndexAfterRunes(l, limit)
|
||||
out = append(out, l[:i])
|
||||
l = l[i:]
|
||||
lineRunes = utf8.RuneCountInString(l)
|
||||
}
|
||||
}
|
||||
if cur.Len() > 0 {
|
||||
if curRunes > 0 {
|
||||
cur.WriteString("\n")
|
||||
curRunes++
|
||||
}
|
||||
cur.WriteString(l)
|
||||
curRunes += utf8.RuneCountInString(l)
|
||||
}
|
||||
if cur.Len() > 0 {
|
||||
if curRunes > 0 {
|
||||
out = append(out, cur.String())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func byteIndexAfterRunes(s string, n int) int {
|
||||
if n <= 0 {
|
||||
return 0
|
||||
}
|
||||
count := 0
|
||||
for i := range s {
|
||||
if count == n {
|
||||
return i
|
||||
}
|
||||
count++
|
||||
}
|
||||
return len(s)
|
||||
}
|
||||
|
||||
// isImageMIME returns true for MIME types the model can probably ingest
|
||||
// as a vision input.
|
||||
func isImageMIME(m string) bool {
|
||||
|
|
|
|||
31
packages/agent/modes/telegram/bot_test.go
Normal file
31
packages/agent/modes/telegram/bot_test.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package telegram
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestChunkMessageUsesRuneLimit(t *testing.T) {
|
||||
s := strings.Repeat("界", 4001)
|
||||
chunks := chunkMessage(s, 4000)
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("len(chunks) = %d, want 2", len(chunks))
|
||||
}
|
||||
if got := utf8.RuneCountInString(chunks[0]); got != 4000 {
|
||||
t.Fatalf("first chunk runes = %d, want 4000", got)
|
||||
}
|
||||
if got := utf8.RuneCountInString(chunks[1]); got != 1 {
|
||||
t.Fatalf("second chunk runes = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkMessageDoesNotSplitMultiByteRune(t *testing.T) {
|
||||
chunks := chunkMessage("🙂🙂🙂", 2)
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("len(chunks) = %d, want 2", len(chunks))
|
||||
}
|
||||
if chunks[0] != "🙂🙂" || chunks[1] != "🙂" {
|
||||
t.Fatalf("chunks = %#v", chunks)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,11 +1,6 @@
|
|||
package telegram
|
||||
|
||||
import "strings"
|
||||
import "github.com/patriceckhart/zot/packages/agent/modes/bot"
|
||||
|
||||
// isStopCommand reports whether text should abort the active turn.
|
||||
// Telegram users often type plain "stop" rather than bot-style
|
||||
// "/stop"; keep this intentionally narrow so normal prompts like
|
||||
// "stop doing X" still go to the agent.
|
||||
func isStopCommand(text string) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(text), "stop")
|
||||
}
|
||||
// isStopCommand is a shim to bot.IsStopCommand for backward compatibility.
|
||||
var isStopCommand = bot.IsStopCommand
|
||||
|
|
|
|||
|
|
@ -1,120 +1,9 @@
|
|||
package telegram
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
import "github.com/patriceckhart/zot/packages/agent/modes/bot"
|
||||
|
||||
"github.com/patriceckhart/zot/packages/provider"
|
||||
)
|
||||
// StatusSnapshot is an alias for bot.StatusSnapshot for backward compatibility.
|
||||
type StatusSnapshot = bot.StatusSnapshot
|
||||
|
||||
// StatusSnapshot is the small cross-host state bundle rendered for
|
||||
// Telegram /status replies.
|
||||
type StatusSnapshot struct {
|
||||
Provider string
|
||||
Model string
|
||||
CWD string
|
||||
Usage provider.Usage
|
||||
Subscription bool
|
||||
ContextUsed int
|
||||
ContextMax int
|
||||
Busy bool
|
||||
Queued int
|
||||
}
|
||||
|
||||
// FormatStatus renders the same compact model/usage/cost/context
|
||||
// information shown in the TUI status bar, plus the current directory.
|
||||
func FormatStatus(s StatusSnapshot) string {
|
||||
providerName := strings.TrimSpace(s.Provider)
|
||||
model := strings.TrimSpace(s.Model)
|
||||
if providerName == "" {
|
||||
providerName = "unknown"
|
||||
}
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
|
||||
var stats []string
|
||||
if s.Usage.InputTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("↑%s", formatTokens(s.Usage.InputTokens)))
|
||||
}
|
||||
if s.Usage.OutputTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("↓%s", formatTokens(s.Usage.OutputTokens)))
|
||||
}
|
||||
if s.Usage.CacheReadTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("R%s", formatTokens(s.Usage.CacheReadTokens)))
|
||||
}
|
||||
if s.Usage.CacheWriteTokens > 0 {
|
||||
stats = append(stats, fmt.Sprintf("W%s", formatTokens(s.Usage.CacheWriteTokens)))
|
||||
}
|
||||
if s.Usage.CostUSD > 0 || s.Subscription {
|
||||
cost := fmt.Sprintf("$%.3f", s.Usage.CostUSD)
|
||||
if s.Subscription {
|
||||
cost += " (sub)"
|
||||
}
|
||||
stats = append(stats, cost)
|
||||
}
|
||||
if ctx := contextUsage(s.ContextUsed, s.ContextMax); ctx != "" {
|
||||
stats = append(stats, ctx)
|
||||
}
|
||||
|
||||
line := fmt.Sprintf("(%s) %s", providerName, model)
|
||||
if len(stats) > 0 {
|
||||
line += " " + strings.Join(stats, " ")
|
||||
}
|
||||
|
||||
state := "idle"
|
||||
if s.Busy {
|
||||
state = "working"
|
||||
}
|
||||
lines := []string{line, "state: " + state}
|
||||
if s.Queued > 0 {
|
||||
lines = append(lines, fmt.Sprintf("queued: %d", s.Queued))
|
||||
}
|
||||
if cwd := shortenHome(strings.TrimSpace(s.CWD)); cwd != "" {
|
||||
lines = append(lines, "cwd: "+cwd)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func contextUsage(used, max int) string {
|
||||
if max <= 0 {
|
||||
if used <= 0 {
|
||||
return ""
|
||||
}
|
||||
return formatTokens(used)
|
||||
}
|
||||
pct := float64(used) / float64(max) * 100
|
||||
return fmt.Sprintf("%.1f%%/%s", pct, formatTokens(max))
|
||||
}
|
||||
|
||||
func formatTokens(n int) string {
|
||||
switch {
|
||||
case n < 0:
|
||||
return "0"
|
||||
case n < 1000:
|
||||
return fmt.Sprintf("%d", n)
|
||||
case n < 10000:
|
||||
return fmt.Sprintf("%.1fk", float64(n)/1000)
|
||||
case n < 1_000_000:
|
||||
return fmt.Sprintf("%dk", (n+500)/1000)
|
||||
case n < 10_000_000:
|
||||
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
||||
default:
|
||||
return fmt.Sprintf("%dM", (n+500_000)/1_000_000)
|
||||
}
|
||||
}
|
||||
|
||||
func shortenHome(path string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || home == "" {
|
||||
return path
|
||||
}
|
||||
if path == home {
|
||||
return "~"
|
||||
}
|
||||
if strings.HasPrefix(path, home+string(os.PathSeparator)) {
|
||||
return "~" + strings.TrimPrefix(path, home)
|
||||
}
|
||||
return path
|
||||
}
|
||||
// FormatStatus is an alias for bot.FormatStatus for backward compatibility.
|
||||
var FormatStatus = bot.FormatStatus
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue