Files

1312 lines
41 KiB
Go

package providers
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
memohcopilot "github.com/memohai/memoh/internal/copilot"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
)
const (
defaultOpenAICodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
defaultOpenAIAuthorizeURL = "https://auth.openai.com/oauth/authorize"
defaultOpenAITokenURL = "https://auth.openai.com/oauth/token" //nolint:gosec // OAuth endpoint URL, not a credential
defaultOpenAICallbackURL = "http://localhost:1455/auth/callback"
defaultOpenAIOAuthScopes = "openid profile email offline_access"
defaultGitHubDeviceCodeURL = "https://github.com/login/device/code" //nolint:gosec // OAuth endpoint URL, not a credential
defaultGitHubTokenURL = "https://github.com/login/oauth/access_token" //nolint:gosec // OAuth endpoint URL, not a credential
defaultGitHubUserURL = "https://api.github.com/user" //nolint:gosec // OAuth endpoint URL, not a credential
defaultGitHubUserEmailsURL = "https://api.github.com/user/emails" //nolint:gosec // OAuth endpoint URL, not a credential
oauthExpirySkew = 30 * time.Second
providerOAuthHTTPTimeout = 15 * time.Second
metadataOAuthClientIDKey = "oauth_client_id"
metadataOAuthAuthorizeURLKey = "oauth_authorize_url"
metadataOAuthDeviceCodeURLKey = "oauth_device_code_url"
metadataOAuthTokenURLKey = "oauth_token_url" //nolint:gosec // metadata key name, not a credential
metadataOAuthRedirectURIKey = "oauth_redirect_uri"
metadataOAuthScopesKey = "oauth_scopes"
metadataOAuthAudienceKey = "oauth_audience"
metadataOAuthUseIDOrgsFlagKey = "oauth_id_token_add_organizations"
metadataDeviceCodeKey = "device_code"
metadataDeviceUserCodeKey = "device_user_code"
metadataDeviceVerifyURIKey = "device_verification_uri"
metadataDeviceIntervalKey = "device_interval_seconds"
metadataDeviceExpiresAtKey = "device_expires_at"
metadataAccountLabelKey = "account_label"
metadataAccountLoginKey = "account_login"
metadataAccountNameKey = "account_name"
metadataAccountEmailKey = "account_email"
metadataAccountAvatarURLKey = "account_avatar_url"
metadataAccountProfileURLKey = "account_profile_url"
configOAuthClientSecretKey = "oauth_client_secret" //nolint:gosec // Metadata key name, not a credential literal.
)
type oauthTokenRecord struct {
ProviderID string
UserID string
AccessToken string //nolint:gosec // Runtime token payload persisted encrypted at rest.
RefreshToken string //nolint:gosec // Runtime token payload persisted encrypted at rest.
ExpiresAt time.Time
Scope string
TokenType string
State string
PKCECodeVerifier string
Metadata map[string]any
}
type oauthConfig struct {
ClientType models.ClientType
ClientID string
ClientSecret string //nolint:gosec // Runtime OAuth client secret from provider metadata.
AuthorizeURL string
DeviceCodeURL string
TokenURL string
RedirectURI string
Scopes string
Audience string
UsePKCE bool
IDTokenAddOrganizations bool
}
type deviceAuthorizationResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
ExpiresIn int64 `json:"expires_in"`
Interval int64 `json:"interval"`
Error string `json:"error"`
Description string `json:"error_description"`
}
func providerMetadata(raw []byte) map[string]any {
if len(raw) == 0 {
return map[string]any{}
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return map[string]any{}
}
if metadata == nil {
return map[string]any{}
}
return metadata
}
func oauthLogAttrs(providerID, userID string, err error) []any {
attrs := []any{}
if strings.TrimSpace(providerID) != "" {
attrs = append(attrs, slog.String("provider_id", providerID))
}
if strings.TrimSpace(userID) != "" {
attrs = append(attrs, slog.String("user_id", userID))
}
if err != nil {
attrs = append(attrs, slog.Any("error", err))
}
return attrs
}
func (s *Service) oauthConfigForProvider(provider sqlc.Provider) oauthConfig {
metadata := providerMetadata(provider.Metadata)
switch models.ClientType(provider.ClientType) {
case models.ClientTypeGitHubCopilot:
result := oauthConfig{
ClientType: models.ClientTypeGitHubCopilot,
ClientID: memohcopilot.GitHubOAuthClientID,
DeviceCodeURL: defaultGitHubDeviceCodeURL,
TokenURL: defaultGitHubTokenURL,
Scopes: memohcopilot.GitHubOAuthScope,
}
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthDeviceCodeURLKey)); v != "" {
result.DeviceCodeURL = v
}
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthTokenURLKey)); v != "" {
result.TokenURL = v
}
return result
default:
result := oauthConfig{
ClientType: models.ClientTypeOpenAICodex,
ClientID: defaultOpenAICodexClientID,
AuthorizeURL: defaultOpenAIAuthorizeURL,
TokenURL: defaultOpenAITokenURL,
RedirectURI: firstNonEmpty(s.callbackURL, defaultOpenAICallbackURL),
Scopes: defaultOpenAIOAuthScopes,
Audience: strings.TrimSpace(stringValue(metadata, metadataOAuthAudienceKey)),
UsePKCE: true,
IDTokenAddOrganizations: true,
}
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthClientIDKey)); v != "" {
result.ClientID = v
}
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthAuthorizeURLKey)); v != "" {
result.AuthorizeURL = v
}
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthTokenURLKey)); v != "" {
result.TokenURL = v
}
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthRedirectURIKey)); v != "" {
result.RedirectURI = v
}
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthScopesKey)); v != "" {
result.Scopes = v
}
if v, ok := metadata[metadataOAuthUseIDOrgsFlagKey].(bool); ok {
result.IDTokenAddOrganizations = v
}
return result
}
}
func supportsOAuth(provider sqlc.Provider) bool {
switch models.ClientType(provider.ClientType) {
case models.ClientTypeOpenAICodex, models.ClientTypeGitHubCopilot:
return true
default:
return false
}
}
func isUserScopedOAuthProvider(provider sqlc.Provider) bool {
return models.ClientType(provider.ClientType) == models.ClientTypeGitHubCopilot
}
func (s *Service) StartOAuthAuthorization(ctx context.Context, providerID string) (*OAuthAuthorizeResponse, error) {
provider, err := s.loadOAuthProvider(ctx, providerID)
if err != nil {
return nil, err
}
cfg := s.oauthConfigForProvider(provider)
if isUserScopedOAuthProvider(provider) {
userID := oauthctx.UserIDFromContext(ctx)
if userID == "" {
return nil, errors.New("github copilot oauth requires a current user")
}
device, err := s.startGitHubDeviceAuthorization(ctx, providerID, userID, cfg)
if err != nil {
return nil, err
}
return &OAuthAuthorizeResponse{
Mode: "device",
Device: device,
}, nil
}
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generate state: %w", err)
}
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {cfg.RedirectURI},
"state": {state},
}
if cfg.Scopes != "" {
params.Set("scope", cfg.Scopes)
}
if cfg.Audience != "" {
params.Set("audience", cfg.Audience)
}
codeVerifier, err := generateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier: %w", err)
}
if err := s.updateOAuthState(ctx, providerID, state, codeVerifier); err != nil {
return nil, err
}
params.Set("scope", cfg.Scopes)
params.Set("code_challenge", computeCodeChallenge(codeVerifier))
params.Set("code_challenge_method", "S256")
if cfg.IDTokenAddOrganizations {
params.Set("id_token_add_organizations", "true")
}
params.Set("codex_cli_simplified_flow", "true")
return &OAuthAuthorizeResponse{
Mode: "web",
AuthURL: cfg.AuthorizeURL + "?" + params.Encode(),
}, nil
}
func (s *Service) HandleOAuthCallback(ctx context.Context, state, code string) (string, error) {
if userToken, err := s.getUserOAuthTokenByState(ctx, state); err == nil {
return s.handleUserScopedOAuthCallback(ctx, userToken, code)
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", err
}
token, err := s.getOAuthTokenByState(ctx, state)
if err != nil {
return "", err
}
providerUUID, err := db.ParseUUID(token.ProviderID)
if err != nil {
return "", err
}
provider, err := s.queries.GetProviderByID(ctx, providerUUID)
if err != nil {
return "", fmt.Errorf("get provider: %w", err)
}
if !supportsOAuth(provider) {
return "", errors.New("provider does not support oauth")
}
cfg := s.oauthConfigForProvider(provider)
resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier)
if err != nil {
return "", err
}
if err := s.saveOAuthToken(ctx, provider.ID.String(), oauthTokenRecord{
ProviderID: provider.ID.String(),
AccessToken: resp.AccessToken,
RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken),
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
Scope: firstNonEmpty(resp.Scope, cfg.Scopes),
TokenType: firstNonEmpty(resp.TokenType, "Bearer"),
State: "",
PKCECodeVerifier: "",
}); err != nil {
return "", err
}
return provider.ID.String(), nil
}
func (s *Service) handleUserScopedOAuthCallback(ctx context.Context, token *oauthTokenRecord, code string) (string, error) {
providerUUID, err := db.ParseUUID(token.ProviderID)
if err != nil {
return "", err
}
provider, err := s.queries.GetProviderByID(ctx, providerUUID)
if err != nil {
return "", fmt.Errorf("get provider: %w", err)
}
if !isUserScopedOAuthProvider(provider) {
return "", errors.New("provider does not use user-scoped oauth")
}
cfg := s.oauthConfigForProvider(provider)
resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier)
if err != nil {
return "", err
}
if err := s.saveUserOAuthToken(ctx, token.ProviderID, token.UserID, oauthTokenRecord{
ProviderID: token.ProviderID,
UserID: token.UserID,
AccessToken: resp.AccessToken,
RefreshToken: firstNonEmpty(
resp.RefreshToken,
token.RefreshToken,
),
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
Scope: firstNonEmpty(resp.Scope, cfg.Scopes),
TokenType: firstNonEmpty(resp.TokenType, "bearer"),
State: "",
PKCECodeVerifier: "",
Metadata: token.Metadata,
}); err != nil {
return "", err
}
return provider.ID.String(), nil
}
func (s *Service) startGitHubDeviceAuthorization(ctx context.Context, providerID, userID string, cfg oauthConfig) (*OAuthDeviceStatus, error) {
resp, err := s.requestDeviceAuthorization(ctx, cfg)
if err != nil {
return nil, err
}
device := oauthDeviceMetadata{
DeviceCode: resp.DeviceCode,
UserCode: resp.UserCode,
VerificationURI: resp.VerificationURI,
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
IntervalSeconds: resp.Interval,
}
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", device.toMetadata()); err != nil {
return nil, err
}
return device.toStatus(), nil
}
func (s *Service) GetOAuthStatus(ctx context.Context, providerID string) (*OAuthStatus, error) {
provider, err := s.loadOAuthProvider(ctx, providerID)
if err != nil {
return nil, err
}
status := &OAuthStatus{
Configured: supportsOAuth(provider),
Mode: "web",
CallbackURL: s.oauthConfigForProvider(provider).RedirectURI,
}
if !status.Configured {
return status, nil
}
if isUserScopedOAuthProvider(provider) {
status.Mode = "device"
status.CallbackURL = ""
}
userID := ""
var token *oauthTokenRecord
if isUserScopedOAuthProvider(provider) {
userID = oauthctx.UserIDFromContext(ctx)
if userID == "" {
return nil, errors.New("github copilot oauth requires a current user")
}
token, err = s.getUserOAuthToken(ctx, providerID, userID)
} else {
token, err = s.getOAuthToken(ctx, providerID)
}
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return status, nil
}
return nil, err
}
status.HasToken = strings.TrimSpace(token.AccessToken) != ""
if !token.ExpiresAt.IsZero() {
expiresAt := token.ExpiresAt
status.ExpiresAt = &expiresAt
status.Expired = time.Now().After(token.ExpiresAt)
}
if isUserScopedOAuthProvider(provider) {
status.Device = deviceMetadataFromMap(token.Metadata).toStatus()
account, accountErr := s.resolveGitHubOAuthAccount(ctx, providerID, userID, token)
if accountErr != nil {
return nil, accountErr
}
status.Account = account
}
return status, nil
}
func (s *Service) PollOAuthAuthorization(ctx context.Context, providerID string) (*OAuthStatus, error) {
provider, err := s.loadOAuthProvider(ctx, providerID)
if err != nil {
return nil, err
}
if !isUserScopedOAuthProvider(provider) {
return nil, errors.New("device authorization is only supported for github copilot")
}
userID := oauthctx.UserIDFromContext(ctx)
if userID == "" {
return nil, errors.New("github copilot oauth requires a current user")
}
token, err := s.getUserOAuthToken(ctx, providerID, userID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return s.GetOAuthStatus(ctx, providerID)
}
return nil, err
}
device := deviceMetadataFromMap(token.Metadata)
if strings.TrimSpace(device.DeviceCode) == "" {
return s.GetOAuthStatus(ctx, providerID)
}
if !device.ExpiresAt.IsZero() && time.Now().After(device.ExpiresAt) {
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", nil); err != nil {
return nil, err
}
return s.GetOAuthStatus(ctx, providerID)
}
cfg := s.oauthConfigForProvider(provider)
resp, err := s.exchangeDeviceCode(ctx, cfg, device.DeviceCode)
if err != nil {
return nil, err
}
if resp.Error != "" {
switch resp.Error {
case "authorization_pending":
return s.GetOAuthStatus(ctx, providerID)
case "slow_down":
if resp.Interval > 0 {
device.IntervalSeconds = resp.Interval
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", device.toMetadata()); err != nil {
return nil, err
}
}
return s.GetOAuthStatus(ctx, providerID)
case "expired_token", "access_denied", "incorrect_device_code", "unsupported_grant_type":
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", nil); err != nil {
return nil, err
}
return s.GetOAuthStatus(ctx, providerID)
default:
return nil, fmt.Errorf("oauth device token request failed: %s", firstNonEmpty(resp.Description, resp.Error))
}
}
account, err := s.fetchGitHubOAuthAccount(ctx, resp.AccessToken)
if err != nil {
s.logger.Warn("fetch github oauth account failed", oauthLogAttrs(providerID, userID, err)...)
}
if err := s.saveUserOAuthToken(ctx, providerID, userID, oauthTokenRecord{
ProviderID: providerID,
UserID: userID,
AccessToken: resp.AccessToken,
RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken),
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
Scope: firstNonEmpty(resp.Scope, token.Scope),
TokenType: firstNonEmpty(resp.TokenType, "bearer"),
Metadata: account.toMetadata(),
}); err != nil {
return nil, err
}
return s.GetOAuthStatus(ctx, providerID)
}
func (s *Service) RevokeOAuthToken(ctx context.Context, providerID string) error {
provider, err := s.loadOAuthProvider(ctx, providerID)
if err != nil {
return err
}
if !supportsOAuth(provider) {
return errors.New("provider does not support oauth")
}
if isUserScopedOAuthProvider(provider) {
userID := oauthctx.UserIDFromContext(ctx)
if userID == "" {
return errors.New("github copilot oauth requires a current user")
}
return s.deleteUserOAuthToken(ctx, providerID, userID)
}
return s.queries.DeleteProviderOAuthToken(ctx, provider.ID)
}
func (s *Service) GetValidAccessToken(ctx context.Context, providerID string) (string, error) {
provider, err := s.loadOAuthProvider(ctx, providerID)
if err != nil {
return "", err
}
cfg := s.oauthConfigForProvider(provider)
if isUserScopedOAuthProvider(provider) {
userID := oauthctx.UserIDFromContext(ctx)
if userID == "" {
return "", errors.New("github copilot requires a current user")
}
token, err := s.getUserOAuthToken(ctx, providerID, userID)
if err != nil {
return "", err
}
return s.resolveValidUserOAuthToken(ctx, cfg, token)
}
token, err := s.getOAuthToken(ctx, providerID)
if err != nil {
return "", err
}
return s.resolveValidProviderOAuthToken(ctx, cfg, token)
}
func (s *Service) resolveValidProviderOAuthToken(ctx context.Context, cfg oauthConfig, token *oauthTokenRecord) (string, error) {
if strings.TrimSpace(token.AccessToken) == "" {
return "", errors.New("oauth token is missing access token")
}
if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) {
return token.AccessToken, nil
}
if strings.TrimSpace(token.RefreshToken) == "" {
return "", errors.New("oauth token expired and no refresh token is available")
}
refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken)
if err != nil {
return "", err
}
saved := oauthTokenRecord{
ProviderID: token.ProviderID,
AccessToken: refreshed.AccessToken,
RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken),
ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn),
Scope: firstNonEmpty(refreshed.Scope, token.Scope),
TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType),
State: token.State,
PKCECodeVerifier: token.PKCECodeVerifier,
Metadata: token.Metadata,
}
if err := s.saveOAuthToken(ctx, token.ProviderID, saved); err != nil {
return "", err
}
return saved.AccessToken, nil
}
func (s *Service) resolveValidUserOAuthToken(ctx context.Context, cfg oauthConfig, token *oauthTokenRecord) (string, error) {
if strings.TrimSpace(token.AccessToken) == "" {
return "", errors.New("oauth token is missing access token")
}
if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) {
return token.AccessToken, nil
}
if strings.TrimSpace(token.RefreshToken) == "" {
return "", errors.New("oauth token expired and no refresh token is available")
}
refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken)
if err != nil {
return "", err
}
saved := oauthTokenRecord{
ProviderID: token.ProviderID,
UserID: token.UserID,
AccessToken: refreshed.AccessToken,
RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken),
ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn),
Scope: firstNonEmpty(refreshed.Scope, token.Scope),
TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType),
State: token.State,
PKCECodeVerifier: token.PKCECodeVerifier,
Metadata: token.Metadata,
}
if err := s.saveUserOAuthToken(ctx, token.ProviderID, token.UserID, saved); err != nil {
return "", err
}
return saved.AccessToken, nil
}
func (s *Service) loadOAuthProvider(ctx context.Context, providerID string) (sqlc.Provider, error) {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return sqlc.Provider{}, err
}
provider, err := s.queries.GetProviderByID(ctx, providerUUID)
if err != nil {
return sqlc.Provider{}, fmt.Errorf("get provider: %w", err)
}
if !supportsOAuth(provider) {
return sqlc.Provider{}, errors.New("provider does not support oauth")
}
return provider, nil
}
func (s *Service) getOAuthToken(ctx context.Context, providerID string) (*oauthTokenRecord, error) {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
row, err := s.queries.GetProviderOAuthTokenByProvider(ctx, providerUUID)
if err != nil {
return nil, err
}
return toProviderOAuthToken(row), nil
}
func (s *Service) getOAuthTokenByState(ctx context.Context, state string) (*oauthTokenRecord, error) {
row, err := s.queries.GetProviderOAuthTokenByState(ctx, state)
if err != nil {
return nil, err
}
return toProviderOAuthToken(row), nil
}
func (s *Service) updateOAuthState(ctx context.Context, providerID, state, codeVerifier string) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
return s.queries.UpdateProviderOAuthState(ctx, sqlc.UpdateProviderOAuthStateParams{
ProviderID: providerUUID,
State: state,
PkceCodeVerifier: codeVerifier,
})
}
func (s *Service) saveOAuthToken(ctx context.Context, providerID string, token oauthTokenRecord) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
var expiresAt pgtype.Timestamptz
if !token.ExpiresAt.IsZero() {
expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true}
}
_, err = s.queries.UpsertProviderOAuthToken(ctx, sqlc.UpsertProviderOAuthTokenParams{
ProviderID: providerUUID,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ExpiresAt: expiresAt,
Scope: token.Scope,
TokenType: token.TokenType,
State: token.State,
PkceCodeVerifier: token.PKCECodeVerifier,
})
return err
}
func (s *Service) getUserOAuthToken(ctx context.Context, providerID, userID string) (*oauthTokenRecord, error) {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
userUUID, err := db.ParseUUID(userID)
if err != nil {
return nil, err
}
row, err := s.queries.GetUserProviderOAuthToken(ctx, sqlc.GetUserProviderOAuthTokenParams{
ProviderID: providerUUID,
UserID: userUUID,
})
if err != nil {
return nil, err
}
return toUserProviderOAuthToken(row), nil
}
func (s *Service) getUserOAuthTokenByState(ctx context.Context, state string) (*oauthTokenRecord, error) {
row, err := s.queries.GetUserProviderOAuthTokenByState(ctx, state)
if err != nil {
return nil, err
}
return toUserProviderOAuthToken(row), nil
}
func (s *Service) updateUserOAuthState(ctx context.Context, providerID, userID, state, codeVerifier string, metadata map[string]any) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
userUUID, err := db.ParseUUID(userID)
if err != nil {
return err
}
return s.queries.UpdateUserProviderOAuthState(ctx, sqlc.UpdateUserProviderOAuthStateParams{
ProviderID: providerUUID,
UserID: userUUID,
State: state,
PkceCodeVerifier: codeVerifier,
Metadata: metadataJSON(metadata),
})
}
func (s *Service) saveUserOAuthToken(ctx context.Context, providerID, userID string, token oauthTokenRecord) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
userUUID, err := db.ParseUUID(userID)
if err != nil {
return err
}
var expiresAt pgtype.Timestamptz
if !token.ExpiresAt.IsZero() {
expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true}
}
_, err = s.queries.UpsertUserProviderOAuthToken(ctx, sqlc.UpsertUserProviderOAuthTokenParams{
ProviderID: providerUUID,
UserID: userUUID,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ExpiresAt: expiresAt,
Scope: token.Scope,
TokenType: token.TokenType,
State: token.State,
PkceCodeVerifier: token.PKCECodeVerifier,
Metadata: metadataJSON(token.Metadata),
})
return err
}
func (s *Service) deleteUserOAuthToken(ctx context.Context, providerID, userID string) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
userUUID, err := db.ParseUUID(userID)
if err != nil {
return err
}
return s.queries.DeleteUserProviderOAuthToken(ctx, sqlc.DeleteUserProviderOAuthTokenParams{
ProviderID: providerUUID,
UserID: userUUID,
})
}
func toProviderOAuthToken(row sqlc.ProviderOauthToken) *oauthTokenRecord {
token := &oauthTokenRecord{
ProviderID: row.ProviderID.String(),
AccessToken: row.AccessToken,
RefreshToken: row.RefreshToken,
Scope: row.Scope,
TokenType: row.TokenType,
State: row.State,
PKCECodeVerifier: row.PkceCodeVerifier,
Metadata: map[string]any{},
}
if row.ExpiresAt.Valid {
token.ExpiresAt = row.ExpiresAt.Time
}
return token
}
func toUserProviderOAuthToken(row sqlc.UserProviderOauthToken) *oauthTokenRecord {
token := &oauthTokenRecord{
ProviderID: row.ProviderID.String(),
UserID: row.UserID.String(),
AccessToken: row.AccessToken,
RefreshToken: row.RefreshToken,
Scope: row.Scope,
TokenType: row.TokenType,
State: row.State,
PKCECodeVerifier: row.PkceCodeVerifier,
Metadata: providerMetadata(row.Metadata),
}
if row.ExpiresAt.Valid {
token.ExpiresAt = row.ExpiresAt.Time
}
return token
}
type oauthTokenResponse struct {
AccessToken string `json:"access_token"` //nolint:gosec // OAuth response payload carries runtime access token
RefreshToken string `json:"refresh_token"` //nolint:gosec // OAuth response payload carries runtime refresh token
TokenType string `json:"token_type"`
Scope string `json:"scope"`
ExpiresIn int64 `json:"expires_in"`
Interval int64 `json:"interval"`
Error string `json:"error"`
Description string `json:"error_description"`
}
type oauthDeviceMetadata struct {
DeviceCode string
UserCode string
VerificationURI string
ExpiresAt time.Time
IntervalSeconds int64
}
type oauthAccountMetadata struct {
Label string
Login string
Name string
Email string
AvatarURL string
ProfileURL string
}
func deviceMetadataFromMap(metadata map[string]any) oauthDeviceMetadata {
device := oauthDeviceMetadata{
DeviceCode: strings.TrimSpace(stringValue(metadata, metadataDeviceCodeKey)),
UserCode: strings.TrimSpace(stringValue(metadata, metadataDeviceUserCodeKey)),
VerificationURI: strings.TrimSpace(stringValue(metadata, metadataDeviceVerifyURIKey)),
IntervalSeconds: int64Value(metadata, metadataDeviceIntervalKey),
}
if raw := strings.TrimSpace(stringValue(metadata, metadataDeviceExpiresAtKey)); raw != "" {
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
device.ExpiresAt = parsed
}
}
return device
}
func (d oauthDeviceMetadata) toMetadata() map[string]any {
if strings.TrimSpace(d.DeviceCode) == "" {
return nil
}
metadata := map[string]any{
metadataDeviceCodeKey: d.DeviceCode,
metadataDeviceUserCodeKey: d.UserCode,
metadataDeviceVerifyURIKey: d.VerificationURI,
metadataDeviceIntervalKey: d.IntervalSeconds,
}
if !d.ExpiresAt.IsZero() {
metadata[metadataDeviceExpiresAtKey] = d.ExpiresAt.UTC().Format(time.RFC3339)
}
return metadata
}
func (d oauthDeviceMetadata) toStatus() *OAuthDeviceStatus {
if strings.TrimSpace(d.DeviceCode) == "" {
return nil
}
status := &OAuthDeviceStatus{
Pending: true,
UserCode: d.UserCode,
VerificationURI: d.VerificationURI,
IntervalSeconds: d.IntervalSeconds,
}
if !d.ExpiresAt.IsZero() {
expiresAt := d.ExpiresAt
status.ExpiresAt = &expiresAt
}
return status
}
func accountMetadataFromMap(metadata map[string]any) oauthAccountMetadata {
account := oauthAccountMetadata{
Label: strings.TrimSpace(stringValue(metadata, metadataAccountLabelKey)),
Login: strings.TrimSpace(stringValue(metadata, metadataAccountLoginKey)),
Name: strings.TrimSpace(stringValue(metadata, metadataAccountNameKey)),
Email: strings.TrimSpace(stringValue(metadata, metadataAccountEmailKey)),
AvatarURL: strings.TrimSpace(stringValue(metadata, metadataAccountAvatarURLKey)),
ProfileURL: strings.TrimSpace(stringValue(metadata, metadataAccountProfileURLKey)),
}
if account.Label == "" {
account.Label = firstNonEmpty(account.Name, account.Login, account.Email)
}
return account
}
func (a oauthAccountMetadata) toMetadata() map[string]any {
if a.isZero() {
return map[string]any{}
}
metadata := map[string]any{}
if a.Label != "" {
metadata[metadataAccountLabelKey] = a.Label
}
if a.Login != "" {
metadata[metadataAccountLoginKey] = a.Login
}
if a.Name != "" {
metadata[metadataAccountNameKey] = a.Name
}
if a.Email != "" {
metadata[metadataAccountEmailKey] = a.Email
}
if a.AvatarURL != "" {
metadata[metadataAccountAvatarURLKey] = a.AvatarURL
}
if a.ProfileURL != "" {
metadata[metadataAccountProfileURLKey] = a.ProfileURL
}
return metadata
}
func (a oauthAccountMetadata) toStatus() *OAuthAccount {
if a.isZero() {
return nil
}
return &OAuthAccount{
Label: a.Label,
Login: a.Login,
Name: a.Name,
Email: a.Email,
AvatarURL: a.AvatarURL,
ProfileURL: a.ProfileURL,
}
}
func (a oauthAccountMetadata) isZero() bool {
return a.Label == "" && a.Login == "" && a.Name == "" && a.Email == "" && a.AvatarURL == "" && a.ProfileURL == ""
}
func (s *Service) resolveGitHubOAuthAccount(ctx context.Context, providerID, userID string, token *oauthTokenRecord) (*OAuthAccount, error) {
account := accountMetadataFromMap(token.Metadata)
if status := account.toStatus(); status != nil {
return status, nil
}
if strings.TrimSpace(token.AccessToken) == "" {
return nil, nil
}
refreshedAccount, err := s.fetchGitHubOAuthAccount(ctx, token.AccessToken)
if err != nil {
s.logger.Warn("refresh github oauth account metadata failed", oauthLogAttrs(providerID, userID, err)...)
return nil, nil
}
updatedToken := *token
updatedToken.Metadata = refreshedAccount.toMetadata()
if err := s.saveUserOAuthToken(ctx, providerID, userID, updatedToken); err != nil {
return nil, err
}
return refreshedAccount.toStatus(), nil
}
func (s *Service) fetchGitHubOAuthAccount(ctx context.Context, accessToken string) (oauthAccountMetadata, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, defaultGitHubUserURL, nil)
if err != nil {
return oauthAccountMetadata{}, fmt.Errorf("create github oauth account request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(accessToken))
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
resp, err := s.httpClient.Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
if err != nil {
return oauthAccountMetadata{}, fmt.Errorf("execute github oauth account request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return oauthAccountMetadata{}, fmt.Errorf("read github oauth account response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return oauthAccountMetadata{}, fmt.Errorf("github oauth account request failed: %s", strings.TrimSpace(string(payload)))
}
var profile struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
HTMLURL string `json:"html_url"`
}
if err := json.Unmarshal(payload, &profile); err != nil {
return oauthAccountMetadata{}, fmt.Errorf("decode github oauth account response: %w", err)
}
account := oauthAccountMetadata{
Login: strings.TrimSpace(profile.Login),
Name: strings.TrimSpace(profile.Name),
Email: strings.TrimSpace(profile.Email),
AvatarURL: strings.TrimSpace(profile.AvatarURL),
ProfileURL: strings.TrimSpace(profile.HTMLURL),
}
if account.Email == "" {
email, err := s.fetchGitHubPrimaryEmail(ctx, accessToken)
if err != nil {
s.logger.Warn("fetch github oauth primary email failed", slog.Any("error", err))
} else {
account.Email = email
}
}
account.Label = firstNonEmpty(account.Email, account.Name, account.Login)
if account.Label == "" {
return oauthAccountMetadata{}, errors.New("github oauth account response did not include a usable account label")
}
return account, nil
}
func (s *Service) fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, defaultGitHubUserEmailsURL, nil)
if err != nil {
return "", fmt.Errorf("create github oauth emails request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(accessToken))
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
resp, err := s.httpClient.Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
if err != nil {
return "", fmt.Errorf("execute github oauth emails request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read github oauth emails response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("github oauth emails request failed: %s", strings.TrimSpace(string(payload)))
}
var emails []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
if err := json.Unmarshal(payload, &emails); err != nil {
return "", fmt.Errorf("decode github oauth emails response: %w", err)
}
for _, candidate := range emails {
email := strings.TrimSpace(candidate.Email)
if candidate.Primary && candidate.Verified && email != "" {
return email, nil
}
}
for _, candidate := range emails {
email := strings.TrimSpace(candidate.Email)
if candidate.Primary && email != "" {
return email, nil
}
}
for _, candidate := range emails {
email := strings.TrimSpace(candidate.Email)
if email != "" {
return email, nil
}
}
return "", errors.New("github oauth emails response did not include a usable email")
}
func (s *Service) requestDeviceAuthorization(ctx context.Context, cfg oauthConfig) (*deviceAuthorizationResponse, error) {
if err := validateOAuthTokenURL(cfg.ClientType, cfg.DeviceCodeURL); err != nil {
return nil, err
}
values := url.Values{
"client_id": {cfg.ClientID},
}
if cfg.Scopes != "" {
values.Set("scope", cfg.Scopes)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.DeviceCodeURL, strings.NewReader(values.Encode()))
if err != nil {
return nil, fmt.Errorf("create oauth device request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := s.httpClient.Do(req) //nolint:gosec // URL is validated by validateOAuthTokenURL before request execution.
if err != nil {
return nil, fmt.Errorf("execute oauth device request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read oauth device response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("oauth device request failed: %s", strings.TrimSpace(string(payload)))
}
var deviceResp deviceAuthorizationResponse
if err := json.Unmarshal(payload, &deviceResp); err != nil {
return nil, fmt.Errorf("decode oauth device response: %w", err)
}
if deviceResp.Error != "" {
return nil, fmt.Errorf("oauth device request failed: %s", firstNonEmpty(deviceResp.Description, deviceResp.Error))
}
if strings.TrimSpace(deviceResp.DeviceCode) == "" || strings.TrimSpace(deviceResp.UserCode) == "" || strings.TrimSpace(deviceResp.VerificationURI) == "" {
return nil, errors.New("oauth device request returned incomplete device authorization data")
}
if deviceResp.Interval <= 0 {
deviceResp.Interval = 5
}
return &deviceResp, nil
}
func (s *Service) exchangeDeviceCode(ctx context.Context, cfg oauthConfig, deviceCode string) (*oauthTokenResponse, error) {
if err := validateOAuthTokenURL(cfg.ClientType, cfg.TokenURL); err != nil {
return nil, err
}
values := url.Values{
"client_id": {cfg.ClientID},
"device_code": {deviceCode},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenURL, strings.NewReader(values.Encode()))
if err != nil {
return nil, fmt.Errorf("create oauth device token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := s.httpClient.Do(req) //nolint:gosec // URL is validated by validateOAuthTokenURL before request execution.
if err != nil {
return nil, fmt.Errorf("execute oauth device token request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read oauth device token response: %w", err)
}
var tokenResp oauthTokenResponse
if err := json.Unmarshal(payload, &tokenResp); err != nil {
return nil, fmt.Errorf("decode oauth device token response: %w", err)
}
if tokenResp.Interval <= 0 {
tokenResp.Interval = 5
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if tokenResp.Error != "" {
return &tokenResp, nil
}
return nil, fmt.Errorf("oauth device token request failed: %s", strings.TrimSpace(string(payload)))
}
return &tokenResp, nil
}
func (s *Service) exchangeCode(ctx context.Context, cfg oauthConfig, code, codeVerifier string) (*oauthTokenResponse, error) {
values := url.Values{
"code": {code},
"client_id": {cfg.ClientID},
"redirect_uri": {cfg.RedirectURI},
}
if cfg.UsePKCE {
values.Set("grant_type", "authorization_code")
values.Set("code_verifier", codeVerifier)
}
if cfg.ClientSecret != "" {
values.Set("client_secret", cfg.ClientSecret)
}
return s.postTokenRequest(ctx, cfg, values)
}
func (s *Service) refreshAccessToken(ctx context.Context, cfg oauthConfig, refreshToken string) (*oauthTokenResponse, error) {
values := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {cfg.ClientID},
}
if cfg.ClientSecret != "" {
values.Set("client_secret", cfg.ClientSecret)
}
return s.postTokenRequest(ctx, cfg, values)
}
func (s *Service) postTokenRequest(ctx context.Context, cfg oauthConfig, body url.Values) (*oauthTokenResponse, error) {
if err := validateOAuthTokenURL(cfg.ClientType, cfg.TokenURL); err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenURL, strings.NewReader(body.Encode()))
if err != nil {
return nil, fmt.Errorf("create oauth request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := s.httpClient.Do(req) //nolint:gosec // URL is validated by validateOAuthTokenURL before request execution.
if err != nil {
return nil, fmt.Errorf("execute oauth request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read oauth response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("oauth token request failed: %s", strings.TrimSpace(string(payload)))
}
var tokenResp oauthTokenResponse
if err := json.Unmarshal(payload, &tokenResp); err != nil {
return nil, fmt.Errorf("decode oauth response: %w", err)
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("oauth token request failed: %s", firstNonEmpty(tokenResp.Description, tokenResp.Error))
}
return &tokenResp, nil
}
func validateOAuthTokenURL(clientType models.ClientType, raw string) error {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return fmt.Errorf("invalid oauth token url: %w", err)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return errors.New("oauth token url must use https")
}
switch clientType {
case models.ClientTypeOpenAICodex:
if !strings.EqualFold(parsed.Hostname(), "auth.openai.com") {
return errors.New("oauth token url host must be auth.openai.com")
}
case models.ClientTypeGitHubCopilot:
if !strings.EqualFold(parsed.Hostname(), "github.com") {
return errors.New("oauth token url host must be github.com")
}
default:
return errors.New("unsupported oauth client type")
}
return nil
}
func stringValue(input map[string]any, key string) string {
if input == nil {
return ""
}
value, _ := input[key].(string)
return value
}
func int64Value(input map[string]any, key string) int64 {
if input == nil {
return 0
}
switch value := input[key].(type) {
case int64:
return value
case int:
return int64(value)
case float64:
return int64(value)
default:
return 0
}
}
func metadataJSON(metadata map[string]any) []byte {
if len(metadata) == 0 {
return []byte("{}")
}
encoded, err := json.Marshal(metadata)
if err != nil {
return []byte("{}")
}
return encoded
}
func generateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func generateCodeVerifier() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func computeCodeChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func expiresAtFromNow(expiresIn int64) time.Time {
if expiresIn <= 0 {
return time.Time{}
}
return time.Now().Add(time.Duration(expiresIn) * time.Second)
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}