mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
1312 lines
41 KiB
Go
1312 lines
41 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
|
|
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
"github.com/memohai/memoh/internal/models"
|
|
"github.com/memohai/memoh/internal/oauthctx"
|
|
)
|
|
|
|
const (
|
|
defaultOpenAICodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
|
defaultOpenAIAuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
|
defaultOpenAITokenURL = "https://auth.openai.com/oauth/token" //nolint:gosec // OAuth endpoint URL, not a credential
|
|
defaultOpenAICallbackURL = "http://localhost:1455/auth/callback"
|
|
defaultOpenAIOAuthScopes = "openid profile email offline_access"
|
|
defaultGitHubDeviceCodeURL = "https://github.com/login/device/code" //nolint:gosec // OAuth endpoint URL, not a credential
|
|
defaultGitHubTokenURL = "https://github.com/login/oauth/access_token" //nolint:gosec // OAuth endpoint URL, not a credential
|
|
defaultGitHubUserURL = "https://api.github.com/user" //nolint:gosec // OAuth endpoint URL, not a credential
|
|
defaultGitHubUserEmailsURL = "https://api.github.com/user/emails" //nolint:gosec // OAuth endpoint URL, not a credential
|
|
oauthExpirySkew = 30 * time.Second
|
|
providerOAuthHTTPTimeout = 15 * time.Second
|
|
metadataOAuthClientIDKey = "oauth_client_id"
|
|
metadataOAuthAuthorizeURLKey = "oauth_authorize_url"
|
|
metadataOAuthDeviceCodeURLKey = "oauth_device_code_url"
|
|
metadataOAuthTokenURLKey = "oauth_token_url" //nolint:gosec // metadata key name, not a credential
|
|
metadataOAuthRedirectURIKey = "oauth_redirect_uri"
|
|
metadataOAuthScopesKey = "oauth_scopes"
|
|
metadataOAuthAudienceKey = "oauth_audience"
|
|
metadataOAuthUseIDOrgsFlagKey = "oauth_id_token_add_organizations"
|
|
metadataDeviceCodeKey = "device_code"
|
|
metadataDeviceUserCodeKey = "device_user_code"
|
|
metadataDeviceVerifyURIKey = "device_verification_uri"
|
|
metadataDeviceIntervalKey = "device_interval_seconds"
|
|
metadataDeviceExpiresAtKey = "device_expires_at"
|
|
metadataAccountLabelKey = "account_label"
|
|
metadataAccountLoginKey = "account_login"
|
|
metadataAccountNameKey = "account_name"
|
|
metadataAccountEmailKey = "account_email"
|
|
metadataAccountAvatarURLKey = "account_avatar_url"
|
|
metadataAccountProfileURLKey = "account_profile_url"
|
|
configOAuthClientSecretKey = "oauth_client_secret" //nolint:gosec // Metadata key name, not a credential literal.
|
|
)
|
|
|
|
type oauthTokenRecord struct {
|
|
ProviderID string
|
|
UserID string
|
|
AccessToken string //nolint:gosec // Runtime token payload persisted encrypted at rest.
|
|
RefreshToken string //nolint:gosec // Runtime token payload persisted encrypted at rest.
|
|
ExpiresAt time.Time
|
|
Scope string
|
|
TokenType string
|
|
State string
|
|
PKCECodeVerifier string
|
|
Metadata map[string]any
|
|
}
|
|
|
|
type oauthConfig struct {
|
|
ClientType models.ClientType
|
|
ClientID string
|
|
ClientSecret string //nolint:gosec // Runtime OAuth client secret from provider metadata.
|
|
AuthorizeURL string
|
|
DeviceCodeURL string
|
|
TokenURL string
|
|
RedirectURI string
|
|
Scopes string
|
|
Audience string
|
|
UsePKCE bool
|
|
IDTokenAddOrganizations bool
|
|
}
|
|
|
|
type deviceAuthorizationResponse struct {
|
|
DeviceCode string `json:"device_code"`
|
|
UserCode string `json:"user_code"`
|
|
VerificationURI string `json:"verification_uri"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
Interval int64 `json:"interval"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}
|
|
|
|
func providerMetadata(raw []byte) map[string]any {
|
|
if len(raw) == 0 {
|
|
return map[string]any{}
|
|
}
|
|
var metadata map[string]any
|
|
if err := json.Unmarshal(raw, &metadata); err != nil {
|
|
return map[string]any{}
|
|
}
|
|
if metadata == nil {
|
|
return map[string]any{}
|
|
}
|
|
return metadata
|
|
}
|
|
|
|
func oauthLogAttrs(providerID, userID string, err error) []any {
|
|
attrs := []any{}
|
|
if strings.TrimSpace(providerID) != "" {
|
|
attrs = append(attrs, slog.String("provider_id", providerID))
|
|
}
|
|
if strings.TrimSpace(userID) != "" {
|
|
attrs = append(attrs, slog.String("user_id", userID))
|
|
}
|
|
if err != nil {
|
|
attrs = append(attrs, slog.Any("error", err))
|
|
}
|
|
return attrs
|
|
}
|
|
|
|
func (s *Service) oauthConfigForProvider(provider sqlc.Provider) oauthConfig {
|
|
metadata := providerMetadata(provider.Metadata)
|
|
|
|
switch models.ClientType(provider.ClientType) {
|
|
case models.ClientTypeGitHubCopilot:
|
|
result := oauthConfig{
|
|
ClientType: models.ClientTypeGitHubCopilot,
|
|
ClientID: memohcopilot.GitHubOAuthClientID,
|
|
DeviceCodeURL: defaultGitHubDeviceCodeURL,
|
|
TokenURL: defaultGitHubTokenURL,
|
|
Scopes: memohcopilot.GitHubOAuthScope,
|
|
}
|
|
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthDeviceCodeURLKey)); v != "" {
|
|
result.DeviceCodeURL = v
|
|
}
|
|
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthTokenURLKey)); v != "" {
|
|
result.TokenURL = v
|
|
}
|
|
return result
|
|
|
|
default:
|
|
result := oauthConfig{
|
|
ClientType: models.ClientTypeOpenAICodex,
|
|
ClientID: defaultOpenAICodexClientID,
|
|
AuthorizeURL: defaultOpenAIAuthorizeURL,
|
|
TokenURL: defaultOpenAITokenURL,
|
|
RedirectURI: firstNonEmpty(s.callbackURL, defaultOpenAICallbackURL),
|
|
Scopes: defaultOpenAIOAuthScopes,
|
|
Audience: strings.TrimSpace(stringValue(metadata, metadataOAuthAudienceKey)),
|
|
UsePKCE: true,
|
|
IDTokenAddOrganizations: true,
|
|
}
|
|
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthClientIDKey)); v != "" {
|
|
result.ClientID = v
|
|
}
|
|
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthAuthorizeURLKey)); v != "" {
|
|
result.AuthorizeURL = v
|
|
}
|
|
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthTokenURLKey)); v != "" {
|
|
result.TokenURL = v
|
|
}
|
|
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthRedirectURIKey)); v != "" {
|
|
result.RedirectURI = v
|
|
}
|
|
if v := strings.TrimSpace(stringValue(metadata, metadataOAuthScopesKey)); v != "" {
|
|
result.Scopes = v
|
|
}
|
|
if v, ok := metadata[metadataOAuthUseIDOrgsFlagKey].(bool); ok {
|
|
result.IDTokenAddOrganizations = v
|
|
}
|
|
return result
|
|
}
|
|
}
|
|
|
|
func supportsOAuth(provider sqlc.Provider) bool {
|
|
switch models.ClientType(provider.ClientType) {
|
|
case models.ClientTypeOpenAICodex, models.ClientTypeGitHubCopilot:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isUserScopedOAuthProvider(provider sqlc.Provider) bool {
|
|
return models.ClientType(provider.ClientType) == models.ClientTypeGitHubCopilot
|
|
}
|
|
|
|
func (s *Service) StartOAuthAuthorization(ctx context.Context, providerID string) (*OAuthAuthorizeResponse, error) {
|
|
provider, err := s.loadOAuthProvider(ctx, providerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cfg := s.oauthConfigForProvider(provider)
|
|
|
|
if isUserScopedOAuthProvider(provider) {
|
|
userID := oauthctx.UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return nil, errors.New("github copilot oauth requires a current user")
|
|
}
|
|
device, err := s.startGitHubDeviceAuthorization(ctx, providerID, userID, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &OAuthAuthorizeResponse{
|
|
Mode: "device",
|
|
Device: device,
|
|
}, nil
|
|
}
|
|
|
|
state, err := generateState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate state: %w", err)
|
|
}
|
|
|
|
params := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cfg.ClientID},
|
|
"redirect_uri": {cfg.RedirectURI},
|
|
"state": {state},
|
|
}
|
|
if cfg.Scopes != "" {
|
|
params.Set("scope", cfg.Scopes)
|
|
}
|
|
if cfg.Audience != "" {
|
|
params.Set("audience", cfg.Audience)
|
|
}
|
|
|
|
codeVerifier, err := generateCodeVerifier()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate code verifier: %w", err)
|
|
}
|
|
if err := s.updateOAuthState(ctx, providerID, state, codeVerifier); err != nil {
|
|
return nil, err
|
|
}
|
|
params.Set("scope", cfg.Scopes)
|
|
params.Set("code_challenge", computeCodeChallenge(codeVerifier))
|
|
params.Set("code_challenge_method", "S256")
|
|
if cfg.IDTokenAddOrganizations {
|
|
params.Set("id_token_add_organizations", "true")
|
|
}
|
|
params.Set("codex_cli_simplified_flow", "true")
|
|
|
|
return &OAuthAuthorizeResponse{
|
|
Mode: "web",
|
|
AuthURL: cfg.AuthorizeURL + "?" + params.Encode(),
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) HandleOAuthCallback(ctx context.Context, state, code string) (string, error) {
|
|
if userToken, err := s.getUserOAuthTokenByState(ctx, state); err == nil {
|
|
return s.handleUserScopedOAuthCallback(ctx, userToken, code)
|
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
|
return "", err
|
|
}
|
|
|
|
token, err := s.getOAuthTokenByState(ctx, state)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
providerUUID, err := db.ParseUUID(token.ProviderID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
provider, err := s.queries.GetProviderByID(ctx, providerUUID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("get provider: %w", err)
|
|
}
|
|
if !supportsOAuth(provider) {
|
|
return "", errors.New("provider does not support oauth")
|
|
}
|
|
|
|
cfg := s.oauthConfigForProvider(provider)
|
|
resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if err := s.saveOAuthToken(ctx, provider.ID.String(), oauthTokenRecord{
|
|
ProviderID: provider.ID.String(),
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken),
|
|
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
|
|
Scope: firstNonEmpty(resp.Scope, cfg.Scopes),
|
|
TokenType: firstNonEmpty(resp.TokenType, "Bearer"),
|
|
State: "",
|
|
PKCECodeVerifier: "",
|
|
}); err != nil {
|
|
return "", err
|
|
}
|
|
return provider.ID.String(), nil
|
|
}
|
|
|
|
func (s *Service) handleUserScopedOAuthCallback(ctx context.Context, token *oauthTokenRecord, code string) (string, error) {
|
|
providerUUID, err := db.ParseUUID(token.ProviderID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
provider, err := s.queries.GetProviderByID(ctx, providerUUID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("get provider: %w", err)
|
|
}
|
|
if !isUserScopedOAuthProvider(provider) {
|
|
return "", errors.New("provider does not use user-scoped oauth")
|
|
}
|
|
|
|
cfg := s.oauthConfigForProvider(provider)
|
|
resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if err := s.saveUserOAuthToken(ctx, token.ProviderID, token.UserID, oauthTokenRecord{
|
|
ProviderID: token.ProviderID,
|
|
UserID: token.UserID,
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: firstNonEmpty(
|
|
resp.RefreshToken,
|
|
token.RefreshToken,
|
|
),
|
|
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
|
|
Scope: firstNonEmpty(resp.Scope, cfg.Scopes),
|
|
TokenType: firstNonEmpty(resp.TokenType, "bearer"),
|
|
State: "",
|
|
PKCECodeVerifier: "",
|
|
Metadata: token.Metadata,
|
|
}); err != nil {
|
|
return "", err
|
|
}
|
|
return provider.ID.String(), nil
|
|
}
|
|
|
|
func (s *Service) startGitHubDeviceAuthorization(ctx context.Context, providerID, userID string, cfg oauthConfig) (*OAuthDeviceStatus, error) {
|
|
resp, err := s.requestDeviceAuthorization(ctx, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
device := oauthDeviceMetadata{
|
|
DeviceCode: resp.DeviceCode,
|
|
UserCode: resp.UserCode,
|
|
VerificationURI: resp.VerificationURI,
|
|
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
|
|
IntervalSeconds: resp.Interval,
|
|
}
|
|
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", device.toMetadata()); err != nil {
|
|
return nil, err
|
|
}
|
|
return device.toStatus(), nil
|
|
}
|
|
|
|
func (s *Service) GetOAuthStatus(ctx context.Context, providerID string) (*OAuthStatus, error) {
|
|
provider, err := s.loadOAuthProvider(ctx, providerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
status := &OAuthStatus{
|
|
Configured: supportsOAuth(provider),
|
|
Mode: "web",
|
|
CallbackURL: s.oauthConfigForProvider(provider).RedirectURI,
|
|
}
|
|
if !status.Configured {
|
|
return status, nil
|
|
}
|
|
if isUserScopedOAuthProvider(provider) {
|
|
status.Mode = "device"
|
|
status.CallbackURL = ""
|
|
}
|
|
|
|
userID := ""
|
|
var token *oauthTokenRecord
|
|
if isUserScopedOAuthProvider(provider) {
|
|
userID = oauthctx.UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return nil, errors.New("github copilot oauth requires a current user")
|
|
}
|
|
token, err = s.getUserOAuthToken(ctx, providerID, userID)
|
|
} else {
|
|
token, err = s.getOAuthToken(ctx, providerID)
|
|
}
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return status, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
status.HasToken = strings.TrimSpace(token.AccessToken) != ""
|
|
if !token.ExpiresAt.IsZero() {
|
|
expiresAt := token.ExpiresAt
|
|
status.ExpiresAt = &expiresAt
|
|
status.Expired = time.Now().After(token.ExpiresAt)
|
|
}
|
|
if isUserScopedOAuthProvider(provider) {
|
|
status.Device = deviceMetadataFromMap(token.Metadata).toStatus()
|
|
account, accountErr := s.resolveGitHubOAuthAccount(ctx, providerID, userID, token)
|
|
if accountErr != nil {
|
|
return nil, accountErr
|
|
}
|
|
status.Account = account
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
func (s *Service) PollOAuthAuthorization(ctx context.Context, providerID string) (*OAuthStatus, error) {
|
|
provider, err := s.loadOAuthProvider(ctx, providerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !isUserScopedOAuthProvider(provider) {
|
|
return nil, errors.New("device authorization is only supported for github copilot")
|
|
}
|
|
|
|
userID := oauthctx.UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return nil, errors.New("github copilot oauth requires a current user")
|
|
}
|
|
|
|
token, err := s.getUserOAuthToken(ctx, providerID, userID)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return s.GetOAuthStatus(ctx, providerID)
|
|
}
|
|
return nil, err
|
|
}
|
|
device := deviceMetadataFromMap(token.Metadata)
|
|
if strings.TrimSpace(device.DeviceCode) == "" {
|
|
return s.GetOAuthStatus(ctx, providerID)
|
|
}
|
|
if !device.ExpiresAt.IsZero() && time.Now().After(device.ExpiresAt) {
|
|
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", nil); err != nil {
|
|
return nil, err
|
|
}
|
|
return s.GetOAuthStatus(ctx, providerID)
|
|
}
|
|
|
|
cfg := s.oauthConfigForProvider(provider)
|
|
resp, err := s.exchangeDeviceCode(ctx, cfg, device.DeviceCode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.Error != "" {
|
|
switch resp.Error {
|
|
case "authorization_pending":
|
|
return s.GetOAuthStatus(ctx, providerID)
|
|
case "slow_down":
|
|
if resp.Interval > 0 {
|
|
device.IntervalSeconds = resp.Interval
|
|
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", device.toMetadata()); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return s.GetOAuthStatus(ctx, providerID)
|
|
case "expired_token", "access_denied", "incorrect_device_code", "unsupported_grant_type":
|
|
if err := s.updateUserOAuthState(ctx, providerID, userID, "", "", nil); err != nil {
|
|
return nil, err
|
|
}
|
|
return s.GetOAuthStatus(ctx, providerID)
|
|
default:
|
|
return nil, fmt.Errorf("oauth device token request failed: %s", firstNonEmpty(resp.Description, resp.Error))
|
|
}
|
|
}
|
|
|
|
account, err := s.fetchGitHubOAuthAccount(ctx, resp.AccessToken)
|
|
if err != nil {
|
|
s.logger.Warn("fetch github oauth account failed", oauthLogAttrs(providerID, userID, err)...)
|
|
}
|
|
|
|
if err := s.saveUserOAuthToken(ctx, providerID, userID, oauthTokenRecord{
|
|
ProviderID: providerID,
|
|
UserID: userID,
|
|
AccessToken: resp.AccessToken,
|
|
RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken),
|
|
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
|
|
Scope: firstNonEmpty(resp.Scope, token.Scope),
|
|
TokenType: firstNonEmpty(resp.TokenType, "bearer"),
|
|
Metadata: account.toMetadata(),
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
return s.GetOAuthStatus(ctx, providerID)
|
|
}
|
|
|
|
func (s *Service) RevokeOAuthToken(ctx context.Context, providerID string) error {
|
|
provider, err := s.loadOAuthProvider(ctx, providerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !supportsOAuth(provider) {
|
|
return errors.New("provider does not support oauth")
|
|
}
|
|
|
|
if isUserScopedOAuthProvider(provider) {
|
|
userID := oauthctx.UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return errors.New("github copilot oauth requires a current user")
|
|
}
|
|
return s.deleteUserOAuthToken(ctx, providerID, userID)
|
|
}
|
|
return s.queries.DeleteProviderOAuthToken(ctx, provider.ID)
|
|
}
|
|
|
|
func (s *Service) GetValidAccessToken(ctx context.Context, providerID string) (string, error) {
|
|
provider, err := s.loadOAuthProvider(ctx, providerID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
cfg := s.oauthConfigForProvider(provider)
|
|
|
|
if isUserScopedOAuthProvider(provider) {
|
|
userID := oauthctx.UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return "", errors.New("github copilot requires a current user")
|
|
}
|
|
token, err := s.getUserOAuthToken(ctx, providerID, userID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return s.resolveValidUserOAuthToken(ctx, cfg, token)
|
|
}
|
|
|
|
token, err := s.getOAuthToken(ctx, providerID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return s.resolveValidProviderOAuthToken(ctx, cfg, token)
|
|
}
|
|
|
|
func (s *Service) resolveValidProviderOAuthToken(ctx context.Context, cfg oauthConfig, token *oauthTokenRecord) (string, error) {
|
|
if strings.TrimSpace(token.AccessToken) == "" {
|
|
return "", errors.New("oauth token is missing access token")
|
|
}
|
|
if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) {
|
|
return token.AccessToken, nil
|
|
}
|
|
if strings.TrimSpace(token.RefreshToken) == "" {
|
|
return "", errors.New("oauth token expired and no refresh token is available")
|
|
}
|
|
|
|
refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
saved := oauthTokenRecord{
|
|
ProviderID: token.ProviderID,
|
|
AccessToken: refreshed.AccessToken,
|
|
RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken),
|
|
ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn),
|
|
Scope: firstNonEmpty(refreshed.Scope, token.Scope),
|
|
TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType),
|
|
State: token.State,
|
|
PKCECodeVerifier: token.PKCECodeVerifier,
|
|
Metadata: token.Metadata,
|
|
}
|
|
if err := s.saveOAuthToken(ctx, token.ProviderID, saved); err != nil {
|
|
return "", err
|
|
}
|
|
return saved.AccessToken, nil
|
|
}
|
|
|
|
func (s *Service) resolveValidUserOAuthToken(ctx context.Context, cfg oauthConfig, token *oauthTokenRecord) (string, error) {
|
|
if strings.TrimSpace(token.AccessToken) == "" {
|
|
return "", errors.New("oauth token is missing access token")
|
|
}
|
|
if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) {
|
|
return token.AccessToken, nil
|
|
}
|
|
if strings.TrimSpace(token.RefreshToken) == "" {
|
|
return "", errors.New("oauth token expired and no refresh token is available")
|
|
}
|
|
|
|
refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
saved := oauthTokenRecord{
|
|
ProviderID: token.ProviderID,
|
|
UserID: token.UserID,
|
|
AccessToken: refreshed.AccessToken,
|
|
RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken),
|
|
ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn),
|
|
Scope: firstNonEmpty(refreshed.Scope, token.Scope),
|
|
TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType),
|
|
State: token.State,
|
|
PKCECodeVerifier: token.PKCECodeVerifier,
|
|
Metadata: token.Metadata,
|
|
}
|
|
if err := s.saveUserOAuthToken(ctx, token.ProviderID, token.UserID, saved); err != nil {
|
|
return "", err
|
|
}
|
|
return saved.AccessToken, nil
|
|
}
|
|
|
|
func (s *Service) loadOAuthProvider(ctx context.Context, providerID string) (sqlc.Provider, error) {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return sqlc.Provider{}, err
|
|
}
|
|
provider, err := s.queries.GetProviderByID(ctx, providerUUID)
|
|
if err != nil {
|
|
return sqlc.Provider{}, fmt.Errorf("get provider: %w", err)
|
|
}
|
|
if !supportsOAuth(provider) {
|
|
return sqlc.Provider{}, errors.New("provider does not support oauth")
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (s *Service) getOAuthToken(ctx context.Context, providerID string) (*oauthTokenRecord, error) {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
row, err := s.queries.GetProviderOAuthTokenByProvider(ctx, providerUUID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return toProviderOAuthToken(row), nil
|
|
}
|
|
|
|
func (s *Service) getOAuthTokenByState(ctx context.Context, state string) (*oauthTokenRecord, error) {
|
|
row, err := s.queries.GetProviderOAuthTokenByState(ctx, state)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return toProviderOAuthToken(row), nil
|
|
}
|
|
|
|
func (s *Service) updateOAuthState(ctx context.Context, providerID, state, codeVerifier string) error {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.UpdateProviderOAuthState(ctx, sqlc.UpdateProviderOAuthStateParams{
|
|
ProviderID: providerUUID,
|
|
State: state,
|
|
PkceCodeVerifier: codeVerifier,
|
|
})
|
|
}
|
|
|
|
func (s *Service) saveOAuthToken(ctx context.Context, providerID string, token oauthTokenRecord) error {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var expiresAt pgtype.Timestamptz
|
|
if !token.ExpiresAt.IsZero() {
|
|
expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true}
|
|
}
|
|
_, err = s.queries.UpsertProviderOAuthToken(ctx, sqlc.UpsertProviderOAuthTokenParams{
|
|
ProviderID: providerUUID,
|
|
AccessToken: token.AccessToken,
|
|
RefreshToken: token.RefreshToken,
|
|
ExpiresAt: expiresAt,
|
|
Scope: token.Scope,
|
|
TokenType: token.TokenType,
|
|
State: token.State,
|
|
PkceCodeVerifier: token.PKCECodeVerifier,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *Service) getUserOAuthToken(ctx context.Context, providerID, userID string) (*oauthTokenRecord, error) {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userUUID, err := db.ParseUUID(userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
row, err := s.queries.GetUserProviderOAuthToken(ctx, sqlc.GetUserProviderOAuthTokenParams{
|
|
ProviderID: providerUUID,
|
|
UserID: userUUID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return toUserProviderOAuthToken(row), nil
|
|
}
|
|
|
|
func (s *Service) getUserOAuthTokenByState(ctx context.Context, state string) (*oauthTokenRecord, error) {
|
|
row, err := s.queries.GetUserProviderOAuthTokenByState(ctx, state)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return toUserProviderOAuthToken(row), nil
|
|
}
|
|
|
|
func (s *Service) updateUserOAuthState(ctx context.Context, providerID, userID, state, codeVerifier string, metadata map[string]any) error {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
userUUID, err := db.ParseUUID(userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.UpdateUserProviderOAuthState(ctx, sqlc.UpdateUserProviderOAuthStateParams{
|
|
ProviderID: providerUUID,
|
|
UserID: userUUID,
|
|
State: state,
|
|
PkceCodeVerifier: codeVerifier,
|
|
Metadata: metadataJSON(metadata),
|
|
})
|
|
}
|
|
|
|
func (s *Service) saveUserOAuthToken(ctx context.Context, providerID, userID string, token oauthTokenRecord) error {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
userUUID, err := db.ParseUUID(userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var expiresAt pgtype.Timestamptz
|
|
if !token.ExpiresAt.IsZero() {
|
|
expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true}
|
|
}
|
|
_, err = s.queries.UpsertUserProviderOAuthToken(ctx, sqlc.UpsertUserProviderOAuthTokenParams{
|
|
ProviderID: providerUUID,
|
|
UserID: userUUID,
|
|
AccessToken: token.AccessToken,
|
|
RefreshToken: token.RefreshToken,
|
|
ExpiresAt: expiresAt,
|
|
Scope: token.Scope,
|
|
TokenType: token.TokenType,
|
|
State: token.State,
|
|
PkceCodeVerifier: token.PKCECodeVerifier,
|
|
Metadata: metadataJSON(token.Metadata),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *Service) deleteUserOAuthToken(ctx context.Context, providerID, userID string) error {
|
|
providerUUID, err := db.ParseUUID(providerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
userUUID, err := db.ParseUUID(userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.DeleteUserProviderOAuthToken(ctx, sqlc.DeleteUserProviderOAuthTokenParams{
|
|
ProviderID: providerUUID,
|
|
UserID: userUUID,
|
|
})
|
|
}
|
|
|
|
func toProviderOAuthToken(row sqlc.ProviderOauthToken) *oauthTokenRecord {
|
|
token := &oauthTokenRecord{
|
|
ProviderID: row.ProviderID.String(),
|
|
AccessToken: row.AccessToken,
|
|
RefreshToken: row.RefreshToken,
|
|
Scope: row.Scope,
|
|
TokenType: row.TokenType,
|
|
State: row.State,
|
|
PKCECodeVerifier: row.PkceCodeVerifier,
|
|
Metadata: map[string]any{},
|
|
}
|
|
if row.ExpiresAt.Valid {
|
|
token.ExpiresAt = row.ExpiresAt.Time
|
|
}
|
|
return token
|
|
}
|
|
|
|
func toUserProviderOAuthToken(row sqlc.UserProviderOauthToken) *oauthTokenRecord {
|
|
token := &oauthTokenRecord{
|
|
ProviderID: row.ProviderID.String(),
|
|
UserID: row.UserID.String(),
|
|
AccessToken: row.AccessToken,
|
|
RefreshToken: row.RefreshToken,
|
|
Scope: row.Scope,
|
|
TokenType: row.TokenType,
|
|
State: row.State,
|
|
PKCECodeVerifier: row.PkceCodeVerifier,
|
|
Metadata: providerMetadata(row.Metadata),
|
|
}
|
|
if row.ExpiresAt.Valid {
|
|
token.ExpiresAt = row.ExpiresAt.Time
|
|
}
|
|
return token
|
|
}
|
|
|
|
type oauthTokenResponse struct {
|
|
AccessToken string `json:"access_token"` //nolint:gosec // OAuth response payload carries runtime access token
|
|
RefreshToken string `json:"refresh_token"` //nolint:gosec // OAuth response payload carries runtime refresh token
|
|
TokenType string `json:"token_type"`
|
|
Scope string `json:"scope"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
Interval int64 `json:"interval"`
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description"`
|
|
}
|
|
|
|
type oauthDeviceMetadata struct {
|
|
DeviceCode string
|
|
UserCode string
|
|
VerificationURI string
|
|
ExpiresAt time.Time
|
|
IntervalSeconds int64
|
|
}
|
|
|
|
type oauthAccountMetadata struct {
|
|
Label string
|
|
Login string
|
|
Name string
|
|
Email string
|
|
AvatarURL string
|
|
ProfileURL string
|
|
}
|
|
|
|
func deviceMetadataFromMap(metadata map[string]any) oauthDeviceMetadata {
|
|
device := oauthDeviceMetadata{
|
|
DeviceCode: strings.TrimSpace(stringValue(metadata, metadataDeviceCodeKey)),
|
|
UserCode: strings.TrimSpace(stringValue(metadata, metadataDeviceUserCodeKey)),
|
|
VerificationURI: strings.TrimSpace(stringValue(metadata, metadataDeviceVerifyURIKey)),
|
|
IntervalSeconds: int64Value(metadata, metadataDeviceIntervalKey),
|
|
}
|
|
if raw := strings.TrimSpace(stringValue(metadata, metadataDeviceExpiresAtKey)); raw != "" {
|
|
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
|
device.ExpiresAt = parsed
|
|
}
|
|
}
|
|
return device
|
|
}
|
|
|
|
func (d oauthDeviceMetadata) toMetadata() map[string]any {
|
|
if strings.TrimSpace(d.DeviceCode) == "" {
|
|
return nil
|
|
}
|
|
metadata := map[string]any{
|
|
metadataDeviceCodeKey: d.DeviceCode,
|
|
metadataDeviceUserCodeKey: d.UserCode,
|
|
metadataDeviceVerifyURIKey: d.VerificationURI,
|
|
metadataDeviceIntervalKey: d.IntervalSeconds,
|
|
}
|
|
if !d.ExpiresAt.IsZero() {
|
|
metadata[metadataDeviceExpiresAtKey] = d.ExpiresAt.UTC().Format(time.RFC3339)
|
|
}
|
|
return metadata
|
|
}
|
|
|
|
func (d oauthDeviceMetadata) toStatus() *OAuthDeviceStatus {
|
|
if strings.TrimSpace(d.DeviceCode) == "" {
|
|
return nil
|
|
}
|
|
status := &OAuthDeviceStatus{
|
|
Pending: true,
|
|
UserCode: d.UserCode,
|
|
VerificationURI: d.VerificationURI,
|
|
IntervalSeconds: d.IntervalSeconds,
|
|
}
|
|
if !d.ExpiresAt.IsZero() {
|
|
expiresAt := d.ExpiresAt
|
|
status.ExpiresAt = &expiresAt
|
|
}
|
|
return status
|
|
}
|
|
|
|
func accountMetadataFromMap(metadata map[string]any) oauthAccountMetadata {
|
|
account := oauthAccountMetadata{
|
|
Label: strings.TrimSpace(stringValue(metadata, metadataAccountLabelKey)),
|
|
Login: strings.TrimSpace(stringValue(metadata, metadataAccountLoginKey)),
|
|
Name: strings.TrimSpace(stringValue(metadata, metadataAccountNameKey)),
|
|
Email: strings.TrimSpace(stringValue(metadata, metadataAccountEmailKey)),
|
|
AvatarURL: strings.TrimSpace(stringValue(metadata, metadataAccountAvatarURLKey)),
|
|
ProfileURL: strings.TrimSpace(stringValue(metadata, metadataAccountProfileURLKey)),
|
|
}
|
|
if account.Label == "" {
|
|
account.Label = firstNonEmpty(account.Name, account.Login, account.Email)
|
|
}
|
|
return account
|
|
}
|
|
|
|
func (a oauthAccountMetadata) toMetadata() map[string]any {
|
|
if a.isZero() {
|
|
return map[string]any{}
|
|
}
|
|
metadata := map[string]any{}
|
|
if a.Label != "" {
|
|
metadata[metadataAccountLabelKey] = a.Label
|
|
}
|
|
if a.Login != "" {
|
|
metadata[metadataAccountLoginKey] = a.Login
|
|
}
|
|
if a.Name != "" {
|
|
metadata[metadataAccountNameKey] = a.Name
|
|
}
|
|
if a.Email != "" {
|
|
metadata[metadataAccountEmailKey] = a.Email
|
|
}
|
|
if a.AvatarURL != "" {
|
|
metadata[metadataAccountAvatarURLKey] = a.AvatarURL
|
|
}
|
|
if a.ProfileURL != "" {
|
|
metadata[metadataAccountProfileURLKey] = a.ProfileURL
|
|
}
|
|
return metadata
|
|
}
|
|
|
|
func (a oauthAccountMetadata) toStatus() *OAuthAccount {
|
|
if a.isZero() {
|
|
return nil
|
|
}
|
|
return &OAuthAccount{
|
|
Label: a.Label,
|
|
Login: a.Login,
|
|
Name: a.Name,
|
|
Email: a.Email,
|
|
AvatarURL: a.AvatarURL,
|
|
ProfileURL: a.ProfileURL,
|
|
}
|
|
}
|
|
|
|
func (a oauthAccountMetadata) isZero() bool {
|
|
return a.Label == "" && a.Login == "" && a.Name == "" && a.Email == "" && a.AvatarURL == "" && a.ProfileURL == ""
|
|
}
|
|
|
|
func (s *Service) resolveGitHubOAuthAccount(ctx context.Context, providerID, userID string, token *oauthTokenRecord) (*OAuthAccount, error) {
|
|
account := accountMetadataFromMap(token.Metadata)
|
|
if status := account.toStatus(); status != nil {
|
|
return status, nil
|
|
}
|
|
if strings.TrimSpace(token.AccessToken) == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
refreshedAccount, err := s.fetchGitHubOAuthAccount(ctx, token.AccessToken)
|
|
if err != nil {
|
|
s.logger.Warn("refresh github oauth account metadata failed", oauthLogAttrs(providerID, userID, err)...)
|
|
return nil, nil
|
|
}
|
|
|
|
updatedToken := *token
|
|
updatedToken.Metadata = refreshedAccount.toMetadata()
|
|
if err := s.saveUserOAuthToken(ctx, providerID, userID, updatedToken); err != nil {
|
|
return nil, err
|
|
}
|
|
return refreshedAccount.toStatus(), nil
|
|
}
|
|
|
|
func (s *Service) fetchGitHubOAuthAccount(ctx context.Context, accessToken string) (oauthAccountMetadata, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, defaultGitHubUserURL, nil)
|
|
if err != nil {
|
|
return oauthAccountMetadata{}, fmt.Errorf("create github oauth account request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/vnd.github+json")
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(accessToken))
|
|
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
|
|
|
resp, err := s.httpClient.Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
|
|
if err != nil {
|
|
return oauthAccountMetadata{}, fmt.Errorf("execute github oauth account request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
payload, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return oauthAccountMetadata{}, fmt.Errorf("read github oauth account response: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return oauthAccountMetadata{}, fmt.Errorf("github oauth account request failed: %s", strings.TrimSpace(string(payload)))
|
|
}
|
|
|
|
var profile struct {
|
|
Login string `json:"login"`
|
|
Name string `json:"name"`
|
|
Email string `json:"email"`
|
|
AvatarURL string `json:"avatar_url"`
|
|
HTMLURL string `json:"html_url"`
|
|
}
|
|
if err := json.Unmarshal(payload, &profile); err != nil {
|
|
return oauthAccountMetadata{}, fmt.Errorf("decode github oauth account response: %w", err)
|
|
}
|
|
|
|
account := oauthAccountMetadata{
|
|
Login: strings.TrimSpace(profile.Login),
|
|
Name: strings.TrimSpace(profile.Name),
|
|
Email: strings.TrimSpace(profile.Email),
|
|
AvatarURL: strings.TrimSpace(profile.AvatarURL),
|
|
ProfileURL: strings.TrimSpace(profile.HTMLURL),
|
|
}
|
|
if account.Email == "" {
|
|
email, err := s.fetchGitHubPrimaryEmail(ctx, accessToken)
|
|
if err != nil {
|
|
s.logger.Warn("fetch github oauth primary email failed", slog.Any("error", err))
|
|
} else {
|
|
account.Email = email
|
|
}
|
|
}
|
|
account.Label = firstNonEmpty(account.Email, account.Name, account.Login)
|
|
if account.Label == "" {
|
|
return oauthAccountMetadata{}, errors.New("github oauth account response did not include a usable account label")
|
|
}
|
|
return account, nil
|
|
}
|
|
|
|
func (s *Service) fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, defaultGitHubUserEmailsURL, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create github oauth emails request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/vnd.github+json")
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(accessToken))
|
|
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
|
|
|
resp, err := s.httpClient.Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
|
|
if err != nil {
|
|
return "", fmt.Errorf("execute github oauth emails request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
payload, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read github oauth emails response: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return "", fmt.Errorf("github oauth emails request failed: %s", strings.TrimSpace(string(payload)))
|
|
}
|
|
|
|
var emails []struct {
|
|
Email string `json:"email"`
|
|
Primary bool `json:"primary"`
|
|
Verified bool `json:"verified"`
|
|
}
|
|
if err := json.Unmarshal(payload, &emails); err != nil {
|
|
return "", fmt.Errorf("decode github oauth emails response: %w", err)
|
|
}
|
|
|
|
for _, candidate := range emails {
|
|
email := strings.TrimSpace(candidate.Email)
|
|
if candidate.Primary && candidate.Verified && email != "" {
|
|
return email, nil
|
|
}
|
|
}
|
|
for _, candidate := range emails {
|
|
email := strings.TrimSpace(candidate.Email)
|
|
if candidate.Primary && email != "" {
|
|
return email, nil
|
|
}
|
|
}
|
|
for _, candidate := range emails {
|
|
email := strings.TrimSpace(candidate.Email)
|
|
if email != "" {
|
|
return email, nil
|
|
}
|
|
}
|
|
|
|
return "", errors.New("github oauth emails response did not include a usable email")
|
|
}
|
|
|
|
func (s *Service) requestDeviceAuthorization(ctx context.Context, cfg oauthConfig) (*deviceAuthorizationResponse, error) {
|
|
if err := validateOAuthTokenURL(cfg.ClientType, cfg.DeviceCodeURL); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
values := url.Values{
|
|
"client_id": {cfg.ClientID},
|
|
}
|
|
if cfg.Scopes != "" {
|
|
values.Set("scope", cfg.Scopes)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.DeviceCodeURL, strings.NewReader(values.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create oauth device request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := s.httpClient.Do(req) //nolint:gosec // URL is validated by validateOAuthTokenURL before request execution.
|
|
if err != nil {
|
|
return nil, fmt.Errorf("execute oauth device request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
payload, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read oauth device response: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("oauth device request failed: %s", strings.TrimSpace(string(payload)))
|
|
}
|
|
|
|
var deviceResp deviceAuthorizationResponse
|
|
if err := json.Unmarshal(payload, &deviceResp); err != nil {
|
|
return nil, fmt.Errorf("decode oauth device response: %w", err)
|
|
}
|
|
if deviceResp.Error != "" {
|
|
return nil, fmt.Errorf("oauth device request failed: %s", firstNonEmpty(deviceResp.Description, deviceResp.Error))
|
|
}
|
|
if strings.TrimSpace(deviceResp.DeviceCode) == "" || strings.TrimSpace(deviceResp.UserCode) == "" || strings.TrimSpace(deviceResp.VerificationURI) == "" {
|
|
return nil, errors.New("oauth device request returned incomplete device authorization data")
|
|
}
|
|
if deviceResp.Interval <= 0 {
|
|
deviceResp.Interval = 5
|
|
}
|
|
return &deviceResp, nil
|
|
}
|
|
|
|
func (s *Service) exchangeDeviceCode(ctx context.Context, cfg oauthConfig, deviceCode string) (*oauthTokenResponse, error) {
|
|
if err := validateOAuthTokenURL(cfg.ClientType, cfg.TokenURL); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
values := url.Values{
|
|
"client_id": {cfg.ClientID},
|
|
"device_code": {deviceCode},
|
|
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenURL, strings.NewReader(values.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create oauth device token request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := s.httpClient.Do(req) //nolint:gosec // URL is validated by validateOAuthTokenURL before request execution.
|
|
if err != nil {
|
|
return nil, fmt.Errorf("execute oauth device token request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
payload, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read oauth device token response: %w", err)
|
|
}
|
|
|
|
var tokenResp oauthTokenResponse
|
|
if err := json.Unmarshal(payload, &tokenResp); err != nil {
|
|
return nil, fmt.Errorf("decode oauth device token response: %w", err)
|
|
}
|
|
if tokenResp.Interval <= 0 {
|
|
tokenResp.Interval = 5
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
if tokenResp.Error != "" {
|
|
return &tokenResp, nil
|
|
}
|
|
return nil, fmt.Errorf("oauth device token request failed: %s", strings.TrimSpace(string(payload)))
|
|
}
|
|
return &tokenResp, nil
|
|
}
|
|
|
|
func (s *Service) exchangeCode(ctx context.Context, cfg oauthConfig, code, codeVerifier string) (*oauthTokenResponse, error) {
|
|
values := url.Values{
|
|
"code": {code},
|
|
"client_id": {cfg.ClientID},
|
|
"redirect_uri": {cfg.RedirectURI},
|
|
}
|
|
if cfg.UsePKCE {
|
|
values.Set("grant_type", "authorization_code")
|
|
values.Set("code_verifier", codeVerifier)
|
|
}
|
|
if cfg.ClientSecret != "" {
|
|
values.Set("client_secret", cfg.ClientSecret)
|
|
}
|
|
return s.postTokenRequest(ctx, cfg, values)
|
|
}
|
|
|
|
func (s *Service) refreshAccessToken(ctx context.Context, cfg oauthConfig, refreshToken string) (*oauthTokenResponse, error) {
|
|
values := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {refreshToken},
|
|
"client_id": {cfg.ClientID},
|
|
}
|
|
if cfg.ClientSecret != "" {
|
|
values.Set("client_secret", cfg.ClientSecret)
|
|
}
|
|
return s.postTokenRequest(ctx, cfg, values)
|
|
}
|
|
|
|
func (s *Service) postTokenRequest(ctx context.Context, cfg oauthConfig, body url.Values) (*oauthTokenResponse, error) {
|
|
if err := validateOAuthTokenURL(cfg.ClientType, cfg.TokenURL); err != nil {
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenURL, strings.NewReader(body.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create oauth request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := s.httpClient.Do(req) //nolint:gosec // URL is validated by validateOAuthTokenURL before request execution.
|
|
if err != nil {
|
|
return nil, fmt.Errorf("execute oauth request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
payload, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read oauth response: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("oauth token request failed: %s", strings.TrimSpace(string(payload)))
|
|
}
|
|
|
|
var tokenResp oauthTokenResponse
|
|
if err := json.Unmarshal(payload, &tokenResp); err != nil {
|
|
return nil, fmt.Errorf("decode oauth response: %w", err)
|
|
}
|
|
if tokenResp.Error != "" {
|
|
return nil, fmt.Errorf("oauth token request failed: %s", firstNonEmpty(tokenResp.Description, tokenResp.Error))
|
|
}
|
|
return &tokenResp, nil
|
|
}
|
|
|
|
func validateOAuthTokenURL(clientType models.ClientType, raw string) error {
|
|
parsed, err := url.Parse(strings.TrimSpace(raw))
|
|
if err != nil {
|
|
return fmt.Errorf("invalid oauth token url: %w", err)
|
|
}
|
|
if !strings.EqualFold(parsed.Scheme, "https") {
|
|
return errors.New("oauth token url must use https")
|
|
}
|
|
|
|
switch clientType {
|
|
case models.ClientTypeOpenAICodex:
|
|
if !strings.EqualFold(parsed.Hostname(), "auth.openai.com") {
|
|
return errors.New("oauth token url host must be auth.openai.com")
|
|
}
|
|
case models.ClientTypeGitHubCopilot:
|
|
if !strings.EqualFold(parsed.Hostname(), "github.com") {
|
|
return errors.New("oauth token url host must be github.com")
|
|
}
|
|
default:
|
|
return errors.New("unsupported oauth client type")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func stringValue(input map[string]any, key string) string {
|
|
if input == nil {
|
|
return ""
|
|
}
|
|
value, _ := input[key].(string)
|
|
return value
|
|
}
|
|
|
|
func int64Value(input map[string]any, key string) int64 {
|
|
if input == nil {
|
|
return 0
|
|
}
|
|
switch value := input[key].(type) {
|
|
case int64:
|
|
return value
|
|
case int:
|
|
return int64(value)
|
|
case float64:
|
|
return int64(value)
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func metadataJSON(metadata map[string]any) []byte {
|
|
if len(metadata) == 0 {
|
|
return []byte("{}")
|
|
}
|
|
encoded, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
return []byte("{}")
|
|
}
|
|
return encoded
|
|
}
|
|
|
|
func generateState() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|
|
|
|
func generateCodeVerifier() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func computeCodeChallenge(verifier string) string {
|
|
sum := sha256.Sum256([]byte(verifier))
|
|
return base64.RawURLEncoding.EncodeToString(sum[:])
|
|
}
|
|
|
|
func expiresAtFromNow(expiresIn int64) time.Time {
|
|
if expiresIn <= 0 {
|
|
return time.Time{}
|
|
}
|
|
return time.Now().Add(time.Duration(expiresIn) * time.Second)
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|