feat(provider): add github copilot device flow provider (#364)

This commit is contained in:
LiBr
2026-04-13 19:38:33 +08:00
committed by GitHub
parent a40207ab6d
commit df8fbd8859
36 changed files with 2659 additions and 246 deletions
+21 -10
View File
@@ -46,7 +46,7 @@
</FormItem>
</FormField>
<FormField
v-if="form.values.client_type !== 'openai-codex'"
v-if="!['openai-codex', 'github-copilot'].includes(form.values.client_type)"
v-slot="{ componentField }"
name="api_key"
>
@@ -69,12 +69,13 @@
</FormItem>
</FormField>
<div
v-else
v-else-if="['openai-codex', 'github-copilot'].includes(form.values.client_type)"
class="rounded-lg border p-3 text-xs text-muted-foreground"
>
{{ $t('provider.oauth.createHint') }}
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.githubCreateHint' : 'provider.oauth.openaiCreateHint') }}
</div>
<FormField
v-if="form.values.client_type !== 'github-copilot'"
v-slot="{ componentField }"
name="base_url"
>
@@ -188,12 +189,13 @@ const { mutateAsync: createProviderMutation, isLoading } = useMutation({
mutation: async (data: Record<string, unknown>) => {
const config: Record<string, unknown> = {}
if (data.base_url) config.base_url = data.base_url
if (data.api_key) config.api_key = data.api_key
if (typeof data.api_key === 'string' && data.api_key.trim() !== '' && data.client_type !== 'github-copilot') {
config.api_key = data.api_key.trim()
}
const payload = {
name: data.name,
client_type: data.client_type,
config,
metadata: { additionalProp1: {} },
}
const { data: result } = await postProviders({ body: payload as ProvidersCreateRequest, throwOnError: true })
if (data.auto_import && result?.id) {
@@ -221,18 +223,25 @@ const { mutateAsync: createProviderMutation, isLoading } = useMutation({
const providerSchema = toTypedSchema(z.object({
api_key: z.string().optional(),
base_url: z.string().min(1),
base_url: z.string().optional(),
name: z.string().min(1),
client_type: z.string().min(1),
auto_import: z.boolean().optional(),
}).superRefine((value, ctx) => {
if (value.client_type !== 'openai-codex' && !value.api_key?.trim()) {
if (!['openai-codex', 'github-copilot'].includes(value.client_type) && !value.api_key?.trim()) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ['api_key'],
message: 'API key is required',
})
}
if (value.client_type !== 'github-copilot' && !value.base_url?.trim()) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ['base_url'],
message: 'Base URL is required',
})
}
}))
const form = useForm({
@@ -240,14 +249,16 @@ const form = useForm({
initialValues: {
auto_import: false,
client_type: 'openai-completions',
},
}
})
watch(() => form.values.client_type, (clientType) => {
if (clientType !== 'openai-codex') return
if (!form.values.base_url) {
if (clientType === 'openai-codex' && !form.values.base_url) {
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
}
if (clientType === 'github-copilot') {
form.setFieldValue('base_url', '')
}
})
const createProvider = form.handleSubmit(async (value) => {
+5
View File
@@ -20,6 +20,11 @@ export const CLIENT_TYPE_META: Record<string, ClientTypeMeta> = {
label: 'OpenAI Codex',
hint: 'Codex API (OAuth, coding-optimized)',
},
'github-copilot': {
value: 'github-copilot',
label: 'GitHub Copilot',
hint: 'Device OAuth with GitHub account',
},
'anthropic-messages': {
value: 'anthropic-messages',
label: 'Anthropic Messages',
+18 -3
View File
@@ -292,20 +292,35 @@
"enable": "Enable",
"enableHint": "Only models from enabled providers appear in the available model list",
"oauth": {
"title": "OpenAI OAuth",
"description": "Authorize this provider with your ChatGPT account for Codex-compatible OpenAI access.",
"createHint": "Save the provider first, then authorize it from the provider details panel.",
"openaiTitle": "OpenAI OAuth",
"openaiDescription": "Authorize this provider with your ChatGPT account for Codex-compatible OpenAI access.",
"openaiCreateHint": "Save the provider first, then authorize it from the provider details panel.",
"githubTitle": "GitHub Copilot OAuth",
"githubDescription": "Connect the current Memoh account with GitHub Copilot.",
"githubDeviceTitle": "GitHub Copilot Device Authorization",
"githubDeviceDescription": "Start device authorization, open GitHub's verification page, and enter the user code shown below.",
"githubCreateHint": "Save the provider first, then start device authorization from the provider details panel.",
"githubDeviceHint": "Open the verification URL below and enter this user code to authorize the current Memoh account.",
"authorize": "Authorize",
"deviceAuthorize": "Start Device Authorization",
"authorizeFailed": "Failed to start authorization",
"authorizeSuccess": "Authorization successful",
"revoke": "Revoke",
"revokeFailed": "Failed to revoke authorization",
"revokeSuccess": "Authorization revoked",
"copyFailed": "Failed to copy device code",
"connectedAccount": "Connected Account",
"callback": "Callback URL",
"deviceVerificationUri": "Verification URL",
"deviceUserCode": "User Code",
"deviceExpiresAt": "Expires At",
"statusFailed": "Failed to load OAuth status",
"status": {
"checking": "Checking authorization status...",
"authorized": "Authorized",
"authorizedCurrent": "Current account connected",
"oauthing": "OAuthing...",
"pendingDevice": "Waiting for device authorization to complete...",
"expired": "Authorization expired. Re-authorize to continue.",
"missing": "Not authorized yet.",
"notConfigured": "This provider is not configured for OAuth."
+18 -3
View File
@@ -288,20 +288,35 @@
"enable": "启用",
"enableHint": "只有启用的供应商的模型才会出现在可用模型列表中",
"oauth": {
"title": "OpenAI OAuth",
"description": "使用你的 ChatGPT 账号为该提供商授权,以启用 Codex 兼容的 OpenAI 访问。",
"createHint": "请先保存提供商,再到详情面板完成授权。",
"openaiTitle": "OpenAI OAuth",
"openaiDescription": "使用你的 ChatGPT 账号为该提供商授权,以启用 Codex 兼容的 OpenAI 访问。",
"openaiCreateHint": "请先保存提供商,再到详情面板完成授权。",
"githubTitle": "GitHub Copilot OAuth",
"githubDescription": "为当前 Memoh 账号连接 GitHub Copilot。",
"githubDeviceTitle": "GitHub Copilot Device Authorization",
"githubDeviceDescription": "启动设备授权后,打开 GitHub 验证页面并输入下方显示的用户代码。",
"githubCreateHint": "请先保存提供商,再到详情面板启动设备授权。",
"githubDeviceHint": "打开下方验证地址,并输入这个用户代码,为当前 Memoh 账号完成授权。",
"authorize": "授权",
"deviceAuthorize": "启动设备授权",
"authorizeFailed": "启动授权失败",
"authorizeSuccess": "授权成功",
"revoke": "撤销授权",
"revokeFailed": "撤销授权失败",
"revokeSuccess": "授权已撤销",
"copyFailed": "复制设备代码失败",
"connectedAccount": "已连接账号",
"callback": "回调地址",
"deviceVerificationUri": "验证地址",
"deviceUserCode": "用户代码",
"deviceExpiresAt": "过期时间",
"statusFailed": "加载 OAuth 状态失败",
"status": {
"checking": "正在检查授权状态...",
"authorized": "已授权",
"authorizedCurrent": "当前账号已连接",
"oauthing": "授权中...",
"pendingDevice": "正在等待设备授权完成...",
"expired": "授权已过期,请重新授权。",
"missing": "尚未授权。",
"notConfigured": "当前提供商未正确配置 OAuth。"
@@ -21,7 +21,7 @@
</section>
<section
v-if="form.values.client_type !== 'openai-codex'"
v-if="!['openai-codex', 'github-copilot'].includes(form.values.client_type)"
class="space-y-2"
>
<FormField
@@ -33,7 +33,7 @@
<FormControl>
<Input
type="password"
:placeholder="(providerWithAuth?.config as Record<string, unknown> | undefined)?.api_key as string || $t('provider.apiKeyPlaceholder')"
:placeholder="getStoredSecret(props.provider?.config as Record<string, unknown> | undefined) || $t('provider.apiKeyPlaceholder')"
:aria-label="$t('provider.apiKey')"
v-bind="componentField"
/>
@@ -42,7 +42,10 @@
</FormField>
</section>
<section class="space-y-2">
<section
v-if="form.values.client_type !== 'github-copilot'"
class="space-y-2"
>
<FormField
v-slot="{ componentField }"
name="base_url"
@@ -81,15 +84,15 @@
</section>
<section
v-if="form.values.client_type === 'openai-codex'"
v-if="['openai-codex', 'github-copilot'].includes(form.values.client_type)"
class="rounded-lg border p-4 space-y-3 text-xs"
>
<div class="space-y-1">
<div class="font-medium">
{{ $t('provider.oauth.title') }}
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.githubDeviceTitle' : 'provider.oauth.openaiTitle') }}
</div>
<div class="text-muted-foreground">
{{ $t('provider.oauth.description') }}
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.githubDeviceDescription' : 'provider.oauth.openaiDescription') }}
</div>
<div
class="text-xs"
@@ -105,7 +108,10 @@
{{ $t('provider.oauth.status.expired') }}
</template>
<template v-else-if="oauthStatus?.has_token">
{{ $t('provider.oauth.status.authorized') }}
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.status.authorizedCurrent' : 'provider.oauth.status.authorized') }}
</template>
<template v-else-if="oauthStatus?.device?.pending">
{{ $t('provider.oauth.status.pendingDevice') }}
</template>
<template v-else>
{{ $t('provider.oauth.status.missing') }}
@@ -118,16 +124,88 @@
{{ $t('provider.oauth.callback') }}: {{ oauthStatus.callback_url }}
</div>
</div>
<div
v-if="form.values.client_type === 'github-copilot'
&& oauthStatus?.device?.pending
&& !oauthStatus?.has_token
&& oauthStatus?.device?.user_code
&& oauthStatus?.device?.verification_uri"
class="rounded-md bg-muted/40 p-3 space-y-2"
>
<div class="text-muted-foreground">
{{ $t('provider.oauth.githubDeviceHint') }}
</div>
<div class="space-y-1">
<div class="font-medium">
{{ $t('provider.oauth.deviceVerificationUri') }}
</div>
<code class="block break-all rounded bg-background px-2 py-1 select-all">{{ oauthStatus?.device?.verification_uri }}</code>
</div>
<div class="space-y-1">
<div class="font-medium">
{{ $t('provider.oauth.deviceUserCode') }}
</div>
<div class="flex items-center gap-2">
<code class="block flex-1 rounded bg-background px-2 py-1 text-sm tracking-[0.3em] select-all">{{ oauthStatus?.device?.user_code }}</code>
<Button
type="button"
variant="ghost"
size="sm"
@click="handleCopyDeviceCode"
>
<Copy />
{{ $t('common.copy') }}
</Button>
</div>
</div>
<div
v-if="oauthStatus?.device?.expires_at"
class="text-muted-foreground"
>
{{ $t('provider.oauth.deviceExpiresAt') }}: {{ oauthStatus.device.expires_at }}
</div>
<div class="flex items-center gap-2 text-foreground">
<Spinner class="size-4" />
<span>{{ $t('provider.oauth.status.oauthing') }}</span>
</div>
</div>
<div
v-if="form.values.client_type === 'github-copilot' && oauthStatus?.has_token && !oauthExpired"
class="rounded-md bg-muted/40 p-3 space-y-1"
>
<div class="font-medium">
{{ $t('provider.oauth.connectedAccount') }}
</div>
<div class="text-sm font-medium">
{{ oauthStatus?.account?.email || oauthStatus?.account?.label || oauthStatus?.account?.name || oauthStatus?.account?.login || $t('provider.oauth.status.authorizedCurrent') }}
</div>
<div
v-if="[oauthStatus?.account?.login?.trim() ? `@${oauthStatus.account.login.trim()}` : '', oauthStatus?.account?.email?.trim() ?? ''].filter(Boolean).join(' · ')"
class="text-xs text-muted-foreground"
>
{{ [oauthStatus?.account?.login?.trim() ? `@${oauthStatus.account.login.trim()}` : '', oauthStatus?.account?.email?.trim() ?? ''].filter(Boolean).join(' · ') }}
</div>
</div>
<div class="flex gap-2">
<LoadingButton
v-if="props.provider?.id
&& ['openai-codex', 'github-copilot'].includes(form.values.client_type)
&& !(
form.values.client_type === 'github-copilot'
&& oauthStatus?.device?.pending
&& !oauthStatus?.has_token
&& oauthStatus?.device?.user_code
&& oauthStatus?.device?.verification_uri
)
&& (!oauthStatus?.has_token || oauthExpired)"
type="button"
variant="outline"
:disabled="!canAuthorizeOAuth"
:disabled="!props.provider?.id || !['openai-codex', 'github-copilot'].includes(form.values.client_type) || oauthStatusLoading"
:loading="authorizeLoading"
@click="handleAuthorize"
>
<KeyRound />
{{ $t('provider.oauth.authorize') }}
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.deviceAuthorize' : 'provider.oauth.authorize') }}
</LoadingButton>
<LoadingButton
v-if="oauthStatus?.has_token"
@@ -225,14 +303,16 @@ import {
FormField,
FormLabel,
FormItem,
Spinner,
} from '@memohai/ui'
import { KeyRound, RefreshCw, Trash2 } from 'lucide-vue-next'
import { Copy, KeyRound, RefreshCw, Trash2 } from 'lucide-vue-next'
import ConfirmPopover from '@/components/confirm-popover/index.vue'
import StatusDot from '@/components/status-dot/index.vue'
import LoadingButton from '@/components/loading-button/index.vue'
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
import { useClipboard } from '@/composables/useClipboard'
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { computed, ref, watch } from 'vue'
import { computed, onBeforeUnmount, ref, watch } from 'vue'
import { toTypedSchema } from '@vee-validate/zod'
import z from 'zod'
import { useForm } from 'vee-validate'
@@ -242,15 +322,44 @@ import { useI18n } from 'vue-i18n'
import { toast } from 'vue-sonner'
const { t } = useI18n()
const { copyText } = useClipboard()
type ProviderWithAuth = Partial<ProvidersGetResponse>
type ProviderOAuthStatus = {
configured: boolean
mode?: string
has_token: boolean
expired: boolean
callback_url?: string
expires_at?: string
account?: {
label?: string
login?: string
name?: string
email?: string
avatar_url?: string
profile_url?: string
}
device?: {
pending: boolean
user_code?: string
verification_uri?: string
expires_at?: string
interval_seconds?: number
}
}
type ProviderOAuthAuthorizeResponse = {
mode?: string
auth_url?: string
device?: ProviderOAuthStatus['device']
}
function getStoredSecret(config: Record<string, unknown> | undefined) {
if (!config) return ''
const apiKey = config.api_key
return typeof apiKey === 'string' ? apiKey : ''
}
const props = defineProps<{
@@ -271,10 +380,9 @@ const oauthStatus = ref<ProviderOAuthStatus | null>(null)
const oauthStatusLoading = ref(false)
const authorizeLoading = ref(false)
const revokeLoading = ref(false)
const pollTimer = ref<number | null>(null)
const apiBase = import.meta.env.VITE_API_URL?.trim() || '/api'
const providerWithAuth = computed(() => props.provider as ProviderWithAuth | undefined)
async function runTest() {
if (!props.provider?.id) return
testLoading.value = true
@@ -310,20 +418,27 @@ const clientTypeOptions = computed(() =>
const providerSchema = toTypedSchema(z.object({
enable: z.boolean(),
name: z.string().min(1),
base_url: z.string().min(1),
base_url: z.string().optional(),
api_key: z.string().optional(),
client_type: z.string().min(1),
metadata: z.object({
additionalProp1: z.object({}),
}),
}).superRefine((value, ctx) => {
if (value.client_type !== 'openai-codex' && !value.api_key?.trim() && !(providerWithAuth.value?.config as Record<string, unknown> | undefined)?.api_key) {
const existingSecret = getStoredSecret(
props.provider?.config as Record<string, unknown> | undefined,
)
if (!['openai-codex', 'github-copilot'].includes(value.client_type) && !value.api_key?.trim() && !existingSecret.trim()) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ['api_key'],
message: 'API key is required',
})
}
if (value.client_type !== 'github-copilot' && !value.base_url?.trim()) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ['base_url'],
message: 'Base URL is required',
})
}
}))
const form = useForm({
@@ -344,17 +459,19 @@ watch(() => props.provider, (newVal) => {
}, { immediate: true })
watch(() => form.values.client_type, (clientType) => {
if (clientType !== 'openai-codex') {
if (!['openai-codex', 'github-copilot'].includes(clientType)) {
oauthStatus.value = null
return
}
if (!form.values.base_url) {
if (clientType === 'openai-codex' && !form.values.base_url) {
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
}
if (clientType === 'github-copilot') {
form.setFieldValue('base_url', '')
}
})
watch(() => [props.provider?.id, form.values.client_type] as const, async ([id, clientType]) => {
if (!id || clientType !== 'openai-codex') {
if (!id || (clientType !== 'openai-codex' && clientType !== 'github-copilot')) {
oauthStatus.value = null
return
}
@@ -369,13 +486,11 @@ const hasChanges = computed(() => {
name: form.values.name,
base_url: form.values.base_url,
client_type: form.values.client_type,
metadata: form.values.metadata,
}) !== JSON.stringify({
enable: raw?.enable ?? true,
name: raw?.name,
base_url: (cfg?.base_url as string) ?? '',
client_type: raw?.client_type || 'openai-completions',
metadata: { additionalProp1: {} },
})
const apiKeyChanged = Boolean(form.values.api_key && form.values.api_key.trim() !== '')
@@ -383,33 +498,47 @@ const hasChanges = computed(() => {
})
const editProvider = form.handleSubmit(async (value) => {
const config: Record<string, unknown> = { base_url: value.base_url }
const config: Record<string, unknown> = {}
if (value.base_url && value.base_url.trim() !== '') {
config.base_url = value.base_url
}
if (value.api_key && value.api_key.trim() !== '') {
config.api_key = value.api_key
if (value.client_type !== 'github-copilot') {
config.api_key = value.api_key.trim()
}
}
const metadata = {
...((props.provider?.metadata as Record<string, unknown> | undefined) ?? {}),
}
if (value.client_type === 'github-copilot') {
delete metadata.oauth_client_id
}
const payload: Record<string, unknown> = {
enable: value.enable,
name: value.name,
config,
client_type: value.client_type,
metadata: value.metadata,
}
if (Object.keys(metadata).length > 0 || value.client_type === 'github-copilot') {
payload.metadata = metadata
}
emit('submit', payload)
})
const oauthExpired = computed(() => Boolean(oauthStatus.value?.has_token && oauthStatus.value?.expired))
const canAuthorizeOAuth = computed(() =>
Boolean(
props.provider?.id
&& form.values.client_type === 'openai-codex',
) && !oauthStatusLoading.value,
)
function authHeaders(): Record<string, string> {
const token = localStorage.getItem('token')
return token ? { Authorization: `Bearer ${token}` } : {}
}
function clearPollTimer() {
if (pollTimer.value !== null) {
window.clearTimeout(pollTimer.value)
pollTimer.value = null
}
}
async function fetchOAuthStatus() {
if (!props.provider?.id) return
oauthStatusLoading.value = true
@@ -427,6 +556,44 @@ async function fetchOAuthStatus() {
}
}
async function pollOAuthAuthorization(notifyOnSuccess = false) {
if (!props.provider?.id || form.values.client_type !== 'github-copilot') return
try {
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/poll`, {
method: 'POST',
headers: authHeaders(),
})
if (!response.ok) throw new Error(t('provider.oauth.authorizeFailed'))
const nextStatus = await response.json() as ProviderOAuthStatus
const becameAuthorized = !oauthStatus.value?.has_token && Boolean(nextStatus.has_token)
oauthStatus.value = nextStatus
if (notifyOnSuccess && becameAuthorized) {
toast.success(t('provider.oauth.authorizeSuccess'))
}
} catch (error) {
clearPollTimer()
toast.error(error instanceof Error ? error.message : t('provider.oauth.authorizeFailed'))
}
}
watch(oauthStatus, (status) => {
clearPollTimer()
if (form.values.client_type !== 'github-copilot') {
return
}
if (!status?.device?.pending || status.has_token) {
return
}
const intervalSeconds = Math.max(status.device.interval_seconds ?? 5, 1)
pollTimer.value = window.setTimeout(() => {
void pollOAuthAuthorization(true)
}, intervalSeconds * 1000)
})
onBeforeUnmount(() => {
clearPollTimer()
})
async function handleAuthorize() {
if (!props.provider?.id) return
authorizeLoading.value = true
@@ -435,7 +602,18 @@ async function handleAuthorize() {
headers: authHeaders(),
})
if (!response.ok) throw new Error(t('provider.oauth.authorizeFailed'))
const data = await response.json() as { auth_url?: string }
const data = await response.json() as ProviderOAuthAuthorizeResponse
if (data.mode === 'device') {
oauthStatus.value = {
configured: true,
mode: 'device',
has_token: false,
expired: false,
callback_url: '',
device: data.device,
}
return
}
if (!data.auth_url) throw new Error(t('provider.oauth.authorizeFailed'))
const popup = window.open(data.auth_url, 'provider-oauth', 'width=600,height=720')
const listener = async (event: MessageEvent) => {
@@ -453,8 +631,34 @@ async function handleAuthorize() {
}
}
async function handleCopyDeviceCode() {
const userCode = oauthStatus.value?.device?.user_code?.trim()
const verificationUri = oauthStatus.value?.device?.verification_uri?.trim()
if (!userCode || !verificationUri) return
const popup = window.open('', 'provider-device-oauth', 'width=960,height=720')
const copied = await copyText(userCode)
if (!copied) {
popup?.close()
toast.error(t('provider.oauth.copyFailed'))
return
}
toast.success(t('common.copied'))
if (popup) {
popup.location.href = verificationUri
popup.focus()
return
}
window.open(verificationUri, '_blank', 'width=960,height=720')
}
async function handleRevoke() {
if (!props.provider?.id) return
clearPollTimer()
revokeLoading.value = true
try {
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/token`, {
+185
View File
@@ -0,0 +1,185 @@
name: GitHub Copilot
client_type: github-copilot
models:
- model_id: claude-opus-4.6-1m
name: Claude Opus 4.6 (1M context)(Internal only)
type: chat
config:
context_window: 1000000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high]
- model_id: claude-opus-4.6
name: Claude Opus 4.6
type: chat
config:
context_window: 144000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high]
- model_id: claude-sonnet-4.6
name: Claude Sonnet 4.6
type: chat
config:
context_window: 200000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high]
- model_id: goldeneye-free-auto
name: Goldeneye
type: chat
config:
context_window: 400000
compatibilities: [vision, tool-call]
- model_id: gpt-5.2-codex
name: GPT-5.2-Codex
type: chat
config:
context_window: 400000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high, xhigh]
- model_id: gpt-5.3-codex
name: GPT-5.3-Codex
type: chat
config:
context_window: 400000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high, xhigh]
- model_id: gpt-5.4-mini
name: GPT-5.4 mini
type: chat
config:
context_window: 400000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [none, low, medium, high, xhigh]
- model_id: gpt-5.4
name: GPT-5.4
type: chat
config:
context_window: 400000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high, xhigh]
- model_id: gpt-5-mini
name: GPT-5 mini
type: chat
config:
context_window: 264000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high]
- model_id: gpt-4o-mini-2024-07-18
name: GPT-4o mini
type: chat
config:
context_window: 128000
compatibilities: [tool-call]
- model_id: grok-code-fast-1
name: Grok Code Fast 1
type: chat
config:
context_window: 128000
compatibilities: [tool-call]
- model_id: gpt-5.1
name: GPT-5.1
type: chat
config:
context_window: 264000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [none, low, medium, high]
- model_id: text-embedding-3-small
name: Embedding V3 small
type: embedding
config:
dimensions: 1536
- model_id: text-embedding-3-small-inference
name: Embedding V3 small (Inference)
type: embedding
config:
dimensions: 1536
- model_id: claude-sonnet-4
name: Claude Sonnet 4
type: chat
config:
context_window: 216000
compatibilities: [vision, tool-call, reasoning]
- model_id: claude-sonnet-4.5
name: Claude Sonnet 4.5
type: chat
config:
context_window: 144000
compatibilities: [vision, tool-call, reasoning]
- model_id: claude-opus-4.5
name: Claude Opus 4.5
type: chat
config:
context_window: 160000
compatibilities: [vision, tool-call, reasoning]
- model_id: claude-haiku-4.5
name: Claude Haiku 4.5
type: chat
config:
context_window: 144000
compatibilities: [vision, tool-call, reasoning]
- model_id: gpt-4.1-2025-04-14
name: GPT-4.1
type: chat
config:
context_window: 128000
compatibilities: [vision, tool-call]
- model_id: gpt-5.2
name: GPT-5.2
type: chat
config:
context_window: 264000
compatibilities: [vision, tool-call, reasoning]
reasoning_efforts: [low, medium, high, xhigh]
- model_id: gpt-3.5-turbo-0613
name: GPT 3.5 Turbo
type: chat
config:
context_window: 16384
compatibilities: [tool-call]
- model_id: gpt-4.1
name: GPT-4.1
type: chat
config:
context_window: 128000
compatibilities: [vision, tool-call]
- model_id: gpt-3.5-turbo
name: GPT 3.5 Turbo
type: chat
config:
context_window: 16384
compatibilities: [tool-call]
- model_id: gpt-4o-mini
name: GPT-4o mini
type: chat
config:
context_window: 128000
compatibilities: [tool-call]
- model_id: text-embedding-ada-002
name: Embedding V2 Ada
type: embedding
config:
dimensions: 1536
+21 -1
View File
@@ -68,7 +68,7 @@ CREATE TABLE IF NOT EXISTS providers (
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT providers_name_unique UNIQUE (name),
CONSTRAINT providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex', 'edge-speech'))
CONSTRAINT providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex', 'github-copilot', 'edge-speech'))
);
CREATE TABLE IF NOT EXISTS search_providers (
@@ -644,3 +644,23 @@ CREATE TABLE IF NOT EXISTS provider_oauth_tokens (
);
CREATE INDEX IF NOT EXISTS idx_provider_oauth_tokens_state ON provider_oauth_tokens(state) WHERE state != '';
-- user_provider_oauth_tokens: per-user OAuth2 tokens for providers with user-scoped auth (e.g. GitHub Copilot)
CREATE TABLE IF NOT EXISTS user_provider_oauth_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
provider_id UUID NOT NULL REFERENCES providers(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
access_token TEXT NOT NULL DEFAULT '',
refresh_token TEXT NOT NULL DEFAULT '',
expires_at TIMESTAMPTZ,
scope TEXT NOT NULL DEFAULT '',
token_type TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT '',
pkce_code_verifier TEXT NOT NULL DEFAULT '',
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT user_provider_oauth_tokens_provider_user_unique UNIQUE (provider_id, user_id)
);
CREATE INDEX IF NOT EXISTS idx_user_provider_oauth_tokens_state ON user_provider_oauth_tokens(state) WHERE state != '';
@@ -0,0 +1,21 @@
-- 0062_github_copilot_user_oauth (rollback)
-- Remove user-scoped provider OAuth tokens and github-copilot client type.
DROP INDEX IF EXISTS idx_user_provider_oauth_tokens_state;
DROP TABLE IF EXISTS user_provider_oauth_tokens;
DELETE FROM providers WHERE client_type = 'github-copilot';
ALTER TABLE IF EXISTS providers DROP CONSTRAINT IF EXISTS providers_client_type_check;
ALTER TABLE IF EXISTS providers
ADD CONSTRAINT providers_client_type_check CHECK (
client_type IN (
'openai-responses',
'openai-completions',
'anthropic-messages',
'google-generative-ai',
'openai-codex',
'edge-speech'
)
);
@@ -0,0 +1,38 @@
-- 0062_github_copilot_user_oauth
-- Add github-copilot as a provider client type and store OAuth tokens per user.
ALTER TABLE IF EXISTS providers DROP CONSTRAINT IF EXISTS providers_client_type_check;
ALTER TABLE IF EXISTS providers
ADD CONSTRAINT providers_client_type_check CHECK (
client_type IN (
'openai-responses',
'openai-completions',
'anthropic-messages',
'google-generative-ai',
'openai-codex',
'github-copilot',
'edge-speech'
)
);
CREATE TABLE IF NOT EXISTS user_provider_oauth_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
provider_id UUID NOT NULL REFERENCES providers(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
access_token TEXT NOT NULL DEFAULT '',
refresh_token TEXT NOT NULL DEFAULT '',
expires_at TIMESTAMPTZ,
scope TEXT NOT NULL DEFAULT '',
token_type TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT '',
pkce_code_verifier TEXT NOT NULL DEFAULT '',
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT user_provider_oauth_tokens_provider_user_unique UNIQUE (provider_id, user_id)
);
CREATE INDEX IF NOT EXISTS idx_user_provider_oauth_tokens_state
ON user_provider_oauth_tokens(state)
WHERE state != '';
+66
View File
@@ -0,0 +1,66 @@
-- name: UpsertUserProviderOAuthToken :one
INSERT INTO user_provider_oauth_tokens (
provider_id,
user_id,
access_token,
refresh_token,
expires_at,
scope,
token_type,
state,
pkce_code_verifier,
metadata
)
VALUES (
sqlc.arg(provider_id),
sqlc.arg(user_id),
sqlc.arg(access_token),
sqlc.arg(refresh_token),
sqlc.arg(expires_at),
sqlc.arg(scope),
sqlc.arg(token_type),
sqlc.arg(state),
sqlc.arg(pkce_code_verifier),
sqlc.arg(metadata)
)
ON CONFLICT (provider_id, user_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
expires_at = EXCLUDED.expires_at,
scope = EXCLUDED.scope,
token_type = EXCLUDED.token_type,
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
metadata = EXCLUDED.metadata,
updated_at = now()
RETURNING *;
-- name: GetUserProviderOAuthToken :one
SELECT * FROM user_provider_oauth_tokens
WHERE provider_id = sqlc.arg(provider_id)
AND user_id = sqlc.arg(user_id);
-- name: GetUserProviderOAuthTokenByState :one
SELECT * FROM user_provider_oauth_tokens
WHERE state = sqlc.arg(state)
AND state != '';
-- name: UpdateUserProviderOAuthState :exec
INSERT INTO user_provider_oauth_tokens (provider_id, user_id, state, pkce_code_verifier, metadata)
VALUES (
sqlc.arg(provider_id),
sqlc.arg(user_id),
sqlc.arg(state),
sqlc.arg(pkce_code_verifier),
sqlc.arg(metadata)
)
ON CONFLICT (provider_id, user_id) DO UPDATE SET
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
metadata = EXCLUDED.metadata,
updated_at = now();
-- name: DeleteUserProviderOAuthToken :exec
DELETE FROM user_provider_oauth_tokens
WHERE provider_id = sqlc.arg(provider_id)
AND user_id = sqlc.arg(user_id);
+6 -6
View File
@@ -27,8 +27,8 @@ require (
github.com/mailgun/mailgun-go/v5 v5.14.0
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7
github.com/memohai/dingtalk-stream-sdk-go v0.0.0-20260405113102-87e23096b978
github.com/memohai/twilight-ai v0.3.4-0.20260402160505-00db38ee4442
github.com/modelcontextprotocol/go-sdk v1.4.1
github.com/memohai/twilight-ai v0.3.4-0.20260412161211-dbedfe32c86f
github.com/modelcontextprotocol/go-sdk v1.5.0
github.com/opencontainers/image-spec v1.1.1
github.com/opencontainers/runtime-spec v1.3.0
github.com/qdrant/go-client v1.17.1
@@ -40,11 +40,12 @@ require (
github.com/yuin/goldmark v1.7.13
go.uber.org/fx v1.24.0
golang.org/x/crypto v0.48.0
golang.org/x/oauth2 v0.35.0
golang.org/x/oauth2 v0.36.0
golang.org/x/time v0.14.0
google.golang.org/grpc v1.78.0
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1
tags.cncf.io/container-device-interface v1.1.0
)
require (
@@ -122,7 +123,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/sasha-s/go-deadlock v0.3.6 // indirect
github.com/segmentio/asm v1.1.3 // indirect
github.com/segmentio/asm v1.2.1 // indirect
github.com/segmentio/encoding v0.5.4 // indirect
github.com/sirupsen/logrus v1.9.4 // indirect
github.com/spf13/pflag v1.0.9 // indirect
@@ -143,11 +144,10 @@ require (
golang.org/x/mod v0.33.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.42.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect
sigs.k8s.io/yaml v1.6.0 // indirect
tags.cncf.io/container-device-interface v1.1.0 // indirect
tags.cncf.io/container-device-interface/specs-go v1.1.0 // indirect
)
+22 -10
View File
@@ -24,6 +24,8 @@ github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kk
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de h1:FxWPpzIjnTlhPwqqXc4/vE0f7GvRjuAsbW+HOIe8KnA=
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoUhZrYW9p1lxo/cm8EmUOOzAPSEZNGF2DK1dJgw=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
@@ -189,6 +191,10 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo=
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -232,8 +238,8 @@ github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 h1:beehwOQperqGWj4m4E
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7/go.mod h1:OvmxM7JmnXBmwJWWVqtreL3HSHSKuzPbtbhlg5MvBg0=
github.com/memohai/dingtalk-stream-sdk-go v0.0.0-20260405113102-87e23096b978 h1:6gD8DvZkimGmU0e3PjlusJPyw55SyeoE12CZQoYUa8g=
github.com/memohai/dingtalk-stream-sdk-go v0.0.0-20260405113102-87e23096b978/go.mod h1:2LMgK5QYFlTSvrGY+sI/j+jK2WK+YGHv4IMuiW+iPSc=
github.com/memohai/twilight-ai v0.3.4-0.20260402160505-00db38ee4442 h1:mTy+OSkMCOvF1S6D5asKRdKx0A+icQvnu6A/f7aZolg=
github.com/memohai/twilight-ai v0.3.4-0.20260402160505-00db38ee4442/go.mod h1:GZTT9GUT3uSs6zram/FcF24GLTZMFSpiybbYmjr+gH8=
github.com/memohai/twilight-ai v0.3.4-0.20260412161211-dbedfe32c86f h1:9NAj+FyDJPi8RzD1PUwb6OxZx/OrBD2FJo4tVAlhpbs=
github.com/memohai/twilight-ai v0.3.4-0.20260412161211-dbedfe32c86f/go.mod h1:1uNfZWc8du+HWJ3r3FLyeGAXGiUAniuSWV89A8gbcz0=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
@@ -252,8 +258,8 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc=
github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s=
github.com/modelcontextprotocol/go-sdk v1.5.0 h1:CHU0FIX9kpueNkxuYtfYQn1Z0slhFzBZuq+x6IiblIU=
github.com/modelcontextprotocol/go-sdk v1.5.0/go.mod h1:gggDIhoemhWs3BGkGwd1umzEXCEMMvAnhTrnbXJKKKA=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -303,8 +309,8 @@ github.com/sasha-s/go-deadlock v0.3.6/go.mod h1:CUqNyyvMxTyjFqDT7MRg9mb4Dv/btmGT
github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=
github.com/sebdah/goldie/v2 v2.8.0 h1:dZb9wR8q5++oplmEiJT+U/5KyotVD+HNGCAc5gNr8rc=
github.com/sebdah/goldie/v2 v2.8.0/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI=
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0=
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
@@ -337,6 +343,12 @@ github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zd
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8=
github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -420,8 +432,8 @@ golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -450,8 +462,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+3 -1
View File
@@ -28,6 +28,7 @@ import (
messagepkg "github.com/memohai/memoh/internal/message"
messageevent "github.com/memohai/memoh/internal/message/event"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
pipelinepkg "github.com/memohai/memoh/internal/pipeline"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/settings"
@@ -499,7 +500,8 @@ func (r *Resolver) buildBaseRunConfig(ctx context.Context, p baseRunConfigParams
}
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
authCtx := oauthctx.WithUserID(ctx, p.UserID)
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
if err != nil {
return agentpkg.RunConfig{}, models.GetResponse{}, sqlc.Provider{}, fmt.Errorf("resolve provider credentials: %w", err)
}
@@ -7,6 +7,7 @@ import (
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/settings"
)
@@ -115,7 +116,8 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
return compaction.TriggerConfig{}, err
}
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, compactProvider)
authCtx := oauthctx.WithUserID(ctx, req.UserID)
creds, err := authResolver.ResolveModelCredentials(authCtx, compactProvider)
if err != nil {
return compaction.TriggerConfig{}, err
}
@@ -137,7 +139,6 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
if compactModel.Config.ContextWindow != nil && *compactModel.Config.ContextWindow > 0 {
cfg.MaxCompactTokens = *compactModel.Config.ContextWindow * 90 / 100
}
// For sync compaction: keep only the last few messages (~2000 tokens ≈ 3 messages).
// The summary provides reference context; if the LLM needs details,
// it will use tools (memory_read, search) to look them up.
+5 -3
View File
@@ -13,6 +13,7 @@ import (
"github.com/memohai/memoh/internal/db/sqlc"
messageevent "github.com/memohai/memoh/internal/message/event"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/session"
)
@@ -82,7 +83,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
return
}
title := r.generateTitle(ctx, titleModel, provider, userQuery)
title := r.generateTitle(ctx, req.UserID, titleModel, provider, userQuery)
if title == "" {
return
}
@@ -95,7 +96,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
}
}
func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
func (r *Resolver) generateTitle(ctx context.Context, userID string, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
userSnippet := truncate(strings.TrimSpace(userQuery), titlePromptMaxInputChars)
if userSnippet == "" {
return ""
@@ -106,7 +107,8 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
"User: " + userSnippet
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
authCtx := oauthctx.WithUserID(ctx, userID)
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
if err != nil {
r.logger.Warn("title gen: failed to resolve provider credentials", slog.Any("error", err))
return ""
+176
View File
@@ -0,0 +1,176 @@
package copilot
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
)
const (
GitHubOAuthClientID = "Iv1.b507a08c87ecfe98"
GitHubOAuthScope = "read:user user:email"
DefaultAPIBaseURL = "https://api.githubcopilot.com"
copilotTokenURL = "https://api.github.com/copilot_internal/v2/token" //nolint:gosec // Fixed GitHub API endpoint, not a credential.
copilotEditorVersion = "vscode/1.110.1"
copilotPluginVersion = "copilot-chat/0.38.2"
copilotUserAgent = "GitHubCopilotChat/0.38.2"
copilotAPIVersion = "2025-10-01"
copilotIntegrationID = "vscode-chat"
copilotTokenRefreshSkew = time.Minute
defaultHTTPClientTimeout = 15 * time.Second
)
type cachedToken struct {
Token string
ExpiresAt time.Time
}
var tokenCache = struct {
mu sync.Mutex
entries map[string]cachedToken
}{
entries: map[string]cachedToken{},
}
type tokenResponse struct {
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
}
func ResolveToken(ctx context.Context, githubToken string) (string, error) {
githubToken = strings.TrimSpace(githubToken)
if githubToken == "" {
return "", errors.New("github token is required")
}
if token, ok := loadCachedToken(githubToken); ok {
return token, nil
}
token, expiresAt, err := FetchCopilotToken(ctx, githubToken)
if err != nil {
return "", err
}
storeCachedToken(githubToken, token, expiresAt)
return token, nil
}
func FetchCopilotToken(ctx context.Context, githubToken string) (string, time.Time, error) {
githubToken = strings.TrimSpace(githubToken)
if githubToken == "" {
return "", time.Time{}, errors.New("github token is required")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotTokenURL, nil)
if err != nil {
return "", time.Time{}, fmt.Errorf("create copilot token request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "token "+githubToken)
req.Header.Set("Editor-Version", copilotEditorVersion)
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
req.Header.Set("User-Agent", copilotUserAgent)
req.Header.Set("X-GitHub-Api-Version", copilotAPIVersion)
resp, err := defaultHTTPClient(nil).Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
if err != nil {
return "", time.Time{}, fmt.Errorf("fetch copilot token: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", time.Time{}, fmt.Errorf("read copilot token response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", time.Time{}, fmt.Errorf("copilot token request failed: %s", strings.TrimSpace(string(body)))
}
var parsed tokenResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return "", time.Time{}, fmt.Errorf("decode copilot token response: %w", err)
}
if strings.TrimSpace(parsed.Token) == "" {
return "", time.Time{}, errors.New("copilot token response did not include a token")
}
var expiresAt time.Time
if parsed.ExpiresAt > 0 {
expiresAt = time.Unix(parsed.ExpiresAt, 0).UTC()
}
return parsed.Token, expiresAt, nil
}
func NewHTTPClient(base *http.Client) *http.Client {
client := defaultHTTPClient(base)
client.Transport = &headerRoundTripper{
base: client.Transport,
headers: map[string]string{
"Copilot-Integration-Id": copilotIntegrationID,
"Editor-Version": copilotEditorVersion,
"Editor-Plugin-Version": copilotPluginVersion,
"User-Agent": copilotUserAgent,
"X-GitHub-Api-Version": copilotAPIVersion,
},
}
return client
}
type headerRoundTripper struct {
base http.RoundTripper
headers map[string]string
}
func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
clone := req.Clone(req.Context())
clone.Header = req.Header.Clone()
for key, value := range rt.headers {
clone.Header.Set(key, value)
}
if rt.base == nil {
rt.base = http.DefaultTransport
}
return rt.base.RoundTrip(clone)
}
func loadCachedToken(githubToken string) (string, bool) {
tokenCache.mu.Lock()
defer tokenCache.mu.Unlock()
entry, ok := tokenCache.entries[githubToken]
if !ok {
return "", false
}
if !entry.ExpiresAt.IsZero() && !time.Now().Add(copilotTokenRefreshSkew).Before(entry.ExpiresAt) {
delete(tokenCache.entries, githubToken)
return "", false
}
return entry.Token, true
}
func storeCachedToken(githubToken, token string, expiresAt time.Time) {
tokenCache.mu.Lock()
defer tokenCache.mu.Unlock()
tokenCache.entries[githubToken] = cachedToken{
Token: token,
ExpiresAt: expiresAt,
}
}
func defaultHTTPClient(base *http.Client) *http.Client {
if base != nil {
clone := *base
if clone.Timeout == 0 {
clone.Timeout = defaultHTTPClientTimeout
}
return &clone
}
return &http.Client{Timeout: defaultHTTPClientTimeout}
}
+80
View File
@@ -0,0 +1,80 @@
package copilot
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestNewHTTPClientAddsCopilotHeaders(t *testing.T) {
t.Parallel()
base := &http.Client{
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
}
if got := req.Header.Get("Editor-Version"); got != copilotEditorVersion {
t.Fatalf("expected editor version %q, got %q", copilotEditorVersion, got)
}
if got := req.Header.Get("Editor-Plugin-Version"); got != copilotPluginVersion {
t.Fatalf("expected plugin version %q, got %q", copilotPluginVersion, got)
}
if got := req.Header.Get("User-Agent"); got != copilotUserAgent {
t.Fatalf("expected user agent %q, got %q", copilotUserAgent, got)
}
if got := req.Header.Get("X-GitHub-Api-Version"); got != copilotAPIVersion {
t.Fatalf("expected api version %q, got %q", copilotAPIVersion, got)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`ok`)),
Header: make(http.Header),
}, nil
}),
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.githubcopilot.com/chat/completions", nil)
if err != nil {
t.Fatalf("create request: %v", err)
}
resp, err := NewHTTPClient(base).Do(req) //nolint:gosec // Test request targets a fixed Copilot API endpoint.
if err != nil {
t.Fatalf("execute request: %v", err)
}
_ = resp.Body.Close()
}
func TestNewHTTPClientWithNilBaseDoesNotPanic(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`ok`))
}))
defer server.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("create request: %v", err)
}
resp, err := NewHTTPClient(nil).Do(req) //nolint:gosec // Test request targets an httptest server URL.
if err != nil {
t.Fatalf("execute request with nil base client: %v", err)
}
_ = resp.Body.Close()
}
+30
View File
@@ -0,0 +1,30 @@
package copilot
import (
"net/http"
"strings"
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
sdk "github.com/memohai/twilight-ai/sdk"
)
func NewProvider(copilotToken string, baseClient *http.Client) sdk.Provider {
options := []githubcopilot.Option{
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
}
return githubcopilot.New(options...)
}
func NewModel(copilotToken, modelID string, baseClient *http.Client) *sdk.Model {
options := []githubcopilot.Option{
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
}
if strings.TrimSpace(modelID) == "" {
modelID = githubcopilot.AutoModel
}
return githubcopilot.New(options...).ChatModel(modelID)
}
+16
View File
@@ -515,3 +515,19 @@ type UserChannelBinding struct {
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type UserProviderOauthToken struct {
ID pgtype.UUID `json:"id"`
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
+205
View File
@@ -0,0 +1,205 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: user_provider_oauth.sql
package sqlc
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteUserProviderOAuthToken = `-- name: DeleteUserProviderOAuthToken :exec
DELETE FROM user_provider_oauth_tokens
WHERE provider_id = $1
AND user_id = $2
`
type DeleteUserProviderOAuthTokenParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) DeleteUserProviderOAuthToken(ctx context.Context, arg DeleteUserProviderOAuthTokenParams) error {
_, err := q.db.Exec(ctx, deleteUserProviderOAuthToken, arg.ProviderID, arg.UserID)
return err
}
const getUserProviderOAuthToken = `-- name: GetUserProviderOAuthToken :one
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
WHERE provider_id = $1
AND user_id = $2
`
type GetUserProviderOAuthTokenParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) GetUserProviderOAuthToken(ctx context.Context, arg GetUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
row := q.db.QueryRow(ctx, getUserProviderOAuthToken, arg.ProviderID, arg.UserID)
var i UserProviderOauthToken
err := row.Scan(
&i.ID,
&i.ProviderID,
&i.UserID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getUserProviderOAuthTokenByState = `-- name: GetUserProviderOAuthTokenByState :one
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
WHERE state = $1
AND state != ''
`
func (q *Queries) GetUserProviderOAuthTokenByState(ctx context.Context, state string) (UserProviderOauthToken, error) {
row := q.db.QueryRow(ctx, getUserProviderOAuthTokenByState, state)
var i UserProviderOauthToken
err := row.Scan(
&i.ID,
&i.ProviderID,
&i.UserID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateUserProviderOAuthState = `-- name: UpdateUserProviderOAuthState :exec
INSERT INTO user_provider_oauth_tokens (provider_id, user_id, state, pkce_code_verifier, metadata)
VALUES (
$1,
$2,
$3,
$4,
$5
)
ON CONFLICT (provider_id, user_id) DO UPDATE SET
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
metadata = EXCLUDED.metadata,
updated_at = now()
`
type UpdateUserProviderOAuthStateParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) UpdateUserProviderOAuthState(ctx context.Context, arg UpdateUserProviderOAuthStateParams) error {
_, err := q.db.Exec(ctx, updateUserProviderOAuthState,
arg.ProviderID,
arg.UserID,
arg.State,
arg.PkceCodeVerifier,
arg.Metadata,
)
return err
}
const upsertUserProviderOAuthToken = `-- name: UpsertUserProviderOAuthToken :one
INSERT INTO user_provider_oauth_tokens (
provider_id,
user_id,
access_token,
refresh_token,
expires_at,
scope,
token_type,
state,
pkce_code_verifier,
metadata
)
VALUES (
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10
)
ON CONFLICT (provider_id, user_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
expires_at = EXCLUDED.expires_at,
scope = EXCLUDED.scope,
token_type = EXCLUDED.token_type,
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
metadata = EXCLUDED.metadata,
updated_at = now()
RETURNING id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at
`
type UpsertUserProviderOAuthTokenParams struct {
ProviderID pgtype.UUID `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) UpsertUserProviderOAuthToken(ctx context.Context, arg UpsertUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
row := q.db.QueryRow(ctx, upsertUserProviderOAuthToken,
arg.ProviderID,
arg.UserID,
arg.AccessToken,
arg.RefreshToken,
arg.ExpiresAt,
arg.Scope,
arg.TokenType,
arg.State,
arg.PkceCodeVerifier,
arg.Metadata,
)
var i UserProviderOauthToken
err := row.Scan(
&i.ID,
&i.ProviderID,
&i.UserID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
+8 -1
View File
@@ -10,7 +10,9 @@ import (
"github.com/jackc/pgx/v5"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
)
type ModelsHandler struct {
@@ -301,7 +303,12 @@ func (h *ModelsHandler) Test(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
resp, err := h.service.Test(c.Request().Context(), id)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
resp, err := h.service.Test(ctx, id)
if err != nil {
if strings.Contains(err.Error(), "invalid") {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
+47 -8
View File
@@ -7,6 +7,8 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
)
@@ -20,6 +22,7 @@ func NewProviderOAuthHandler(service *providers.Service) *ProviderOAuthHandler {
func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
e.GET("/providers/:id/oauth/authorize", h.Authorize)
e.POST("/providers/:id/oauth/poll", h.Poll)
e.GET("/providers/:id/oauth/status", h.Status)
e.DELETE("/providers/:id/oauth/token", h.Revoke)
e.GET("/auth/callback", h.Callback)
@@ -30,7 +33,7 @@ func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
// @Summary Start OAuth2 authorization for an LLM provider
// @Tags providers-oauth
// @Param id path string true "Provider ID (UUID)"
// @Success 200 {object} map[string]string
// @Success 200 {object} providers.OAuthAuthorizeResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /providers/{id}/oauth/authorize [get].
@@ -39,11 +42,39 @@ func (h *ProviderOAuthHandler) Authorize(c echo.Context) error {
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
authURL, err := h.service.StartOAuthAuthorization(c.Request().Context(), providerID)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
resp, err := h.service.StartOAuthAuthorization(ctx, providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusOK, map[string]string{"auth_url": authURL})
return c.JSON(http.StatusOK, resp)
}
// Poll godoc
// @Summary Poll OAuth device authorization for an LLM provider
// @Tags providers-oauth
// @Param id path string true "Provider ID (UUID)"
// @Success 200 {object} providers.OAuthStatus
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /providers/{id}/oauth/poll [post].
func (h *ProviderOAuthHandler) Poll(c echo.Context) error {
providerID := strings.TrimSpace(c.Param("id"))
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
status, err := h.service.PollOAuthAuthorization(ctx, providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusOK, status)
}
// Status godoc
@@ -59,7 +90,11 @@ func (h *ProviderOAuthHandler) Status(c echo.Context) error {
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
status, err := h.service.GetOAuthStatus(c.Request().Context(), providerID)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
status, err := h.service.GetOAuthStatus(ctx, providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
@@ -79,7 +114,11 @@ func (h *ProviderOAuthHandler) Revoke(c echo.Context) error {
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
if err := h.service.RevokeOAuthToken(c.Request().Context(), providerID); err != nil {
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
if err := h.service.RevokeOAuthToken(ctx, providerID); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.NoContent(http.StatusNoContent)
@@ -111,11 +150,11 @@ func (h *ProviderOAuthHandler) Callback(c echo.Context) error {
<html>
<head>
<meta charset="utf-8">
<title>OpenAI OAuth Connected</title>
<title>Provider Connected</title>
</head>
<body style="font-family: sans-serif; padding: 24px;">
<h2>OpenAI OAuth connected</h2>
<p>You can close this window and return to Memoh.</p>
<h2>Provider connected</h2>
<p>Your current Memoh account is now connected.</p>
<script>
window.opener?.postMessage({ type: "memoh-provider-oauth-success", providerId: "{{.ProviderID}}" }, "*");
window.close();
+14 -2
View File
@@ -9,7 +9,9 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
"github.com/memohai/memoh/internal/providers"
)
@@ -272,7 +274,12 @@ func (h *ProvidersHandler) Test(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
resp, err := h.service.Test(c.Request().Context(), id)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
resp, err := h.service.Test(ctx, id)
if err != nil {
if strings.Contains(err.Error(), "invalid") {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
@@ -301,7 +308,12 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
remoteModels, err := h.service.FetchRemoteModels(c.Request().Context(), id)
ctx := c.Request().Context()
if userID, err := auth.UserIDFromContext(c); err == nil {
ctx = oauthctx.WithUserID(ctx, userID)
}
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
}
@@ -10,6 +10,7 @@ import (
"github.com/memohai/memoh/internal/healthcheck"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
)
const (
@@ -25,6 +26,7 @@ type BotModelLookup interface {
// BotModels holds the model UUIDs associated with a bot.
type BotModels struct {
OwnerUserID string
ChatModelID string
MemoryModelID string
EmbeddingModelID string
@@ -115,7 +117,8 @@ func (c *Checker) ListChecks(ctx context.Context, botID string) []healthcheck.Ch
wg.Add(1)
go func(idx int, s modelSlot) {
defer wg.Done()
results[idx] = c.probeSlot(probeCtx, s)
slotCtx := oauthctx.WithUserID(probeCtx, botModels.OwnerUserID)
results[idx] = c.probeSlot(slotCtx, s)
}(i, slot)
}
wg.Wait()
@@ -0,0 +1,59 @@
package modelchecker
import (
"context"
"testing"
"github.com/memohai/memoh/internal/healthcheck"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/oauthctx"
)
type testLookup struct {
models BotModels
}
func (l testLookup) GetBotModelIDs(context.Context, string) (BotModels, error) {
return l.models, nil
}
type testProber struct {
t *testing.T
wantUserID string
}
func (p testProber) Test(ctx context.Context, id string) (models.TestResponse, error) {
if got := oauthctx.UserIDFromContext(ctx); got != p.wantUserID {
p.t.Fatalf("expected oauth user id %q, got %q", p.wantUserID, got)
}
if id != "model-chat-1" {
p.t.Fatalf("expected model id %q, got %q", "model-chat-1", id)
}
return models.TestResponse{
Status: models.TestStatusOK,
Reachable: true,
Message: "ok",
}, nil
}
func TestCheckerListChecksInjectsOwnerUserIDIntoProbeContext(t *testing.T) {
t.Parallel()
checker := NewChecker(nil, testLookup{
models: BotModels{
OwnerUserID: "user-123",
ChatModelID: "model-chat-1",
},
}, testProber{
t: t,
wantUserID: "user-123",
})
items := checker.ListChecks(context.Background(), "bot-1")
if len(items) != 1 {
t.Fatalf("expected 1 check result, got %d", len(items))
}
if items[0].Status != healthcheck.StatusOK {
t.Fatalf("expected status %q, got %q", healthcheck.StatusOK, items[0].Status)
}
}
@@ -36,6 +36,7 @@ func (l *QueriesLookup) GetBotModelIDs(ctx context.Context, botID string) (BotMo
}
var m BotModels
m.OwnerUserID = bot.OwnerUserID.String()
if bot.ChatModelID.Valid {
m.ChatModelID = bot.ChatModelID.String()
}
+1
View File
@@ -429,6 +429,7 @@ func IsValidClientType(clientType ClientType) bool {
ClientTypeAnthropicMessages,
ClientTypeGoogleGenerativeAI,
ClientTypeOpenAICodex,
ClientTypeGitHubCopilot,
ClientTypeEdgeSpeech:
return true
default:
+53 -12
View File
@@ -17,8 +17,10 @@ import (
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
memohcopilot "github.com/memohai/memoh/internal/copilot"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/oauthctx"
)
const probeTimeout = 15 * time.Second
@@ -162,6 +164,9 @@ func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientTyp
}
return openaicodex.New(opts...)
case ClientTypeGitHubCopilot:
return memohcopilot.NewProvider(apiKey, httpClient)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(apiKey),
@@ -202,26 +207,62 @@ type modelCredentials struct {
func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.Provider) (modelCredentials, error) {
apiKey := providerConfigString(provider.Config, "api_key")
if ClientType(provider.ClientType) != ClientTypeOpenAICodex {
switch ClientType(provider.ClientType) {
case ClientTypeGitHubCopilot:
token, err := s.resolveGitHubCopilotAccessToken(ctx, provider)
if err != nil {
return modelCredentials{}, err
}
return modelCredentials{APIKey: token}, nil
case ClientTypeOpenAICodex:
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
if err != nil {
return modelCredentials{}, err
}
accessToken := strings.TrimSpace(tokenRow.AccessToken)
if accessToken == "" {
return modelCredentials{}, errors.New("oauth token is missing access token")
}
accountID, err := codexAccountIDFromToken(accessToken)
if err != nil {
return modelCredentials{}, err
}
return modelCredentials{
APIKey: accessToken,
CodexAccountID: accountID,
}, nil
default:
return modelCredentials{APIKey: apiKey}, nil
}
}
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
if err != nil {
return modelCredentials{}, err
func (s *Service) resolveGitHubCopilotAccessToken(ctx context.Context, provider sqlc.Provider) (string, error) {
userID := oauthctx.UserIDFromContext(ctx)
if userID == "" {
return "", errors.New("github copilot requires a current user")
}
accessToken := strings.TrimSpace(tokenRow.AccessToken)
userUUID, err := db.ParseUUID(userID)
if err != nil {
return "", err
}
row, err := s.queries.GetUserProviderOAuthToken(ctx, sqlc.GetUserProviderOAuthTokenParams{
ProviderID: provider.ID,
UserID: userUUID,
})
if err != nil {
return "", err
}
accessToken := strings.TrimSpace(row.AccessToken)
if accessToken == "" {
return modelCredentials{}, errors.New("oauth token is missing access token")
return "", errors.New("oauth token is missing access token")
}
accountID, err := codexAccountIDFromToken(accessToken)
copilotToken, err := memohcopilot.ResolveToken(ctx, accessToken)
if err != nil {
return modelCredentials{}, err
return "", err
}
return modelCredentials{
APIKey: accessToken,
CodexAccountID: accountID,
}, nil
return copilotToken, nil
}
func codexAccountIDFromToken(token string) (string, error) {
+7
View File
@@ -10,6 +10,8 @@ import (
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
memohcopilot "github.com/memohai/memoh/internal/copilot"
)
// SDKModelConfig holds provider and model information resolved from DB,
@@ -76,6 +78,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
}
return openaicodex.New(opts...).ChatModel(cfg.ModelID)
case ClientTypeGitHubCopilot:
return memohcopilot.NewModel(cfg.APIKey, cfg.ModelID, cfg.HTTPClient)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(cfg.APIKey),
@@ -178,6 +183,8 @@ func ResolveClientType(model *sdk.Model) string {
return string(ClientTypeAnthropicMessages)
case strings.Contains(name, "google"):
return string(ClientTypeGoogleGenerativeAI)
case strings.Contains(name, "github-copilot"), strings.Contains(name, "copilot"):
return string(ClientTypeGitHubCopilot)
case strings.Contains(name, "codex"):
return string(ClientTypeOpenAICodex)
case strings.Contains(name, "responses"):
+1
View File
@@ -22,6 +22,7 @@ const (
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
ClientTypeOpenAICodex ClientType = "openai-codex"
ClientTypeGitHubCopilot ClientType = "github-copilot"
ClientTypeEdgeSpeech ClientType = "edge-speech"
)
+21
View File
@@ -0,0 +1,21 @@
package oauthctx
import (
"context"
"strings"
)
type userIDContextKey struct{}
func WithUserID(ctx context.Context, userID string) context.Context {
userID = strings.TrimSpace(userID)
if userID == "" {
return ctx
}
return context.WithValue(ctx, userIDContextKey{}, userID)
}
func UserIDFromContext(ctx context.Context) string {
userID, _ := ctx.Value(userIDContextKey{}).(string)
return strings.TrimSpace(userID)
}
+28 -14
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"strings"
memohcopilot "github.com/memohai/memoh/internal/copilot"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
@@ -24,25 +25,38 @@ func SupportsOpenAICodexOAuth(provider sqlc.Provider) bool {
}
func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.Provider) (ModelCredentials, error) {
if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex {
switch models.ClientType(provider.ClientType) {
case models.ClientTypeGitHubCopilot:
githubToken, err := s.GetValidAccessToken(ctx, provider.ID.String())
if err != nil {
return ModelCredentials{}, err
}
copilotToken, err := memohcopilot.ResolveToken(ctx, githubToken)
if err != nil {
return ModelCredentials{}, err
}
return ModelCredentials{APIKey: copilotToken}, nil
case models.ClientTypeOpenAICodex:
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
if err != nil {
return ModelCredentials{}, err
}
accountID, err := codexAccountIDFromToken(token)
if err != nil {
return ModelCredentials{}, err
}
return ModelCredentials{
APIKey: token,
CodexAccountID: accountID,
}, nil
default:
apiKey := ProviderConfigString(provider, "api_key")
return ModelCredentials{
APIKey: apiKey,
}, nil
}
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
if err != nil {
return ModelCredentials{}, err
}
accountID, err := codexAccountIDFromToken(token)
if err != nil {
return ModelCredentials{}, err
}
return ModelCredentials{
APIKey: token,
CodexAccountID: accountID,
}, nil
}
func codexAccountIDFromToken(token string) (string, error) {
File diff suppressed because it is too large Load Diff
+81 -16
View File
@@ -11,9 +11,11 @@ import (
"time"
"github.com/jackc/pgx/v5/pgtype"
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
sdk "github.com/memohai/twilight-ai/sdk"
memohcopilot "github.com/memohai/memoh/internal/copilot"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
@@ -47,15 +49,14 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
}
configJSON, err := json.Marshal(req.Config)
if err != nil {
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
}
clientType := req.ClientType
if clientType == "" {
clientType = string(models.ClientTypeOpenAICompletions)
}
configJSON, err := json.Marshal(normalizeProviderConfig(clientType, req.Config))
if err != nil {
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
}
var icon pgtype.Text
if req.Icon != "" {
@@ -150,12 +151,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
existingConfig := providerConfig(existing.Config)
if req.Config != nil {
existingAPIKey := configString(existingConfig, "api_key")
newAPIKey := configString(req.Config, "api_key")
if newAPIKey != "" && newAPIKey == maskAPIKey(existingAPIKey) {
req.Config["api_key"] = existingAPIKey
}
existingConfig = req.Config
mergedConfig := mergeProviderConfig(existingConfig, req.Config)
preserveMaskedConfigSecret(mergedConfig, existingConfig, req.Config, "api_key")
existingConfig = normalizeProviderConfig(clientType, mergedConfig)
} else {
existingConfig = normalizeProviderConfig(clientType, existingConfig)
}
configJSON, err := json.Marshal(existingConfig)
if err != nil {
@@ -257,6 +257,34 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
if err != nil {
return nil, fmt.Errorf("get provider: %w", err)
}
if models.ClientType(provider.ClientType) == models.ClientTypeGitHubCopilot {
creds, err := s.ResolveModelCredentials(ctx, provider)
if err != nil {
return nil, err
}
sdkProvider := memohcopilot.NewProvider(creds.APIKey, nil)
if result := sdkProvider.Test(ctx); result.Status != sdk.ProviderStatusOK {
return nil, fmt.Errorf("github copilot provider test failed: %s", result.Message)
}
catalog := githubcopilot.Catalog()
remoteModels := make([]RemoteModel, 0, len(catalog))
for _, model := range catalog {
remoteModels = append(remoteModels, RemoteModel{
ID: model.ID,
Name: model.DisplayName,
Object: "model",
OwnedBy: "github-copilot",
Type: "chat",
Compatibilities: []string{
models.CompatVision,
models.CompatToolCall,
models.CompatReasoning,
},
})
}
return remoteModels, nil
}
if supportsOAuth(provider) {
catalog := openaicodex.Catalog()
remoteModels := make([]RemoteModel, 0, len(catalog))
@@ -329,7 +357,7 @@ func (s *Service) toGetResponse(provider sqlc.Provider) GetResponse {
}
cfg := providerConfig(provider.Config)
maskedCfg := maskConfigAPIKey(cfg)
maskedCfg := maskConfigSecrets(provider.ClientType, cfg)
var icon string
if provider.Icon.Valid {
@@ -378,14 +406,51 @@ func ProviderConfigString(provider sqlc.Provider, key string) string {
return configString(providerConfig(provider.Config), key)
}
// maskConfigAPIKey returns a copy of config with api_key masked.
func maskConfigAPIKey(cfg map[string]any) map[string]any {
func cloneConfig(cfg map[string]any) map[string]any {
result := make(map[string]any, len(cfg))
for k, v := range cfg {
result[k] = v
}
if apiKey, _ := result["api_key"].(string); apiKey != "" {
result["api_key"] = maskAPIKey(apiKey)
return result
}
func mergeProviderConfig(existing, incoming map[string]any) map[string]any {
result := cloneConfig(existing)
for k, v := range incoming {
result[k] = v
}
return result
}
func preserveMaskedConfigSecret(merged, existing, incoming map[string]any, key string) {
existingValue := strings.TrimSpace(configString(existing, key))
newValue := strings.TrimSpace(configString(incoming, key))
if existingValue == "" || newValue == "" {
return
}
if newValue == maskAPIKey(existingValue) {
merged[key] = existingValue
}
}
// normalizeProviderConfig keeps provider-specific secrets under stable keys while
// preserving backward compatibility for legacy stored configs.
func normalizeProviderConfig(clientType string, cfg map[string]any) map[string]any {
result := cloneConfig(cfg)
if models.ClientType(clientType) == models.ClientTypeGitHubCopilot {
delete(result, "api_key")
delete(result, configOAuthClientSecretKey)
}
return result
}
// maskConfigSecrets returns a copy of config with all known secret fields masked.
func maskConfigSecrets(clientType string, cfg map[string]any) map[string]any {
result := normalizeProviderConfig(clientType, cfg)
for _, key := range []string{"api_key", configOAuthClientSecretKey} {
if value, _ := result[key].(string); value != "" {
result[key] = maskAPIKey(value)
}
}
return result
}
+170 -1
View File
@@ -1,6 +1,12 @@
package providers
import "testing"
import (
"testing"
"time"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
func TestMaskAPIKey(t *testing.T) {
t.Parallel()
@@ -34,3 +40,166 @@ func TestMaskAPIKey(t *testing.T) {
}
})
}
func TestNormalizeProviderConfig(t *testing.T) {
t.Parallel()
t.Run("github copilot drops legacy secrets", func(t *testing.T) {
t.Parallel()
cfg := normalizeProviderConfig("github-copilot", map[string]any{
"api_key": "gh-secret",
configOAuthClientSecretKey: "oauth-secret",
"base_url": "ignored",
})
if _, exists := cfg[configOAuthClientSecretKey]; exists {
t.Fatalf("expected oauth client secret to be removed, got %#v", cfg[configOAuthClientSecretKey])
}
if _, exists := cfg["api_key"]; exists {
t.Fatalf("expected legacy api_key to be removed, got %#v", cfg["api_key"])
}
})
t.Run("non copilot providers keep api key key", func(t *testing.T) {
t.Parallel()
cfg := normalizeProviderConfig("openai-completions", map[string]any{
"api_key": "sk-live",
})
if got, ok := cfg["api_key"].(string); !ok || got != "sk-live" {
t.Fatalf("expected api_key to remain untouched, got %#v", cfg["api_key"])
}
})
}
func TestMaskConfigSecrets(t *testing.T) {
t.Parallel()
cfg := maskConfigSecrets("openai-completions", map[string]any{
"api_key": "sk-secret-123456",
})
masked, _ := cfg["api_key"].(string)
if masked == "" || masked == "sk-secret-123456" {
t.Fatalf("expected api key to be masked, got %q", masked)
}
}
func TestPreserveMaskedConfigSecret(t *testing.T) {
t.Parallel()
merged := map[string]any{
configOAuthClientSecretKey: "*************",
}
existing := map[string]any{
configOAuthClientSecretKey: "gh-secret-1234",
}
incoming := map[string]any{
configOAuthClientSecretKey: maskAPIKey("gh-secret-1234"),
}
preserveMaskedConfigSecret(merged, existing, incoming, configOAuthClientSecretKey)
if got, _ := merged[configOAuthClientSecretKey].(string); got != "gh-secret-1234" {
t.Fatalf("expected masked value to be restored to original secret, got %q", got)
}
}
func TestDeviceMetadataRoundTrip(t *testing.T) {
t.Parallel()
expiresAt := time.Date(2026, time.April, 11, 12, 0, 0, 0, time.UTC)
device := oauthDeviceMetadata{
DeviceCode: "device-code",
UserCode: "ABCD-EFGH",
VerificationURI: "https://github.com/login/device",
ExpiresAt: expiresAt,
IntervalSeconds: 5,
}
parsed := deviceMetadataFromMap(device.toMetadata())
if parsed.DeviceCode != device.DeviceCode {
t.Fatalf("expected device code %q, got %q", device.DeviceCode, parsed.DeviceCode)
}
if parsed.UserCode != device.UserCode {
t.Fatalf("expected user code %q, got %q", device.UserCode, parsed.UserCode)
}
if parsed.VerificationURI != device.VerificationURI {
t.Fatalf("expected verification uri %q, got %q", device.VerificationURI, parsed.VerificationURI)
}
if !parsed.ExpiresAt.Equal(expiresAt) {
t.Fatalf("expected expiresAt %s, got %s", expiresAt, parsed.ExpiresAt)
}
if parsed.IntervalSeconds != device.IntervalSeconds {
t.Fatalf("expected interval %d, got %d", device.IntervalSeconds, parsed.IntervalSeconds)
}
status := parsed.toStatus()
if status == nil || !status.Pending {
t.Fatalf("expected pending device status, got %#v", status)
}
}
func TestAccountMetadataRoundTrip(t *testing.T) {
t.Parallel()
account := oauthAccountMetadata{
Label: "octocat",
Login: "octocat",
Name: "The Octocat",
Email: "octocat@github.com",
AvatarURL: "https://avatars.githubusercontent.com/u/1?v=4",
ProfileURL: "https://github.com/octocat",
}
parsed := accountMetadataFromMap(account.toMetadata())
if parsed.Label != account.Label {
t.Fatalf("expected label %q, got %q", account.Label, parsed.Label)
}
if parsed.Login != account.Login {
t.Fatalf("expected login %q, got %q", account.Login, parsed.Login)
}
if parsed.Name != account.Name {
t.Fatalf("expected name %q, got %q", account.Name, parsed.Name)
}
if parsed.Email != account.Email {
t.Fatalf("expected email %q, got %q", account.Email, parsed.Email)
}
if parsed.AvatarURL != account.AvatarURL {
t.Fatalf("expected avatar url %q, got %q", account.AvatarURL, parsed.AvatarURL)
}
if parsed.ProfileURL != account.ProfileURL {
t.Fatalf("expected profile url %q, got %q", account.ProfileURL, parsed.ProfileURL)
}
status := parsed.toStatus()
if status == nil {
t.Fatal("expected account status")
}
if status.Label != account.Label {
t.Fatalf("expected status label %q, got %q", account.Label, status.Label)
}
}
func TestOAuthConfigForGitHubCopilotUsesFixedDeviceFlowSettings(t *testing.T) {
t.Parallel()
service := &Service{}
cfg := service.oauthConfigForProvider(sqlc.Provider{
ClientType: string(models.ClientTypeGitHubCopilot),
Config: []byte(`{"api_key":"legacy","oauth_client_secret":"legacy-secret"}`),
Metadata: []byte(`{"oauth_client_id":"custom","oauth_scopes":"repo"}`),
})
if cfg.ClientID != "Iv1.b507a08c87ecfe98" {
t.Fatalf("expected fixed client id, got %q", cfg.ClientID)
}
if cfg.ClientSecret != "" {
t.Fatalf("expected empty client secret, got %q", cfg.ClientSecret)
}
if cfg.Scopes != "read:user user:email" {
t.Fatalf("expected fixed scope, got %q", cfg.Scopes)
}
}
+31 -5
View File
@@ -54,11 +54,37 @@ type TestResponse struct {
// OAuthStatus is returned by GET /providers/:id/oauth/status.
type OAuthStatus struct {
Configured bool `json:"configured"`
HasToken bool `json:"has_token"`
Expired bool `json:"expired"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CallbackURL string `json:"callback_url"`
Configured bool `json:"configured"`
Mode string `json:"mode,omitempty"`
HasToken bool `json:"has_token"`
Expired bool `json:"expired"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CallbackURL string `json:"callback_url"`
Device *OAuthDeviceStatus `json:"device,omitempty"`
Account *OAuthAccount `json:"account,omitempty"`
}
type OAuthDeviceStatus struct {
Pending bool `json:"pending"`
UserCode string `json:"user_code,omitempty"`
VerificationURI string `json:"verification_uri,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IntervalSeconds int64 `json:"interval_seconds,omitempty"`
}
type OAuthAccount struct {
Label string `json:"label,omitempty"`
Login string `json:"login,omitempty"`
Name string `json:"name,omitempty"`
Email string `json:"email,omitempty"`
AvatarURL string `json:"avatar_url,omitempty"`
ProfileURL string `json:"profile_url,omitempty"`
}
type OAuthAuthorizeResponse struct {
Mode string `json:"mode,omitempty"`
AuthURL string `json:"auth_url,omitempty"`
Device *OAuthDeviceStatus `json:"device,omitempty"`
}
// RemoteModel represents a model returned by the provider's /v1/models endpoint.