feat(provider): add github copilot device flow provider (#364)

This commit is contained in:
LiBr
2026-04-13 19:38:33 +08:00
committed by GitHub
parent a40207ab6d
commit df8fbd8859
36 changed files with 2659 additions and 246 deletions
+3 -1
View File
@@ -28,6 +28,7 @@ import (
messagepkg "github.com/memohai/memoh/internal/message"
messageevent "github.com/memohai/memoh/internal/message/event"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
pipelinepkg "github.com/memohai/memoh/internal/pipeline"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/settings"
@@ -499,7 +500,8 @@ func (r *Resolver) buildBaseRunConfig(ctx context.Context, p baseRunConfigParams
}
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
authCtx := oauthctx.WithUserID(ctx, p.UserID)
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
if err != nil {
return agentpkg.RunConfig{}, models.GetResponse{}, sqlc.Provider{}, fmt.Errorf("resolve provider credentials: %w", err)
}
@@ -7,6 +7,7 @@ import (
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/settings"
)
@@ -115,7 +116,8 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
return compaction.TriggerConfig{}, err
}
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, compactProvider)
authCtx := oauthctx.WithUserID(ctx, req.UserID)
creds, err := authResolver.ResolveModelCredentials(authCtx, compactProvider)
if err != nil {
return compaction.TriggerConfig{}, err
}
@@ -137,7 +139,6 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
if compactModel.Config.ContextWindow != nil && *compactModel.Config.ContextWindow > 0 {
cfg.MaxCompactTokens = *compactModel.Config.ContextWindow * 90 / 100
}
// For sync compaction: keep only the last few messages (~2000 tokens ≈ 3 messages).
// The summary provides reference context; if the LLM needs details,
// it will use tools (memory_read, search) to look them up.
+5 -3
View File
@@ -13,6 +13,7 @@ import (
"github.com/memohai/memoh/internal/db/sqlc"
messageevent "github.com/memohai/memoh/internal/message/event"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/session"
)
@@ -82,7 +83,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
return
}
title := r.generateTitle(ctx, titleModel, provider, userQuery)
title := r.generateTitle(ctx, req.UserID, titleModel, provider, userQuery)
if title == "" {
return
}
@@ -95,7 +96,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
}
}
func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
func (r *Resolver) generateTitle(ctx context.Context, userID string, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
userSnippet := truncate(strings.TrimSpace(userQuery), titlePromptMaxInputChars)
if userSnippet == "" {
return ""
@@ -106,7 +107,8 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
"User: " + userSnippet
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
authCtx := oauthctx.WithUserID(ctx, userID)
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
if err != nil {
r.logger.Warn("title gen: failed to resolve provider credentials", slog.Any("error", err))
return ""
+176
View File
@@ -0,0 +1,176 @@
package copilot
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
)
const (
GitHubOAuthClientID = "Iv1.b507a08c87ecfe98"
GitHubOAuthScope = "read:user user:email"
DefaultAPIBaseURL = "https://api.githubcopilot.com"
copilotTokenURL = "https://api.github.com/copilot_internal/v2/token" //nolint:gosec // Fixed GitHub API endpoint, not a credential.
copilotEditorVersion = "vscode/1.110.1"
copilotPluginVersion = "copilot-chat/0.38.2"
copilotUserAgent = "GitHubCopilotChat/0.38.2"
copilotAPIVersion = "2025-10-01"
copilotIntegrationID = "vscode-chat"
copilotTokenRefreshSkew = time.Minute
defaultHTTPClientTimeout = 15 * time.Second
)
type cachedToken struct {
Token string
ExpiresAt time.Time
}
var tokenCache = struct {
mu sync.Mutex
entries map[string]cachedToken
}{
entries: map[string]cachedToken{},
}
type tokenResponse struct {
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
}
func ResolveToken(ctx context.Context, githubToken string) (string, error) {
githubToken = strings.TrimSpace(githubToken)
if githubToken == "" {
return "", errors.New("github token is required")
}
if token, ok := loadCachedToken(githubToken); ok {
return token, nil
}
token, expiresAt, err := FetchCopilotToken(ctx, githubToken)
if err != nil {
return "", err
}
storeCachedToken(githubToken, token, expiresAt)
return token, nil
}
func FetchCopilotToken(ctx context.Context, githubToken string) (string, time.Time, error) {
githubToken = strings.TrimSpace(githubToken)
if githubToken == "" {
return "", time.Time{}, errors.New("github token is required")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotTokenURL, nil)
if err != nil {
return "", time.Time{}, fmt.Errorf("create copilot token request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "token "+githubToken)
req.Header.Set("Editor-Version", copilotEditorVersion)
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
req.Header.Set("User-Agent", copilotUserAgent)
req.Header.Set("X-GitHub-Api-Version", copilotAPIVersion)
resp, err := defaultHTTPClient(nil).Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
if err != nil {
return "", time.Time{}, fmt.Errorf("fetch copilot token: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", time.Time{}, fmt.Errorf("read copilot token response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", time.Time{}, fmt.Errorf("copilot token request failed: %s", strings.TrimSpace(string(body)))
}
var parsed tokenResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return "", time.Time{}, fmt.Errorf("decode copilot token response: %w", err)
}
if strings.TrimSpace(parsed.Token) == "" {
return "", time.Time{}, errors.New("copilot token response did not include a token")
}
var expiresAt time.Time
if parsed.ExpiresAt > 0 {
expiresAt = time.Unix(parsed.ExpiresAt, 0).UTC()
}
return parsed.Token, expiresAt, nil
}
func NewHTTPClient(base *http.Client) *http.Client {
client := defaultHTTPClient(base)
client.Transport = &headerRoundTripper{
base: client.Transport,
headers: map[string]string{
"Copilot-Integration-Id": copilotIntegrationID,
"Editor-Version": copilotEditorVersion,
"Editor-Plugin-Version": copilotPluginVersion,
"User-Agent": copilotUserAgent,
"X-GitHub-Api-Version": copilotAPIVersion,
},
}
return client
}
type headerRoundTripper struct {
base http.RoundTripper
headers map[string]string
}
func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
clone := req.Clone(req.Context())
clone.Header = req.Header.Clone()
for key, value := range rt.headers {
clone.Header.Set(key, value)
}
if rt.base == nil {
rt.base = http.DefaultTransport
}
return rt.base.RoundTrip(clone)
}
func loadCachedToken(githubToken string) (string, bool) {
tokenCache.mu.Lock()
defer tokenCache.mu.Unlock()
entry, ok := tokenCache.entries[githubToken]
if !ok {
return "", false
}
if !entry.ExpiresAt.IsZero() && !time.Now().Add(copilotTokenRefreshSkew).Before(entry.ExpiresAt) {
delete(tokenCache.entries, githubToken)
return "", false
}
return entry.Token, true
}
func storeCachedToken(githubToken, token string, expiresAt time.Time) {
tokenCache.mu.Lock()
defer tokenCache.mu.Unlock()
tokenCache.entries[githubToken] = cachedToken{
Token: token,
ExpiresAt: expiresAt,
}
}
func defaultHTTPClient(base *http.Client) *http.Client {
if base != nil {
clone := *base
if clone.Timeout == 0 {
clone.Timeout = defaultHTTPClientTimeout
}
return &clone
}
return &http.Client{Timeout: defaultHTTPClientTimeout}
}
+80
View File
@@ -0,0 +1,80 @@
package copilot
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestNewHTTPClientAddsCopilotHeaders(t *testing.T) {
t.Parallel()
base := &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
}
if got := req.Header.Get("Editor-Version"); got != copilotEditorVersion {
t.Fatalf("expected editor version %q, got %q", copilotEditorVersion, got)
}
if got := req.Header.Get("Editor-Plugin-Version"); got != copilotPluginVersion {
t.Fatalf("expected plugin version %q, got %q", copilotPluginVersion, got)
}
if got := req.Header.Get("User-Agent"); got != copilotUserAgent {
t.Fatalf("expected user agent %q, got %q", copilotUserAgent, got)
}
if got := req.Header.Get("X-GitHub-Api-Version"); got != copilotAPIVersion {
t.Fatalf("expected api version %q, got %q", copilotAPIVersion, got)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`ok`)),
Header: make(http.Header),
}, nil
}),
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.githubcopilot.com/chat/completions", nil)
if err != nil {
t.Fatalf("create request: %v", err)
}
resp, err := NewHTTPClient(base).Do(req) //nolint:gosec // Test request targets a fixed Copilot API endpoint.
if err != nil {
t.Fatalf("execute request: %v", err)
}
_ = resp.Body.Close()
}
func TestNewHTTPClientWithNilBaseDoesNotPanic(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`ok`))
}))
defer server.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("create request: %v", err)
}
resp, err := NewHTTPClient(nil).Do(req) //nolint:gosec // Test request targets an httptest server URL.
if err != nil {
t.Fatalf("execute request with nil base client: %v", err)
}
_ = resp.Body.Close()
}
+30
View File
@@ -0,0 +1,30 @@
package copilot
import (
"net/http"
"strings"
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
sdk "github.com/memohai/twilight-ai/sdk"
)
func NewProvider(copilotToken string, baseClient *http.Client) sdk.Provider {
options := []githubcopilot.Option{
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
}
return githubcopilot.New(options...)
}
func NewModel(copilotToken, modelID string, baseClient *http.Client) *sdk.Model {
options := []githubcopilot.Option{
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
}
if strings.TrimSpace(modelID) == "" {
modelID = githubcopilot.AutoModel
}
return githubcopilot.New(options...).ChatModel(modelID)
}
+16
View File
@@ -515,3 +515,19 @@ type UserChannelBinding struct {
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type UserProviderOauthToken struct {
ID pgtype.UUID `json:"id"`
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
+205
View File
@@ -0,0 +1,205 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: user_provider_oauth.sql
package sqlc
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteUserProviderOAuthToken = `-- name: DeleteUserProviderOAuthToken :exec
DELETE FROM user_provider_oauth_tokens
WHERE provider_id = $1
AND user_id = $2
`
type DeleteUserProviderOAuthTokenParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) DeleteUserProviderOAuthToken(ctx context.Context, arg DeleteUserProviderOAuthTokenParams) error {
_, err := q.db.Exec(ctx, deleteUserProviderOAuthToken, arg.ProviderID, arg.UserID)
return err
}
const getUserProviderOAuthToken = `-- name: GetUserProviderOAuthToken :one
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
WHERE provider_id = $1
AND user_id = $2
`
type GetUserProviderOAuthTokenParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) GetUserProviderOAuthToken(ctx context.Context, arg GetUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
row := q.db.QueryRow(ctx, getUserProviderOAuthToken, arg.ProviderID, arg.UserID)
var i UserProviderOauthToken
err := row.Scan(
&i.ID,
&i.ProviderID,
&i.UserID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getUserProviderOAuthTokenByState = `-- name: GetUserProviderOAuthTokenByState :one
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
WHERE state = $1
AND state != ''
`
func (q *Queries) GetUserProviderOAuthTokenByState(ctx context.Context, state string) (UserProviderOauthToken, error) {
row := q.db.QueryRow(ctx, getUserProviderOAuthTokenByState, state)
var i UserProviderOauthToken
err := row.Scan(
&i.ID,
&i.ProviderID,
&i.UserID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateUserProviderOAuthState = `-- name: UpdateUserProviderOAuthState :exec
INSERT INTO user_provider_oauth_tokens (provider_id, user_id, state, pkce_code_verifier, metadata)
VALUES (
$1,
$2,
$3,
$4,
$5
)
ON CONFLICT (provider_id, user_id) DO UPDATE SET
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
metadata = EXCLUDED.metadata,
updated_at = now()
`
type UpdateUserProviderOAuthStateParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) UpdateUserProviderOAuthState(ctx context.Context, arg UpdateUserProviderOAuthStateParams) error {
_, err := q.db.Exec(ctx, updateUserProviderOAuthState,
arg.ProviderID,
arg.UserID,
arg.State,
arg.PkceCodeVerifier,
arg.Metadata,
)
return err
}
const upsertUserProviderOAuthToken = `-- name: UpsertUserProviderOAuthToken :one
INSERT INTO user_provider_oauth_tokens (
provider_id,
user_id,
access_token,
refresh_token,
expires_at,
scope,
token_type,
state,
pkce_code_verifier,
metadata
)
VALUES (
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10
)
ON CONFLICT (provider_id, user_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
expires_at = EXCLUDED.expires_at,
scope = EXCLUDED.scope,
token_type = EXCLUDED.token_type,
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
metadata = EXCLUDED.metadata,
updated_at = now()
RETURNING id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at
`
type UpsertUserProviderOAuthTokenParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) UpsertUserProviderOAuthToken(ctx context.Context, arg UpsertUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
row := q.db.QueryRow(ctx, upsertUserProviderOAuthToken,
arg.ProviderID,
arg.UserID,
arg.AccessToken,
arg.RefreshToken,
arg.ExpiresAt,
arg.Scope,
arg.TokenType,
arg.State,
arg.PkceCodeVerifier,
arg.Metadata,
)
var i UserProviderOauthToken
err := row.Scan(
&i.ID,
&i.ProviderID,
&i.UserID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
+8 -1
View File
@@ -10,7 +10,9 @@ import (
"github.com/jackc/pgx/v5"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
)
type ModelsHandler struct {
@@ -301,7 +303,12 @@ func (h *ModelsHandler) Test(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
resp, err := h.service.Test(c.Request().Context(), id)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
resp, err := h.service.Test(ctx, id)
if err != nil {
if strings.Contains(err.Error(), "invalid") {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
+47 -8
View File
@@ -7,6 +7,8 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
)
@@ -20,6 +22,7 @@ func NewProviderOAuthHandler(service *providers.Service) *ProviderOAuthHandler {
func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
e.GET("/providers/:id/oauth/authorize", h.Authorize)
e.POST("/providers/:id/oauth/poll", h.Poll)
e.GET("/providers/:id/oauth/status", h.Status)
e.DELETE("/providers/:id/oauth/token", h.Revoke)
e.GET("/auth/callback", h.Callback)
@@ -30,7 +33,7 @@ func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
// @Summary Start OAuth2 authorization for an LLM provider
// @Tags providers-oauth
// @Param id path string true "Provider ID (UUID)"
// @Success 200 {object} map[string]string
// @Success 200 {object} providers.OAuthAuthorizeResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /providers/{id}/oauth/authorize [get].
@@ -39,11 +42,39 @@ func (h *ProviderOAuthHandler) Authorize(c echo.Context) error {
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
authURL, err := h.service.StartOAuthAuthorization(c.Request().Context(), providerID)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
resp, err := h.service.StartOAuthAuthorization(ctx, providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusOK, map[string]string{"auth_url": authURL})
return c.JSON(http.StatusOK, resp)
}
// Poll godoc
// @Summary Poll OAuth device authorization for an LLM provider
// @Tags providers-oauth
// @Param id path string true "Provider ID (UUID)"
// @Success 200 {object} providers.OAuthStatus
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /providers/{id}/oauth/poll [post].
func (h *ProviderOAuthHandler) Poll(c echo.Context) error {
providerID := strings.TrimSpace(c.Param("id"))
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
status, err := h.service.PollOAuthAuthorization(ctx, providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusOK, status)
}
// Status godoc
@@ -59,7 +90,11 @@ func (h *ProviderOAuthHandler) Status(c echo.Context) error {
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
status, err := h.service.GetOAuthStatus(c.Request().Context(), providerID)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
status, err := h.service.GetOAuthStatus(ctx, providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
@@ -79,7 +114,11 @@ func (h *ProviderOAuthHandler) Revoke(c echo.Context) error {
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
if err := h.service.RevokeOAuthToken(c.Request().Context(), providerID); err != nil {
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
if err := h.service.RevokeOAuthToken(ctx, providerID); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.NoContent(http.StatusNoContent)
@@ -111,11 +150,11 @@ func (h *ProviderOAuthHandler) Callback(c echo.Context) error {
<html>
<head>
<meta charset="utf-8">
<title>OpenAI OAuth Connected</title>
<title>Provider Connected</title>
</head>
<body style="font-family: sans-serif; padding: 24px;">
<h2>OpenAI OAuth connected</h2>
<p>You can close this window and return to Memoh.</p>
<h2>Provider connected</h2>
<p>Your current Memoh account is now connected.</p>
<script>
window.opener?.postMessage({ type: "memoh-provider-oauth-success", providerId: "{{.ProviderID}}" }, "*");
window.close();
+14 -2
View File
@@ -9,7 +9,9 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
)
@@ -272,7 +274,12 @@ func (h *ProvidersHandler) Test(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
resp, err := h.service.Test(c.Request().Context(), id)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
resp, err := h.service.Test(ctx, id)
if err != nil {
if strings.Contains(err.Error(), "invalid") {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
@@ -301,7 +308,12 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
remoteModels, err := h.service.FetchRemoteModels(c.Request().Context(), id)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
}
@@ -10,6 +10,7 @@ import (
"github.com/memohai/memoh/internal/healthcheck"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
)
const (
@@ -25,6 +26,7 @@ type BotModelLookup interface {
// BotModels holds the model UUIDs associated with a bot.
type BotModels struct {
OwnerUserID string
ChatModelID string
MemoryModelID string
EmbeddingModelID string
@@ -115,7 +117,8 @@ func (c *Checker) ListChecks(ctx context.Context, botID string) []healthcheck.Ch
wg.Add(1)
go func(idx int, s modelSlot) {
defer wg.Done()
results[idx] = c.probeSlot(probeCtx, s)
slotCtx := oauthctx.WithUserID(probeCtx, botModels.OwnerUserID)
results[idx] = c.probeSlot(slotCtx, s)
}(i, slot)
}
wg.Wait()
@@ -0,0 +1,59 @@
package modelchecker
import (
"context"
"testing"
"github.com/memohai/memoh/internal/healthcheck"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
)
type testLookup struct {
models BotModels
}
func (l testLookup) GetBotModelIDs(context.Context, string) (BotModels, error) {
return l.models, nil
}
type testProber struct {
t *testing.T
wantUserID string
}
func (p testProber) Test(ctx context.Context, id string) (models.TestResponse, error) {
if got := oauthctx.UserIDFromContext(ctx); got != p.wantUserID {
p.t.Fatalf("expected oauth user id %q, got %q", p.wantUserID, got)
}
if id != "model-chat-1" {
p.t.Fatalf("expected model id %q, got %q", "model-chat-1", id)
}
return models.TestResponse{
Status: models.TestStatusOK,
Reachable: true,
Message: "ok",
}, nil
}
func TestCheckerListChecksInjectsOwnerUserIDIntoProbeContext(t *testing.T) {
t.Parallel()
checker := NewChecker(nil, testLookup{
models: BotModels{
OwnerUserID: "user-123",
ChatModelID: "model-chat-1",
},
}, testProber{
t: t,
wantUserID: "user-123",
})
items := checker.ListChecks(context.Background(), "bot-1")
if len(items) != 1 {
t.Fatalf("expected 1 check result, got %d", len(items))
}
if items[0].Status != healthcheck.StatusOK {
t.Fatalf("expected status %q, got %q", healthcheck.StatusOK, items[0].Status)
}
}
@@ -36,6 +36,7 @@ func (l *QueriesLookup) GetBotModelIDs(ctx context.Context, botID string) (BotMo
}
var m BotModels
m.OwnerUserID = bot.OwnerUserID.String()
if bot.ChatModelID.Valid {
m.ChatModelID = bot.ChatModelID.String()
}
+1
View File
@@ -429,6 +429,7 @@ func IsValidClientType(clientType ClientType) bool {
ClientTypeAnthropicMessages,
ClientTypeGoogleGenerativeAI,
ClientTypeOpenAICodex,
ClientTypeGitHubCopilot,
ClientTypeEdgeSpeech:
return true
default:
+53 -12
View File
@@ -17,8 +17,10 @@ import (
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
memohcopilot "github.com/memohai/memoh/internal/copilot"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/oauthctx"
)
const probeTimeout = 15 * time.Second
@@ -162,6 +164,9 @@ func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientTyp
}
return openaicodex.New(opts...)
case ClientTypeGitHubCopilot:
return memohcopilot.NewProvider(apiKey, httpClient)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(apiKey),
@@ -202,26 +207,62 @@ type modelCredentials struct {
func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.Provider) (modelCredentials, error) {
apiKey := providerConfigString(provider.Config, "api_key")
if ClientType(provider.ClientType) != ClientTypeOpenAICodex {
switch ClientType(provider.ClientType) {
case ClientTypeGitHubCopilot:
token, err := s.resolveGitHubCopilotAccessToken(ctx, provider)
if err != nil {
return modelCredentials{}, err
}
return modelCredentials{APIKey: token}, nil
case ClientTypeOpenAICodex:
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
if err != nil {
return modelCredentials{}, err
}
accessToken := strings.TrimSpace(tokenRow.AccessToken)
if accessToken == "" {
return modelCredentials{}, errors.New("oauth token is missing access token")
}
accountID, err := codexAccountIDFromToken(accessToken)
if err != nil {
return modelCredentials{}, err
}
return modelCredentials{
APIKey: accessToken,
CodexAccountID: accountID,
}, nil
default:
return modelCredentials{APIKey: apiKey}, nil
}
}
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
if err != nil {
return modelCredentials{}, err
func (s *Service) resolveGitHubCopilotAccessToken(ctx context.Context, provider sqlc.Provider) (string, error) {
userID := oauthctx.UserIDFromContext(ctx)
if userID == "" {
return "", errors.New("github copilot requires a current user")
}
accessToken := strings.TrimSpace(tokenRow.AccessToken)
userUUID, err := db.ParseUUID(userID)
if err != nil {
return "", err
}
row, err := s.queries.GetUserProviderOAuthToken(ctx, sqlc.GetUserProviderOAuthTokenParams{
ProviderID: provider.ID,
UserID: userUUID,
})
if err != nil {
return "", err
}
accessToken := strings.TrimSpace(row.AccessToken)
if accessToken == "" {
return modelCredentials{}, errors.New("oauth token is missing access token")
return "", errors.New("oauth token is missing access token")
}
accountID, err := codexAccountIDFromToken(accessToken)
copilotToken, err := memohcopilot.ResolveToken(ctx, accessToken)
if err != nil {
return modelCredentials{}, err
return "", err
}
return modelCredentials{
APIKey: accessToken,
CodexAccountID: accountID,
}, nil
return copilotToken, nil
}
func codexAccountIDFromToken(token string) (string, error) {
+7
View File
@@ -10,6 +10,8 @@ import (
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
memohcopilot "github.com/memohai/memoh/internal/copilot"
)
// SDKModelConfig holds provider and model information resolved from DB,
@@ -76,6 +78,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
}
return openaicodex.New(opts...).ChatModel(cfg.ModelID)
case ClientTypeGitHubCopilot:
return memohcopilot.NewModel(cfg.APIKey, cfg.ModelID, cfg.HTTPClient)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(cfg.APIKey),
@@ -178,6 +183,8 @@ func ResolveClientType(model *sdk.Model) string {
return string(ClientTypeAnthropicMessages)
case strings.Contains(name, "google"):
return string(ClientTypeGoogleGenerativeAI)
case strings.Contains(name, "github-copilot"), strings.Contains(name, "copilot"):
return string(ClientTypeGitHubCopilot)
case strings.Contains(name, "codex"):
return string(ClientTypeOpenAICodex)
case strings.Contains(name, "responses"):
+1
View File
@@ -22,6 +22,7 @@ const (
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
ClientTypeOpenAICodex ClientType = "openai-codex"
ClientTypeGitHubCopilot ClientType = "github-copilot"
ClientTypeEdgeSpeech ClientType = "edge-speech"
)
+21
View File
@@ -0,0 +1,21 @@
package oauthctx
import (
"context"
"strings"
)
type userIDContextKey struct{}
func WithUserID(ctx context.Context, userID string) context.Context {
userID = strings.TrimSpace(userID)
if userID == "" {
return ctx
}
return context.WithValue(ctx, userIDContextKey{}, userID)
}
func UserIDFromContext(ctx context.Context) string {
userID, _ := ctx.Value(userIDContextKey{}).(string)
return strings.TrimSpace(userID)
}
+28 -14
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"strings"
memohcopilot "github.com/memohai/memoh/internal/copilot"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
@@ -24,25 +25,38 @@ func SupportsOpenAICodexOAuth(provider sqlc.Provider) bool {
}
func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.Provider) (ModelCredentials, error) {
if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex {
switch models.ClientType(provider.ClientType) {
case models.ClientTypeGitHubCopilot:
githubToken, err := s.GetValidAccessToken(ctx, provider.ID.String())
if err != nil {
return ModelCredentials{}, err
}
copilotToken, err := memohcopilot.ResolveToken(ctx, githubToken)
if err != nil {
return ModelCredentials{}, err
}
return ModelCredentials{APIKey: copilotToken}, nil
case models.ClientTypeOpenAICodex:
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
if err != nil {
return ModelCredentials{}, err
}
accountID, err := codexAccountIDFromToken(token)
if err != nil {
return ModelCredentials{}, err
}
return ModelCredentials{
APIKey: token,
CodexAccountID: accountID,
}, nil
default:
apiKey := ProviderConfigString(provider, "api_key")
return ModelCredentials{
APIKey: apiKey,
}, nil
}
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
if err != nil {
return ModelCredentials{}, err
}
accountID, err := codexAccountIDFromToken(token)
if err != nil {
return ModelCredentials{}, err
}
return ModelCredentials{
APIKey: token,
CodexAccountID: accountID,
}, nil
}
func codexAccountIDFromToken(token string) (string, error) {
File diff suppressed because it is too large Load Diff
+81 -16
View File
@@ -11,9 +11,11 @@ import (
"time"
"github.com/jackc/pgx/v5/pgtype"
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
sdk "github.com/memohai/twilight-ai/sdk"
memohcopilot "github.com/memohai/memoh/internal/copilot"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
@@ -47,15 +49,14 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
}
configJSON, err := json.Marshal(req.Config)
if err != nil {
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
}
clientType := req.ClientType
if clientType == "" {
clientType = string(models.ClientTypeOpenAICompletions)
}
configJSON, err := json.Marshal(normalizeProviderConfig(clientType, req.Config))
if err != nil {
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
}
var icon pgtype.Text
if req.Icon != "" {
@@ -150,12 +151,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
existingConfig := providerConfig(existing.Config)
if req.Config != nil {
existingAPIKey := configString(existingConfig, "api_key")
newAPIKey := configString(req.Config, "api_key")
if newAPIKey != "" && newAPIKey == maskAPIKey(existingAPIKey) {
req.Config["api_key"] = existingAPIKey
}
existingConfig = req.Config
mergedConfig := mergeProviderConfig(existingConfig, req.Config)
preserveMaskedConfigSecret(mergedConfig, existingConfig, req.Config, "api_key")
existingConfig = normalizeProviderConfig(clientType, mergedConfig)
} else {
existingConfig = normalizeProviderConfig(clientType, existingConfig)
}
configJSON, err := json.Marshal(existingConfig)
if err != nil {
@@ -257,6 +257,34 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
if err != nil {
return nil, fmt.Errorf("get provider: %w", err)
}
if models.ClientType(provider.ClientType) == models.ClientTypeGitHubCopilot {
creds, err := s.ResolveModelCredentials(ctx, provider)
if err != nil {
return nil, err
}
sdkProvider := memohcopilot.NewProvider(creds.APIKey, nil)
if result := sdkProvider.Test(ctx); result.Status != sdk.ProviderStatusOK {
return nil, fmt.Errorf("github copilot provider test failed: %s", result.Message)
}
catalog := githubcopilot.Catalog()
remoteModels := make([]RemoteModel, 0, len(catalog))
for _, model := range catalog {
remoteModels = append(remoteModels, RemoteModel{
ID: model.ID,
Name: model.DisplayName,
Object: "model",
OwnedBy: "github-copilot",
Type: "chat",
Compatibilities: []string{
models.CompatVision,
models.CompatToolCall,
models.CompatReasoning,
},
})
}
return remoteModels, nil
}
if supportsOAuth(provider) {
catalog := openaicodex.Catalog()
remoteModels := make([]RemoteModel, 0, len(catalog))
@@ -329,7 +357,7 @@ func (s *Service) toGetResponse(provider sqlc.Provider) GetResponse {
}
cfg := providerConfig(provider.Config)
maskedCfg := maskConfigAPIKey(cfg)
maskedCfg := maskConfigSecrets(provider.ClientType, cfg)
var icon string
if provider.Icon.Valid {
@@ -378,14 +406,51 @@ func ProviderConfigString(provider sqlc.Provider, key string) string {
return configString(providerConfig(provider.Config), key)
}
// maskConfigAPIKey returns a copy of config with api_key masked.
func maskConfigAPIKey(cfg map[string]any) map[string]any {
func cloneConfig(cfg map[string]any) map[string]any {
result := make(map[string]any, len(cfg))
for k, v := range cfg {
result[k] = v
}
if apiKey, _ := result["api_key"].(string); apiKey != "" {
result["api_key"] = maskAPIKey(apiKey)
return result
}
func mergeProviderConfig(existing, incoming map[string]any) map[string]any {
result := cloneConfig(existing)
for k, v := range incoming {
result[k] = v
}
return result
}
func preserveMaskedConfigSecret(merged, existing, incoming map[string]any, key string) {
existingValue := strings.TrimSpace(configString(existing, key))
newValue := strings.TrimSpace(configString(incoming, key))
if existingValue == "" || newValue == "" {
return
}
if newValue == maskAPIKey(existingValue) {
merged[key] = existingValue
}
}
// normalizeProviderConfig keeps provider-specific secrets under stable keys while
// preserving backward compatibility for legacy stored configs.
func normalizeProviderConfig(clientType string, cfg map[string]any) map[string]any {
result := cloneConfig(cfg)
if models.ClientType(clientType) == models.ClientTypeGitHubCopilot {
delete(result, "api_key")
delete(result, configOAuthClientSecretKey)
}
return result
}
// maskConfigSecrets returns a copy of config with all known secret fields masked.
func maskConfigSecrets(clientType string, cfg map[string]any) map[string]any {
result := normalizeProviderConfig(clientType, cfg)
for _, key := range []string{"api_key", configOAuthClientSecretKey} {
if value, _ := result[key].(string); value != "" {
result[key] = maskAPIKey(value)
}
}
return result
}
+170 -1
View File
@@ -1,6 +1,12 @@
package providers
import "testing"
import (
"testing"
"time"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
func TestMaskAPIKey(t *testing.T) {
t.Parallel()
@@ -34,3 +40,166 @@ func TestMaskAPIKey(t *testing.T) {
}
})
}
func TestNormalizeProviderConfig(t *testing.T) {
t.Parallel()
t.Run("github copilot drops legacy secrets", func(t *testing.T) {
t.Parallel()
cfg := normalizeProviderConfig("github-copilot", map[string]any{
"api_key": "gh-secret",
configOAuthClientSecretKey: "oauth-secret",
"base_url": "ignored",
})
if _, exists := cfg[configOAuthClientSecretKey]; exists {
t.Fatalf("expected oauth client secret to be removed, got %#v", cfg[configOAuthClientSecretKey])
}
if _, exists := cfg["api_key"]; exists {
t.Fatalf("expected legacy api_key to be removed, got %#v", cfg["api_key"])
}
})
t.Run("non copilot providers keep api key key", func(t *testing.T) {
t.Parallel()
cfg := normalizeProviderConfig("openai-completions", map[string]any{
"api_key": "sk-live",
})
if got, ok := cfg["api_key"].(string); !ok || got != "sk-live" {
t.Fatalf("expected api_key to remain untouched, got %#v", cfg["api_key"])
}
})
}
func TestMaskConfigSecrets(t *testing.T) {
t.Parallel()
cfg := maskConfigSecrets("openai-completions", map[string]any{
"api_key": "sk-secret-123456",
})
masked, _ := cfg["api_key"].(string)
if masked == "" || masked == "sk-secret-123456" {
t.Fatalf("expected api key to be masked, got %q", masked)
}
}
func TestPreserveMaskedConfigSecret(t *testing.T) {
t.Parallel()
merged := map[string]any{
configOAuthClientSecretKey: "*************",
}
existing := map[string]any{
configOAuthClientSecretKey: "gh-secret-1234",
}
incoming := map[string]any{
configOAuthClientSecretKey: maskAPIKey("gh-secret-1234"),
}
preserveMaskedConfigSecret(merged, existing, incoming, configOAuthClientSecretKey)
if got, _ := merged[configOAuthClientSecretKey].(string); got != "gh-secret-1234" {
t.Fatalf("expected masked value to be restored to original secret, got %q", got)
}
}
func TestDeviceMetadataRoundTrip(t *testing.T) {
t.Parallel()
expiresAt := time.Date(2026, time.April, 11, 12, 0, 0, 0, time.UTC)
device := oauthDeviceMetadata{
DeviceCode: "device-code",
UserCode: "ABCD-EFGH",
VerificationURI: "https://github.com/login/device",
ExpiresAt: expiresAt,
IntervalSeconds: 5,
}
parsed := deviceMetadataFromMap(device.toMetadata())
if parsed.DeviceCode != device.DeviceCode {
t.Fatalf("expected device code %q, got %q", device.DeviceCode, parsed.DeviceCode)
}
if parsed.UserCode != device.UserCode {
t.Fatalf("expected user code %q, got %q", device.UserCode, parsed.UserCode)
}
if parsed.VerificationURI != device.VerificationURI {
t.Fatalf("expected verification uri %q, got %q", device.VerificationURI, parsed.VerificationURI)
}
if !parsed.ExpiresAt.Equal(expiresAt) {
t.Fatalf("expected expiresAt %s, got %s", expiresAt, parsed.ExpiresAt)
}
if parsed.IntervalSeconds != device.IntervalSeconds {
t.Fatalf("expected interval %d, got %d", device.IntervalSeconds, parsed.IntervalSeconds)
}
status := parsed.toStatus()
if status == nil || !status.Pending {
t.Fatalf("expected pending device status, got %#v", status)
}
}
func TestAccountMetadataRoundTrip(t *testing.T) {
t.Parallel()
account := oauthAccountMetadata{
Label: "octocat",
Login: "octocat",
Name: "The Octocat",
Email: "octocat@github.com",
AvatarURL: "https://avatars.githubusercontent.com/u/1?v=4",
ProfileURL: "https://github.com/octocat",
}
parsed := accountMetadataFromMap(account.toMetadata())
if parsed.Label != account.Label {
t.Fatalf("expected label %q, got %q", account.Label, parsed.Label)
}
if parsed.Login != account.Login {
t.Fatalf("expected login %q, got %q", account.Login, parsed.Login)
}
if parsed.Name != account.Name {
t.Fatalf("expected name %q, got %q", account.Name, parsed.Name)
}
if parsed.Email != account.Email {
t.Fatalf("expected email %q, got %q", account.Email, parsed.Email)
}
if parsed.AvatarURL != account.AvatarURL {
t.Fatalf("expected avatar url %q, got %q", account.AvatarURL, parsed.AvatarURL)
}
if parsed.ProfileURL != account.ProfileURL {
t.Fatalf("expected profile url %q, got %q", account.ProfileURL, parsed.ProfileURL)
}
status := parsed.toStatus()
if status == nil {
t.Fatal("expected account status")
}
if status.Label != account.Label {
t.Fatalf("expected status label %q, got %q", account.Label, status.Label)
}
}
func TestOAuthConfigForGitHubCopilotUsesFixedDeviceFlowSettings(t *testing.T) {
t.Parallel()
service := &Service{}
cfg := service.oauthConfigForProvider(sqlc.Provider{
ClientType: string(models.ClientTypeGitHubCopilot),
Config: []byte(`{"api_key":"legacy","oauth_client_secret":"legacy-secret"}`),
Metadata: []byte(`{"oauth_client_id":"custom","oauth_scopes":"repo"}`),
})
if cfg.ClientID != "Iv1.b507a08c87ecfe98" {
t.Fatalf("expected fixed client id, got %q", cfg.ClientID)
}
if cfg.ClientSecret != "" {
t.Fatalf("expected empty client secret, got %q", cfg.ClientSecret)
}
if cfg.Scopes != "read:user user:email" {
t.Fatalf("expected fixed scope, got %q", cfg.Scopes)
}
}
+31 -5
View File
@@ -54,11 +54,37 @@ type TestResponse struct {
// OAuthStatus is returned by GET /providers/:id/oauth/status.
type OAuthStatus struct {
Configured bool `json:"configured"`
HasToken bool `json:"has_token"`
Expired bool `json:"expired"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CallbackURL string `json:"callback_url"`
Configured bool `json:"configured"`
Mode string `json:"mode,omitempty"`
HasToken bool `json:"has_token"`
Expired bool `json:"expired"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CallbackURL string `json:"callback_url"`
Device *OAuthDeviceStatus `json:"device,omitempty"`
Account *OAuthAccount `json:"account,omitempty"`
}
type OAuthDeviceStatus struct {
Pending bool `json:"pending"`
UserCode string `json:"user_code,omitempty"`
VerificationURI string `json:"verification_uri,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IntervalSeconds int64 `json:"interval_seconds,omitempty"`
}
type OAuthAccount struct {
Label string `json:"label,omitempty"`
Login string `json:"login,omitempty"`
Name string `json:"name,omitempty"`
Email string `json:"email,omitempty"`
AvatarURL string `json:"avatar_url,omitempty"`
ProfileURL string `json:"profile_url,omitempty"`
}
type OAuthAuthorizeResponse struct {
Mode string `json:"mode,omitempty"`
AuthURL string `json:"auth_url,omitempty"`
Device *OAuthDeviceStatus `json:"device,omitempty"`
}
// RemoteModel represents a model returned by the provider's /v1/models endpoint.