mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(provider): add github copilot device flow provider (#364)
This commit is contained in:
@@ -46,7 +46,7 @@
|
||||
</FormItem>
|
||||
</FormField>
|
||||
<FormField
|
||||
v-if="form.values.client_type !== 'openai-codex'"
|
||||
v-if="!['openai-codex', 'github-copilot'].includes(form.values.client_type)"
|
||||
v-slot="{ componentField }"
|
||||
name="api_key"
|
||||
>
|
||||
@@ -69,12 +69,13 @@
|
||||
</FormItem>
|
||||
</FormField>
|
||||
<div
|
||||
v-else
|
||||
v-else-if="['openai-codex', 'github-copilot'].includes(form.values.client_type)"
|
||||
class="rounded-lg border p-3 text-xs text-muted-foreground"
|
||||
>
|
||||
{{ $t('provider.oauth.createHint') }}
|
||||
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.githubCreateHint' : 'provider.oauth.openaiCreateHint') }}
|
||||
</div>
|
||||
<FormField
|
||||
v-if="form.values.client_type !== 'github-copilot'"
|
||||
v-slot="{ componentField }"
|
||||
name="base_url"
|
||||
>
|
||||
@@ -188,12 +189,13 @@ const { mutateAsync: createProviderMutation, isLoading } = useMutation({
|
||||
mutation: async (data: Record<string, unknown>) => {
|
||||
const config: Record<string, unknown> = {}
|
||||
if (data.base_url) config.base_url = data.base_url
|
||||
if (data.api_key) config.api_key = data.api_key
|
||||
if (typeof data.api_key === 'string' && data.api_key.trim() !== '' && data.client_type !== 'github-copilot') {
|
||||
config.api_key = data.api_key.trim()
|
||||
}
|
||||
const payload = {
|
||||
name: data.name,
|
||||
client_type: data.client_type,
|
||||
config,
|
||||
metadata: { additionalProp1: {} },
|
||||
}
|
||||
const { data: result } = await postProviders({ body: payload as ProvidersCreateRequest, throwOnError: true })
|
||||
if (data.auto_import && result?.id) {
|
||||
@@ -221,18 +223,25 @@ const { mutateAsync: createProviderMutation, isLoading } = useMutation({
|
||||
|
||||
const providerSchema = toTypedSchema(z.object({
|
||||
api_key: z.string().optional(),
|
||||
base_url: z.string().min(1),
|
||||
base_url: z.string().optional(),
|
||||
name: z.string().min(1),
|
||||
client_type: z.string().min(1),
|
||||
auto_import: z.boolean().optional(),
|
||||
}).superRefine((value, ctx) => {
|
||||
if (value.client_type !== 'openai-codex' && !value.api_key?.trim()) {
|
||||
if (!['openai-codex', 'github-copilot'].includes(value.client_type) && !value.api_key?.trim()) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
path: ['api_key'],
|
||||
message: 'API key is required',
|
||||
})
|
||||
}
|
||||
if (value.client_type !== 'github-copilot' && !value.base_url?.trim()) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
path: ['base_url'],
|
||||
message: 'Base URL is required',
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
const form = useForm({
|
||||
@@ -240,14 +249,16 @@ const form = useForm({
|
||||
initialValues: {
|
||||
auto_import: false,
|
||||
client_type: 'openai-completions',
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
watch(() => form.values.client_type, (clientType) => {
|
||||
if (clientType !== 'openai-codex') return
|
||||
if (!form.values.base_url) {
|
||||
if (clientType === 'openai-codex' && !form.values.base_url) {
|
||||
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
|
||||
}
|
||||
if (clientType === 'github-copilot') {
|
||||
form.setFieldValue('base_url', '')
|
||||
}
|
||||
})
|
||||
|
||||
const createProvider = form.handleSubmit(async (value) => {
|
||||
|
||||
@@ -20,6 +20,11 @@ export const CLIENT_TYPE_META: Record<string, ClientTypeMeta> = {
|
||||
label: 'OpenAI Codex',
|
||||
hint: 'Codex API (OAuth, coding-optimized)',
|
||||
},
|
||||
'github-copilot': {
|
||||
value: 'github-copilot',
|
||||
label: 'GitHub Copilot',
|
||||
hint: 'Device OAuth with GitHub account',
|
||||
},
|
||||
'anthropic-messages': {
|
||||
value: 'anthropic-messages',
|
||||
label: 'Anthropic Messages',
|
||||
|
||||
@@ -292,20 +292,35 @@
|
||||
"enable": "Enable",
|
||||
"enableHint": "Only models from enabled providers appear in the available model list",
|
||||
"oauth": {
|
||||
"title": "OpenAI OAuth",
|
||||
"description": "Authorize this provider with your ChatGPT account for Codex-compatible OpenAI access.",
|
||||
"createHint": "Save the provider first, then authorize it from the provider details panel.",
|
||||
"openaiTitle": "OpenAI OAuth",
|
||||
"openaiDescription": "Authorize this provider with your ChatGPT account for Codex-compatible OpenAI access.",
|
||||
"openaiCreateHint": "Save the provider first, then authorize it from the provider details panel.",
|
||||
"githubTitle": "GitHub Copilot OAuth",
|
||||
"githubDescription": "Connect the current Memoh account with GitHub Copilot.",
|
||||
"githubDeviceTitle": "GitHub Copilot Device Authorization",
|
||||
"githubDeviceDescription": "Start device authorization, open GitHub's verification page, and enter the user code shown below.",
|
||||
"githubCreateHint": "Save the provider first, then start device authorization from the provider details panel.",
|
||||
"githubDeviceHint": "Open the verification URL below and enter this user code to authorize the current Memoh account.",
|
||||
"authorize": "Authorize",
|
||||
"deviceAuthorize": "Start Device Authorization",
|
||||
"authorizeFailed": "Failed to start authorization",
|
||||
"authorizeSuccess": "Authorization successful",
|
||||
"revoke": "Revoke",
|
||||
"revokeFailed": "Failed to revoke authorization",
|
||||
"revokeSuccess": "Authorization revoked",
|
||||
"copyFailed": "Failed to copy device code",
|
||||
"connectedAccount": "Connected Account",
|
||||
"callback": "Callback URL",
|
||||
"deviceVerificationUri": "Verification URL",
|
||||
"deviceUserCode": "User Code",
|
||||
"deviceExpiresAt": "Expires At",
|
||||
"statusFailed": "Failed to load OAuth status",
|
||||
"status": {
|
||||
"checking": "Checking authorization status...",
|
||||
"authorized": "Authorized",
|
||||
"authorizedCurrent": "Current account connected",
|
||||
"oauthing": "OAuthing...",
|
||||
"pendingDevice": "Waiting for device authorization to complete...",
|
||||
"expired": "Authorization expired. Re-authorize to continue.",
|
||||
"missing": "Not authorized yet.",
|
||||
"notConfigured": "This provider is not configured for OAuth."
|
||||
|
||||
@@ -288,20 +288,35 @@
|
||||
"enable": "启用",
|
||||
"enableHint": "只有启用的供应商的模型才会出现在可用模型列表中",
|
||||
"oauth": {
|
||||
"title": "OpenAI OAuth",
|
||||
"description": "使用你的 ChatGPT 账号为该提供商授权,以启用 Codex 兼容的 OpenAI 访问。",
|
||||
"createHint": "请先保存提供商,再到详情面板完成授权。",
|
||||
"openaiTitle": "OpenAI OAuth",
|
||||
"openaiDescription": "使用你的 ChatGPT 账号为该提供商授权,以启用 Codex 兼容的 OpenAI 访问。",
|
||||
"openaiCreateHint": "请先保存提供商,再到详情面板完成授权。",
|
||||
"githubTitle": "GitHub Copilot OAuth",
|
||||
"githubDescription": "为当前 Memoh 账号连接 GitHub Copilot。",
|
||||
"githubDeviceTitle": "GitHub Copilot Device Authorization",
|
||||
"githubDeviceDescription": "启动设备授权后,打开 GitHub 验证页面并输入下方显示的用户代码。",
|
||||
"githubCreateHint": "请先保存提供商,再到详情面板启动设备授权。",
|
||||
"githubDeviceHint": "打开下方验证地址,并输入这个用户代码,为当前 Memoh 账号完成授权。",
|
||||
"authorize": "授权",
|
||||
"deviceAuthorize": "启动设备授权",
|
||||
"authorizeFailed": "启动授权失败",
|
||||
"authorizeSuccess": "授权成功",
|
||||
"revoke": "撤销授权",
|
||||
"revokeFailed": "撤销授权失败",
|
||||
"revokeSuccess": "授权已撤销",
|
||||
"copyFailed": "复制设备代码失败",
|
||||
"connectedAccount": "已连接账号",
|
||||
"callback": "回调地址",
|
||||
"deviceVerificationUri": "验证地址",
|
||||
"deviceUserCode": "用户代码",
|
||||
"deviceExpiresAt": "过期时间",
|
||||
"statusFailed": "加载 OAuth 状态失败",
|
||||
"status": {
|
||||
"checking": "正在检查授权状态...",
|
||||
"authorized": "已授权",
|
||||
"authorizedCurrent": "当前账号已连接",
|
||||
"oauthing": "授权中...",
|
||||
"pendingDevice": "正在等待设备授权完成...",
|
||||
"expired": "授权已过期,请重新授权。",
|
||||
"missing": "尚未授权。",
|
||||
"notConfigured": "当前提供商未正确配置 OAuth。"
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
</section>
|
||||
|
||||
<section
|
||||
v-if="form.values.client_type !== 'openai-codex'"
|
||||
v-if="!['openai-codex', 'github-copilot'].includes(form.values.client_type)"
|
||||
class="space-y-2"
|
||||
>
|
||||
<FormField
|
||||
@@ -33,7 +33,7 @@
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
:placeholder="(providerWithAuth?.config as Record<string, unknown> | undefined)?.api_key as string || $t('provider.apiKeyPlaceholder')"
|
||||
:placeholder="getStoredSecret(props.provider?.config as Record<string, unknown> | undefined) || $t('provider.apiKeyPlaceholder')"
|
||||
:aria-label="$t('provider.apiKey')"
|
||||
v-bind="componentField"
|
||||
/>
|
||||
@@ -42,7 +42,10 @@
|
||||
</FormField>
|
||||
</section>
|
||||
|
||||
<section class="space-y-2">
|
||||
<section
|
||||
v-if="form.values.client_type !== 'github-copilot'"
|
||||
class="space-y-2"
|
||||
>
|
||||
<FormField
|
||||
v-slot="{ componentField }"
|
||||
name="base_url"
|
||||
@@ -81,15 +84,15 @@
|
||||
</section>
|
||||
|
||||
<section
|
||||
v-if="form.values.client_type === 'openai-codex'"
|
||||
v-if="['openai-codex', 'github-copilot'].includes(form.values.client_type)"
|
||||
class="rounded-lg border p-4 space-y-3 text-xs"
|
||||
>
|
||||
<div class="space-y-1">
|
||||
<div class="font-medium">
|
||||
{{ $t('provider.oauth.title') }}
|
||||
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.githubDeviceTitle' : 'provider.oauth.openaiTitle') }}
|
||||
</div>
|
||||
<div class="text-muted-foreground">
|
||||
{{ $t('provider.oauth.description') }}
|
||||
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.githubDeviceDescription' : 'provider.oauth.openaiDescription') }}
|
||||
</div>
|
||||
<div
|
||||
class="text-xs"
|
||||
@@ -105,7 +108,10 @@
|
||||
{{ $t('provider.oauth.status.expired') }}
|
||||
</template>
|
||||
<template v-else-if="oauthStatus?.has_token">
|
||||
{{ $t('provider.oauth.status.authorized') }}
|
||||
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.status.authorizedCurrent' : 'provider.oauth.status.authorized') }}
|
||||
</template>
|
||||
<template v-else-if="oauthStatus?.device?.pending">
|
||||
{{ $t('provider.oauth.status.pendingDevice') }}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ $t('provider.oauth.status.missing') }}
|
||||
@@ -118,16 +124,88 @@
|
||||
{{ $t('provider.oauth.callback') }}: {{ oauthStatus.callback_url }}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="form.values.client_type === 'github-copilot'
|
||||
&& oauthStatus?.device?.pending
|
||||
&& !oauthStatus?.has_token
|
||||
&& oauthStatus?.device?.user_code
|
||||
&& oauthStatus?.device?.verification_uri"
|
||||
class="rounded-md bg-muted/40 p-3 space-y-2"
|
||||
>
|
||||
<div class="text-muted-foreground">
|
||||
{{ $t('provider.oauth.githubDeviceHint') }}
|
||||
</div>
|
||||
<div class="space-y-1">
|
||||
<div class="font-medium">
|
||||
{{ $t('provider.oauth.deviceVerificationUri') }}
|
||||
</div>
|
||||
<code class="block break-all rounded bg-background px-2 py-1 select-all">{{ oauthStatus?.device?.verification_uri }}</code>
|
||||
</div>
|
||||
<div class="space-y-1">
|
||||
<div class="font-medium">
|
||||
{{ $t('provider.oauth.deviceUserCode') }}
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<code class="block flex-1 rounded bg-background px-2 py-1 text-sm tracking-[0.3em] select-all">{{ oauthStatus?.device?.user_code }}</code>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
@click="handleCopyDeviceCode"
|
||||
>
|
||||
<Copy />
|
||||
{{ $t('common.copy') }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="oauthStatus?.device?.expires_at"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
{{ $t('provider.oauth.deviceExpiresAt') }}: {{ oauthStatus.device.expires_at }}
|
||||
</div>
|
||||
<div class="flex items-center gap-2 text-foreground">
|
||||
<Spinner class="size-4" />
|
||||
<span>{{ $t('provider.oauth.status.oauthing') }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="form.values.client_type === 'github-copilot' && oauthStatus?.has_token && !oauthExpired"
|
||||
class="rounded-md bg-muted/40 p-3 space-y-1"
|
||||
>
|
||||
<div class="font-medium">
|
||||
{{ $t('provider.oauth.connectedAccount') }}
|
||||
</div>
|
||||
<div class="text-sm font-medium">
|
||||
{{ oauthStatus?.account?.email || oauthStatus?.account?.label || oauthStatus?.account?.name || oauthStatus?.account?.login || $t('provider.oauth.status.authorizedCurrent') }}
|
||||
</div>
|
||||
<div
|
||||
v-if="[oauthStatus?.account?.login?.trim() ? `@${oauthStatus.account.login.trim()}` : '', oauthStatus?.account?.email?.trim() ?? ''].filter(Boolean).join(' · ')"
|
||||
class="text-xs text-muted-foreground"
|
||||
>
|
||||
{{ [oauthStatus?.account?.login?.trim() ? `@${oauthStatus.account.login.trim()}` : '', oauthStatus?.account?.email?.trim() ?? ''].filter(Boolean).join(' · ') }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex gap-2">
|
||||
<LoadingButton
|
||||
v-if="props.provider?.id
|
||||
&& ['openai-codex', 'github-copilot'].includes(form.values.client_type)
|
||||
&& !(
|
||||
form.values.client_type === 'github-copilot'
|
||||
&& oauthStatus?.device?.pending
|
||||
&& !oauthStatus?.has_token
|
||||
&& oauthStatus?.device?.user_code
|
||||
&& oauthStatus?.device?.verification_uri
|
||||
)
|
||||
&& (!oauthStatus?.has_token || oauthExpired)"
|
||||
type="button"
|
||||
variant="outline"
|
||||
:disabled="!canAuthorizeOAuth"
|
||||
:disabled="!props.provider?.id || !['openai-codex', 'github-copilot'].includes(form.values.client_type) || oauthStatusLoading"
|
||||
:loading="authorizeLoading"
|
||||
@click="handleAuthorize"
|
||||
>
|
||||
<KeyRound />
|
||||
{{ $t('provider.oauth.authorize') }}
|
||||
{{ $t(form.values.client_type === 'github-copilot' ? 'provider.oauth.deviceAuthorize' : 'provider.oauth.authorize') }}
|
||||
</LoadingButton>
|
||||
<LoadingButton
|
||||
v-if="oauthStatus?.has_token"
|
||||
@@ -225,14 +303,16 @@ import {
|
||||
FormField,
|
||||
FormLabel,
|
||||
FormItem,
|
||||
Spinner,
|
||||
} from '@memohai/ui'
|
||||
import { KeyRound, RefreshCw, Trash2 } from 'lucide-vue-next'
|
||||
import { Copy, KeyRound, RefreshCw, Trash2 } from 'lucide-vue-next'
|
||||
import ConfirmPopover from '@/components/confirm-popover/index.vue'
|
||||
import StatusDot from '@/components/status-dot/index.vue'
|
||||
import LoadingButton from '@/components/loading-button/index.vue'
|
||||
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||
import { computed, ref, watch } from 'vue'
|
||||
import { computed, onBeforeUnmount, ref, watch } from 'vue'
|
||||
import { toTypedSchema } from '@vee-validate/zod'
|
||||
import z from 'zod'
|
||||
import { useForm } from 'vee-validate'
|
||||
@@ -242,15 +322,44 @@ import { useI18n } from 'vue-i18n'
|
||||
import { toast } from 'vue-sonner'
|
||||
|
||||
const { t } = useI18n()
|
||||
const { copyText } = useClipboard()
|
||||
|
||||
type ProviderWithAuth = Partial<ProvidersGetResponse>
|
||||
|
||||
type ProviderOAuthStatus = {
|
||||
configured: boolean
|
||||
mode?: string
|
||||
has_token: boolean
|
||||
expired: boolean
|
||||
callback_url?: string
|
||||
expires_at?: string
|
||||
account?: {
|
||||
label?: string
|
||||
login?: string
|
||||
name?: string
|
||||
email?: string
|
||||
avatar_url?: string
|
||||
profile_url?: string
|
||||
}
|
||||
device?: {
|
||||
pending: boolean
|
||||
user_code?: string
|
||||
verification_uri?: string
|
||||
expires_at?: string
|
||||
interval_seconds?: number
|
||||
}
|
||||
}
|
||||
|
||||
type ProviderOAuthAuthorizeResponse = {
|
||||
mode?: string
|
||||
auth_url?: string
|
||||
device?: ProviderOAuthStatus['device']
|
||||
}
|
||||
|
||||
function getStoredSecret(config: Record<string, unknown> | undefined) {
|
||||
if (!config) return ''
|
||||
const apiKey = config.api_key
|
||||
return typeof apiKey === 'string' ? apiKey : ''
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
@@ -271,10 +380,9 @@ const oauthStatus = ref<ProviderOAuthStatus | null>(null)
|
||||
const oauthStatusLoading = ref(false)
|
||||
const authorizeLoading = ref(false)
|
||||
const revokeLoading = ref(false)
|
||||
const pollTimer = ref<number | null>(null)
|
||||
const apiBase = import.meta.env.VITE_API_URL?.trim() || '/api'
|
||||
|
||||
const providerWithAuth = computed(() => props.provider as ProviderWithAuth | undefined)
|
||||
|
||||
async function runTest() {
|
||||
if (!props.provider?.id) return
|
||||
testLoading.value = true
|
||||
@@ -310,20 +418,27 @@ const clientTypeOptions = computed(() =>
|
||||
const providerSchema = toTypedSchema(z.object({
|
||||
enable: z.boolean(),
|
||||
name: z.string().min(1),
|
||||
base_url: z.string().min(1),
|
||||
base_url: z.string().optional(),
|
||||
api_key: z.string().optional(),
|
||||
client_type: z.string().min(1),
|
||||
metadata: z.object({
|
||||
additionalProp1: z.object({}),
|
||||
}),
|
||||
}).superRefine((value, ctx) => {
|
||||
if (value.client_type !== 'openai-codex' && !value.api_key?.trim() && !(providerWithAuth.value?.config as Record<string, unknown> | undefined)?.api_key) {
|
||||
const existingSecret = getStoredSecret(
|
||||
props.provider?.config as Record<string, unknown> | undefined,
|
||||
)
|
||||
if (!['openai-codex', 'github-copilot'].includes(value.client_type) && !value.api_key?.trim() && !existingSecret.trim()) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
path: ['api_key'],
|
||||
message: 'API key is required',
|
||||
})
|
||||
}
|
||||
if (value.client_type !== 'github-copilot' && !value.base_url?.trim()) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
path: ['base_url'],
|
||||
message: 'Base URL is required',
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
const form = useForm({
|
||||
@@ -344,17 +459,19 @@ watch(() => props.provider, (newVal) => {
|
||||
}, { immediate: true })
|
||||
|
||||
watch(() => form.values.client_type, (clientType) => {
|
||||
if (clientType !== 'openai-codex') {
|
||||
if (!['openai-codex', 'github-copilot'].includes(clientType)) {
|
||||
oauthStatus.value = null
|
||||
return
|
||||
}
|
||||
if (!form.values.base_url) {
|
||||
if (clientType === 'openai-codex' && !form.values.base_url) {
|
||||
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
|
||||
}
|
||||
if (clientType === 'github-copilot') {
|
||||
form.setFieldValue('base_url', '')
|
||||
}
|
||||
})
|
||||
|
||||
watch(() => [props.provider?.id, form.values.client_type] as const, async ([id, clientType]) => {
|
||||
if (!id || clientType !== 'openai-codex') {
|
||||
if (!id || (clientType !== 'openai-codex' && clientType !== 'github-copilot')) {
|
||||
oauthStatus.value = null
|
||||
return
|
||||
}
|
||||
@@ -369,13 +486,11 @@ const hasChanges = computed(() => {
|
||||
name: form.values.name,
|
||||
base_url: form.values.base_url,
|
||||
client_type: form.values.client_type,
|
||||
metadata: form.values.metadata,
|
||||
}) !== JSON.stringify({
|
||||
enable: raw?.enable ?? true,
|
||||
name: raw?.name,
|
||||
base_url: (cfg?.base_url as string) ?? '',
|
||||
client_type: raw?.client_type || 'openai-completions',
|
||||
metadata: { additionalProp1: {} },
|
||||
})
|
||||
|
||||
const apiKeyChanged = Boolean(form.values.api_key && form.values.api_key.trim() !== '')
|
||||
@@ -383,33 +498,47 @@ const hasChanges = computed(() => {
|
||||
})
|
||||
|
||||
const editProvider = form.handleSubmit(async (value) => {
|
||||
const config: Record<string, unknown> = { base_url: value.base_url }
|
||||
const config: Record<string, unknown> = {}
|
||||
if (value.base_url && value.base_url.trim() !== '') {
|
||||
config.base_url = value.base_url
|
||||
}
|
||||
if (value.api_key && value.api_key.trim() !== '') {
|
||||
config.api_key = value.api_key
|
||||
if (value.client_type !== 'github-copilot') {
|
||||
config.api_key = value.api_key.trim()
|
||||
}
|
||||
}
|
||||
const metadata = {
|
||||
...((props.provider?.metadata as Record<string, unknown> | undefined) ?? {}),
|
||||
}
|
||||
if (value.client_type === 'github-copilot') {
|
||||
delete metadata.oauth_client_id
|
||||
}
|
||||
const payload: Record<string, unknown> = {
|
||||
enable: value.enable,
|
||||
name: value.name,
|
||||
config,
|
||||
client_type: value.client_type,
|
||||
metadata: value.metadata,
|
||||
}
|
||||
if (Object.keys(metadata).length > 0 || value.client_type === 'github-copilot') {
|
||||
payload.metadata = metadata
|
||||
}
|
||||
emit('submit', payload)
|
||||
})
|
||||
|
||||
const oauthExpired = computed(() => Boolean(oauthStatus.value?.has_token && oauthStatus.value?.expired))
|
||||
const canAuthorizeOAuth = computed(() =>
|
||||
Boolean(
|
||||
props.provider?.id
|
||||
&& form.values.client_type === 'openai-codex',
|
||||
) && !oauthStatusLoading.value,
|
||||
)
|
||||
|
||||
function authHeaders(): Record<string, string> {
|
||||
const token = localStorage.getItem('token')
|
||||
return token ? { Authorization: `Bearer ${token}` } : {}
|
||||
}
|
||||
|
||||
function clearPollTimer() {
|
||||
if (pollTimer.value !== null) {
|
||||
window.clearTimeout(pollTimer.value)
|
||||
pollTimer.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchOAuthStatus() {
|
||||
if (!props.provider?.id) return
|
||||
oauthStatusLoading.value = true
|
||||
@@ -427,6 +556,44 @@ async function fetchOAuthStatus() {
|
||||
}
|
||||
}
|
||||
|
||||
async function pollOAuthAuthorization(notifyOnSuccess = false) {
|
||||
if (!props.provider?.id || form.values.client_type !== 'github-copilot') return
|
||||
try {
|
||||
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/poll`, {
|
||||
method: 'POST',
|
||||
headers: authHeaders(),
|
||||
})
|
||||
if (!response.ok) throw new Error(t('provider.oauth.authorizeFailed'))
|
||||
const nextStatus = await response.json() as ProviderOAuthStatus
|
||||
const becameAuthorized = !oauthStatus.value?.has_token && Boolean(nextStatus.has_token)
|
||||
oauthStatus.value = nextStatus
|
||||
if (notifyOnSuccess && becameAuthorized) {
|
||||
toast.success(t('provider.oauth.authorizeSuccess'))
|
||||
}
|
||||
} catch (error) {
|
||||
clearPollTimer()
|
||||
toast.error(error instanceof Error ? error.message : t('provider.oauth.authorizeFailed'))
|
||||
}
|
||||
}
|
||||
|
||||
watch(oauthStatus, (status) => {
|
||||
clearPollTimer()
|
||||
if (form.values.client_type !== 'github-copilot') {
|
||||
return
|
||||
}
|
||||
if (!status?.device?.pending || status.has_token) {
|
||||
return
|
||||
}
|
||||
const intervalSeconds = Math.max(status.device.interval_seconds ?? 5, 1)
|
||||
pollTimer.value = window.setTimeout(() => {
|
||||
void pollOAuthAuthorization(true)
|
||||
}, intervalSeconds * 1000)
|
||||
})
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
clearPollTimer()
|
||||
})
|
||||
|
||||
async function handleAuthorize() {
|
||||
if (!props.provider?.id) return
|
||||
authorizeLoading.value = true
|
||||
@@ -435,7 +602,18 @@ async function handleAuthorize() {
|
||||
headers: authHeaders(),
|
||||
})
|
||||
if (!response.ok) throw new Error(t('provider.oauth.authorizeFailed'))
|
||||
const data = await response.json() as { auth_url?: string }
|
||||
const data = await response.json() as ProviderOAuthAuthorizeResponse
|
||||
if (data.mode === 'device') {
|
||||
oauthStatus.value = {
|
||||
configured: true,
|
||||
mode: 'device',
|
||||
has_token: false,
|
||||
expired: false,
|
||||
callback_url: '',
|
||||
device: data.device,
|
||||
}
|
||||
return
|
||||
}
|
||||
if (!data.auth_url) throw new Error(t('provider.oauth.authorizeFailed'))
|
||||
const popup = window.open(data.auth_url, 'provider-oauth', 'width=600,height=720')
|
||||
const listener = async (event: MessageEvent) => {
|
||||
@@ -453,8 +631,34 @@ async function handleAuthorize() {
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCopyDeviceCode() {
|
||||
const userCode = oauthStatus.value?.device?.user_code?.trim()
|
||||
const verificationUri = oauthStatus.value?.device?.verification_uri?.trim()
|
||||
if (!userCode || !verificationUri) return
|
||||
|
||||
const popup = window.open('', 'provider-device-oauth', 'width=960,height=720')
|
||||
const copied = await copyText(userCode)
|
||||
|
||||
if (!copied) {
|
||||
popup?.close()
|
||||
toast.error(t('provider.oauth.copyFailed'))
|
||||
return
|
||||
}
|
||||
|
||||
toast.success(t('common.copied'))
|
||||
|
||||
if (popup) {
|
||||
popup.location.href = verificationUri
|
||||
popup.focus()
|
||||
return
|
||||
}
|
||||
|
||||
window.open(verificationUri, '_blank', 'width=960,height=720')
|
||||
}
|
||||
|
||||
async function handleRevoke() {
|
||||
if (!props.provider?.id) return
|
||||
clearPollTimer()
|
||||
revokeLoading.value = true
|
||||
try {
|
||||
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/token`, {
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
name: GitHub Copilot
|
||||
client_type: github-copilot
|
||||
|
||||
models:
|
||||
- model_id: claude-opus-4.6-1m
|
||||
name: Claude Opus 4.6 (1M context)(Internal only)
|
||||
type: chat
|
||||
config:
|
||||
context_window: 1000000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high]
|
||||
|
||||
- model_id: claude-opus-4.6
|
||||
name: Claude Opus 4.6
|
||||
type: chat
|
||||
config:
|
||||
context_window: 144000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high]
|
||||
|
||||
- model_id: claude-sonnet-4.6
|
||||
name: Claude Sonnet 4.6
|
||||
type: chat
|
||||
config:
|
||||
context_window: 200000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high]
|
||||
|
||||
- model_id: goldeneye-free-auto
|
||||
name: Goldeneye
|
||||
type: chat
|
||||
config:
|
||||
context_window: 400000
|
||||
compatibilities: [vision, tool-call]
|
||||
|
||||
- model_id: gpt-5.2-codex
|
||||
name: GPT-5.2-Codex
|
||||
type: chat
|
||||
config:
|
||||
context_window: 400000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high, xhigh]
|
||||
|
||||
- model_id: gpt-5.3-codex
|
||||
name: GPT-5.3-Codex
|
||||
type: chat
|
||||
config:
|
||||
context_window: 400000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high, xhigh]
|
||||
|
||||
- model_id: gpt-5.4-mini
|
||||
name: GPT-5.4 mini
|
||||
type: chat
|
||||
config:
|
||||
context_window: 400000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [none, low, medium, high, xhigh]
|
||||
|
||||
- model_id: gpt-5.4
|
||||
name: GPT-5.4
|
||||
type: chat
|
||||
config:
|
||||
context_window: 400000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high, xhigh]
|
||||
|
||||
- model_id: gpt-5-mini
|
||||
name: GPT-5 mini
|
||||
type: chat
|
||||
config:
|
||||
context_window: 264000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high]
|
||||
|
||||
- model_id: gpt-4o-mini-2024-07-18
|
||||
name: GPT-4o mini
|
||||
type: chat
|
||||
config:
|
||||
context_window: 128000
|
||||
compatibilities: [tool-call]
|
||||
|
||||
- model_id: grok-code-fast-1
|
||||
name: Grok Code Fast 1
|
||||
type: chat
|
||||
config:
|
||||
context_window: 128000
|
||||
compatibilities: [tool-call]
|
||||
|
||||
- model_id: gpt-5.1
|
||||
name: GPT-5.1
|
||||
type: chat
|
||||
config:
|
||||
context_window: 264000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [none, low, medium, high]
|
||||
|
||||
- model_id: text-embedding-3-small
|
||||
name: Embedding V3 small
|
||||
type: embedding
|
||||
config:
|
||||
dimensions: 1536
|
||||
|
||||
- model_id: text-embedding-3-small-inference
|
||||
name: Embedding V3 small (Inference)
|
||||
type: embedding
|
||||
config:
|
||||
dimensions: 1536
|
||||
|
||||
- model_id: claude-sonnet-4
|
||||
name: Claude Sonnet 4
|
||||
type: chat
|
||||
config:
|
||||
context_window: 216000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
|
||||
- model_id: claude-sonnet-4.5
|
||||
name: Claude Sonnet 4.5
|
||||
type: chat
|
||||
config:
|
||||
context_window: 144000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
|
||||
- model_id: claude-opus-4.5
|
||||
name: Claude Opus 4.5
|
||||
type: chat
|
||||
config:
|
||||
context_window: 160000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
|
||||
- model_id: claude-haiku-4.5
|
||||
name: Claude Haiku 4.5
|
||||
type: chat
|
||||
config:
|
||||
context_window: 144000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
|
||||
- model_id: gpt-4.1-2025-04-14
|
||||
name: GPT-4.1
|
||||
type: chat
|
||||
config:
|
||||
context_window: 128000
|
||||
compatibilities: [vision, tool-call]
|
||||
|
||||
- model_id: gpt-5.2
|
||||
name: GPT-5.2
|
||||
type: chat
|
||||
config:
|
||||
context_window: 264000
|
||||
compatibilities: [vision, tool-call, reasoning]
|
||||
reasoning_efforts: [low, medium, high, xhigh]
|
||||
|
||||
- model_id: gpt-3.5-turbo-0613
|
||||
name: GPT 3.5 Turbo
|
||||
type: chat
|
||||
config:
|
||||
context_window: 16384
|
||||
compatibilities: [tool-call]
|
||||
|
||||
- model_id: gpt-4.1
|
||||
name: GPT-4.1
|
||||
type: chat
|
||||
config:
|
||||
context_window: 128000
|
||||
compatibilities: [vision, tool-call]
|
||||
|
||||
- model_id: gpt-3.5-turbo
|
||||
name: GPT 3.5 Turbo
|
||||
type: chat
|
||||
config:
|
||||
context_window: 16384
|
||||
compatibilities: [tool-call]
|
||||
|
||||
- model_id: gpt-4o-mini
|
||||
name: GPT-4o mini
|
||||
type: chat
|
||||
config:
|
||||
context_window: 128000
|
||||
compatibilities: [tool-call]
|
||||
|
||||
- model_id: text-embedding-ada-002
|
||||
name: Embedding V2 Ada
|
||||
type: embedding
|
||||
config:
|
||||
dimensions: 1536
|
||||
@@ -68,7 +68,7 @@ CREATE TABLE IF NOT EXISTS providers (
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT providers_name_unique UNIQUE (name),
|
||||
CONSTRAINT providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex', 'edge-speech'))
|
||||
CONSTRAINT providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex', 'github-copilot', 'edge-speech'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS search_providers (
|
||||
@@ -644,3 +644,23 @@ CREATE TABLE IF NOT EXISTS provider_oauth_tokens (
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_provider_oauth_tokens_state ON provider_oauth_tokens(state) WHERE state != '';
|
||||
|
||||
-- user_provider_oauth_tokens: per-user OAuth2 tokens for providers with user-scoped auth (e.g. GitHub Copilot)
|
||||
CREATE TABLE IF NOT EXISTS user_provider_oauth_tokens (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
provider_id UUID NOT NULL REFERENCES providers(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
access_token TEXT NOT NULL DEFAULT '',
|
||||
refresh_token TEXT NOT NULL DEFAULT '',
|
||||
expires_at TIMESTAMPTZ,
|
||||
scope TEXT NOT NULL DEFAULT '',
|
||||
token_type TEXT NOT NULL DEFAULT '',
|
||||
state TEXT NOT NULL DEFAULT '',
|
||||
pkce_code_verifier TEXT NOT NULL DEFAULT '',
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT user_provider_oauth_tokens_provider_user_unique UNIQUE (provider_id, user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_provider_oauth_tokens_state ON user_provider_oauth_tokens(state) WHERE state != '';
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
-- 0062_github_copilot_user_oauth (rollback)
|
||||
-- Remove user-scoped provider OAuth tokens and github-copilot client type.
|
||||
|
||||
DROP INDEX IF EXISTS idx_user_provider_oauth_tokens_state;
|
||||
DROP TABLE IF EXISTS user_provider_oauth_tokens;
|
||||
|
||||
DELETE FROM providers WHERE client_type = 'github-copilot';
|
||||
|
||||
ALTER TABLE IF EXISTS providers DROP CONSTRAINT IF EXISTS providers_client_type_check;
|
||||
|
||||
ALTER TABLE IF EXISTS providers
|
||||
ADD CONSTRAINT providers_client_type_check CHECK (
|
||||
client_type IN (
|
||||
'openai-responses',
|
||||
'openai-completions',
|
||||
'anthropic-messages',
|
||||
'google-generative-ai',
|
||||
'openai-codex',
|
||||
'edge-speech'
|
||||
)
|
||||
);
|
||||
@@ -0,0 +1,38 @@
|
||||
-- 0062_github_copilot_user_oauth
|
||||
-- Add github-copilot as a provider client type and store OAuth tokens per user.
|
||||
|
||||
ALTER TABLE IF EXISTS providers DROP CONSTRAINT IF EXISTS providers_client_type_check;
|
||||
|
||||
ALTER TABLE IF EXISTS providers
|
||||
ADD CONSTRAINT providers_client_type_check CHECK (
|
||||
client_type IN (
|
||||
'openai-responses',
|
||||
'openai-completions',
|
||||
'anthropic-messages',
|
||||
'google-generative-ai',
|
||||
'openai-codex',
|
||||
'github-copilot',
|
||||
'edge-speech'
|
||||
)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_provider_oauth_tokens (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
provider_id UUID NOT NULL REFERENCES providers(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
access_token TEXT NOT NULL DEFAULT '',
|
||||
refresh_token TEXT NOT NULL DEFAULT '',
|
||||
expires_at TIMESTAMPTZ,
|
||||
scope TEXT NOT NULL DEFAULT '',
|
||||
token_type TEXT NOT NULL DEFAULT '',
|
||||
state TEXT NOT NULL DEFAULT '',
|
||||
pkce_code_verifier TEXT NOT NULL DEFAULT '',
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT user_provider_oauth_tokens_provider_user_unique UNIQUE (provider_id, user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_provider_oauth_tokens_state
|
||||
ON user_provider_oauth_tokens(state)
|
||||
WHERE state != '';
|
||||
@@ -0,0 +1,66 @@
|
||||
-- name: UpsertUserProviderOAuthToken :one
|
||||
INSERT INTO user_provider_oauth_tokens (
|
||||
provider_id,
|
||||
user_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
scope,
|
||||
token_type,
|
||||
state,
|
||||
pkce_code_verifier,
|
||||
metadata
|
||||
)
|
||||
VALUES (
|
||||
sqlc.arg(provider_id),
|
||||
sqlc.arg(user_id),
|
||||
sqlc.arg(access_token),
|
||||
sqlc.arg(refresh_token),
|
||||
sqlc.arg(expires_at),
|
||||
sqlc.arg(scope),
|
||||
sqlc.arg(token_type),
|
||||
sqlc.arg(state),
|
||||
sqlc.arg(pkce_code_verifier),
|
||||
sqlc.arg(metadata)
|
||||
)
|
||||
ON CONFLICT (provider_id, user_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
scope = EXCLUDED.scope,
|
||||
token_type = EXCLUDED.token_type,
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = now()
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetUserProviderOAuthToken :one
|
||||
SELECT * FROM user_provider_oauth_tokens
|
||||
WHERE provider_id = sqlc.arg(provider_id)
|
||||
AND user_id = sqlc.arg(user_id);
|
||||
|
||||
-- name: GetUserProviderOAuthTokenByState :one
|
||||
SELECT * FROM user_provider_oauth_tokens
|
||||
WHERE state = sqlc.arg(state)
|
||||
AND state != '';
|
||||
|
||||
-- name: UpdateUserProviderOAuthState :exec
|
||||
INSERT INTO user_provider_oauth_tokens (provider_id, user_id, state, pkce_code_verifier, metadata)
|
||||
VALUES (
|
||||
sqlc.arg(provider_id),
|
||||
sqlc.arg(user_id),
|
||||
sqlc.arg(state),
|
||||
sqlc.arg(pkce_code_verifier),
|
||||
sqlc.arg(metadata)
|
||||
)
|
||||
ON CONFLICT (provider_id, user_id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = now();
|
||||
|
||||
-- name: DeleteUserProviderOAuthToken :exec
|
||||
DELETE FROM user_provider_oauth_tokens
|
||||
WHERE provider_id = sqlc.arg(provider_id)
|
||||
AND user_id = sqlc.arg(user_id);
|
||||
@@ -27,8 +27,8 @@ require (
|
||||
github.com/mailgun/mailgun-go/v5 v5.14.0
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7
|
||||
github.com/memohai/dingtalk-stream-sdk-go v0.0.0-20260405113102-87e23096b978
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260402160505-00db38ee4442
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260412161211-dbedfe32c86f
|
||||
github.com/modelcontextprotocol/go-sdk v1.5.0
|
||||
github.com/opencontainers/image-spec v1.1.1
|
||||
github.com/opencontainers/runtime-spec v1.3.0
|
||||
github.com/qdrant/go-client v1.17.1
|
||||
@@ -40,11 +40,12 @@ require (
|
||||
github.com/yuin/goldmark v1.7.13
|
||||
go.uber.org/fx v1.24.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/time v0.14.0
|
||||
google.golang.org/grpc v1.78.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
tags.cncf.io/container-device-interface v1.1.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -122,7 +123,7 @@ require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/sasha-s/go-deadlock v0.3.6 // indirect
|
||||
github.com/segmentio/asm v1.1.3 // indirect
|
||||
github.com/segmentio/asm v1.2.1 // indirect
|
||||
github.com/segmentio/encoding v0.5.4 // indirect
|
||||
github.com/sirupsen/logrus v1.9.4 // indirect
|
||||
github.com/spf13/pflag v1.0.9 // indirect
|
||||
@@ -143,11 +144,10 @@ require (
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/net v0.50.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect
|
||||
sigs.k8s.io/yaml v1.6.0 // indirect
|
||||
tags.cncf.io/container-device-interface v1.1.0 // indirect
|
||||
tags.cncf.io/container-device-interface/specs-go v1.1.0 // indirect
|
||||
)
|
||||
|
||||
@@ -24,6 +24,8 @@ github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kk
|
||||
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
|
||||
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de h1:FxWPpzIjnTlhPwqqXc4/vE0f7GvRjuAsbW+HOIe8KnA=
|
||||
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoUhZrYW9p1lxo/cm8EmUOOzAPSEZNGF2DK1dJgw=
|
||||
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
|
||||
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
|
||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
@@ -189,6 +191,10 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo=
|
||||
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
@@ -232,8 +238,8 @@ github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 h1:beehwOQperqGWj4m4E
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7/go.mod h1:OvmxM7JmnXBmwJWWVqtreL3HSHSKuzPbtbhlg5MvBg0=
|
||||
github.com/memohai/dingtalk-stream-sdk-go v0.0.0-20260405113102-87e23096b978 h1:6gD8DvZkimGmU0e3PjlusJPyw55SyeoE12CZQoYUa8g=
|
||||
github.com/memohai/dingtalk-stream-sdk-go v0.0.0-20260405113102-87e23096b978/go.mod h1:2LMgK5QYFlTSvrGY+sI/j+jK2WK+YGHv4IMuiW+iPSc=
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260402160505-00db38ee4442 h1:mTy+OSkMCOvF1S6D5asKRdKx0A+icQvnu6A/f7aZolg=
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260402160505-00db38ee4442/go.mod h1:GZTT9GUT3uSs6zram/FcF24GLTZMFSpiybbYmjr+gH8=
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260412161211-dbedfe32c86f h1:9NAj+FyDJPi8RzD1PUwb6OxZx/OrBD2FJo4tVAlhpbs=
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260412161211-dbedfe32c86f/go.mod h1:1uNfZWc8du+HWJ3r3FLyeGAXGiUAniuSWV89A8gbcz0=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
|
||||
@@ -252,8 +258,8 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
|
||||
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc=
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s=
|
||||
github.com/modelcontextprotocol/go-sdk v1.5.0 h1:CHU0FIX9kpueNkxuYtfYQn1Z0slhFzBZuq+x6IiblIU=
|
||||
github.com/modelcontextprotocol/go-sdk v1.5.0/go.mod h1:gggDIhoemhWs3BGkGwd1umzEXCEMMvAnhTrnbXJKKKA=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -303,8 +309,8 @@ github.com/sasha-s/go-deadlock v0.3.6/go.mod h1:CUqNyyvMxTyjFqDT7MRg9mb4Dv/btmGT
|
||||
github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=
|
||||
github.com/sebdah/goldie/v2 v2.8.0 h1:dZb9wR8q5++oplmEiJT+U/5KyotVD+HNGCAc5gNr8rc=
|
||||
github.com/sebdah/goldie/v2 v2.8.0/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI=
|
||||
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
|
||||
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
|
||||
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
|
||||
github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0=
|
||||
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
|
||||
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||
@@ -337,6 +343,12 @@ github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zd
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8=
|
||||
github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
|
||||
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
|
||||
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -420,8 +432,8 @@ golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
|
||||
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
|
||||
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -450,8 +462,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
pipelinepkg "github.com/memohai/memoh/internal/pipeline"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
@@ -499,7 +500,8 @@ func (r *Resolver) buildBaseRunConfig(ctx context.Context, p baseRunConfigParams
|
||||
}
|
||||
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
authCtx := oauthctx.WithUserID(ctx, p.UserID)
|
||||
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
|
||||
if err != nil {
|
||||
return agentpkg.RunConfig{}, models.GetResponse{}, sqlc.Provider{}, fmt.Errorf("resolve provider credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/compaction"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
)
|
||||
@@ -115,7 +116,8 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
|
||||
return compaction.TriggerConfig{}, err
|
||||
}
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, compactProvider)
|
||||
authCtx := oauthctx.WithUserID(ctx, req.UserID)
|
||||
creds, err := authResolver.ResolveModelCredentials(authCtx, compactProvider)
|
||||
if err != nil {
|
||||
return compaction.TriggerConfig{}, err
|
||||
}
|
||||
@@ -137,7 +139,6 @@ func (r *Resolver) buildCompactionConfig(ctx context.Context, req conversation.C
|
||||
if compactModel.Config.ContextWindow != nil && *compactModel.Config.ContextWindow > 0 {
|
||||
cfg.MaxCompactTokens = *compactModel.Config.ContextWindow * 90 / 100
|
||||
}
|
||||
|
||||
// For sync compaction: keep only the last few messages (~2000 tokens ≈ 3 messages).
|
||||
// The summary provides reference context; if the LLM needs details,
|
||||
// it will use tools (memory_read, search) to look them up.
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/session"
|
||||
)
|
||||
@@ -82,7 +83,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
|
||||
return
|
||||
}
|
||||
|
||||
title := r.generateTitle(ctx, titleModel, provider, userQuery)
|
||||
title := r.generateTitle(ctx, req.UserID, titleModel, provider, userQuery)
|
||||
if title == "" {
|
||||
return
|
||||
}
|
||||
@@ -95,7 +96,7 @@ func (r *Resolver) maybeGenerateSessionTitle(ctx context.Context, req conversati
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
|
||||
func (r *Resolver) generateTitle(ctx context.Context, userID string, model models.GetResponse, provider sqlc.Provider, userQuery string) string {
|
||||
userSnippet := truncate(strings.TrimSpace(userQuery), titlePromptMaxInputChars)
|
||||
if userSnippet == "" {
|
||||
return ""
|
||||
@@ -106,7 +107,8 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
|
||||
"User: " + userSnippet
|
||||
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
authCtx := oauthctx.WithUserID(ctx, userID)
|
||||
creds, err := authResolver.ResolveModelCredentials(authCtx, provider)
|
||||
if err != nil {
|
||||
r.logger.Warn("title gen: failed to resolve provider credentials", slog.Any("error", err))
|
||||
return ""
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
GitHubOAuthClientID = "Iv1.b507a08c87ecfe98"
|
||||
GitHubOAuthScope = "read:user user:email"
|
||||
DefaultAPIBaseURL = "https://api.githubcopilot.com"
|
||||
|
||||
copilotTokenURL = "https://api.github.com/copilot_internal/v2/token" //nolint:gosec // Fixed GitHub API endpoint, not a credential.
|
||||
copilotEditorVersion = "vscode/1.110.1"
|
||||
copilotPluginVersion = "copilot-chat/0.38.2"
|
||||
copilotUserAgent = "GitHubCopilotChat/0.38.2"
|
||||
copilotAPIVersion = "2025-10-01"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotTokenRefreshSkew = time.Minute
|
||||
defaultHTTPClientTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
type cachedToken struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
var tokenCache = struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]cachedToken
|
||||
}{
|
||||
entries: map[string]cachedToken{},
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
func ResolveToken(ctx context.Context, githubToken string) (string, error) {
|
||||
githubToken = strings.TrimSpace(githubToken)
|
||||
if githubToken == "" {
|
||||
return "", errors.New("github token is required")
|
||||
}
|
||||
|
||||
if token, ok := loadCachedToken(githubToken); ok {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
token, expiresAt, err := FetchCopilotToken(ctx, githubToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
storeCachedToken(githubToken, token, expiresAt)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func FetchCopilotToken(ctx context.Context, githubToken string) (string, time.Time, error) {
|
||||
githubToken = strings.TrimSpace(githubToken)
|
||||
if githubToken == "" {
|
||||
return "", time.Time{}, errors.New("github token is required")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotTokenURL, nil)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("create copilot token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "token "+githubToken)
|
||||
req.Header.Set("Editor-Version", copilotEditorVersion)
|
||||
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||
req.Header.Set("User-Agent", copilotUserAgent)
|
||||
req.Header.Set("X-GitHub-Api-Version", copilotAPIVersion)
|
||||
|
||||
resp, err := defaultHTTPClient(nil).Do(req) //nolint:gosec // Request targets a fixed GitHub API endpoint.
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("fetch copilot token: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("read copilot token response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", time.Time{}, fmt.Errorf("copilot token request failed: %s", strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed tokenResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("decode copilot token response: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(parsed.Token) == "" {
|
||||
return "", time.Time{}, errors.New("copilot token response did not include a token")
|
||||
}
|
||||
|
||||
var expiresAt time.Time
|
||||
if parsed.ExpiresAt > 0 {
|
||||
expiresAt = time.Unix(parsed.ExpiresAt, 0).UTC()
|
||||
}
|
||||
return parsed.Token, expiresAt, nil
|
||||
}
|
||||
|
||||
func NewHTTPClient(base *http.Client) *http.Client {
|
||||
client := defaultHTTPClient(base)
|
||||
client.Transport = &headerRoundTripper{
|
||||
base: client.Transport,
|
||||
headers: map[string]string{
|
||||
"Copilot-Integration-Id": copilotIntegrationID,
|
||||
"Editor-Version": copilotEditorVersion,
|
||||
"Editor-Plugin-Version": copilotPluginVersion,
|
||||
"User-Agent": copilotUserAgent,
|
||||
"X-GitHub-Api-Version": copilotAPIVersion,
|
||||
},
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
type headerRoundTripper struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
clone := req.Clone(req.Context())
|
||||
clone.Header = req.Header.Clone()
|
||||
for key, value := range rt.headers {
|
||||
clone.Header.Set(key, value)
|
||||
}
|
||||
if rt.base == nil {
|
||||
rt.base = http.DefaultTransport
|
||||
}
|
||||
return rt.base.RoundTrip(clone)
|
||||
}
|
||||
|
||||
func loadCachedToken(githubToken string) (string, bool) {
|
||||
tokenCache.mu.Lock()
|
||||
defer tokenCache.mu.Unlock()
|
||||
|
||||
entry, ok := tokenCache.entries[githubToken]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if !entry.ExpiresAt.IsZero() && !time.Now().Add(copilotTokenRefreshSkew).Before(entry.ExpiresAt) {
|
||||
delete(tokenCache.entries, githubToken)
|
||||
return "", false
|
||||
}
|
||||
return entry.Token, true
|
||||
}
|
||||
|
||||
func storeCachedToken(githubToken, token string, expiresAt time.Time) {
|
||||
tokenCache.mu.Lock()
|
||||
defer tokenCache.mu.Unlock()
|
||||
tokenCache.entries[githubToken] = cachedToken{
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func defaultHTTPClient(base *http.Client) *http.Client {
|
||||
if base != nil {
|
||||
clone := *base
|
||||
if clone.Timeout == 0 {
|
||||
clone.Timeout = defaultHTTPClientTimeout
|
||||
}
|
||||
return &clone
|
||||
}
|
||||
return &http.Client{Timeout: defaultHTTPClientTimeout}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestNewHTTPClientAddsCopilotHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := &http.Client{
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
|
||||
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
|
||||
}
|
||||
if got := req.Header.Get("Editor-Version"); got != copilotEditorVersion {
|
||||
t.Fatalf("expected editor version %q, got %q", copilotEditorVersion, got)
|
||||
}
|
||||
if got := req.Header.Get("Editor-Plugin-Version"); got != copilotPluginVersion {
|
||||
t.Fatalf("expected plugin version %q, got %q", copilotPluginVersion, got)
|
||||
}
|
||||
if got := req.Header.Get("User-Agent"); got != copilotUserAgent {
|
||||
t.Fatalf("expected user agent %q, got %q", copilotUserAgent, got)
|
||||
}
|
||||
if got := req.Header.Get("X-GitHub-Api-Version"); got != copilotAPIVersion {
|
||||
t.Fatalf("expected api version %q, got %q", copilotAPIVersion, got)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`ok`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.githubcopilot.com/chat/completions", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := NewHTTPClient(base).Do(req) //nolint:gosec // Test request targets a fixed Copilot API endpoint.
|
||||
if err != nil {
|
||||
t.Fatalf("execute request: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestNewHTTPClientWithNilBaseDoesNotPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if got := req.Header.Get("Copilot-Integration-Id"); got != copilotIntegrationID {
|
||||
t.Fatalf("expected integration id %q, got %q", copilotIntegrationID, got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := NewHTTPClient(nil).Do(req) //nolint:gosec // Test request targets an httptest server URL.
|
||||
if err != nil {
|
||||
t.Fatalf("execute request with nil base client: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
)
|
||||
|
||||
func NewProvider(copilotToken string, baseClient *http.Client) sdk.Provider {
|
||||
options := []githubcopilot.Option{
|
||||
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
|
||||
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
|
||||
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
|
||||
}
|
||||
return githubcopilot.New(options...)
|
||||
}
|
||||
|
||||
func NewModel(copilotToken, modelID string, baseClient *http.Client) *sdk.Model {
|
||||
options := []githubcopilot.Option{
|
||||
githubcopilot.WithGitHubToken(strings.TrimSpace(copilotToken)),
|
||||
githubcopilot.WithBaseURL(DefaultAPIBaseURL),
|
||||
githubcopilot.WithHTTPClient(NewHTTPClient(baseClient)),
|
||||
}
|
||||
if strings.TrimSpace(modelID) == "" {
|
||||
modelID = githubcopilot.AutoModel
|
||||
}
|
||||
return githubcopilot.New(options...).ChatModel(modelID)
|
||||
}
|
||||
@@ -515,3 +515,19 @@ type UserChannelBinding struct {
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserProviderOauthToken struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: user_provider_oauth.sql
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteUserProviderOAuthToken = `-- name: DeleteUserProviderOAuthToken :exec
|
||||
DELETE FROM user_provider_oauth_tokens
|
||||
WHERE provider_id = $1
|
||||
AND user_id = $2
|
||||
`
|
||||
|
||||
type DeleteUserProviderOAuthTokenParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteUserProviderOAuthToken(ctx context.Context, arg DeleteUserProviderOAuthTokenParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteUserProviderOAuthToken, arg.ProviderID, arg.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserProviderOAuthToken = `-- name: GetUserProviderOAuthToken :one
|
||||
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
|
||||
WHERE provider_id = $1
|
||||
AND user_id = $2
|
||||
`
|
||||
|
||||
type GetUserProviderOAuthTokenParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserProviderOAuthToken(ctx context.Context, arg GetUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, getUserProviderOAuthToken, arg.ProviderID, arg.UserID)
|
||||
var i UserProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserProviderOAuthTokenByState = `-- name: GetUserProviderOAuthTokenByState :one
|
||||
SELECT id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at FROM user_provider_oauth_tokens
|
||||
WHERE state = $1
|
||||
AND state != ''
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserProviderOAuthTokenByState(ctx context.Context, state string) (UserProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, getUserProviderOAuthTokenByState, state)
|
||||
var i UserProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateUserProviderOAuthState = `-- name: UpdateUserProviderOAuthState :exec
|
||||
INSERT INTO user_provider_oauth_tokens (provider_id, user_id, state, pkce_code_verifier, metadata)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5
|
||||
)
|
||||
ON CONFLICT (provider_id, user_id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = now()
|
||||
`
|
||||
|
||||
type UpdateUserProviderOAuthStateParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUserProviderOAuthState(ctx context.Context, arg UpdateUserProviderOAuthStateParams) error {
|
||||
_, err := q.db.Exec(ctx, updateUserProviderOAuthState,
|
||||
arg.ProviderID,
|
||||
arg.UserID,
|
||||
arg.State,
|
||||
arg.PkceCodeVerifier,
|
||||
arg.Metadata,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertUserProviderOAuthToken = `-- name: UpsertUserProviderOAuthToken :one
|
||||
INSERT INTO user_provider_oauth_tokens (
|
||||
provider_id,
|
||||
user_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
scope,
|
||||
token_type,
|
||||
state,
|
||||
pkce_code_verifier,
|
||||
metadata
|
||||
)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10
|
||||
)
|
||||
ON CONFLICT (provider_id, user_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
scope = EXCLUDED.scope,
|
||||
token_type = EXCLUDED.token_type,
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = now()
|
||||
RETURNING id, provider_id, user_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpsertUserProviderOAuthTokenParams struct {
|
||||
ProviderID pgtype.UUID `json:"provider_id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpsertUserProviderOAuthToken(ctx context.Context, arg UpsertUserProviderOAuthTokenParams) (UserProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, upsertUserProviderOAuthToken,
|
||||
arg.ProviderID,
|
||||
arg.UserID,
|
||||
arg.AccessToken,
|
||||
arg.RefreshToken,
|
||||
arg.ExpiresAt,
|
||||
arg.Scope,
|
||||
arg.TokenType,
|
||||
arg.State,
|
||||
arg.PkceCodeVerifier,
|
||||
arg.Metadata,
|
||||
)
|
||||
var i UserProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
type ModelsHandler struct {
|
||||
@@ -301,7 +303,12 @@ func (h *ModelsHandler) Test(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(c.Request().Context(), id)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(ctx, id)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
@@ -20,6 +22,7 @@ func NewProviderOAuthHandler(service *providers.Service) *ProviderOAuthHandler {
|
||||
|
||||
func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
|
||||
e.GET("/providers/:id/oauth/authorize", h.Authorize)
|
||||
e.POST("/providers/:id/oauth/poll", h.Poll)
|
||||
e.GET("/providers/:id/oauth/status", h.Status)
|
||||
e.DELETE("/providers/:id/oauth/token", h.Revoke)
|
||||
e.GET("/auth/callback", h.Callback)
|
||||
@@ -30,7 +33,7 @@ func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
|
||||
// @Summary Start OAuth2 authorization for an LLM provider
|
||||
// @Tags providers-oauth
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Success 200 {object} providers.OAuthAuthorizeResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /providers/{id}/oauth/authorize [get].
|
||||
@@ -39,11 +42,39 @@ func (h *ProviderOAuthHandler) Authorize(c echo.Context) error {
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
authURL, err := h.service.StartOAuthAuthorization(c.Request().Context(), providerID)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
resp, err := h.service.StartOAuthAuthorization(ctx, providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"auth_url": authURL})
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Poll godoc
|
||||
// @Summary Poll OAuth device authorization for an LLM provider
|
||||
// @Tags providers-oauth
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Success 200 {object} providers.OAuthStatus
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /providers/{id}/oauth/poll [post].
|
||||
func (h *ProviderOAuthHandler) Poll(c echo.Context) error {
|
||||
providerID := strings.TrimSpace(c.Param("id"))
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
status, err := h.service.PollOAuthAuthorization(ctx, providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
// Status godoc
|
||||
@@ -59,7 +90,11 @@ func (h *ProviderOAuthHandler) Status(c echo.Context) error {
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
status, err := h.service.GetOAuthStatus(c.Request().Context(), providerID)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
status, err := h.service.GetOAuthStatus(ctx, providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
@@ -79,7 +114,11 @@ func (h *ProviderOAuthHandler) Revoke(c echo.Context) error {
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
if err := h.service.RevokeOAuthToken(c.Request().Context(), providerID); err != nil {
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
if err := h.service.RevokeOAuthToken(ctx, providerID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
@@ -111,11 +150,11 @@ func (h *ProviderOAuthHandler) Callback(c echo.Context) error {
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>OpenAI OAuth Connected</title>
|
||||
<title>Provider Connected</title>
|
||||
</head>
|
||||
<body style="font-family: sans-serif; padding: 24px;">
|
||||
<h2>OpenAI OAuth connected</h2>
|
||||
<p>You can close this window and return to Memoh.</p>
|
||||
<h2>Provider connected</h2>
|
||||
<p>Your current Memoh account is now connected.</p>
|
||||
<script>
|
||||
window.opener?.postMessage({ type: "memoh-provider-oauth-success", providerId: "{{.ProviderID}}" }, "*");
|
||||
window.close();
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
@@ -272,7 +274,12 @@ func (h *ProvidersHandler) Test(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(c.Request().Context(), id)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
|
||||
resp, err := h.service.Test(ctx, id)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
@@ -301,7 +308,12 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
|
||||
remoteModels, err := h.service.FetchRemoteModels(c.Request().Context(), id)
|
||||
ctx := c.Request().Context()
|
||||
if userID, err := auth.UserIDFromContext(c); err == nil {
|
||||
ctx = oauthctx.WithUserID(ctx, userID)
|
||||
}
|
||||
|
||||
remoteModels, err := h.service.FetchRemoteModels(ctx, id)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/memohai/memoh/internal/healthcheck"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,6 +26,7 @@ type BotModelLookup interface {
|
||||
|
||||
// BotModels holds the model UUIDs associated with a bot.
|
||||
type BotModels struct {
|
||||
OwnerUserID string
|
||||
ChatModelID string
|
||||
MemoryModelID string
|
||||
EmbeddingModelID string
|
||||
@@ -115,7 +117,8 @@ func (c *Checker) ListChecks(ctx context.Context, botID string) []healthcheck.Ch
|
||||
wg.Add(1)
|
||||
go func(idx int, s modelSlot) {
|
||||
defer wg.Done()
|
||||
results[idx] = c.probeSlot(probeCtx, s)
|
||||
slotCtx := oauthctx.WithUserID(probeCtx, botModels.OwnerUserID)
|
||||
results[idx] = c.probeSlot(slotCtx, s)
|
||||
}(i, slot)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package modelchecker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/healthcheck"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
type testLookup struct {
|
||||
models BotModels
|
||||
}
|
||||
|
||||
func (l testLookup) GetBotModelIDs(context.Context, string) (BotModels, error) {
|
||||
return l.models, nil
|
||||
}
|
||||
|
||||
type testProber struct {
|
||||
t *testing.T
|
||||
wantUserID string
|
||||
}
|
||||
|
||||
func (p testProber) Test(ctx context.Context, id string) (models.TestResponse, error) {
|
||||
if got := oauthctx.UserIDFromContext(ctx); got != p.wantUserID {
|
||||
p.t.Fatalf("expected oauth user id %q, got %q", p.wantUserID, got)
|
||||
}
|
||||
if id != "model-chat-1" {
|
||||
p.t.Fatalf("expected model id %q, got %q", "model-chat-1", id)
|
||||
}
|
||||
return models.TestResponse{
|
||||
Status: models.TestStatusOK,
|
||||
Reachable: true,
|
||||
Message: "ok",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestCheckerListChecksInjectsOwnerUserIDIntoProbeContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checker := NewChecker(nil, testLookup{
|
||||
models: BotModels{
|
||||
OwnerUserID: "user-123",
|
||||
ChatModelID: "model-chat-1",
|
||||
},
|
||||
}, testProber{
|
||||
t: t,
|
||||
wantUserID: "user-123",
|
||||
})
|
||||
|
||||
items := checker.ListChecks(context.Background(), "bot-1")
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 check result, got %d", len(items))
|
||||
}
|
||||
if items[0].Status != healthcheck.StatusOK {
|
||||
t.Fatalf("expected status %q, got %q", healthcheck.StatusOK, items[0].Status)
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,7 @@ func (l *QueriesLookup) GetBotModelIDs(ctx context.Context, botID string) (BotMo
|
||||
}
|
||||
|
||||
var m BotModels
|
||||
m.OwnerUserID = bot.OwnerUserID.String()
|
||||
if bot.ChatModelID.Valid {
|
||||
m.ChatModelID = bot.ChatModelID.String()
|
||||
}
|
||||
|
||||
@@ -429,6 +429,7 @@ func IsValidClientType(clientType ClientType) bool {
|
||||
ClientTypeAnthropicMessages,
|
||||
ClientTypeGoogleGenerativeAI,
|
||||
ClientTypeOpenAICodex,
|
||||
ClientTypeGitHubCopilot,
|
||||
ClientTypeEdgeSpeech:
|
||||
return true
|
||||
default:
|
||||
|
||||
+53
-12
@@ -17,8 +17,10 @@ import (
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/oauthctx"
|
||||
)
|
||||
|
||||
const probeTimeout = 15 * time.Second
|
||||
@@ -162,6 +164,9 @@ func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientTyp
|
||||
}
|
||||
return openaicodex.New(opts...)
|
||||
|
||||
case ClientTypeGitHubCopilot:
|
||||
return memohcopilot.NewProvider(apiKey, httpClient)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(apiKey),
|
||||
@@ -202,26 +207,62 @@ type modelCredentials struct {
|
||||
func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.Provider) (modelCredentials, error) {
|
||||
apiKey := providerConfigString(provider.Config, "api_key")
|
||||
|
||||
if ClientType(provider.ClientType) != ClientTypeOpenAICodex {
|
||||
switch ClientType(provider.ClientType) {
|
||||
case ClientTypeGitHubCopilot:
|
||||
token, err := s.resolveGitHubCopilotAccessToken(ctx, provider)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
return modelCredentials{APIKey: token}, nil
|
||||
|
||||
case ClientTypeOpenAICodex:
|
||||
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenRow.AccessToken)
|
||||
if accessToken == "" {
|
||||
return modelCredentials{}, errors.New("oauth token is missing access token")
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(accessToken)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
return modelCredentials{
|
||||
APIKey: accessToken,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return modelCredentials{APIKey: apiKey}, nil
|
||||
}
|
||||
}
|
||||
|
||||
tokenRow, err := s.queries.GetProviderOAuthTokenByProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
func (s *Service) resolveGitHubCopilotAccessToken(ctx context.Context, provider sqlc.Provider) (string, error) {
|
||||
userID := oauthctx.UserIDFromContext(ctx)
|
||||
if userID == "" {
|
||||
return "", errors.New("github copilot requires a current user")
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenRow.AccessToken)
|
||||
userUUID, err := db.ParseUUID(userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := s.queries.GetUserProviderOAuthToken(ctx, sqlc.GetUserProviderOAuthTokenParams{
|
||||
ProviderID: provider.ID,
|
||||
UserID: userUUID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
accessToken := strings.TrimSpace(row.AccessToken)
|
||||
if accessToken == "" {
|
||||
return modelCredentials{}, errors.New("oauth token is missing access token")
|
||||
return "", errors.New("oauth token is missing access token")
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(accessToken)
|
||||
copilotToken, err := memohcopilot.ResolveToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
return "", err
|
||||
}
|
||||
return modelCredentials{
|
||||
APIKey: accessToken,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
return copilotToken, nil
|
||||
}
|
||||
|
||||
func codexAccountIDFromToken(token string) (string, error) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
)
|
||||
|
||||
// SDKModelConfig holds provider and model information resolved from DB,
|
||||
@@ -76,6 +78,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
}
|
||||
return openaicodex.New(opts...).ChatModel(cfg.ModelID)
|
||||
|
||||
case ClientTypeGitHubCopilot:
|
||||
return memohcopilot.NewModel(cfg.APIKey, cfg.ModelID, cfg.HTTPClient)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(cfg.APIKey),
|
||||
@@ -178,6 +183,8 @@ func ResolveClientType(model *sdk.Model) string {
|
||||
return string(ClientTypeAnthropicMessages)
|
||||
case strings.Contains(name, "google"):
|
||||
return string(ClientTypeGoogleGenerativeAI)
|
||||
case strings.Contains(name, "github-copilot"), strings.Contains(name, "copilot"):
|
||||
return string(ClientTypeGitHubCopilot)
|
||||
case strings.Contains(name, "codex"):
|
||||
return string(ClientTypeOpenAICodex)
|
||||
case strings.Contains(name, "responses"):
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
|
||||
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
|
||||
ClientTypeOpenAICodex ClientType = "openai-codex"
|
||||
ClientTypeGitHubCopilot ClientType = "github-copilot"
|
||||
ClientTypeEdgeSpeech ClientType = "edge-speech"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package oauthctx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type userIDContextKey struct{}
|
||||
|
||||
func WithUserID(ctx context.Context, userID string) context.Context {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, userIDContextKey{}, userID)
|
||||
}
|
||||
|
||||
func UserIDFromContext(ctx context.Context) string {
|
||||
userID, _ := ctx.Value(userIDContextKey{}).(string)
|
||||
return strings.TrimSpace(userID)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
@@ -24,25 +25,38 @@ func SupportsOpenAICodexOAuth(provider sqlc.Provider) bool {
|
||||
}
|
||||
|
||||
func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.Provider) (ModelCredentials, error) {
|
||||
if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex {
|
||||
switch models.ClientType(provider.ClientType) {
|
||||
case models.ClientTypeGitHubCopilot:
|
||||
githubToken, err := s.GetValidAccessToken(ctx, provider.ID.String())
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
copilotToken, err := memohcopilot.ResolveToken(ctx, githubToken)
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
return ModelCredentials{APIKey: copilotToken}, nil
|
||||
|
||||
case models.ClientTypeOpenAICodex:
|
||||
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(token)
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
return ModelCredentials{
|
||||
APIKey: token,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
apiKey := ProviderConfigString(provider, "api_key")
|
||||
return ModelCredentials{
|
||||
APIKey: apiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(token)
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
return ModelCredentials{
|
||||
APIKey: token,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func codexAccountIDFromToken(token string) (string, error) {
|
||||
|
||||
+956
-113
File diff suppressed because it is too large
Load Diff
@@ -11,9 +11,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
githubcopilot "github.com/memohai/twilight-ai/provider/github/copilot"
|
||||
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
memohcopilot "github.com/memohai/memoh/internal/copilot"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
@@ -47,15 +49,14 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
|
||||
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(req.Config)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
clientType := req.ClientType
|
||||
if clientType == "" {
|
||||
clientType = string(models.ClientTypeOpenAICompletions)
|
||||
}
|
||||
configJSON, err := json.Marshal(normalizeProviderConfig(clientType, req.Config))
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
var icon pgtype.Text
|
||||
if req.Icon != "" {
|
||||
@@ -150,12 +151,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
|
||||
existingConfig := providerConfig(existing.Config)
|
||||
if req.Config != nil {
|
||||
existingAPIKey := configString(existingConfig, "api_key")
|
||||
newAPIKey := configString(req.Config, "api_key")
|
||||
if newAPIKey != "" && newAPIKey == maskAPIKey(existingAPIKey) {
|
||||
req.Config["api_key"] = existingAPIKey
|
||||
}
|
||||
existingConfig = req.Config
|
||||
mergedConfig := mergeProviderConfig(existingConfig, req.Config)
|
||||
preserveMaskedConfigSecret(mergedConfig, existingConfig, req.Config, "api_key")
|
||||
existingConfig = normalizeProviderConfig(clientType, mergedConfig)
|
||||
} else {
|
||||
existingConfig = normalizeProviderConfig(clientType, existingConfig)
|
||||
}
|
||||
configJSON, err := json.Marshal(existingConfig)
|
||||
if err != nil {
|
||||
@@ -257,6 +257,34 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
if models.ClientType(provider.ClientType) == models.ClientTypeGitHubCopilot {
|
||||
creds, err := s.ResolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdkProvider := memohcopilot.NewProvider(creds.APIKey, nil)
|
||||
if result := sdkProvider.Test(ctx); result.Status != sdk.ProviderStatusOK {
|
||||
return nil, fmt.Errorf("github copilot provider test failed: %s", result.Message)
|
||||
}
|
||||
|
||||
catalog := githubcopilot.Catalog()
|
||||
remoteModels := make([]RemoteModel, 0, len(catalog))
|
||||
for _, model := range catalog {
|
||||
remoteModels = append(remoteModels, RemoteModel{
|
||||
ID: model.ID,
|
||||
Name: model.DisplayName,
|
||||
Object: "model",
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "chat",
|
||||
Compatibilities: []string{
|
||||
models.CompatVision,
|
||||
models.CompatToolCall,
|
||||
models.CompatReasoning,
|
||||
},
|
||||
})
|
||||
}
|
||||
return remoteModels, nil
|
||||
}
|
||||
if supportsOAuth(provider) {
|
||||
catalog := openaicodex.Catalog()
|
||||
remoteModels := make([]RemoteModel, 0, len(catalog))
|
||||
@@ -329,7 +357,7 @@ func (s *Service) toGetResponse(provider sqlc.Provider) GetResponse {
|
||||
}
|
||||
|
||||
cfg := providerConfig(provider.Config)
|
||||
maskedCfg := maskConfigAPIKey(cfg)
|
||||
maskedCfg := maskConfigSecrets(provider.ClientType, cfg)
|
||||
|
||||
var icon string
|
||||
if provider.Icon.Valid {
|
||||
@@ -378,14 +406,51 @@ func ProviderConfigString(provider sqlc.Provider, key string) string {
|
||||
return configString(providerConfig(provider.Config), key)
|
||||
}
|
||||
|
||||
// maskConfigAPIKey returns a copy of config with api_key masked.
|
||||
func maskConfigAPIKey(cfg map[string]any) map[string]any {
|
||||
func cloneConfig(cfg map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(cfg))
|
||||
for k, v := range cfg {
|
||||
result[k] = v
|
||||
}
|
||||
if apiKey, _ := result["api_key"].(string); apiKey != "" {
|
||||
result["api_key"] = maskAPIKey(apiKey)
|
||||
return result
|
||||
}
|
||||
|
||||
func mergeProviderConfig(existing, incoming map[string]any) map[string]any {
|
||||
result := cloneConfig(existing)
|
||||
for k, v := range incoming {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func preserveMaskedConfigSecret(merged, existing, incoming map[string]any, key string) {
|
||||
existingValue := strings.TrimSpace(configString(existing, key))
|
||||
newValue := strings.TrimSpace(configString(incoming, key))
|
||||
if existingValue == "" || newValue == "" {
|
||||
return
|
||||
}
|
||||
if newValue == maskAPIKey(existingValue) {
|
||||
merged[key] = existingValue
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeProviderConfig keeps provider-specific secrets under stable keys while
|
||||
// preserving backward compatibility for legacy stored configs.
|
||||
func normalizeProviderConfig(clientType string, cfg map[string]any) map[string]any {
|
||||
result := cloneConfig(cfg)
|
||||
if models.ClientType(clientType) == models.ClientTypeGitHubCopilot {
|
||||
delete(result, "api_key")
|
||||
delete(result, configOAuthClientSecretKey)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// maskConfigSecrets returns a copy of config with all known secret fields masked.
|
||||
func maskConfigSecrets(clientType string, cfg map[string]any) map[string]any {
|
||||
result := normalizeProviderConfig(clientType, cfg)
|
||||
for _, key := range []string{"api_key", configOAuthClientSecretKey} {
|
||||
if value, _ := result[key].(string); value != "" {
|
||||
result[key] = maskAPIKey(value)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
func TestMaskAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -34,3 +40,166 @@ func TestMaskAPIKey(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeProviderConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("github copilot drops legacy secrets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := normalizeProviderConfig("github-copilot", map[string]any{
|
||||
"api_key": "gh-secret",
|
||||
configOAuthClientSecretKey: "oauth-secret",
|
||||
"base_url": "ignored",
|
||||
})
|
||||
|
||||
if _, exists := cfg[configOAuthClientSecretKey]; exists {
|
||||
t.Fatalf("expected oauth client secret to be removed, got %#v", cfg[configOAuthClientSecretKey])
|
||||
}
|
||||
if _, exists := cfg["api_key"]; exists {
|
||||
t.Fatalf("expected legacy api_key to be removed, got %#v", cfg["api_key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non copilot providers keep api key key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := normalizeProviderConfig("openai-completions", map[string]any{
|
||||
"api_key": "sk-live",
|
||||
})
|
||||
|
||||
if got, ok := cfg["api_key"].(string); !ok || got != "sk-live" {
|
||||
t.Fatalf("expected api_key to remain untouched, got %#v", cfg["api_key"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaskConfigSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := maskConfigSecrets("openai-completions", map[string]any{
|
||||
"api_key": "sk-secret-123456",
|
||||
})
|
||||
|
||||
masked, _ := cfg["api_key"].(string)
|
||||
if masked == "" || masked == "sk-secret-123456" {
|
||||
t.Fatalf("expected api key to be masked, got %q", masked)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreserveMaskedConfigSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
merged := map[string]any{
|
||||
configOAuthClientSecretKey: "*************",
|
||||
}
|
||||
existing := map[string]any{
|
||||
configOAuthClientSecretKey: "gh-secret-1234",
|
||||
}
|
||||
incoming := map[string]any{
|
||||
configOAuthClientSecretKey: maskAPIKey("gh-secret-1234"),
|
||||
}
|
||||
|
||||
preserveMaskedConfigSecret(merged, existing, incoming, configOAuthClientSecretKey)
|
||||
|
||||
if got, _ := merged[configOAuthClientSecretKey].(string); got != "gh-secret-1234" {
|
||||
t.Fatalf("expected masked value to be restored to original secret, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceMetadataRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expiresAt := time.Date(2026, time.April, 11, 12, 0, 0, 0, time.UTC)
|
||||
device := oauthDeviceMetadata{
|
||||
DeviceCode: "device-code",
|
||||
UserCode: "ABCD-EFGH",
|
||||
VerificationURI: "https://github.com/login/device",
|
||||
ExpiresAt: expiresAt,
|
||||
IntervalSeconds: 5,
|
||||
}
|
||||
|
||||
parsed := deviceMetadataFromMap(device.toMetadata())
|
||||
if parsed.DeviceCode != device.DeviceCode {
|
||||
t.Fatalf("expected device code %q, got %q", device.DeviceCode, parsed.DeviceCode)
|
||||
}
|
||||
if parsed.UserCode != device.UserCode {
|
||||
t.Fatalf("expected user code %q, got %q", device.UserCode, parsed.UserCode)
|
||||
}
|
||||
if parsed.VerificationURI != device.VerificationURI {
|
||||
t.Fatalf("expected verification uri %q, got %q", device.VerificationURI, parsed.VerificationURI)
|
||||
}
|
||||
if !parsed.ExpiresAt.Equal(expiresAt) {
|
||||
t.Fatalf("expected expiresAt %s, got %s", expiresAt, parsed.ExpiresAt)
|
||||
}
|
||||
if parsed.IntervalSeconds != device.IntervalSeconds {
|
||||
t.Fatalf("expected interval %d, got %d", device.IntervalSeconds, parsed.IntervalSeconds)
|
||||
}
|
||||
|
||||
status := parsed.toStatus()
|
||||
if status == nil || !status.Pending {
|
||||
t.Fatalf("expected pending device status, got %#v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountMetadataRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
account := oauthAccountMetadata{
|
||||
Label: "octocat",
|
||||
Login: "octocat",
|
||||
Name: "The Octocat",
|
||||
Email: "octocat@github.com",
|
||||
AvatarURL: "https://avatars.githubusercontent.com/u/1?v=4",
|
||||
ProfileURL: "https://github.com/octocat",
|
||||
}
|
||||
|
||||
parsed := accountMetadataFromMap(account.toMetadata())
|
||||
if parsed.Label != account.Label {
|
||||
t.Fatalf("expected label %q, got %q", account.Label, parsed.Label)
|
||||
}
|
||||
if parsed.Login != account.Login {
|
||||
t.Fatalf("expected login %q, got %q", account.Login, parsed.Login)
|
||||
}
|
||||
if parsed.Name != account.Name {
|
||||
t.Fatalf("expected name %q, got %q", account.Name, parsed.Name)
|
||||
}
|
||||
if parsed.Email != account.Email {
|
||||
t.Fatalf("expected email %q, got %q", account.Email, parsed.Email)
|
||||
}
|
||||
if parsed.AvatarURL != account.AvatarURL {
|
||||
t.Fatalf("expected avatar url %q, got %q", account.AvatarURL, parsed.AvatarURL)
|
||||
}
|
||||
if parsed.ProfileURL != account.ProfileURL {
|
||||
t.Fatalf("expected profile url %q, got %q", account.ProfileURL, parsed.ProfileURL)
|
||||
}
|
||||
|
||||
status := parsed.toStatus()
|
||||
if status == nil {
|
||||
t.Fatal("expected account status")
|
||||
}
|
||||
if status.Label != account.Label {
|
||||
t.Fatalf("expected status label %q, got %q", account.Label, status.Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthConfigForGitHubCopilotUsesFixedDeviceFlowSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := &Service{}
|
||||
cfg := service.oauthConfigForProvider(sqlc.Provider{
|
||||
ClientType: string(models.ClientTypeGitHubCopilot),
|
||||
Config: []byte(`{"api_key":"legacy","oauth_client_secret":"legacy-secret"}`),
|
||||
Metadata: []byte(`{"oauth_client_id":"custom","oauth_scopes":"repo"}`),
|
||||
})
|
||||
|
||||
if cfg.ClientID != "Iv1.b507a08c87ecfe98" {
|
||||
t.Fatalf("expected fixed client id, got %q", cfg.ClientID)
|
||||
}
|
||||
if cfg.ClientSecret != "" {
|
||||
t.Fatalf("expected empty client secret, got %q", cfg.ClientSecret)
|
||||
}
|
||||
if cfg.Scopes != "read:user user:email" {
|
||||
t.Fatalf("expected fixed scope, got %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,11 +54,37 @@ type TestResponse struct {
|
||||
|
||||
// OAuthStatus is returned by GET /providers/:id/oauth/status.
|
||||
type OAuthStatus struct {
|
||||
Configured bool `json:"configured"`
|
||||
HasToken bool `json:"has_token"`
|
||||
Expired bool `json:"expired"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
Configured bool `json:"configured"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
HasToken bool `json:"has_token"`
|
||||
Expired bool `json:"expired"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
Device *OAuthDeviceStatus `json:"device,omitempty"`
|
||||
Account *OAuthAccount `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthDeviceStatus struct {
|
||||
Pending bool `json:"pending"`
|
||||
UserCode string `json:"user_code,omitempty"`
|
||||
VerificationURI string `json:"verification_uri,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
IntervalSeconds int64 `json:"interval_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthAccount struct {
|
||||
Label string `json:"label,omitempty"`
|
||||
Login string `json:"login,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
AvatarURL string `json:"avatar_url,omitempty"`
|
||||
ProfileURL string `json:"profile_url,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthAuthorizeResponse struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
AuthURL string `json:"auth_url,omitempty"`
|
||||
Device *OAuthDeviceStatus `json:"device,omitempty"`
|
||||
}
|
||||
|
||||
// RemoteModel represents a model returned by the provider's /v1/models endpoint.
|
||||
|
||||
Reference in New Issue
Block a user