mirror of
https://github.com/patriceckhart/zot.git
synced 2026-06-26 21:36:31 +02:00
feat: auto-refresh OAuth tokens before each API call
Wraps OAuth clients with a RefreshingClient that checks token expiry before every Stream call. Refreshes transparently and rebuilds the underlying client with the fresh token. Fixes sessions silently dying after the 1-hour token lifetime.
This commit is contained in:
parent
cb9de10ec6
commit
8841800acd
3 changed files with 110 additions and 2 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -337,17 +338,47 @@ func (r Resolved) NewClient() provider.Client {
|
|||
return provider.NewOpenAI(r.Credential, r.BaseURL)
|
||||
case "openai":
|
||||
if r.AuthMethod == "oauth" {
|
||||
return provider.NewOpenAICodex(r.Credential, r.AccountID, r.BaseURL)
|
||||
inner := provider.NewOpenAICodex(r.Credential, r.AccountID, r.BaseURL)
|
||||
return r.wrapWithRefresh(inner)
|
||||
}
|
||||
return provider.NewOpenAI(r.Credential, r.BaseURL)
|
||||
default:
|
||||
if r.AuthMethod == "oauth" {
|
||||
return provider.NewAnthropicOAuth(r.Credential, r.BaseURL)
|
||||
inner := provider.NewAnthropicOAuth(r.Credential, r.BaseURL)
|
||||
return r.wrapWithRefresh(inner)
|
||||
}
|
||||
return provider.NewAnthropic(r.Credential, r.BaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
// wrapWithRefresh wraps an OAuth client so the access token is
|
||||
// refreshed automatically before each API call. Without this, long
|
||||
// sessions (hours) silently fail when the 1-hour token expires.
|
||||
func (r Resolved) wrapWithRefresh(inner provider.Client) provider.Client {
|
||||
provName := r.Provider
|
||||
baseURL := r.BaseURL
|
||||
accountID := r.AccountID
|
||||
|
||||
refreshFn := func(ctx context.Context) (string, error) {
|
||||
tok, err := refreshIfExpired(provName, loadOAuthToken(provName))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tok.AccessToken, nil
|
||||
}
|
||||
|
||||
factory := func(token string) provider.Client {
|
||||
switch provName {
|
||||
case "openai":
|
||||
return provider.NewOpenAICodex(token, accountID, baseURL)
|
||||
default:
|
||||
return provider.NewAnthropicOAuth(token, baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
return provider.NewRefreshingClient(inner, refreshFn, factory)
|
||||
}
|
||||
|
||||
// UseSandbox replaces the sandbox pointer that every tool in r's
|
||||
// registry references. Used to keep the /jail state stable across
|
||||
// agent rebuilds (e.g. /login, /model switching providers).
|
||||
|
|
|
|||
|
|
@ -155,6 +155,26 @@ func ResolveCredentialFull(provider, explicit string) (cred, method, accountID s
|
|||
return "", "", "", fmt.Errorf("no credential for %s", provider)
|
||||
}
|
||||
|
||||
// loadOAuthToken reads the current OAuth token from auth.json for the
|
||||
// given provider. Returns nil if no token is stored.
|
||||
func loadOAuthToken(providerName string) *auth.OAuthToken {
|
||||
c, err := AuthStoreFor().Load()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
switch providerName {
|
||||
case "anthropic":
|
||||
if c.Anthropic.OAuth != nil {
|
||||
return c.Anthropic.OAuth
|
||||
}
|
||||
case "openai":
|
||||
if c.OpenAI.OAuth != nil {
|
||||
return c.OpenAI.OAuth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshIfExpired returns a usable OAuth token for the given provider,
|
||||
// refreshing it synchronously when it's past (or near) expiry. The
|
||||
// refreshed token is persisted to auth.json.
|
||||
|
|
|
|||
57
internal/provider/refreshing.go
Normal file
57
internal/provider/refreshing.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TokenRefresher is a callback that checks whether the current token
|
||||
// is still valid and returns a fresh one if needed. The returned
|
||||
// string is the new access token; if empty the old one is still fine.
|
||||
// An error means refresh failed (network down, refresh token expired,
|
||||
// etc.) — the caller should proceed with the stale token and let the
|
||||
// API return 401 naturally.
|
||||
type TokenRefresher func(ctx context.Context) (newToken string, err error)
|
||||
|
||||
// RefreshingClient wraps a Client and calls a TokenRefresher before
|
||||
// every Stream call. When the refresher returns a new token, a fresh
|
||||
// underlying client is built via the factory function.
|
||||
type RefreshingClient struct {
|
||||
mu sync.Mutex
|
||||
inner Client
|
||||
refresh TokenRefresher
|
||||
factory func(token string) Client
|
||||
}
|
||||
|
||||
// NewRefreshingClient wraps inner with automatic token refresh.
|
||||
// refreshFn is called before each Stream; if it returns a non-empty
|
||||
// token the factory rebuilds the underlying client with the new token.
|
||||
func NewRefreshingClient(inner Client, refreshFn TokenRefresher, factory func(token string) Client) *RefreshingClient {
|
||||
return &RefreshingClient{
|
||||
inner: inner,
|
||||
refresh: refreshFn,
|
||||
factory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RefreshingClient) Name() string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.inner.Name()
|
||||
}
|
||||
|
||||
func (c *RefreshingClient) Stream(ctx context.Context, req Request) (<-chan Event, error) {
|
||||
if c.refresh != nil {
|
||||
if newToken, err := c.refresh(ctx); err == nil && newToken != "" {
|
||||
c.mu.Lock()
|
||||
c.inner = c.factory(newToken)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
// On refresh error: proceed with the current client.
|
||||
// The stale token will 401 and the user sees a clear error.
|
||||
}
|
||||
c.mu.Lock()
|
||||
inner := c.inner
|
||||
c.mu.Unlock()
|
||||
return inner.Stream(ctx, req)
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue