diff --git a/apps/web/src/i18n/locales/en.json b/apps/web/src/i18n/locales/en.json index d41b1d34..c5cd4034 100644 --- a/apps/web/src/i18n/locales/en.json +++ b/apps/web/src/i18n/locales/en.json @@ -843,6 +843,9 @@ "memoryHealthUnavailable": "Unavailable", "ttsModel": "TTS Model", "ttsModelPlaceholder": "Select TTS model", + "imageModel": "Image Generation Model", + "imageModelDescription": "Model used for the generate_image tool. Must support image-output compatibility.", + "imageModelPlaceholder": "Select image model (optional)", "language": "Language", "reasoningEnabled": "Enable Reasoning", "reasoningEffort": "Reasoning Effort", diff --git a/apps/web/src/i18n/locales/zh.json b/apps/web/src/i18n/locales/zh.json index bee49d79..1c20aa2b 100644 --- a/apps/web/src/i18n/locales/zh.json +++ b/apps/web/src/i18n/locales/zh.json @@ -839,6 +839,9 @@ "memoryHealthUnavailable": "暂不可用", "ttsModel": "语音合成模型", "ttsModelPlaceholder": "选择语音合成模型", + "imageModel": "图片生成模型", + "imageModelDescription": "用于 generate_image 工具的模型,必须支持 image-output 兼容性。", + "imageModelPlaceholder": "选择图片模型(可选)", "language": "语言", "reasoningEnabled": "启用推理", "reasoningEffort": "推理等级", diff --git a/apps/web/src/pages/bots/components/bot-settings.vue b/apps/web/src/pages/bots/components/bot-settings.vue index b0a2611b..96ab75a1 100644 --- a/apps/web/src/pages/bots/components/bot-settings.vue +++ b/apps/web/src/pages/bots/components/bot-settings.vue @@ -187,6 +187,21 @@ /> + +
+ +

+ {{ $t('bots.settings.imageModelDescription') }} +

+ +
+
@@ -426,6 +441,9 @@ const { mutateAsync: deleteBot, isLoading: deleteLoading } = useMutation({ const models = computed(() => modelData.value ?? []) const providers = computed(() => providerData.value ?? []) +const imageCapableModels = computed(() => + models.value.filter((m) => m.config?.compatibilities?.includes('image-output')), +) const searchProviders = computed(() => (searchProviderData.value ?? []).filter((p) => p.enable !== false)) const memoryProviders = computed(() => memoryProviderData.value ?? []) const ttsProviders = computed(() => (ttsProviderData.value ?? []).filter((p) => p.enable !== false)) @@ -437,6 +455,7 @@ const browserContexts = computed(() => browserContextData.value ?? []) const form = reactive({ chat_model_id: '', title_model_id: '', + image_model_id: '', search_provider_id: '', memory_provider_id: '', tts_model_id: '', @@ -574,6 +593,7 @@ watch(settings, (val) => { if (val) { form.chat_model_id = val.chat_model_id ?? '' form.title_model_id = val.title_model_id ?? '' + form.image_model_id = val.image_model_id ?? '' form.search_provider_id = val.search_provider_id ?? '' form.memory_provider_id = val.memory_provider_id ?? '' form.tts_model_id = val.tts_model_id ?? '' @@ -590,6 +610,7 @@ const hasChanges = computed(() => { let changed = form.chat_model_id !== (s.chat_model_id ?? '') || form.title_model_id !== (s.title_model_id ?? '') + || form.image_model_id !== (s.image_model_id ?? '') || form.search_provider_id !== (s.search_provider_id ?? '') || form.memory_provider_id !== (s.memory_provider_id ?? '') || form.tts_model_id !== (s.tts_model_id ?? '') diff --git a/cmd/agent/main.go b/cmd/agent/main.go index f80bda6d..d60f16be 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -647,6 +647,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c agenttools.NewSkillProvider(log), agenttools.NewBrowserProvider(log, settingsService, browserContextService, manager, cfg.BrowserGateway), agenttools.NewTTSProvider(log, settingsService, ttsService, channelManager, registry), + agenttools.NewImageGenProvider(log, settingsService, modelsService, queries, manager, config.DefaultDataMount), agenttools.NewFederationProvider(log, fedSource), agenttools.NewHistoryProvider(log, sessionService, queries), } diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index ce45f56c..ed20ef40 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -547,6 +547,7 @@ func provideToolProviders(log *slog.Logger, cfg config.Config, channelManager *c agenttools.NewSkillProvider(log), agenttools.NewBrowserProvider(log, settingsService, browserContextService, manager, cfg.BrowserGateway), agenttools.NewTTSProvider(log, settingsService, ttsService, channelManager, registry), + agenttools.NewImageGenProvider(log, settingsService, modelsService, queries, manager, config.DefaultDataMount), agenttools.NewFederationProvider(log, fedSource), agenttools.NewHistoryProvider(log, sessionService, queries), } diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 260f0d39..8f69b063 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -177,6 +177,7 @@ CREATE TABLE IF NOT EXISTS bots ( compaction_ratio INTEGER NOT NULL DEFAULT 80, compaction_model_id UUID REFERENCES models(id) ON DELETE SET NULL, title_model_id UUID REFERENCES models(id) ON DELETE SET NULL, + image_model_id UUID REFERENCES models(id) ON DELETE SET NULL, tts_model_id UUID REFERENCES tts_models(id) ON DELETE SET NULL, browser_context_id UUID REFERENCES browser_contexts(id) ON DELETE SET NULL, metadata JSONB NOT NULL DEFAULT '{}'::jsonb, diff --git a/db/migrations/0053_add_image_model.down.sql b/db/migrations/0053_add_image_model.down.sql new file mode 100644 index 00000000..12caa3d2 --- /dev/null +++ b/db/migrations/0053_add_image_model.down.sql @@ -0,0 +1,3 @@ +-- 0053_add_image_model (rollback) +-- Remove image_model_id column from bots table +ALTER TABLE bots DROP COLUMN IF EXISTS image_model_id; diff --git a/db/migrations/0053_add_image_model.up.sql b/db/migrations/0053_add_image_model.up.sql new file mode 100644 index 00000000..96dd6817 --- /dev/null +++ b/db/migrations/0053_add_image_model.up.sql @@ -0,0 +1,3 @@ +-- 0053_add_image_model +-- Add image_model_id column to bots table for image generation model configuration +ALTER TABLE bots ADD COLUMN IF NOT EXISTS image_model_id UUID REFERENCES models(id) ON DELETE SET NULL; diff --git a/db/queries/settings.sql b/db/queries/settings.sql index 4a67c2c7..75208f1d 100644 --- a/db/queries/settings.sql +++ b/db/queries/settings.sql @@ -16,6 +16,7 @@ SELECT title_models.id AS title_model_id, search_providers.id AS search_provider_id, memory_providers.id AS memory_provider_id, + image_models.id AS image_model_id, tts_models.id AS tts_model_id, browser_contexts.id AS browser_context_id FROM bots @@ -23,6 +24,7 @@ LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = bots.heartbeat_model_id LEFT JOIN models AS compaction_models ON compaction_models.id = bots.compaction_model_id LEFT JOIN models AS title_models ON title_models.id = bots.title_model_id +LEFT JOIN models AS image_models ON image_models.id = bots.image_model_id LEFT JOIN search_providers ON search_providers.id = bots.search_provider_id LEFT JOIN memory_providers ON memory_providers.id = bots.memory_provider_id LEFT JOIN tts_models ON tts_models.id = bots.tts_model_id @@ -47,11 +49,12 @@ WITH updated AS ( title_model_id = COALESCE(sqlc.narg(title_model_id)::uuid, bots.title_model_id), search_provider_id = COALESCE(sqlc.narg(search_provider_id)::uuid, bots.search_provider_id), memory_provider_id = COALESCE(sqlc.narg(memory_provider_id)::uuid, bots.memory_provider_id), + image_model_id = COALESCE(sqlc.narg(image_model_id)::uuid, bots.image_model_id), tts_model_id = COALESCE(sqlc.narg(tts_model_id)::uuid, bots.tts_model_id), browser_context_id = COALESCE(sqlc.narg(browser_context_id)::uuid, bots.browser_context_id), updated_at = now() WHERE bots.id = sqlc.arg(id) - RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id + RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id ) SELECT updated.id AS bot_id, @@ -70,6 +73,7 @@ SELECT title_models.id AS title_model_id, search_providers.id AS search_provider_id, memory_providers.id AS memory_provider_id, + image_models.id AS image_model_id, tts_models.id AS tts_model_id, browser_contexts.id AS browser_context_id FROM updated @@ -77,6 +81,7 @@ LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = updated.heartbeat_model_id LEFT JOIN models AS compaction_models ON compaction_models.id = updated.compaction_model_id LEFT JOIN models AS title_models ON title_models.id = updated.title_model_id +LEFT JOIN models AS image_models ON image_models.id = updated.image_model_id LEFT JOIN search_providers ON search_providers.id = updated.search_provider_id LEFT JOIN memory_providers ON memory_providers.id = updated.memory_provider_id LEFT JOIN tts_models ON tts_models.id = updated.tts_model_id @@ -97,6 +102,7 @@ SET language = 'auto', heartbeat_model_id = NULL, compaction_model_id = NULL, title_model_id = NULL, + image_model_id = NULL, search_provider_id = NULL, memory_provider_id = NULL, tts_model_id = NULL, diff --git a/internal/agent/tools/image_gen.go b/internal/agent/tools/image_gen.go new file mode 100644 index 00000000..c0263578 --- /dev/null +++ b/internal/agent/tools/image_gen.go @@ -0,0 +1,198 @@ +package tools + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + sdk "github.com/memohai/twilight-ai/sdk" + + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/providers" + "github.com/memohai/memoh/internal/settings" + "github.com/memohai/memoh/internal/workspace/bridge" +) + +const imageGenDir = "/data/generated-images" + +type ImageGenProvider struct { + logger *slog.Logger + settings *settings.Service + models *models.Service + queries *sqlc.Queries + containers bridge.Provider + dataMount string +} + +func NewImageGenProvider( + log *slog.Logger, + settingsSvc *settings.Service, + modelsSvc *models.Service, + queries *sqlc.Queries, + containers bridge.Provider, + dataMount string, +) *ImageGenProvider { + if log == nil { + log = slog.Default() + } + return &ImageGenProvider{ + logger: log.With(slog.String("tool", "image_gen")), + settings: settingsSvc, + models: modelsSvc, + queries: queries, + containers: containers, + dataMount: dataMount, + } +} + +func (p *ImageGenProvider) Tools(ctx context.Context, session SessionContext) ([]sdk.Tool, error) { + if session.IsSubagent || p.settings == nil || p.models == nil || p.queries == nil { + return nil, nil + } + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return nil, nil + } + botSettings, err := p.settings.GetBot(ctx, botID) + if err != nil { + return nil, nil + } + if strings.TrimSpace(botSettings.ImageModelID) == "" { + return nil, nil + } + sess := session + return []sdk.Tool{ + { + Name: "generate_image", + Description: "Generate an image from a text description using the configured image generation model. Returns the file path of the generated image in the workspace.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "prompt": map[string]any{"type": "string", "description": "Detailed description of the image to generate"}, + "size": map[string]any{"type": "string", "description": "Image size, e.g. 1024x1024, 1792x1024, 1024x1792. Defaults to 1024x1024."}, + }, + "required": []string{"prompt"}, + }, + Execute: func(execCtx *sdk.ToolExecContext, input any) (any, error) { + return p.execGenerateImage(execCtx.Context, sess, inputAsMap(input)) + }, + }, + }, nil +} + +func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session SessionContext, args map[string]any) (any, error) { + botID := strings.TrimSpace(session.BotID) + if botID == "" { + return nil, errors.New("bot_id is required") + } + prompt := strings.TrimSpace(StringArg(args, "prompt")) + if prompt == "" { + return nil, errors.New("prompt is required") + } + size := strings.TrimSpace(StringArg(args, "size")) + if size == "" { + size = "1024x1024" + } + + botSettings, err := p.settings.GetBot(ctx, botID) + if err != nil { + return nil, errors.New("failed to load bot settings") + } + imageModelID := strings.TrimSpace(botSettings.ImageModelID) + if imageModelID == "" { + return nil, errors.New("no image generation model configured") + } + + modelResp, err := p.models.GetByID(ctx, imageModelID) + if err != nil { + return nil, fmt.Errorf("failed to load image model: %w", err) + } + if !modelResp.HasCompatibility(models.CompatImageOutput) { + return nil, errors.New("configured model does not support image generation") + } + + provider, err := models.FetchProviderByID(ctx, p.queries, modelResp.LlmProviderID) + if err != nil { + return nil, fmt.Errorf("failed to load model provider: %w", err) + } + + authResolver := providers.NewService(nil, p.queries, "") + creds, err := authResolver.ResolveModelCredentials(ctx, provider) + if err != nil { + return nil, fmt.Errorf("failed to resolve provider credentials: %w", err) + } + + sdkModel := models.NewSDKChatModel(models.SDKModelConfig{ + ModelID: modelResp.ModelID, + ClientType: provider.ClientType, + APIKey: creds.APIKey, + BaseURL: provider.BaseUrl, + }) + + userMsg := fmt.Sprintf("Generate an image with the following description. Size: %s\n\n%s", size, prompt) + result, err := sdk.GenerateTextResult(ctx, + sdk.WithModel(sdkModel), + sdk.WithMessages([]sdk.Message{ + {Role: sdk.MessageRoleUser, Content: []sdk.MessagePart{sdk.TextPart{Text: userMsg}}}, + }), + ) + if err != nil { + return nil, fmt.Errorf("image generation failed: %w", err) + } + + if len(result.Files) == 0 { + if result.Text != "" { + return map[string]any{"error": "no image generated", "model_response": result.Text}, nil + } + return nil, errors.New("no image was generated by the model") + } + + file := result.Files[0] + imgBytes, err := base64.StdEncoding.DecodeString(file.Data) + if err != nil { + return nil, fmt.Errorf("failed to decode generated image: %w", err) + } + + ext := "png" + switch { + case strings.Contains(file.MediaType, "jpeg"), strings.Contains(file.MediaType, "jpg"): + ext = "jpg" + case strings.Contains(file.MediaType, "webp"): + ext = "webp" + } + + containerPath := fmt.Sprintf("%s/%d.%s", imageGenDir, time.Now().UnixMilli(), ext) + + client, clientErr := p.containers.MCPClient(ctx, botID) + if clientErr != nil { + return map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "Image generated (container not reachable, not saved to disk)"}, + {"type": "image", "data": file.Data, "mimeType": file.MediaType}, + }, + }, nil + } + + mkdirCmd := fmt.Sprintf("mkdir -p %s", imageGenDir) + _, _ = client.Exec(ctx, mkdirCmd, "/", 5) + + if writeErr := client.WriteFile(ctx, containerPath, imgBytes); writeErr != nil { + return map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": fmt.Sprintf("Image generated (failed to save: %s)", writeErr.Error())}, + {"type": "image", "data": file.Data, "mimeType": file.MediaType}, + }, + }, nil + } + + return map[string]any{ + "path": containerPath, + "media_type": file.MediaType, + "size_bytes": len(imgBytes), + }, nil +} diff --git a/internal/db/sqlc/conversations.sql.go b/internal/db/sqlc/conversations.sql.go index 6b7f50bb..4df79315 100644 --- a/internal/db/sqlc/conversations.sql.go +++ b/internal/db/sqlc/conversations.sql.go @@ -511,7 +511,7 @@ WITH updated AS ( SET display_name = $1, updated_at = now() WHERE bots.id = $2 - RETURNING id, owner_user_id, display_name, avatar_url, timezone, is_active, status, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, heartbeat_model_id, compaction_enabled, compaction_threshold, compaction_ratio, compaction_model_id, title_model_id, tts_model_id, browser_context_id, metadata, created_at, updated_at, acl_default_effect + RETURNING id, owner_user_id, display_name, avatar_url, timezone, is_active, status, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, heartbeat_model_id, compaction_enabled, compaction_threshold, compaction_ratio, compaction_model_id, title_model_id, image_model_id, tts_model_id, browser_context_id, metadata, created_at, updated_at, acl_default_effect ) SELECT updated.id AS id, diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 53204cfc..539a9c41 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -31,6 +31,7 @@ type Bot struct { CompactionRatio int32 `json:"compaction_ratio"` CompactionModelID pgtype.UUID `json:"compaction_model_id"` TitleModelID pgtype.UUID `json:"title_model_id"` + ImageModelID pgtype.UUID `json:"image_model_id"` TtsModelID pgtype.UUID `json:"tts_model_id"` BrowserContextID pgtype.UUID `json:"browser_context_id"` Metadata []byte `json:"metadata"` diff --git a/internal/db/sqlc/settings.sql.go b/internal/db/sqlc/settings.sql.go index 11b2ead8..28873e3a 100644 --- a/internal/db/sqlc/settings.sql.go +++ b/internal/db/sqlc/settings.sql.go @@ -26,6 +26,7 @@ SET language = 'auto', heartbeat_model_id = NULL, compaction_model_id = NULL, title_model_id = NULL, + image_model_id = NULL, search_provider_id = NULL, memory_provider_id = NULL, tts_model_id = NULL, @@ -57,6 +58,7 @@ SELECT title_models.id AS title_model_id, search_providers.id AS search_provider_id, memory_providers.id AS memory_provider_id, + image_models.id AS image_model_id, tts_models.id AS tts_model_id, browser_contexts.id AS browser_context_id FROM bots @@ -64,6 +66,7 @@ LEFT JOIN models AS chat_models ON chat_models.id = bots.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = bots.heartbeat_model_id LEFT JOIN models AS compaction_models ON compaction_models.id = bots.compaction_model_id LEFT JOIN models AS title_models ON title_models.id = bots.title_model_id +LEFT JOIN models AS image_models ON image_models.id = bots.image_model_id LEFT JOIN search_providers ON search_providers.id = bots.search_provider_id LEFT JOIN memory_providers ON memory_providers.id = bots.memory_provider_id LEFT JOIN tts_models ON tts_models.id = bots.tts_model_id @@ -88,6 +91,7 @@ type GetSettingsByBotIDRow struct { TitleModelID pgtype.UUID `json:"title_model_id"` SearchProviderID pgtype.UUID `json:"search_provider_id"` MemoryProviderID pgtype.UUID `json:"memory_provider_id"` + ImageModelID pgtype.UUID `json:"image_model_id"` TtsModelID pgtype.UUID `json:"tts_model_id"` BrowserContextID pgtype.UUID `json:"browser_context_id"` } @@ -112,6 +116,7 @@ func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSe &i.TitleModelID, &i.SearchProviderID, &i.MemoryProviderID, + &i.ImageModelID, &i.TtsModelID, &i.BrowserContextID, ) @@ -136,11 +141,12 @@ WITH updated AS ( title_model_id = COALESCE($13::uuid, bots.title_model_id), search_provider_id = COALESCE($14::uuid, bots.search_provider_id), memory_provider_id = COALESCE($15::uuid, bots.memory_provider_id), - tts_model_id = COALESCE($16::uuid, bots.tts_model_id), - browser_context_id = COALESCE($17::uuid, bots.browser_context_id), + image_model_id = COALESCE($16::uuid, bots.image_model_id), + tts_model_id = COALESCE($17::uuid, bots.tts_model_id), + browser_context_id = COALESCE($18::uuid, bots.browser_context_id), updated_at = now() - WHERE bots.id = $18 - RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id + WHERE bots.id = $19 + RETURNING bots.id, bots.language, bots.reasoning_enabled, bots.reasoning_effort, bots.heartbeat_enabled, bots.heartbeat_interval, bots.heartbeat_prompt, bots.compaction_enabled, bots.compaction_threshold, bots.compaction_ratio, bots.chat_model_id, bots.heartbeat_model_id, bots.compaction_model_id, bots.title_model_id, bots.image_model_id, bots.search_provider_id, bots.memory_provider_id, bots.tts_model_id, bots.browser_context_id ) SELECT updated.id AS bot_id, @@ -159,6 +165,7 @@ SELECT title_models.id AS title_model_id, search_providers.id AS search_provider_id, memory_providers.id AS memory_provider_id, + image_models.id AS image_model_id, tts_models.id AS tts_model_id, browser_contexts.id AS browser_context_id FROM updated @@ -166,6 +173,7 @@ LEFT JOIN models AS chat_models ON chat_models.id = updated.chat_model_id LEFT JOIN models AS heartbeat_models ON heartbeat_models.id = updated.heartbeat_model_id LEFT JOIN models AS compaction_models ON compaction_models.id = updated.compaction_model_id LEFT JOIN models AS title_models ON title_models.id = updated.title_model_id +LEFT JOIN models AS image_models ON image_models.id = updated.image_model_id LEFT JOIN search_providers ON search_providers.id = updated.search_provider_id LEFT JOIN memory_providers ON memory_providers.id = updated.memory_provider_id LEFT JOIN tts_models ON tts_models.id = updated.tts_model_id @@ -188,6 +196,7 @@ type UpsertBotSettingsParams struct { TitleModelID pgtype.UUID `json:"title_model_id"` SearchProviderID pgtype.UUID `json:"search_provider_id"` MemoryProviderID pgtype.UUID `json:"memory_provider_id"` + ImageModelID pgtype.UUID `json:"image_model_id"` TtsModelID pgtype.UUID `json:"tts_model_id"` BrowserContextID pgtype.UUID `json:"browser_context_id"` ID pgtype.UUID `json:"id"` @@ -210,6 +219,7 @@ type UpsertBotSettingsRow struct { TitleModelID pgtype.UUID `json:"title_model_id"` SearchProviderID pgtype.UUID `json:"search_provider_id"` MemoryProviderID pgtype.UUID `json:"memory_provider_id"` + ImageModelID pgtype.UUID `json:"image_model_id"` TtsModelID pgtype.UUID `json:"tts_model_id"` BrowserContextID pgtype.UUID `json:"browser_context_id"` } @@ -231,6 +241,7 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa arg.TitleModelID, arg.SearchProviderID, arg.MemoryProviderID, + arg.ImageModelID, arg.TtsModelID, arg.BrowserContextID, arg.ID, @@ -253,6 +264,7 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa &i.TitleModelID, &i.SearchProviderID, &i.MemoryProviderID, + &i.ImageModelID, &i.TtsModelID, &i.BrowserContextID, ) diff --git a/internal/settings/service.go b/internal/settings/service.go index 8adf86df..df2f9635 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -131,6 +131,14 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest } titleModelUUID = modelID } + imageModelUUID := pgtype.UUID{} + if value := strings.TrimSpace(req.ImageModelID); value != "" { + modelID, err := s.resolveModelUUID(ctx, value) + if err != nil { + return Settings{}, err + } + imageModelUUID = modelID + } searchProviderUUID := pgtype.UUID{} if value := strings.TrimSpace(req.SearchProviderID); value != "" { providerID, err := db.ParseUUID(value) @@ -178,6 +186,7 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest HeartbeatModelID: heartbeatModelUUID, CompactionModelID: compactionModelUUID, TitleModelID: titleModelUUID, + ImageModelID: imageModelUUID, SearchProviderID: searchProviderUUID, MemoryProviderID: memoryProviderUUID, TtsModelID: ttsModelUUID, @@ -269,6 +278,7 @@ func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings { row.HeartbeatModelID, row.CompactionModelID, row.TitleModelID, + row.ImageModelID, row.SearchProviderID, row.MemoryProviderID, row.TtsModelID, @@ -290,6 +300,7 @@ func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings { row.HeartbeatModelID, row.CompactionModelID, row.TitleModelID, + row.ImageModelID, row.SearchProviderID, row.MemoryProviderID, row.TtsModelID, @@ -310,6 +321,7 @@ func normalizeBotSettingsFields( heartbeatModelID pgtype.UUID, compactionModelID pgtype.UUID, titleModelID pgtype.UUID, + imageModelID pgtype.UUID, searchProviderID pgtype.UUID, memoryProviderID pgtype.UUID, ttsModelID pgtype.UUID, @@ -328,6 +340,9 @@ func normalizeBotSettingsFields( if titleModelID.Valid { settings.TitleModelID = uuid.UUID(titleModelID.Bytes).String() } + if imageModelID.Valid { + settings.ImageModelID = uuid.UUID(imageModelID.Bytes).String() + } if searchProviderID.Valid { settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String() } diff --git a/internal/settings/types.go b/internal/settings/types.go index 9cc9835b..fcabc4a5 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -8,6 +8,7 @@ const ( type Settings struct { ChatModelID string `json:"chat_model_id"` + ImageModelID string `json:"image_model_id"` SearchProviderID string `json:"search_provider_id"` MemoryProviderID string `json:"memory_provider_id"` TtsModelID string `json:"tts_model_id"` @@ -28,6 +29,7 @@ type Settings struct { type UpsertRequest struct { ChatModelID string `json:"chat_model_id,omitempty"` + ImageModelID string `json:"image_model_id,omitempty"` SearchProviderID string `json:"search_provider_id,omitempty"` MemoryProviderID string `json:"memory_provider_id,omitempty"` TtsModelID string `json:"tts_model_id,omitempty"` diff --git a/packages/sdk/src/types.gen.ts b/packages/sdk/src/types.gen.ts index 05c7873d..7d1736ae 100644 --- a/packages/sdk/src/types.gen.ts +++ b/packages/sdk/src/types.gen.ts @@ -1544,6 +1544,7 @@ export type SettingsSettings = { heartbeat_enabled?: boolean; heartbeat_interval?: number; heartbeat_model_id?: string; + image_model_id?: string; language?: string; memory_provider_id?: string; reasoning_effort?: string; @@ -1564,6 +1565,7 @@ export type SettingsUpsertRequest = { heartbeat_enabled?: boolean; heartbeat_interval?: number; heartbeat_model_id?: string; + image_model_id?: string; language?: string; memory_provider_id?: string; reasoning_effort?: string; diff --git a/spec/docs.go b/spec/docs.go index b69e1478..7f91e0dd 100644 --- a/spec/docs.go +++ b/spec/docs.go @@ -13029,6 +13029,9 @@ const docTemplate = `{ "heartbeat_model_id": { "type": "string" }, + "image_model_id": { + "type": "string" + }, "language": { "type": "string" }, @@ -13085,6 +13088,9 @@ const docTemplate = `{ "heartbeat_model_id": { "type": "string" }, + "image_model_id": { + "type": "string" + }, "language": { "type": "string" }, diff --git a/spec/swagger.json b/spec/swagger.json index 0892fc3c..1cb0ea99 100644 --- a/spec/swagger.json +++ b/spec/swagger.json @@ -13020,6 +13020,9 @@ "heartbeat_model_id": { "type": "string" }, + "image_model_id": { + "type": "string" + }, "language": { "type": "string" }, @@ -13076,6 +13079,9 @@ "heartbeat_model_id": { "type": "string" }, + "image_model_id": { + "type": "string" + }, "language": { "type": "string" }, diff --git a/spec/swagger.yaml b/spec/swagger.yaml index e23340a8..2b145b37 100644 --- a/spec/swagger.yaml +++ b/spec/swagger.yaml @@ -2571,6 +2571,8 @@ definitions: type: integer heartbeat_model_id: type: string + image_model_id: + type: string language: type: string memory_provider_id: @@ -2608,6 +2610,8 @@ definitions: type: integer heartbeat_model_id: type: string + image_model_id: + type: string language: type: string memory_provider_id: