diff --git a/apps/web/src/components/add-provider/index.vue b/apps/web/src/components/add-provider/index.vue
index 6d45e351..d3bbdc20 100644
--- a/apps/web/src/components/add-provider/index.vue
+++ b/apps/web/src/components/add-provider/index.vue
@@ -47,6 +47,7 @@
- {{ $t('bots.timezoneInheritedHint') }} -
-You can close this window and return to Memoh.
+ + +`)) + return c.HTML(http.StatusOK, executeHTMLTemplate(page, map[string]string{"ProviderID": providerID})) +} + +func executeHTMLTemplate(tpl *template.Template, data any) string { + var b strings.Builder + _ = tpl.Execute(&b, data) + return b.String() +} diff --git a/internal/handlers/providers.go b/internal/handlers/providers.go index b53466bd..37eaecf4 100644 --- a/internal/handlers/providers.go +++ b/internal/handlers/providers.go @@ -309,20 +309,31 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err)) } - defaultCompat := []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning} - resp := providers.ImportModelsResponse{ Models: make([]string, 0), } for _, m := range remoteModels { + modelType := models.ModelTypeChat + if strings.TrimSpace(m.Type) == string(models.ModelTypeEmbedding) { + modelType = models.ModelTypeEmbedding + } + compatibilities := m.Compatibilities + if len(compatibilities) == 0 { + compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning} + } + name := strings.TrimSpace(m.Name) + if name == "" { + name = m.ID + } _, err := h.modelsService.Create(c.Request().Context(), models.AddRequest{ ModelID: m.ID, - Name: m.ID, + Name: name, LlmProviderID: id, - Type: models.ModelTypeChat, + Type: modelType, Config: models.ModelConfig{ - Compatibilities: defaultCompat, + Compatibilities: compatibilities, + ReasoningEfforts: m.ReasoningEfforts, }, }) if err != nil { diff --git a/internal/memory/adapters/builtin/dense_runtime.go b/internal/memory/adapters/builtin/dense_runtime.go index d0006d0c..92d47c11 100644 --- a/internal/memory/adapters/builtin/dense_runtime.go +++ b/internal/memory/adapters/builtin/dense_runtime.go @@ -75,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg return nil, fmt.Errorf("dense runtime: %w", err) } - embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout) + embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout, nil) return &denseRuntime{ qdrant: qClient, diff --git a/internal/models/embedding.go b/internal/models/embedding.go index 6ff8fd0b..831b9e44 100644 --- a/internal/models/embedding.go +++ b/internal/models/embedding.go @@ -13,11 +13,13 @@ import ( // provider configuration. It dispatches to the native Google embedding provider // when clientType is "google-generative-ai", and falls back to the // OpenAI-compatible /embeddings endpoint for all other provider types. -func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel { +func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration, httpClient *http.Client) *sdk.EmbeddingModel { if timeout <= 0 { timeout = 30 * time.Second } - httpClient := &http.Client{Timeout: timeout} + if httpClient == nil { + httpClient = &http.Client{Timeout: timeout} + } switch ClientType(clientType) { case ClientTypeGoogleGenerativeAI: @@ -30,7 +32,6 @@ func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout t } p := googleembedding.New(opts...) return p.EmbeddingModel(modelID) - default: opts := []openaiembedding.Option{ openaiembedding.WithAPIKey(apiKey), diff --git a/internal/models/probe.go b/internal/models/probe.go index c37dc8fe..412ecf35 100644 --- a/internal/models/probe.go +++ b/internal/models/probe.go @@ -2,6 +2,9 @@ package models import ( "context" + "encoding/base64" + "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -9,11 +12,13 @@ import ( anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages" googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai" + openaicodex "github.com/memohai/twilight-ai/provider/openai/codex" openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions" openairesponses "github.com/memohai/twilight-ai/provider/openai/responses" sdk "github.com/memohai/twilight-ai/sdk" "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" ) const probeTimeout = 15 * time.Second @@ -37,16 +42,17 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { } baseURL := strings.TrimRight(provider.BaseUrl, "/") - apiKey := provider.ApiKey clientType := ClientType(provider.ClientType) - - // Embedding models don't have a chat Provider in the SDK — probe - // the /embeddings endpoint directly. - if model.Type == string(ModelTypeEmbedding) { - return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID) + creds, err := s.resolveModelCredentials(ctx, provider) + if err != nil { + return TestResponse{}, err } - sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout) + if model.Type == string(ModelTypeEmbedding) { + return s.testEmbeddingModel(ctx, baseURL, creds.APIKey, model.ModelID, nil) + } + + sdkProvider := NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil) start := time.Now() @@ -100,11 +106,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { // testEmbeddingModel probes an embedding model by performing a minimal // embedding request via the Twilight SDK, verifying that the model is // reachable and functional rather than merely checking HTTP connectivity. -func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) { +func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string, httpClient *http.Client) (TestResponse, error) { ctx, cancel := context.WithTimeout(ctx, probeTimeout) defer cancel() - model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout) + model := NewSDKEmbeddingModel(string(ClientTypeOpenAICompletions), baseURL, apiKey, modelID, probeTimeout, httpClient) client := sdk.NewClient() start := time.Now() @@ -130,8 +136,10 @@ func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, api // NewSDKProvider creates a Twilight AI SDK Provider for the given client type. // It is exported so that other packages (e.g. providers) can reuse it for testing. -func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.Duration) sdk.Provider { - httpClient := &http.Client{Timeout: timeout} +func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientType, timeout time.Duration, httpClient *http.Client) sdk.Provider { + if httpClient == nil { + httpClient = &http.Client{Timeout: timeout} + } switch clientType { case ClientTypeOpenAIResponses: @@ -144,6 +152,16 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time. } return openairesponses.New(opts...) + case ClientTypeOpenAICodex: + opts := []openaicodex.Option{ + openaicodex.WithAccessToken(apiKey), + openaicodex.WithHTTPClient(httpClient), + } + if codexAccountID != "" { + opts = append(opts, openaicodex.WithAccountID(codexAccountID)) + } + return openaicodex.New(opts...) + case ClientTypeAnthropicMessages: opts := []anthropicmessages.Option{ anthropicmessages.WithAPIKey(apiKey), @@ -175,3 +193,55 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time. return openaicompletions.New(opts...) } } + +type modelCredentials struct { + APIKey string //nolint:gosec // runtime credential material used to construct SDK providers + CodexAccountID string +} + +func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (modelCredentials, error) { + if ClientType(provider.ClientType) != ClientTypeOpenAICodex { + return modelCredentials{APIKey: provider.ApiKey}, nil + } + + tokenRow, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, provider.ID) + if err != nil { + return modelCredentials{}, err + } + accessToken := strings.TrimSpace(tokenRow.AccessToken) + if accessToken == "" { + return modelCredentials{}, errors.New("oauth token is missing access token") + } + accountID, err := codexAccountIDFromToken(accessToken) + if err != nil { + return modelCredentials{}, err + } + return modelCredentials{ + APIKey: accessToken, + CodexAccountID: accountID, + }, nil +} + +func codexAccountIDFromToken(token string) (string, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return "", errors.New("invalid oauth access token") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode oauth token payload: %w", err) + } + var claims struct { + OpenAIAuth struct { + ChatGPTAccountID string `json:"chatgpt_account_id"` + } `json:"https://api.openai.com/auth"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("parse oauth token payload: %w", err) + } + accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID) + if accountID == "" { + return "", errors.New("oauth access token missing chatgpt_account_id") + } + return accountID, nil +} diff --git a/internal/models/sdk.go b/internal/models/sdk.go index d5357c6a..edc307e5 100644 --- a/internal/models/sdk.go +++ b/internal/models/sdk.go @@ -1,10 +1,12 @@ package models import ( + "net/http" "strings" anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages" googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai" + openaicodex "github.com/memohai/twilight-ai/provider/openai/codex" openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions" openairesponses "github.com/memohai/twilight-ai/provider/openai/responses" sdk "github.com/memohai/twilight-ai/sdk" @@ -16,7 +18,9 @@ type SDKModelConfig struct { ModelID string ClientType string APIKey string //nolint:gosec // carries provider credential material at runtime + CodexAccountID string BaseURL string + HTTPClient *http.Client ReasoningConfig *ReasoningConfig } @@ -38,6 +42,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL)) } @@ -48,16 +55,34 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openairesponses.Option{ openairesponses.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, openairesponses.WithBaseURL(cfg.BaseURL)) } p := openairesponses.New(opts...) return p.ChatModel(cfg.ModelID) + case ClientTypeOpenAICodex: + opts := []openaicodex.Option{ + openaicodex.WithAccessToken(cfg.APIKey), + } + if cfg.HTTPClient != nil { + opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient)) + } + if cfg.CodexAccountID != "" { + opts = append(opts, openaicodex.WithAccountID(cfg.CodexAccountID)) + } + return openaicodex.New(opts...).ChatModel(cfg.ModelID) + case ClientTypeAnthropicMessages: opts := []anthropicmessages.Option{ anthropicmessages.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL)) } @@ -75,6 +100,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []googlegenerative.Option{ googlegenerative.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, googlegenerative.WithBaseURL(cfg.BaseURL)) } @@ -85,6 +113,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL)) } @@ -106,7 +137,7 @@ func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption { switch ClientType(cfg.ClientType) { case ClientTypeAnthropicMessages: return nil - case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions: + case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions, ClientTypeOpenAICodex: return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)} case ClientTypeGoogleGenerativeAI: return nil @@ -147,6 +178,8 @@ func ResolveClientType(model *sdk.Model) string { return string(ClientTypeAnthropicMessages) case strings.Contains(name, "google"): return string(ClientTypeGoogleGenerativeAI) + case strings.Contains(name, "codex"): + return string(ClientTypeOpenAICodex) case strings.Contains(name, "responses"): return string(ClientTypeOpenAIResponses) default: diff --git a/internal/models/types.go b/internal/models/types.go index 0514a5cf..ff8b3c6e 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -20,6 +20,7 @@ const ( ClientTypeOpenAICompletions ClientType = "openai-completions" ClientTypeAnthropicMessages ClientType = "anthropic-messages" ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai" + ClientTypeOpenAICodex ClientType = "openai-codex" ) const ( @@ -29,16 +30,33 @@ const ( CompatReasoning = "reasoning" ) +const ( + ReasoningEffortNone = "none" + ReasoningEffortLow = "low" + ReasoningEffortMedium = "medium" + ReasoningEffortHigh = "high" + ReasoningEffortXHigh = "xhigh" +) + // validCompatibilities enumerates accepted compatibility tokens. var validCompatibilities = map[string]struct{}{ CompatVision: {}, CompatToolCall: {}, CompatImageOutput: {}, CompatReasoning: {}, } +var validReasoningEfforts = map[string]struct{}{ + ReasoningEffortNone: {}, + ReasoningEffortLow: {}, + ReasoningEffortMedium: {}, + ReasoningEffortHigh: {}, + ReasoningEffortXHigh: {}, +} + // ModelConfig holds the JSONB config stored per model. type ModelConfig struct { - Dimensions *int `json:"dimensions,omitempty"` - Compatibilities []string `json:"compatibilities,omitempty"` - ContextWindow *int `json:"context_window,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` + Compatibilities []string `json:"compatibilities,omitempty"` + ContextWindow *int `json:"context_window,omitempty"` + ReasoningEfforts []string `json:"reasoning_efforts,omitempty"` } type Model struct { @@ -72,6 +90,11 @@ func (m *Model) Validate() error { return errors.New("invalid compatibility: " + c) } } + for _, effort := range m.Config.ReasoningEfforts { + if _, ok := validReasoningEfforts[effort]; !ok { + return errors.New("invalid reasoning effort: " + effort) + } + } return nil } diff --git a/internal/providers/credentials.go b/internal/providers/credentials.go new file mode 100644 index 00000000..dc5f5e8a --- /dev/null +++ b/internal/providers/credentials.go @@ -0,0 +1,69 @@ +package providers + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" +) + +const openAIAuthClaimPath = "https://api.openai.com/auth" + +type ModelCredentials struct { + APIKey string //nolint:gosec // runtime credential material used to construct SDK providers + CodexAccountID string +} + +func SupportsOpenAICodexOAuth(provider sqlc.LlmProvider) bool { + return supportsOAuth(provider) +} + +func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (ModelCredentials, error) { + if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex { + return ModelCredentials{ + APIKey: provider.ApiKey, + }, nil + } + + token, err := s.GetValidAccessToken(ctx, provider.ID.String()) + if err != nil { + return ModelCredentials{}, err + } + accountID, err := codexAccountIDFromToken(token) + if err != nil { + return ModelCredentials{}, err + } + return ModelCredentials{ + APIKey: token, + CodexAccountID: accountID, + }, nil +} + +func codexAccountIDFromToken(token string) (string, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return "", errors.New("invalid oauth access token") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode oauth token payload: %w", err) + } + var claims struct { + OpenAIAuth struct { + ChatGPTAccountID string `json:"chatgpt_account_id"` + } `json:"https://api.openai.com/auth"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("parse oauth token payload: %w", err) + } + accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID) + if accountID == "" { + return "", fmt.Errorf("oauth access token missing %s.chatgpt_account_id", openAIAuthClaimPath) + } + return accountID, nil +} diff --git a/internal/providers/oauth.go b/internal/providers/oauth.go new file mode 100644 index 00000000..a03c16b2 --- /dev/null +++ b/internal/providers/oauth.go @@ -0,0 +1,468 @@ +package providers + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" +) + +const ( + defaultOpenAICodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + defaultOpenAIAuthorizeURL = "https://auth.openai.com/oauth/authorize" + defaultOpenAITokenURL = "https://auth.openai.com/oauth/token" //nolint:gosec // OAuth endpoint URL, not a credential + defaultOpenAICallbackURL = "http://localhost:1455/auth/callback" + defaultOpenAIOAuthScopes = "openid profile email offline_access" + oauthExpirySkew = 30 * time.Second + providerOAuthHTTPTimeout = 15 * time.Second + metadataOAuthClientIDKey = "oauth_client_id" + metadataOAuthAuthorizeURLKey = "oauth_authorize_url" + metadataOAuthTokenURLKey = "oauth_token_url" //nolint:gosec // metadata key name, not a credential + metadataOAuthRedirectURIKey = "oauth_redirect_uri" + metadataOAuthScopesKey = "oauth_scopes" + metadataOAuthAudienceKey = "oauth_audience" + metadataOAuthUseIDOrgsFlagKey = "oauth_id_token_add_organizations" +) + +type providerOAuthToken struct { + ProviderID string `json:"provider_id"` + AccessToken string `json:"access_token"` //nolint:gosec // runtime credential storage + RefreshToken string `json:"refresh_token"` //nolint:gosec // runtime credential storage + ExpiresAt time.Time `json:"expires_at"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` + State string `json:"state"` + PKCECodeVerifier string `json:"pkce_code_verifier"` +} + +type openAIOAuthConfig struct { + ClientID string + AuthorizeURL string + TokenURL string + RedirectURI string + Scopes string + IDTokenAddOrganizations bool +} + +func providerMetadata(raw []byte) map[string]any { + if len(raw) == 0 { + return map[string]any{} + } + var metadata map[string]any + if err := json.Unmarshal(raw, &metadata); err != nil { + return map[string]any{} + } + if metadata == nil { + return map[string]any{} + } + return metadata +} + +func (s *Service) oauthConfig(metadata map[string]any) openAIOAuthConfig { + cfg := openAIOAuthConfig{ + ClientID: defaultOpenAICodexClientID, + AuthorizeURL: defaultOpenAIAuthorizeURL, + TokenURL: defaultOpenAITokenURL, + RedirectURI: firstNonEmpty(s.callbackURL, defaultOpenAICallbackURL), + Scopes: defaultOpenAIOAuthScopes, + IDTokenAddOrganizations: true, + } + if v, _ := metadata[metadataOAuthClientIDKey].(string); strings.TrimSpace(v) != "" { + cfg.ClientID = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthAuthorizeURLKey].(string); strings.TrimSpace(v) != "" { + cfg.AuthorizeURL = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthTokenURLKey].(string); strings.TrimSpace(v) != "" { + cfg.TokenURL = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthRedirectURIKey].(string); strings.TrimSpace(v) != "" { + cfg.RedirectURI = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthScopesKey].(string); strings.TrimSpace(v) != "" { + cfg.Scopes = strings.TrimSpace(v) + } + if v, ok := metadata[metadataOAuthUseIDOrgsFlagKey].(bool); ok { + cfg.IDTokenAddOrganizations = v + } + return cfg +} + +func supportsOAuth(provider sqlc.LlmProvider) bool { + return models.ClientType(provider.ClientType) == models.ClientTypeOpenAICodex +} + +func (s *Service) StartOAuthAuthorization(ctx context.Context, providerID string) (string, error) { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return "", err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return "", fmt.Errorf("get provider: %w", err) + } + if !supportsOAuth(provider) { + return "", errors.New("provider does not support oauth") + } + + cfg := s.oauthConfig(providerMetadata(provider.Metadata)) + codeVerifier, err := generateCodeVerifier() + if err != nil { + return "", fmt.Errorf("generate code verifier: %w", err) + } + state, err := generateState() + if err != nil { + return "", fmt.Errorf("generate state: %w", err) + } + if err := s.updateOAuthState(ctx, providerID, state, codeVerifier); err != nil { + return "", err + } + + params := url.Values{ + "response_type": {"code"}, + "client_id": {cfg.ClientID}, + "redirect_uri": {cfg.RedirectURI}, + "scope": {cfg.Scopes}, + "code_challenge": {computeCodeChallenge(codeVerifier)}, + "code_challenge_method": {"S256"}, + "state": {state}, + } + if cfg.IDTokenAddOrganizations { + params.Set("id_token_add_organizations", "true") + } + params.Set("codex_cli_simplified_flow", "true") + + return cfg.AuthorizeURL + "?" + params.Encode(), nil +} + +func (s *Service) HandleOAuthCallback(ctx context.Context, state, code string) (string, error) { + token, err := s.getOAuthTokenByState(ctx, state) + if err != nil { + return "", err + } + providerUUID, err := db.ParseUUID(token.ProviderID) + if err != nil { + return "", err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return "", fmt.Errorf("get provider: %w", err) + } + if !supportsOAuth(provider) { + return "", errors.New("provider does not support oauth") + } + + cfg := s.oauthConfig(providerMetadata(provider.Metadata)) + resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier) + if err != nil { + return "", err + } + if err := s.saveOAuthToken(ctx, provider.ID.String(), providerOAuthToken{ + ProviderID: provider.ID.String(), + AccessToken: resp.AccessToken, + RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken), + ExpiresAt: expiresAtFromNow(resp.ExpiresIn), + Scope: firstNonEmpty(resp.Scope, cfg.Scopes), + TokenType: firstNonEmpty(resp.TokenType, "Bearer"), + State: "", + PKCECodeVerifier: "", + }); err != nil { + return "", err + } + return provider.ID.String(), nil +} + +func (s *Service) GetOAuthStatus(ctx context.Context, providerID string) (*OAuthStatus, error) { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return nil, err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return nil, fmt.Errorf("get provider: %w", err) + } + status := &OAuthStatus{ + Configured: supportsOAuth(provider), + CallbackURL: s.oauthConfig(providerMetadata(provider.Metadata)).RedirectURI, + } + if !status.Configured { + return status, nil + } + + token, err := s.getOAuthToken(ctx, providerID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return status, nil + } + return nil, err + } + status.HasToken = strings.TrimSpace(token.AccessToken) != "" + if !token.ExpiresAt.IsZero() { + expiresAt := token.ExpiresAt + status.ExpiresAt = &expiresAt + status.Expired = time.Now().After(token.ExpiresAt) + } + return status, nil +} + +func (s *Service) RevokeOAuthToken(ctx context.Context, providerID string) error { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return fmt.Errorf("get provider: %w", err) + } + if !supportsOAuth(provider) { + return errors.New("provider does not support oauth") + } + return s.queries.DeleteLlmProviderOAuthToken(ctx, providerUUID) +} + +func (s *Service) GetValidAccessToken(ctx context.Context, providerID string) (string, error) { + token, err := s.getOAuthToken(ctx, providerID) + if err != nil { + return "", err + } + if strings.TrimSpace(token.AccessToken) == "" { + return "", errors.New("oauth token is missing access token") + } + if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) { + return token.AccessToken, nil + } + if strings.TrimSpace(token.RefreshToken) == "" { + return "", errors.New("oauth token expired and no refresh token is available") + } + + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return "", err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return "", fmt.Errorf("get provider: %w", err) + } + cfg := s.oauthConfig(providerMetadata(provider.Metadata)) + refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken) + if err != nil { + return "", err + } + saved := providerOAuthToken{ + ProviderID: providerID, + AccessToken: refreshed.AccessToken, + RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken), + ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn), + Scope: firstNonEmpty(refreshed.Scope, token.Scope), + TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType), + State: token.State, + PKCECodeVerifier: token.PKCECodeVerifier, + } + if err := s.saveOAuthToken(ctx, providerID, saved); err != nil { + return "", err + } + return saved.AccessToken, nil +} + +func (s *Service) getOAuthToken(ctx context.Context, providerID string) (*providerOAuthToken, error) { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return nil, err + } + row, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, providerUUID) + if err != nil { + return nil, err + } + return toProviderOAuthToken(row), nil +} + +func (s *Service) getOAuthTokenByState(ctx context.Context, state string) (*providerOAuthToken, error) { + row, err := s.queries.GetLlmProviderOAuthTokenByState(ctx, state) + if err != nil { + return nil, err + } + return toProviderOAuthToken(row), nil +} + +func (s *Service) updateOAuthState(ctx context.Context, providerID, state, codeVerifier string) error { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return err + } + return s.queries.UpdateLlmProviderOAuthState(ctx, sqlc.UpdateLlmProviderOAuthStateParams{ + LlmProviderID: providerUUID, + State: state, + PkceCodeVerifier: codeVerifier, + }) +} + +func (s *Service) saveOAuthToken(ctx context.Context, providerID string, token providerOAuthToken) error { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return err + } + var expiresAt pgtype.Timestamptz + if !token.ExpiresAt.IsZero() { + expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true} + } + _, err = s.queries.UpsertLlmProviderOAuthToken(ctx, sqlc.UpsertLlmProviderOAuthTokenParams{ + LlmProviderID: providerUUID, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + ExpiresAt: expiresAt, + Scope: token.Scope, + TokenType: token.TokenType, + State: token.State, + PkceCodeVerifier: token.PKCECodeVerifier, + }) + return err +} + +func toProviderOAuthToken(row sqlc.LlmProviderOauthToken) *providerOAuthToken { + token := &providerOAuthToken{ + ProviderID: row.LlmProviderID.String(), + AccessToken: row.AccessToken, + RefreshToken: row.RefreshToken, + Scope: row.Scope, + TokenType: row.TokenType, + State: row.State, + PKCECodeVerifier: row.PkceCodeVerifier, + } + if row.ExpiresAt.Valid { + token.ExpiresAt = row.ExpiresAt.Time + } + return token +} + +type openAITokenResponse struct { + AccessToken string `json:"access_token"` //nolint:gosec // OAuth response payload carries runtime access token + RefreshToken string `json:"refresh_token"` //nolint:gosec // OAuth response payload carries runtime refresh token + TokenType string `json:"token_type"` + Scope string `json:"scope"` + ExpiresIn int64 `json:"expires_in"` + Error string `json:"error"` + Description string `json:"error_description"` +} + +func (s *Service) exchangeCode(ctx context.Context, cfg openAIOAuthConfig, code, codeVerifier string) (*openAITokenResponse, error) { + values := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "client_id": {cfg.ClientID}, + "redirect_uri": {cfg.RedirectURI}, + "code_verifier": {codeVerifier}, + } + return s.postTokenRequest(ctx, cfg.TokenURL, values) +} + +func (s *Service) refreshAccessToken(ctx context.Context, cfg openAIOAuthConfig, refreshToken string) (*openAITokenResponse, error) { + values := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {cfg.ClientID}, + } + return s.postTokenRequest(ctx, cfg.TokenURL, values) +} + +func (s *Service) postTokenRequest(ctx context.Context, tokenURL string, body url.Values) (*openAITokenResponse, error) { + if err := validateOAuthTokenURL(tokenURL); err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(body.Encode())) + if err != nil { + return nil, fmt.Errorf("create oauth request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + //nolint:gosec // tokenURL is restricted to the fixed OpenAI OAuth host by validateOAuthTokenURL above + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("execute oauth request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + payload, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read oauth response: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("oauth token request failed: %s", strings.TrimSpace(string(payload))) + } + + var tokenResp openAITokenResponse + if err := json.Unmarshal(payload, &tokenResp); err != nil { + return nil, fmt.Errorf("decode oauth response: %w", err) + } + if tokenResp.Error != "" { + return nil, fmt.Errorf("oauth token request failed: %s", firstNonEmpty(tokenResp.Description, tokenResp.Error)) + } + return &tokenResp, nil +} + +func validateOAuthTokenURL(raw string) error { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return fmt.Errorf("invalid oauth token url: %w", err) + } + if !strings.EqualFold(parsed.Scheme, "https") { + return errors.New("oauth token url must use https") + } + if !strings.EqualFold(parsed.Hostname(), "auth.openai.com") { + return errors.New("oauth token url host must be auth.openai.com") + } + return nil +} + +func generateState() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func computeCodeChallenge(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} + +func expiresAtFromNow(expiresIn int64) time.Time { + if expiresIn <= 0 { + return time.Time{} + } + return time.Now().Add(time.Duration(expiresIn) * time.Second) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} diff --git a/internal/providers/service.go b/internal/providers/service.go index 4426f09e..b9940e46 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -11,6 +11,7 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" + openaicodex "github.com/memohai/twilight-ai/provider/openai/codex" sdk "github.com/memohai/twilight-ai/sdk" "github.com/memohai/memoh/internal/db" @@ -20,21 +21,27 @@ import ( // Service handles provider operations. type Service struct { - queries *sqlc.Queries - logger *slog.Logger + queries *sqlc.Queries + logger *slog.Logger + httpClient *http.Client + callbackURL string } // NewService creates a new provider service. -func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { +func NewService(log *slog.Logger, queries *sqlc.Queries, callbackURL string) *Service { + if log == nil { + log = slog.Default() + } return &Service{ - queries: queries, - logger: log.With(slog.String("service", "providers")), + queries: queries, + logger: log.With(slog.String("service", "providers")), + httpClient: &http.Client{Timeout: providerOAuthHTTPTimeout}, + callbackURL: callbackURL, } } // Create creates a new LLM provider. func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) { - // Marshal metadata metadataJSON, err := json.Marshal(req.Metadata) if err != nil { return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) @@ -112,13 +119,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get return GetResponse{}, err } - // Get existing provider existing, err := s.queries.GetLlmProviderByID(ctx, providerID) if err != nil { return GetResponse{}, fmt.Errorf("get provider: %w", err) } - // Apply updates name := existing.Name if req.Name != nil { name = *req.Name @@ -146,16 +151,15 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get enable = *req.Enable } - metadata := existing.Metadata + metadataMap := providerMetadata(existing.Metadata) if req.Metadata != nil { - metadataJSON, err := json.Marshal(req.Metadata) - if err != nil { - return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) - } - metadata = metadataJSON + metadataMap = req.Metadata + } + metadataJSON, err := json.Marshal(metadataMap) + if err != nil { + return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) } - // Update provider updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{ ID: providerID, Name: name, @@ -164,7 +168,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get ClientType: clientType, Icon: icon, Enable: enable, - Metadata: metadata, + Metadata: metadataJSON, }) if err != nil { return GetResponse{}, fmt.Errorf("update provider: %w", err) @@ -213,8 +217,12 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { baseURL := strings.TrimRight(provider.BaseUrl, "/") clientType := models.ClientType(provider.ClientType) + creds, err := s.ResolveModelCredentials(ctx, provider) + if err != nil { + return TestResponse{}, err + } - sdkProvider := models.NewSDKProvider(baseURL, provider.ApiKey, clientType, probeTimeout) + sdkProvider := models.NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil) start := time.Now() result := sdkProvider.Test(ctx) @@ -238,6 +246,29 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod if err != nil { return nil, fmt.Errorf("get provider: %w", err) } + if supportsOAuth(provider) { + catalog := openaicodex.Catalog() + remoteModels := make([]RemoteModel, 0, len(catalog)) + for _, model := range catalog { + compatibilities := make([]string, 0, 2) + if model.SupportsToolCall { + compatibilities = append(compatibilities, models.CompatToolCall) + } + if model.SupportsReasoning { + compatibilities = append(compatibilities, models.CompatReasoning) + } + remoteModels = append(remoteModels, RemoteModel{ + ID: model.ID, + Name: model.DisplayName, + Object: "model", + OwnedBy: "openai-codex", + Type: "chat", + Compatibilities: compatibilities, + ReasoningEfforts: append([]string(nil), model.ReasoningEfforts...), + }) + } + return remoteModels, nil + } baseURL := strings.TrimRight(provider.BaseUrl, "/") modelsURL := fmt.Sprintf("%s/models", baseURL) @@ -250,7 +281,7 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod return nil, fmt.Errorf("create request: %w", err) } - if provider.ApiKey != "" { + if provider.ApiKey != "" && !supportsOAuth(provider) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.ApiKey)) } @@ -284,7 +315,6 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { } } - // Mask API key (show only first 8 characters) maskedAPIKey := maskAPIKey(provider.ApiKey) var icon string @@ -318,7 +348,6 @@ func maskAPIKey(apiKey string) string { } // resolveUpdatedAPIKey keeps the original key when the request value matches the masked version. -// This prevents masked placeholder values from overwriting the real stored credential. func resolveUpdatedAPIKey(existing string, updated *string) string { if updated == nil { return existing diff --git a/internal/providers/types.go b/internal/providers/types.go index dc89eb57..81583407 100644 --- a/internal/providers/types.go +++ b/internal/providers/types.go @@ -55,12 +55,25 @@ type TestResponse struct { Message string `json:"message,omitempty"` } +// OAuthStatus is returned by GET /providers/:id/oauth/status. +type OAuthStatus struct { + Configured bool `json:"configured"` + HasToken bool `json:"has_token"` + Expired bool `json:"expired"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + CallbackURL string `json:"callback_url"` +} + // RemoteModel represents a model returned by the provider's /v1/models endpoint. type RemoteModel struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - OwnedBy string `json:"owned_by"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` + Compatibilities []string `json:"compatibilities,omitempty"` + ReasoningEfforts []string `json:"reasoning_efforts,omitempty"` } // FetchModelsResponse represents the response from the provider's /v1/models endpoint. diff --git a/internal/server/server.go b/internal/server/server.go index 3ad22601..10be9d3a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -90,5 +90,11 @@ func shouldSkipJWT(path string) bool { if strings.HasPrefix(path, "/email/oauth/callback") { return true } + if strings.HasPrefix(path, "/providers/oauth/callback") { + return true + } + if strings.HasPrefix(path, "/auth/callback") { + return true + } return false }