feat: models import (#164)

This commit is contained in:
Acbox Liu
2026-03-03 15:53:52 +08:00
committed by GitHub
parent 450cc30a9f
commit 5982bc6a42
18 changed files with 669 additions and 32 deletions
+68
View File
@@ -1,6 +1,8 @@
package handlers
import (
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
@@ -36,6 +38,7 @@ func (h *ProvidersHandler) Register(e *echo.Echo) {
group.DELETE("/:id", h.Delete)
group.GET("/count", h.Count)
group.POST("/:id/test", h.Test)
group.POST("/:id/import-models", h.ImportModels)
}
// Create godoc
@@ -282,3 +285,68 @@ func (h *ProvidersHandler) Test(c echo.Context) error {
return c.JSON(http.StatusOK, resp)
}
// ImportModels godoc
// @Summary Import models from provider
// @Description Fetch models from provider's /v1/models endpoint and import them
// @Tags providers
// @Accept json
// @Produce json
// @Param id path string true "Provider ID (UUID)"
// @Param request body providers.ImportModelsRequest true "Import configuration"
// @Success 200 {object} providers.ImportModelsResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /providers/{id}/import-models [post]
func (h *ProvidersHandler) ImportModels(c echo.Context) error {
id := c.Param("id")
if id == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
var req providers.ImportModelsRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if req.ClientType == "" {
req.ClientType = string(models.ClientTypeOpenAICompletions)
}
remoteModels, err := h.service.FetchRemoteModels(c.Request().Context(), id)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("fetch remote models: %v", err))
}
resp := providers.ImportModelsResponse{
Models: make([]string, 0),
}
for _, m := range remoteModels {
// Try to create the model
_, err := h.modelsService.Create(c.Request().Context(), models.AddRequest{
ModelID: m.ID,
Name: m.ID,
LlmProviderID: id,
ClientType: models.ClientType(req.ClientType),
Type: models.ModelTypeChat,
InputModalities: []string{models.ModelInputText},
})
if err != nil {
if errors.Is(err, models.ErrModelIDAlreadyExists) {
resp.Skipped++
continue
}
// Log error but continue with other models
h.logger.Warn("failed to import model", slog.String("model_id", m.ID), slog.Any("error", err))
continue
}
resp.Created++
resp.Models = append(resp.Models, m.ID)
}
return c.JSON(http.StatusOK, resp)
}
+46
View File
@@ -188,6 +188,52 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
}, nil
}
// FetchRemoteModels fetches models from the provider's /v1/models endpoint.
func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteModel, error) {
providerID, err := db.ParseUUID(id)
if err != nil {
return nil, err
}
provider, err := s.queries.GetLlmProviderByID(ctx, providerID)
if err != nil {
return nil, fmt.Errorf("get provider: %w", err)
}
baseURL := strings.TrimRight(provider.BaseUrl, "/")
modelsURL := fmt.Sprintf("%s/models", baseURL)
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
if provider.ApiKey != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.ApiKey))
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))
}
var fetchResp FetchModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&fetchResp); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
return fetchResp.Data, nil
}
func probeReachable(ctx context.Context, baseURL string) (bool, string) {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
+26
View File
@@ -46,3 +46,29 @@ type TestResponse struct {
LatencyMs int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
}
// RemoteModel represents a model returned by the provider's /v1/models endpoint
type RemoteModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// FetchModelsResponse represents the response from the provider's /v1/models endpoint
type FetchModelsResponse struct {
Object string `json:"object"`
Data []RemoteModel `json:"data"`
}
// ImportModelsRequest represents a request to import models from a provider
type ImportModelsRequest struct {
ClientType string `json:"client_type"`
}
// ImportModelsResponse represents the response for importing models
type ImportModelsResponse struct {
Created int `json:"created"`
Skipped int `json:"skipped"`
Models []string `json:"models"`
}
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+51
View File
@@ -1040,6 +1040,16 @@ export type ProvidersGetResponse = {
updated_at?: string;
};
export type ProvidersImportModelsRequest = {
client_type?: string;
};
export type ProvidersImportModelsResponse = {
created?: number;
models?: Array<string>;
skipped?: number;
};
export type ProvidersTestResponse = {
latency_ms?: number;
message?: string;
@@ -6151,6 +6161,47 @@ export type PutProvidersByIdResponses = {
export type PutProvidersByIdResponse = PutProvidersByIdResponses[keyof PutProvidersByIdResponses];
export type PostProvidersByIdImportModelsData = {
/**
* Import configuration
*/
body: ProvidersImportModelsRequest;
path: {
/**
* Provider ID (UUID)
*/
id: string;
};
query?: never;
url: '/providers/{id}/import-models';
};
export type PostProvidersByIdImportModelsErrors = {
/**
* Bad Request
*/
400: HandlersErrorResponse;
/**
* Not Found
*/
404: HandlersErrorResponse;
/**
* Internal Server Error
*/
500: HandlersErrorResponse;
};
export type PostProvidersByIdImportModelsError = PostProvidersByIdImportModelsErrors[keyof PostProvidersByIdImportModelsErrors];
export type PostProvidersByIdImportModelsResponses = {
/**
* OK
*/
200: ProvidersImportModelsResponse;
};
export type PostProvidersByIdImportModelsResponse = PostProvidersByIdImportModelsResponses[keyof PostProvidersByIdImportModelsResponses];
export type GetProvidersByIdModelsData = {
body?: never;
path: {
@@ -90,6 +90,53 @@
</FormControl>
</FormItem>
</FormField>
<Separator />
<FormField
v-slot="{ value, handleChange }"
name="auto_import"
>
<FormItem class="flex flex-row items-center justify-between rounded-lg border p-3 shadow-sm">
<div class="space-y-0.5">
<Label class="text-base">
{{ $t('provider.autoImport') }}
</Label>
<p class="text-[0.8rem] text-muted-foreground">
{{ $t('provider.autoImportHint') }}
</p>
</div>
<FormControl>
<Switch
:model-value="value"
@update:model-value="handleChange"
/>
</FormControl>
</FormItem>
</FormField>
<FormField
v-if="form.values.auto_import"
v-slot="{ value, handleChange }"
name="client_type"
>
<FormItem>
<Label class="mb-2">
{{ $t('models.importClientType') }}
</Label>
<FormControl>
<SearchableSelectPopover
:model-value="value"
:options="CLIENT_TYPE_LIST"
:placeholder="$t('models.clientTypePlaceholder')"
@update:model-value="handleChange"
/>
</FormControl>
<p class="text-[0.8rem] text-muted-foreground">
{{ $t('models.importClientTypeHint') }}
</p>
</FormItem>
</FormField>
</div>
</template>
</FormDialogShell>
@@ -103,26 +150,52 @@ import {
FormControl,
FormItem,
Label,
Switch,
Separator,
} from '@memoh/ui'
import { toTypedSchema } from '@vee-validate/zod'
import z from 'zod'
import { useForm,Form,Field } from 'vee-validate'
import { useMutation, useQueryCache } from '@pinia/colada'
import { postProviders } from '@memoh/sdk'
import { postProviders, postProvidersByIdImportModels } from '@memoh/sdk'
import { useI18n } from 'vue-i18n'
import FormDialogShell from '@/components/form-dialog-shell/index.vue'
import { useDialogMutation } from '@/composables/useDialogMutation'
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
import { CLIENT_TYPE_LIST } from '@/constants/client-types'
import { toast } from 'vue-sonner'
const open = defineModel<boolean>('open')
const { t } = useI18n()
const { run } = useDialogMutation()
const queryCache = useQueryCache()
const { mutateAsync: createProviderMutation, isLoading } = useMutation({
mutation: async (data: Record<string, unknown>) => {
const { data: result } = await postProviders({ body: data as any, throwOnError: true })
const payload = {
...data,
metadata: { additionalProp1: {} },
}
const { data: result } = await postProviders({ body: payload as any, throwOnError: true })
if (data.auto_import && result?.id) {
try {
const { data: importResult } = await postProvidersByIdImportModels({
path: { id: result.id },
body: { client_type: data.client_type as string },
throwOnError: true,
})
if (importResult) {
toast.success(t('models.importSuccess', {
created: importResult.created,
skipped: importResult.skipped,
}))
}
}
catch (e) {
console.error('Auto import failed:', e)
toast.error(t('models.importFailed'))
}
}
return result
},
onSettled: () => queryCache.invalidateQueries({ key: ['providers'] }),
@@ -132,13 +205,16 @@ const providerSchema = toTypedSchema(z.object({
api_key: z.string().min(1),
base_url: z.string().min(1),
name: z.string().min(1),
metadata: z.object({
additionalProp1: z.object({}),
}),
auto_import: z.boolean().optional(),
client_type: z.string().optional(),
}))
const form = useForm({
validationSchema: providerSchema,
initialValues: {
auto_import: false,
client_type: 'openai-completions',
},
})
const createProvider = form.handleSubmit(async (value) => {
@@ -148,6 +224,7 @@ const createProvider = form.handleSubmit(async (value) => {
fallbackMessage: t('common.saveFailed'),
onSuccess: () => {
open.value = false
form.resetForm()
},
},
)
@@ -0,0 +1,102 @@
<template>
<FormDialogShell
v-model:open="open"
:title="$t('models.importModels')"
:cancel-text="$t('common.cancel')"
:submit-text="$t('common.import')"
:submit-disabled="!clientType"
:loading="isLoading"
@submit="handleImport"
>
<template #trigger>
<Button
variant="outline"
class="flex items-center gap-2"
>
<FontAwesomeIcon :icon="['fas', 'file-import']" />
{{ $t('models.importModels') }}
</Button>
</template>
<template #body>
<div class="flex flex-col gap-3 mt-4">
<Label class="mb-2">
{{ $t('models.importClientType') }}
</Label>
<SearchableSelectPopover
v-model="clientType"
:options="clientTypeOptions"
:placeholder="$t('models.clientTypePlaceholder')"
class="w-full"
/>
<p class="text-[0.8rem] text-muted-foreground">
{{ $t('models.importClientTypeHint') }}
</p>
</div>
</template>
</FormDialogShell>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue'
import { useI18n } from 'vue-i18n'
import { useMutation, useQueryCache } from '@pinia/colada'
import { postProvidersByIdImportModels } from '@memoh/sdk'
import { toast } from 'vue-sonner'
import { Button, Label } from '@memoh/ui'
import FormDialogShell from '@/components/form-dialog-shell/index.vue'
import SearchableSelectPopover from '@/components/searchable-select-popover/index.vue'
import { CLIENT_TYPE_LIST, CLIENT_TYPE_META } from '@/constants/client-types'
import { useDialogMutation } from '@/composables/useDialogMutation'
const props = defineProps<{
providerId: string
}>()
const open = ref(false)
const { t } = useI18n()
const { run } = useDialogMutation()
const queryCache = useQueryCache()
const clientType = ref('openai-completions')
const clientTypeOptions = computed(() =>
CLIENT_TYPE_LIST.map((ct) => ({
value: ct.value,
label: ct.label,
description: ct.hint,
keywords: [ct.label, ct.hint, CLIENT_TYPE_META[ct.value]?.value ?? ct.value],
})),
)
const { mutateAsync: importModelsMutation, isLoading } = useMutation({
mutation: async () => {
const { data } = await postProvidersByIdImportModels({
path: { id: props.providerId },
body: { client_type: clientType.value },
throwOnError: true,
})
return data
},
onSettled: () => {
queryCache.invalidateQueries({ key: ['provider-models'] })
},
})
async function handleImport() {
await run(
() => importModelsMutation(),
{
fallbackMessage: t('models.importFailed'),
onSuccess: (data) => {
if (data) {
toast.success(t('models.importSuccess', {
created: data.created,
skipped: data.skipped,
}))
}
open.value = false
},
},
)
}
</script>
@@ -42,8 +42,8 @@
>
</div>
<ScrollArea
class="max-h-64"
<div
class="max-h-64 overflow-y-auto"
role="listbox"
>
<div
@@ -112,7 +112,7 @@
</slot>
</button>
</div>
</ScrollArea>
</div>
</PopoverContent>
</Popover>
</template>
@@ -123,7 +123,6 @@ import {
PopoverTrigger,
PopoverContent,
Button,
ScrollArea,
} from '@memoh/ui'
import { computed, ref, watch } from 'vue'
+9 -2
View File
@@ -169,7 +169,12 @@
"testOk": "OK",
"testAuthError": "Auth Error",
"testError": "Error",
"testFailed": "Test failed"
"testFailed": "Test failed",
"importModels": "Import Models",
"importSuccess": "Successfully imported {created} models, skipped {skipped}",
"importFailed": "Failed to import models",
"importClientType": "Model Client Type",
"importClientTypeHint": "Set default client type for imported models"
},
"provider": {
"add": "Add Provider",
@@ -185,7 +190,9 @@
"testConnection": "Test Connection",
"reachable": "Reachable",
"unreachable": "Unreachable",
"testFailed": "Test failed"
"testFailed": "Test failed",
"autoImport": "Auto Import Models",
"autoImportHint": "Automatically fetch and import models from the provider after creation"
},
"searchProvider": {
"title": "Search Providers",
+9 -2
View File
@@ -165,7 +165,12 @@
"testOk": "正常",
"testAuthError": "认证失败",
"testError": "异常",
"testFailed": "测试失败"
"testFailed": "测试失败",
"importModels": "导入模型",
"importSuccess": "成功导入 {created} 个模型,跳过 {skipped} 个",
"importFailed": "导入模型失败",
"importClientType": "模型客户端类型",
"importClientTypeHint": "为导入的模型设置默认客户端类型"
},
"provider": {
"add": "添加服务商",
@@ -181,7 +186,9 @@
"testConnection": "测试连接",
"reachable": "可连接",
"unreachable": "不可连接",
"testFailed": "测试失败"
"testFailed": "测试失败",
"autoImport": "自动导入模型",
"autoImportHint": "创建后自动从服务商获取并导入模型"
},
"searchProvider": {
"title": "搜索提供方",
@@ -91,7 +91,7 @@ import {
import ConfirmPopover from '@/components/confirm-popover/index.vue'
import { postModelsByIdTest } from '@memoh/sdk'
import type { ModelsGetResponse, ModelsTestResponse } from '@memoh/sdk'
import { ref, computed, onMounted } from 'vue'
import { ref, computed } from 'vue'
const props = defineProps<{
model: ModelsGetResponse
@@ -132,7 +132,4 @@ async function runTest() {
}
}
onMounted(() => {
runTest()
})
</script>
@@ -4,10 +4,13 @@
<h4 class="scroll-m-20 font-semibold tracking-tight">
{{ $t('models.title') }}
</h4>
<CreateModel
<div
v-if="providerId"
:id="providerId"
/>
class="flex items-center gap-2 ml-auto"
>
<ImportModelsDialog :provider-id="providerId" />
<CreateModel :id="providerId" />
</div>
</section>
<section
@@ -50,6 +53,7 @@ import {
EmptyTitle,
} from '@memoh/ui'
import CreateModel from '@/components/create-model/index.vue'
import ImportModelsDialog from '@/components/import-models-dialog/index.vue'
import ModelItem from './model-item.vue'
import type { ModelsGetResponse } from '@memoh/sdk'
@@ -73,9 +73,8 @@
:disabled="!props.provider?.id"
@click="runTest"
>
<Spinner v-if="testLoading" />
<FontAwesomeIcon
v-else
v-if="!testLoading"
:icon="['fas', 'rotate']"
/>
{{ $t('provider.testConnection') }}
@@ -197,9 +196,10 @@ async function runTest() {
}
}
watch(() => props.provider?.id, (newId) => {
if (newId) runTest()
}, { immediate: true })
watch(() => props.provider?.id, () => {
testResult.value = null
testError.value = ''
})
const providerSchema = toTypedSchema(z.object({
name: z.string().min(1),
+84
View File
@@ -6002,6 +6002,65 @@ const docTemplate = `{
}
}
},
"/providers/{id}/import-models": {
"post": {
"description": "Fetch models from provider's /v1/models endpoint and import them",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"providers"
],
"summary": "Import models from provider",
"parameters": [
{
"type": "string",
"description": "Provider ID (UUID)",
"name": "id",
"in": "path",
"required": true
},
{
"description": "Import configuration",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/providers.ImportModelsRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/providers.ImportModelsResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/providers/{id}/models": {
"get": {
"description": "Get models for a provider by id, optionally filtered by type",
@@ -9378,6 +9437,31 @@ const docTemplate = `{
}
}
},
"providers.ImportModelsRequest": {
"type": "object",
"properties": {
"client_type": {
"type": "string"
}
}
},
"providers.ImportModelsResponse": {
"type": "object",
"properties": {
"created": {
"type": "integer"
},
"models": {
"type": "array",
"items": {
"type": "string"
}
},
"skipped": {
"type": "integer"
}
}
},
"providers.TestResponse": {
"type": "object",
"properties": {
+84
View File
@@ -5993,6 +5993,65 @@
}
}
},
"/providers/{id}/import-models": {
"post": {
"description": "Fetch models from provider's /v1/models endpoint and import them",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"providers"
],
"summary": "Import models from provider",
"parameters": [
{
"type": "string",
"description": "Provider ID (UUID)",
"name": "id",
"in": "path",
"required": true
},
{
"description": "Import configuration",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/providers.ImportModelsRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/providers.ImportModelsResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/providers/{id}/models": {
"get": {
"description": "Get models for a provider by id, optionally filtered by type",
@@ -9369,6 +9428,31 @@
}
}
},
"providers.ImportModelsRequest": {
"type": "object",
"properties": {
"client_type": {
"type": "string"
}
}
},
"providers.ImportModelsResponse": {
"type": "object",
"properties": {
"created": {
"type": "integer"
},
"models": {
"type": "array",
"items": {
"type": "string"
}
},
"skipped": {
"type": "integer"
}
}
},
"providers.TestResponse": {
"type": "object",
"properties": {
+55
View File
@@ -1699,6 +1699,22 @@ definitions:
updated_at:
type: string
type: object
providers.ImportModelsRequest:
properties:
client_type:
type: string
type: object
providers.ImportModelsResponse:
properties:
created:
type: integer
models:
items:
type: string
type: array
skipped:
type: integer
type: object
providers.TestResponse:
properties:
latency_ms:
@@ -5991,6 +6007,45 @@ paths:
summary: Update provider
tags:
- providers
/providers/{id}/import-models:
post:
consumes:
- application/json
description: Fetch models from provider's /v1/models endpoint and import them
parameters:
- description: Provider ID (UUID)
in: path
name: id
required: true
type: string
- description: Import configuration
in: body
name: request
required: true
schema:
$ref: '#/definitions/providers.ImportModelsRequest'
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/providers.ImportModelsResponse'
"400":
description: Bad Request
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"404":
description: Not Found
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/handlers.ErrorResponse'
summary: Import models from provider
tags:
- providers
/providers/{id}/models:
get:
description: Get models for a provider by id, optionally filtered by type