feat: openai codex support (#292)

* feat(web): add provider oauth management ui

* feat: add OAuth callback support on port 1455

* feat: enhance reasoning effort options and support for OpenAI Codex OAuth

* feat: update twilight-ai dependency to v0.3.4

* refactor: promote openai-codex to first-class client_type, remove auth_type

Replace the previous openai-responses + metadata auth_type=openai-codex-oauth
combo with a dedicated openai-codex client_type. OAuth requirement is now
determined solely by client_type, eliminating the auth_type concept from the
LLM provider domain entirely.

- Add openai-codex to DB CHECK constraint (migration 0047) with data migration
- Add ClientTypeOpenAICodex constant and dedicated SDK/probe branches
- Remove AuthType from SDKModelConfig, ModelCredentials, TriggerConfig, etc.
- Simplify supportsOAuth to check client_type == openai-codex
- Add conf/providers/codex.yaml preset with Codex catalog models
- Frontend: replace auth_type selector with client_type-driven OAuth UI

---------

Co-authored-by: Acbox <acbox0328@gmail.com>
This commit is contained in:
Yiming Qi
2026-03-27 19:30:45 +08:00
committed by GitHub
parent 44c92f198b
commit 64378d29ed
44 changed files with 1663 additions and 160 deletions
+25 -3
View File
@@ -47,6 +47,7 @@
</FormItem>
</FormField>
<FormField
v-if="form.values.client_type !== 'openai-codex'"
v-slot="{ componentField }"
name="api_key"
>
@@ -68,6 +69,12 @@
</FormControl>
</FormItem>
</FormField>
<div
v-else
class="rounded-lg border p-3 text-sm text-muted-foreground"
>
{{ $t('provider.oauth.createHint') }}
</div>
<FormField
v-slot="{ componentField }"
name="base_url"
@@ -161,7 +168,7 @@ import { useDialogMutation } from '@/composables/useDialogMutation'
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { toast } from 'vue-sonner'
import { computed } from 'vue'
import { computed, watch } from 'vue'
const open = defineModel<boolean>('open')
const { t } = useI18n()
@@ -208,11 +215,19 @@ const { mutateAsync: createProviderMutation, isLoading } = useMutation({
})
const providerSchema = toTypedSchema(z.object({
api_key: z.string().min(1),
api_key: z.string().optional(),
base_url: z.string().min(1),
name: z.string().min(1),
client_type: z.string().min(1),
auto_import: z.boolean().optional(),
}).superRefine((value, ctx) => {
if (value.client_type !== 'openai-codex' && !value.api_key?.trim()) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ['api_key'],
message: 'API key is required',
})
}
}))
const form = useForm({
@@ -223,12 +238,19 @@ const form = useForm({
},
})
watch(() => form.values.client_type, (clientType) => {
if (clientType !== 'openai-codex') return
if (!form.values.base_url) {
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
}
})
const createProvider = form.handleSubmit(async (value) => {
await run(
() => createProviderMutation(value),
{
fallbackMessage: t('common.saveFailed'),
onSuccess: () => {
onSuccess: () => {
open.value = false
form.resetForm()
},
+5
View File
@@ -15,6 +15,11 @@ export const CLIENT_TYPE_META: Record<string, ClientTypeMeta> = {
label: 'OpenAI Completions',
hint: 'Chat Completions API (widely compatible)',
},
'openai-codex': {
value: 'openai-codex',
label: 'OpenAI Codex',
hint: 'Codex API (OAuth, coding-optimized)',
},
'anthropic-messages': {
value: 'anthropic-messages',
label: 'Anthropic Messages',
+24 -4
View File
@@ -34,9 +34,7 @@
"loadFailed": "Failed to load",
"saveFailed": "Failed to save",
"createdAt": "Created at",
"none": "None",
"searchTimezone": "Search timezones…",
"noTimezoneFound": "No timezone found."
"none": "None"
},
"auth": {
"welcome": "Welcome Back",
@@ -230,7 +228,27 @@
"icon": "Icon",
"iconPlaceholder": "Icon URL or identifier (optional)",
"enable": "Enable",
"enableHint": "Only models from enabled providers appear in the available model list"
"enableHint": "Only models from enabled providers appear in the available model list",
"oauth": {
"title": "OpenAI OAuth",
"description": "Authorize this provider with your ChatGPT account for Codex-compatible OpenAI access.",
"createHint": "Save the provider first, then authorize it from the provider details panel.",
"authorize": "Authorize",
"authorizeFailed": "Failed to start authorization",
"authorizeSuccess": "Authorization successful",
"revoke": "Revoke",
"revokeFailed": "Failed to revoke authorization",
"revokeSuccess": "Authorization revoked",
"callback": "Callback URL",
"statusFailed": "Failed to load OAuth status",
"status": {
"checking": "Checking authorization status...",
"authorized": "Authorized",
"expired": "Authorization expired. Re-authorize to continue.",
"missing": "Not authorized yet.",
"notConfigured": "This provider is not configured for OAuth."
}
}
},
"searchProvider": {
"title": "Search Providers",
@@ -784,9 +802,11 @@
"language": "Language",
"reasoningEnabled": "Enable Reasoning",
"reasoningEffort": "Reasoning Effort",
"reasoningEffortNone": "None",
"reasoningEffortLow": "Low",
"reasoningEffortMedium": "Medium",
"reasoningEffortHigh": "High",
"reasoningEffortXHigh": "X-High",
"heartbeatEnabled": "Enable Heartbeat",
"heartbeatDescription": "Periodically trigger agent to check for items that need attention",
"heartbeatInterval": "Heartbeat Interval (minutes)",
+24 -4
View File
@@ -34,9 +34,7 @@
"loadFailed": "加载失败",
"saveFailed": "保存失败",
"createdAt": "创建时间",
"none": "无",
"searchTimezone": "搜索时区…",
"noTimezoneFound": "未找到时区"
"none": "无"
},
"auth": {
"welcome": "欢迎回来",
@@ -226,7 +224,27 @@
"icon": "图标",
"iconPlaceholder": "图标 URL 或标识(可选)",
"enable": "启用",
"enableHint": "只有启用的供应商的模型才会出现在可用模型列表中"
"enableHint": "只有启用的供应商的模型才会出现在可用模型列表中",
"oauth": {
"title": "OpenAI OAuth",
"description": "使用你的 ChatGPT 账号为该提供商授权,以启用 Codex 兼容的 OpenAI 访问。",
"createHint": "请先保存提供商,再到详情面板完成授权。",
"authorize": "授权",
"authorizeFailed": "启动授权失败",
"authorizeSuccess": "授权成功",
"revoke": "撤销授权",
"revokeFailed": "撤销授权失败",
"revokeSuccess": "授权已撤销",
"callback": "回调地址",
"statusFailed": "加载 OAuth 状态失败",
"status": {
"checking": "正在检查授权状态...",
"authorized": "已授权",
"expired": "授权已过期,请重新授权。",
"missing": "尚未授权。",
"notConfigured": "当前提供商未正确配置 OAuth。"
}
}
},
"searchProvider": {
"title": "搜索提供方",
@@ -780,9 +798,11 @@
"language": "语言",
"reasoningEnabled": "启用推理",
"reasoningEffort": "推理等级",
"reasoningEffortNone": "无",
"reasoningEffortLow": "低",
"reasoningEffortMedium": "中",
"reasoningEffortHigh": "高",
"reasoningEffortXHigh": "超高",
"heartbeatEnabled": "启用心跳",
"heartbeatDescription": "定期触发 Agent 检查是否有需要关注的事项",
"heartbeatInterval": "心跳间隔(分钟)",
@@ -197,21 +197,6 @@
/>
</div>
<!-- Timezone -->
<div class="space-y-2">
<Label>{{ $t('bots.timezone') }}</Label>
<TimezoneSelect
:model-value="timezoneModel"
:placeholder="$t('bots.timezonePlaceholder')"
allow-empty
:empty-label="$t('bots.timezoneInherited')"
@update:model-value="onTimezoneChange"
/>
<p class="text-xs text-muted-foreground">
{{ $t('bots.timezoneInheritedHint') }}
</p>
</div>
<Separator />
<!-- Max Context Load Time -->
@@ -272,15 +257,36 @@
</SelectTrigger>
<SelectContent>
<SelectGroup>
<SelectItem value="low">
<SelectItem
v-if="availableReasoningEfforts.includes('none')"
value="none"
>
{{ $t('bots.settings.reasoningEffortNone') }}
</SelectItem>
<SelectItem
v-if="availableReasoningEfforts.includes('low')"
value="low"
>
{{ $t('bots.settings.reasoningEffortLow') }}
</SelectItem>
<SelectItem value="medium">
<SelectItem
v-if="availableReasoningEfforts.includes('medium')"
value="medium"
>
{{ $t('bots.settings.reasoningEffortMedium') }}
</SelectItem>
<SelectItem value="high">
<SelectItem
v-if="availableReasoningEfforts.includes('high')"
value="high"
>
{{ $t('bots.settings.reasoningEffortHigh') }}
</SelectItem>
<SelectItem
v-if="availableReasoningEfforts.includes('xhigh')"
value="xhigh"
>
{{ $t('bots.settings.reasoningEffortXHigh') }}
</SelectItem>
</SelectGroup>
</SelectContent>
</Select>
@@ -349,20 +355,18 @@ import {
SelectTrigger,
SelectValue,
} from '@memohai/ui'
import { reactive, computed, watch, ref } from 'vue'
import { reactive, computed, watch } from 'vue'
import { useRouter } from 'vue-router'
import { toast } from 'vue-sonner'
import { useI18n } from 'vue-i18n'
import ConfirmPopover from '@/components/confirm-popover/index.vue'
import TimezoneSelect from '@/components/timezone-select/index.vue'
import { emptyTimezoneValue } from '@/utils/timezones'
import ModelSelect from './model-select.vue'
import SearchProviderSelect from './search-provider-select.vue'
import MemoryProviderSelect from './memory-provider-select.vue'
import TtsModelSelect from './tts-model-select.vue'
import BrowserContextSelect from './browser-context-select.vue'
import { useQuery, useMutation, useQueryCache } from '@pinia/colada'
import { getBotsById, putBotsById, getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getMemoryProviders, getTtsProviders, getBrowserContexts, getBotsByBotIdMemoryStatus, postBotsByBotIdMemoryRebuild } from '@memohai/sdk'
import { getBotsByBotIdSettings, putBotsByBotIdSettings, deleteBotsById, getModels, getProviders, getSearchProviders, getMemoryProviders, getTtsProviders, getBrowserContexts, getBotsByBotIdMemoryStatus, postBotsByBotIdMemoryRebuild } from '@memohai/sdk'
import type { SettingsSettings } from '@memohai/sdk'
import type { Ref } from 'vue'
import { resolveApiErrorMessage } from '@/utils/api-error'
@@ -379,15 +383,6 @@ const botIdRef = computed(() => props.botId) as Ref<string>
// ---- Data ----
const queryCache = useQueryCache()
const { data: bot } = useQuery({
key: () => ['bot', botIdRef.value],
query: async () => {
const { data } = await getBotsById({ path: { id: botIdRef.value }, throwOnError: true })
return data
},
enabled: () => !!botIdRef.value,
})
const { data: settings } = useQuery({
key: () => ['bot-settings', botIdRef.value],
query: async () => {
@@ -503,31 +498,6 @@ const form = reactive({
reasoning_effort: 'medium',
})
const timezone = ref('')
const timezoneModel = computed(() => timezone.value || emptyTimezoneValue)
function onTimezoneChange(value: string) {
timezone.value = value === emptyTimezoneValue ? '' : value
}
watch(bot, (val) => {
if (val) {
timezone.value = val.timezone || ''
}
}, { immediate: true })
const { mutateAsync: updateBot } = useMutation({
mutation: async ({ id, ...body }: Record<string, unknown> & { id: string }) => {
const { data } = await putBotsById({ path: { id }, body, throwOnError: true })
return data
},
onSettled: () => {
queryCache.invalidateQueries({ key: ['bots'] })
queryCache.invalidateQueries({ key: ['bot'] })
},
})
const selectedMemoryProvider = computed(() =>
memoryProviders.value.find((provider) => provider.id === form.memory_provider_id),
)
@@ -582,6 +552,20 @@ const chatModelSupportsReasoning = computed(() => {
return !!m?.config?.compatibilities?.includes('reasoning')
})
const availableReasoningEfforts = computed(() => {
if (!form.chat_model_id) return ['low', 'medium', 'high']
const model = models.value.find((m) => m.id === form.chat_model_id)
const efforts = ((model?.config as { reasoning_efforts?: string[] } | undefined)?.reasoning_efforts ?? [])
.filter((effort) => ['none', 'low', 'medium', 'high', 'xhigh'].includes(effort))
return efforts.length > 0 ? efforts : ['low', 'medium', 'high']
})
watch(availableReasoningEfforts, (efforts) => {
if (!efforts.includes(form.reasoning_effort)) {
form.reasoning_effort = efforts.includes('medium') ? 'medium' : efforts[0] ?? 'medium'
}
}, { immediate: true })
const { data: memoryStatusData, isLoading: isMemoryStatusLoading } = useQuery({
key: () => ['bot-memory-status', botIdRef.value, persistedMemoryProviderID.value],
query: async () => {
@@ -641,12 +625,10 @@ watch(settings, (val) => {
}, { immediate: true })
const hasChanges = computed(() => {
const timezoneChanged = timezone.value !== (bot.value?.timezone || '')
if (!settings.value) return true
const s = settings.value
let changed =
timezoneChanged
|| form.chat_model_id !== (s.chat_model_id ?? '')
form.chat_model_id !== (s.chat_model_id ?? '')
|| form.title_model_id !== (s.title_model_id ?? '')
|| form.search_provider_id !== (s.search_provider_id ?? '')
|| form.memory_provider_id !== (s.memory_provider_id ?? '')
@@ -662,11 +644,7 @@ const hasChanges = computed(() => {
async function handleSave() {
try {
const promises: Promise<unknown>[] = [updateSettings({ ...form })]
if (timezone.value !== (bot.value?.timezone || '')) {
promises.push(updateBot({ id: botIdRef.value, timezone: timezone.value }))
}
await Promise.all(promises)
await updateSettings({ ...form })
toast.success(t('bots.settings.saveSuccess'))
} catch {
return
@@ -30,6 +30,14 @@
>
{{ $t(`models.compatibility.${cap}`, cap) }}
</Badge>
<Badge
v-for="effort in reasoningEfforts"
:key="effort"
variant="secondary"
class="text-xs"
>
{{ effort }}
</Badge>
<span
v-if="model.config?.context_window"
class="text-xs text-muted-foreground"
@@ -101,6 +109,10 @@ import { postModelsByIdTest } from '@memohai/sdk'
import type { ModelsGetResponse, ModelsTestResponse } from '@memohai/sdk'
import { ref, computed } from 'vue'
type ModelConfigWithReasoning = {
reasoning_efforts?: string[]
}
const props = defineProps<{
model: ModelsGetResponse
deleteLoading: boolean
@@ -113,6 +125,7 @@ defineEmits<{
const testLoading = ref(false)
const testResult = ref<ModelsTestResponse | null>(null)
const reasoningEfforts = computed(() => ((props.model.config as ModelConfigWithReasoning | undefined)?.reasoning_efforts ?? []))
const statusDotClass = computed(() => {
switch (testResult.value?.status) {
@@ -44,7 +44,10 @@
</FormField>
</section>
<section class="space-y-2">
<section
v-if="form.values.client_type !== 'openai-codex'"
class="space-y-2"
>
<h4 class="scroll-m-20 font-semibold tracking-tight">
{{ $t('provider.apiKey') }}
</h4>
@@ -56,7 +59,7 @@
<FormControl>
<Input
type="password"
:placeholder="props.provider?.api_key || $t('provider.apiKeyPlaceholder')"
:placeholder="providerWithAuth?.api_key || $t('provider.apiKeyPlaceholder')"
:aria-label="$t('provider.apiKey')"
v-bind="componentField"
/>
@@ -106,6 +109,67 @@
</FormItem>
</FormField>
</section>
<section
v-if="form.values.client_type === 'openai-codex'"
class="rounded-lg border p-4 space-y-3 text-sm"
>
<div class="space-y-1">
<div class="font-medium">
{{ $t('provider.oauth.title') }}
</div>
<div class="text-muted-foreground">
{{ $t('provider.oauth.description') }}
</div>
<div
class="text-xs"
:class="oauthExpired ? 'text-destructive' : 'text-muted-foreground'"
>
<template v-if="oauthStatusLoading">
{{ $t('provider.oauth.status.checking') }}
</template>
<template v-else-if="oauthStatus && !oauthStatus.configured">
{{ $t('provider.oauth.status.notConfigured') }}
</template>
<template v-else-if="oauthExpired">
{{ $t('provider.oauth.status.expired') }}
</template>
<template v-else-if="oauthStatus?.has_token">
{{ $t('provider.oauth.status.authorized') }}
</template>
<template v-else>
{{ $t('provider.oauth.status.missing') }}
</template>
</div>
<div
v-if="oauthStatus?.callback_url"
class="text-xs text-muted-foreground"
>
{{ $t('provider.oauth.callback') }}: {{ oauthStatus.callback_url }}
</div>
</div>
<div class="flex gap-2">
<LoadingButton
type="button"
variant="outline"
:disabled="!canAuthorizeOAuth"
:loading="authorizeLoading"
@click="handleAuthorize"
>
<FontAwesomeIcon :icon="['fas', 'key']" />
{{ $t('provider.oauth.authorize') }}
</LoadingButton>
<LoadingButton
v-if="oauthStatus?.has_token"
type="button"
variant="ghost"
:loading="revokeLoading"
@click="handleRevoke"
>
{{ $t('provider.oauth.revoke') }}
</LoadingButton>
</div>
</section>
</div>
<section class="flex justify-between items-center mt-4">
@@ -206,11 +270,22 @@ import { useForm } from 'vee-validate'
import { postProvidersByIdTest } from '@memohai/sdk'
import type { ProvidersGetResponse, ProvidersTestResponse } from '@memohai/sdk'
import { useI18n } from 'vue-i18n'
import { toast } from 'vue-sonner'
const { t } = useI18n()
type ProviderWithAuth = Partial<ProvidersGetResponse>
type ProviderOAuthStatus = {
configured: boolean
has_token: boolean
expired: boolean
callback_url?: string
expires_at?: string
}
const props = defineProps<{
provider: Partial<ProvidersGetResponse> | undefined
provider: ProviderWithAuth | undefined
editLoading: boolean
deleteLoading: boolean
}>()
@@ -223,6 +298,13 @@ const emit = defineEmits<{
const testLoading = ref(false)
const testResult = ref<ProvidersTestResponse | null>(null)
const testError = ref('')
const oauthStatus = ref<ProviderOAuthStatus | null>(null)
const oauthStatusLoading = ref(false)
const authorizeLoading = ref(false)
const revokeLoading = ref(false)
const apiBase = import.meta.env.VITE_API_URL?.trim() || '/api'
const providerWithAuth = computed(() => props.provider as ProviderWithAuth | undefined)
async function runTest() {
if (!props.provider?.id) return
@@ -265,6 +347,14 @@ const providerSchema = toTypedSchema(z.object({
metadata: z.object({
additionalProp1: z.object({}),
}),
}).superRefine((value, ctx) => {
if (value.client_type !== 'openai-codex' && !value.api_key?.trim() && !providerWithAuth.value?.api_key) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ['api_key'],
message: 'API key is required',
})
}
}))
const form = useForm({
@@ -283,6 +373,24 @@ watch(() => props.provider, (newVal) => {
}
}, { immediate: true })
watch(() => form.values.client_type, (clientType) => {
if (clientType !== 'openai-codex') {
oauthStatus.value = null
return
}
if (!form.values.base_url) {
form.setFieldValue('base_url', 'https://chatgpt.com/backend-api')
}
})
watch(() => [props.provider?.id, form.values.client_type] as const, async ([id, clientType]) => {
if (!id || clientType !== 'openai-codex') {
oauthStatus.value = null
return
}
await fetchOAuthStatus()
}, { immediate: true })
const hasChanges = computed(() => {
const raw = props.provider
const baseChanged = JSON.stringify({
@@ -316,4 +424,78 @@ const editProvider = form.handleSubmit(async (value) => {
}
emit('submit', payload)
})
const oauthExpired = computed(() => Boolean(oauthStatus.value?.has_token && oauthStatus.value?.expired))
const canAuthorizeOAuth = computed(() =>
Boolean(
props.provider?.id
&& form.values.client_type === 'openai-codex',
) && !oauthStatusLoading.value,
)
function authHeaders(): Record<string, string> {
const token = localStorage.getItem('token')
return token ? { Authorization: `Bearer ${token}` } : {}
}
async function fetchOAuthStatus() {
if (!props.provider?.id) return
oauthStatusLoading.value = true
try {
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/status`, {
headers: authHeaders(),
})
if (!response.ok) throw new Error(t('provider.oauth.statusFailed'))
oauthStatus.value = await response.json() as ProviderOAuthStatus
} catch (error) {
oauthStatus.value = null
console.error('failed to load provider oauth status', error)
} finally {
oauthStatusLoading.value = false
}
}
async function handleAuthorize() {
if (!props.provider?.id) return
authorizeLoading.value = true
try {
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/authorize`, {
headers: authHeaders(),
})
if (!response.ok) throw new Error(t('provider.oauth.authorizeFailed'))
const data = await response.json() as { auth_url?: string }
if (!data.auth_url) throw new Error(t('provider.oauth.authorizeFailed'))
const popup = window.open(data.auth_url, 'provider-oauth', 'width=600,height=720')
const listener = async (event: MessageEvent) => {
if (event.data?.type !== 'memoh-provider-oauth-success') return
window.removeEventListener('message', listener)
popup?.close()
toast.success(t('provider.oauth.authorizeSuccess'))
await fetchOAuthStatus()
}
window.addEventListener('message', listener)
} catch (error) {
toast.error(error instanceof Error ? error.message : t('provider.oauth.authorizeFailed'))
} finally {
authorizeLoading.value = false
}
}
async function handleRevoke() {
if (!props.provider?.id) return
revokeLoading.value = true
try {
const response = await fetch(`${apiBase}/providers/${props.provider.id}/oauth/token`, {
method: 'DELETE',
headers: authHeaders(),
})
if (!response.ok) throw new Error(t('provider.oauth.revokeFailed'))
toast.success(t('provider.oauth.revokeSuccess'))
await fetchOAuthStatus()
} catch (error) {
toast.error(error instanceof Error ? error.message : t('provider.oauth.revokeFailed'))
} finally {
revokeLoading.value = false
}
}
</script>
+11 -1
View File
@@ -165,7 +165,7 @@ func runServe() {
accounts.NewService,
acl.NewService,
settings.NewService,
providers.NewService,
provideProvidersService,
searchproviders.NewService,
browsercontexts.NewService,
policy.NewService,
@@ -228,6 +228,7 @@ func runServe() {
provideServerHandler(provideSessionHandler),
provideServerHandler(handlers.NewSwaggerHandler),
provideServerHandler(handlers.NewProvidersHandler),
provideServerHandler(handlers.NewProviderOAuthHandler),
provideServerHandler(handlers.NewSearchProvidersHandler),
provideServerHandler(handlers.NewModelsHandler),
provideServerHandler(handlers.NewSettingsHandler),
@@ -765,6 +766,15 @@ func provideEmailRegistry(log *slog.Logger, tokenStore *emailpkg.DBOAuthTokenSto
return reg
}
func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *providers.Service {
_ = cfg
return providers.NewService(log, queries, defaultProviderOAuthCallbackURL())
}
func defaultProviderOAuthCallbackURL() string {
return "http://localhost:1455/auth/callback"
}
func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, tokenStore *emailpkg.DBOAuthTokenStore, cfg config.Config) *handlers.EmailOAuthHandler {
addr := strings.TrimSpace(cfg.Server.Addr)
if addr == "" {
+11 -1
View File
@@ -106,7 +106,7 @@ func runServe() {
accounts.NewService,
acl.NewService,
settings.NewService,
providers.NewService,
provideProvidersService,
searchproviders.NewService,
policy.NewService,
mcp.NewConnectionService,
@@ -154,6 +154,7 @@ func runServe() {
provideServerHandler(provideSessionHandler),
provideServerHandler(handlers.NewSwaggerHandler),
provideServerHandler(handlers.NewProvidersHandler),
provideServerHandler(handlers.NewProviderOAuthHandler),
provideServerHandler(handlers.NewSearchProvidersHandler),
provideServerHandler(handlers.NewModelsHandler),
provideServerHandler(handlers.NewSettingsHandler),
@@ -858,6 +859,15 @@ func provideEmailRegistry(log *slog.Logger, tokenStore *emailpkg.DBOAuthTokenSto
return reg
}
func provideProvidersService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *providers.Service {
_ = cfg
return providers.NewService(log, queries, defaultProviderOAuthCallbackURL())
}
func defaultProviderOAuthCallbackURL() string {
return "http://localhost:1455/auth/callback"
}
func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, tokenStore *emailpkg.DBOAuthTokenStore, cfg config.Config) *handlers.EmailOAuthHandler {
addr := strings.TrimSpace(cfg.Server.Addr)
if addr == "" {
+41
View File
@@ -0,0 +1,41 @@
name: OpenAI Codex
client_type: openai-codex
icon: openai
base_url: https://chatgpt.com/backend-api
models:
- model_id: gpt-5.2
name: GPT-5.2
type: chat
config:
compatibilities: [tool-call, reasoning]
- model_id: gpt-5.2-codex
name: GPT-5.2 Codex
type: chat
config:
compatibilities: [tool-call, reasoning]
- model_id: gpt-5.1-codex
name: GPT-5.1 Codex
type: chat
config:
compatibilities: [tool-call, reasoning]
- model_id: gpt-5.1-codex-max
name: GPT-5.1 Codex Max
type: chat
config:
compatibilities: [tool-call, reasoning]
- model_id: gpt-5.1-codex-mini
name: GPT-5.1 Codex Mini
type: chat
config:
compatibilities: [tool-call, reasoning]
- model_id: gpt-5.1
name: GPT-5.1
type: chat
config:
compatibilities: [tool-call, reasoning]
+1 -1
View File
@@ -69,7 +69,7 @@ CREATE TABLE IF NOT EXISTS llm_providers (
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT llm_providers_name_unique UNIQUE (name),
CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai'))
CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex'))
);
CREATE TABLE IF NOT EXISTS search_providers (
@@ -0,0 +1,5 @@
-- 0046_llm_provider_oauth (rollback)
-- Remove OAuth token storage for LLM providers.
DROP INDEX IF EXISTS idx_llm_provider_oauth_tokens_state;
DROP TABLE IF EXISTS llm_provider_oauth_tokens;
@@ -0,0 +1,18 @@
-- 0046_llm_provider_oauth
-- Add OAuth token storage for LLM providers to support OpenAI Codex OAuth.
CREATE TABLE IF NOT EXISTS llm_provider_oauth_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
llm_provider_id UUID NOT NULL UNIQUE REFERENCES llm_providers(id) ON DELETE CASCADE,
access_token TEXT NOT NULL DEFAULT '',
refresh_token TEXT NOT NULL DEFAULT '',
expires_at TIMESTAMPTZ,
scope TEXT NOT NULL DEFAULT '',
token_type TEXT NOT NULL DEFAULT '',
state TEXT NOT NULL DEFAULT '',
pkce_code_verifier TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_llm_provider_oauth_tokens_state ON llm_provider_oauth_tokens(state) WHERE state != '';
@@ -0,0 +1,11 @@
-- 0047_add_openai_codex_client_type (rollback)
-- Revert openai-codex rows back to openai-responses and restore the old CHECK constraint.
UPDATE llm_providers
SET client_type = 'openai-responses',
updated_at = now()
WHERE client_type = 'openai-codex';
ALTER TABLE llm_providers DROP CONSTRAINT IF EXISTS llm_providers_client_type_check;
ALTER TABLE llm_providers ADD CONSTRAINT llm_providers_client_type_check
CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai'));
@@ -0,0 +1,12 @@
-- 0047_add_openai_codex_client_type
-- Add openai-codex as a first-class client_type and migrate existing codex-oauth providers.
ALTER TABLE llm_providers DROP CONSTRAINT IF EXISTS llm_providers_client_type_check;
ALTER TABLE llm_providers ADD CONSTRAINT llm_providers_client_type_check
CHECK (client_type IN ('openai-responses', 'openai-completions', 'anthropic-messages', 'google-generative-ai', 'openai-codex'));
UPDATE llm_providers
SET client_type = 'openai-codex',
updated_at = now()
WHERE client_type = 'openai-responses'
AND metadata->>'auth_type' = 'openai-codex-oauth';
+52
View File
@@ -0,0 +1,52 @@
-- name: UpsertLlmProviderOAuthToken :one
INSERT INTO llm_provider_oauth_tokens (
llm_provider_id,
access_token,
refresh_token,
expires_at,
scope,
token_type,
state,
pkce_code_verifier
)
VALUES (
sqlc.arg(llm_provider_id),
sqlc.arg(access_token),
sqlc.arg(refresh_token),
sqlc.arg(expires_at),
sqlc.arg(scope),
sqlc.arg(token_type),
sqlc.arg(state),
sqlc.arg(pkce_code_verifier)
)
ON CONFLICT (llm_provider_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
expires_at = EXCLUDED.expires_at,
scope = EXCLUDED.scope,
token_type = EXCLUDED.token_type,
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
updated_at = now()
RETURNING *;
-- name: GetLlmProviderOAuthTokenByProvider :one
SELECT * FROM llm_provider_oauth_tokens WHERE llm_provider_id = sqlc.arg(llm_provider_id);
-- name: GetLlmProviderOAuthTokenByState :one
SELECT * FROM llm_provider_oauth_tokens WHERE state = sqlc.arg(state) AND state != '';
-- name: UpdateLlmProviderOAuthState :exec
INSERT INTO llm_provider_oauth_tokens (llm_provider_id, state, pkce_code_verifier)
VALUES (
sqlc.arg(llm_provider_id),
sqlc.arg(state),
sqlc.arg(pkce_code_verifier)
)
ON CONFLICT (llm_provider_id) DO UPDATE SET
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
updated_at = now();
-- name: DeleteLlmProviderOAuthToken :exec
DELETE FROM llm_provider_oauth_tokens WHERE llm_provider_id = sqlc.arg(llm_provider_id);
+1 -1
View File
@@ -59,6 +59,6 @@ RUN chmod +x /entrypoint.sh
VOLUME ["/var/lib/containerd", "/opt/memoh/data"]
EXPOSE 8080
EXPOSE 8080 1455
ENTRYPOINT ["/entrypoint.sh"]
+1
View File
@@ -95,6 +95,7 @@ services:
- /etc/localtime:/etc/localtime:ro
ports:
- "${MEMOH_DEV_SERVER_PORT:-18080}:8080"
- "${MEMOH_DEV_OAUTH_PORT:-1455}:8080"
healthcheck:
test: ["CMD-SHELL", "wget --no-verbose --tries=1 --spider http://127.0.0.1:8080/health || exit 1"]
interval: 5s
+1
View File
@@ -47,6 +47,7 @@ services:
- /etc/localtime:/etc/localtime:ro
ports:
- "8080:8080"
- "1455:8080"
depends_on:
migrate:
condition: service_completed_successfully
+1 -1
View File
@@ -117,7 +117,7 @@ RUN mkdir -p /opt/memoh/data /run/containerd /var/lib/containerd
VOLUME ["/var/lib/containerd", "/opt/memoh/data"]
EXPOSE 8080
EXPOSE 8080 1455
HEALTHCHECK --interval=30s --timeout=3s --start-period=30s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://127.0.0.1:8080/health \
+3 -2
View File
@@ -26,7 +26,7 @@ require (
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mailgun/mailgun-go/v5 v5.14.0
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd
github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584
github.com/modelcontextprotocol/go-sdk v1.4.1
github.com/opencontainers/image-spec v1.1.1
github.com/opencontainers/runtime-spec v1.3.0
@@ -36,6 +36,7 @@ require (
github.com/stretchr/testify v1.11.1
github.com/swaggo/swag v1.16.6
github.com/wneessen/go-mail v0.7.2
github.com/yuin/goldmark v1.7.13
go.uber.org/fx v1.24.0
golang.org/x/crypto v0.48.0
golang.org/x/oauth2 v0.35.0
@@ -123,7 +124,6 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yuin/goldmark v1.7.13 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
@@ -143,3 +143,4 @@ require (
golang.org/x/tools v0.42.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect
)
+2 -2
View File
@@ -228,8 +228,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 h1:beehwOQperqGWj4m4EhcPhnSZKtDiuHK/7ZMoTPaQjw=
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7/go.mod h1:OvmxM7JmnXBmwJWWVqtreL3HSHSKuzPbtbhlg5MvBg0=
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd h1:uV7xsqYHYpEmT6xKvkOs5mHT5oEKnwV1F93ialqi78k=
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd/go.mod h1:vHNoRb6/quMacMAgIp838aoiNhsZbE0bFCnRRNyRwNc=
github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584 h1:zu4T54unBe8ziIlTr8gUtFoR16c6u2G1Qpx8OGsenxo=
github.com/memohai/twilight-ai v0.3.4-0.20260326121718-a9628b948584/go.mod h1:GZTT9GUT3uSs6zram/FcF24GLTZMFSpiybbYmjr+gH8=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
+10 -6
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"net/http"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -69,14 +70,17 @@ func SpawnSystemPrompt(sessionType string) string {
})
}
// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to models.NewSDKChatModel.
// SpawnModelCreatorFunc returns a tools.ModelCreator backed by the shared SDK model factory.
// This keeps subagent model creation aligned with the shared SDK model factory.
func SpawnModelCreatorFunc() tools.ModelCreator {
return func(modelID, clientType, apiKey, baseURL string) *sdk.Model {
return func(modelID, clientType, apiKey, codexAccountID, baseURL string, httpClient *http.Client) *sdk.Model {
return models.NewSDKChatModel(models.SDKModelConfig{
ModelID: modelID,
ClientType: clientType,
APIKey: apiKey,
BaseURL: baseURL,
ModelID: modelID,
ClientType: clientType,
APIKey: apiKey,
CodexAccountID: codexAccountID,
BaseURL: baseURL,
HTTPClient: httpClient,
})
}
}
+16 -2
View File
@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
@@ -14,6 +15,7 @@ import (
"github.com/memohai/memoh/internal/db/sqlc"
messagepkg "github.com/memohai/memoh/internal/message"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/providers"
sessionpkg "github.com/memohai/memoh/internal/session"
"github.com/memohai/memoh/internal/settings"
)
@@ -292,7 +294,7 @@ func (p *SpawnProvider) persistMessages(
}
// ModelCreator creates an sdk.Model from provider config. Set via SetModelCreator.
type ModelCreator func(modelID, clientType, apiKey, baseURL string) *sdk.Model
type ModelCreator func(modelID, clientType, apiKey, codexAccountID, baseURL string, httpClient *http.Client) *sdk.Model
// SetModelCreator injects the function used to create SDK models
// (typically agent.CreateModel wrapped to match the signature).
@@ -323,7 +325,19 @@ func (p *SpawnProvider) resolveModel(ctx context.Context, botID string) (*sdk.Mo
if p.modelCreator == nil {
return nil, "", errors.New("model creator not configured")
}
sdkModel := p.modelCreator(modelInfo.ModelID, provider.ClientType, provider.ApiKey, provider.BaseUrl)
authResolver := providers.NewService(nil, p.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
if err != nil {
return nil, "", err
}
sdkModel := p.modelCreator(
modelInfo.ModelID,
provider.ClientType,
creds.APIKey,
creds.CodexAccountID,
provider.BaseUrl,
nil,
)
return sdkModel, modelInfo.ID, nil
}
+18
View File
@@ -2,6 +2,7 @@ package agent
import (
"encoding/json"
"net/http"
"time"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -95,6 +96,23 @@ type SystemFile struct {
Content string
}
// ModelConfig holds provider and model information resolved from DB.
type ModelConfig struct {
ModelID string
ClientType string
APIKey string //nolint:gosec // carries provider credential material at runtime
CodexAccountID string
BaseURL string
HTTPClient *http.Client
ReasoningConfig *ReasoningConfig
}
// ReasoningConfig controls extended thinking/reasoning behavior.
type ReasoningConfig struct {
Enabled bool
Effort string
}
func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
+6 -4
View File
@@ -104,10 +104,12 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU
userPrompt := buildUserPrompt(priorSummaries, entries)
model := models.NewSDKChatModel(models.SDKModelConfig{
ClientType: cfg.ClientType,
BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey,
ModelID: cfg.ModelID,
ClientType: cfg.ClientType,
BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey,
CodexAccountID: cfg.CodexAccountID,
ModelID: cfg.ModelID,
HTTPClient: cfg.HTTPClient,
})
result, err := sdk.GenerateTextResult(ctx,
+12 -7
View File
@@ -1,6 +1,9 @@
package compaction
import "time"
import (
"net/http"
"time"
)
// Log represents a compaction log entry.
type Log struct {
@@ -24,10 +27,12 @@ type ListLogsResponse struct {
// TriggerConfig holds the parameters needed to trigger a compaction.
type TriggerConfig struct {
BotID string
SessionID string
ModelID string
ClientType string
APIKey string //nolint:gosec // runtime credential, not a hardcoded secret
BaseURL string
BotID string
SessionID string
ModelID string
ClientType string
APIKey string //nolint:gosec // runtime credential, not a hardcoded secret
CodexAccountID string
BaseURL string
HTTPClient *http.Client
}
+10 -1
View File
@@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math"
@@ -24,6 +25,7 @@ import (
messagepkg "github.com/memohai/memoh/internal/message"
messageevent "github.com/memohai/memoh/internal/message/event"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/settings"
)
@@ -278,10 +280,17 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
}
}
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
if err != nil {
return resolvedContext{}, fmt.Errorf("resolve provider credentials: %w", err)
}
modelCfg := models.SDKModelConfig{
ModelID: chatModel.ModelID,
ClientType: clientType,
APIKey: provider.ApiKey,
APIKey: creds.APIKey,
CodexAccountID: creds.CodexAccountID,
BaseURL: provider.BaseUrl,
ReasoningConfig: reasoningConfig,
}
@@ -7,6 +7,7 @@ import (
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/providers"
)
func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatRequest, rc resolvedContext, inputTokens int) {
@@ -47,8 +48,15 @@ func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatReques
r.logger.Warn("compaction: failed to fetch provider", slog.Any("error", err))
return
}
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
if err != nil {
r.logger.Warn("compaction: failed to resolve provider credentials", slog.Any("error", err))
return
}
cfg.ClientType = provider.ClientType
cfg.APIKey = provider.ApiKey
cfg.APIKey = creds.APIKey
cfg.CodexAccountID = creds.CodexAccountID
cfg.BaseURL = provider.BaseUrl
r.compactionService.TriggerCompaction(ctx, cfg)
+13 -4
View File
@@ -13,6 +13,7 @@ import (
"github.com/memohai/memoh/internal/db/sqlc"
messageevent "github.com/memohai/memoh/internal/message/event"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/session"
)
@@ -104,11 +105,19 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
"Return ONLY the title text, nothing else.\n\n" +
"User: " + userSnippet
authResolver := providers.NewService(nil, r.queries, "")
creds, err := authResolver.ResolveModelCredentials(ctx, provider)
if err != nil {
r.logger.Warn("title gen: failed to resolve provider credentials", slog.Any("error", err))
return ""
}
modelCfg := models.SDKModelConfig{
ModelID: model.ModelID,
ClientType: provider.ClientType,
APIKey: provider.ApiKey,
BaseURL: provider.BaseUrl,
ModelID: model.ModelID,
ClientType: provider.ClientType,
APIKey: creds.APIKey,
CodexAccountID: creds.CodexAccountID,
BaseURL: provider.BaseUrl,
}
sdkModel := models.NewSDKChatModel(modelCfg)
+163
View File
@@ -0,0 +1,163 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: llm_provider_oauth.sql
package sqlc
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteLlmProviderOAuthToken = `-- name: DeleteLlmProviderOAuthToken :exec
DELETE FROM llm_provider_oauth_tokens WHERE llm_provider_id = $1
`
func (q *Queries) DeleteLlmProviderOAuthToken(ctx context.Context, llmProviderID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteLlmProviderOAuthToken, llmProviderID)
return err
}
const getLlmProviderOAuthTokenByProvider = `-- name: GetLlmProviderOAuthTokenByProvider :one
SELECT id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at FROM llm_provider_oauth_tokens WHERE llm_provider_id = $1
`
func (q *Queries) GetLlmProviderOAuthTokenByProvider(ctx context.Context, llmProviderID pgtype.UUID) (LlmProviderOauthToken, error) {
row := q.db.QueryRow(ctx, getLlmProviderOAuthTokenByProvider, llmProviderID)
var i LlmProviderOauthToken
err := row.Scan(
&i.ID,
&i.LlmProviderID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getLlmProviderOAuthTokenByState = `-- name: GetLlmProviderOAuthTokenByState :one
SELECT id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at FROM llm_provider_oauth_tokens WHERE state = $1 AND state != ''
`
func (q *Queries) GetLlmProviderOAuthTokenByState(ctx context.Context, state string) (LlmProviderOauthToken, error) {
row := q.db.QueryRow(ctx, getLlmProviderOAuthTokenByState, state)
var i LlmProviderOauthToken
err := row.Scan(
&i.ID,
&i.LlmProviderID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateLlmProviderOAuthState = `-- name: UpdateLlmProviderOAuthState :exec
INSERT INTO llm_provider_oauth_tokens (llm_provider_id, state, pkce_code_verifier)
VALUES (
$1,
$2,
$3
)
ON CONFLICT (llm_provider_id) DO UPDATE SET
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
updated_at = now()
`
type UpdateLlmProviderOAuthStateParams struct {
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
}
func (q *Queries) UpdateLlmProviderOAuthState(ctx context.Context, arg UpdateLlmProviderOAuthStateParams) error {
_, err := q.db.Exec(ctx, updateLlmProviderOAuthState, arg.LlmProviderID, arg.State, arg.PkceCodeVerifier)
return err
}
const upsertLlmProviderOAuthToken = `-- name: UpsertLlmProviderOAuthToken :one
INSERT INTO llm_provider_oauth_tokens (
llm_provider_id,
access_token,
refresh_token,
expires_at,
scope,
token_type,
state,
pkce_code_verifier
)
VALUES (
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8
)
ON CONFLICT (llm_provider_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
expires_at = EXCLUDED.expires_at,
scope = EXCLUDED.scope,
token_type = EXCLUDED.token_type,
state = EXCLUDED.state,
pkce_code_verifier = EXCLUDED.pkce_code_verifier,
updated_at = now()
RETURNING id, llm_provider_id, access_token, refresh_token, expires_at, scope, token_type, state, pkce_code_verifier, created_at, updated_at
`
type UpsertLlmProviderOAuthTokenParams struct {
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
}
func (q *Queries) UpsertLlmProviderOAuthToken(ctx context.Context, arg UpsertLlmProviderOAuthTokenParams) (LlmProviderOauthToken, error) {
row := q.db.QueryRow(ctx, upsertLlmProviderOAuthToken,
arg.LlmProviderID,
arg.AccessToken,
arg.RefreshToken,
arg.ExpiresAt,
arg.Scope,
arg.TokenType,
arg.State,
arg.PkceCodeVerifier,
)
var i LlmProviderOauthToken
err := row.Scan(
&i.ID,
&i.LlmProviderID,
&i.AccessToken,
&i.RefreshToken,
&i.ExpiresAt,
&i.Scope,
&i.TokenType,
&i.State,
&i.PkceCodeVerifier,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
+14
View File
@@ -292,6 +292,20 @@ type LlmProvider struct {
ClientType string `json:"client_type"`
}
type LlmProviderOauthToken struct {
ID pgtype.UUID `json:"id"`
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
State string `json:"state"`
PkceCodeVerifier string `json:"pkce_code_verifier"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type McpConnection struct {
ID pgtype.UUID `json:"id"`
BotID pgtype.UUID `json:"bot_id"`
+132
View File
@@ -0,0 +1,132 @@
package handlers
import (
"html/template"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/providers"
)
type ProviderOAuthHandler struct {
service *providers.Service
}
func NewProviderOAuthHandler(service *providers.Service) *ProviderOAuthHandler {
return &ProviderOAuthHandler{service: service}
}
func (h *ProviderOAuthHandler) Register(e *echo.Echo) {
e.GET("/providers/:id/oauth/authorize", h.Authorize)
e.GET("/providers/:id/oauth/status", h.Status)
e.DELETE("/providers/:id/oauth/token", h.Revoke)
e.GET("/auth/callback", h.Callback)
e.GET("/providers/oauth/callback", h.Callback)
}
// Authorize godoc
// @Summary Start OAuth2 authorization for an LLM provider
// @Tags providers-oauth
// @Param id path string true "Provider ID (UUID)"
// @Success 200 {object} map[string]string
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /providers/{id}/oauth/authorize [get].
func (h *ProviderOAuthHandler) Authorize(c echo.Context) error {
providerID := strings.TrimSpace(c.Param("id"))
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
authURL, err := h.service.StartOAuthAuthorization(c.Request().Context(), providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusOK, map[string]string{"auth_url": authURL})
}
// Status godoc
// @Summary Get OAuth2 status for an LLM provider
// @Tags providers-oauth
// @Param id path string true "Provider ID (UUID)"
// @Success 200 {object} providers.OAuthStatus
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /providers/{id}/oauth/status [get].
func (h *ProviderOAuthHandler) Status(c echo.Context) error {
providerID := strings.TrimSpace(c.Param("id"))
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
status, err := h.service.GetOAuthStatus(c.Request().Context(), providerID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.JSON(http.StatusOK, status)
}
// Revoke godoc
// @Summary Revoke stored OAuth2 tokens for an LLM provider
// @Tags providers-oauth
// @Param id path string true "Provider ID (UUID)"
// @Success 204 "No Content"
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /providers/{id}/oauth/token [delete].
func (h *ProviderOAuthHandler) Revoke(c echo.Context) error {
providerID := strings.TrimSpace(c.Param("id"))
if providerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
if err := h.service.RevokeOAuthToken(c.Request().Context(), providerID); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return c.NoContent(http.StatusNoContent)
}
// Callback godoc
// @Summary OAuth2 callback for LLM providers
// @Tags providers-oauth
// @Param code query string true "Authorization code"
// @Param state query string true "State parameter"
// @Success 200 {string} string "HTML success page"
// @Failure 400 {object} ErrorResponse
// @Router /providers/oauth/callback [get].
func (h *ProviderOAuthHandler) Callback(c echo.Context) error {
code := strings.TrimSpace(c.QueryParam("code"))
state := strings.TrimSpace(c.QueryParam("state"))
if code == "" {
return echo.NewHTTPError(http.StatusBadRequest, "code is required")
}
if state == "" {
return echo.NewHTTPError(http.StatusBadRequest, "state is required")
}
providerID, err := h.service.HandleOAuthCallback(c.Request().Context(), state, code)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
page := template.Must(template.New("oauth-success").Parse(`<!doctype html>
<html>
<head>
<meta charset="utf-8">
<title>OpenAI OAuth Connected</title>
</head>
<body style="font-family: sans-serif; padding: 24px;">
<h2>OpenAI OAuth connected</h2>
<p>You can close this window and return to Memoh.</p>
<script>
window.opener?.postMessage({ type: "memoh-provider-oauth-success", providerId: "{{.ProviderID}}" }, "*");
window.close();
</script>
</body>
</html>`))
return c.HTML(http.StatusOK, executeHTMLTemplate(page, map[string]string{"ProviderID": providerID}))
}
func executeHTMLTemplate(tpl *template.Template, data any) string {
var b strings.Builder
_ = tpl.Execute(&b, data)
return b.String()
}
+16 -5
View File
@@ -309,20 +309,31 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
}
defaultCompat := []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
resp := providers.ImportModelsResponse{
Models: make([]string, 0),
}
for _, m := range remoteModels {
modelType := models.ModelTypeChat
if strings.TrimSpace(m.Type) == string(models.ModelTypeEmbedding) {
modelType = models.ModelTypeEmbedding
}
compatibilities := m.Compatibilities
if len(compatibilities) == 0 {
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
}
name := strings.TrimSpace(m.Name)
if name == "" {
name = m.ID
}
_, err := h.modelsService.Create(c.Request().Context(), models.AddRequest{
ModelID: m.ID,
Name: m.ID,
Name: name,
LlmProviderID: id,
Type: models.ModelTypeChat,
Type: modelType,
Config: models.ModelConfig{
Compatibilities: defaultCompat,
Compatibilities: compatibilities,
ReasoningEfforts: m.ReasoningEfforts,
},
})
if err != nil {
@@ -75,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg
return nil, fmt.Errorf("dense runtime: %w", err)
}
embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout, nil)
return &denseRuntime{
qdrant: qClient,
+4 -3
View File
@@ -13,11 +13,13 @@ import (
// provider configuration. It dispatches to the native Google embedding provider
// when clientType is "google-generative-ai", and falls back to the
// OpenAI-compatible /embeddings endpoint for all other provider types.
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration, httpClient *http.Client) *sdk.EmbeddingModel {
if timeout <= 0 {
timeout = 30 * time.Second
}
httpClient := &http.Client{Timeout: timeout}
if httpClient == nil {
httpClient = &http.Client{Timeout: timeout}
}
switch ClientType(clientType) {
case ClientTypeGoogleGenerativeAI:
@@ -30,7 +32,6 @@ func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout t
}
p := googleembedding.New(opts...)
return p.EmbeddingModel(modelID)
default:
opts := []openaiembedding.Option{
openaiembedding.WithAPIKey(apiKey),
+81 -11
View File
@@ -2,6 +2,9 @@ package models
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
@@ -9,11 +12,13 @@ import (
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
const probeTimeout = 15 * time.Second
@@ -37,16 +42,17 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
}
baseURL := strings.TrimRight(provider.BaseUrl, "/")
apiKey := provider.ApiKey
clientType := ClientType(provider.ClientType)
// Embedding models don't have a chat Provider in the SDK — probe
// the /embeddings endpoint directly.
if model.Type == string(ModelTypeEmbedding) {
return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID)
creds, err := s.resolveModelCredentials(ctx, provider)
if err != nil {
return TestResponse{}, err
}
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
if model.Type == string(ModelTypeEmbedding) {
return s.testEmbeddingModel(ctx, baseURL, creds.APIKey, model.ModelID, nil)
}
sdkProvider := NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil)
start := time.Now()
@@ -100,11 +106,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
// testEmbeddingModel probes an embedding model by performing a minimal
// embedding request via the Twilight SDK, verifying that the model is
// reachable and functional rather than merely checking HTTP connectivity.
func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) {
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string, httpClient *http.Client) (TestResponse, error) {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout)
model := NewSDKEmbeddingModel(string(ClientTypeOpenAICompletions), baseURL, apiKey, modelID, probeTimeout, httpClient)
client := sdk.NewClient()
start := time.Now()
@@ -130,8 +136,10 @@ func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, api
// NewSDKProvider creates a Twilight AI SDK Provider for the given client type.
// It is exported so that other packages (e.g. providers) can reuse it for testing.
func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.Duration) sdk.Provider {
httpClient := &http.Client{Timeout: timeout}
func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientType, timeout time.Duration, httpClient *http.Client) sdk.Provider {
if httpClient == nil {
httpClient = &http.Client{Timeout: timeout}
}
switch clientType {
case ClientTypeOpenAIResponses:
@@ -144,6 +152,16 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.
}
return openairesponses.New(opts...)
case ClientTypeOpenAICodex:
opts := []openaicodex.Option{
openaicodex.WithAccessToken(apiKey),
openaicodex.WithHTTPClient(httpClient),
}
if codexAccountID != "" {
opts = append(opts, openaicodex.WithAccountID(codexAccountID))
}
return openaicodex.New(opts...)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(apiKey),
@@ -175,3 +193,55 @@ func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.
return openaicompletions.New(opts...)
}
}
type modelCredentials struct {
APIKey string //nolint:gosec // runtime credential material used to construct SDK providers
CodexAccountID string
}
func (s *Service) resolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (modelCredentials, error) {
if ClientType(provider.ClientType) != ClientTypeOpenAICodex {
return modelCredentials{APIKey: provider.ApiKey}, nil
}
tokenRow, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, provider.ID)
if err != nil {
return modelCredentials{}, err
}
accessToken := strings.TrimSpace(tokenRow.AccessToken)
if accessToken == "" {
return modelCredentials{}, errors.New("oauth token is missing access token")
}
accountID, err := codexAccountIDFromToken(accessToken)
if err != nil {
return modelCredentials{}, err
}
return modelCredentials{
APIKey: accessToken,
CodexAccountID: accountID,
}, nil
}
func codexAccountIDFromToken(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return "", errors.New("invalid oauth access token")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("decode oauth token payload: %w", err)
}
var claims struct {
OpenAIAuth struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
} `json:"https://api.openai.com/auth"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return "", fmt.Errorf("parse oauth token payload: %w", err)
}
accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID)
if accountID == "" {
return "", errors.New("oauth access token missing chatgpt_account_id")
}
return accountID, nil
}
+34 -1
View File
@@ -1,10 +1,12 @@
package models
import (
"net/http"
"strings"
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -16,7 +18,9 @@ type SDKModelConfig struct {
ModelID string
ClientType string
APIKey string //nolint:gosec // carries provider credential material at runtime
CodexAccountID string
BaseURL string
HTTPClient *http.Client
ReasoningConfig *ReasoningConfig
}
@@ -38,6 +42,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
}
if cfg.BaseURL != "" {
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
}
@@ -48,16 +55,34 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []openairesponses.Option{
openairesponses.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient))
}
if cfg.BaseURL != "" {
opts = append(opts, openairesponses.WithBaseURL(cfg.BaseURL))
}
p := openairesponses.New(opts...)
return p.ChatModel(cfg.ModelID)
case ClientTypeOpenAICodex:
opts := []openaicodex.Option{
openaicodex.WithAccessToken(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient))
}
if cfg.CodexAccountID != "" {
opts = append(opts, openaicodex.WithAccountID(cfg.CodexAccountID))
}
return openaicodex.New(opts...).ChatModel(cfg.ModelID)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient))
}
if cfg.BaseURL != "" {
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
}
@@ -75,6 +100,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []googlegenerative.Option{
googlegenerative.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient))
}
if cfg.BaseURL != "" {
opts = append(opts, googlegenerative.WithBaseURL(cfg.BaseURL))
}
@@ -85,6 +113,9 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
}
if cfg.BaseURL != "" {
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
}
@@ -106,7 +137,7 @@ func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
switch ClientType(cfg.ClientType) {
case ClientTypeAnthropicMessages:
return nil
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions:
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions, ClientTypeOpenAICodex:
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
case ClientTypeGoogleGenerativeAI:
return nil
@@ -147,6 +178,8 @@ func ResolveClientType(model *sdk.Model) string {
return string(ClientTypeAnthropicMessages)
case strings.Contains(name, "google"):
return string(ClientTypeGoogleGenerativeAI)
case strings.Contains(name, "codex"):
return string(ClientTypeOpenAICodex)
case strings.Contains(name, "responses"):
return string(ClientTypeOpenAIResponses)
default:
+26 -3
View File
@@ -20,6 +20,7 @@ const (
ClientTypeOpenAICompletions ClientType = "openai-completions"
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
ClientTypeOpenAICodex ClientType = "openai-codex"
)
const (
@@ -29,16 +30,33 @@ const (
CompatReasoning = "reasoning"
)
const (
ReasoningEffortNone = "none"
ReasoningEffortLow = "low"
ReasoningEffortMedium = "medium"
ReasoningEffortHigh = "high"
ReasoningEffortXHigh = "xhigh"
)
// validCompatibilities enumerates accepted compatibility tokens.
var validCompatibilities = map[string]struct{}{
CompatVision: {}, CompatToolCall: {}, CompatImageOutput: {}, CompatReasoning: {},
}
var validReasoningEfforts = map[string]struct{}{
ReasoningEffortNone: {},
ReasoningEffortLow: {},
ReasoningEffortMedium: {},
ReasoningEffortHigh: {},
ReasoningEffortXHigh: {},
}
// ModelConfig holds the JSONB config stored per model.
type ModelConfig struct {
Dimensions *int `json:"dimensions,omitempty"`
Compatibilities []string `json:"compatibilities,omitempty"`
ContextWindow *int `json:"context_window,omitempty"`
Dimensions *int `json:"dimensions,omitempty"`
Compatibilities []string `json:"compatibilities,omitempty"`
ContextWindow *int `json:"context_window,omitempty"`
ReasoningEfforts []string `json:"reasoning_efforts,omitempty"`
}
type Model struct {
@@ -72,6 +90,11 @@ func (m *Model) Validate() error {
return errors.New("invalid compatibility: " + c)
}
}
for _, effort := range m.Config.ReasoningEfforts {
if _, ok := validReasoningEfforts[effort]; !ok {
return errors.New("invalid reasoning effort: " + effort)
}
}
return nil
}
+69
View File
@@ -0,0 +1,69 @@
package providers
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
const openAIAuthClaimPath = "https://api.openai.com/auth"
type ModelCredentials struct {
APIKey string //nolint:gosec // runtime credential material used to construct SDK providers
CodexAccountID string
}
func SupportsOpenAICodexOAuth(provider sqlc.LlmProvider) bool {
return supportsOAuth(provider)
}
func (s *Service) ResolveModelCredentials(ctx context.Context, provider sqlc.LlmProvider) (ModelCredentials, error) {
if models.ClientType(provider.ClientType) != models.ClientTypeOpenAICodex {
return ModelCredentials{
APIKey: provider.ApiKey,
}, nil
}
token, err := s.GetValidAccessToken(ctx, provider.ID.String())
if err != nil {
return ModelCredentials{}, err
}
accountID, err := codexAccountIDFromToken(token)
if err != nil {
return ModelCredentials{}, err
}
return ModelCredentials{
APIKey: token,
CodexAccountID: accountID,
}, nil
}
func codexAccountIDFromToken(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return "", errors.New("invalid oauth access token")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("decode oauth token payload: %w", err)
}
var claims struct {
OpenAIAuth struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
} `json:"https://api.openai.com/auth"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return "", fmt.Errorf("parse oauth token payload: %w", err)
}
accountID := strings.TrimSpace(claims.OpenAIAuth.ChatGPTAccountID)
if accountID == "" {
return "", fmt.Errorf("oauth access token missing %s.chatgpt_account_id", openAIAuthClaimPath)
}
return accountID, nil
}
+468
View File
@@ -0,0 +1,468 @@
package providers
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
const (
defaultOpenAICodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
defaultOpenAIAuthorizeURL = "https://auth.openai.com/oauth/authorize"
defaultOpenAITokenURL = "https://auth.openai.com/oauth/token" //nolint:gosec // OAuth endpoint URL, not a credential
defaultOpenAICallbackURL = "http://localhost:1455/auth/callback"
defaultOpenAIOAuthScopes = "openid profile email offline_access"
oauthExpirySkew = 30 * time.Second
providerOAuthHTTPTimeout = 15 * time.Second
metadataOAuthClientIDKey = "oauth_client_id"
metadataOAuthAuthorizeURLKey = "oauth_authorize_url"
metadataOAuthTokenURLKey = "oauth_token_url" //nolint:gosec // metadata key name, not a credential
metadataOAuthRedirectURIKey = "oauth_redirect_uri"
metadataOAuthScopesKey = "oauth_scopes"
metadataOAuthAudienceKey = "oauth_audience"
metadataOAuthUseIDOrgsFlagKey = "oauth_id_token_add_organizations"
)
type providerOAuthToken struct {
ProviderID string `json:"provider_id"`
AccessToken string `json:"access_token"` //nolint:gosec // runtime credential storage
RefreshToken string `json:"refresh_token"` //nolint:gosec // runtime credential storage
ExpiresAt time.Time `json:"expires_at"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
State string `json:"state"`
PKCECodeVerifier string `json:"pkce_code_verifier"`
}
type openAIOAuthConfig struct {
ClientID string
AuthorizeURL string
TokenURL string
RedirectURI string
Scopes string
IDTokenAddOrganizations bool
}
func providerMetadata(raw []byte) map[string]any {
if len(raw) == 0 {
return map[string]any{}
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return map[string]any{}
}
if metadata == nil {
return map[string]any{}
}
return metadata
}
func (s *Service) oauthConfig(metadata map[string]any) openAIOAuthConfig {
cfg := openAIOAuthConfig{
ClientID: defaultOpenAICodexClientID,
AuthorizeURL: defaultOpenAIAuthorizeURL,
TokenURL: defaultOpenAITokenURL,
RedirectURI: firstNonEmpty(s.callbackURL, defaultOpenAICallbackURL),
Scopes: defaultOpenAIOAuthScopes,
IDTokenAddOrganizations: true,
}
if v, _ := metadata[metadataOAuthClientIDKey].(string); strings.TrimSpace(v) != "" {
cfg.ClientID = strings.TrimSpace(v)
}
if v, _ := metadata[metadataOAuthAuthorizeURLKey].(string); strings.TrimSpace(v) != "" {
cfg.AuthorizeURL = strings.TrimSpace(v)
}
if v, _ := metadata[metadataOAuthTokenURLKey].(string); strings.TrimSpace(v) != "" {
cfg.TokenURL = strings.TrimSpace(v)
}
if v, _ := metadata[metadataOAuthRedirectURIKey].(string); strings.TrimSpace(v) != "" {
cfg.RedirectURI = strings.TrimSpace(v)
}
if v, _ := metadata[metadataOAuthScopesKey].(string); strings.TrimSpace(v) != "" {
cfg.Scopes = strings.TrimSpace(v)
}
if v, ok := metadata[metadataOAuthUseIDOrgsFlagKey].(bool); ok {
cfg.IDTokenAddOrganizations = v
}
return cfg
}
func supportsOAuth(provider sqlc.LlmProvider) bool {
return models.ClientType(provider.ClientType) == models.ClientTypeOpenAICodex
}
func (s *Service) StartOAuthAuthorization(ctx context.Context, providerID string) (string, error) {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return "", err
}
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
if err != nil {
return "", fmt.Errorf("get provider: %w", err)
}
if !supportsOAuth(provider) {
return "", errors.New("provider does not support oauth")
}
cfg := s.oauthConfig(providerMetadata(provider.Metadata))
codeVerifier, err := generateCodeVerifier()
if err != nil {
return "", fmt.Errorf("generate code verifier: %w", err)
}
state, err := generateState()
if err != nil {
return "", fmt.Errorf("generate state: %w", err)
}
if err := s.updateOAuthState(ctx, providerID, state, codeVerifier); err != nil {
return "", err
}
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {cfg.RedirectURI},
"scope": {cfg.Scopes},
"code_challenge": {computeCodeChallenge(codeVerifier)},
"code_challenge_method": {"S256"},
"state": {state},
}
if cfg.IDTokenAddOrganizations {
params.Set("id_token_add_organizations", "true")
}
params.Set("codex_cli_simplified_flow", "true")
return cfg.AuthorizeURL + "?" + params.Encode(), nil
}
func (s *Service) HandleOAuthCallback(ctx context.Context, state, code string) (string, error) {
token, err := s.getOAuthTokenByState(ctx, state)
if err != nil {
return "", err
}
providerUUID, err := db.ParseUUID(token.ProviderID)
if err != nil {
return "", err
}
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
if err != nil {
return "", fmt.Errorf("get provider: %w", err)
}
if !supportsOAuth(provider) {
return "", errors.New("provider does not support oauth")
}
cfg := s.oauthConfig(providerMetadata(provider.Metadata))
resp, err := s.exchangeCode(ctx, cfg, code, token.PKCECodeVerifier)
if err != nil {
return "", err
}
if err := s.saveOAuthToken(ctx, provider.ID.String(), providerOAuthToken{
ProviderID: provider.ID.String(),
AccessToken: resp.AccessToken,
RefreshToken: firstNonEmpty(resp.RefreshToken, token.RefreshToken),
ExpiresAt: expiresAtFromNow(resp.ExpiresIn),
Scope: firstNonEmpty(resp.Scope, cfg.Scopes),
TokenType: firstNonEmpty(resp.TokenType, "Bearer"),
State: "",
PKCECodeVerifier: "",
}); err != nil {
return "", err
}
return provider.ID.String(), nil
}
func (s *Service) GetOAuthStatus(ctx context.Context, providerID string) (*OAuthStatus, error) {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
if err != nil {
return nil, fmt.Errorf("get provider: %w", err)
}
status := &OAuthStatus{
Configured: supportsOAuth(provider),
CallbackURL: s.oauthConfig(providerMetadata(provider.Metadata)).RedirectURI,
}
if !status.Configured {
return status, nil
}
token, err := s.getOAuthToken(ctx, providerID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return status, nil
}
return nil, err
}
status.HasToken = strings.TrimSpace(token.AccessToken) != ""
if !token.ExpiresAt.IsZero() {
expiresAt := token.ExpiresAt
status.ExpiresAt = &expiresAt
status.Expired = time.Now().After(token.ExpiresAt)
}
return status, nil
}
func (s *Service) RevokeOAuthToken(ctx context.Context, providerID string) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
if err != nil {
return fmt.Errorf("get provider: %w", err)
}
if !supportsOAuth(provider) {
return errors.New("provider does not support oauth")
}
return s.queries.DeleteLlmProviderOAuthToken(ctx, providerUUID)
}
func (s *Service) GetValidAccessToken(ctx context.Context, providerID string) (string, error) {
token, err := s.getOAuthToken(ctx, providerID)
if err != nil {
return "", err
}
if strings.TrimSpace(token.AccessToken) == "" {
return "", errors.New("oauth token is missing access token")
}
if token.ExpiresAt.IsZero() || time.Now().Add(oauthExpirySkew).Before(token.ExpiresAt) {
return token.AccessToken, nil
}
if strings.TrimSpace(token.RefreshToken) == "" {
return "", errors.New("oauth token expired and no refresh token is available")
}
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return "", err
}
provider, err := s.queries.GetLlmProviderByID(ctx, providerUUID)
if err != nil {
return "", fmt.Errorf("get provider: %w", err)
}
cfg := s.oauthConfig(providerMetadata(provider.Metadata))
refreshed, err := s.refreshAccessToken(ctx, cfg, token.RefreshToken)
if err != nil {
return "", err
}
saved := providerOAuthToken{
ProviderID: providerID,
AccessToken: refreshed.AccessToken,
RefreshToken: firstNonEmpty(refreshed.RefreshToken, token.RefreshToken),
ExpiresAt: expiresAtFromNow(refreshed.ExpiresIn),
Scope: firstNonEmpty(refreshed.Scope, token.Scope),
TokenType: firstNonEmpty(refreshed.TokenType, token.TokenType),
State: token.State,
PKCECodeVerifier: token.PKCECodeVerifier,
}
if err := s.saveOAuthToken(ctx, providerID, saved); err != nil {
return "", err
}
return saved.AccessToken, nil
}
func (s *Service) getOAuthToken(ctx context.Context, providerID string) (*providerOAuthToken, error) {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return nil, err
}
row, err := s.queries.GetLlmProviderOAuthTokenByProvider(ctx, providerUUID)
if err != nil {
return nil, err
}
return toProviderOAuthToken(row), nil
}
func (s *Service) getOAuthTokenByState(ctx context.Context, state string) (*providerOAuthToken, error) {
row, err := s.queries.GetLlmProviderOAuthTokenByState(ctx, state)
if err != nil {
return nil, err
}
return toProviderOAuthToken(row), nil
}
func (s *Service) updateOAuthState(ctx context.Context, providerID, state, codeVerifier string) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
return s.queries.UpdateLlmProviderOAuthState(ctx, sqlc.UpdateLlmProviderOAuthStateParams{
LlmProviderID: providerUUID,
State: state,
PkceCodeVerifier: codeVerifier,
})
}
func (s *Service) saveOAuthToken(ctx context.Context, providerID string, token providerOAuthToken) error {
providerUUID, err := db.ParseUUID(providerID)
if err != nil {
return err
}
var expiresAt pgtype.Timestamptz
if !token.ExpiresAt.IsZero() {
expiresAt = pgtype.Timestamptz{Time: token.ExpiresAt, Valid: true}
}
_, err = s.queries.UpsertLlmProviderOAuthToken(ctx, sqlc.UpsertLlmProviderOAuthTokenParams{
LlmProviderID: providerUUID,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
ExpiresAt: expiresAt,
Scope: token.Scope,
TokenType: token.TokenType,
State: token.State,
PkceCodeVerifier: token.PKCECodeVerifier,
})
return err
}
func toProviderOAuthToken(row sqlc.LlmProviderOauthToken) *providerOAuthToken {
token := &providerOAuthToken{
ProviderID: row.LlmProviderID.String(),
AccessToken: row.AccessToken,
RefreshToken: row.RefreshToken,
Scope: row.Scope,
TokenType: row.TokenType,
State: row.State,
PKCECodeVerifier: row.PkceCodeVerifier,
}
if row.ExpiresAt.Valid {
token.ExpiresAt = row.ExpiresAt.Time
}
return token
}
type openAITokenResponse struct {
AccessToken string `json:"access_token"` //nolint:gosec // OAuth response payload carries runtime access token
RefreshToken string `json:"refresh_token"` //nolint:gosec // OAuth response payload carries runtime refresh token
TokenType string `json:"token_type"`
Scope string `json:"scope"`
ExpiresIn int64 `json:"expires_in"`
Error string `json:"error"`
Description string `json:"error_description"`
}
func (s *Service) exchangeCode(ctx context.Context, cfg openAIOAuthConfig, code, codeVerifier string) (*openAITokenResponse, error) {
values := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"client_id": {cfg.ClientID},
"redirect_uri": {cfg.RedirectURI},
"code_verifier": {codeVerifier},
}
return s.postTokenRequest(ctx, cfg.TokenURL, values)
}
func (s *Service) refreshAccessToken(ctx context.Context, cfg openAIOAuthConfig, refreshToken string) (*openAITokenResponse, error) {
values := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {cfg.ClientID},
}
return s.postTokenRequest(ctx, cfg.TokenURL, values)
}
func (s *Service) postTokenRequest(ctx context.Context, tokenURL string, body url.Values) (*openAITokenResponse, error) {
if err := validateOAuthTokenURL(tokenURL); err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(body.Encode()))
if err != nil {
return nil, fmt.Errorf("create oauth request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
//nolint:gosec // tokenURL is restricted to the fixed OpenAI OAuth host by validateOAuthTokenURL above
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("execute oauth request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read oauth response: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("oauth token request failed: %s", strings.TrimSpace(string(payload)))
}
var tokenResp openAITokenResponse
if err := json.Unmarshal(payload, &tokenResp); err != nil {
return nil, fmt.Errorf("decode oauth response: %w", err)
}
if tokenResp.Error != "" {
return nil, fmt.Errorf("oauth token request failed: %s", firstNonEmpty(tokenResp.Description, tokenResp.Error))
}
return &tokenResp, nil
}
func validateOAuthTokenURL(raw string) error {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return fmt.Errorf("invalid oauth token url: %w", err)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return errors.New("oauth token url must use https")
}
if !strings.EqualFold(parsed.Hostname(), "auth.openai.com") {
return errors.New("oauth token url host must be auth.openai.com")
}
return nil
}
func generateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func generateCodeVerifier() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func computeCodeChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
func expiresAtFromNow(expiresIn int64) time.Time {
if expiresIn <= 0 {
return time.Time{}
}
return time.Now().Add(time.Duration(expiresIn) * time.Second)
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}
+49 -20
View File
@@ -11,6 +11,7 @@ import (
"time"
"github.com/jackc/pgx/v5/pgtype"
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/db"
@@ -20,21 +21,27 @@ import (
// Service handles provider operations.
type Service struct {
queries *sqlc.Queries
logger *slog.Logger
queries *sqlc.Queries
logger *slog.Logger
httpClient *http.Client
callbackURL string
}
// NewService creates a new provider service.
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
func NewService(log *slog.Logger, queries *sqlc.Queries, callbackURL string) *Service {
if log == nil {
log = slog.Default()
}
return &Service{
queries: queries,
logger: log.With(slog.String("service", "providers")),
queries: queries,
logger: log.With(slog.String("service", "providers")),
httpClient: &http.Client{Timeout: providerOAuthHTTPTimeout},
callbackURL: callbackURL,
}
}
// Create creates a new LLM provider.
func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) {
// Marshal metadata
metadataJSON, err := json.Marshal(req.Metadata)
if err != nil {
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
@@ -112,13 +119,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
return GetResponse{}, err
}
// Get existing provider
existing, err := s.queries.GetLlmProviderByID(ctx, providerID)
if err != nil {
return GetResponse{}, fmt.Errorf("get provider: %w", err)
}
// Apply updates
name := existing.Name
if req.Name != nil {
name = *req.Name
@@ -146,16 +151,15 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
enable = *req.Enable
}
metadata := existing.Metadata
metadataMap := providerMetadata(existing.Metadata)
if req.Metadata != nil {
metadataJSON, err := json.Marshal(req.Metadata)
if err != nil {
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
}
metadata = metadataJSON
metadataMap = req.Metadata
}
metadataJSON, err := json.Marshal(metadataMap)
if err != nil {
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
}
// Update provider
updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{
ID: providerID,
Name: name,
@@ -164,7 +168,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
ClientType: clientType,
Icon: icon,
Enable: enable,
Metadata: metadata,
Metadata: metadataJSON,
})
if err != nil {
return GetResponse{}, fmt.Errorf("update provider: %w", err)
@@ -213,8 +217,12 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
baseURL := strings.TrimRight(provider.BaseUrl, "/")
clientType := models.ClientType(provider.ClientType)
creds, err := s.ResolveModelCredentials(ctx, provider)
if err != nil {
return TestResponse{}, err
}
sdkProvider := models.NewSDKProvider(baseURL, provider.ApiKey, clientType, probeTimeout)
sdkProvider := models.NewSDKProvider(baseURL, creds.APIKey, creds.CodexAccountID, clientType, probeTimeout, nil)
start := time.Now()
result := sdkProvider.Test(ctx)
@@ -238,6 +246,29 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
if err != nil {
return nil, fmt.Errorf("get provider: %w", err)
}
if supportsOAuth(provider) {
catalog := openaicodex.Catalog()
remoteModels := make([]RemoteModel, 0, len(catalog))
for _, model := range catalog {
compatibilities := make([]string, 0, 2)
if model.SupportsToolCall {
compatibilities = append(compatibilities, models.CompatToolCall)
}
if model.SupportsReasoning {
compatibilities = append(compatibilities, models.CompatReasoning)
}
remoteModels = append(remoteModels, RemoteModel{
ID: model.ID,
Name: model.DisplayName,
Object: "model",
OwnedBy: "openai-codex",
Type: "chat",
Compatibilities: compatibilities,
ReasoningEfforts: append([]string(nil), model.ReasoningEfforts...),
})
}
return remoteModels, nil
}
baseURL := strings.TrimRight(provider.BaseUrl, "/")
modelsURL := fmt.Sprintf("%s/models", baseURL)
@@ -250,7 +281,7 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
return nil, fmt.Errorf("create request: %w", err)
}
if provider.ApiKey != "" {
if provider.ApiKey != "" && !supportsOAuth(provider) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.ApiKey))
}
@@ -284,7 +315,6 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
}
}
// Mask API key (show only first 8 characters)
maskedAPIKey := maskAPIKey(provider.ApiKey)
var icon string
@@ -318,7 +348,6 @@ func maskAPIKey(apiKey string) string {
}
// resolveUpdatedAPIKey keeps the original key when the request value matches the masked version.
// This prevents masked placeholder values from overwriting the real stored credential.
func resolveUpdatedAPIKey(existing string, updated *string) string {
if updated == nil {
return existing
+17 -4
View File
@@ -55,12 +55,25 @@ type TestResponse struct {
Message string `json:"message,omitempty"`
}
// OAuthStatus is returned by GET /providers/:id/oauth/status.
type OAuthStatus struct {
Configured bool `json:"configured"`
HasToken bool `json:"has_token"`
Expired bool `json:"expired"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CallbackURL string `json:"callback_url"`
}
// RemoteModel represents a model returned by the provider's /v1/models endpoint.
type RemoteModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Name string `json:"name,omitempty"`
Type string `json:"type,omitempty"`
Compatibilities []string `json:"compatibilities,omitempty"`
ReasoningEfforts []string `json:"reasoning_efforts,omitempty"`
}
// FetchModelsResponse represents the response from the provider's /v1/models endpoint.
+6
View File
@@ -90,5 +90,11 @@ func shouldSkipJWT(path string) bool {
if strings.HasPrefix(path, "/email/oauth/callback") {
return true
}
if strings.HasPrefix(path, "/providers/oauth/callback") {
return true
}
if strings.HasPrefix(path, "/auth/callback") {
return true
}
return false
}