mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(models): add image model type support
Add a dedicated image model type so bots can use image API models without overloading chat model capabilities, while keeping existing chat-based image generation selectable.
This commit is contained in:
@@ -41,6 +41,9 @@
|
||||
<SelectItem value="embedding">
|
||||
Embedding
|
||||
</SelectItem>
|
||||
<SelectItem value="image">
|
||||
Image
|
||||
</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
@@ -107,14 +110,14 @@
|
||||
</FormItem>
|
||||
</FormField>
|
||||
|
||||
<!-- Compatibilities (chat only) -->
|
||||
<div v-if="selectedType === 'chat'">
|
||||
<!-- Compatibilities -->
|
||||
<div v-if="selectedCompatibilityOptions.length > 0">
|
||||
<Label class="mb-4">
|
||||
{{ $t('models.compatibilities') }}
|
||||
</Label>
|
||||
<div class="flex flex-wrap gap-3 mt-2">
|
||||
<label
|
||||
v-for="opt in COMPATIBILITY_OPTIONS"
|
||||
v-for="opt in selectedCompatibilityOptions"
|
||||
:key="opt.value"
|
||||
class="flex items-center gap-1.5 text-xs"
|
||||
>
|
||||
@@ -177,7 +180,7 @@ import { useMutation, useQueryCache } from '@pinia/colada'
|
||||
import { postModels, putModelsById, putModelsModelByModelId } from '@memohai/sdk'
|
||||
import type { ModelsGetResponse, ModelsAddRequest, ModelsUpdateRequest } from '@memohai/sdk'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { COMPATIBILITY_OPTIONS } from '@/constants/compatibilities'
|
||||
import { CHAT_COMPATIBILITY_OPTIONS, IMAGE_COMPATIBILITY_OPTIONS } from '@/constants/compatibilities'
|
||||
import FormDialogShell from '@/components/form-dialog-shell/index.vue'
|
||||
import { useDialogMutation } from '@/composables/useDialogMutation'
|
||||
|
||||
@@ -201,6 +204,21 @@ const form = useForm({
|
||||
})
|
||||
|
||||
const selectedType = computed(() => form.values.type || 'chat')
|
||||
const selectedCompatibilityOptions = computed(() => {
|
||||
switch (selectedType.value) {
|
||||
case 'chat':
|
||||
return CHAT_COMPATIBILITY_OPTIONS
|
||||
case 'image':
|
||||
return IMAGE_COMPATIBILITY_OPTIONS
|
||||
default:
|
||||
return []
|
||||
}
|
||||
})
|
||||
|
||||
watch(selectedCompatibilityOptions, (options) => {
|
||||
const allowed = new Set(options.map(option => option.value))
|
||||
selectedCompat.value = selectedCompat.value.filter(value => allowed.has(value))
|
||||
}, { immediate: true })
|
||||
|
||||
const open = inject<Ref<boolean>>('openModel', ref(false))
|
||||
const title = inject<Ref<'edit' | 'title'>>('openModelTitle', ref('title'))
|
||||
@@ -288,8 +306,11 @@ async function addModel() {
|
||||
if (dim) config.dimensions = dim
|
||||
}
|
||||
|
||||
if (type === 'chat') {
|
||||
if (type === 'chat' || type === 'image') {
|
||||
config.compatibilities = selectedCompat.value
|
||||
}
|
||||
|
||||
if (type === 'chat') {
|
||||
const ctxWin = form.values.context_window ?? (isEdit ? fallback!.config?.context_window : undefined)
|
||||
if (ctxWin) config.context_window = ctxWin
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import type { Component } from 'vue'
|
||||
import { Wrench, Eye, Image, Brain } from 'lucide-vue-next'
|
||||
import { Wrench, Eye, Image, Brain, Pencil } from 'lucide-vue-next'
|
||||
|
||||
defineProps<{
|
||||
compatibilities: string[]
|
||||
@@ -25,6 +25,8 @@ const ICONS: Record<string, Component> = {
|
||||
'tool-call': Wrench,
|
||||
'vision': Eye,
|
||||
'image-output': Image,
|
||||
'generate': Image,
|
||||
'edit': Pencil,
|
||||
'reasoning': Brain,
|
||||
}
|
||||
|
||||
@@ -32,6 +34,8 @@ const CLASSES: Record<string, string> = {
|
||||
'tool-call': 'bg-blue-50 text-blue-700 dark:bg-blue-950 dark:text-blue-300',
|
||||
'vision': 'bg-purple-50 text-purple-700 dark:bg-purple-950 dark:text-purple-300',
|
||||
'image-output': 'bg-pink-50 text-pink-700 dark:bg-pink-950 dark:text-pink-300',
|
||||
'generate': 'bg-pink-50 text-pink-700 dark:bg-pink-950 dark:text-pink-300',
|
||||
'edit': 'bg-emerald-50 text-emerald-700 dark:bg-emerald-950 dark:text-emerald-300',
|
||||
'reasoning': 'bg-amber-50 text-amber-700 dark:bg-amber-950 dark:text-amber-300',
|
||||
}
|
||||
|
||||
|
||||
@@ -3,9 +3,19 @@ export interface CompatibilityMeta {
|
||||
label: string
|
||||
}
|
||||
|
||||
export const COMPATIBILITY_OPTIONS: CompatibilityMeta[] = [
|
||||
export const CHAT_COMPATIBILITY_OPTIONS: CompatibilityMeta[] = [
|
||||
{ value: 'vision', label: 'Vision' },
|
||||
{ value: 'tool-call', label: 'Tool Call' },
|
||||
{ value: 'image-output', label: 'Image Output' },
|
||||
{ value: 'reasoning', label: 'Reasoning' },
|
||||
]
|
||||
|
||||
export const IMAGE_COMPATIBILITY_OPTIONS: CompatibilityMeta[] = [
|
||||
{ value: 'generate', label: 'Generate' },
|
||||
{ value: 'edit', label: 'Edit' },
|
||||
]
|
||||
|
||||
export const COMPATIBILITY_OPTIONS: CompatibilityMeta[] = [
|
||||
...CHAT_COMPATIBILITY_OPTIONS,
|
||||
...IMAGE_COMPATIBILITY_OPTIONS,
|
||||
]
|
||||
|
||||
@@ -257,6 +257,8 @@
|
||||
"vision": "Vision",
|
||||
"tool-call": "Tool Call",
|
||||
"image-output": "Image Output",
|
||||
"generate": "Generate",
|
||||
"edit": "Edit",
|
||||
"reasoning": "Reasoning"
|
||||
},
|
||||
"contextWindow": "Context Window",
|
||||
@@ -916,7 +918,7 @@
|
||||
"ttsModel": "TTS Model",
|
||||
"ttsModelPlaceholder": "Select TTS model",
|
||||
"imageModel": "Image Generation Model",
|
||||
"imageModelDescription": "Model used for the generate_image tool. Must support image-output compatibility.",
|
||||
"imageModelDescription": "Model used for image tools. Supports chat models with image-output or image models with generate/edit compatibility.",
|
||||
"imageModelPlaceholder": "Select image model (optional)",
|
||||
"language": "Language",
|
||||
"reasoningEnabled": "Enable Reasoning",
|
||||
|
||||
@@ -253,6 +253,8 @@
|
||||
"vision": "视觉",
|
||||
"tool-call": "工具调用",
|
||||
"image-output": "图片生成",
|
||||
"generate": "生成",
|
||||
"edit": "编辑",
|
||||
"reasoning": "推理"
|
||||
},
|
||||
"contextWindow": "上下文窗口",
|
||||
@@ -912,7 +914,7 @@
|
||||
"ttsModel": "语音合成模型",
|
||||
"ttsModelPlaceholder": "选择语音合成模型",
|
||||
"imageModel": "图片生成模型",
|
||||
"imageModelDescription": "用于 generate_image 工具的模型,必须支持 image-output 兼容性。",
|
||||
"imageModelDescription": "用于图片工具的模型。可选择支持 image-output 的 chat 模型,或支持 generate/edit 的 image 模型。",
|
||||
"imageModelPlaceholder": "选择图片模型(可选)",
|
||||
"language": "语言",
|
||||
"reasoningEnabled": "启用推理",
|
||||
|
||||
@@ -197,7 +197,6 @@
|
||||
v-model="form.image_model_id"
|
||||
:models="imageCapableModels"
|
||||
:providers="providers"
|
||||
model-type="chat"
|
||||
:placeholder="$t('bots.settings.imageModelPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
@@ -488,7 +487,15 @@ const { mutateAsync: deleteBot, isLoading: deleteLoading } = useMutation({
|
||||
const models = computed(() => modelData.value ?? [])
|
||||
const providers = computed(() => providerData.value ?? [])
|
||||
const imageCapableModels = computed(() =>
|
||||
models.value.filter((m) => m.config?.compatibilities?.includes('image-output')),
|
||||
models.value.filter((m) => {
|
||||
if (m.type === 'chat') {
|
||||
return m.config?.compatibilities?.includes('image-output')
|
||||
}
|
||||
if (m.type === 'image') {
|
||||
return m.config?.compatibilities?.includes('generate') || m.config?.compatibilities?.includes('edit')
|
||||
}
|
||||
return false
|
||||
}),
|
||||
)
|
||||
const searchProviders = computed(() => (searchProviderData.value ?? []).filter((p) => p.enable !== false))
|
||||
const memoryProviders = computed(() => memoryProviderData.value ?? [])
|
||||
|
||||
@@ -93,7 +93,7 @@ export interface ModelOption {
|
||||
const props = defineProps<{
|
||||
models: ModelsGetResponse[]
|
||||
providers: ProvidersGetResponse[]
|
||||
modelType: 'chat' | 'embedding'
|
||||
modelType?: 'chat' | 'embedding' | 'image'
|
||||
open?: boolean
|
||||
}>()
|
||||
|
||||
@@ -118,7 +118,7 @@ const providerMap = computed(() => {
|
||||
})
|
||||
|
||||
const typeFilteredModels = computed(() =>
|
||||
props.models.filter((m) => m.type === props.modelType),
|
||||
props.modelType ? props.models.filter((m) => m.type === props.modelType) : props.models,
|
||||
)
|
||||
|
||||
const options = computed<ModelOption[]>(() =>
|
||||
|
||||
@@ -41,7 +41,7 @@ import ModelOptions from './model-options.vue'
|
||||
const props = defineProps<{
|
||||
models: ModelsGetResponse[]
|
||||
providers: ProvidersGetResponse[]
|
||||
modelType: 'chat' | 'embedding'
|
||||
modelType?: 'chat' | 'embedding' | 'image'
|
||||
placeholder?: string
|
||||
}>()
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ import {
|
||||
Button,
|
||||
Spinner,
|
||||
} from '@memohai/ui'
|
||||
import { RefreshCw, Settings, Trash2, MessageSquare, Binary } from 'lucide-vue-next'
|
||||
import { RefreshCw, Settings, Trash2, MessageSquare, Binary, Image } from 'lucide-vue-next'
|
||||
import ConfirmPopover from '@/components/confirm-popover/index.vue'
|
||||
import ModelCapabilities from '@/components/model-capabilities/index.vue'
|
||||
import ContextWindowBadge from '@/components/context-window-badge/index.vue'
|
||||
@@ -128,7 +128,14 @@ const testResult = ref<ModelsTestResponse | null>(null)
|
||||
const reasoningEfforts = computed(() => ((props.model.config as ModelConfigWithReasoning | undefined)?.reasoning_efforts ?? []))
|
||||
|
||||
const typeIcon = computed(() => {
|
||||
return props.model.type === 'embedding' ? Binary : MessageSquare
|
||||
switch (props.model.type) {
|
||||
case 'embedding':
|
||||
return Binary
|
||||
case 'image':
|
||||
return Image
|
||||
default:
|
||||
return MessageSquare
|
||||
}
|
||||
})
|
||||
|
||||
const statusDotClass = computed(() => {
|
||||
|
||||
@@ -92,7 +92,7 @@ CREATE TABLE IF NOT EXISTS models (
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT models_provider_id_model_id_unique UNIQUE (provider_id, model_id),
|
||||
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding', 'speech'))
|
||||
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding', 'image', 'speech'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS model_variants (
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
-- 0067_add_image_model_type (rollback)
|
||||
-- Remove image from the supported model types.
|
||||
ALTER TABLE models DROP CONSTRAINT IF EXISTS models_type_check;
|
||||
|
||||
ALTER TABLE models
|
||||
ADD CONSTRAINT models_type_check
|
||||
CHECK (type IN ('chat', 'embedding', 'speech'));
|
||||
@@ -0,0 +1,7 @@
|
||||
-- 0067_add_image_model_type
|
||||
-- Add image as a supported model type.
|
||||
ALTER TABLE models DROP CONSTRAINT IF EXISTS models_type_check;
|
||||
|
||||
ALTER TABLE models
|
||||
ADD CONSTRAINT models_type_check
|
||||
CHECK (type IN ('chat', 'embedding', 'image', 'speech'));
|
||||
@@ -29,6 +29,11 @@ type ImageGenProvider struct {
|
||||
dataMount string
|
||||
}
|
||||
|
||||
type generatedImageFile struct {
|
||||
Data string
|
||||
MediaType string
|
||||
}
|
||||
|
||||
func NewImageGenProvider(
|
||||
log *slog.Logger,
|
||||
settingsSvc *settings.Service,
|
||||
@@ -65,6 +70,10 @@ func (p *ImageGenProvider) Tools(ctx context.Context, session SessionContext) ([
|
||||
if strings.TrimSpace(botSettings.ImageModelID) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
modelResp, err := p.models.GetByID(ctx, botSettings.ImageModelID)
|
||||
if err != nil || !supportsImageGeneration(modelResp) {
|
||||
return nil, nil
|
||||
}
|
||||
sess := session
|
||||
return []sdk.Tool{
|
||||
{
|
||||
@@ -112,7 +121,7 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||
}
|
||||
if !modelResp.HasCompatibility(models.CompatImageOutput) {
|
||||
if !supportsImageGeneration(modelResp) {
|
||||
return nil, errors.New("configured model does not support image generation")
|
||||
}
|
||||
|
||||
@@ -127,43 +136,9 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
|
||||
return nil, fmt.Errorf("failed to resolve provider credentials: %w", err)
|
||||
}
|
||||
|
||||
sdkModel := models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ModelID: modelResp.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: creds.APIKey,
|
||||
BaseURL: providers.ProviderConfigString(provider, "base_url"),
|
||||
})
|
||||
|
||||
userMsg := fmt.Sprintf("Generate an image with the following description. Size: %s\n\n%s", size, prompt)
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
sdk.WithModel(sdkModel),
|
||||
sdk.WithMessages([]sdk.Message{
|
||||
{Role: sdk.MessageRoleUser, Content: []sdk.MessagePart{sdk.TextPart{Text: userMsg}}},
|
||||
}),
|
||||
)
|
||||
file, imgBytes, ext, err := generateImage(ctx, modelResp, provider, creds, prompt, size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("image generation failed: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Files) == 0 {
|
||||
if result.Text != "" {
|
||||
return map[string]any{"error": "no image generated", "model_response": result.Text}, nil
|
||||
}
|
||||
return nil, errors.New("no image was generated by the model")
|
||||
}
|
||||
|
||||
file := result.Files[0]
|
||||
imgBytes, err := base64.StdEncoding.DecodeString(file.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode generated image: %w", err)
|
||||
}
|
||||
|
||||
ext := "png"
|
||||
switch {
|
||||
case strings.Contains(file.MediaType, "jpeg"), strings.Contains(file.MediaType, "jpg"):
|
||||
ext = "jpg"
|
||||
case strings.Contains(file.MediaType, "webp"):
|
||||
ext = "webp"
|
||||
return nil, err
|
||||
}
|
||||
|
||||
containerPath := fmt.Sprintf("%s/%d.%s", imageGenDir, time.Now().UnixMilli(), ext)
|
||||
@@ -196,3 +171,138 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
|
||||
"size_bytes": len(imgBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func supportsImageGeneration(model models.GetResponse) bool {
|
||||
switch model.Type {
|
||||
case models.ModelTypeChat:
|
||||
return model.HasCompatibility(models.CompatImageOutput)
|
||||
case models.ModelTypeImage:
|
||||
return model.HasCompatibility(models.CompatGenerate)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func generateImage(
|
||||
ctx context.Context,
|
||||
modelResp models.GetResponse,
|
||||
provider sqlc.Provider,
|
||||
creds providers.ModelCredentials,
|
||||
prompt string,
|
||||
size string,
|
||||
) (generatedImageFile, []byte, string, error) {
|
||||
switch modelResp.Type {
|
||||
case models.ModelTypeChat:
|
||||
return generateImageFromChatModel(ctx, modelResp, provider, creds, prompt, size)
|
||||
case models.ModelTypeImage:
|
||||
return generateImageFromImageModel(ctx, modelResp, provider, creds, prompt, size)
|
||||
default:
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("unsupported image model type: %s", modelResp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func generateImageFromChatModel(
|
||||
ctx context.Context,
|
||||
modelResp models.GetResponse,
|
||||
provider sqlc.Provider,
|
||||
creds providers.ModelCredentials,
|
||||
prompt string,
|
||||
size string,
|
||||
) (generatedImageFile, []byte, string, error) {
|
||||
sdkModel := models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ModelID: modelResp.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: creds.APIKey,
|
||||
BaseURL: providers.ProviderConfigString(provider, "base_url"),
|
||||
})
|
||||
|
||||
userMsg := fmt.Sprintf("Generate an image with the following description. Size: %s\n\n%s", size, prompt)
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
sdk.WithModel(sdkModel),
|
||||
sdk.WithMessages([]sdk.Message{
|
||||
{Role: sdk.MessageRoleUser, Content: []sdk.MessagePart{sdk.TextPart{Text: userMsg}}},
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("image generation failed: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Files) == 0 {
|
||||
if result.Text != "" {
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("no image generated: %s", result.Text)
|
||||
}
|
||||
return generatedImageFile{}, nil, "", errors.New("no image was generated by the model")
|
||||
}
|
||||
|
||||
file := generatedImageFile{
|
||||
Data: result.Files[0].Data,
|
||||
MediaType: result.Files[0].MediaType,
|
||||
}
|
||||
imgBytes, ext, err := decodeGeneratedImage(file)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", err
|
||||
}
|
||||
return file, imgBytes, ext, nil
|
||||
}
|
||||
|
||||
func generateImageFromImageModel(
|
||||
ctx context.Context,
|
||||
modelResp models.GetResponse,
|
||||
provider sqlc.Provider,
|
||||
creds providers.ModelCredentials,
|
||||
prompt string,
|
||||
size string,
|
||||
) (generatedImageFile, []byte, string, error) {
|
||||
imageModel := models.NewSDKImageGenerationModel(models.SDKModelConfig{
|
||||
ModelID: modelResp.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: creds.APIKey,
|
||||
BaseURL: providers.ProviderConfigString(provider, "base_url"),
|
||||
})
|
||||
if imageModel == nil {
|
||||
return generatedImageFile{}, nil, "", errors.New("configured provider does not support image generation API")
|
||||
}
|
||||
|
||||
result, err := sdk.GenerateImage(ctx,
|
||||
sdk.WithImageGenerationModel(imageModel),
|
||||
sdk.WithImagePrompt(prompt),
|
||||
sdk.WithImageSize(size),
|
||||
sdk.WithImageResponseFormat("b64_json"),
|
||||
sdk.WithImageOutputFormat("png"),
|
||||
)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("image generation failed: %w", err)
|
||||
}
|
||||
if len(result.Data) == 0 {
|
||||
return generatedImageFile{}, nil, "", errors.New("no image was generated by the model")
|
||||
}
|
||||
if strings.TrimSpace(result.Data[0].B64JSON) == "" {
|
||||
return generatedImageFile{}, nil, "", errors.New("image model did not return inline image data")
|
||||
}
|
||||
|
||||
file := generatedImageFile{
|
||||
Data: result.Data[0].B64JSON,
|
||||
MediaType: "image/png",
|
||||
}
|
||||
imgBytes, ext, err := decodeGeneratedImage(file)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", err
|
||||
}
|
||||
return file, imgBytes, ext, nil
|
||||
}
|
||||
|
||||
func decodeGeneratedImage(file generatedImageFile) ([]byte, string, error) {
|
||||
imgBytes, err := base64.StdEncoding.DecodeString(file.Data)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to decode generated image: %w", err)
|
||||
}
|
||||
|
||||
ext := "png"
|
||||
switch {
|
||||
case strings.Contains(file.MediaType, "jpeg"), strings.Contains(file.MediaType, "jpg"):
|
||||
ext = "jpg"
|
||||
case strings.Contains(file.MediaType, "webp"):
|
||||
ext = "webp"
|
||||
}
|
||||
return imgBytes, ext, nil
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func (h *ModelsHandler) Create(c echo.Context) error {
|
||||
// @Summary List all models
|
||||
// @Description Get a list of all configured models, optionally filtered by type or provider client type
|
||||
// @Tags models
|
||||
// @Param type query string false "Model type (chat, embedding)"
|
||||
// @Param type query string false "Model type (chat, embedding, image, speech)"
|
||||
// @Param client_type query string false "Provider client type (openai-responses, openai-completions, anthropic-messages, google-generative-ai)"
|
||||
// @Success 200 {array} models.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
|
||||
@@ -332,12 +332,20 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
|
||||
|
||||
for _, m := range remoteModels {
|
||||
modelType := models.ModelTypeChat
|
||||
if strings.TrimSpace(m.Type) == string(models.ModelTypeEmbedding) {
|
||||
switch strings.TrimSpace(m.Type) {
|
||||
case string(models.ModelTypeEmbedding):
|
||||
modelType = models.ModelTypeEmbedding
|
||||
case string(models.ModelTypeImage):
|
||||
modelType = models.ModelTypeImage
|
||||
}
|
||||
compatibilities := m.Compatibilities
|
||||
if len(compatibilities) == 0 {
|
||||
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
|
||||
switch modelType {
|
||||
case models.ModelTypeImage:
|
||||
compatibilities = []string{models.CompatGenerate, models.CompatEdit}
|
||||
case models.ModelTypeChat:
|
||||
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
|
||||
}
|
||||
}
|
||||
name := strings.TrimSpace(m.Name)
|
||||
if name == "" {
|
||||
|
||||
@@ -128,7 +128,7 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) {
|
||||
|
||||
// ListByType returns models filtered by type (chat, embedding, or speech).
|
||||
func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return nil, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func (s *Service) ListEnabled(ctx context.Context) ([]GetResponse, error) {
|
||||
|
||||
// ListEnabledByType returns models from enabled providers filtered by type.
|
||||
func (s *Service) ListEnabledByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return nil, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
dbModels, err := s.queries.ListEnabledModelsByType(ctx, string(modelType))
|
||||
@@ -206,7 +206,7 @@ func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]Ge
|
||||
|
||||
// ListByProviderIDAndType returns models filtered by provider ID and type.
|
||||
func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string, modelType ModelType) ([]GetResponse, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return nil, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
@@ -361,7 +361,7 @@ func (s *Service) Count(ctx context.Context) (int64, error) {
|
||||
|
||||
// CountByType returns the number of models of a specific type.
|
||||
func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return 0, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,19 @@ func TestModel_Validate(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid image model",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-image-1",
|
||||
Name: "GPT Image 1",
|
||||
ProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeImage,
|
||||
Config: models.ModelConfig{
|
||||
Compatibilities: []string{models.CompatGenerate, models.CompatEdit},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model_id",
|
||||
model: models.Model{
|
||||
@@ -129,12 +142,14 @@ func TestModel_HasCompatibility(t *testing.T) {
|
||||
assert.True(t, m.HasCompatibility("tool-call"))
|
||||
assert.True(t, m.HasCompatibility("reasoning"))
|
||||
assert.False(t, m.HasCompatibility("image-output"))
|
||||
assert.False(t, m.HasCompatibility("generate"))
|
||||
}
|
||||
|
||||
func TestModelTypes(t *testing.T) {
|
||||
t.Run("ModelType constants", func(t *testing.T) {
|
||||
assert.Equal(t, models.ModelTypeChat, models.ModelType("chat"))
|
||||
assert.Equal(t, models.ModelTypeEmbedding, models.ModelType("embedding"))
|
||||
assert.Equal(t, models.ModelTypeImage, models.ModelType("image"))
|
||||
})
|
||||
|
||||
t.Run("ClientType constants", func(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
||||
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openaiimages "github.com/memohai/twilight-ai/provider/openai/images"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
@@ -121,6 +122,40 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
}
|
||||
}
|
||||
|
||||
func NewSDKImageGenerationModel(cfg SDKModelConfig) *sdk.ImageGenerationModel {
|
||||
opts := imageProviderOptions(cfg)
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
return openaiimages.New(opts...).GenerationModel(cfg.ModelID)
|
||||
}
|
||||
|
||||
func NewSDKImageEditModel(cfg SDKModelConfig) *sdk.ImageEditModel {
|
||||
opts := imageProviderOptions(cfg)
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
return openaiimages.New(opts...).EditModel(cfg.ModelID)
|
||||
}
|
||||
|
||||
func imageProviderOptions(cfg SDKModelConfig) []openaiimages.Option {
|
||||
switch ClientType(cfg.ClientType) {
|
||||
case ClientTypeOpenAICompletions, ClientTypeOpenAIResponses:
|
||||
opts := []openaiimages.Option{
|
||||
openaiimages.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, openaiimages.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, openaiimages.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
return opts
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// BuildReasoningOptions returns SDK generation options for reasoning/thinking.
|
||||
func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
|
||||
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
|
||||
|
||||
@@ -11,6 +11,7 @@ type ModelType string
|
||||
const (
|
||||
ModelTypeChat ModelType = "chat"
|
||||
ModelTypeEmbedding ModelType = "embedding"
|
||||
ModelTypeImage ModelType = "image"
|
||||
ModelTypeSpeech ModelType = "speech"
|
||||
)
|
||||
|
||||
@@ -30,6 +31,8 @@ const (
|
||||
CompatVision = "vision"
|
||||
CompatToolCall = "tool-call"
|
||||
CompatImageOutput = "image-output"
|
||||
CompatGenerate = "generate"
|
||||
CompatEdit = "edit"
|
||||
CompatReasoning = "reasoning"
|
||||
)
|
||||
|
||||
@@ -43,7 +46,12 @@ const (
|
||||
|
||||
// validCompatibilities enumerates accepted compatibility tokens.
|
||||
var validCompatibilities = map[string]struct{}{
|
||||
CompatVision: {}, CompatToolCall: {}, CompatImageOutput: {}, CompatReasoning: {},
|
||||
CompatVision: {},
|
||||
CompatToolCall: {},
|
||||
CompatImageOutput: {},
|
||||
CompatGenerate: {},
|
||||
CompatEdit: {},
|
||||
CompatReasoning: {},
|
||||
}
|
||||
|
||||
var validReasoningEfforts = map[string]struct{}{
|
||||
@@ -70,6 +78,15 @@ type Model struct {
|
||||
Config ModelConfig `json:"config"`
|
||||
}
|
||||
|
||||
func IsValidModelType(modelType ModelType) bool {
|
||||
switch modelType {
|
||||
case ModelTypeChat, ModelTypeEmbedding, ModelTypeImage, ModelTypeSpeech:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.ModelID == "" {
|
||||
return errors.New("model ID is required")
|
||||
@@ -80,7 +97,7 @@ func (m *Model) Validate() error {
|
||||
if _, err := uuid.Parse(m.ProviderID); err != nil {
|
||||
return errors.New("provider ID must be a valid UUID")
|
||||
}
|
||||
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding && m.Type != ModelTypeSpeech {
|
||||
if !IsValidModelType(m.Type) {
|
||||
return errors.New("invalid model type")
|
||||
}
|
||||
if m.Type == ModelTypeEmbedding {
|
||||
|
||||
@@ -1354,7 +1354,7 @@ export type ModelsModelConfig = {
|
||||
reasoning_efforts?: Array<string>;
|
||||
};
|
||||
|
||||
export type ModelsModelType = 'chat' | 'embedding' | 'speech';
|
||||
export type ModelsModelType = 'chat' | 'embedding' | 'image' | 'speech';
|
||||
|
||||
export type ModelsTestResponse = {
|
||||
latency_ms?: number;
|
||||
@@ -7084,7 +7084,7 @@ export type GetModelsData = {
|
||||
path?: never;
|
||||
query?: {
|
||||
/**
|
||||
* Model type (chat, embedding)
|
||||
* Model type (chat, embedding, image, speech)
|
||||
*/
|
||||
type?: string;
|
||||
/**
|
||||
|
||||
+3
-1
@@ -6572,7 +6572,7 @@ const docTemplate = `{
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Model type (chat, embedding)",
|
||||
"description": "Model type (chat, embedding, image, speech)",
|
||||
"name": "type",
|
||||
"in": "query"
|
||||
},
|
||||
@@ -12188,11 +12188,13 @@ const docTemplate = `{
|
||||
"enum": [
|
||||
"chat",
|
||||
"embedding",
|
||||
"image",
|
||||
"speech"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ModelTypeChat",
|
||||
"ModelTypeEmbedding",
|
||||
"ModelTypeImage",
|
||||
"ModelTypeSpeech"
|
||||
]
|
||||
},
|
||||
|
||||
+3
-1
@@ -6563,7 +6563,7 @@
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Model type (chat, embedding)",
|
||||
"description": "Model type (chat, embedding, image, speech)",
|
||||
"name": "type",
|
||||
"in": "query"
|
||||
},
|
||||
@@ -12179,11 +12179,13 @@
|
||||
"enum": [
|
||||
"chat",
|
||||
"embedding",
|
||||
"image",
|
||||
"speech"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ModelTypeChat",
|
||||
"ModelTypeEmbedding",
|
||||
"ModelTypeImage",
|
||||
"ModelTypeSpeech"
|
||||
]
|
||||
},
|
||||
|
||||
+3
-1
@@ -2261,11 +2261,13 @@ definitions:
|
||||
enum:
|
||||
- chat
|
||||
- embedding
|
||||
- image
|
||||
- speech
|
||||
type: string
|
||||
x-enum-varnames:
|
||||
- ModelTypeChat
|
||||
- ModelTypeEmbedding
|
||||
- ModelTypeImage
|
||||
- ModelTypeSpeech
|
||||
models.TestResponse:
|
||||
properties:
|
||||
@@ -7231,7 +7233,7 @@ paths:
|
||||
description: Get a list of all configured models, optionally filtered by type
|
||||
or provider client type
|
||||
parameters:
|
||||
- description: Model type (chat, embedding)
|
||||
- description: Model type (chat, embedding, image, speech)
|
||||
in: query
|
||||
name: type
|
||||
type: string
|
||||
|
||||
Reference in New Issue
Block a user