diff --git a/apps/web/src/components/add-provider/index.vue b/apps/web/src/components/add-provider/index.vue index 6d45e351..d3bbdc20 100644 --- a/apps/web/src/components/add-provider/index.vue +++ b/apps/web/src/components/add-provider/index.vue @@ -47,6 +47,7 @@ @@ -68,6 +69,12 @@ +
+ {{ $t('provider.oauth.createHint') }} +
('open') const { t } = useI18n() @@ -208,11 +215,19 @@ const { mutateAsync: createProviderMutation, isLoading } = useMutation({ }) const providerSchema = toTypedSchema(z.object({ - api_key: z.string().min(1), + api_key: z.string().optional(), base_url: z.string().min(1), name: z.string().min(1), client_type: z.string().min(1), auto_import: z.boolean().optional(), +}).superRefine((value, ctx) => { + if (value.client_type !== 'openai-codex' && !value.api_key?.trim()) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + path: ['api_key'], + message: 'API key is required', + }) + } })) const form = useForm({ @@ -223,12 +238,19 @@ const form = useForm({ }, }) +watch(() => form.values.client_type, (clientType) => { + if (clientType !== 'openai-codex') return + if (!form.values.base_url) { + form.setFieldValue('base_url', 'https://chatgpt.com/backend-api') + } +}) + const createProvider = form.handleSubmit(async (value) => { await run( () => createProviderMutation(value), { fallbackMessage: t('common.saveFailed'), - onSuccess: () => { + onSuccess: () => { open.value = false form.resetForm() }, diff --git a/apps/web/src/constants/client-types.ts b/apps/web/src/constants/client-types.ts index ffcd651a..318a8d1c 100644 --- a/apps/web/src/constants/client-types.ts +++ b/apps/web/src/constants/client-types.ts @@ -15,6 +15,11 @@ export const CLIENT_TYPE_META: Record = { label: 'OpenAI Completions', hint: 'Chat Completions API (widely compatible)', }, + 'openai-codex': { + value: 'openai-codex', + label: 'OpenAI Codex', + hint: 'Codex API (OAuth, coding-optimized)', + }, 'anthropic-messages': { value: 'anthropic-messages', label: 'Anthropic Messages', diff --git a/apps/web/src/i18n/locales/en.json b/apps/web/src/i18n/locales/en.json index 3d8603bf..08737c77 100644 --- a/apps/web/src/i18n/locales/en.json +++ b/apps/web/src/i18n/locales/en.json @@ -34,9 +34,7 @@ "loadFailed": "Failed to load", "saveFailed": "Failed to save", "createdAt": "Created at", - "none": "None", - "searchTimezone": "Search timezones…", - "noTimezoneFound": "No timezone found." + "none": "None" }, "auth": { "welcome": "Welcome Back", @@ -230,7 +228,27 @@ "icon": "Icon", "iconPlaceholder": "Icon URL or identifier (optional)", "enable": "Enable", - "enableHint": "Only models from enabled providers appear in the available model list" + "enableHint": "Only models from enabled providers appear in the available model list", + "oauth": { + "title": "OpenAI OAuth", + "description": "Authorize this provider with your ChatGPT account for Codex-compatible OpenAI access.", + "createHint": "Save the provider first, then authorize it from the provider details panel.", + "authorize": "Authorize", + "authorizeFailed": "Failed to start authorization", + "authorizeSuccess": "Authorization successful", + "revoke": "Revoke", + "revokeFailed": "Failed to revoke authorization", + "revokeSuccess": "Authorization revoked", + "callback": "Callback URL", + "statusFailed": "Failed to load OAuth status", + "status": { + "checking": "Checking authorization status...", + "authorized": "Authorized", + "expired": "Authorization expired. Re-authorize to continue.", + "missing": "Not authorized yet.", + "notConfigured": "This provider is not configured for OAuth." + } + } }, "searchProvider": { "title": "Search Providers", @@ -784,9 +802,11 @@ "language": "Language", "reasoningEnabled": "Enable Reasoning", "reasoningEffort": "Reasoning Effort", + "reasoningEffortNone": "None", "reasoningEffortLow": "Low", "reasoningEffortMedium": "Medium", "reasoningEffortHigh": "High", + "reasoningEffortXHigh": "X-High", "heartbeatEnabled": "Enable Heartbeat", "heartbeatDescription": "Periodically trigger agent to check for items that need attention", "heartbeatInterval": "Heartbeat Interval (minutes)", diff --git a/apps/web/src/i18n/locales/zh.json b/apps/web/src/i18n/locales/zh.json index 4e48b425..723cda76 100644 --- a/apps/web/src/i18n/locales/zh.json +++ b/apps/web/src/i18n/locales/zh.json @@ -34,9 +34,7 @@ "loadFailed": "加载失败", "saveFailed": "保存失败", "createdAt": "创建时间", - "none": "无", - "searchTimezone": "搜索时区…", - "noTimezoneFound": "未找到时区" + "none": "无" }, "auth": { "welcome": "欢迎回来", @@ -226,7 +224,27 @@ "icon": "图标", "iconPlaceholder": "图标 URL 或标识(可选)", "enable": "启用", - "enableHint": "只有启用的供应商的模型才会出现在可用模型列表中" + "enableHint": "只有启用的供应商的模型才会出现在可用模型列表中", + "oauth": { + "title": "OpenAI OAuth", + "description": "使用你的 ChatGPT 账号为该提供商授权,以启用 Codex 兼容的 OpenAI 访问。", + "createHint": "请先保存提供商,再到详情面板完成授权。", + "authorize": "授权", + "authorizeFailed": "启动授权失败", + "authorizeSuccess": "授权成功", + "revoke": "撤销授权", + "revokeFailed": "撤销授权失败", + "revokeSuccess": "授权已撤销", + "callback": "回调地址", + "statusFailed": "加载 OAuth 状态失败", + "status": { + "checking": "正在检查授权状态...", + "authorized": "已授权", + "expired": "授权已过期,请重新授权。", + "missing": "尚未授权。", + "notConfigured": "当前提供商未正确配置 OAuth。" + } + } }, "searchProvider": { "title": "搜索提供方", @@ -780,9 +798,11 @@ "language": "语言", "reasoningEnabled": "启用推理", "reasoningEffort": "推理等级", + "reasoningEffortNone": "无", "reasoningEffortLow": "低", "reasoningEffortMedium": "中", "reasoningEffortHigh": "高", + "reasoningEffortXHigh": "超高", "heartbeatEnabled": "启用心跳", "heartbeatDescription": "定期触发 Agent 检查是否有需要关注的事项", "heartbeatInterval": "心跳间隔(分钟)", diff --git a/apps/web/src/pages/bots/components/bot-settings.vue b/apps/web/src/pages/bots/components/bot-settings.vue index 69ac8acc..d07dc28f 100644 --- a/apps/web/src/pages/bots/components/bot-settings.vue +++ b/apps/web/src/pages/bots/components/bot-settings.vue @@ -197,21 +197,6 @@ /> - -
- - -

- {{ $t('bots.timezoneInheritedHint') }} -

-
- @@ -272,15 +257,36 @@ - + + {{ $t('bots.settings.reasoningEffortNone') }} + + {{ $t('bots.settings.reasoningEffortLow') }} - + {{ $t('bots.settings.reasoningEffortMedium') }} - + {{ $t('bots.settings.reasoningEffortHigh') }} + + {{ $t('bots.settings.reasoningEffortXHigh') }} + @@ -349,20 +355,18 @@ import { SelectTrigger, SelectValue, } from '@memohai/ui' -import { reactive, computed, watch, ref } from 'vue' +import { reactive, computed, watch } from 'vue' import { useRouter } from 'vue-router' import { toast } from 'vue-sonner' import { useI18n } from 'vue-i18n' import ConfirmPopover from '@/components/confirm-popover/index.vue' -import TimezoneSelect from '@/components/timezone-select/index.vue' -import { emptyTimezoneValue } from '@/utils/timezones' import ModelSelect from './model-select.vue' import SearchProviderSelect from './search-provider-select.vue' import MemoryProviderSelect from './memory-provider-select.vue' import TtsModelSelect from './tts-model-select.vue' import BrowserContextSelect from './browser-context-select.vue' import { useQuery, useMutation, useQueryCache } from '@pinia/colada' -import { getBotsById, putBotsById, getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getMemoryProviders, getTtsProviders, getBrowserContexts, getBotsByBotIdMemoryStatus, postBotsByBotIdMemoryRebuild } from '@memohai/sdk' +import { getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getMemoryProviders, getTtsProviders, getBrowserContexts, getBotsByBotIdMemoryStatus, postBotsByBotIdMemoryRebuild } from '@memohai/sdk' import type { SettingsSettings } from '@memohai/sdk' import type { Ref } from 'vue' import { resolveApiErrorMessage } from '@/utils/api-error' @@ -379,15 +383,6 @@ const botIdRef = computed(() => props.botId) as Ref // ---- Data ---- const queryCache = useQueryCache() -const { data: bot } = useQuery({ - key: () => ['bot', botIdRef.value], - query: async () => { - const { data } = await getBotsById({ path: { id: botIdRef.value }, throwOnError: true }) - return data - }, - enabled: () => !!botIdRef.value, -}) - const { data: settings } = useQuery({ key: () => ['bot-settings', botIdRef.value], query: async () => { @@ -503,31 +498,6 @@ const form = reactive({ reasoning_effort: 'medium', }) -const timezone = ref('') - -const timezoneModel = computed(() => timezone.value || emptyTimezoneValue) - -function onTimezoneChange(value: string) { - timezone.value = value === emptyTimezoneValue ? '' : value -} - -watch(bot, (val) => { - if (val) { - timezone.value = val.timezone || '' - } -}, { immediate: true }) - -const { mutateAsync: updateBot } = useMutation({ - mutation: async ({ id, ...body }: Record & { id: string }) => { - const { data } = await putBotsById({ path: { id }, body, throwOnError: true }) - return data - }, - onSettled: () => { - queryCache.invalidateQueries({ key: ['bots'] }) - queryCache.invalidateQueries({ key: ['bot'] }) - }, -}) - const selectedMemoryProvider = computed(() => memoryProviders.value.find((provider) => provider.id === form.memory_provider_id), ) @@ -582,6 +552,20 @@ const chatModelSupportsReasoning = computed(() => { return !!m?.config?.compatibilities?.includes('reasoning') }) +const availableReasoningEfforts = computed(() => { + if (!form.chat_model_id) return ['low', 'medium', 'high'] + const model = models.value.find((m) => m.id === form.chat_model_id) + const efforts = ((model?.config as { reasoning_efforts?: string[] } | undefined)?.reasoning_efforts ?? []) + .filter((effort) => ['none', 'low', 'medium', 'high', 'xhigh'].includes(effort)) + return efforts.length > 0 ? efforts : ['low', 'medium', 'high'] +}) + +watch(availableReasoningEfforts, (efforts) => { + if (!efforts.includes(form.reasoning_effort)) { + form.reasoning_effort = efforts.includes('medium') ? 'medium' : efforts[0] ?? 'medium' + } +}, { immediate: true }) + const { data: memoryStatusData, isLoading: isMemoryStatusLoading } = useQuery({ key: () => ['bot-memory-status', botIdRef.value, persistedMemoryProviderID.value], query: async () => { @@ -641,12 +625,10 @@ watch(settings, (val) => { }, { immediate: true }) const hasChanges = computed(() => { - const timezoneChanged = timezone.value !== (bot.value?.timezone || '') if (!settings.value) return true const s = settings.value let changed = - timezoneChanged - || form.chat_model_id !== (s.chat_model_id ?? '') + form.chat_model_id !== (s.chat_model_id ?? '') || form.title_model_id !== (s.title_model_id ?? '') || form.search_provider_id !== (s.search_provider_id ?? '') || form.memory_provider_id !== (s.memory_provider_id ?? '') @@ -662,11 +644,7 @@ const hasChanges = computed(() => { async function handleSave() { try { - const promises: Promise[] = [updateSettings({ ...form })] - if (timezone.value !== (bot.value?.timezone || '')) { - promises.push(updateBot({ id: botIdRef.value, timezone: timezone.value })) - } - await Promise.all(promises) + await updateSettings({ ...form }) toast.success(t('bots.settings.saveSuccess')) } catch { return diff --git a/apps/web/src/pages/models/components/model-item.vue b/apps/web/src/pages/models/components/model-item.vue index 8e67f249..58f48685 100644 --- a/apps/web/src/pages/models/components/model-item.vue +++ b/apps/web/src/pages/models/components/model-item.vue @@ -30,6 +30,14 @@ > {{ $t(`models.compatibility.${cap}`, cap) }} + + {{ effort }} + (null) +const reasoningEfforts = computed(() => ((props.model.config as ModelConfigWithReasoning | undefined)?.reasoning_efforts ?? [])) const statusDotClass = computed(() => { switch (testResult.value?.status) { diff --git a/apps/web/src/pages/models/components/provider-form.vue b/apps/web/src/pages/models/components/provider-form.vue index 2e78a5dd..c19b9cc0 100644 --- a/apps/web/src/pages/models/components/provider-form.vue +++ b/apps/web/src/pages/models/components/provider-form.vue @@ -44,7 +44,10 @@
-
+

{{ $t('provider.apiKey') }}

@@ -56,7 +59,7 @@ @@ -106,6 +109,67 @@
+ +
+
+
+ {{ $t('provider.oauth.title') }} +
+
+ {{ $t('provider.oauth.description') }} +
+
+ + + + + +
+
+ {{ $t('provider.oauth.callback') }}: {{ oauthStatus.callback_url }} +
+
+
+ + + {{ $t('provider.oauth.authorize') }} + + + {{ $t('provider.oauth.revoke') }} + +
+
@@ -206,11 +270,22 @@ import { useForm } from 'vee-validate' import { postProvidersByIdTest } from '@memohai/sdk' import type { ProvidersGetResponse, ProvidersTestResponse } from '@memohai/sdk' import { useI18n } from 'vue-i18n' +import { toast } from 'vue-sonner' const { t } = useI18n() +type ProviderWithAuth = Partial + +type ProviderOAuthStatus = { + configured: boolean + has_token: boolean + expired: boolean + callback_url?: string + expires_at?: string +} + const props = defineProps<{ - provider: Partial | undefined + provider: ProviderWithAuth | undefined editLoading: boolean deleteLoading: boolean }>() @@ -223,6 +298,13 @@ const emit = defineEmits<{ const testLoading = ref(false) const testResult = ref(null) const testError = ref('') +const oauthStatus = ref(null) +const oauthStatusLoading = ref(false) +const authorizeLoading = ref(false) +const revokeLoading = ref(false) +const apiBase = import.meta.env.VITE_API_URL?.trim() || '/api' + +const providerWithAuth = computed(() => props.provider as ProviderWithAuth | undefined) async function runTest() { if (!props.provider?.id) return @@ -265,6 +347,14 @@ const providerSchema = toTypedSchema(z.object({ metadata: z.object({ additionalProp1: z.object({}), }), +}).superRefine((value, ctx) => { + if (value.client_type !== 'openai-codex' && !value.api_key?.trim() && !providerWithAuth.value?.api_key) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + path: ['api_key'], + message: 'API key is required', + }) + } })) const form = useForm({ @@ -283,6 +373,24 @@ watch(() => props.provider, (newVal) => { } }, { immediate: true }) +watch(() => form.values.client_type, (clientType) => { + if (clientType !== 'openai-codex') { + oauthStatus.value = null + return + } + if (!form.values.base_url) { + form.setFieldValue('base_url', 'https://chatgpt.com/backend-api') + } +}) + +watch(() => [props.provider?.id, form.values.client_type] as const, async ([id, clientType]) => { + if (!id || clientType !== 'openai-codex') { + oauthStatus.value = null + return + } + await fetchOAuthStatus() +}, { immediate: true }) + const hasChanges = computed(() => { const raw = props.provider const baseChanged = JSON.stringify({ @@ -316,4 +424,78 @@ const editProvider = form.handleSubmit(async (value) => { } emit('submit', payload) }) + +const oauthExpired = computed(() => Boolean(oauthStatus.value?.has_token && oauthStatus.value?.expired)) +const canAuthorizeOAuth = computed(() => + Boolean( + props.provider?.id + && form.values.client_type === 'openai-codex', + ) && !oauthStatusLoading.value, +) + +function authHeaders(): Record { + const token = localStorage.getItem('token') + return token ? { Authorization: `Bearer ${token}` } : {} +} + +async function fetchOAuthStatus() { + if (!props.provider?.id) return + oauthStatusLoading.value = true + try { + const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/status`, { + headers: authHeaders(), + }) + if (!response.ok) throw new Error(t('provider.oauth.statusFailed')) + oauthStatus.value = await response.json() as ProviderOAuthStatus + } catch (error) { + oauthStatus.value = null + console.error('failed to load provider oauth status', error) + } finally { + oauthStatusLoading.value = false + } +} + +async function handleAuthorize() { + if (!props.provider?.id) return + authorizeLoading.value = true + try { + const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/authorize`, { + headers: authHeaders(), + }) + if (!response.ok) throw new Error(t('provider.oauth.authorizeFailed')) + const data = await response.json() as { auth_url?: string } + if (!data.auth_url) throw new Error(t('provider.oauth.authorizeFailed')) + const popup = window.open(data.auth_url, 'provider-oauth', 'width=600,height=720') + const listener = async (event: MessageEvent) => { + if (event.data?.type !== 'memoh-provider-oauth-success') return + window.removeEventListener('message', listener) + popup?.close() + toast.success(t('provider.oauth.authorizeSuccess')) + await fetchOAuthStatus() + } + window.addEventListener('message', listener) + } catch (error) { + toast.error(error instanceof Error ? error.message : t('provider.oauth.authorizeFailed')) + } finally { + authorizeLoading.value = false + } +} + +async function handleRevoke() { + if (!props.provider?.id) return + revokeLoading.value = true + try { + const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/token`, { + method: 'DELETE', + headers: authHeaders(), + }) + if (!response.ok) throw new Error(t('provider.oauth.revokeFailed')) + toast.success(t('provider.oauth.revokeSuccess')) + await fetchOAuthStatus() + } catch (error) { + toast.error(error instanceof Error ? error.message : t('provider.oauth.revokeFailed')) + } finally { + revokeLoading.value = false + } +} diff --git a/cmd/agent/main.go b/cmd/agent/main.go index a47129eb..9ba59d9d 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -165,7 +165,7 @@ func runServe() { accounts.NewService, acl.NewService, settings.NewService, - providers.NewService, + provideProvidersService, searchproviders.NewService, browsercontexts.NewService, policy.NewService, @@ -228,6 +228,7 @@ func runServe() { provideServerHandler(provideSessionHandler), provideServerHandler(handlers.NewSwaggerHandler), provideServerHandler(handlers.NewProvidersHandler), + provideServerHandler(handlers.NewProviderOAuthHandler), provideServerHandler(handlers.NewSearchProvidersHandler), provideServerHandler(handlers.NewModelsHandler), provideServerHandler(handlers.NewSettingsHandler), @@ -765,6 +766,15 @@ func provideEmailRegistry(log *slog.Logger, tokenStore *emailpkg.DBOAuthTokenSto return reg } +func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *providers.Service { + _ = cfg + return providers.NewService(log, queries, defaultProviderOAuthCallbackURL()) +} + +func defaultProviderOAuthCallbackURL() string { + return "http://localhost:1455/auth/callback" +} + func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, tokenStore *emailpkg.DBOAuthTokenStore, cfg config.Config) *handlers.EmailOAuthHandler { addr := strings.TrimSpace(cfg.Server.Addr) if addr == "" { diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index b3d0bdfb..bc40aceb 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -106,7 +106,7 @@ func runServe() { accounts.NewService, acl.NewService, settings.NewService, - providers.NewService, + provideProvidersService, searchproviders.NewService, policy.NewService, mcp.NewConnectionService, @@ -154,6 +154,7 @@ func runServe() { provideServerHandler(provideSessionHandler), provideServerHandler(handlers.NewSwaggerHandler), provideServerHandler(handlers.NewProvidersHandler), + provideServerHandler(handlers.NewProviderOAuthHandler), provideServerHandler(handlers.NewSearchProvidersHandler), provideServerHandler(handlers.NewModelsHandler), provideServerHandler(handlers.NewSettingsHandler), @@ -858,6 +859,15 @@ func provideEmailRegistry(log *slog.Logger, tokenStore *emailpkg.DBOAuthTokenSto return reg } +func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *providers.Service { + _ = cfg + return providers.NewService(log, queries, defaultProviderOAuthCallbackURL()) +} + +func defaultProviderOAuthCallbackURL() string { + return "http://localhost:1455/auth/callback" +} + func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, tokenStore *emailpkg.DBOAuthTokenStore, cfg config.Config) *handlers.EmailOAuthHandler { addr := strings.TrimSpace(cfg.Server.Addr) if addr == "" { diff --git a/conf/providers/codex.yaml b/conf/providers/codex.yaml new file mode 100644 index 00000000..618bfae3 --- /dev/null +++ b/conf/providers/codex.yaml @@ -0,0 +1,41 @@ +name: OpenAI Codex +client_type: openai-codex +icon: openai +base_url: https://chatgpt.com/backend-api + +models: + - model_id: gpt-5.2 + name: GPT-5.2 + type: chat + config: + compatibilities: [tool-call, reasoning] + + - model_id: gpt-5.2-codex + name: GPT-5.2 Codex + type: chat + config: + compatibilities: [tool-call, reasoning] + + - model_id: gpt-5.1-codex + name: GPT-5.1 Codex + type: chat + config: + compatibilities: [tool-call, reasoning] + + - model_id: gpt-5.1-codex-max + name: GPT-5.1 Codex Max + type: chat + config: + compatibilities: [tool-call, reasoning] + + - model_id: gpt-5.1-codex-mini + name: GPT-5.1 Codex Mini + type: chat + config: + compatibilities: [tool-call, reasoning] + + - model_id: gpt-5.1 + name: GPT-5.1 + type: chat + config: + compatibilities: [tool-call, reasoning] diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 8069c044..6b652fa2 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -69,7 +69,7 @@ CREATE TABLE IF NOT EXISTS llm_providers ( created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), CONSTRAINT llm_providers_name_unique UNIQUE (name), - CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai')) + CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex')) ); CREATE TABLE IF NOT EXISTS search_providers ( diff --git a/db/migrations/0046_llm_provider_oauth.down.sql b/db/migrations/0046_llm_provider_oauth.down.sql new file mode 100644 index 00000000..9d425405 --- /dev/null +++ b/db/migrations/0046_llm_provider_oauth.down.sql @@ -0,0 +1,5 @@ +-- 0046_llm_provider_oauth (rollback) +-- Remove OAuth token storage for LLM providers. + +DROP INDEX IF EXISTS idx_llm_provider_oauth_tokens_state; +DROP TABLE IF EXISTS llm_provider_oauth_tokens; diff --git a/db/migrations/0046_llm_provider_oauth.up.sql b/db/migrations/0046_llm_provider_oauth.up.sql new file mode 100644 index 00000000..e556976e --- /dev/null +++ b/db/migrations/0046_llm_provider_oauth.up.sql @@ -0,0 +1,18 @@ +-- 0046_llm_provider_oauth +-- Add OAuth token storage for LLM providers to support OpenAI Codex OAuth. + +CREATE TABLE IF NOT EXISTS llm_provider_oauth_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + llm_provider_id UUID NOT NULL UNIQUE REFERENCES llm_providers(id) ON DELETE CASCADE, + access_token TEXT NOT NULL DEFAULT '', + refresh_token TEXT NOT NULL DEFAULT '', + expires_at TIMESTAMPTZ, + scope TEXT NOT NULL DEFAULT '', + token_type TEXT NOT NULL DEFAULT '', + state TEXT NOT NULL DEFAULT '', + pkce_code_verifier TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_llm_provider_oauth_tokens_state ON llm_provider_oauth_tokens(state) WHERE state != ''; diff --git a/db/migrations/0047_add_openai_codex_client_type.down.sql b/db/migrations/0047_add_openai_codex_client_type.down.sql new file mode 100644 index 00000000..6eb02162 --- /dev/null +++ b/db/migrations/0047_add_openai_codex_client_type.down.sql @@ -0,0 +1,11 @@ +-- 0047_add_openai_codex_client_type (rollback) +-- Revert openai-codex rows back to openai-responses and restore the old CHECK constraint. + +UPDATE llm_providers +SET client_type = 'openai-responses', + updated_at = now() +WHERE client_type = 'openai-codex'; + +ALTER TABLE llm_providers DROP CONSTRAINT IF EXISTS llm_providers_client_type_check; +ALTER TABLE llm_providers ADD CONSTRAINT llm_providers_client_type_check + CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai')); diff --git a/db/migrations/0047_add_openai_codex_client_type.up.sql b/db/migrations/0047_add_openai_codex_client_type.up.sql new file mode 100644 index 00000000..37d7b527 --- /dev/null +++ b/db/migrations/0047_add_openai_codex_client_type.up.sql @@ -0,0 +1,12 @@ +-- 0047_add_openai_codex_client_type +-- Add openai-codex as a first-class client_type and migrate existing codex-oauth providers. + +ALTER TABLE llm_providers DROP CONSTRAINT IF EXISTS llm_providers_client_type_check; +ALTER TABLE llm_providers ADD CONSTRAINT llm_providers_client_type_check + CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex')); + +UPDATE llm_providers +SET client_type = 'openai-codex', + updated_at = now() +WHERE client_type = 'openai-responses' + AND metadata->>'auth_type' = 'openai-codex-oauth'; diff --git a/db/queries/llm_provider_oauth.sql b/db/queries/llm_provider_oauth.sql new file mode 100644 index 00000000..48499714 --- /dev/null +++ b/db/queries/llm_provider_oauth.sql @@ -0,0 +1,52 @@ +-- name: UpsertLlmProviderOAuthToken :one +INSERT INTO llm_provider_oauth_tokens ( + llm_provider_id, + access_token, + refresh_token, + expires_at, + scope, + token_type, + state, + pkce_code_verifier +) +VALUES ( + sqlc.arg(llm_provider_id), + sqlc.arg(access_token), + sqlc.arg(refresh_token), + sqlc.arg(expires_at), + sqlc.arg(scope), + sqlc.arg(token_type), + sqlc.arg(state), + sqlc.arg(pkce_code_verifier) +) +ON CONFLICT (llm_provider_id) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + expires_at = EXCLUDED.expires_at, + scope = EXCLUDED.scope, + token_type = EXCLUDED.token_type, + state = EXCLUDED.state, + pkce_code_verifier = EXCLUDED.pkce_code_verifier, + updated_at = now() +RETURNING *; + +-- name: GetLlmProviderOAuthTokenByProvider :one +SELECT * FROM llm_provider_oauth_tokens WHERE llm_provider_id = sqlc.arg(llm_provider_id); + +-- name: GetLlmProviderOAuthTokenByState :one +SELECT * FROM llm_provider_oauth_tokens WHERE state = sqlc.arg(state) AND state != ''; + +-- name: UpdateLlmProviderOAuthState :exec +INSERT INTO llm_provider_oauth_tokens (llm_provider_id, state, pkce_code_verifier) +VALUES ( + sqlc.arg(llm_provider_id), + sqlc.arg(state), + sqlc.arg(pkce_code_verifier) +) +ON CONFLICT (llm_provider_id) DO UPDATE SET + state = EXCLUDED.state, + pkce_code_verifier = EXCLUDED.pkce_code_verifier, + updated_at = now(); + +-- name: DeleteLlmProviderOAuthToken :exec +DELETE FROM llm_provider_oauth_tokens WHERE llm_provider_id = sqlc.arg(llm_provider_id); diff --git a/devenv/Dockerfile.server b/devenv/Dockerfile.server index ad4f094e..60582207 100644 --- a/devenv/Dockerfile.server +++ b/devenv/Dockerfile.server @@ -59,6 +59,6 @@ RUN chmod +x /entrypoint.sh VOLUME ["/var/lib/containerd", "/opt/memoh/data"] -EXPOSE 8080 +EXPOSE 8080 1455 ENTRYPOINT ["/entrypoint.sh"] diff --git a/devenv/docker-compose.yml b/devenv/docker-compose.yml index 4a733467..18903646 100644 --- a/devenv/docker-compose.yml +++ b/devenv/docker-compose.yml @@ -95,6 +95,7 @@ services: - /etc/localtime:/etc/localtime:ro ports: - "${MEMOH_DEV_SERVER_PORT:-18080}:8080" + - "${MEMOH_DEV_OAUTH_PORT:-1455}:8080" healthcheck: test: ["CMD-SHELL", "wget --no-verbose --tries=1 --spider http://127.0.0.1:8080/health || exit 1"] interval: 5s diff --git a/docker-compose.yml b/docker-compose.yml index 1820e3da..b512db42 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -47,6 +47,7 @@ services: - /etc/localtime:/etc/localtime:ro ports: - "8080:8080" + - "1455:8080" depends_on: migrate: condition: service_completed_successfully diff --git a/docker/Dockerfile.server b/docker/Dockerfile.server index 8515b5dc..40484cd5 100644 --- a/docker/Dockerfile.server +++ b/docker/Dockerfile.server @@ -117,7 +117,7 @@ RUN mkdir -p /opt/memoh/data /run/containerd /var/lib/containerd VOLUME ["/var/lib/containerd", "/opt/memoh/data"] -EXPOSE 8080 +EXPOSE 8080 1455 HEALTHCHECK --interval=30s --timeout=3s --start-period=30s --retries=3 \ CMD wget --no-verbose --tries=1 --spider http://127.0.0.1:8080/health \ diff --git a/go.mod b/go.mod index faa51901..3d840ff7 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mailgun/mailgun-go/v5 v5.14.0 github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 - github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd + github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584 github.com/modelcontextprotocol/go-sdk v1.4.1 github.com/opencontainers/image-spec v1.1.1 github.com/opencontainers/runtime-spec v1.3.0 @@ -36,6 +36,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/swaggo/swag v1.16.6 github.com/wneessen/go-mail v0.7.2 + github.com/yuin/goldmark v1.7.13 go.uber.org/fx v1.24.0 golang.org/x/crypto v0.48.0 golang.org/x/oauth2 v0.35.0 @@ -123,7 +124,6 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - github.com/yuin/goldmark v1.7.13 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect @@ -143,3 +143,4 @@ require ( golang.org/x/tools v0.42.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect ) + diff --git a/go.sum b/go.sum index 8d73e22e..3cac68c7 100644 --- a/go.sum +++ b/go.sum @@ -228,8 +228,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 h1:beehwOQperqGWj4m4EhcPhnSZKtDiuHK/7ZMoTPaQjw= github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7/go.mod h1:OvmxM7JmnXBmwJWWVqtreL3HSHSKuzPbtbhlg5MvBg0= -github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd h1:uV7xsqYHYpEmT6xKvkOs5mHT5oEKnwV1F93ialqi78k= -github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd/go.mod h1:vHNoRb6/quMacMAgIp838aoiNhsZbE0bFCnRRNyRwNc= +github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584 h1:zu4T54unBe8ziIlTr8gUtFoR16c6u2G1Qpx8OGsenxo= +github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584/go.mod h1:GZTT9GUT3uSs6zram/FcF24GLTZMFSpiybbYmjr+gH8= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg= diff --git a/internal/agent/spawn_adapter.go b/internal/agent/spawn_adapter.go index 03ff3ab4..61bfe2bc 100644 --- a/internal/agent/spawn_adapter.go +++ b/internal/agent/spawn_adapter.go @@ -2,6 +2,7 @@ package agent import ( "context" + "net/http" sdk "github.com/memohai/twilight-ai/sdk" @@ -69,14 +70,17 @@ func SpawnSystemPrompt(sessionType string) string { }) } -// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to models.NewSDKChatModel. +// SpawnModelCreatorFunc returns a tools.ModelCreator backed by the shared SDK model factory. +// This keeps subagent model creation aligned with the shared SDK model factory. func SpawnModelCreatorFunc() tools.ModelCreator { - return func(modelID, clientType, apiKey, baseURL string) *sdk.Model { + return func(modelID, clientType, apiKey, codexAccountID, baseURL string, httpClient *http.Client) *sdk.Model { return models.NewSDKChatModel(models.SDKModelConfig{ - ModelID: modelID, - ClientType: clientType, - APIKey: apiKey, - BaseURL: baseURL, + ModelID: modelID, + ClientType: clientType, + APIKey: apiKey, + CodexAccountID: codexAccountID, + BaseURL: baseURL, + HTTPClient: httpClient, }) } } diff --git a/internal/agent/tools/subagent.go b/internal/agent/tools/subagent.go index 03406489..7100d7dd 100644 --- a/internal/agent/tools/subagent.go +++ b/internal/agent/tools/subagent.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "net/http" "strings" "sync" @@ -14,6 +15,7 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" messagepkg "github.com/memohai/memoh/internal/message" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/providers" sessionpkg "github.com/memohai/memoh/internal/session" "github.com/memohai/memoh/internal/settings" ) @@ -292,7 +294,7 @@ func (p *SpawnProvider) persistMessages( } // ModelCreator creates an sdk.Model from provider config. Set via SetModelCreator. -type ModelCreator func(modelID, clientType, apiKey, baseURL string) *sdk.Model +type ModelCreator func(modelID, clientType, apiKey, codexAccountID, baseURL string, httpClient *http.Client) *sdk.Model // SetModelCreator injects the function used to create SDK models // (typically agent.CreateModel wrapped to match the signature). @@ -323,7 +325,19 @@ func (p *SpawnProvider) resolveModel(ctx context.Context, botID string) (*sdk.Mo if p.modelCreator == nil { return nil, "", errors.New("model creator not configured") } - sdkModel := p.modelCreator(modelInfo.ModelID, provider.ClientType, provider.ApiKey, provider.BaseUrl) + authResolver := providers.NewService(nil, p.queries, "") + creds, err := authResolver.ResolveModelCredentials(ctx, provider) + if err != nil { + return nil, "", err + } + sdkModel := p.modelCreator( + modelInfo.ModelID, + provider.ClientType, + creds.APIKey, + creds.CodexAccountID, + provider.BaseUrl, + nil, + ) return sdkModel, modelInfo.ID, nil } diff --git a/internal/agent/types.go b/internal/agent/types.go index 14d0f2d9..c0db8404 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -2,6 +2,7 @@ package agent import ( "encoding/json" + "net/http" "time" sdk "github.com/memohai/twilight-ai/sdk" @@ -95,6 +96,23 @@ type SystemFile struct { Content string } +// ModelConfig holds provider and model information resolved from DB. +type ModelConfig struct { + ModelID string + ClientType string + APIKey string //nolint:gosec // carries provider credential material at runtime + CodexAccountID string + BaseURL string + HTTPClient *http.Client + ReasoningConfig *ReasoningConfig +} + +// ReasoningConfig controls extended thinking/reasoning behavior. +type ReasoningConfig struct { + Enabled bool + Effort string +} + func mustMarshal(v any) json.RawMessage { data, err := json.Marshal(v) if err != nil { diff --git a/internal/compaction/service.go b/internal/compaction/service.go index df16c1c6..61aabe82 100644 --- a/internal/compaction/service.go +++ b/internal/compaction/service.go @@ -104,10 +104,12 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU userPrompt := buildUserPrompt(priorSummaries, entries) model := models.NewSDKChatModel(models.SDKModelConfig{ - ClientType: cfg.ClientType, - BaseURL: cfg.BaseURL, - APIKey: cfg.APIKey, - ModelID: cfg.ModelID, + ClientType: cfg.ClientType, + BaseURL: cfg.BaseURL, + APIKey: cfg.APIKey, + CodexAccountID: cfg.CodexAccountID, + ModelID: cfg.ModelID, + HTTPClient: cfg.HTTPClient, }) result, err := sdk.GenerateTextResult(ctx, diff --git a/internal/compaction/types.go b/internal/compaction/types.go index 4d13b628..8a352b37 100644 --- a/internal/compaction/types.go +++ b/internal/compaction/types.go @@ -1,6 +1,9 @@ package compaction -import "time" +import ( + "net/http" + "time" +) // Log represents a compaction log entry. type Log struct { @@ -24,10 +27,12 @@ type ListLogsResponse struct { // TriggerConfig holds the parameters needed to trigger a compaction. type TriggerConfig struct { - BotID string - SessionID string - ModelID string - ClientType string - APIKey string //nolint:gosec // runtime credential, not a hardcoded secret - BaseURL string + BotID string + SessionID string + ModelID string + ClientType string + APIKey string //nolint:gosec // runtime credential, not a hardcoded secret + CodexAccountID string + BaseURL string + HTTPClient *http.Client } diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 16f35527..6c3fac81 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "io" "log/slog" "math" @@ -24,6 +25,7 @@ import ( messagepkg "github.com/memohai/memoh/internal/message" messageevent "github.com/memohai/memoh/internal/message/event" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/providers" "github.com/memohai/memoh/internal/settings" ) @@ -278,10 +280,17 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r } } + authResolver := providers.NewService(nil, r.queries, "") + creds, err := authResolver.ResolveModelCredentials(ctx, provider) + if err != nil { + return resolvedContext{}, fmt.Errorf("resolve provider credentials: %w", err) + } + modelCfg := models.SDKModelConfig{ ModelID: chatModel.ModelID, ClientType: clientType, - APIKey: provider.ApiKey, + APIKey: creds.APIKey, + CodexAccountID: creds.CodexAccountID, BaseURL: provider.BaseUrl, ReasoningConfig: reasoningConfig, } diff --git a/internal/conversation/flow/resolver_compaction.go b/internal/conversation/flow/resolver_compaction.go index 98b19370..f09c7ab1 100644 --- a/internal/conversation/flow/resolver_compaction.go +++ b/internal/conversation/flow/resolver_compaction.go @@ -7,6 +7,7 @@ import ( "github.com/memohai/memoh/internal/compaction" "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/providers" ) func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatRequest, rc resolvedContext, inputTokens int) { @@ -47,8 +48,15 @@ func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatReques r.logger.Warn("compaction: failed to fetch provider", slog.Any("error", err)) return } + authResolver := providers.NewService(nil, r.queries, "") + creds, err := authResolver.ResolveModelCredentials(ctx, provider) + if err != nil { + r.logger.Warn("compaction: failed to resolve provider credentials", slog.Any("error", err)) + return + } cfg.ClientType = provider.ClientType - cfg.APIKey = provider.ApiKey + cfg.APIKey = creds.APIKey + cfg.CodexAccountID = creds.CodexAccountID cfg.BaseURL = provider.BaseUrl r.compactionService.TriggerCompaction(ctx, cfg) diff --git a/internal/conversation/flow/resolver_title.go b/internal/conversation/flow/resolver_title.go index f1453d40..e73aaf4c 100644 --- a/internal/conversation/flow/resolver_title.go +++ b/internal/conversation/flow/resolver_title.go @@ -13,6 +13,7 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" messageevent "github.com/memohai/memoh/internal/message/event" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/providers" "github.com/memohai/memoh/internal/session" ) @@ -104,11 +105,19 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse, "Return ONLY the title text, nothing else.\n\n" + "User: " + userSnippet + authResolver := providers.NewService(nil, r.queries, "") + creds, err := authResolver.ResolveModelCredentials(ctx, provider) + if err != nil { + r.logger.Warn("title gen: failed to resolve provider credentials", slog.Any("error", err)) + return "" + } + modelCfg := models.SDKModelConfig{ - ModelID: model.ModelID, - ClientType: provider.ClientType, - APIKey: provider.ApiKey, - BaseURL: provider.BaseUrl, + ModelID: model.ModelID, + ClientType: provider.ClientType, + APIKey: creds.APIKey, + CodexAccountID: creds.CodexAccountID, + BaseURL: provider.BaseUrl, } sdkModel := models.NewSDKChatModel(modelCfg) diff --git a/internal/db/sqlc/llm_provider_oauth.sql.go b/internal/db/sqlc/llm_provider_oauth.sql.go new file mode 100644 index 00000000..cd31ac5c --- /dev/null +++ b/internal/db/sqlc/llm_provider_oauth.sql.go @@ -0,0 +1,163 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: llm_provider_oauth.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const deleteLlmProviderOAuthToken = `-- name: DeleteLlmProviderOAuthToken :exec +DELETE FROM llm_provider_oauth_tokens WHERE llm_provider_id = $1 +` + +func (q *Queries) DeleteLlmProviderOAuthToken(ctx context.Context, llmProviderID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteLlmProviderOAuthToken, llmProviderID) + return err +} + +const getLlmProviderOAuthTokenByProvider = `-- name: GetLlmProviderOAuthTokenByProvider :one +SELECT id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at FROM llm_provider_oauth_tokens WHERE llm_provider_id = $1 +` + +func (q *Queries) GetLlmProviderOAuthTokenByProvider(ctx context.Context, llmProviderID pgtype.UUID) (LlmProviderOauthToken, error) { + row := q.db.QueryRow(ctx, getLlmProviderOAuthTokenByProvider, llmProviderID) + var i LlmProviderOauthToken + err := row.Scan( + &i.ID, + &i.LlmProviderID, + &i.AccessToken, + &i.RefreshToken, + &i.ExpiresAt, + &i.Scope, + &i.TokenType, + &i.State, + &i.PkceCodeVerifier, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getLlmProviderOAuthTokenByState = `-- name: GetLlmProviderOAuthTokenByState :one +SELECT id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at FROM llm_provider_oauth_tokens WHERE state = $1 AND state != '' +` + +func (q *Queries) GetLlmProviderOAuthTokenByState(ctx context.Context, state string) (LlmProviderOauthToken, error) { + row := q.db.QueryRow(ctx, getLlmProviderOAuthTokenByState, state) + var i LlmProviderOauthToken + err := row.Scan( + &i.ID, + &i.LlmProviderID, + &i.AccessToken, + &i.RefreshToken, + &i.ExpiresAt, + &i.Scope, + &i.TokenType, + &i.State, + &i.PkceCodeVerifier, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateLlmProviderOAuthState = `-- name: UpdateLlmProviderOAuthState :exec +INSERT INTO llm_provider_oauth_tokens (llm_provider_id, state, pkce_code_verifier) +VALUES ( + $1, + $2, + $3 +) +ON CONFLICT (llm_provider_id) DO UPDATE SET + state = EXCLUDED.state, + pkce_code_verifier = EXCLUDED.pkce_code_verifier, + updated_at = now() +` + +type UpdateLlmProviderOAuthStateParams struct { + LlmProviderID pgtype.UUID `json:"llm_provider_id"` + State string `json:"state"` + PkceCodeVerifier string `json:"pkce_code_verifier"` +} + +func (q *Queries) UpdateLlmProviderOAuthState(ctx context.Context, arg UpdateLlmProviderOAuthStateParams) error { + _, err := q.db.Exec(ctx, updateLlmProviderOAuthState, arg.LlmProviderID, arg.State, arg.PkceCodeVerifier) + return err +} + +const upsertLlmProviderOAuthToken = `-- name: UpsertLlmProviderOAuthToken :one +INSERT INTO llm_provider_oauth_tokens ( + llm_provider_id, + access_token, + refresh_token, + expires_at, + scope, + token_type, + state, + pkce_code_verifier +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8 +) +ON CONFLICT (llm_provider_id) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + expires_at = EXCLUDED.expires_at, + scope = EXCLUDED.scope, + token_type = EXCLUDED.token_type, + state = EXCLUDED.state, + pkce_code_verifier = EXCLUDED.pkce_code_verifier, + updated_at = now() +RETURNING id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at +` + +type UpsertLlmProviderOAuthTokenParams struct { + LlmProviderID pgtype.UUID `json:"llm_provider_id"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` + State string `json:"state"` + PkceCodeVerifier string `json:"pkce_code_verifier"` +} + +func (q *Queries) UpsertLlmProviderOAuthToken(ctx context.Context, arg UpsertLlmProviderOAuthTokenParams) (LlmProviderOauthToken, error) { + row := q.db.QueryRow(ctx, upsertLlmProviderOAuthToken, + arg.LlmProviderID, + arg.AccessToken, + arg.RefreshToken, + arg.ExpiresAt, + arg.Scope, + arg.TokenType, + arg.State, + arg.PkceCodeVerifier, + ) + var i LlmProviderOauthToken + err := row.Scan( + &i.ID, + &i.LlmProviderID, + &i.AccessToken, + &i.RefreshToken, + &i.ExpiresAt, + &i.Scope, + &i.TokenType, + &i.State, + &i.PkceCodeVerifier, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 627a125d..342c602d 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -292,6 +292,20 @@ type LlmProvider struct { ClientType string `json:"client_type"` } +type LlmProviderOauthToken struct { + ID pgtype.UUID `json:"id"` + LlmProviderID pgtype.UUID `json:"llm_provider_id"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresAt pgtype.Timestamptz `json:"expires_at"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` + State string `json:"state"` + PkceCodeVerifier string `json:"pkce_code_verifier"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + type McpConnection struct { ID pgtype.UUID `json:"id"` BotID pgtype.UUID `json:"bot_id"` diff --git a/internal/handlers/provider_oauth.go b/internal/handlers/provider_oauth.go new file mode 100644 index 00000000..619ef071 --- /dev/null +++ b/internal/handlers/provider_oauth.go @@ -0,0 +1,132 @@ +package handlers + +import ( + "html/template" + "net/http" + "strings" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/providers" +) + +type ProviderOAuthHandler struct { + service *providers.Service +} + +func NewProviderOAuthHandler(service *providers.Service) *ProviderOAuthHandler { + return &ProviderOAuthHandler{service: service} +} + +func (h *ProviderOAuthHandler) Register(e *echo.Echo) { + e.GET("/providers/:id/oauth/authorize", h.Authorize) + e.GET("/providers/:id/oauth/status", h.Status) + e.DELETE("/providers/:id/oauth/token", h.Revoke) + e.GET("/auth/callback", h.Callback) + e.GET("/providers/oauth/callback", h.Callback) +} + +// Authorize godoc +// @Summary Start OAuth2 authorization for an LLM provider +// @Tags providers-oauth +// @Param id path string true "Provider ID (UUID)" +// @Success 200 {object} map[string]string +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Router /providers/{id}/oauth/authorize [get]. +func (h *ProviderOAuthHandler) Authorize(c echo.Context) error { + providerID := strings.TrimSpace(c.Param("id")) + if providerID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + authURL, err := h.service.StartOAuthAuthorization(c.Request().Context(), providerID) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return c.JSON(http.StatusOK, map[string]string{"auth_url": authURL}) +} + +// Status godoc +// @Summary Get OAuth2 status for an LLM provider +// @Tags providers-oauth +// @Param id path string true "Provider ID (UUID)" +// @Success 200 {object} providers.OAuthStatus +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Router /providers/{id}/oauth/status [get]. +func (h *ProviderOAuthHandler) Status(c echo.Context) error { + providerID := strings.TrimSpace(c.Param("id")) + if providerID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + status, err := h.service.GetOAuthStatus(c.Request().Context(), providerID) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return c.JSON(http.StatusOK, status) +} + +// Revoke godoc +// @Summary Revoke stored OAuth2 tokens for an LLM provider +// @Tags providers-oauth +// @Param id path string true "Provider ID (UUID)" +// @Success 204 "No Content" +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Router /providers/{id}/oauth/token [delete]. +func (h *ProviderOAuthHandler) Revoke(c echo.Context) error { + providerID := strings.TrimSpace(c.Param("id")) + if providerID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + if err := h.service.RevokeOAuthToken(c.Request().Context(), providerID); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// Callback godoc +// @Summary OAuth2 callback for LLM providers +// @Tags providers-oauth +// @Param code query string true "Authorization code" +// @Param state query string true "State parameter" +// @Success 200 {string} string "HTML success page" +// @Failure 400 {object} ErrorResponse +// @Router /providers/oauth/callback [get]. +func (h *ProviderOAuthHandler) Callback(c echo.Context) error { + code := strings.TrimSpace(c.QueryParam("code")) + state := strings.TrimSpace(c.QueryParam("state")) + if code == "" { + return echo.NewHTTPError(http.StatusBadRequest, "code is required") + } + if state == "" { + return echo.NewHTTPError(http.StatusBadRequest, "state is required") + } + providerID, err := h.service.HandleOAuthCallback(c.Request().Context(), state, code) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + page := template.Must(template.New("oauth-success").Parse(` + + + + OpenAI OAuth Connected + + +

OpenAI OAuth connected

+

You can close this window and return to Memoh.

+ + +`)) + return c.HTML(http.StatusOK, executeHTMLTemplate(page, map[string]string{"ProviderID": providerID})) +} + +func executeHTMLTemplate(tpl *template.Template, data any) string { + var b strings.Builder + _ = tpl.Execute(&b, data) + return b.String() +} diff --git a/internal/handlers/providers.go b/internal/handlers/providers.go index b53466bd..37eaecf4 100644 --- a/internal/handlers/providers.go +++ b/internal/handlers/providers.go @@ -309,20 +309,31 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err)) } - defaultCompat := []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning} - resp := providers.ImportModelsResponse{ Models: make([]string, 0), } for _, m := range remoteModels { + modelType := models.ModelTypeChat + if strings.TrimSpace(m.Type) == string(models.ModelTypeEmbedding) { + modelType = models.ModelTypeEmbedding + } + compatibilities := m.Compatibilities + if len(compatibilities) == 0 { + compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning} + } + name := strings.TrimSpace(m.Name) + if name == "" { + name = m.ID + } _, err := h.modelsService.Create(c.Request().Context(), models.AddRequest{ ModelID: m.ID, - Name: m.ID, + Name: name, LlmProviderID: id, - Type: models.ModelTypeChat, + Type: modelType, Config: models.ModelConfig{ - Compatibilities: defaultCompat, + Compatibilities: compatibilities, + ReasoningEfforts: m.ReasoningEfforts, }, }) if err != nil { diff --git a/internal/memory/adapters/builtin/dense_runtime.go b/internal/memory/adapters/builtin/dense_runtime.go index d0006d0c..92d47c11 100644 --- a/internal/memory/adapters/builtin/dense_runtime.go +++ b/internal/memory/adapters/builtin/dense_runtime.go @@ -75,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg return nil, fmt.Errorf("dense runtime: %w", err) } - embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout) + embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout, nil) return &denseRuntime{ qdrant: qClient, diff --git a/internal/models/embedding.go b/internal/models/embedding.go index 6ff8fd0b..831b9e44 100644 --- a/internal/models/embedding.go +++ b/internal/models/embedding.go @@ -13,11 +13,13 @@ import ( // provider configuration. It dispatches to the native Google embedding provider // when clientType is "google-generative-ai", and falls back to the // OpenAI-compatible /embeddings endpoint for all other provider types. -func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel { +func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration, httpClient *http.Client) *sdk.EmbeddingModel { if timeout <= 0 { timeout = 30 * time.Second } - httpClient := &http.Client{Timeout: timeout} + if httpClient == nil { + httpClient = &http.Client{Timeout: timeout} + } switch ClientType(clientType) { case ClientTypeGoogleGenerativeAI: @@ -30,7 +32,6 @@ func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout t } p := googleembedding.New(opts...) return p.EmbeddingModel(modelID) - default: opts := []openaiembedding.Option{ openaiembedding.WithAPIKey(apiKey), diff --git a/internal/models/probe.go b/internal/models/probe.go index c37dc8fe..412ecf35 100644 --- a/internal/models/probe.go +++ b/internal/models/probe.go @@ -2,6 +2,9 @@ package models import ( "context" + "encoding/base64" + "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -9,11 +12,13 @@ import ( anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages" googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai" + openaicodex "github.com/memohai/twilight-ai/provider/openai/codex" openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions" openairesponses "github.com/memohai/twilight-ai/provider/openai/responses" sdk "github.com/memohai/twilight-ai/sdk" "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" ) const probeTimeout = 15 * time.Second @@ -37,16 +42,17 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { } baseURL := strings.TrimRight(provider.BaseUrl, "/") - apiKey := provider.ApiKey clientType := ClientType(provider.ClientType) - - // Embedding models don't have a chat Provider in the SDK — probe - // the /embeddings endpoint directly. - if model.Type == string(ModelTypeEmbedding) { - return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID) + creds, err := s.resolveModelCredentials(ctx, provider) + if err != nil { + return TestResponse{}, err } - sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout) + if model.Type == string(ModelTypeEmbedding) { + return s.testEmbeddingModel(ctx, baseURL, creds.APIKey, model.ModelID, nil) + } + + sdkProvider := NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil) start := time.Now() @@ -100,11 +106,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { // testEmbeddingModel probes an embedding model by performing a minimal // embedding request via the Twilight SDK, verifying that the model is // reachable and functional rather than merely checking HTTP connectivity. -func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) { +func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string, httpClient *http.Client) (TestResponse, error) { ctx, cancel := context.WithTimeout(ctx, probeTimeout) defer cancel() - model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout) + model := NewSDKEmbeddingModel(string(ClientTypeOpenAICompletions), baseURL, apiKey, modelID, probeTimeout, httpClient) client := sdk.NewClient() start := time.Now() @@ -130,8 +136,10 @@ func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, api // NewSDKProvider creates a Twilight AI SDK Provider for the given client type. // It is exported so that other packages (e.g. providers) can reuse it for testing. -func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.Duration) sdk.Provider { - httpClient := &http.Client{Timeout: timeout} +func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientType, timeout time.Duration, httpClient *http.Client) sdk.Provider { + if httpClient == nil { + httpClient = &http.Client{Timeout: timeout} + } switch clientType { case ClientTypeOpenAIResponses: @@ -144,6 +152,16 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time. } return openairesponses.New(opts...) + case ClientTypeOpenAICodex: + opts := []openaicodex.Option{ + openaicodex.WithAccessToken(apiKey), + openaicodex.WithHTTPClient(httpClient), + } + if codexAccountID != "" { + opts = append(opts, openaicodex.WithAccountID(codexAccountID)) + } + return openaicodex.New(opts...) + case ClientTypeAnthropicMessages: opts := []anthropicmessages.Option{ anthropicmessages.WithAPIKey(apiKey), @@ -175,3 +193,55 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time. return openaicompletions.New(opts...) } } + +type modelCredentials struct { + APIKey string //nolint:gosec // runtime credential material used to construct SDK providers + CodexAccountID string +} + +func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (modelCredentials, error) { + if ClientType(provider.ClientType) != ClientTypeOpenAICodex { + return modelCredentials{APIKey: provider.ApiKey}, nil + } + + tokenRow, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, provider.ID) + if err != nil { + return modelCredentials{}, err + } + accessToken := strings.TrimSpace(tokenRow.AccessToken) + if accessToken == "" { + return modelCredentials{}, errors.New("oauth token is missing access token") + } + accountID, err := codexAccountIDFromToken(accessToken) + if err != nil { + return modelCredentials{}, err + } + return modelCredentials{ + APIKey: accessToken, + CodexAccountID: accountID, + }, nil +} + +func codexAccountIDFromToken(token string) (string, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return "", errors.New("invalid oauth access token") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode oauth token payload: %w", err) + } + var claims struct { + OpenAIAuth struct { + ChatGPTAccountID string `json:"chatgpt_account_id"` + } `json:"https://api.openai.com/auth"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("parse oauth token payload: %w", err) + } + accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID) + if accountID == "" { + return "", errors.New("oauth access token missing chatgpt_account_id") + } + return accountID, nil +} diff --git a/internal/models/sdk.go b/internal/models/sdk.go index d5357c6a..edc307e5 100644 --- a/internal/models/sdk.go +++ b/internal/models/sdk.go @@ -1,10 +1,12 @@ package models import ( + "net/http" "strings" anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages" googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai" + openaicodex "github.com/memohai/twilight-ai/provider/openai/codex" openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions" openairesponses "github.com/memohai/twilight-ai/provider/openai/responses" sdk "github.com/memohai/twilight-ai/sdk" @@ -16,7 +18,9 @@ type SDKModelConfig struct { ModelID string ClientType string APIKey string //nolint:gosec // carries provider credential material at runtime + CodexAccountID string BaseURL string + HTTPClient *http.Client ReasoningConfig *ReasoningConfig } @@ -38,6 +42,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL)) } @@ -48,16 +55,34 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openairesponses.Option{ openairesponses.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, openairesponses.WithBaseURL(cfg.BaseURL)) } p := openairesponses.New(opts...) return p.ChatModel(cfg.ModelID) + case ClientTypeOpenAICodex: + opts := []openaicodex.Option{ + openaicodex.WithAccessToken(cfg.APIKey), + } + if cfg.HTTPClient != nil { + opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient)) + } + if cfg.CodexAccountID != "" { + opts = append(opts, openaicodex.WithAccountID(cfg.CodexAccountID)) + } + return openaicodex.New(opts...).ChatModel(cfg.ModelID) + case ClientTypeAnthropicMessages: opts := []anthropicmessages.Option{ anthropicmessages.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL)) } @@ -75,6 +100,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []googlegenerative.Option{ googlegenerative.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, googlegenerative.WithBaseURL(cfg.BaseURL)) } @@ -85,6 +113,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), } + if cfg.HTTPClient != nil { + opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) + } if cfg.BaseURL != "" { opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL)) } @@ -106,7 +137,7 @@ func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption { switch ClientType(cfg.ClientType) { case ClientTypeAnthropicMessages: return nil - case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions: + case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions, ClientTypeOpenAICodex: return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)} case ClientTypeGoogleGenerativeAI: return nil @@ -147,6 +178,8 @@ func ResolveClientType(model *sdk.Model) string { return string(ClientTypeAnthropicMessages) case strings.Contains(name, "google"): return string(ClientTypeGoogleGenerativeAI) + case strings.Contains(name, "codex"): + return string(ClientTypeOpenAICodex) case strings.Contains(name, "responses"): return string(ClientTypeOpenAIResponses) default: diff --git a/internal/models/types.go b/internal/models/types.go index 0514a5cf..ff8b3c6e 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -20,6 +20,7 @@ const ( ClientTypeOpenAICompletions ClientType = "openai-completions" ClientTypeAnthropicMessages ClientType = "anthropic-messages" ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai" + ClientTypeOpenAICodex ClientType = "openai-codex" ) const ( @@ -29,16 +30,33 @@ const ( CompatReasoning = "reasoning" ) +const ( + ReasoningEffortNone = "none" + ReasoningEffortLow = "low" + ReasoningEffortMedium = "medium" + ReasoningEffortHigh = "high" + ReasoningEffortXHigh = "xhigh" +) + // validCompatibilities enumerates accepted compatibility tokens. var validCompatibilities = map[string]struct{}{ CompatVision: {}, CompatToolCall: {}, CompatImageOutput: {}, CompatReasoning: {}, } +var validReasoningEfforts = map[string]struct{}{ + ReasoningEffortNone: {}, + ReasoningEffortLow: {}, + ReasoningEffortMedium: {}, + ReasoningEffortHigh: {}, + ReasoningEffortXHigh: {}, +} + // ModelConfig holds the JSONB config stored per model. type ModelConfig struct { - Dimensions *int `json:"dimensions,omitempty"` - Compatibilities []string `json:"compatibilities,omitempty"` - ContextWindow *int `json:"context_window,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` + Compatibilities []string `json:"compatibilities,omitempty"` + ContextWindow *int `json:"context_window,omitempty"` + ReasoningEfforts []string `json:"reasoning_efforts,omitempty"` } type Model struct { @@ -72,6 +90,11 @@ func (m *Model) Validate() error { return errors.New("invalid compatibility: " + c) } } + for _, effort := range m.Config.ReasoningEfforts { + if _, ok := validReasoningEfforts[effort]; !ok { + return errors.New("invalid reasoning effort: " + effort) + } + } return nil } diff --git a/internal/providers/credentials.go b/internal/providers/credentials.go new file mode 100644 index 00000000..dc5f5e8a --- /dev/null +++ b/internal/providers/credentials.go @@ -0,0 +1,69 @@ +package providers + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" +) + +const openAIAuthClaimPath = "https://api.openai.com/auth" + +type ModelCredentials struct { + APIKey string //nolint:gosec // runtime credential material used to construct SDK providers + CodexAccountID string +} + +func SupportsOpenAICodexOAuth(provider sqlc.LlmProvider) bool { + return supportsOAuth(provider) +} + +func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (ModelCredentials, error) { + if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex { + return ModelCredentials{ + APIKey: provider.ApiKey, + }, nil + } + + token, err := s.GetValidAccessToken(ctx, provider.ID.String()) + if err != nil { + return ModelCredentials{}, err + } + accountID, err := codexAccountIDFromToken(token) + if err != nil { + return ModelCredentials{}, err + } + return ModelCredentials{ + APIKey: token, + CodexAccountID: accountID, + }, nil +} + +func codexAccountIDFromToken(token string) (string, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return "", errors.New("invalid oauth access token") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode oauth token payload: %w", err) + } + var claims struct { + OpenAIAuth struct { + ChatGPTAccountID string `json:"chatgpt_account_id"` + } `json:"https://api.openai.com/auth"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("parse oauth token payload: %w", err) + } + accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID) + if accountID == "" { + return "", fmt.Errorf("oauth access token missing %s.chatgpt_account_id", openAIAuthClaimPath) + } + return accountID, nil +} diff --git a/internal/providers/oauth.go b/internal/providers/oauth.go new file mode 100644 index 00000000..a03c16b2 --- /dev/null +++ b/internal/providers/oauth.go @@ -0,0 +1,468 @@ +package providers + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db" + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" +) + +const ( + defaultOpenAICodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + defaultOpenAIAuthorizeURL = "https://auth.openai.com/oauth/authorize" + defaultOpenAITokenURL = "https://auth.openai.com/oauth/token" //nolint:gosec // OAuth endpoint URL, not a credential + defaultOpenAICallbackURL = "http://localhost:1455/auth/callback" + defaultOpenAIOAuthScopes = "openid profile email offline_access" + oauthExpirySkew = 30 * time.Second + providerOAuthHTTPTimeout = 15 * time.Second + metadataOAuthClientIDKey = "oauth_client_id" + metadataOAuthAuthorizeURLKey = "oauth_authorize_url" + metadataOAuthTokenURLKey = "oauth_token_url" //nolint:gosec // metadata key name, not a credential + metadataOAuthRedirectURIKey = "oauth_redirect_uri" + metadataOAuthScopesKey = "oauth_scopes" + metadataOAuthAudienceKey = "oauth_audience" + metadataOAuthUseIDOrgsFlagKey = "oauth_id_token_add_organizations" +) + +type providerOAuthToken struct { + ProviderID string `json:"provider_id"` + AccessToken string `json:"access_token"` //nolint:gosec // runtime credential storage + RefreshToken string `json:"refresh_token"` //nolint:gosec // runtime credential storage + ExpiresAt time.Time `json:"expires_at"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` + State string `json:"state"` + PKCECodeVerifier string `json:"pkce_code_verifier"` +} + +type openAIOAuthConfig struct { + ClientID string + AuthorizeURL string + TokenURL string + RedirectURI string + Scopes string + IDTokenAddOrganizations bool +} + +func providerMetadata(raw []byte) map[string]any { + if len(raw) == 0 { + return map[string]any{} + } + var metadata map[string]any + if err := json.Unmarshal(raw, &metadata); err != nil { + return map[string]any{} + } + if metadata == nil { + return map[string]any{} + } + return metadata +} + +func (s *Service) oauthConfig(metadata map[string]any) openAIOAuthConfig { + cfg := openAIOAuthConfig{ + ClientID: defaultOpenAICodexClientID, + AuthorizeURL: defaultOpenAIAuthorizeURL, + TokenURL: defaultOpenAITokenURL, + RedirectURI: firstNonEmpty(s.callbackURL, defaultOpenAICallbackURL), + Scopes: defaultOpenAIOAuthScopes, + IDTokenAddOrganizations: true, + } + if v, _ := metadata[metadataOAuthClientIDKey].(string); strings.TrimSpace(v) != "" { + cfg.ClientID = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthAuthorizeURLKey].(string); strings.TrimSpace(v) != "" { + cfg.AuthorizeURL = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthTokenURLKey].(string); strings.TrimSpace(v) != "" { + cfg.TokenURL = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthRedirectURIKey].(string); strings.TrimSpace(v) != "" { + cfg.RedirectURI = strings.TrimSpace(v) + } + if v, _ := metadata[metadataOAuthScopesKey].(string); strings.TrimSpace(v) != "" { + cfg.Scopes = strings.TrimSpace(v) + } + if v, ok := metadata[metadataOAuthUseIDOrgsFlagKey].(bool); ok { + cfg.IDTokenAddOrganizations = v + } + return cfg +} + +func supportsOAuth(provider sqlc.LlmProvider) bool { + return models.ClientType(provider.ClientType) == models.ClientTypeOpenAICodex +} + +func (s *Service) StartOAuthAuthorization(ctx context.Context, providerID string) (string, error) { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return "", err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return "", fmt.Errorf("get provider: %w", err) + } + if !supportsOAuth(provider) { + return "", errors.New("provider does not support oauth") + } + + cfg := s.oauthConfig(providerMetadata(provider.Metadata)) + codeVerifier, err := generateCodeVerifier() + if err != nil { + return "", fmt.Errorf("generate code verifier: %w", err) + } + state, err := generateState() + if err != nil { + return "", fmt.Errorf("generate state: %w", err) + } + if err := s.updateOAuthState(ctx, providerID, state, codeVerifier); err != nil { + return "", err + } + + params := url.Values{ + "response_type": {"code"}, + "client_id": {cfg.ClientID}, + "redirect_uri": {cfg.RedirectURI}, + "scope": {cfg.Scopes}, + "code_challenge": {computeCodeChallenge(codeVerifier)}, + "code_challenge_method": {"S256"}, + "state": {state}, + } + if cfg.IDTokenAddOrganizations { + params.Set("id_token_add_organizations", "true") + } + params.Set("codex_cli_simplified_flow", "true") + + return cfg.AuthorizeURL + "?" + params.Encode(), nil +} + +func (s *Service) HandleOAuthCallback(ctx context.Context, state, code string) (string, error) { + token, err := s.getOAuthTokenByState(ctx, state) + if err != nil { + return "", err + } + providerUUID, err := db.ParseUUID(token.ProviderID) + if err != nil { + return "", err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return "", fmt.Errorf("get provider: %w", err) + } + if !supportsOAuth(provider) { + return "", errors.New("provider does not support oauth") + } + + cfg := s.oauthConfig(providerMetadata(provider.Metadata)) + resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier) + if err != nil { + return "", err + } + if err := s.saveOAuthToken(ctx, provider.ID.String(), providerOAuthToken{ + ProviderID: provider.ID.String(), + AccessToken: resp.AccessToken, + RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken), + ExpiresAt: expiresAtFromNow(resp.ExpiresIn), + Scope: firstNonEmpty(resp.Scope, cfg.Scopes), + TokenType: firstNonEmpty(resp.TokenType, "Bearer"), + State: "", + PKCECodeVerifier: "", + }); err != nil { + return "", err + } + return provider.ID.String(), nil +} + +func (s *Service) GetOAuthStatus(ctx context.Context, providerID string) (*OAuthStatus, error) { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return nil, err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return nil, fmt.Errorf("get provider: %w", err) + } + status := &OAuthStatus{ + Configured: supportsOAuth(provider), + CallbackURL: s.oauthConfig(providerMetadata(provider.Metadata)).RedirectURI, + } + if !status.Configured { + return status, nil + } + + token, err := s.getOAuthToken(ctx, providerID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return status, nil + } + return nil, err + } + status.HasToken = strings.TrimSpace(token.AccessToken) != "" + if !token.ExpiresAt.IsZero() { + expiresAt := token.ExpiresAt + status.ExpiresAt = &expiresAt + status.Expired = time.Now().After(token.ExpiresAt) + } + return status, nil +} + +func (s *Service) RevokeOAuthToken(ctx context.Context, providerID string) error { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return fmt.Errorf("get provider: %w", err) + } + if !supportsOAuth(provider) { + return errors.New("provider does not support oauth") + } + return s.queries.DeleteLlmProviderOAuthToken(ctx, providerUUID) +} + +func (s *Service) GetValidAccessToken(ctx context.Context, providerID string) (string, error) { + token, err := s.getOAuthToken(ctx, providerID) + if err != nil { + return "", err + } + if strings.TrimSpace(token.AccessToken) == "" { + return "", errors.New("oauth token is missing access token") + } + if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) { + return token.AccessToken, nil + } + if strings.TrimSpace(token.RefreshToken) == "" { + return "", errors.New("oauth token expired and no refresh token is available") + } + + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return "", err + } + provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID) + if err != nil { + return "", fmt.Errorf("get provider: %w", err) + } + cfg := s.oauthConfig(providerMetadata(provider.Metadata)) + refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken) + if err != nil { + return "", err + } + saved := providerOAuthToken{ + ProviderID: providerID, + AccessToken: refreshed.AccessToken, + RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken), + ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn), + Scope: firstNonEmpty(refreshed.Scope, token.Scope), + TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType), + State: token.State, + PKCECodeVerifier: token.PKCECodeVerifier, + } + if err := s.saveOAuthToken(ctx, providerID, saved); err != nil { + return "", err + } + return saved.AccessToken, nil +} + +func (s *Service) getOAuthToken(ctx context.Context, providerID string) (*providerOAuthToken, error) { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return nil, err + } + row, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, providerUUID) + if err != nil { + return nil, err + } + return toProviderOAuthToken(row), nil +} + +func (s *Service) getOAuthTokenByState(ctx context.Context, state string) (*providerOAuthToken, error) { + row, err := s.queries.GetLlmProviderOAuthTokenByState(ctx, state) + if err != nil { + return nil, err + } + return toProviderOAuthToken(row), nil +} + +func (s *Service) updateOAuthState(ctx context.Context, providerID, state, codeVerifier string) error { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return err + } + return s.queries.UpdateLlmProviderOAuthState(ctx, sqlc.UpdateLlmProviderOAuthStateParams{ + LlmProviderID: providerUUID, + State: state, + PkceCodeVerifier: codeVerifier, + }) +} + +func (s *Service) saveOAuthToken(ctx context.Context, providerID string, token providerOAuthToken) error { + providerUUID, err := db.ParseUUID(providerID) + if err != nil { + return err + } + var expiresAt pgtype.Timestamptz + if !token.ExpiresAt.IsZero() { + expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true} + } + _, err = s.queries.UpsertLlmProviderOAuthToken(ctx, sqlc.UpsertLlmProviderOAuthTokenParams{ + LlmProviderID: providerUUID, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + ExpiresAt: expiresAt, + Scope: token.Scope, + TokenType: token.TokenType, + State: token.State, + PkceCodeVerifier: token.PKCECodeVerifier, + }) + return err +} + +func toProviderOAuthToken(row sqlc.LlmProviderOauthToken) *providerOAuthToken { + token := &providerOAuthToken{ + ProviderID: row.LlmProviderID.String(), + AccessToken: row.AccessToken, + RefreshToken: row.RefreshToken, + Scope: row.Scope, + TokenType: row.TokenType, + State: row.State, + PKCECodeVerifier: row.PkceCodeVerifier, + } + if row.ExpiresAt.Valid { + token.ExpiresAt = row.ExpiresAt.Time + } + return token +} + +type openAITokenResponse struct { + AccessToken string `json:"access_token"` //nolint:gosec // OAuth response payload carries runtime access token + RefreshToken string `json:"refresh_token"` //nolint:gosec // OAuth response payload carries runtime refresh token + TokenType string `json:"token_type"` + Scope string `json:"scope"` + ExpiresIn int64 `json:"expires_in"` + Error string `json:"error"` + Description string `json:"error_description"` +} + +func (s *Service) exchangeCode(ctx context.Context, cfg openAIOAuthConfig, code, codeVerifier string) (*openAITokenResponse, error) { + values := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "client_id": {cfg.ClientID}, + "redirect_uri": {cfg.RedirectURI}, + "code_verifier": {codeVerifier}, + } + return s.postTokenRequest(ctx, cfg.TokenURL, values) +} + +func (s *Service) refreshAccessToken(ctx context.Context, cfg openAIOAuthConfig, refreshToken string) (*openAITokenResponse, error) { + values := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {cfg.ClientID}, + } + return s.postTokenRequest(ctx, cfg.TokenURL, values) +} + +func (s *Service) postTokenRequest(ctx context.Context, tokenURL string, body url.Values) (*openAITokenResponse, error) { + if err := validateOAuthTokenURL(tokenURL); err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(body.Encode())) + if err != nil { + return nil, fmt.Errorf("create oauth request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + //nolint:gosec // tokenURL is restricted to the fixed OpenAI OAuth host by validateOAuthTokenURL above + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("execute oauth request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + payload, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read oauth response: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("oauth token request failed: %s", strings.TrimSpace(string(payload))) + } + + var tokenResp openAITokenResponse + if err := json.Unmarshal(payload, &tokenResp); err != nil { + return nil, fmt.Errorf("decode oauth response: %w", err) + } + if tokenResp.Error != "" { + return nil, fmt.Errorf("oauth token request failed: %s", firstNonEmpty(tokenResp.Description, tokenResp.Error)) + } + return &tokenResp, nil +} + +func validateOAuthTokenURL(raw string) error { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return fmt.Errorf("invalid oauth token url: %w", err) + } + if !strings.EqualFold(parsed.Scheme, "https") { + return errors.New("oauth token url must use https") + } + if !strings.EqualFold(parsed.Hostname(), "auth.openai.com") { + return errors.New("oauth token url host must be auth.openai.com") + } + return nil +} + +func generateState() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func computeCodeChallenge(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(sum[:]) +} + +func expiresAtFromNow(expiresIn int64) time.Time { + if expiresIn <= 0 { + return time.Time{} + } + return time.Now().Add(time.Duration(expiresIn) * time.Second) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} diff --git a/internal/providers/service.go b/internal/providers/service.go index 4426f09e..b9940e46 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -11,6 +11,7 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" + openaicodex "github.com/memohai/twilight-ai/provider/openai/codex" sdk "github.com/memohai/twilight-ai/sdk" "github.com/memohai/memoh/internal/db" @@ -20,21 +21,27 @@ import ( // Service handles provider operations. type Service struct { - queries *sqlc.Queries - logger *slog.Logger + queries *sqlc.Queries + logger *slog.Logger + httpClient *http.Client + callbackURL string } // NewService creates a new provider service. -func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { +func NewService(log *slog.Logger, queries *sqlc.Queries, callbackURL string) *Service { + if log == nil { + log = slog.Default() + } return &Service{ - queries: queries, - logger: log.With(slog.String("service", "providers")), + queries: queries, + logger: log.With(slog.String("service", "providers")), + httpClient: &http.Client{Timeout: providerOAuthHTTPTimeout}, + callbackURL: callbackURL, } } // Create creates a new LLM provider. func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) { - // Marshal metadata metadataJSON, err := json.Marshal(req.Metadata) if err != nil { return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) @@ -112,13 +119,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get return GetResponse{}, err } - // Get existing provider existing, err := s.queries.GetLlmProviderByID(ctx, providerID) if err != nil { return GetResponse{}, fmt.Errorf("get provider: %w", err) } - // Apply updates name := existing.Name if req.Name != nil { name = *req.Name @@ -146,16 +151,15 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get enable = *req.Enable } - metadata := existing.Metadata + metadataMap := providerMetadata(existing.Metadata) if req.Metadata != nil { - metadataJSON, err := json.Marshal(req.Metadata) - if err != nil { - return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) - } - metadata = metadataJSON + metadataMap = req.Metadata + } + metadataJSON, err := json.Marshal(metadataMap) + if err != nil { + return GetResponse{}, fmt.Errorf("marshal metadata: %w", err) } - // Update provider updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{ ID: providerID, Name: name, @@ -164,7 +168,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get ClientType: clientType, Icon: icon, Enable: enable, - Metadata: metadata, + Metadata: metadataJSON, }) if err != nil { return GetResponse{}, fmt.Errorf("update provider: %w", err) @@ -213,8 +217,12 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { baseURL := strings.TrimRight(provider.BaseUrl, "/") clientType := models.ClientType(provider.ClientType) + creds, err := s.ResolveModelCredentials(ctx, provider) + if err != nil { + return TestResponse{}, err + } - sdkProvider := models.NewSDKProvider(baseURL, provider.ApiKey, clientType, probeTimeout) + sdkProvider := models.NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil) start := time.Now() result := sdkProvider.Test(ctx) @@ -238,6 +246,29 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod if err != nil { return nil, fmt.Errorf("get provider: %w", err) } + if supportsOAuth(provider) { + catalog := openaicodex.Catalog() + remoteModels := make([]RemoteModel, 0, len(catalog)) + for _, model := range catalog { + compatibilities := make([]string, 0, 2) + if model.SupportsToolCall { + compatibilities = append(compatibilities, models.CompatToolCall) + } + if model.SupportsReasoning { + compatibilities = append(compatibilities, models.CompatReasoning) + } + remoteModels = append(remoteModels, RemoteModel{ + ID: model.ID, + Name: model.DisplayName, + Object: "model", + OwnedBy: "openai-codex", + Type: "chat", + Compatibilities: compatibilities, + ReasoningEfforts: append([]string(nil), model.ReasoningEfforts...), + }) + } + return remoteModels, nil + } baseURL := strings.TrimRight(provider.BaseUrl, "/") modelsURL := fmt.Sprintf("%s/models", baseURL) @@ -250,7 +281,7 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod return nil, fmt.Errorf("create request: %w", err) } - if provider.ApiKey != "" { + if provider.ApiKey != "" && !supportsOAuth(provider) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.ApiKey)) } @@ -284,7 +315,6 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { } } - // Mask API key (show only first 8 characters) maskedAPIKey := maskAPIKey(provider.ApiKey) var icon string @@ -318,7 +348,6 @@ func maskAPIKey(apiKey string) string { } // resolveUpdatedAPIKey keeps the original key when the request value matches the masked version. -// This prevents masked placeholder values from overwriting the real stored credential. func resolveUpdatedAPIKey(existing string, updated *string) string { if updated == nil { return existing diff --git a/internal/providers/types.go b/internal/providers/types.go index dc89eb57..81583407 100644 --- a/internal/providers/types.go +++ b/internal/providers/types.go @@ -55,12 +55,25 @@ type TestResponse struct { Message string `json:"message,omitempty"` } +// OAuthStatus is returned by GET /providers/:id/oauth/status. +type OAuthStatus struct { + Configured bool `json:"configured"` + HasToken bool `json:"has_token"` + Expired bool `json:"expired"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + CallbackURL string `json:"callback_url"` +} + // RemoteModel represents a model returned by the provider's /v1/models endpoint. type RemoteModel struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - OwnedBy string `json:"owned_by"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` + Compatibilities []string `json:"compatibilities,omitempty"` + ReasoningEfforts []string `json:"reasoning_efforts,omitempty"` } // FetchModelsResponse represents the response from the provider's /v1/models endpoint. diff --git a/internal/server/server.go b/internal/server/server.go index 3ad22601..10be9d3a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -90,5 +90,11 @@ func shouldSkipJWT(path string) bool { if strings.HasPrefix(path, "/email/oauth/callback") { return true } + if strings.HasPrefix(path, "/providers/oauth/callback") { + return true + } + if strings.HasPrefix(path, "/auth/callback") { + return true + } return false }