mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: openai codex support (#292)
* feat(web): add provider oauth management ui * feat: add OAuth callback support on port 1455 * feat: enhance reasoning effort options and support for OpenAI Codex OAuth * feat: update twilight-ai dependency to v0.3.4 * refactor: promote openai-codex to first-class client_type, remove auth_type Replace the previous openai-responses + metadata auth_type=openai-codex-oauth combo with a dedicated openai-codex client_type. OAuth requirement is now determined solely by client_type, eliminating the auth_type concept from the LLM provider domain entirely. - Add openai-codex to DB CHECK constraint (migration 0047) with data migration - Add ClientTypeOpenAICodex constant and dedicated SDK/probe branches - Remove AuthType from SDKModelConfig, ModelCredentials, TriggerConfig, etc. - Simplify supportsOAuth to check client_type == openai-codex - Add conf/providers/codex.yaml preset with Codex catalog models - Frontend: replace auth_type selector with client_type-driven OAuth UI --------- Co-authored-by: Acbox <acbox0328@gmail.com>
This commit is contained in:
@@ -47,6 +47,7 @@
|
||||
</FormItem>
|
||||
</FormField>
|
||||
<FormField
|
||||
v-if="form.values.client_type !== 'openai-codex'"
|
||||
v-slot="{ componentField }"
|
||||
name="api_key"
|
||||
>
|
||||
@@ -68,6 +69,12 @@
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
</FormField>
|
||||
<div
|
||||
v-else
|
||||
class="rounded-lg border p-3 text-sm text-muted-foreground"
|
||||
>
|
||||
{{ $t('provider.oauth.createHint') }}
|
||||
</div>
|
||||
<FormField
|
||||
v-slot="{ componentField }"
|
||||
name="base_url"
|
||||
@@ -161,7 +168,7 @@ import { useDialogMutation } from '@/composables/useDialogMutation'
|
||||
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
|
||||
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
|
||||
import { toast } from 'vue-sonner'
|
||||
import { computed } from 'vue'
|
||||
import { computed, watch } from 'vue'
|
||||
|
||||
const open = defineModel<boolean>('open')
|
||||
const { t } = useI18n()
|
||||
@@ -208,11 +215,19 @@ const { mutateAsync: createProviderMutation, isLoading } = useMutation({
|
||||
})
|
||||
|
||||
const providerSchema = toTypedSchema(z.object({
|
||||
api_key: z.string().min(1),
|
||||
api_key: z.string().optional(),
|
||||
base_url: z.string().min(1),
|
||||
name: z.string().min(1),
|
||||
client_type: z.string().min(1),
|
||||
auto_import: z.boolean().optional(),
|
||||
}).superRefine((value, ctx) => {
|
||||
if (value.client_type !== 'openai-codex' && !value.api_key?.trim()) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
path: ['api_key'],
|
||||
message: 'API key is required',
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
const form = useForm({
|
||||
@@ -223,12 +238,19 @@ const form = useForm({
|
||||
},
|
||||
})
|
||||
|
||||
watch(() => form.values.client_type, (clientType) => {
|
||||
if (clientType !== 'openai-codex') return
|
||||
if (!form.values.base_url) {
|
||||
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
|
||||
}
|
||||
})
|
||||
|
||||
const createProvider = form.handleSubmit(async (value) => {
|
||||
await run(
|
||||
() => createProviderMutation(value),
|
||||
{
|
||||
fallbackMessage: t('common.saveFailed'),
|
||||
onSuccess: () => {
|
||||
onSuccess: () => {
|
||||
open.value = false
|
||||
form.resetForm()
|
||||
},
|
||||
|
||||
@@ -15,6 +15,11 @@ export const CLIENT_TYPE_META: Record<string, ClientTypeMeta> = {
|
||||
label: 'OpenAI Completions',
|
||||
hint: 'Chat Completions API (widely compatible)',
|
||||
},
|
||||
'openai-codex': {
|
||||
value: 'openai-codex',
|
||||
label: 'OpenAI Codex',
|
||||
hint: 'Codex API (OAuth, coding-optimized)',
|
||||
},
|
||||
'anthropic-messages': {
|
||||
value: 'anthropic-messages',
|
||||
label: 'Anthropic Messages',
|
||||
|
||||
@@ -34,9 +34,7 @@
|
||||
"loadFailed": "Failed to load",
|
||||
"saveFailed": "Failed to save",
|
||||
"createdAt": "Created at",
|
||||
"none": "None",
|
||||
"searchTimezone": "Search timezones…",
|
||||
"noTimezoneFound": "No timezone found."
|
||||
"none": "None"
|
||||
},
|
||||
"auth": {
|
||||
"welcome": "Welcome Back",
|
||||
@@ -230,7 +228,27 @@
|
||||
"icon": "Icon",
|
||||
"iconPlaceholder": "Icon URL or identifier (optional)",
|
||||
"enable": "Enable",
|
||||
"enableHint": "Only models from enabled providers appear in the available model list"
|
||||
"enableHint": "Only models from enabled providers appear in the available model list",
|
||||
"oauth": {
|
||||
"title": "OpenAI OAuth",
|
||||
"description": "Authorize this provider with your ChatGPT account for Codex-compatible OpenAI access.",
|
||||
"createHint": "Save the provider first, then authorize it from the provider details panel.",
|
||||
"authorize": "Authorize",
|
||||
"authorizeFailed": "Failed to start authorization",
|
||||
"authorizeSuccess": "Authorization successful",
|
||||
"revoke": "Revoke",
|
||||
"revokeFailed": "Failed to revoke authorization",
|
||||
"revokeSuccess": "Authorization revoked",
|
||||
"callback": "Callback URL",
|
||||
"statusFailed": "Failed to load OAuth status",
|
||||
"status": {
|
||||
"checking": "Checking authorization status...",
|
||||
"authorized": "Authorized",
|
||||
"expired": "Authorization expired. Re-authorize to continue.",
|
||||
"missing": "Not authorized yet.",
|
||||
"notConfigured": "This provider is not configured for OAuth."
|
||||
}
|
||||
}
|
||||
},
|
||||
"searchProvider": {
|
||||
"title": "Search Providers",
|
||||
@@ -784,9 +802,11 @@
|
||||
"language": "Language",
|
||||
"reasoningEnabled": "Enable Reasoning",
|
||||
"reasoningEffort": "Reasoning Effort",
|
||||
"reasoningEffortNone": "None",
|
||||
"reasoningEffortLow": "Low",
|
||||
"reasoningEffortMedium": "Medium",
|
||||
"reasoningEffortHigh": "High",
|
||||
"reasoningEffortXHigh": "X-High",
|
||||
"heartbeatEnabled": "Enable Heartbeat",
|
||||
"heartbeatDescription": "Periodically trigger agent to check for items that need attention",
|
||||
"heartbeatInterval": "Heartbeat Interval (minutes)",
|
||||
|
||||
@@ -34,9 +34,7 @@
|
||||
"loadFailed": "加载失败",
|
||||
"saveFailed": "保存失败",
|
||||
"createdAt": "创建时间",
|
||||
"none": "无",
|
||||
"searchTimezone": "搜索时区…",
|
||||
"noTimezoneFound": "未找到时区"
|
||||
"none": "无"
|
||||
},
|
||||
"auth": {
|
||||
"welcome": "欢迎回来",
|
||||
@@ -226,7 +224,27 @@
|
||||
"icon": "图标",
|
||||
"iconPlaceholder": "图标 URL 或标识(可选)",
|
||||
"enable": "启用",
|
||||
"enableHint": "只有启用的供应商的模型才会出现在可用模型列表中"
|
||||
"enableHint": "只有启用的供应商的模型才会出现在可用模型列表中",
|
||||
"oauth": {
|
||||
"title": "OpenAI OAuth",
|
||||
"description": "使用你的 ChatGPT 账号为该提供商授权,以启用 Codex 兼容的 OpenAI 访问。",
|
||||
"createHint": "请先保存提供商,再到详情面板完成授权。",
|
||||
"authorize": "授权",
|
||||
"authorizeFailed": "启动授权失败",
|
||||
"authorizeSuccess": "授权成功",
|
||||
"revoke": "撤销授权",
|
||||
"revokeFailed": "撤销授权失败",
|
||||
"revokeSuccess": "授权已撤销",
|
||||
"callback": "回调地址",
|
||||
"statusFailed": "加载 OAuth 状态失败",
|
||||
"status": {
|
||||
"checking": "正在检查授权状态...",
|
||||
"authorized": "已授权",
|
||||
"expired": "授权已过期,请重新授权。",
|
||||
"missing": "尚未授权。",
|
||||
"notConfigured": "当前提供商未正确配置 OAuth。"
|
||||
}
|
||||
}
|
||||
},
|
||||
"searchProvider": {
|
||||
"title": "搜索提供方",
|
||||
@@ -780,9 +798,11 @@
|
||||
"language": "语言",
|
||||
"reasoningEnabled": "启用推理",
|
||||
"reasoningEffort": "推理等级",
|
||||
"reasoningEffortNone": "无",
|
||||
"reasoningEffortLow": "低",
|
||||
"reasoningEffortMedium": "中",
|
||||
"reasoningEffortHigh": "高",
|
||||
"reasoningEffortXHigh": "超高",
|
||||
"heartbeatEnabled": "启用心跳",
|
||||
"heartbeatDescription": "定期触发 Agent 检查是否有需要关注的事项",
|
||||
"heartbeatInterval": "心跳间隔(分钟)",
|
||||
|
||||
@@ -197,21 +197,6 @@
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Timezone -->
|
||||
<div class="space-y-2">
|
||||
<Label>{{ $t('bots.timezone') }}</Label>
|
||||
<TimezoneSelect
|
||||
:model-value="timezoneModel"
|
||||
:placeholder="$t('bots.timezonePlaceholder')"
|
||||
allow-empty
|
||||
:empty-label="$t('bots.timezoneInherited')"
|
||||
@update:model-value="onTimezoneChange"
|
||||
/>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{{ $t('bots.timezoneInheritedHint') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<!-- Max Context Load Time -->
|
||||
@@ -272,15 +257,36 @@
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem value="low">
|
||||
<SelectItem
|
||||
v-if="availableReasoningEfforts.includes('none')"
|
||||
value="none"
|
||||
>
|
||||
{{ $t('bots.settings.reasoningEffortNone') }}
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
v-if="availableReasoningEfforts.includes('low')"
|
||||
value="low"
|
||||
>
|
||||
{{ $t('bots.settings.reasoningEffortLow') }}
|
||||
</SelectItem>
|
||||
<SelectItem value="medium">
|
||||
<SelectItem
|
||||
v-if="availableReasoningEfforts.includes('medium')"
|
||||
value="medium"
|
||||
>
|
||||
{{ $t('bots.settings.reasoningEffortMedium') }}
|
||||
</SelectItem>
|
||||
<SelectItem value="high">
|
||||
<SelectItem
|
||||
v-if="availableReasoningEfforts.includes('high')"
|
||||
value="high"
|
||||
>
|
||||
{{ $t('bots.settings.reasoningEffortHigh') }}
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
v-if="availableReasoningEfforts.includes('xhigh')"
|
||||
value="xhigh"
|
||||
>
|
||||
{{ $t('bots.settings.reasoningEffortXHigh') }}
|
||||
</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
@@ -349,20 +355,18 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@memohai/ui'
|
||||
import { reactive, computed, watch, ref } from 'vue'
|
||||
import { reactive, computed, watch } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { toast } from 'vue-sonner'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import ConfirmPopover from '@/components/confirm-popover/index.vue'
|
||||
import TimezoneSelect from '@/components/timezone-select/index.vue'
|
||||
import { emptyTimezoneValue } from '@/utils/timezones'
|
||||
import ModelSelect from './model-select.vue'
|
||||
import SearchProviderSelect from './search-provider-select.vue'
|
||||
import MemoryProviderSelect from './memory-provider-select.vue'
|
||||
import TtsModelSelect from './tts-model-select.vue'
|
||||
import BrowserContextSelect from './browser-context-select.vue'
|
||||
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
|
||||
import { getBotsById, putBotsById, getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getMemoryProviders, getTtsProviders, getBrowserContexts, getBotsByBotIdMemoryStatus, postBotsByBotIdMemoryRebuild } from '@memohai/sdk'
|
||||
import { getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getMemoryProviders, getTtsProviders, getBrowserContexts, getBotsByBotIdMemoryStatus, postBotsByBotIdMemoryRebuild } from '@memohai/sdk'
|
||||
import type { SettingsSettings } from '@memohai/sdk'
|
||||
import type { Ref } from 'vue'
|
||||
import { resolveApiErrorMessage } from '@/utils/api-error'
|
||||
@@ -379,15 +383,6 @@ const botIdRef = computed(() => props.botId) as Ref<string>
|
||||
// ---- Data ----
|
||||
const queryCache = useQueryCache()
|
||||
|
||||
const { data: bot } = useQuery({
|
||||
key: () => ['bot', botIdRef.value],
|
||||
query: async () => {
|
||||
const { data } = await getBotsById({ path: { id: botIdRef.value }, throwOnError: true })
|
||||
return data
|
||||
},
|
||||
enabled: () => !!botIdRef.value,
|
||||
})
|
||||
|
||||
const { data: settings } = useQuery({
|
||||
key: () => ['bot-settings', botIdRef.value],
|
||||
query: async () => {
|
||||
@@ -503,31 +498,6 @@ const form = reactive({
|
||||
reasoning_effort: 'medium',
|
||||
})
|
||||
|
||||
const timezone = ref('')
|
||||
|
||||
const timezoneModel = computed(() => timezone.value || emptyTimezoneValue)
|
||||
|
||||
function onTimezoneChange(value: string) {
|
||||
timezone.value = value === emptyTimezoneValue ? '' : value
|
||||
}
|
||||
|
||||
watch(bot, (val) => {
|
||||
if (val) {
|
||||
timezone.value = val.timezone || ''
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
const { mutateAsync: updateBot } = useMutation({
|
||||
mutation: async ({ id, ...body }: Record<string, unknown> & { id: string }) => {
|
||||
const { data } = await putBotsById({ path: { id }, body, throwOnError: true })
|
||||
return data
|
||||
},
|
||||
onSettled: () => {
|
||||
queryCache.invalidateQueries({ key: ['bots'] })
|
||||
queryCache.invalidateQueries({ key: ['bot'] })
|
||||
},
|
||||
})
|
||||
|
||||
const selectedMemoryProvider = computed(() =>
|
||||
memoryProviders.value.find((provider) => provider.id === form.memory_provider_id),
|
||||
)
|
||||
@@ -582,6 +552,20 @@ const chatModelSupportsReasoning = computed(() => {
|
||||
return !!m?.config?.compatibilities?.includes('reasoning')
|
||||
})
|
||||
|
||||
const availableReasoningEfforts = computed(() => {
|
||||
if (!form.chat_model_id) return ['low', 'medium', 'high']
|
||||
const model = models.value.find((m) => m.id === form.chat_model_id)
|
||||
const efforts = ((model?.config as { reasoning_efforts?: string[] } | undefined)?.reasoning_efforts ?? [])
|
||||
.filter((effort) => ['none', 'low', 'medium', 'high', 'xhigh'].includes(effort))
|
||||
return efforts.length > 0 ? efforts : ['low', 'medium', 'high']
|
||||
})
|
||||
|
||||
watch(availableReasoningEfforts, (efforts) => {
|
||||
if (!efforts.includes(form.reasoning_effort)) {
|
||||
form.reasoning_effort = efforts.includes('medium') ? 'medium' : efforts[0] ?? 'medium'
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
const { data: memoryStatusData, isLoading: isMemoryStatusLoading } = useQuery({
|
||||
key: () => ['bot-memory-status', botIdRef.value, persistedMemoryProviderID.value],
|
||||
query: async () => {
|
||||
@@ -641,12 +625,10 @@ watch(settings, (val) => {
|
||||
}, { immediate: true })
|
||||
|
||||
const hasChanges = computed(() => {
|
||||
const timezoneChanged = timezone.value !== (bot.value?.timezone || '')
|
||||
if (!settings.value) return true
|
||||
const s = settings.value
|
||||
let changed =
|
||||
timezoneChanged
|
||||
|| form.chat_model_id !== (s.chat_model_id ?? '')
|
||||
form.chat_model_id !== (s.chat_model_id ?? '')
|
||||
|| form.title_model_id !== (s.title_model_id ?? '')
|
||||
|| form.search_provider_id !== (s.search_provider_id ?? '')
|
||||
|| form.memory_provider_id !== (s.memory_provider_id ?? '')
|
||||
@@ -662,11 +644,7 @@ const hasChanges = computed(() => {
|
||||
|
||||
async function handleSave() {
|
||||
try {
|
||||
const promises: Promise<unknown>[] = [updateSettings({ ...form })]
|
||||
if (timezone.value !== (bot.value?.timezone || '')) {
|
||||
promises.push(updateBot({ id: botIdRef.value, timezone: timezone.value }))
|
||||
}
|
||||
await Promise.all(promises)
|
||||
await updateSettings({ ...form })
|
||||
toast.success(t('bots.settings.saveSuccess'))
|
||||
} catch {
|
||||
return
|
||||
|
||||
@@ -30,6 +30,14 @@
|
||||
>
|
||||
{{ $t(`models.compatibility.${cap}`, cap) }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-for="effort in reasoningEfforts"
|
||||
:key="effort"
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ effort }}
|
||||
</Badge>
|
||||
<span
|
||||
v-if="model.config?.context_window"
|
||||
class="text-xs text-muted-foreground"
|
||||
@@ -101,6 +109,10 @@ import { postModelsByIdTest } from '@memohai/sdk'
|
||||
import type { ModelsGetResponse, ModelsTestResponse } from '@memohai/sdk'
|
||||
import { ref, computed } from 'vue'
|
||||
|
||||
type ModelConfigWithReasoning = {
|
||||
reasoning_efforts?: string[]
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
model: ModelsGetResponse
|
||||
deleteLoading: boolean
|
||||
@@ -113,6 +125,7 @@ defineEmits<{
|
||||
|
||||
const testLoading = ref(false)
|
||||
const testResult = ref<ModelsTestResponse | null>(null)
|
||||
const reasoningEfforts = computed(() => ((props.model.config as ModelConfigWithReasoning | undefined)?.reasoning_efforts ?? []))
|
||||
|
||||
const statusDotClass = computed(() => {
|
||||
switch (testResult.value?.status) {
|
||||
|
||||
@@ -44,7 +44,10 @@
|
||||
</FormField>
|
||||
</section>
|
||||
|
||||
<section class="space-y-2">
|
||||
<section
|
||||
v-if="form.values.client_type !== 'openai-codex'"
|
||||
class="space-y-2"
|
||||
>
|
||||
<h4 class="scroll-m-20 font-semibold tracking-tight">
|
||||
{{ $t('provider.apiKey') }}
|
||||
</h4>
|
||||
@@ -56,7 +59,7 @@
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
:placeholder="props.provider?.api_key || $t('provider.apiKeyPlaceholder')"
|
||||
:placeholder="providerWithAuth?.api_key || $t('provider.apiKeyPlaceholder')"
|
||||
:aria-label="$t('provider.apiKey')"
|
||||
v-bind="componentField"
|
||||
/>
|
||||
@@ -106,6 +109,67 @@
|
||||
</FormItem>
|
||||
</FormField>
|
||||
</section>
|
||||
|
||||
<section
|
||||
v-if="form.values.client_type === 'openai-codex'"
|
||||
class="rounded-lg border p-4 space-y-3 text-sm"
|
||||
>
|
||||
<div class="space-y-1">
|
||||
<div class="font-medium">
|
||||
{{ $t('provider.oauth.title') }}
|
||||
</div>
|
||||
<div class="text-muted-foreground">
|
||||
{{ $t('provider.oauth.description') }}
|
||||
</div>
|
||||
<div
|
||||
class="text-xs"
|
||||
:class="oauthExpired ? 'text-destructive' : 'text-muted-foreground'"
|
||||
>
|
||||
<template v-if="oauthStatusLoading">
|
||||
{{ $t('provider.oauth.status.checking') }}
|
||||
</template>
|
||||
<template v-else-if="oauthStatus && !oauthStatus.configured">
|
||||
{{ $t('provider.oauth.status.notConfigured') }}
|
||||
</template>
|
||||
<template v-else-if="oauthExpired">
|
||||
{{ $t('provider.oauth.status.expired') }}
|
||||
</template>
|
||||
<template v-else-if="oauthStatus?.has_token">
|
||||
{{ $t('provider.oauth.status.authorized') }}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ $t('provider.oauth.status.missing') }}
|
||||
</template>
|
||||
</div>
|
||||
<div
|
||||
v-if="oauthStatus?.callback_url"
|
||||
class="text-xs text-muted-foreground"
|
||||
>
|
||||
{{ $t('provider.oauth.callback') }}: {{ oauthStatus.callback_url }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex gap-2">
|
||||
<LoadingButton
|
||||
type="button"
|
||||
variant="outline"
|
||||
:disabled="!canAuthorizeOAuth"
|
||||
:loading="authorizeLoading"
|
||||
@click="handleAuthorize"
|
||||
>
|
||||
<FontAwesomeIcon :icon="['fas', 'key']" />
|
||||
{{ $t('provider.oauth.authorize') }}
|
||||
</LoadingButton>
|
||||
<LoadingButton
|
||||
v-if="oauthStatus?.has_token"
|
||||
type="button"
|
||||
variant="ghost"
|
||||
:loading="revokeLoading"
|
||||
@click="handleRevoke"
|
||||
>
|
||||
{{ $t('provider.oauth.revoke') }}
|
||||
</LoadingButton>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
|
||||
<section class="flex justify-between items-center mt-4">
|
||||
@@ -206,11 +270,22 @@ import { useForm } from 'vee-validate'
|
||||
import { postProvidersByIdTest } from '@memohai/sdk'
|
||||
import type { ProvidersGetResponse, ProvidersTestResponse } from '@memohai/sdk'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { toast } from 'vue-sonner'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
type ProviderWithAuth = Partial<ProvidersGetResponse>
|
||||
|
||||
type ProviderOAuthStatus = {
|
||||
configured: boolean
|
||||
has_token: boolean
|
||||
expired: boolean
|
||||
callback_url?: string
|
||||
expires_at?: string
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
provider: Partial<ProvidersGetResponse> | undefined
|
||||
provider: ProviderWithAuth | undefined
|
||||
editLoading: boolean
|
||||
deleteLoading: boolean
|
||||
}>()
|
||||
@@ -223,6 +298,13 @@ const emit = defineEmits<{
|
||||
const testLoading = ref(false)
|
||||
const testResult = ref<ProvidersTestResponse | null>(null)
|
||||
const testError = ref('')
|
||||
const oauthStatus = ref<ProviderOAuthStatus | null>(null)
|
||||
const oauthStatusLoading = ref(false)
|
||||
const authorizeLoading = ref(false)
|
||||
const revokeLoading = ref(false)
|
||||
const apiBase = import.meta.env.VITE_API_URL?.trim() || '/api'
|
||||
|
||||
const providerWithAuth = computed(() => props.provider as ProviderWithAuth | undefined)
|
||||
|
||||
async function runTest() {
|
||||
if (!props.provider?.id) return
|
||||
@@ -265,6 +347,14 @@ const providerSchema = toTypedSchema(z.object({
|
||||
metadata: z.object({
|
||||
additionalProp1: z.object({}),
|
||||
}),
|
||||
}).superRefine((value, ctx) => {
|
||||
if (value.client_type !== 'openai-codex' && !value.api_key?.trim() && !providerWithAuth.value?.api_key) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
path: ['api_key'],
|
||||
message: 'API key is required',
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
const form = useForm({
|
||||
@@ -283,6 +373,24 @@ watch(() => props.provider, (newVal) => {
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
watch(() => form.values.client_type, (clientType) => {
|
||||
if (clientType !== 'openai-codex') {
|
||||
oauthStatus.value = null
|
||||
return
|
||||
}
|
||||
if (!form.values.base_url) {
|
||||
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
|
||||
}
|
||||
})
|
||||
|
||||
watch(() => [props.provider?.id, form.values.client_type] as const, async ([id, clientType]) => {
|
||||
if (!id || clientType !== 'openai-codex') {
|
||||
oauthStatus.value = null
|
||||
return
|
||||
}
|
||||
await fetchOAuthStatus()
|
||||
}, { immediate: true })
|
||||
|
||||
const hasChanges = computed(() => {
|
||||
const raw = props.provider
|
||||
const baseChanged = JSON.stringify({
|
||||
@@ -316,4 +424,78 @@ const editProvider = form.handleSubmit(async (value) => {
|
||||
}
|
||||
emit('submit', payload)
|
||||
})
|
||||
|
||||
const oauthExpired = computed(() => Boolean(oauthStatus.value?.has_token && oauthStatus.value?.expired))
|
||||
const canAuthorizeOAuth = computed(() =>
|
||||
Boolean(
|
||||
props.provider?.id
|
||||
&& form.values.client_type === 'openai-codex',
|
||||
) && !oauthStatusLoading.value,
|
||||
)
|
||||
|
||||
function authHeaders(): Record<string, string> {
|
||||
const token = localStorage.getItem('token')
|
||||
return token ? { Authorization: `Bearer ${token}` } : {}
|
||||
}
|
||||
|
||||
async function fetchOAuthStatus() {
|
||||
if (!props.provider?.id) return
|
||||
oauthStatusLoading.value = true
|
||||
try {
|
||||
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/status`, {
|
||||
headers: authHeaders(),
|
||||
})
|
||||
if (!response.ok) throw new Error(t('provider.oauth.statusFailed'))
|
||||
oauthStatus.value = await response.json() as ProviderOAuthStatus
|
||||
} catch (error) {
|
||||
oauthStatus.value = null
|
||||
console.error('failed to load provider oauth status', error)
|
||||
} finally {
|
||||
oauthStatusLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function handleAuthorize() {
|
||||
if (!props.provider?.id) return
|
||||
authorizeLoading.value = true
|
||||
try {
|
||||
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/authorize`, {
|
||||
headers: authHeaders(),
|
||||
})
|
||||
if (!response.ok) throw new Error(t('provider.oauth.authorizeFailed'))
|
||||
const data = await response.json() as { auth_url?: string }
|
||||
if (!data.auth_url) throw new Error(t('provider.oauth.authorizeFailed'))
|
||||
const popup = window.open(data.auth_url, 'provider-oauth', 'width=600,height=720')
|
||||
const listener = async (event: MessageEvent) => {
|
||||
if (event.data?.type !== 'memoh-provider-oauth-success') return
|
||||
window.removeEventListener('message', listener)
|
||||
popup?.close()
|
||||
toast.success(t('provider.oauth.authorizeSuccess'))
|
||||
await fetchOAuthStatus()
|
||||
}
|
||||
window.addEventListener('message', listener)
|
||||
} catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : t('provider.oauth.authorizeFailed'))
|
||||
} finally {
|
||||
authorizeLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function handleRevoke() {
|
||||
if (!props.provider?.id) return
|
||||
revokeLoading.value = true
|
||||
try {
|
||||
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/token`, {
|
||||
method: 'DELETE',
|
||||
headers: authHeaders(),
|
||||
})
|
||||
if (!response.ok) throw new Error(t('provider.oauth.revokeFailed'))
|
||||
toast.success(t('provider.oauth.revokeSuccess'))
|
||||
await fetchOAuthStatus()
|
||||
} catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : t('provider.oauth.revokeFailed'))
|
||||
} finally {
|
||||
revokeLoading.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
+11
-1
@@ -165,7 +165,7 @@ func runServe() {
|
||||
accounts.NewService,
|
||||
acl.NewService,
|
||||
settings.NewService,
|
||||
providers.NewService,
|
||||
provideProvidersService,
|
||||
searchproviders.NewService,
|
||||
browsercontexts.NewService,
|
||||
policy.NewService,
|
||||
@@ -228,6 +228,7 @@ func runServe() {
|
||||
provideServerHandler(provideSessionHandler),
|
||||
provideServerHandler(handlers.NewSwaggerHandler),
|
||||
provideServerHandler(handlers.NewProvidersHandler),
|
||||
provideServerHandler(handlers.NewProviderOAuthHandler),
|
||||
provideServerHandler(handlers.NewSearchProvidersHandler),
|
||||
provideServerHandler(handlers.NewModelsHandler),
|
||||
provideServerHandler(handlers.NewSettingsHandler),
|
||||
@@ -765,6 +766,15 @@ func provideEmailRegistry(log *slog.Logger, tokenStore *emailpkg.DBOAuthTokenSto
|
||||
return reg
|
||||
}
|
||||
|
||||
func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *providers.Service {
|
||||
_ = cfg
|
||||
return providers.NewService(log, queries, defaultProviderOAuthCallbackURL())
|
||||
}
|
||||
|
||||
func defaultProviderOAuthCallbackURL() string {
|
||||
return "http://localhost:1455/auth/callback"
|
||||
}
|
||||
|
||||
func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, tokenStore *emailpkg.DBOAuthTokenStore, cfg config.Config) *handlers.EmailOAuthHandler {
|
||||
addr := strings.TrimSpace(cfg.Server.Addr)
|
||||
if addr == "" {
|
||||
|
||||
+11
-1
@@ -106,7 +106,7 @@ func runServe() {
|
||||
accounts.NewService,
|
||||
acl.NewService,
|
||||
settings.NewService,
|
||||
providers.NewService,
|
||||
provideProvidersService,
|
||||
searchproviders.NewService,
|
||||
policy.NewService,
|
||||
mcp.NewConnectionService,
|
||||
@@ -154,6 +154,7 @@ func runServe() {
|
||||
provideServerHandler(provideSessionHandler),
|
||||
provideServerHandler(handlers.NewSwaggerHandler),
|
||||
provideServerHandler(handlers.NewProvidersHandler),
|
||||
provideServerHandler(handlers.NewProviderOAuthHandler),
|
||||
provideServerHandler(handlers.NewSearchProvidersHandler),
|
||||
provideServerHandler(handlers.NewModelsHandler),
|
||||
provideServerHandler(handlers.NewSettingsHandler),
|
||||
@@ -858,6 +859,15 @@ func provideEmailRegistry(log *slog.Logger, tokenStore *emailpkg.DBOAuthTokenSto
|
||||
return reg
|
||||
}
|
||||
|
||||
func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *providers.Service {
|
||||
_ = cfg
|
||||
return providers.NewService(log, queries, defaultProviderOAuthCallbackURL())
|
||||
}
|
||||
|
||||
func defaultProviderOAuthCallbackURL() string {
|
||||
return "http://localhost:1455/auth/callback"
|
||||
}
|
||||
|
||||
func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, tokenStore *emailpkg.DBOAuthTokenStore, cfg config.Config) *handlers.EmailOAuthHandler {
|
||||
addr := strings.TrimSpace(cfg.Server.Addr)
|
||||
if addr == "" {
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
name: OpenAI Codex
|
||||
client_type: openai-codex
|
||||
icon: openai
|
||||
base_url: https://chatgpt.com/backend-api
|
||||
|
||||
models:
|
||||
- model_id: gpt-5.2
|
||||
name: GPT-5.2
|
||||
type: chat
|
||||
config:
|
||||
compatibilities: [tool-call, reasoning]
|
||||
|
||||
- model_id: gpt-5.2-codex
|
||||
name: GPT-5.2 Codex
|
||||
type: chat
|
||||
config:
|
||||
compatibilities: [tool-call, reasoning]
|
||||
|
||||
- model_id: gpt-5.1-codex
|
||||
name: GPT-5.1 Codex
|
||||
type: chat
|
||||
config:
|
||||
compatibilities: [tool-call, reasoning]
|
||||
|
||||
- model_id: gpt-5.1-codex-max
|
||||
name: GPT-5.1 Codex Max
|
||||
type: chat
|
||||
config:
|
||||
compatibilities: [tool-call, reasoning]
|
||||
|
||||
- model_id: gpt-5.1-codex-mini
|
||||
name: GPT-5.1 Codex Mini
|
||||
type: chat
|
||||
config:
|
||||
compatibilities: [tool-call, reasoning]
|
||||
|
||||
- model_id: gpt-5.1
|
||||
name: GPT-5.1
|
||||
type: chat
|
||||
config:
|
||||
compatibilities: [tool-call, reasoning]
|
||||
@@ -69,7 +69,7 @@ CREATE TABLE IF NOT EXISTS llm_providers (
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT llm_providers_name_unique UNIQUE (name),
|
||||
CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai'))
|
||||
CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS search_providers (
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
-- 0046_llm_provider_oauth (rollback)
|
||||
-- Remove OAuth token storage for LLM providers.
|
||||
|
||||
DROP INDEX IF EXISTS idx_llm_provider_oauth_tokens_state;
|
||||
DROP TABLE IF EXISTS llm_provider_oauth_tokens;
|
||||
@@ -0,0 +1,18 @@
|
||||
-- 0046_llm_provider_oauth
|
||||
-- Add OAuth token storage for LLM providers to support OpenAI Codex OAuth.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS llm_provider_oauth_tokens (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
llm_provider_id UUID NOT NULL UNIQUE REFERENCES llm_providers(id) ON DELETE CASCADE,
|
||||
access_token TEXT NOT NULL DEFAULT '',
|
||||
refresh_token TEXT NOT NULL DEFAULT '',
|
||||
expires_at TIMESTAMPTZ,
|
||||
scope TEXT NOT NULL DEFAULT '',
|
||||
token_type TEXT NOT NULL DEFAULT '',
|
||||
state TEXT NOT NULL DEFAULT '',
|
||||
pkce_code_verifier TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_llm_provider_oauth_tokens_state ON llm_provider_oauth_tokens(state) WHERE state != '';
|
||||
@@ -0,0 +1,11 @@
|
||||
-- 0047_add_openai_codex_client_type (rollback)
|
||||
-- Revert openai-codex rows back to openai-responses and restore the old CHECK constraint.
|
||||
|
||||
UPDATE llm_providers
|
||||
SET client_type = 'openai-responses',
|
||||
updated_at = now()
|
||||
WHERE client_type = 'openai-codex';
|
||||
|
||||
ALTER TABLE llm_providers DROP CONSTRAINT IF EXISTS llm_providers_client_type_check;
|
||||
ALTER TABLE llm_providers ADD CONSTRAINT llm_providers_client_type_check
|
||||
CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai'));
|
||||
@@ -0,0 +1,12 @@
|
||||
-- 0047_add_openai_codex_client_type
|
||||
-- Add openai-codex as a first-class client_type and migrate existing codex-oauth providers.
|
||||
|
||||
ALTER TABLE llm_providers DROP CONSTRAINT IF EXISTS llm_providers_client_type_check;
|
||||
ALTER TABLE llm_providers ADD CONSTRAINT llm_providers_client_type_check
|
||||
CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex'));
|
||||
|
||||
UPDATE llm_providers
|
||||
SET client_type = 'openai-codex',
|
||||
updated_at = now()
|
||||
WHERE client_type = 'openai-responses'
|
||||
AND metadata->>'auth_type' = 'openai-codex-oauth';
|
||||
@@ -0,0 +1,52 @@
|
||||
-- name: UpsertLlmProviderOAuthToken :one
|
||||
INSERT INTO llm_provider_oauth_tokens (
|
||||
llm_provider_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
scope,
|
||||
token_type,
|
||||
state,
|
||||
pkce_code_verifier
|
||||
)
|
||||
VALUES (
|
||||
sqlc.arg(llm_provider_id),
|
||||
sqlc.arg(access_token),
|
||||
sqlc.arg(refresh_token),
|
||||
sqlc.arg(expires_at),
|
||||
sqlc.arg(scope),
|
||||
sqlc.arg(token_type),
|
||||
sqlc.arg(state),
|
||||
sqlc.arg(pkce_code_verifier)
|
||||
)
|
||||
ON CONFLICT (llm_provider_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
scope = EXCLUDED.scope,
|
||||
token_type = EXCLUDED.token_type,
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
updated_at = now()
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetLlmProviderOAuthTokenByProvider :one
|
||||
SELECT * FROM llm_provider_oauth_tokens WHERE llm_provider_id = sqlc.arg(llm_provider_id);
|
||||
|
||||
-- name: GetLlmProviderOAuthTokenByState :one
|
||||
SELECT * FROM llm_provider_oauth_tokens WHERE state = sqlc.arg(state) AND state != '';
|
||||
|
||||
-- name: UpdateLlmProviderOAuthState :exec
|
||||
INSERT INTO llm_provider_oauth_tokens (llm_provider_id, state, pkce_code_verifier)
|
||||
VALUES (
|
||||
sqlc.arg(llm_provider_id),
|
||||
sqlc.arg(state),
|
||||
sqlc.arg(pkce_code_verifier)
|
||||
)
|
||||
ON CONFLICT (llm_provider_id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
updated_at = now();
|
||||
|
||||
-- name: DeleteLlmProviderOAuthToken :exec
|
||||
DELETE FROM llm_provider_oauth_tokens WHERE llm_provider_id = sqlc.arg(llm_provider_id);
|
||||
@@ -59,6 +59,6 @@ RUN chmod +x /entrypoint.sh
|
||||
|
||||
VOLUME ["/var/lib/containerd", "/opt/memoh/data"]
|
||||
|
||||
EXPOSE 8080
|
||||
EXPOSE 8080 1455
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
@@ -95,6 +95,7 @@ services:
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
ports:
|
||||
- "${MEMOH_DEV_SERVER_PORT:-18080}:8080"
|
||||
- "${MEMOH_DEV_OAUTH_PORT:-1455}:8080"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget --no-verbose --tries=1 --spider http://127.0.0.1:8080/health || exit 1"]
|
||||
interval: 5s
|
||||
|
||||
@@ -47,6 +47,7 @@ services:
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
ports:
|
||||
- "8080:8080"
|
||||
- "1455:8080"
|
||||
depends_on:
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
|
||||
@@ -117,7 +117,7 @@ RUN mkdir -p /opt/memoh/data /run/containerd /var/lib/containerd
|
||||
|
||||
VOLUME ["/var/lib/containerd", "/opt/memoh/data"]
|
||||
|
||||
EXPOSE 8080
|
||||
EXPOSE 8080 1455
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=30s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 --spider http://127.0.0.1:8080/health \
|
||||
|
||||
@@ -26,7 +26,7 @@ require (
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mailgun/mailgun-go/v5 v5.14.0
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7
|
||||
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1
|
||||
github.com/opencontainers/image-spec v1.1.1
|
||||
github.com/opencontainers/runtime-spec v1.3.0
|
||||
@@ -36,6 +36,7 @@ require (
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/swaggo/swag v1.16.6
|
||||
github.com/wneessen/go-mail v0.7.2
|
||||
github.com/yuin/goldmark v1.7.13
|
||||
go.uber.org/fx v1.24.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
@@ -123,7 +124,6 @@ require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.13 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
|
||||
@@ -143,3 +143,4 @@ require (
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect
|
||||
)
|
||||
|
||||
|
||||
@@ -228,8 +228,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 h1:beehwOQperqGWj4m4EhcPhnSZKtDiuHK/7ZMoTPaQjw=
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7/go.mod h1:OvmxM7JmnXBmwJWWVqtreL3HSHSKuzPbtbhlg5MvBg0=
|
||||
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd h1:uV7xsqYHYpEmT6xKvkOs5mHT5oEKnwV1F93ialqi78k=
|
||||
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd/go.mod h1:vHNoRb6/quMacMAgIp838aoiNhsZbE0bFCnRRNyRwNc=
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584 h1:zu4T54unBe8ziIlTr8gUtFoR16c6u2G1Qpx8OGsenxo=
|
||||
github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584/go.mod h1:GZTT9GUT3uSs6zram/FcF24GLTZMFSpiybbYmjr+gH8=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
|
||||
|
||||
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
@@ -69,14 +70,17 @@ func SpawnSystemPrompt(sessionType string) string {
|
||||
})
|
||||
}
|
||||
|
||||
// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to models.NewSDKChatModel.
|
||||
// SpawnModelCreatorFunc returns a tools.ModelCreator backed by the shared SDK model factory.
|
||||
// This keeps subagent model creation aligned with the shared SDK model factory.
|
||||
func SpawnModelCreatorFunc() tools.ModelCreator {
|
||||
return func(modelID, clientType, apiKey, baseURL string) *sdk.Model {
|
||||
return func(modelID, clientType, apiKey, codexAccountID, baseURL string, httpClient *http.Client) *sdk.Model {
|
||||
return models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ModelID: modelID,
|
||||
ClientType: clientType,
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
ModelID: modelID,
|
||||
ClientType: clientType,
|
||||
APIKey: apiKey,
|
||||
CodexAccountID: codexAccountID,
|
||||
BaseURL: baseURL,
|
||||
HTTPClient: httpClient,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
sessionpkg "github.com/memohai/memoh/internal/session"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
)
|
||||
@@ -292,7 +294,7 @@ func (p *SpawnProvider) persistMessages(
|
||||
}
|
||||
|
||||
// ModelCreator creates an sdk.Model from provider config. Set via SetModelCreator.
|
||||
type ModelCreator func(modelID, clientType, apiKey, baseURL string) *sdk.Model
|
||||
type ModelCreator func(modelID, clientType, apiKey, codexAccountID, baseURL string, httpClient *http.Client) *sdk.Model
|
||||
|
||||
// SetModelCreator injects the function used to create SDK models
|
||||
// (typically agent.CreateModel wrapped to match the signature).
|
||||
@@ -323,7 +325,19 @@ func (p *SpawnProvider) resolveModel(ctx context.Context, botID string) (*sdk.Mo
|
||||
if p.modelCreator == nil {
|
||||
return nil, "", errors.New("model creator not configured")
|
||||
}
|
||||
sdkModel := p.modelCreator(modelInfo.ModelID, provider.ClientType, provider.ApiKey, provider.BaseUrl)
|
||||
authResolver := providers.NewService(nil, p.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
sdkModel := p.modelCreator(
|
||||
modelInfo.ModelID,
|
||||
provider.ClientType,
|
||||
creds.APIKey,
|
||||
creds.CodexAccountID,
|
||||
provider.BaseUrl,
|
||||
nil,
|
||||
)
|
||||
return sdkModel, modelInfo.ID, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
@@ -95,6 +96,23 @@ type SystemFile struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
// ModelConfig holds provider and model information resolved from DB.
|
||||
type ModelConfig struct {
|
||||
ModelID string
|
||||
ClientType string
|
||||
APIKey string //nolint:gosec // carries provider credential material at runtime
|
||||
CodexAccountID string
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
ReasoningConfig *ReasoningConfig
|
||||
}
|
||||
|
||||
// ReasoningConfig controls extended thinking/reasoning behavior.
|
||||
type ReasoningConfig struct {
|
||||
Enabled bool
|
||||
Effort string
|
||||
}
|
||||
|
||||
func mustMarshal(v any) json.RawMessage {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
|
||||
@@ -104,10 +104,12 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU
|
||||
userPrompt := buildUserPrompt(priorSummaries, entries)
|
||||
|
||||
model := models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ClientType: cfg.ClientType,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
ModelID: cfg.ModelID,
|
||||
ClientType: cfg.ClientType,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
CodexAccountID: cfg.CodexAccountID,
|
||||
ModelID: cfg.ModelID,
|
||||
HTTPClient: cfg.HTTPClient,
|
||||
})
|
||||
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package compaction
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Log represents a compaction log entry.
|
||||
type Log struct {
|
||||
@@ -24,10 +27,12 @@ type ListLogsResponse struct {
|
||||
|
||||
// TriggerConfig holds the parameters needed to trigger a compaction.
|
||||
type TriggerConfig struct {
|
||||
BotID string
|
||||
SessionID string
|
||||
ModelID string
|
||||
ClientType string
|
||||
APIKey string //nolint:gosec // runtime credential, not a hardcoded secret
|
||||
BaseURL string
|
||||
BotID string
|
||||
SessionID string
|
||||
ModelID string
|
||||
ClientType string
|
||||
APIKey string //nolint:gosec // runtime credential, not a hardcoded secret
|
||||
CodexAccountID string
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
)
|
||||
|
||||
@@ -278,10 +280,17 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
}
|
||||
}
|
||||
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
return resolvedContext{}, fmt.Errorf("resolve provider credentials: %w", err)
|
||||
}
|
||||
|
||||
modelCfg := models.SDKModelConfig{
|
||||
ModelID: chatModel.ModelID,
|
||||
ClientType: clientType,
|
||||
APIKey: provider.ApiKey,
|
||||
APIKey: creds.APIKey,
|
||||
CodexAccountID: creds.CodexAccountID,
|
||||
BaseURL: provider.BaseUrl,
|
||||
ReasoningConfig: reasoningConfig,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/compaction"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatRequest, rc resolvedContext, inputTokens int) {
|
||||
@@ -47,8 +48,15 @@ func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatReques
|
||||
r.logger.Warn("compaction: failed to fetch provider", slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
r.logger.Warn("compaction: failed to resolve provider credentials", slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
cfg.ClientType = provider.ClientType
|
||||
cfg.APIKey = provider.ApiKey
|
||||
cfg.APIKey = creds.APIKey
|
||||
cfg.CodexAccountID = creds.CodexAccountID
|
||||
cfg.BaseURL = provider.BaseUrl
|
||||
|
||||
r.compactionService.TriggerCompaction(ctx, cfg)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/session"
|
||||
)
|
||||
|
||||
@@ -104,11 +105,19 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
|
||||
"Return ONLY the title text, nothing else.\n\n" +
|
||||
"User: " + userSnippet
|
||||
|
||||
authResolver := providers.NewService(nil, r.queries, "")
|
||||
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
r.logger.Warn("title gen: failed to resolve provider credentials", slog.Any("error", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
modelCfg := models.SDKModelConfig{
|
||||
ModelID: model.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: provider.ApiKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
ModelID: model.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: creds.APIKey,
|
||||
CodexAccountID: creds.CodexAccountID,
|
||||
BaseURL: provider.BaseUrl,
|
||||
}
|
||||
sdkModel := models.NewSDKChatModel(modelCfg)
|
||||
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: llm_provider_oauth.sql
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteLlmProviderOAuthToken = `-- name: DeleteLlmProviderOAuthToken :exec
|
||||
DELETE FROM llm_provider_oauth_tokens WHERE llm_provider_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteLlmProviderOAuthToken(ctx context.Context, llmProviderID pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteLlmProviderOAuthToken, llmProviderID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getLlmProviderOAuthTokenByProvider = `-- name: GetLlmProviderOAuthTokenByProvider :one
|
||||
SELECT id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at FROM llm_provider_oauth_tokens WHERE llm_provider_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetLlmProviderOAuthTokenByProvider(ctx context.Context, llmProviderID pgtype.UUID) (LlmProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, getLlmProviderOAuthTokenByProvider, llmProviderID)
|
||||
var i LlmProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.LlmProviderID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getLlmProviderOAuthTokenByState = `-- name: GetLlmProviderOAuthTokenByState :one
|
||||
SELECT id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at FROM llm_provider_oauth_tokens WHERE state = $1 AND state != ''
|
||||
`
|
||||
|
||||
func (q *Queries) GetLlmProviderOAuthTokenByState(ctx context.Context, state string) (LlmProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, getLlmProviderOAuthTokenByState, state)
|
||||
var i LlmProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.LlmProviderID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateLlmProviderOAuthState = `-- name: UpdateLlmProviderOAuthState :exec
|
||||
INSERT INTO llm_provider_oauth_tokens (llm_provider_id, state, pkce_code_verifier)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3
|
||||
)
|
||||
ON CONFLICT (llm_provider_id) DO UPDATE SET
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
updated_at = now()
|
||||
`
|
||||
|
||||
type UpdateLlmProviderOAuthStateParams struct {
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateLlmProviderOAuthState(ctx context.Context, arg UpdateLlmProviderOAuthStateParams) error {
|
||||
_, err := q.db.Exec(ctx, updateLlmProviderOAuthState, arg.LlmProviderID, arg.State, arg.PkceCodeVerifier)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertLlmProviderOAuthToken = `-- name: UpsertLlmProviderOAuthToken :one
|
||||
INSERT INTO llm_provider_oauth_tokens (
|
||||
llm_provider_id,
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
scope,
|
||||
token_type,
|
||||
state,
|
||||
pkce_code_verifier
|
||||
)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8
|
||||
)
|
||||
ON CONFLICT (llm_provider_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
scope = EXCLUDED.scope,
|
||||
token_type = EXCLUDED.token_type,
|
||||
state = EXCLUDED.state,
|
||||
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
|
||||
updated_at = now()
|
||||
RETURNING id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpsertLlmProviderOAuthTokenParams struct {
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpsertLlmProviderOAuthToken(ctx context.Context, arg UpsertLlmProviderOAuthTokenParams) (LlmProviderOauthToken, error) {
|
||||
row := q.db.QueryRow(ctx, upsertLlmProviderOAuthToken,
|
||||
arg.LlmProviderID,
|
||||
arg.AccessToken,
|
||||
arg.RefreshToken,
|
||||
arg.ExpiresAt,
|
||||
arg.Scope,
|
||||
arg.TokenType,
|
||||
arg.State,
|
||||
arg.PkceCodeVerifier,
|
||||
)
|
||||
var i LlmProviderOauthToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.LlmProviderID,
|
||||
&i.AccessToken,
|
||||
&i.RefreshToken,
|
||||
&i.ExpiresAt,
|
||||
&i.Scope,
|
||||
&i.TokenType,
|
||||
&i.State,
|
||||
&i.PkceCodeVerifier,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -292,6 +292,20 @@ type LlmProvider struct {
|
||||
ClientType string `json:"client_type"`
|
||||
}
|
||||
|
||||
type LlmProviderOauthToken struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
State string `json:"state"`
|
||||
PkceCodeVerifier string `json:"pkce_code_verifier"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type McpConnection struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
BotID pgtype.UUID `json:"bot_id"`
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
type ProviderOAuthHandler struct {
|
||||
service *providers.Service
|
||||
}
|
||||
|
||||
func NewProviderOAuthHandler(service *providers.Service) *ProviderOAuthHandler {
|
||||
return &ProviderOAuthHandler{service: service}
|
||||
}
|
||||
|
||||
func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
|
||||
e.GET("/providers/:id/oauth/authorize", h.Authorize)
|
||||
e.GET("/providers/:id/oauth/status", h.Status)
|
||||
e.DELETE("/providers/:id/oauth/token", h.Revoke)
|
||||
e.GET("/auth/callback", h.Callback)
|
||||
e.GET("/providers/oauth/callback", h.Callback)
|
||||
}
|
||||
|
||||
// Authorize godoc
|
||||
// @Summary Start OAuth2 authorization for an LLM provider
|
||||
// @Tags providers-oauth
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Success 200 {object} map[string]string
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /providers/{id}/oauth/authorize [get].
|
||||
func (h *ProviderOAuthHandler) Authorize(c echo.Context) error {
|
||||
providerID := strings.TrimSpace(c.Param("id"))
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
authURL, err := h.service.StartOAuthAuthorization(c.Request().Context(), providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]string{"auth_url": authURL})
|
||||
}
|
||||
|
||||
// Status godoc
|
||||
// @Summary Get OAuth2 status for an LLM provider
|
||||
// @Tags providers-oauth
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Success 200 {object} providers.OAuthStatus
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /providers/{id}/oauth/status [get].
|
||||
func (h *ProviderOAuthHandler) Status(c echo.Context) error {
|
||||
providerID := strings.TrimSpace(c.Param("id"))
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
status, err := h.service.GetOAuthStatus(c.Request().Context(), providerID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
// Revoke godoc
|
||||
// @Summary Revoke stored OAuth2 tokens for an LLM provider
|
||||
// @Tags providers-oauth
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Success 204 "No Content"
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /providers/{id}/oauth/token [delete].
|
||||
func (h *ProviderOAuthHandler) Revoke(c echo.Context) error {
|
||||
providerID := strings.TrimSpace(c.Param("id"))
|
||||
if providerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
if err := h.service.RevokeOAuthToken(c.Request().Context(), providerID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Callback godoc
|
||||
// @Summary OAuth2 callback for LLM providers
|
||||
// @Tags providers-oauth
|
||||
// @Param code query string true "Authorization code"
|
||||
// @Param state query string true "State parameter"
|
||||
// @Success 200 {string} string "HTML success page"
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Router /providers/oauth/callback [get].
|
||||
func (h *ProviderOAuthHandler) Callback(c echo.Context) error {
|
||||
code := strings.TrimSpace(c.QueryParam("code"))
|
||||
state := strings.TrimSpace(c.QueryParam("state"))
|
||||
if code == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "code is required")
|
||||
}
|
||||
if state == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "state is required")
|
||||
}
|
||||
providerID, err := h.service.HandleOAuthCallback(c.Request().Context(), state, code)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
page := template.Must(template.New("oauth-success").Parse(`<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>OpenAI OAuth Connected</title>
|
||||
</head>
|
||||
<body style="font-family: sans-serif; padding: 24px;">
|
||||
<h2>OpenAI OAuth connected</h2>
|
||||
<p>You can close this window and return to Memoh.</p>
|
||||
<script>
|
||||
window.opener?.postMessage({ type: "memoh-provider-oauth-success", providerId: "{{.ProviderID}}" }, "*");
|
||||
window.close();
|
||||
</script>
|
||||
</body>
|
||||
</html>`))
|
||||
return c.HTML(http.StatusOK, executeHTMLTemplate(page, map[string]string{"ProviderID": providerID}))
|
||||
}
|
||||
|
||||
func executeHTMLTemplate(tpl *template.Template, data any) string {
|
||||
var b strings.Builder
|
||||
_ = tpl.Execute(&b, data)
|
||||
return b.String()
|
||||
}
|
||||
@@ -309,20 +309,31 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
|
||||
}
|
||||
|
||||
defaultCompat := []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
|
||||
|
||||
resp := providers.ImportModelsResponse{
|
||||
Models: make([]string, 0),
|
||||
}
|
||||
|
||||
for _, m := range remoteModels {
|
||||
modelType := models.ModelTypeChat
|
||||
if strings.TrimSpace(m.Type) == string(models.ModelTypeEmbedding) {
|
||||
modelType = models.ModelTypeEmbedding
|
||||
}
|
||||
compatibilities := m.Compatibilities
|
||||
if len(compatibilities) == 0 {
|
||||
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
|
||||
}
|
||||
name := strings.TrimSpace(m.Name)
|
||||
if name == "" {
|
||||
name = m.ID
|
||||
}
|
||||
_, err := h.modelsService.Create(c.Request().Context(), models.AddRequest{
|
||||
ModelID: m.ID,
|
||||
Name: m.ID,
|
||||
Name: name,
|
||||
LlmProviderID: id,
|
||||
Type: models.ModelTypeChat,
|
||||
Type: modelType,
|
||||
Config: models.ModelConfig{
|
||||
Compatibilities: defaultCompat,
|
||||
Compatibilities: compatibilities,
|
||||
ReasoningEfforts: m.ReasoningEfforts,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -75,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg
|
||||
return nil, fmt.Errorf("dense runtime: %w", err)
|
||||
}
|
||||
|
||||
embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
|
||||
embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout, nil)
|
||||
|
||||
return &denseRuntime{
|
||||
qdrant: qClient,
|
||||
|
||||
@@ -13,11 +13,13 @@ import (
|
||||
// provider configuration. It dispatches to the native Google embedding provider
|
||||
// when clientType is "google-generative-ai", and falls back to the
|
||||
// OpenAI-compatible /embeddings endpoint for all other provider types.
|
||||
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
|
||||
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration, httpClient *http.Client) *sdk.EmbeddingModel {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{Timeout: timeout}
|
||||
}
|
||||
|
||||
switch ClientType(clientType) {
|
||||
case ClientTypeGoogleGenerativeAI:
|
||||
@@ -30,7 +32,6 @@ func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout t
|
||||
}
|
||||
p := googleembedding.New(opts...)
|
||||
return p.EmbeddingModel(modelID)
|
||||
|
||||
default:
|
||||
opts := []openaiembedding.Option{
|
||||
openaiembedding.WithAPIKey(apiKey),
|
||||
|
||||
+81
-11
@@ -2,6 +2,9 @@ package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -9,11 +12,13 @@ import (
|
||||
|
||||
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
|
||||
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
||||
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
const probeTimeout = 15 * time.Second
|
||||
@@ -37,16 +42,17 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(provider.BaseUrl, "/")
|
||||
apiKey := provider.ApiKey
|
||||
clientType := ClientType(provider.ClientType)
|
||||
|
||||
// Embedding models don't have a chat Provider in the SDK — probe
|
||||
// the /embeddings endpoint directly.
|
||||
if model.Type == string(ModelTypeEmbedding) {
|
||||
return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID)
|
||||
creds, err := s.resolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
return TestResponse{}, err
|
||||
}
|
||||
|
||||
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
|
||||
if model.Type == string(ModelTypeEmbedding) {
|
||||
return s.testEmbeddingModel(ctx, baseURL, creds.APIKey, model.ModelID, nil)
|
||||
}
|
||||
|
||||
sdkProvider := NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
@@ -100,11 +106,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
// testEmbeddingModel probes an embedding model by performing a minimal
|
||||
// embedding request via the Twilight SDK, verifying that the model is
|
||||
// reachable and functional rather than merely checking HTTP connectivity.
|
||||
func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) {
|
||||
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string, httpClient *http.Client) (TestResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout)
|
||||
model := NewSDKEmbeddingModel(string(ClientTypeOpenAICompletions), baseURL, apiKey, modelID, probeTimeout, httpClient)
|
||||
client := sdk.NewClient()
|
||||
|
||||
start := time.Now()
|
||||
@@ -130,8 +136,10 @@ func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, api
|
||||
|
||||
// NewSDKProvider creates a Twilight AI SDK Provider for the given client type.
|
||||
// It is exported so that other packages (e.g. providers) can reuse it for testing.
|
||||
func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.Duration) sdk.Provider {
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientType, timeout time.Duration, httpClient *http.Client) sdk.Provider {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{Timeout: timeout}
|
||||
}
|
||||
|
||||
switch clientType {
|
||||
case ClientTypeOpenAIResponses:
|
||||
@@ -144,6 +152,16 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.
|
||||
}
|
||||
return openairesponses.New(opts...)
|
||||
|
||||
case ClientTypeOpenAICodex:
|
||||
opts := []openaicodex.Option{
|
||||
openaicodex.WithAccessToken(apiKey),
|
||||
openaicodex.WithHTTPClient(httpClient),
|
||||
}
|
||||
if codexAccountID != "" {
|
||||
opts = append(opts, openaicodex.WithAccountID(codexAccountID))
|
||||
}
|
||||
return openaicodex.New(opts...)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(apiKey),
|
||||
@@ -175,3 +193,55 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.
|
||||
return openaicompletions.New(opts...)
|
||||
}
|
||||
}
|
||||
|
||||
type modelCredentials struct {
|
||||
APIKey string //nolint:gosec // runtime credential material used to construct SDK providers
|
||||
CodexAccountID string
|
||||
}
|
||||
|
||||
func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (modelCredentials, error) {
|
||||
if ClientType(provider.ClientType) != ClientTypeOpenAICodex {
|
||||
return modelCredentials{APIKey: provider.ApiKey}, nil
|
||||
}
|
||||
|
||||
tokenRow, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
accessToken := strings.TrimSpace(tokenRow.AccessToken)
|
||||
if accessToken == "" {
|
||||
return modelCredentials{}, errors.New("oauth token is missing access token")
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(accessToken)
|
||||
if err != nil {
|
||||
return modelCredentials{}, err
|
||||
}
|
||||
return modelCredentials{
|
||||
APIKey: accessToken,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func codexAccountIDFromToken(token string) (string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", errors.New("invalid oauth access token")
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode oauth token payload: %w", err)
|
||||
}
|
||||
var claims struct {
|
||||
OpenAIAuth struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
} `json:"https://api.openai.com/auth"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return "", fmt.Errorf("parse oauth token payload: %w", err)
|
||||
}
|
||||
accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID)
|
||||
if accountID == "" {
|
||||
return "", errors.New("oauth access token missing chatgpt_account_id")
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
+34
-1
@@ -1,10 +1,12 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
|
||||
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
||||
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
@@ -16,7 +18,9 @@ type SDKModelConfig struct {
|
||||
ModelID string
|
||||
ClientType string
|
||||
APIKey string //nolint:gosec // carries provider credential material at runtime
|
||||
CodexAccountID string
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
ReasoningConfig *ReasoningConfig
|
||||
}
|
||||
|
||||
@@ -38,6 +42,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
opts := []openaicompletions.Option{
|
||||
openaicompletions.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
@@ -48,16 +55,34 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
opts := []openairesponses.Option{
|
||||
openairesponses.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, openairesponses.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
p := openairesponses.New(opts...)
|
||||
return p.ChatModel(cfg.ModelID)
|
||||
|
||||
case ClientTypeOpenAICodex:
|
||||
opts := []openaicodex.Option{
|
||||
openaicodex.WithAccessToken(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.CodexAccountID != "" {
|
||||
opts = append(opts, openaicodex.WithAccountID(cfg.CodexAccountID))
|
||||
}
|
||||
return openaicodex.New(opts...).ChatModel(cfg.ModelID)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
@@ -75,6 +100,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
opts := []googlegenerative.Option{
|
||||
googlegenerative.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, googlegenerative.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
@@ -85,6 +113,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
opts := []openaicompletions.Option{
|
||||
openaicompletions.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
@@ -106,7 +137,7 @@ func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
|
||||
switch ClientType(cfg.ClientType) {
|
||||
case ClientTypeAnthropicMessages:
|
||||
return nil
|
||||
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions:
|
||||
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions, ClientTypeOpenAICodex:
|
||||
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
|
||||
case ClientTypeGoogleGenerativeAI:
|
||||
return nil
|
||||
@@ -147,6 +178,8 @@ func ResolveClientType(model *sdk.Model) string {
|
||||
return string(ClientTypeAnthropicMessages)
|
||||
case strings.Contains(name, "google"):
|
||||
return string(ClientTypeGoogleGenerativeAI)
|
||||
case strings.Contains(name, "codex"):
|
||||
return string(ClientTypeOpenAICodex)
|
||||
case strings.Contains(name, "responses"):
|
||||
return string(ClientTypeOpenAIResponses)
|
||||
default:
|
||||
|
||||
@@ -20,6 +20,7 @@ const (
|
||||
ClientTypeOpenAICompletions ClientType = "openai-completions"
|
||||
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
|
||||
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
|
||||
ClientTypeOpenAICodex ClientType = "openai-codex"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,16 +30,33 @@ const (
|
||||
CompatReasoning = "reasoning"
|
||||
)
|
||||
|
||||
const (
|
||||
ReasoningEffortNone = "none"
|
||||
ReasoningEffortLow = "low"
|
||||
ReasoningEffortMedium = "medium"
|
||||
ReasoningEffortHigh = "high"
|
||||
ReasoningEffortXHigh = "xhigh"
|
||||
)
|
||||
|
||||
// validCompatibilities enumerates accepted compatibility tokens.
|
||||
var validCompatibilities = map[string]struct{}{
|
||||
CompatVision: {}, CompatToolCall: {}, CompatImageOutput: {}, CompatReasoning: {},
|
||||
}
|
||||
|
||||
var validReasoningEfforts = map[string]struct{}{
|
||||
ReasoningEffortNone: {},
|
||||
ReasoningEffortLow: {},
|
||||
ReasoningEffortMedium: {},
|
||||
ReasoningEffortHigh: {},
|
||||
ReasoningEffortXHigh: {},
|
||||
}
|
||||
|
||||
// ModelConfig holds the JSONB config stored per model.
|
||||
type ModelConfig struct {
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
Compatibilities []string `json:"compatibilities,omitempty"`
|
||||
ContextWindow *int `json:"context_window,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
Compatibilities []string `json:"compatibilities,omitempty"`
|
||||
ContextWindow *int `json:"context_window,omitempty"`
|
||||
ReasoningEfforts []string `json:"reasoning_efforts,omitempty"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -72,6 +90,11 @@ func (m *Model) Validate() error {
|
||||
return errors.New("invalid compatibility: " + c)
|
||||
}
|
||||
}
|
||||
for _, effort := range m.Config.ReasoningEfforts {
|
||||
if _, ok := validReasoningEfforts[effort]; !ok {
|
||||
return errors.New("invalid reasoning effort: " + effort)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
const openAIAuthClaimPath = "https://api.openai.com/auth"
|
||||
|
||||
type ModelCredentials struct {
|
||||
APIKey string //nolint:gosec // runtime credential material used to construct SDK providers
|
||||
CodexAccountID string
|
||||
}
|
||||
|
||||
func SupportsOpenAICodexOAuth(provider sqlc.LlmProvider) bool {
|
||||
return supportsOAuth(provider)
|
||||
}
|
||||
|
||||
func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (ModelCredentials, error) {
|
||||
if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex {
|
||||
return ModelCredentials{
|
||||
APIKey: provider.ApiKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
accountID, err := codexAccountIDFromToken(token)
|
||||
if err != nil {
|
||||
return ModelCredentials{}, err
|
||||
}
|
||||
return ModelCredentials{
|
||||
APIKey: token,
|
||||
CodexAccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func codexAccountIDFromToken(token string) (string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", errors.New("invalid oauth access token")
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode oauth token payload: %w", err)
|
||||
}
|
||||
var claims struct {
|
||||
OpenAIAuth struct {
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id"`
|
||||
} `json:"https://api.openai.com/auth"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return "", fmt.Errorf("parse oauth token payload: %w", err)
|
||||
}
|
||||
accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID)
|
||||
if accountID == "" {
|
||||
return "", fmt.Errorf("oauth access token missing %s.chatgpt_account_id", openAIAuthClaimPath)
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
@@ -0,0 +1,468 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultOpenAICodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
defaultOpenAIAuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
defaultOpenAITokenURL = "https://auth.openai.com/oauth/token" //nolint:gosec // OAuth endpoint URL, not a credential
|
||||
defaultOpenAICallbackURL = "http://localhost:1455/auth/callback"
|
||||
defaultOpenAIOAuthScopes = "openid profile email offline_access"
|
||||
oauthExpirySkew = 30 * time.Second
|
||||
providerOAuthHTTPTimeout = 15 * time.Second
|
||||
metadataOAuthClientIDKey = "oauth_client_id"
|
||||
metadataOAuthAuthorizeURLKey = "oauth_authorize_url"
|
||||
metadataOAuthTokenURLKey = "oauth_token_url" //nolint:gosec // metadata key name, not a credential
|
||||
metadataOAuthRedirectURIKey = "oauth_redirect_uri"
|
||||
metadataOAuthScopesKey = "oauth_scopes"
|
||||
metadataOAuthAudienceKey = "oauth_audience"
|
||||
metadataOAuthUseIDOrgsFlagKey = "oauth_id_token_add_organizations"
|
||||
)
|
||||
|
||||
type providerOAuthToken struct {
|
||||
ProviderID string `json:"provider_id"`
|
||||
AccessToken string `json:"access_token"` //nolint:gosec // runtime credential storage
|
||||
RefreshToken string `json:"refresh_token"` //nolint:gosec // runtime credential storage
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
State string `json:"state"`
|
||||
PKCECodeVerifier string `json:"pkce_code_verifier"`
|
||||
}
|
||||
|
||||
type openAIOAuthConfig struct {
|
||||
ClientID string
|
||||
AuthorizeURL string
|
||||
TokenURL string
|
||||
RedirectURI string
|
||||
Scopes string
|
||||
IDTokenAddOrganizations bool
|
||||
}
|
||||
|
||||
func providerMetadata(raw []byte) map[string]any {
|
||||
if len(raw) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(raw, &metadata); err != nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
if metadata == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func (s *Service) oauthConfig(metadata map[string]any) openAIOAuthConfig {
|
||||
cfg := openAIOAuthConfig{
|
||||
ClientID: defaultOpenAICodexClientID,
|
||||
AuthorizeURL: defaultOpenAIAuthorizeURL,
|
||||
TokenURL: defaultOpenAITokenURL,
|
||||
RedirectURI: firstNonEmpty(s.callbackURL, defaultOpenAICallbackURL),
|
||||
Scopes: defaultOpenAIOAuthScopes,
|
||||
IDTokenAddOrganizations: true,
|
||||
}
|
||||
if v, _ := metadata[metadataOAuthClientIDKey].(string); strings.TrimSpace(v) != "" {
|
||||
cfg.ClientID = strings.TrimSpace(v)
|
||||
}
|
||||
if v, _ := metadata[metadataOAuthAuthorizeURLKey].(string); strings.TrimSpace(v) != "" {
|
||||
cfg.AuthorizeURL = strings.TrimSpace(v)
|
||||
}
|
||||
if v, _ := metadata[metadataOAuthTokenURLKey].(string); strings.TrimSpace(v) != "" {
|
||||
cfg.TokenURL = strings.TrimSpace(v)
|
||||
}
|
||||
if v, _ := metadata[metadataOAuthRedirectURIKey].(string); strings.TrimSpace(v) != "" {
|
||||
cfg.RedirectURI = strings.TrimSpace(v)
|
||||
}
|
||||
if v, _ := metadata[metadataOAuthScopesKey].(string); strings.TrimSpace(v) != "" {
|
||||
cfg.Scopes = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata[metadataOAuthUseIDOrgsFlagKey].(bool); ok {
|
||||
cfg.IDTokenAddOrganizations = v
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func supportsOAuth(provider sqlc.LlmProvider) bool {
|
||||
return models.ClientType(provider.ClientType) == models.ClientTypeOpenAICodex
|
||||
}
|
||||
|
||||
func (s *Service) StartOAuthAuthorization(ctx context.Context, providerID string) (string, error) {
|
||||
providerUUID, err := db.ParseUUID(providerID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
if !supportsOAuth(provider) {
|
||||
return "", errors.New("provider does not support oauth")
|
||||
}
|
||||
|
||||
cfg := s.oauthConfig(providerMetadata(provider.Metadata))
|
||||
codeVerifier, err := generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate code verifier: %w", err)
|
||||
}
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate state: %w", err)
|
||||
}
|
||||
if err := s.updateOAuthState(ctx, providerID, state, codeVerifier); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {cfg.RedirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {computeCodeChallenge(codeVerifier)},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
}
|
||||
if cfg.IDTokenAddOrganizations {
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
}
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
return cfg.AuthorizeURL + "?" + params.Encode(), nil
|
||||
}
|
||||
|
||||
func (s *Service) HandleOAuthCallback(ctx context.Context, state, code string) (string, error) {
|
||||
token, err := s.getOAuthTokenByState(ctx, state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
providerUUID, err := db.ParseUUID(token.ProviderID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
if !supportsOAuth(provider) {
|
||||
return "", errors.New("provider does not support oauth")
|
||||
}
|
||||
|
||||
cfg := s.oauthConfig(providerMetadata(provider.Metadata))
|
||||
resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := s.saveOAuthToken(ctx, provider.ID.String(), providerOAuthToken{
|
||||
ProviderID: provider.ID.String(),
|
||||
AccessToken: resp.AccessToken,
|
||||
RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken),
|
||||
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
|
||||
Scope: firstNonEmpty(resp.Scope, cfg.Scopes),
|
||||
TokenType: firstNonEmpty(resp.TokenType, "Bearer"),
|
||||
State: "",
|
||||
PKCECodeVerifier: "",
|
||||
}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return provider.ID.String(), nil
|
||||
}
|
||||
|
||||
func (s *Service) GetOAuthStatus(ctx context.Context, providerID string) (*OAuthStatus, error) {
|
||||
providerUUID, err := db.ParseUUID(providerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
status := &OAuthStatus{
|
||||
Configured: supportsOAuth(provider),
|
||||
CallbackURL: s.oauthConfig(providerMetadata(provider.Metadata)).RedirectURI,
|
||||
}
|
||||
if !status.Configured {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
token, err := s.getOAuthToken(ctx, providerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return status, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
status.HasToken = strings.TrimSpace(token.AccessToken) != ""
|
||||
if !token.ExpiresAt.IsZero() {
|
||||
expiresAt := token.ExpiresAt
|
||||
status.ExpiresAt = &expiresAt
|
||||
status.Expired = time.Now().After(token.ExpiresAt)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *Service) RevokeOAuthToken(ctx context.Context, providerID string) error {
|
||||
providerUUID, err := db.ParseUUID(providerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
if !supportsOAuth(provider) {
|
||||
return errors.New("provider does not support oauth")
|
||||
}
|
||||
return s.queries.DeleteLlmProviderOAuthToken(ctx, providerUUID)
|
||||
}
|
||||
|
||||
func (s *Service) GetValidAccessToken(ctx context.Context, providerID string) (string, error) {
|
||||
token, err := s.getOAuthToken(ctx, providerID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(token.AccessToken) == "" {
|
||||
return "", errors.New("oauth token is missing access token")
|
||||
}
|
||||
if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) {
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
if strings.TrimSpace(token.RefreshToken) == "" {
|
||||
return "", errors.New("oauth token expired and no refresh token is available")
|
||||
}
|
||||
|
||||
providerUUID, err := db.ParseUUID(providerID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
cfg := s.oauthConfig(providerMetadata(provider.Metadata))
|
||||
refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
saved := providerOAuthToken{
|
||||
ProviderID: providerID,
|
||||
AccessToken: refreshed.AccessToken,
|
||||
RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken),
|
||||
ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn),
|
||||
Scope: firstNonEmpty(refreshed.Scope, token.Scope),
|
||||
TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType),
|
||||
State: token.State,
|
||||
PKCECodeVerifier: token.PKCECodeVerifier,
|
||||
}
|
||||
if err := s.saveOAuthToken(ctx, providerID, saved); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return saved.AccessToken, nil
|
||||
}
|
||||
|
||||
func (s *Service) getOAuthToken(ctx context.Context, providerID string) (*providerOAuthToken, error) {
|
||||
providerUUID, err := db.ParseUUID(providerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
row, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, providerUUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toProviderOAuthToken(row), nil
|
||||
}
|
||||
|
||||
func (s *Service) getOAuthTokenByState(ctx context.Context, state string) (*providerOAuthToken, error) {
|
||||
row, err := s.queries.GetLlmProviderOAuthTokenByState(ctx, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toProviderOAuthToken(row), nil
|
||||
}
|
||||
|
||||
func (s *Service) updateOAuthState(ctx context.Context, providerID, state, codeVerifier string) error {
|
||||
providerUUID, err := db.ParseUUID(providerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.queries.UpdateLlmProviderOAuthState(ctx, sqlc.UpdateLlmProviderOAuthStateParams{
|
||||
LlmProviderID: providerUUID,
|
||||
State: state,
|
||||
PkceCodeVerifier: codeVerifier,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) saveOAuthToken(ctx context.Context, providerID string, token providerOAuthToken) error {
|
||||
providerUUID, err := db.ParseUUID(providerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var expiresAt pgtype.Timestamptz
|
||||
if !token.ExpiresAt.IsZero() {
|
||||
expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true}
|
||||
}
|
||||
_, err = s.queries.UpsertLlmProviderOAuthToken(ctx, sqlc.UpsertLlmProviderOAuthTokenParams{
|
||||
LlmProviderID: providerUUID,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: token.Scope,
|
||||
TokenType: token.TokenType,
|
||||
State: token.State,
|
||||
PkceCodeVerifier: token.PKCECodeVerifier,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func toProviderOAuthToken(row sqlc.LlmProviderOauthToken) *providerOAuthToken {
|
||||
token := &providerOAuthToken{
|
||||
ProviderID: row.LlmProviderID.String(),
|
||||
AccessToken: row.AccessToken,
|
||||
RefreshToken: row.RefreshToken,
|
||||
Scope: row.Scope,
|
||||
TokenType: row.TokenType,
|
||||
State: row.State,
|
||||
PKCECodeVerifier: row.PkceCodeVerifier,
|
||||
}
|
||||
if row.ExpiresAt.Valid {
|
||||
token.ExpiresAt = row.ExpiresAt.Time
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
type openAITokenResponse struct {
|
||||
AccessToken string `json:"access_token"` //nolint:gosec // OAuth response payload carries runtime access token
|
||||
RefreshToken string `json:"refresh_token"` //nolint:gosec // OAuth response payload carries runtime refresh token
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}
|
||||
|
||||
func (s *Service) exchangeCode(ctx context.Context, cfg openAIOAuthConfig, code, codeVerifier string) (*openAITokenResponse, error) {
|
||||
values := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {cfg.RedirectURI},
|
||||
"code_verifier": {codeVerifier},
|
||||
}
|
||||
return s.postTokenRequest(ctx, cfg.TokenURL, values)
|
||||
}
|
||||
|
||||
func (s *Service) refreshAccessToken(ctx context.Context, cfg openAIOAuthConfig, refreshToken string) (*openAITokenResponse, error) {
|
||||
values := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {cfg.ClientID},
|
||||
}
|
||||
return s.postTokenRequest(ctx, cfg.TokenURL, values)
|
||||
}
|
||||
|
||||
func (s *Service) postTokenRequest(ctx context.Context, tokenURL string, body url.Values) (*openAITokenResponse, error) {
|
||||
if err := validateOAuthTokenURL(tokenURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(body.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create oauth request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
//nolint:gosec // tokenURL is restricted to the fixed OpenAI OAuth host by validateOAuthTokenURL above
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execute oauth request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
payload, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read oauth response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("oauth token request failed: %s", strings.TrimSpace(string(payload)))
|
||||
}
|
||||
|
||||
var tokenResp openAITokenResponse
|
||||
if err := json.Unmarshal(payload, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("decode oauth response: %w", err)
|
||||
}
|
||||
if tokenResp.Error != "" {
|
||||
return nil, fmt.Errorf("oauth token request failed: %s", firstNonEmpty(tokenResp.Description, tokenResp.Error))
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func validateOAuthTokenURL(raw string) error {
|
||||
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid oauth token url: %w", err)
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return errors.New("oauth token url must use https")
|
||||
}
|
||||
if !strings.EqualFold(parsed.Hostname(), "auth.openai.com") {
|
||||
return errors.New("oauth token url host must be auth.openai.com")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func generateCodeVerifier() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func computeCodeChallenge(verifier string) string {
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func expiresAtFromNow(expiresIn int64) time.Time {
|
||||
if expiresIn <= 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
@@ -20,21 +21,27 @@ import (
|
||||
|
||||
// Service handles provider operations.
|
||||
type Service struct {
|
||||
queries *sqlc.Queries
|
||||
logger *slog.Logger
|
||||
queries *sqlc.Queries
|
||||
logger *slog.Logger
|
||||
httpClient *http.Client
|
||||
callbackURL string
|
||||
}
|
||||
|
||||
// NewService creates a new provider service.
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries, callbackURL string) *Service {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
return &Service{
|
||||
queries: queries,
|
||||
logger: log.With(slog.String("service", "providers")),
|
||||
queries: queries,
|
||||
logger: log.With(slog.String("service", "providers")),
|
||||
httpClient: &http.Client{Timeout: providerOAuthHTTPTimeout},
|
||||
callbackURL: callbackURL,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new LLM provider.
|
||||
func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) {
|
||||
// Marshal metadata
|
||||
metadataJSON, err := json.Marshal(req.Metadata)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
@@ -112,13 +119,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
return GetResponse{}, err
|
||||
}
|
||||
|
||||
// Get existing provider
|
||||
existing, err := s.queries.GetLlmProviderByID(ctx, providerID)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
|
||||
// Apply updates
|
||||
name := existing.Name
|
||||
if req.Name != nil {
|
||||
name = *req.Name
|
||||
@@ -146,16 +151,15 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
enable = *req.Enable
|
||||
}
|
||||
|
||||
metadata := existing.Metadata
|
||||
metadataMap := providerMetadata(existing.Metadata)
|
||||
if req.Metadata != nil {
|
||||
metadataJSON, err := json.Marshal(req.Metadata)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
metadata = metadataJSON
|
||||
metadataMap = req.Metadata
|
||||
}
|
||||
metadataJSON, err := json.Marshal(metadataMap)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
// Update provider
|
||||
updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{
|
||||
ID: providerID,
|
||||
Name: name,
|
||||
@@ -164,7 +168,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
ClientType: clientType,
|
||||
Icon: icon,
|
||||
Enable: enable,
|
||||
Metadata: metadata,
|
||||
Metadata: metadataJSON,
|
||||
})
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("update provider: %w", err)
|
||||
@@ -213,8 +217,12 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
baseURL := strings.TrimRight(provider.BaseUrl, "/")
|
||||
|
||||
clientType := models.ClientType(provider.ClientType)
|
||||
creds, err := s.ResolveModelCredentials(ctx, provider)
|
||||
if err != nil {
|
||||
return TestResponse{}, err
|
||||
}
|
||||
|
||||
sdkProvider := models.NewSDKProvider(baseURL, provider.ApiKey, clientType, probeTimeout)
|
||||
sdkProvider := models.NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil)
|
||||
|
||||
start := time.Now()
|
||||
result := sdkProvider.Test(ctx)
|
||||
@@ -238,6 +246,29 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get provider: %w", err)
|
||||
}
|
||||
if supportsOAuth(provider) {
|
||||
catalog := openaicodex.Catalog()
|
||||
remoteModels := make([]RemoteModel, 0, len(catalog))
|
||||
for _, model := range catalog {
|
||||
compatibilities := make([]string, 0, 2)
|
||||
if model.SupportsToolCall {
|
||||
compatibilities = append(compatibilities, models.CompatToolCall)
|
||||
}
|
||||
if model.SupportsReasoning {
|
||||
compatibilities = append(compatibilities, models.CompatReasoning)
|
||||
}
|
||||
remoteModels = append(remoteModels, RemoteModel{
|
||||
ID: model.ID,
|
||||
Name: model.DisplayName,
|
||||
Object: "model",
|
||||
OwnedBy: "openai-codex",
|
||||
Type: "chat",
|
||||
Compatibilities: compatibilities,
|
||||
ReasoningEfforts: append([]string(nil), model.ReasoningEfforts...),
|
||||
})
|
||||
}
|
||||
return remoteModels, nil
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(provider.BaseUrl, "/")
|
||||
modelsURL := fmt.Sprintf("%s/models", baseURL)
|
||||
@@ -250,7 +281,7 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
if provider.ApiKey != "" {
|
||||
if provider.ApiKey != "" && !supportsOAuth(provider) {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.ApiKey))
|
||||
}
|
||||
|
||||
@@ -284,7 +315,6 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// Mask API key (show only first 8 characters)
|
||||
maskedAPIKey := maskAPIKey(provider.ApiKey)
|
||||
|
||||
var icon string
|
||||
@@ -318,7 +348,6 @@ func maskAPIKey(apiKey string) string {
|
||||
}
|
||||
|
||||
// resolveUpdatedAPIKey keeps the original key when the request value matches the masked version.
|
||||
// This prevents masked placeholder values from overwriting the real stored credential.
|
||||
func resolveUpdatedAPIKey(existing string, updated *string) string {
|
||||
if updated == nil {
|
||||
return existing
|
||||
|
||||
@@ -55,12 +55,25 @@ type TestResponse struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthStatus is returned by GET /providers/:id/oauth/status.
|
||||
type OAuthStatus struct {
|
||||
Configured bool `json:"configured"`
|
||||
HasToken bool `json:"has_token"`
|
||||
Expired bool `json:"expired"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
}
|
||||
|
||||
// RemoteModel represents a model returned by the provider's /v1/models endpoint.
|
||||
type RemoteModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Compatibilities []string `json:"compatibilities,omitempty"`
|
||||
ReasoningEfforts []string `json:"reasoning_efforts,omitempty"`
|
||||
}
|
||||
|
||||
// FetchModelsResponse represents the response from the provider's /v1/models endpoint.
|
||||
|
||||
@@ -90,5 +90,11 @@ func shouldSkipJWT(path string) bool {
|
||||
if strings.HasPrefix(path, "/email/oauth/callback") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(path, "/providers/oauth/callback") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(path, "/auth/callback") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user