feat: auto-refresh OAuth tokens before each API call

Wraps OAuth clients with a RefreshingClient that checks token expiry before every Stream call. Refreshes transparently and rebuilds the underlying client with the fresh token. Fixes sessions silently dying after the 1-hour token lifetime.
This commit is contained in:
patriceckhart 2026-04-24 19:37:44 +02:00
parent cb9de10ec6
commit 8841800acd
3 changed files with 110 additions and 2 deletions

View file

@ -1,6 +1,7 @@
package agent
import (
"context"
"fmt"
"os"
"path/filepath"
@ -337,17 +338,47 @@ func (r Resolved) NewClient() provider.Client {
return provider.NewOpenAI(r.Credential, r.BaseURL)
case "openai":
if r.AuthMethod == "oauth" {
return provider.NewOpenAICodex(r.Credential, r.AccountID, r.BaseURL)
inner := provider.NewOpenAICodex(r.Credential, r.AccountID, r.BaseURL)
return r.wrapWithRefresh(inner)
}
return provider.NewOpenAI(r.Credential, r.BaseURL)
default:
if r.AuthMethod == "oauth" {
return provider.NewAnthropicOAuth(r.Credential, r.BaseURL)
inner := provider.NewAnthropicOAuth(r.Credential, r.BaseURL)
return r.wrapWithRefresh(inner)
}
return provider.NewAnthropic(r.Credential, r.BaseURL)
}
}
// wrapWithRefresh wraps an OAuth client so the access token is
// refreshed automatically before each API call. Without this, long
// sessions (hours) silently fail when the 1-hour token expires.
func (r Resolved) wrapWithRefresh(inner provider.Client) provider.Client {
provName := r.Provider
baseURL := r.BaseURL
accountID := r.AccountID
refreshFn := func(ctx context.Context) (string, error) {
tok, err := refreshIfExpired(provName, loadOAuthToken(provName))
if err != nil {
return "", err
}
return tok.AccessToken, nil
}
factory := func(token string) provider.Client {
switch provName {
case "openai":
return provider.NewOpenAICodex(token, accountID, baseURL)
default:
return provider.NewAnthropicOAuth(token, baseURL)
}
}
return provider.NewRefreshingClient(inner, refreshFn, factory)
}
// UseSandbox replaces the sandbox pointer that every tool in r's
// registry references. Used to keep the /jail state stable across
// agent rebuilds (e.g. /login, /model switching providers).

View file

@ -155,6 +155,26 @@ func ResolveCredentialFull(provider, explicit string) (cred, method, accountID s
return "", "", "", fmt.Errorf("no credential for %s", provider)
}
// loadOAuthToken reads the current OAuth token from auth.json for the
// given provider. Returns nil if no token is stored.
func loadOAuthToken(providerName string) *auth.OAuthToken {
c, err := AuthStoreFor().Load()
if err != nil {
return nil
}
switch providerName {
case "anthropic":
if c.Anthropic.OAuth != nil {
return c.Anthropic.OAuth
}
case "openai":
if c.OpenAI.OAuth != nil {
return c.OpenAI.OAuth
}
}
return nil
}
// refreshIfExpired returns a usable OAuth token for the given provider,
// refreshing it synchronously when it's past (or near) expiry. The
// refreshed token is persisted to auth.json.

View file

@ -0,0 +1,57 @@
package provider
import (
"context"
"sync"
)
// TokenRefresher is a callback that checks whether the current token
// is still valid and returns a fresh one if needed. The returned
// string is the new access token; if empty the old one is still fine.
// An error means refresh failed (network down, refresh token expired,
// etc.) — the caller should proceed with the stale token and let the
// API return 401 naturally.
type TokenRefresher func(ctx context.Context) (newToken string, err error)
// RefreshingClient wraps a Client and calls a TokenRefresher before
// every Stream call. When the refresher returns a new token, a fresh
// underlying client is built via the factory function.
type RefreshingClient struct {
mu sync.Mutex
inner Client
refresh TokenRefresher
factory func(token string) Client
}
// NewRefreshingClient wraps inner with automatic token refresh.
// refreshFn is called before each Stream; if it returns a non-empty
// token the factory rebuilds the underlying client with the new token.
func NewRefreshingClient(inner Client, refreshFn TokenRefresher, factory func(token string) Client) *RefreshingClient {
return &RefreshingClient{
inner: inner,
refresh: refreshFn,
factory: factory,
}
}
func (c *RefreshingClient) Name() string {
c.mu.Lock()
defer c.mu.Unlock()
return c.inner.Name()
}
func (c *RefreshingClient) Stream(ctx context.Context, req Request) (<-chan Event, error) {
if c.refresh != nil {
if newToken, err := c.refresh(ctx); err == nil && newToken != "" {
c.mu.Lock()
c.inner = c.factory(newToken)
c.mu.Unlock()
}
// On refresh error: proceed with the current client.
// The stale token will 401 and the user sees a clear error.
}
c.mu.Lock()
inner := c.inner
c.mu.Unlock()
return inner.Stream(ctx, req)
}