mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(provider): add github copilot device flow provider (#364)
This commit is contained in:
@@ -28,6 +28,7 @@ import (
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
pipelinepkg "github.com/memohai/memoh/internal/pipeline"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
@@ -499,7 +500,8 @@ func (r *Resolver) buildBaseRunConfig(ctx context.Context, p baseRunConfigParams
|
||||
}
|
||||
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
authCtx := oauthctx.WithUserID(ctx, p.UserID)
|
||||
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
|
||||
if err != nil {
|
||||
return agentpkg.RunConfig{}, models.GetResponse{}, sqlc.Provider{}, fmt.Errorf("resolve provider credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/compaction"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
)
|
||||
@@ -115,7 +116,8 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
|
||||
return compaction.TriggerConfig{}, err
|
||||
}
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, compactProvider)
|
||||
authCtx := oauthctx.WithUserID(ctx, req.UserID)
|
||||
creds, err := authResolver.ResolveModelCredentials(authCtx, compactProvider)
|
||||
if err != nil {
|
||||
return compaction.TriggerConfig{}, err
|
||||
}
|
||||
@@ -137,7 +139,6 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
|
||||
if compactModel.Config.ContextWindow != nil && *compactModel.Config.ContextWindow > 0 {
|
||||
cfg.MaxCompactTokens = *compactModel.Config.ContextWindow * 90 / 100
|
||||
}
|
||||
|
||||
// For sync compaction: keep only the last few messages (~2000 tokens ≈ 3 messages).
|
||||
// The summary provides reference context; if the LLM needs details,
|
||||
// it will use tools (memory_read, search) to look them up.
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/session"
|
||||
)
|
||||
@@ -82,7 +83,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
|
||||
return
|
||||
}
|
||||
|
||||
title := r.generateTitle(ctx, titleModel, provider, userQuery)
|
||||
title := r.generateTitle(ctx, req.UserID, titleModel, provider, userQuery)
|
||||
if title == "" {
|
||||
return
|
||||
}
|
||||
@@ -95,7 +96,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
|
||||
func (r *Resolver) generateTitle(ctx context.Context, userID string, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
|
||||
userSnippet := truncate(strings.TrimSpace(userQuery), titlePromptMaxInputChars)
|
||||
if userSnippet == "" {
|
||||
return ""
|
||||
@@ -106,7 +107,8 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
|
||||
"User: " + userSnippet
|
||||
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
authCtx := oauthctx.WithUserID(ctx, userID)
|
||||
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
|
||||
if err != nil {
|
||||
r.logger.Warn("title gen: failed to resolve provider credentials", slog.Any("error", err))
|
||||
return ""
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
GitHubOAuthClientID = "Iv1.b507a08c87ecfe98"
|
||||
GitHubOAuthScope = "read:user user:email"
|
||||
DefaultAPIBaseURL = "https://api.githubcopilot.com"
|
||||
|
||||
copilotTokenURL = "https://api.github.com/copilot_internal/v2/token" //nolint:gosec // Fixed GitHub API endpoint, not a credential.
|
||||
copilotEditorVersion = "vscode/1.110.1"
|
||||
copilotPluginVersion = "copilot-chat/0.38.2"
|
||||
copilotUserAgent = "GitHubCopilotChat/0.38.2"
|
||||
copilotAPIVersion = "2025-10-01"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotTokenRefreshSkew = time.Minute
|
||||
defaultHTTPClientTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
type cachedToken struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
var tokenCache = struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]cachedToken
|
||||
}{
|
||||
entries: map[string]cachedToken{},
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
func ResolveToken(ctx context.Context, githubToken string) (string, error) {
|
||||
githubToken = strings.TrimSpace(githubToken)
|
||||
if githubToken == "" {
|
||||
return "", errors.New("github token is required")
|
||||
}
|
||||
|
||||
if token, ok := loadCachedToken(githubToken); ok {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
token, expiresAt, err := FetchCopilotToken(ctx, githubToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
storeCachedToken(githubToken, token, expiresAt)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func FetchCopilotToken(ctx context.Context, githubToken string) (string, time.Time, error) {
|
||||
githubToken = strings.TrimSpace(githubToken)
|
||||
if githubToken == "" {
|
||||
return "", time.Time{}, errors.New("github token is required")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotTokenURL, nil)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("create copilot token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "token "+githubToken)
|
||||
req.Header.Set("Editor-Version", copilotEditorVersion)
|
||||
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||
req.Header.Set("User-Agent", copilotUserAgent)
|
||||
req.Header.Set("X-GitHub-Api-Version", copilotAPIVersion)
|
||||
|
||||
resp, err := defaultHTTPClient(nil).Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("fetch copilot token: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("read copilot token response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", time.Time{}, fmt.Errorf("copilot token request failed: %s", strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed tokenResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("decode copilot token response: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(parsed.Token) == "" {
|
||||
return "", time.Time{}, errors.New("copilot token response did not include a token")
|
||||
}
|
||||
|
||||
var expiresAt time.Time
|
||||
if parsed.ExpiresAt > 0 {
|
||||
expiresAt = time.Unix(parsed.ExpiresAt, 0).UTC()
|
||||
}
|
||||
return parsed.Token, expiresAt, nil
|
||||
}
|
||||
|
||||
func NewHTTPClient(base *http.Client) *http.Client {
|
||||
client := defaultHTTPClient(base)
|
||||
client.Transport = &headerRoundTripper{
|
||||
base: client.Transport,
|
||||
headers: map[string]string{
|
||||
"Copilot-Integration-Id": copilotIntegrationID,
|
||||
"Editor-Version": copilotEditorVersion,
|
||||
"Editor-Plugin-Version": copilotPluginVersion,
|
||||
"User-Agent": copilotUserAgent,
|
||||
"X-GitHub-Api-Version": copilotAPIVersion,
|
||||
},
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
type headerRoundTripper struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
clone := req.Clone(req.Context())
|
||||
clone.Header = req.Header.Clone()
|
||||
for key, value := range rt.headers {
|
||||
clone.Header.Set(key, value)
|
||||
}
|
||||
if rt.base == nil {
|
||||
rt.base = http.DefaultTransport
|
||||
}
|
||||
return rt.base.RoundTrip(clone)
|
||||
}
|
||||
|
||||
func loadCachedToken(githubToken string) (string, bool) {
|
||||
tokenCache.mu.Lock()
|
||||
defer tokenCache.mu.Unlock()
|
||||
|
||||
entry, ok := tokenCache.entries[githubToken]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if !entry.ExpiresAt.IsZero() && !time.Now().Add(copilotTokenRefreshSkew).Before(entry.ExpiresAt) {
|
||||
delete(tokenCache.entries, githubToken)
|
||||
return "", false
|
||||
}
|
||||
return entry.Token, true
|
||||
}
|
||||
|
||||
func storeCachedToken(githubToken, token string, expiresAt time.Time) {
|
||||
tokenCache.mu.Lock()
|
||||
defer tokenCache.mu.Unlock()
|
||||
tokenCache.entries[githubToken] = cachedToken{
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func defaultHTTPClient(base *http.Client) *http.Client {
|
||||
if base != nil {
|
||||
clone := *base
|
||||
if clone.Timeout == 0 {
|
||||
clone.Timeout = defaultHTTPClientTimeout
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
return &http.Client{Timeout: defaultHTTPClientTimeout}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestNewHTTPClientAddsCopilotHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
|
||||
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
|
||||
}
|
||||
if got := req.Header.Get("Editor-Version"); got != copilotEditorVersion {
|
||||
t.Fatalf("expected editor version %q, got %q", copilotEditorVersion, got)
|
||||
}
|
||||
if got := req.Header.Get("Editor-Plugin-Version"); got != copilotPluginVersion {
|
||||
t.Fatalf("expected plugin version %q, got %q", copilotPluginVersion, got)
|
||||
}
|
||||
if got := req.Header.Get("User-Agent"); got != copilotUserAgent {
|
||||
t.Fatalf("expected user agent %q, got %q", copilotUserAgent, got)
|
||||
}
|
||||
if got := req.Header.Get("X-GitHub-Api-Version"); got != copilotAPIVersion {
|
||||
t.Fatalf("expected api version %q, got %q", copilotAPIVersion, got)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`ok`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.githubcopilot.com/chat/completions", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := NewHTTPClient(base).Do(req) //nolint:gosec // Test request targets a fixed Copilot API endpoint.
|
||||
if err != nil {
|
||||
t.Fatalf("execute request: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestNewHTTPClientWithNilBaseDoesNotPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
|
||||
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := NewHTTPClient(nil).Do(req) //nolint:gosec // Test request targets an httptest server URL.
|
||||
if err != nil {
|
||||
t.Fatalf("execute request with nil base client: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
)
|
||||
|
||||
func NewProvider(copilotToken string, baseClient *http.Client) sdk.Provider {
|
||||
options := []githubcopilot.Option{
|
||||
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
|
||||
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
|
||||
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
|
||||
}
|
||||
return githubcopilot.New(options...)
|
||||
}
|
||||
|
||||
func NewModel(copilotToken, modelID string, baseClient *http.Client) *sdk.Model {
|
||||
options := []githubcopilot.Option{
|
||||
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
|
||||
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
|
||||
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
|
||||
}
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
modelID = githubcopilot.AutoModel
|
||||
}
|
||||
return githubcopilot.New(options...).ChatModel(modelID)
|
||||
}
|
||||
@@ -515,3 +515,19 @@ type UserChannelBinding struct {
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserProviderOauthToken struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: user_provider_oauth.sql
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteUserProviderOAuthToken = `-- name: DeleteUserProviderOAuthToken :exec
|
||||
DELETE FROM user_provider_oauth_tokens
|
||||
WHERE provider_id = $1
|
||||
AND user_id = $2
|
||||
`
|
||||
|
||||
type DeleteUserProviderOAuthTokenParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteUserProviderOAuthToken(ctx context.Context, arg DeleteUserProviderOAuthTokenParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteUserProviderOAuthToken, arg.ProviderID, arg.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserProviderOAuthToken = `-- name: GetUserProviderOAuthToken :one
|
||||
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
|
||||
WHERE provider_id = $1
|
||||
AND user_id = $2
|
||||
`
|
||||
|
||||
type GetUserProviderOAuthTokenParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserProviderOAuthToken(ctx context.Context, arg GetUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, getUserProviderOAuthToken, arg.ProviderID, arg.UserID)
|
||||
var i UserProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserProviderOAuthTokenByState = `-- name: GetUserProviderOAuthTokenByState :one
|
||||
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
|
||||
WHERE state = $1
|
||||
AND state != ''
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserProviderOAuthTokenByState(ctx context.Context, state string) (UserProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, getUserProviderOAuthTokenByState, state)
|
||||
var i UserProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateUserProviderOAuthState = `-- name: UpdateUserProviderOAuthState :exec
|
||||
INSERT INTO user_provider_oauth_tokens (provider_id, user_id, state, pkce_code_verifier, metadata)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5
|
||||
)
|
||||
ON CONFLICT (provider_id, user_id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = now()
|
||||
`
|
||||
|
||||
type UpdateUserProviderOAuthStateParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUserProviderOAuthState(ctx context.Context, arg UpdateUserProviderOAuthStateParams) error {
|
||||
_, err := q.db.Exec(ctx, updateUserProviderOAuthState,
|
||||
arg.ProviderID,
|
||||
arg.UserID,
|
||||
arg.State,
|
||||
arg.PkceCodeVerifier,
|
||||
arg.Metadata,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertUserProviderOAuthToken = `-- name: UpsertUserProviderOAuthToken :one
|
||||
INSERT INTO user_provider_oauth_tokens (
|
||||
provider_id,
|
||||
user_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
scope,
|
||||
token_type,
|
||||
state,
|
||||
pkce_code_verifier,
|
||||
metadata
|
||||
)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10
|
||||
)
|
||||
ON CONFLICT (provider_id, user_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
scope = EXCLUDED.scope,
|
||||
token_type = EXCLUDED.token_type,
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = now()
|
||||
RETURNING id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpsertUserProviderOAuthTokenParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpsertUserProviderOAuthToken(ctx context.Context, arg UpsertUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, upsertUserProviderOAuthToken,
|
||||
arg.ProviderID,
|
||||
arg.UserID,
|
||||
arg.AccessToken,
|
||||
arg.RefreshToken,
|
||||
arg.ExpiresAt,
|
||||
arg.Scope,
|
||||
arg.TokenType,
|
||||
arg.State,
|
||||
arg.PkceCodeVerifier,
|
||||
arg.Metadata,
|
||||
)
|
||||
var i UserProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
type ModelsHandler struct {
|
||||
@@ -301,7 +303,12 @@ func (h *ModelsHandler) Test(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(c.Request().Context(), id)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(ctx, id)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
@@ -20,6 +22,7 @@ func NewProviderOAuthHandler(service *providers.Service) *ProviderOAuthHandler {
|
||||
|
||||
func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
|
||||
e.GET("/providers/:id/oauth/authorize", h.Authorize)
|
||||
e.POST("/providers/:id/oauth/poll", h.Poll)
|
||||
e.GET("/providers/:id/oauth/status", h.Status)
|
||||
e.DELETE("/providers/:id/oauth/token", h.Revoke)
|
||||
e.GET("/auth/callback", h.Callback)
|
||||
@@ -30,7 +33,7 @@ func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
|
||||
// @Summary Start OAuth2 authorization for an LLM provider
|
||||
// @Tags providers-oauth
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Success 200 {object} providers.OAuthAuthorizeResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /providers/{id}/oauth/authorize [get].
|
||||
@@ -39,11 +42,39 @@ func (h *ProviderOAuthHandler) Authorize(c echo.Context) error {
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
authURL, err := h.service.StartOAuthAuthorization(c.Request().Context(), providerID)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
resp, err := h.service.StartOAuthAuthorization(ctx, providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"auth_url": authURL})
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Poll godoc
|
||||
// @Summary Poll OAuth device authorization for an LLM provider
|
||||
// @Tags providers-oauth
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Success 200 {object} providers.OAuthStatus
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /providers/{id}/oauth/poll [post].
|
||||
func (h *ProviderOAuthHandler) Poll(c echo.Context) error {
|
||||
providerID := strings.TrimSpace(c.Param("id"))
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
status, err := h.service.PollOAuthAuthorization(ctx, providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
// Status godoc
|
||||
@@ -59,7 +90,11 @@ func (h *ProviderOAuthHandler) Status(c echo.Context) error {
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
status, err := h.service.GetOAuthStatus(c.Request().Context(), providerID)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
status, err := h.service.GetOAuthStatus(ctx, providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
@@ -79,7 +114,11 @@ func (h *ProviderOAuthHandler) Revoke(c echo.Context) error {
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
if err := h.service.RevokeOAuthToken(c.Request().Context(), providerID); err != nil {
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
if err := h.service.RevokeOAuthToken(ctx, providerID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
@@ -111,11 +150,11 @@ func (h *ProviderOAuthHandler) Callback(c echo.Context) error {
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>OpenAI OAuth Connected</title>
|
||||
<title>Provider Connected</title>
|
||||
</head>
|
||||
<body style="font-family: sans-serif; padding: 24px;">
|
||||
<h2>OpenAI OAuth connected</h2>
|
||||
<p>You can close this window and return to Memoh.</p>
|
||||
<h2>Provider connected</h2>
|
||||
<p>Your current Memoh account is now connected.</p>
|
||||
<script>
|
||||
window.opener?.postMessage({ type: "memoh-provider-oauth-success", providerId: "{{.ProviderID}}" }, "*");
|
||||
window.close();
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
@@ -272,7 +274,12 @@ func (h *ProvidersHandler) Test(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(c.Request().Context(), id)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(ctx, id)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
@@ -301,7 +308,12 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
remoteModels, err := h.service.FetchRemoteModels(c.Request().Context(), id)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
|
||||
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/memohai/memoh/internal/healthcheck"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,6 +26,7 @@ type BotModelLookup interface {
|
||||
|
||||
// BotModels holds the model UUIDs associated with a bot.
|
||||
type BotModels struct {
|
||||
OwnerUserID string
|
||||
ChatModelID string
|
||||
MemoryModelID string
|
||||
EmbeddingModelID string
|
||||
@@ -115,7 +117,8 @@ func (c *Checker) ListChecks(ctx context.Context, botID string) []healthcheck.Ch
|
||||
wg.Add(1)
|
||||
go func(idx int, s modelSlot) {
|
||||
defer wg.Done()
|
||||
results[idx] = c.probeSlot(probeCtx, s)
|
||||
slotCtx := oauthctx.WithUserID(probeCtx, botModels.OwnerUserID)
|
||||
results[idx] = c.probeSlot(slotCtx, s)
|
||||
}(i, slot)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package modelchecker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/healthcheck"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
type testLookup struct {
|
||||
models BotModels
|
||||
}
|
||||
|
||||
func (l testLookup) GetBotModelIDs(context.Context, string) (BotModels, error) {
|
||||
return l.models, nil
|
||||
}
|
||||
|
||||
type testProber struct {
|
||||
t *testing.T
|
||||
wantUserID string
|
||||
}
|
||||
|
||||
func (p testProber) Test(ctx context.Context, id string) (models.TestResponse, error) {
|
||||
if got := oauthctx.UserIDFromContext(ctx); got != p.wantUserID {
|
||||
p.t.Fatalf("expected oauth user id %q, got %q", p.wantUserID, got)
|
||||
}
|
||||
if id != "model-chat-1" {
|
||||
p.t.Fatalf("expected model id %q, got %q", "model-chat-1", id)
|
||||
}
|
||||
return models.TestResponse{
|
||||
Status: models.TestStatusOK,
|
||||
Reachable: true,
|
||||
Message: "ok",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestCheckerListChecksInjectsOwnerUserIDIntoProbeContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checker := NewChecker(nil, testLookup{
|
||||
models: BotModels{
|
||||
OwnerUserID: "user-123",
|
||||
ChatModelID: "model-chat-1",
|
||||
},
|
||||
}, testProber{
|
||||
t: t,
|
||||
wantUserID: "user-123",
|
||||
})
|
||||
|
||||
items := checker.ListChecks(context.Background(), "bot-1")
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 check result, got %d", len(items))
|
||||
}
|
||||
if items[0].Status != healthcheck.StatusOK {
|
||||
t.Fatalf("expected status %q, got %q", healthcheck.StatusOK, items[0].Status)
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,7 @@ func (l *QueriesLookup) GetBotModelIDs(ctx context.Context, botID string) (BotMo
|
||||
}
|
||||
|
||||
var m BotModels
|
||||
m.OwnerUserID = bot.OwnerUserID.String()
|
||||
if bot.ChatModelID.Valid {
|
||||
m.ChatModelID = bot.ChatModelID.String()
|
||||
}
|
||||
|
||||
@@ -429,6 +429,7 @@ func IsValidClientType(clientType ClientType) bool {
|
||||
ClientTypeAnthropicMessages,
|
||||
ClientTypeGoogleGenerativeAI,
|
||||
ClientTypeOpenAICodex,
|
||||
ClientTypeGitHubCopilot,
|
||||
ClientTypeEdgeSpeech:
|
||||
return true
|
||||
default:
|
||||
|
||||
+53
-12
@@ -17,8 +17,10 @@ import (
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
const probeTimeout = 15 * time.Second
|
||||
@@ -162,6 +164,9 @@ func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientTyp
|
||||
}
|
||||
return openaicodex.New(opts...)
|
||||
|
||||
case ClientTypeGitHubCopilot:
|
||||
return memohcopilot.NewProvider(apiKey, httpClient)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(apiKey),
|
||||
@@ -202,26 +207,62 @@ type modelCredentials struct {
|
||||
func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.Provider) (modelCredentials, error) {
|
||||
apiKey := providerConfigString(provider.Config, "api_key")
|
||||
|
||||
if ClientType(provider.ClientType) != ClientTypeOpenAICodex {
|
||||
switch ClientType(provider.ClientType) {
|
||||
case ClientTypeGitHubCopilot:
|
||||
token, err := s.resolveGitHubCopilotAccessToken(ctx, provider)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
return modelCredentials{APIKey: token}, nil
|
||||
|
||||
case ClientTypeOpenAICodex:
|
||||
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenRow.AccessToken)
|
||||
if accessToken == "" {
|
||||
return modelCredentials{}, errors.New("oauth token is missing access token")
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(accessToken)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
return modelCredentials{
|
||||
APIKey: accessToken,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return modelCredentials{APIKey: apiKey}, nil
|
||||
}
|
||||
}
|
||||
|
||||
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
func (s *Service) resolveGitHubCopilotAccessToken(ctx context.Context, provider sqlc.Provider) (string, error) {
|
||||
userID := oauthctx.UserIDFromContext(ctx)
|
||||
if userID == "" {
|
||||
return "", errors.New("github copilot requires a current user")
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenRow.AccessToken)
|
||||
userUUID, err := db.ParseUUID(userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := s.queries.GetUserProviderOAuthToken(ctx, sqlc.GetUserProviderOAuthTokenParams{
|
||||
ProviderID: provider.ID,
|
||||
UserID: userUUID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
accessToken := strings.TrimSpace(row.AccessToken)
|
||||
if accessToken == "" {
|
||||
return modelCredentials{}, errors.New("oauth token is missing access token")
|
||||
return "", errors.New("oauth token is missing access token")
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(accessToken)
|
||||
copilotToken, err := memohcopilot.ResolveToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
return "", err
|
||||
}
|
||||
return modelCredentials{
|
||||
APIKey: accessToken,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
return copilotToken, nil
|
||||
}
|
||||
|
||||
func codexAccountIDFromToken(token string) (string, error) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
)
|
||||
|
||||
// SDKModelConfig holds provider and model information resolved from DB,
|
||||
@@ -76,6 +78,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
}
|
||||
return openaicodex.New(opts...).ChatModel(cfg.ModelID)
|
||||
|
||||
case ClientTypeGitHubCopilot:
|
||||
return memohcopilot.NewModel(cfg.APIKey, cfg.ModelID, cfg.HTTPClient)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(cfg.APIKey),
|
||||
@@ -178,6 +183,8 @@ func ResolveClientType(model *sdk.Model) string {
|
||||
return string(ClientTypeAnthropicMessages)
|
||||
case strings.Contains(name, "google"):
|
||||
return string(ClientTypeGoogleGenerativeAI)
|
||||
case strings.Contains(name, "github-copilot"), strings.Contains(name, "copilot"):
|
||||
return string(ClientTypeGitHubCopilot)
|
||||
case strings.Contains(name, "codex"):
|
||||
return string(ClientTypeOpenAICodex)
|
||||
case strings.Contains(name, "responses"):
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
|
||||
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
|
||||
ClientTypeOpenAICodex ClientType = "openai-codex"
|
||||
ClientTypeGitHubCopilot ClientType = "github-copilot"
|
||||
ClientTypeEdgeSpeech ClientType = "edge-speech"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package oauthctx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type userIDContextKey struct{}
|
||||
|
||||
func WithUserID(ctx context.Context, userID string) context.Context {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, userIDContextKey{}, userID)
|
||||
}
|
||||
|
||||
func UserIDFromContext(ctx context.Context) string {
|
||||
userID, _ := ctx.Value(userIDContextKey{}).(string)
|
||||
return strings.TrimSpace(userID)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
@@ -24,25 +25,38 @@ func SupportsOpenAICodexOAuth(provider sqlc.Provider) bool {
|
||||
}
|
||||
|
||||
func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.Provider) (ModelCredentials, error) {
|
||||
if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex {
|
||||
switch models.ClientType(provider.ClientType) {
|
||||
case models.ClientTypeGitHubCopilot:
|
||||
githubToken, err := s.GetValidAccessToken(ctx, provider.ID.String())
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
copilotToken, err := memohcopilot.ResolveToken(ctx, githubToken)
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
return ModelCredentials{APIKey: copilotToken}, nil
|
||||
|
||||
case models.ClientTypeOpenAICodex:
|
||||
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(token)
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
return ModelCredentials{
|
||||
APIKey: token,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
apiKey := ProviderConfigString(provider, "api_key")
|
||||
return ModelCredentials{
|
||||
APIKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(token)
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
return ModelCredentials{
|
||||
APIKey: token,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func codexAccountIDFromToken(token string) (string, error) {
|
||||
|
||||
+956
-113
File diff suppressed because it is too large
Load Diff
@@ -11,9 +11,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
|
||||
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
@@ -47,15 +49,14 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
|
||||
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(req.Config)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
clientType := req.ClientType
|
||||
if clientType == "" {
|
||||
clientType = string(models.ClientTypeOpenAICompletions)
|
||||
}
|
||||
configJSON, err := json.Marshal(normalizeProviderConfig(clientType, req.Config))
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
var icon pgtype.Text
|
||||
if req.Icon != "" {
|
||||
@@ -150,12 +151,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
|
||||
existingConfig := providerConfig(existing.Config)
|
||||
if req.Config != nil {
|
||||
existingAPIKey := configString(existingConfig, "api_key")
|
||||
newAPIKey := configString(req.Config, "api_key")
|
||||
if newAPIKey != "" && newAPIKey == maskAPIKey(existingAPIKey) {
|
||||
req.Config["api_key"] = existingAPIKey
|
||||
}
|
||||
existingConfig = req.Config
|
||||
mergedConfig := mergeProviderConfig(existingConfig, req.Config)
|
||||
preserveMaskedConfigSecret(mergedConfig, existingConfig, req.Config, "api_key")
|
||||
existingConfig = normalizeProviderConfig(clientType, mergedConfig)
|
||||
} else {
|
||||
existingConfig = normalizeProviderConfig(clientType, existingConfig)
|
||||
}
|
||||
configJSON, err := json.Marshal(existingConfig)
|
||||
if err != nil {
|
||||
@@ -257,6 +257,34 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
if models.ClientType(provider.ClientType) == models.ClientTypeGitHubCopilot {
|
||||
creds, err := s.ResolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkProvider := memohcopilot.NewProvider(creds.APIKey, nil)
|
||||
if result := sdkProvider.Test(ctx); result.Status != sdk.ProviderStatusOK {
|
||||
return nil, fmt.Errorf("github copilot provider test failed: %s", result.Message)
|
||||
}
|
||||
|
||||
catalog := githubcopilot.Catalog()
|
||||
remoteModels := make([]RemoteModel, 0, len(catalog))
|
||||
for _, model := range catalog {
|
||||
remoteModels = append(remoteModels, RemoteModel{
|
||||
ID: model.ID,
|
||||
Name: model.DisplayName,
|
||||
Object: "model",
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "chat",
|
||||
Compatibilities: []string{
|
||||
models.CompatVision,
|
||||
models.CompatToolCall,
|
||||
models.CompatReasoning,
|
||||
},
|
||||
})
|
||||
}
|
||||
return remoteModels, nil
|
||||
}
|
||||
if supportsOAuth(provider) {
|
||||
catalog := openaicodex.Catalog()
|
||||
remoteModels := make([]RemoteModel, 0, len(catalog))
|
||||
@@ -329,7 +357,7 @@ func (s *Service) toGetResponse(provider sqlc.Provider) GetResponse {
|
||||
}
|
||||
|
||||
cfg := providerConfig(provider.Config)
|
||||
maskedCfg := maskConfigAPIKey(cfg)
|
||||
maskedCfg := maskConfigSecrets(provider.ClientType, cfg)
|
||||
|
||||
var icon string
|
||||
if provider.Icon.Valid {
|
||||
@@ -378,14 +406,51 @@ func ProviderConfigString(provider sqlc.Provider, key string) string {
|
||||
return configString(providerConfig(provider.Config), key)
|
||||
}
|
||||
|
||||
// maskConfigAPIKey returns a copy of config with api_key masked.
|
||||
func maskConfigAPIKey(cfg map[string]any) map[string]any {
|
||||
func cloneConfig(cfg map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(cfg))
|
||||
for k, v := range cfg {
|
||||
result[k] = v
|
||||
}
|
||||
if apiKey, _ := result["api_key"].(string); apiKey != "" {
|
||||
result["api_key"] = maskAPIKey(apiKey)
|
||||
return result
|
||||
}
|
||||
|
||||
func mergeProviderConfig(existing, incoming map[string]any) map[string]any {
|
||||
result := cloneConfig(existing)
|
||||
for k, v := range incoming {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func preserveMaskedConfigSecret(merged, existing, incoming map[string]any, key string) {
|
||||
existingValue := strings.TrimSpace(configString(existing, key))
|
||||
newValue := strings.TrimSpace(configString(incoming, key))
|
||||
if existingValue == "" || newValue == "" {
|
||||
return
|
||||
}
|
||||
if newValue == maskAPIKey(existingValue) {
|
||||
merged[key] = existingValue
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeProviderConfig keeps provider-specific secrets under stable keys while
|
||||
// preserving backward compatibility for legacy stored configs.
|
||||
func normalizeProviderConfig(clientType string, cfg map[string]any) map[string]any {
|
||||
result := cloneConfig(cfg)
|
||||
if models.ClientType(clientType) == models.ClientTypeGitHubCopilot {
|
||||
delete(result, "api_key")
|
||||
delete(result, configOAuthClientSecretKey)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// maskConfigSecrets returns a copy of config with all known secret fields masked.
|
||||
func maskConfigSecrets(clientType string, cfg map[string]any) map[string]any {
|
||||
result := normalizeProviderConfig(clientType, cfg)
|
||||
for _, key := range []string{"api_key", configOAuthClientSecretKey} {
|
||||
if value, _ := result[key].(string); value != "" {
|
||||
result[key] = maskAPIKey(value)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
func TestMaskAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -34,3 +40,166 @@ func TestMaskAPIKey(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeProviderConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("github copilot drops legacy secrets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := normalizeProviderConfig("github-copilot", map[string]any{
|
||||
"api_key": "gh-secret",
|
||||
configOAuthClientSecretKey: "oauth-secret",
|
||||
"base_url": "ignored",
|
||||
})
|
||||
|
||||
if _, exists := cfg[configOAuthClientSecretKey]; exists {
|
||||
t.Fatalf("expected oauth client secret to be removed, got %#v", cfg[configOAuthClientSecretKey])
|
||||
}
|
||||
if _, exists := cfg["api_key"]; exists {
|
||||
t.Fatalf("expected legacy api_key to be removed, got %#v", cfg["api_key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non copilot providers keep api key key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := normalizeProviderConfig("openai-completions", map[string]any{
|
||||
"api_key": "sk-live",
|
||||
})
|
||||
|
||||
if got, ok := cfg["api_key"].(string); !ok || got != "sk-live" {
|
||||
t.Fatalf("expected api_key to remain untouched, got %#v", cfg["api_key"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaskConfigSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := maskConfigSecrets("openai-completions", map[string]any{
|
||||
"api_key": "sk-secret-123456",
|
||||
})
|
||||
|
||||
masked, _ := cfg["api_key"].(string)
|
||||
if masked == "" || masked == "sk-secret-123456" {
|
||||
t.Fatalf("expected api key to be masked, got %q", masked)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreserveMaskedConfigSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
merged := map[string]any{
|
||||
configOAuthClientSecretKey: "*************",
|
||||
}
|
||||
existing := map[string]any{
|
||||
configOAuthClientSecretKey: "gh-secret-1234",
|
||||
}
|
||||
incoming := map[string]any{
|
||||
configOAuthClientSecretKey: maskAPIKey("gh-secret-1234"),
|
||||
}
|
||||
|
||||
preserveMaskedConfigSecret(merged, existing, incoming, configOAuthClientSecretKey)
|
||||
|
||||
if got, _ := merged[configOAuthClientSecretKey].(string); got != "gh-secret-1234" {
|
||||
t.Fatalf("expected masked value to be restored to original secret, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceMetadataRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expiresAt := time.Date(2026, time.April, 11, 12, 0, 0, 0, time.UTC)
|
||||
device := oauthDeviceMetadata{
|
||||
DeviceCode: "device-code",
|
||||
UserCode: "ABCD-EFGH",
|
||||
VerificationURI: "https://github.com/login/device",
|
||||
ExpiresAt: expiresAt,
|
||||
IntervalSeconds: 5,
|
||||
}
|
||||
|
||||
parsed := deviceMetadataFromMap(device.toMetadata())
|
||||
if parsed.DeviceCode != device.DeviceCode {
|
||||
t.Fatalf("expected device code %q, got %q", device.DeviceCode, parsed.DeviceCode)
|
||||
}
|
||||
if parsed.UserCode != device.UserCode {
|
||||
t.Fatalf("expected user code %q, got %q", device.UserCode, parsed.UserCode)
|
||||
}
|
||||
if parsed.VerificationURI != device.VerificationURI {
|
||||
t.Fatalf("expected verification uri %q, got %q", device.VerificationURI, parsed.VerificationURI)
|
||||
}
|
||||
if !parsed.ExpiresAt.Equal(expiresAt) {
|
||||
t.Fatalf("expected expiresAt %s, got %s", expiresAt, parsed.ExpiresAt)
|
||||
}
|
||||
if parsed.IntervalSeconds != device.IntervalSeconds {
|
||||
t.Fatalf("expected interval %d, got %d", device.IntervalSeconds, parsed.IntervalSeconds)
|
||||
}
|
||||
|
||||
status := parsed.toStatus()
|
||||
if status == nil || !status.Pending {
|
||||
t.Fatalf("expected pending device status, got %#v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountMetadataRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
account := oauthAccountMetadata{
|
||||
Label: "octocat",
|
||||
Login: "octocat",
|
||||
Name: "The Octocat",
|
||||
Email: "octocat@github.com",
|
||||
AvatarURL: "https://avatars.githubusercontent.com/u/1?v=4",
|
||||
ProfileURL: "https://github.com/octocat",
|
||||
}
|
||||
|
||||
parsed := accountMetadataFromMap(account.toMetadata())
|
||||
if parsed.Label != account.Label {
|
||||
t.Fatalf("expected label %q, got %q", account.Label, parsed.Label)
|
||||
}
|
||||
if parsed.Login != account.Login {
|
||||
t.Fatalf("expected login %q, got %q", account.Login, parsed.Login)
|
||||
}
|
||||
if parsed.Name != account.Name {
|
||||
t.Fatalf("expected name %q, got %q", account.Name, parsed.Name)
|
||||
}
|
||||
if parsed.Email != account.Email {
|
||||
t.Fatalf("expected email %q, got %q", account.Email, parsed.Email)
|
||||
}
|
||||
if parsed.AvatarURL != account.AvatarURL {
|
||||
t.Fatalf("expected avatar url %q, got %q", account.AvatarURL, parsed.AvatarURL)
|
||||
}
|
||||
if parsed.ProfileURL != account.ProfileURL {
|
||||
t.Fatalf("expected profile url %q, got %q", account.ProfileURL, parsed.ProfileURL)
|
||||
}
|
||||
|
||||
status := parsed.toStatus()
|
||||
if status == nil {
|
||||
t.Fatal("expected account status")
|
||||
}
|
||||
if status.Label != account.Label {
|
||||
t.Fatalf("expected status label %q, got %q", account.Label, status.Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthConfigForGitHubCopilotUsesFixedDeviceFlowSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := &Service{}
|
||||
cfg := service.oauthConfigForProvider(sqlc.Provider{
|
||||
ClientType: string(models.ClientTypeGitHubCopilot),
|
||||
Config: []byte(`{"api_key":"legacy","oauth_client_secret":"legacy-secret"}`),
|
||||
Metadata: []byte(`{"oauth_client_id":"custom","oauth_scopes":"repo"}`),
|
||||
})
|
||||
|
||||
if cfg.ClientID != "Iv1.b507a08c87ecfe98" {
|
||||
t.Fatalf("expected fixed client id, got %q", cfg.ClientID)
|
||||
}
|
||||
if cfg.ClientSecret != "" {
|
||||
t.Fatalf("expected empty client secret, got %q", cfg.ClientSecret)
|
||||
}
|
||||
if cfg.Scopes != "read:user user:email" {
|
||||
t.Fatalf("expected fixed scope, got %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,11 +54,37 @@ type TestResponse struct {
|
||||
|
||||
// OAuthStatus is returned by GET /providers/:id/oauth/status.
|
||||
type OAuthStatus struct {
|
||||
Configured bool `json:"configured"`
|
||||
HasToken bool `json:"has_token"`
|
||||
Expired bool `json:"expired"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
Configured bool `json:"configured"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
HasToken bool `json:"has_token"`
|
||||
Expired bool `json:"expired"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
Device *OAuthDeviceStatus `json:"device,omitempty"`
|
||||
Account *OAuthAccount `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthDeviceStatus struct {
|
||||
Pending bool `json:"pending"`
|
||||
UserCode string `json:"user_code,omitempty"`
|
||||
VerificationURI string `json:"verification_uri,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
IntervalSeconds int64 `json:"interval_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthAccount struct {
|
||||
Label string `json:"label,omitempty"`
|
||||
Login string `json:"login,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
AvatarURL string `json:"avatar_url,omitempty"`
|
||||
ProfileURL string `json:"profile_url,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthAuthorizeResponse struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
AuthURL string `json:"auth_url,omitempty"`
|
||||
Device *OAuthDeviceStatus `json:"device,omitempty"`
|
||||
}
|
||||
|
||||
// RemoteModel represents a model returned by the provider's /v1/models endpoint.
|
||||
|
||||
Reference in New Issue
Block a user