mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(provider): add github copilot device flow provider (#364)
This commit is contained in:
@@ -429,6 +429,7 @@ func IsValidClientType(clientType ClientType) bool {
|
||||
ClientTypeAnthropicMessages,
|
||||
ClientTypeGoogleGenerativeAI,
|
||||
ClientTypeOpenAICodex,
|
||||
ClientTypeGitHubCopilot,
|
||||
ClientTypeEdgeSpeech:
|
||||
return true
|
||||
default:
|
||||
|
||||
+53
-12
@@ -17,8 +17,10 @@ import (
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
const probeTimeout = 15 * time.Second
|
||||
@@ -162,6 +164,9 @@ func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientTyp
|
||||
}
|
||||
return openaicodex.New(opts...)
|
||||
|
||||
case ClientTypeGitHubCopilot:
|
||||
return memohcopilot.NewProvider(apiKey, httpClient)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(apiKey),
|
||||
@@ -202,26 +207,62 @@ type modelCredentials struct {
|
||||
func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.Provider) (modelCredentials, error) {
|
||||
apiKey := providerConfigString(provider.Config, "api_key")
|
||||
|
||||
if ClientType(provider.ClientType) != ClientTypeOpenAICodex {
|
||||
switch ClientType(provider.ClientType) {
|
||||
case ClientTypeGitHubCopilot:
|
||||
token, err := s.resolveGitHubCopilotAccessToken(ctx, provider)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
return modelCredentials{APIKey: token}, nil
|
||||
|
||||
case ClientTypeOpenAICodex:
|
||||
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenRow.AccessToken)
|
||||
if accessToken == "" {
|
||||
return modelCredentials{}, errors.New("oauth token is missing access token")
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(accessToken)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
return modelCredentials{
|
||||
APIKey: accessToken,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return modelCredentials{APIKey: apiKey}, nil
|
||||
}
|
||||
}
|
||||
|
||||
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
func (s *Service) resolveGitHubCopilotAccessToken(ctx context.Context, provider sqlc.Provider) (string, error) {
|
||||
userID := oauthctx.UserIDFromContext(ctx)
|
||||
if userID == "" {
|
||||
return "", errors.New("github copilot requires a current user")
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenRow.AccessToken)
|
||||
userUUID, err := db.ParseUUID(userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := s.queries.GetUserProviderOAuthToken(ctx, sqlc.GetUserProviderOAuthTokenParams{
|
||||
ProviderID: provider.ID,
|
||||
UserID: userUUID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
accessToken := strings.TrimSpace(row.AccessToken)
|
||||
if accessToken == "" {
|
||||
return modelCredentials{}, errors.New("oauth token is missing access token")
|
||||
return "", errors.New("oauth token is missing access token")
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(accessToken)
|
||||
copilotToken, err := memohcopilot.ResolveToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
return "", err
|
||||
}
|
||||
return modelCredentials{
|
||||
APIKey: accessToken,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
return copilotToken, nil
|
||||
}
|
||||
|
||||
func codexAccountIDFromToken(token string) (string, error) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
)
|
||||
|
||||
// SDKModelConfig holds provider and model information resolved from DB,
|
||||
@@ -76,6 +78,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
}
|
||||
return openaicodex.New(opts...).ChatModel(cfg.ModelID)
|
||||
|
||||
case ClientTypeGitHubCopilot:
|
||||
return memohcopilot.NewModel(cfg.APIKey, cfg.ModelID, cfg.HTTPClient)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(cfg.APIKey),
|
||||
@@ -178,6 +183,8 @@ func ResolveClientType(model *sdk.Model) string {
|
||||
return string(ClientTypeAnthropicMessages)
|
||||
case strings.Contains(name, "google"):
|
||||
return string(ClientTypeGoogleGenerativeAI)
|
||||
case strings.Contains(name, "github-copilot"), strings.Contains(name, "copilot"):
|
||||
return string(ClientTypeGitHubCopilot)
|
||||
case strings.Contains(name, "codex"):
|
||||
return string(ClientTypeOpenAICodex)
|
||||
case strings.Contains(name, "responses"):
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
|
||||
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
|
||||
ClientTypeOpenAICodex ClientType = "openai-codex"
|
||||
ClientTypeGitHubCopilot ClientType = "github-copilot"
|
||||
ClientTypeEdgeSpeech ClientType = "edge-speech"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user