diff --git a/packages/provider/openai.go b/packages/provider/openai.go index 22c0d6e..79a9f2f 100644 --- a/packages/provider/openai.go +++ b/packages/provider/openai.go @@ -8,12 +8,32 @@ import ( "fmt" "io" "net/http" + "regexp" "strings" "time" ) const openaiDefaultBaseURL = "https://api.openai.com" +// versionSegmentSuffix matches a trailing API version segment such as +// "/v1" or Z.AI's "/v4". +var versionSegmentSuffix = regexp.MustCompile(`/v\d+$`) + +// chatCompletionsURL builds the chat-completions endpoint for an +// OpenAI-compatible base URL. A base that already carries an API +// version segment gets "/chat/completions" appended directly; a bare +// host (e.g. api.openai.com) gets the conventional "/v1/chat/completions". +// +// Matching any "/vN" segment (not just "/v1") keeps Z.AI's coding-plan +// base, which ends in "/paas/v4", from getting a spurious "/v1" that +// yields ".../paas/v4/v1/chat/completions" and a 404. +func chatCompletionsURL(baseURL string) string { + if versionSegmentSuffix.MatchString(baseURL) { + return baseURL + "/chat/completions" + } + return baseURL + "/v1/chat/completions" +} + type openaiClient struct { apiKey string baseURL string @@ -412,10 +432,7 @@ func buildOAIContentBlocks(blocks []Content, isError bool) []interface{} { // ---- streaming ---- func (c *openaiClient) Stream(ctx context.Context, req Request) (<-chan Event, error) { - apiPath := "/v1/chat/completions" - if strings.HasSuffix(c.baseURL, "/v1") { - apiPath = "/chat/completions" - } + endpoint := chatCompletionsURL(c.baseURL) wire, err := c.buildRequest(req) if err != nil { return nil, err @@ -425,7 +442,7 @@ func (c *openaiClient) Stream(ctx context.Context, req Request) (<-chan Event, e return nil, err } newReq := func() (*http.Request, error) { - httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+apiPath, bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) if err != nil { return nil, err } diff --git a/packages/provider/openai_url_test.go b/packages/provider/openai_url_test.go new file mode 100644 index 0000000..ef89bd3 --- /dev/null +++ b/packages/provider/openai_url_test.go @@ -0,0 +1,40 @@ +package provider + +import "testing" + +// TestChatCompletionsURL pins the endpoint built for OpenAI-compatible +// providers. The regression that motivated it: Z.AI's coding-plan base +// carries a "/v4" version segment, so blindly appending +// "/v1/chat/completions" produced ".../paas/v4/v1/chat/completions" +// and a 404. Any base that already ends in a version segment must get +// "/chat/completions" appended directly. +func TestChatCompletionsURL(t *testing.T) { + cases := []struct { + name string + baseURL string + want string + }{ + { + name: "bare host gets conventional /v1 path", + baseURL: "https://api.openai.com", + want: "https://api.openai.com/v1/chat/completions", + }, + { + name: "v1 base is not doubled", + baseURL: "https://api.moonshot.ai/v1", + want: "https://api.moonshot.ai/v1/chat/completions", + }, + { + name: "zai coding plan v4 base is not given a spurious /v1", + baseURL: "https://api.z.ai/api/coding/paas/v4", + want: "https://api.z.ai/api/coding/paas/v4/chat/completions", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := chatCompletionsURL(tc.baseURL); got != tc.want { + t.Errorf("chatCompletionsURL(%q) = %q, want %q", tc.baseURL, got, tc.want) + } + }) + } +}