diff --git a/internal/agent/build.go b/internal/agent/build.go index 26a51cf..253a9ca 100644 --- a/internal/agent/build.go +++ b/internal/agent/build.go @@ -1,6 +1,7 @@ package agent import ( + "context" "fmt" "os" "path/filepath" @@ -337,17 +338,47 @@ func (r Resolved) NewClient() provider.Client { return provider.NewOpenAI(r.Credential, r.BaseURL) case "openai": if r.AuthMethod == "oauth" { - return provider.NewOpenAICodex(r.Credential, r.AccountID, r.BaseURL) + inner := provider.NewOpenAICodex(r.Credential, r.AccountID, r.BaseURL) + return r.wrapWithRefresh(inner) } return provider.NewOpenAI(r.Credential, r.BaseURL) default: if r.AuthMethod == "oauth" { - return provider.NewAnthropicOAuth(r.Credential, r.BaseURL) + inner := provider.NewAnthropicOAuth(r.Credential, r.BaseURL) + return r.wrapWithRefresh(inner) } return provider.NewAnthropic(r.Credential, r.BaseURL) } } +// wrapWithRefresh wraps an OAuth client so the access token is +// refreshed automatically before each API call. Without this, long +// sessions (hours) silently fail when the 1-hour token expires. +func (r Resolved) wrapWithRefresh(inner provider.Client) provider.Client { + provName := r.Provider + baseURL := r.BaseURL + accountID := r.AccountID + + refreshFn := func(ctx context.Context) (string, error) { + tok, err := refreshIfExpired(provName, loadOAuthToken(provName)) + if err != nil { + return "", err + } + return tok.AccessToken, nil + } + + factory := func(token string) provider.Client { + switch provName { + case "openai": + return provider.NewOpenAICodex(token, accountID, baseURL) + default: + return provider.NewAnthropicOAuth(token, baseURL) + } + } + + return provider.NewRefreshingClient(inner, refreshFn, factory) +} + // UseSandbox replaces the sandbox pointer that every tool in r's // registry references. Used to keep the /jail state stable across // agent rebuilds (e.g. /login, /model switching providers). diff --git a/internal/agent/config.go b/internal/agent/config.go index 0da2cec..e8acce0 100644 --- a/internal/agent/config.go +++ b/internal/agent/config.go @@ -155,6 +155,26 @@ func ResolveCredentialFull(provider, explicit string) (cred, method, accountID s return "", "", "", fmt.Errorf("no credential for %s", provider) } +// loadOAuthToken reads the current OAuth token from auth.json for the +// given provider. Returns nil if no token is stored. +func loadOAuthToken(providerName string) *auth.OAuthToken { + c, err := AuthStoreFor().Load() + if err != nil { + return nil + } + switch providerName { + case "anthropic": + if c.Anthropic.OAuth != nil { + return c.Anthropic.OAuth + } + case "openai": + if c.OpenAI.OAuth != nil { + return c.OpenAI.OAuth + } + } + return nil +} + // refreshIfExpired returns a usable OAuth token for the given provider, // refreshing it synchronously when it's past (or near) expiry. The // refreshed token is persisted to auth.json. diff --git a/internal/provider/refreshing.go b/internal/provider/refreshing.go new file mode 100644 index 0000000..49848bf --- /dev/null +++ b/internal/provider/refreshing.go @@ -0,0 +1,57 @@ +package provider + +import ( + "context" + "sync" +) + +// TokenRefresher is a callback that checks whether the current token +// is still valid and returns a fresh one if needed. The returned +// string is the new access token; if empty the old one is still fine. +// An error means refresh failed (network down, refresh token expired, +// etc.) — the caller should proceed with the stale token and let the +// API return 401 naturally. +type TokenRefresher func(ctx context.Context) (newToken string, err error) + +// RefreshingClient wraps a Client and calls a TokenRefresher before +// every Stream call. When the refresher returns a new token, a fresh +// underlying client is built via the factory function. +type RefreshingClient struct { + mu sync.Mutex + inner Client + refresh TokenRefresher + factory func(token string) Client +} + +// NewRefreshingClient wraps inner with automatic token refresh. +// refreshFn is called before each Stream; if it returns a non-empty +// token the factory rebuilds the underlying client with the new token. +func NewRefreshingClient(inner Client, refreshFn TokenRefresher, factory func(token string) Client) *RefreshingClient { + return &RefreshingClient{ + inner: inner, + refresh: refreshFn, + factory: factory, + } +} + +func (c *RefreshingClient) Name() string { + c.mu.Lock() + defer c.mu.Unlock() + return c.inner.Name() +} + +func (c *RefreshingClient) Stream(ctx context.Context, req Request) (<-chan Event, error) { + if c.refresh != nil { + if newToken, err := c.refresh(ctx); err == nil && newToken != "" { + c.mu.Lock() + c.inner = c.factory(newToken) + c.mu.Unlock() + } + // On refresh error: proceed with the current client. + // The stale token will 401 and the user sees a clear error. + } + c.mu.Lock() + inner := c.inner + c.mu.Unlock() + return inner.Stream(ctx, req) +}