From c731e0ca1d32e3729d4395643fbf922f8c88d493 Mon Sep 17 00:00:00 2001 From: Acbox Date: Mon, 2 Feb 2026 01:39:21 +0800 Subject: [PATCH] feat: move default model into user settings --- cmd/agent/main.go | 6 +- db/migrations/0001_init.up.sql | 10 +- db/queries/models.sql | 14 +- db/queries/settings.sql | 11 +- docs/docs.go | 321 +++++++++++++++---------------- docs/swagger.json | 321 +++++++++++++++---------------- docs/swagger.yaml | 237 ++++++++++++----------- internal/chat/resolver.go | 97 ++++++---- internal/db/sqlc/models.go | 4 +- internal/db/sqlc/models.sql.go | 74 ++----- internal/db/sqlc/settings.sql.go | 41 +++- internal/embeddings/resolver.go | 67 +++++-- internal/handlers/embeddings.go | 8 + internal/handlers/models.go | 101 +++++++--- internal/models/bootstrap.go | 57 ------ internal/models/models.go | 85 +++----- internal/models/types.go | 25 +-- internal/settings/service.go | 21 ++ internal/settings/types.go | 7 +- packages/cli/src/cli/index.ts | 97 +++++----- 20 files changed, 793 insertions(+), 811 deletions(-) delete mode 100644 internal/models/bootstrap.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 72d89aaa..5ca1777c 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -11,20 +11,20 @@ import ( "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/config" - "github.com/memohai/memoh/internal/logger" ctr "github.com/memohai/memoh/internal/containerd" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/embeddings" "github.com/memohai/memoh/internal/handlers" "github.com/memohai/memoh/internal/history" + "github.com/memohai/memoh/internal/logger" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/providers" "github.com/memohai/memoh/internal/schedule" - "github.com/memohai/memoh/internal/settings" "github.com/memohai/memoh/internal/server" + "github.com/memohai/memoh/internal/settings" "github.com/memohai/memoh/internal/subagent" "github.com/memohai/memoh/internal/version" @@ -177,9 +177,9 @@ func main() { // Initialize providers and models handlers providersService := providers.NewService(logger.L, queries) providersHandler := handlers.NewProvidersHandler(logger.L, providersService) - modelsHandler := handlers.NewModelsHandler(logger.L, modelsService) settingsService := settings.NewService(logger.L, queries) settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService) + modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService) historyService := history.NewService(logger.L, queries) historyHandler := handlers.NewHistoryHandler(logger.L, historyService) channelService := channel.NewService(queries) diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 0ca9b491..bf6b4c30 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -79,14 +79,9 @@ CREATE TABLE IF NOT EXISTS models ( dimensions INTEGER, is_multimodal BOOLEAN NOT NULL DEFAULT false, type TEXT NOT NULL DEFAULT 'chat', - enable_as TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), CONSTRAINT models_model_id_unique UNIQUE (model_id), - CONSTRAINT models_enable_as_check CHECK ( - (type = 'embedding' AND (enable_as = 'embedding' OR enable_as IS NULL)) OR - (type = 'chat' AND (enable_as IN ('chat', 'memory') OR enable_as IS NULL)) - ), CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')), CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL) ); @@ -104,8 +99,6 @@ CREATE TABLE IF NOT EXISTS model_variants ( CREATE INDEX IF NOT EXISTS idx_model_variants_model_uuid ON model_variants(model_uuid); CREATE INDEX IF NOT EXISTS idx_model_variants_variant_id ON model_variants(variant_id); -CREATE UNIQUE INDEX IF NOT EXISTS idx_models_enable_as_unique ON models(enable_as) WHERE enable_as IS NOT NULL; - CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id); CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id); @@ -144,6 +137,9 @@ CREATE INDEX IF NOT EXISTS idx_history_timestamp ON history(timestamp); CREATE TABLE IF NOT EXISTS user_settings ( user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, + chat_model_id TEXT, + memory_model_id TEXT, + embedding_model_id TEXT, max_context_load_time INTEGER NOT NULL DEFAULT 1440, language TEXT NOT NULL DEFAULT 'Same as user input' ); diff --git a/db/queries/models.sql b/db/queries/models.sql index bc829107..1b7fefd2 100644 --- a/db/queries/models.sql +++ b/db/queries/models.sql @@ -46,15 +46,14 @@ SELECT COUNT(*) FROM llm_providers; SELECT COUNT(*) FROM llm_providers WHERE client_type = sqlc.arg(client_type); -- name: CreateModel :one -INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as) +INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type) VALUES ( sqlc.arg(model_id), sqlc.arg(name), sqlc.arg(llm_provider_id), sqlc.arg(dimensions), sqlc.arg(is_multimodal), - sqlc.arg(type), - sqlc.arg(enable_as) + sqlc.arg(type) ) RETURNING *; @@ -87,7 +86,6 @@ SET dimensions = sqlc.arg(dimensions), is_multimodal = sqlc.arg(is_multimodal), type = sqlc.arg(type), - enable_as = sqlc.arg(enable_as), updated_at = now() WHERE id = sqlc.arg(id) RETURNING *; @@ -100,7 +98,6 @@ SET dimensions = sqlc.arg(dimensions), is_multimodal = sqlc.arg(is_multimodal), type = sqlc.arg(type), - enable_as = sqlc.arg(enable_as), updated_at = now() WHERE model_id = sqlc.arg(model_id) RETURNING *; @@ -117,13 +114,6 @@ SELECT COUNT(*) FROM models; -- name: CountModelsByType :one SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type); --- name: GetModelByEnableAs :one -SELECT * FROM models WHERE enable_as = sqlc.arg(enable_as) LIMIT 1; - --- name: ClearEnableAs :exec -UPDATE models -SET enable_as = NULL, updated_at = now() -WHERE enable_as = sqlc.arg(enable_as); -- name: CreateModelVariant :one INSERT INTO model_variants (model_uuid, variant_id, weight, metadata) diff --git a/db/queries/settings.sql b/db/queries/settings.sql index ec1df7d5..2ec129be 100644 --- a/db/queries/settings.sql +++ b/db/queries/settings.sql @@ -1,15 +1,18 @@ -- name: GetSettingsByUserID :one -SELECT user_id, max_context_load_time, language +SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language FROM user_settings WHERE user_id = $1; -- name: UpsertSettings :one -INSERT INTO user_settings (user_id, max_context_load_time, language) -VALUES ($1, $2, $3) +INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language) +VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (user_id) DO UPDATE SET + chat_model_id = EXCLUDED.chat_model_id, + memory_model_id = EXCLUDED.memory_model_id, + embedding_model_id = EXCLUDED.embedding_model_id, max_context_load_time = EXCLUDED.max_context_load_time, language = EXCLUDED.language -RETURNING user_id, max_context_load_time, language; +RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language; -- name: DeleteSettingsByUserID :exec DELETE FROM user_settings diff --git a/docs/docs.go b/docs/docs.go index d7c4f6c6..3e0efad1 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -739,7 +739,7 @@ const docTemplate = `{ }, "/memory/add": { "post": { - "description": "Add memory for a user via memory", + "description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -751,7 +751,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.AddRequest" + "$ref": "#/definitions/handlers.memoryAddPayload" } } ], @@ -779,7 +779,7 @@ const docTemplate = `{ }, "/memory/embed": { "post": { - "description": "Embed text or multimodal input and upsert into memory store", + "description": "Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -791,7 +791,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.EmbedUpsertRequest" + "$ref": "#/definitions/handlers.memoryEmbedUpsertPayload" } } ], @@ -819,18 +819,12 @@ const docTemplate = `{ }, "/memory/memories": { "get": { - "description": "List memories for a user via memory", + "description": "List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], "summary": "List memories", "parameters": [ - { - "type": "string", - "description": "User ID", - "name": "user_id", - "in": "query" - }, { "type": "string", "description": "Agent ID", @@ -872,7 +866,7 @@ const docTemplate = `{ } }, "delete": { - "description": "Delete all memories for a user via memory", + "description": "Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -884,7 +878,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.DeleteAllRequest" + "$ref": "#/definitions/handlers.memoryDeleteAllPayload" } } ], @@ -912,7 +906,7 @@ const docTemplate = `{ }, "/memory/memories/{memoryId}": { "get": { - "description": "Get a memory by ID via memory", + "description": "Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -948,7 +942,7 @@ const docTemplate = `{ } }, "delete": { - "description": "Delete a memory by ID via memory", + "description": "Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -986,7 +980,7 @@ const docTemplate = `{ }, "/memory/search": { "post": { - "description": "Search memories for a user via memory", + "description": "Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -998,7 +992,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.SearchRequest" + "$ref": "#/definitions/handlers.memorySearchPayload" } } ], @@ -1026,7 +1020,7 @@ const docTemplate = `{ }, "/memory/update": { "post": { - "description": "Update a memory by ID via memory", + "description": "Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -1185,27 +1179,29 @@ const docTemplate = `{ } } }, - "/models/enable-as/{enableAs}": { - "get": { - "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)", + "/models/enable": { + "post": { + "description": "Update the current user's settings to use the selected model", "tags": [ "models" ], - "summary": "Get model by enable_as", + "summary": "Enable model for chat/memory/embedding", "parameters": [ { - "type": "string", - "description": "Enable as value (chat, memory, embedding)", - "name": "enableAs", - "in": "path", - "required": true + "description": "Enable model payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.EnableModelRequest" + } } ], "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/models.GetResponse" + "$ref": "#/definitions/settings.Settings" } }, "400": { @@ -2850,6 +2846,17 @@ const docTemplate = `{ } } }, + "handlers.EnableModelRequest": { + "type": "object", + "properties": { + "as": { + "type": "string" + }, + "model_id": { + "type": "string" + } + } + }, "handlers.ErrorResponse": { "type": "object", "properties": { @@ -2996,6 +3003,109 @@ const docTemplate = `{ } } }, + "handlers.memoryAddPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "infer": { + "type": "boolean" + }, + "message": { + "type": "string" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/memory.Message" + } + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "run_id": { + "type": "string" + } + } + }, + "handlers.memoryDeleteAllPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "run_id": { + "type": "string" + } + } + }, + "handlers.memoryEmbedUpsertPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "input": { + "$ref": "#/definitions/memory.EmbedInput" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "run_id": { + "type": "string" + }, + "source": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "handlers.memorySearchPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "limit": { + "type": "integer" + }, + "query": { + "type": "string" + }, + "run_id": { + "type": "string" + }, + "sources": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, "handlers.skillsOpResponse": { "type": "object", "properties": { @@ -3060,54 +3170,6 @@ const docTemplate = `{ } } }, - "memory.AddRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "filters": { - "type": "object", - "additionalProperties": true - }, - "infer": { - "type": "boolean" - }, - "message": { - "type": "string" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/memory.Message" - } - }, - "metadata": { - "type": "object", - "additionalProperties": true - }, - "run_id": { - "type": "string" - }, - "user_id": { - "type": "string" - } - } - }, - "memory.DeleteAllRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "user_id": { - "type": "string" - } - } - }, "memory.DeleteResponse": { "type": "object", "properties": { @@ -3130,43 +3192,6 @@ const docTemplate = `{ } } }, - "memory.EmbedUpsertRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "filters": { - "type": "object", - "additionalProperties": true - }, - "input": { - "$ref": "#/definitions/memory.EmbedInput" - }, - "metadata": { - "type": "object", - "additionalProperties": true - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "source": { - "type": "string" - }, - "type": { - "type": "string" - }, - "user_id": { - "type": "string" - } - } - }, "memory.EmbedUpsertResponse": { "type": "object", "properties": { @@ -3231,36 +3256,6 @@ const docTemplate = `{ } } }, - "memory.SearchRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "filters": { - "type": "object", - "additionalProperties": true - }, - "limit": { - "type": "integer" - }, - "query": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "sources": { - "type": "array", - "items": { - "type": "string" - } - }, - "user_id": { - "type": "string" - } - } - }, "memory.SearchResponse": { "type": "object", "properties": { @@ -3293,9 +3288,6 @@ const docTemplate = `{ "dimensions": { "type": "integer" }, - "enable_as": { - "$ref": "#/definitions/models.EnableAs" - }, "is_multimodal": { "type": "boolean" }, @@ -3332,28 +3324,12 @@ const docTemplate = `{ } } }, - "models.EnableAs": { - "type": "string", - "enum": [ - "chat", - "memory", - "embedding" - ], - "x-enum-varnames": [ - "EnableAsChat", - "EnableAsMemory", - "EnableAsEmbedding" - ] - }, "models.GetResponse": { "type": "object", "properties": { "dimensions": { "type": "integer" }, - "enable_as": { - "$ref": "#/definitions/models.EnableAs" - }, "is_multimodal": { "type": "boolean" }, @@ -3388,9 +3364,6 @@ const docTemplate = `{ "dimensions": { "type": "integer" }, - "enable_as": { - "$ref": "#/definitions/models.EnableAs" - }, "is_multimodal": { "type": "boolean" }, @@ -3609,22 +3582,40 @@ const docTemplate = `{ "settings.Settings": { "type": "object", "properties": { + "chat_model_id": { + "type": "string" + }, + "embedding_model_id": { + "type": "string" + }, "language": { "type": "string" }, "max_context_load_time": { "type": "integer" + }, + "memory_model_id": { + "type": "string" } } }, "settings.UpsertRequest": { "type": "object", "properties": { + "chat_model_id": { + "type": "string" + }, + "embedding_model_id": { + "type": "string" + }, "language": { "type": "string" }, "max_context_load_time": { "type": "integer" + }, + "memory_model_id": { + "type": "string" } } }, diff --git a/docs/swagger.json b/docs/swagger.json index 13920970..2adf19de 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -730,7 +730,7 @@ }, "/memory/add": { "post": { - "description": "Add memory for a user via memory", + "description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -742,7 +742,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.AddRequest" + "$ref": "#/definitions/handlers.memoryAddPayload" } } ], @@ -770,7 +770,7 @@ }, "/memory/embed": { "post": { - "description": "Embed text or multimodal input and upsert into memory store", + "description": "Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -782,7 +782,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.EmbedUpsertRequest" + "$ref": "#/definitions/handlers.memoryEmbedUpsertPayload" } } ], @@ -810,18 +810,12 @@ }, "/memory/memories": { "get": { - "description": "List memories for a user via memory", + "description": "List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], "summary": "List memories", "parameters": [ - { - "type": "string", - "description": "User ID", - "name": "user_id", - "in": "query" - }, { "type": "string", "description": "Agent ID", @@ -863,7 +857,7 @@ } }, "delete": { - "description": "Delete all memories for a user via memory", + "description": "Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -875,7 +869,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.DeleteAllRequest" + "$ref": "#/definitions/handlers.memoryDeleteAllPayload" } } ], @@ -903,7 +897,7 @@ }, "/memory/memories/{memoryId}": { "get": { - "description": "Get a memory by ID via memory", + "description": "Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -939,7 +933,7 @@ } }, "delete": { - "description": "Delete a memory by ID via memory", + "description": "Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -977,7 +971,7 @@ }, "/memory/search": { "post": { - "description": "Search memories for a user via memory", + "description": "Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -989,7 +983,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/memory.SearchRequest" + "$ref": "#/definitions/handlers.memorySearchPayload" } } ], @@ -1017,7 +1011,7 @@ }, "/memory/update": { "post": { - "description": "Update a memory by ID via memory", + "description": "Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).", "tags": [ "memory" ], @@ -1176,27 +1170,29 @@ } } }, - "/models/enable-as/{enableAs}": { - "get": { - "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)", + "/models/enable": { + "post": { + "description": "Update the current user's settings to use the selected model", "tags": [ "models" ], - "summary": "Get model by enable_as", + "summary": "Enable model for chat/memory/embedding", "parameters": [ { - "type": "string", - "description": "Enable as value (chat, memory, embedding)", - "name": "enableAs", - "in": "path", - "required": true + "description": "Enable model payload", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.EnableModelRequest" + } } ], "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/models.GetResponse" + "$ref": "#/definitions/settings.Settings" } }, "400": { @@ -2841,6 +2837,17 @@ } } }, + "handlers.EnableModelRequest": { + "type": "object", + "properties": { + "as": { + "type": "string" + }, + "model_id": { + "type": "string" + } + } + }, "handlers.ErrorResponse": { "type": "object", "properties": { @@ -2987,6 +2994,109 @@ } } }, + "handlers.memoryAddPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "infer": { + "type": "boolean" + }, + "message": { + "type": "string" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/memory.Message" + } + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "run_id": { + "type": "string" + } + } + }, + "handlers.memoryDeleteAllPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "run_id": { + "type": "string" + } + } + }, + "handlers.memoryEmbedUpsertPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "input": { + "$ref": "#/definitions/memory.EmbedInput" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "run_id": { + "type": "string" + }, + "source": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "handlers.memorySearchPayload": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "limit": { + "type": "integer" + }, + "query": { + "type": "string" + }, + "run_id": { + "type": "string" + }, + "sources": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, "handlers.skillsOpResponse": { "type": "object", "properties": { @@ -3051,54 +3161,6 @@ } } }, - "memory.AddRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "filters": { - "type": "object", - "additionalProperties": true - }, - "infer": { - "type": "boolean" - }, - "message": { - "type": "string" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/memory.Message" - } - }, - "metadata": { - "type": "object", - "additionalProperties": true - }, - "run_id": { - "type": "string" - }, - "user_id": { - "type": "string" - } - } - }, - "memory.DeleteAllRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "user_id": { - "type": "string" - } - } - }, "memory.DeleteResponse": { "type": "object", "properties": { @@ -3121,43 +3183,6 @@ } } }, - "memory.EmbedUpsertRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "filters": { - "type": "object", - "additionalProperties": true - }, - "input": { - "$ref": "#/definitions/memory.EmbedInput" - }, - "metadata": { - "type": "object", - "additionalProperties": true - }, - "model": { - "type": "string" - }, - "provider": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "source": { - "type": "string" - }, - "type": { - "type": "string" - }, - "user_id": { - "type": "string" - } - } - }, "memory.EmbedUpsertResponse": { "type": "object", "properties": { @@ -3222,36 +3247,6 @@ } } }, - "memory.SearchRequest": { - "type": "object", - "properties": { - "agent_id": { - "type": "string" - }, - "filters": { - "type": "object", - "additionalProperties": true - }, - "limit": { - "type": "integer" - }, - "query": { - "type": "string" - }, - "run_id": { - "type": "string" - }, - "sources": { - "type": "array", - "items": { - "type": "string" - } - }, - "user_id": { - "type": "string" - } - } - }, "memory.SearchResponse": { "type": "object", "properties": { @@ -3284,9 +3279,6 @@ "dimensions": { "type": "integer" }, - "enable_as": { - "$ref": "#/definitions/models.EnableAs" - }, "is_multimodal": { "type": "boolean" }, @@ -3323,28 +3315,12 @@ } } }, - "models.EnableAs": { - "type": "string", - "enum": [ - "chat", - "memory", - "embedding" - ], - "x-enum-varnames": [ - "EnableAsChat", - "EnableAsMemory", - "EnableAsEmbedding" - ] - }, "models.GetResponse": { "type": "object", "properties": { "dimensions": { "type": "integer" }, - "enable_as": { - "$ref": "#/definitions/models.EnableAs" - }, "is_multimodal": { "type": "boolean" }, @@ -3379,9 +3355,6 @@ "dimensions": { "type": "integer" }, - "enable_as": { - "$ref": "#/definitions/models.EnableAs" - }, "is_multimodal": { "type": "boolean" }, @@ -3600,22 +3573,40 @@ "settings.Settings": { "type": "object", "properties": { + "chat_model_id": { + "type": "string" + }, + "embedding_model_id": { + "type": "string" + }, "language": { "type": "string" }, "max_context_load_time": { "type": "integer" + }, + "memory_model_id": { + "type": "string" } } }, "settings.UpsertRequest": { "type": "object", "properties": { + "chat_model_id": { + "type": "string" + }, + "embedding_model_id": { + "type": "string" + }, "language": { "type": "string" }, "max_context_load_time": { "type": "integer" + }, + "memory_model_id": { + "type": "string" } } }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index e6312f55..1997112c 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -166,6 +166,13 @@ definitions: video_tokens: type: integer type: object + handlers.EnableModelRequest: + properties: + as: + type: string + model_id: + type: string + type: object handlers.ErrorResponse: properties: message: @@ -260,6 +267,75 @@ definitions: updated_at: type: string type: object + handlers.memoryAddPayload: + properties: + agent_id: + type: string + filters: + additionalProperties: true + type: object + infer: + type: boolean + message: + type: string + messages: + items: + $ref: '#/definitions/memory.Message' + type: array + metadata: + additionalProperties: true + type: object + run_id: + type: string + type: object + handlers.memoryDeleteAllPayload: + properties: + agent_id: + type: string + run_id: + type: string + type: object + handlers.memoryEmbedUpsertPayload: + properties: + agent_id: + type: string + filters: + additionalProperties: true + type: object + input: + $ref: '#/definitions/memory.EmbedInput' + metadata: + additionalProperties: true + type: object + model: + type: string + provider: + type: string + run_id: + type: string + source: + type: string + type: + type: string + type: object + handlers.memorySearchPayload: + properties: + agent_id: + type: string + filters: + additionalProperties: true + type: object + limit: + type: integer + query: + type: string + run_id: + type: string + sources: + items: + type: string + type: array + type: object handlers.skillsOpResponse: properties: ok: @@ -302,38 +378,6 @@ definitions: user_id: type: string type: object - memory.AddRequest: - properties: - agent_id: - type: string - filters: - additionalProperties: true - type: object - infer: - type: boolean - message: - type: string - messages: - items: - $ref: '#/definitions/memory.Message' - type: array - metadata: - additionalProperties: true - type: object - run_id: - type: string - user_id: - type: string - type: object - memory.DeleteAllRequest: - properties: - agent_id: - type: string - run_id: - type: string - user_id: - type: string - type: object memory.DeleteResponse: properties: message: @@ -348,31 +392,6 @@ definitions: video_url: type: string type: object - memory.EmbedUpsertRequest: - properties: - agent_id: - type: string - filters: - additionalProperties: true - type: object - input: - $ref: '#/definitions/memory.EmbedInput' - metadata: - additionalProperties: true - type: object - model: - type: string - provider: - type: string - run_id: - type: string - source: - type: string - type: - type: string - user_id: - type: string - type: object memory.EmbedUpsertResponse: properties: dimensions: @@ -415,26 +434,6 @@ definitions: role: type: string type: object - memory.SearchRequest: - properties: - agent_id: - type: string - filters: - additionalProperties: true - type: object - limit: - type: integer - query: - type: string - run_id: - type: string - sources: - items: - type: string - type: array - user_id: - type: string - type: object memory.SearchResponse: properties: relations: @@ -456,8 +455,6 @@ definitions: properties: dimensions: type: integer - enable_as: - $ref: '#/definitions/models.EnableAs' is_multimodal: type: boolean llm_provider_id: @@ -481,22 +478,10 @@ definitions: count: type: integer type: object - models.EnableAs: - enum: - - chat - - memory - - embedding - type: string - x-enum-varnames: - - EnableAsChat - - EnableAsMemory - - EnableAsEmbedding models.GetResponse: properties: dimensions: type: integer - enable_as: - $ref: '#/definitions/models.EnableAs' is_multimodal: type: boolean llm_provider_id: @@ -520,8 +505,6 @@ definitions: properties: dimensions: type: integer - enable_as: - $ref: '#/definitions/models.EnableAs' is_multimodal: type: boolean llm_provider_id: @@ -669,17 +652,29 @@ definitions: type: object settings.Settings: properties: + chat_model_id: + type: string + embedding_model_id: + type: string language: type: string max_context_load_time: type: integer + memory_model_id: + type: string type: object settings.UpsertRequest: properties: + chat_model_id: + type: string + embedding_model_id: + type: string language: type: string max_context_load_time: type: integer + memory_model_id: + type: string type: object subagent.AddSkillsRequest: properties: @@ -1279,14 +1274,15 @@ paths: - containerd /memory/add: post: - description: Add memory for a user via memory + description: 'Add memory for a user via memory. Auth: Bearer JWT determines + user_id (sub or user_id).' parameters: - description: Add request in: body name: payload required: true schema: - $ref: '#/definitions/memory.AddRequest' + $ref: '#/definitions/handlers.memoryAddPayload' responses: "200": description: OK @@ -1305,14 +1301,15 @@ paths: - memory /memory/embed: post: - description: Embed text or multimodal input and upsert into memory store + description: 'Embed text or multimodal input and upsert into memory store. Auth: + Bearer JWT determines user_id (sub or user_id).' parameters: - description: Embed upsert request in: body name: payload required: true schema: - $ref: '#/definitions/memory.EmbedUpsertRequest' + $ref: '#/definitions/handlers.memoryEmbedUpsertPayload' responses: "200": description: OK @@ -1331,14 +1328,15 @@ paths: - memory /memory/memories: delete: - description: Delete all memories for a user via memory + description: 'Delete all memories for a user via memory. Auth: Bearer JWT determines + user_id (sub or user_id).' parameters: - description: Delete all request in: body name: payload required: true schema: - $ref: '#/definitions/memory.DeleteAllRequest' + $ref: '#/definitions/handlers.memoryDeleteAllPayload' responses: "200": description: OK @@ -1356,12 +1354,9 @@ paths: tags: - memory get: - description: List memories for a user via memory + description: 'List memories for a user via memory. Auth: Bearer JWT determines + user_id (sub or user_id).' parameters: - - description: User ID - in: query - name: user_id - type: string - description: Agent ID in: query name: agent_id @@ -1392,7 +1387,8 @@ paths: - memory /memory/memories/{memoryId}: delete: - description: Delete a memory by ID via memory + description: 'Delete a memory by ID via memory. Auth: Bearer JWT determines + user_id (sub or user_id).' parameters: - description: Memory ID in: path @@ -1416,7 +1412,8 @@ paths: tags: - memory get: - description: Get a memory by ID via memory + description: 'Get a memory by ID via memory. Auth: Bearer JWT determines user_id + (sub or user_id).' parameters: - description: Memory ID in: path @@ -1441,14 +1438,15 @@ paths: - memory /memory/search: post: - description: Search memories for a user via memory + description: 'Search memories for a user via memory. Auth: Bearer JWT determines + user_id (sub or user_id).' parameters: - description: Search request in: body name: payload required: true schema: - $ref: '#/definitions/memory.SearchRequest' + $ref: '#/definitions/handlers.memorySearchPayload' responses: "200": description: OK @@ -1467,7 +1465,8 @@ paths: - memory /memory/update: post: - description: Update a memory by ID via memory + description: 'Update a memory by ID via memory. Auth: Bearer JWT determines + user_id (sub or user_id).' parameters: - description: Update request in: body @@ -1660,21 +1659,21 @@ paths: summary: Get model count tags: - models - /models/enable-as/{enableAs}: - get: - description: Get the model that is enabled for a specific purpose (chat, memory, - embedding) + /models/enable: + post: + description: Update the current user's settings to use the selected model parameters: - - description: Enable as value (chat, memory, embedding) - in: path - name: enableAs + - description: Enable model payload + in: body + name: payload required: true - type: string + schema: + $ref: '#/definitions/handlers.EnableModelRequest' responses: "200": description: OK schema: - $ref: '#/definitions/models.GetResponse' + $ref: '#/definitions/settings.Settings' "400": description: Bad Request schema: @@ -1687,7 +1686,7 @@ paths: description: Internal Server Error schema: $ref: '#/definitions/handlers.ErrorResponse' - summary: Get model by enable_as + summary: Enable model for chat/memory/embedding tags: - models /models/model/{modelId}: diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index e51b95a8..79ecf9b7 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -36,6 +36,14 @@ type Resolver struct { streamingClient *http.Client } +type userSettings struct { + ChatModelID string + MemoryModelID string + EmbeddingModelID string + MaxContextLoadTime int + Language string +} + func NewResolver(log *slog.Logger, modelsService *models.Service, queries *sqlc.Queries, memoryService *memory.Service, gatewayBaseURL string, timeout time.Duration) *Resolver { if strings.TrimSpace(gatewayBaseURL) == "" { gatewayBaseURL = "http://127.0.0.1:8081" @@ -67,7 +75,11 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err } skipHistory := req.MaxContextLoadTime < 0 - chatModel, provider, err := r.selectChatModel(ctx, req) + settings, err := r.loadUserSettings(ctx, req.UserID) + if err != nil { + return ChatResponse{}, err + } + chatModel, provider, err := r.selectChatModel(ctx, req, settings) if err != nil { return ChatResponse{}, err } @@ -76,10 +88,8 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err return ChatResponse{}, err } - maxContextLoadTime, language, err := r.loadUserSettings(ctx, req.UserID) - if err != nil { - return ChatResponse{}, err - } + maxContextLoadTime := settings.MaxContextLoadTime + language := settings.Language if req.MaxContextLoadTime > 0 { maxContextLoadTime = req.MaxContextLoadTime } @@ -157,7 +167,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule Locale: "", Language: "", } - chatModel, provider, err := r.selectChatModel(ctx, req) + settings, err := r.loadUserSettings(ctx, userID) + if err != nil { + return err + } + chatModel, provider, err := r.selectChatModel(ctx, req, settings) if err != nil { return err } @@ -166,10 +180,8 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule return err } - maxContextLoadTime, language, err := r.loadUserSettings(ctx, userID) - if err != nil { - return err - } + maxContextLoadTime := settings.MaxContextLoadTime + language := settings.Language messages, err := r.loadHistoryMessages(ctx, userID, maxContextLoadTime) if err != nil { @@ -229,7 +241,12 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre } skipHistory := req.MaxContextLoadTime < 0 - chatModel, provider, err := r.selectChatModel(ctx, req) + settings, err := r.loadUserSettings(ctx, req.UserID) + if err != nil { + errChan <- err + return + } + chatModel, provider, err := r.selectChatModel(ctx, req, settings) if err != nil { errChan <- err return @@ -240,11 +257,8 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre return } - maxContextLoadTime, language, err := r.loadUserSettings(ctx, req.UserID) - if err != nil { - errChan <- err - return - } + maxContextLoadTime := settings.MaxContextLoadTime + language := settings.Language if req.MaxContextLoadTime > 0 { maxContextLoadTime = req.MaxContextLoadTime } @@ -797,7 +811,7 @@ func isEmptyValue(value interface{}) bool { } } -func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest) (models.GetResponse, sqlc.LlmProvider, error) { +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, settings userSettings) (models.GetResponse, sqlc.LlmProvider, error) { if r.modelsService == nil { return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") } @@ -819,15 +833,19 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest) (models return model, provider, nil } - if providerFilter == "" && modelID == "" { - defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsChat) - if err == nil { - provider, err := models.FetchProviderByID(ctx, r.queries, defaultModel.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return defaultModel, provider, nil + if providerFilter == "" && modelID == "" && strings.TrimSpace(settings.ChatModelID) != "" { + selected, err := r.modelsService.GetByModelID(ctx, settings.ChatModelID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model not found: %w", err) } + if selected.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model is not a chat model") + } + provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err + } + return selected, provider, nil } var candidates []models.GetResponse @@ -880,26 +898,31 @@ func normalizeMaxContextLoad(value int) int { return value } -func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (int, string, error) { +func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (userSettings, error) { if r.queries == nil { - return defaultMaxContextMinutes, "Same as user input", nil + return userSettings{ + MaxContextLoadTime: defaultMaxContextMinutes, + Language: "Same as user input", + }, nil } pgUserID, err := parseUUID(userID) if err != nil { - return 0, "", err + return userSettings{}, err } settingsRow, err := r.queries.GetSettingsByUserID(ctx, pgUserID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return defaultMaxContextMinutes, "Same as user input", nil + return userSettings{ + MaxContextLoadTime: defaultMaxContextMinutes, + Language: "Same as user input", + }, nil } - return 0, "", err + return userSettings{}, err } - maxLoad, language := normalizeUserSettingRow(settingsRow) - return maxLoad, language, nil + return normalizeUserSettingRow(settingsRow), nil } -func normalizeUserSettingRow(row sqlc.UserSetting) (int, string) { +func normalizeUserSettingRow(row sqlc.UserSetting) userSettings { maxLoad := int(row.MaxContextLoadTime) if maxLoad <= 0 { maxLoad = defaultMaxContextMinutes @@ -908,7 +931,13 @@ func normalizeUserSettingRow(row sqlc.UserSetting) (int, string) { if language == "" { language = "Same as user input" } - return maxLoad, language + return userSettings{ + ChatModelID: strings.TrimSpace(row.ChatModelID.String), + MemoryModelID: strings.TrimSpace(row.MemoryModelID.String), + EmbeddingModelID: strings.TrimSpace(row.EmbeddingModelID.String), + MaxContextLoadTime: maxLoad, + Language: language, + } } func normalizeClientType(clientType string) (string, error) { diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 677c19a4..dd47dc1e 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -87,7 +87,6 @@ type Model struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` - EnableAs pgtype.Text `json:"enable_as"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } @@ -156,6 +155,9 @@ type User struct { type UserSetting struct { UserID pgtype.UUID `json:"user_id"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` MaxContextLoadTime int32 `json:"max_context_load_time"` Language string `json:"language"` } diff --git a/internal/db/sqlc/models.sql.go b/internal/db/sqlc/models.sql.go index b5b36426..e66c3ce2 100644 --- a/internal/db/sqlc/models.sql.go +++ b/internal/db/sqlc/models.sql.go @@ -11,17 +11,6 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -const clearEnableAs = `-- name: ClearEnableAs :exec -UPDATE models -SET enable_as = NULL, updated_at = now() -WHERE enable_as = $1 -` - -func (q *Queries) ClearEnableAs(ctx context.Context, enableAs pgtype.Text) error { - _, err := q.db.Exec(ctx, clearEnableAs, enableAs) - return err -} - const countLlmProviders = `-- name: CountLlmProviders :one SELECT COUNT(*) FROM llm_providers ` @@ -109,17 +98,16 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa } const createModel = `-- name: CreateModel :one -INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as) +INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type) VALUES ( $1, $2, $3, $4, $5, - $6, - $7 + $6 ) -RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at ` type CreateModelParams struct { @@ -129,7 +117,6 @@ type CreateModelParams struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` - EnableAs pgtype.Text `json:"enable_as"` } func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) { @@ -140,7 +127,6 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model arg.Dimensions, arg.IsMultimodal, arg.Type, - arg.EnableAs, ) var i Model err := row.Scan( @@ -151,7 +137,6 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -272,30 +257,8 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro return i, err } -const getModelByEnableAs = `-- name: GetModelByEnableAs :one -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE enable_as = $1 LIMIT 1 -` - -func (q *Queries) GetModelByEnableAs(ctx context.Context, enableAs pgtype.Text) (Model, error) { - row := q.db.QueryRow(ctx, getModelByEnableAs, enableAs) - var i Model - err := row.Scan( - &i.ID, - &i.ModelID, - &i.Name, - &i.LlmProviderID, - &i.Dimensions, - &i.IsMultimodal, - &i.Type, - &i.EnableAs, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - const getModelByID = `-- name: GetModelByID :one -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE id = $1 +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE id = $1 ` func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) { @@ -309,7 +272,6 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -317,7 +279,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro } const getModelByModelID = `-- name: GetModelByModelID :one -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE model_id = $1 +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE model_id = $1 ` func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) { @@ -331,7 +293,6 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -495,7 +456,7 @@ func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID st } const listModels = `-- name: ListModels :many -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models ORDER BY created_at DESC ` @@ -516,7 +477,6 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) { &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -531,7 +491,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) { } const listModelsByClientType = `-- name: ListModelsByClientType :many -SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.enable_as, m.created_at, m.updated_at FROM models AS m +SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.created_at, m.updated_at FROM models AS m JOIN llm_providers AS p ON p.id = m.llm_provider_id WHERE p.client_type = $1 ORDER BY m.created_at DESC @@ -554,7 +514,6 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -569,7 +528,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) } const listModelsByType = `-- name: ListModelsByType :many -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE type = $1 ORDER BY created_at DESC ` @@ -591,7 +550,6 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model, &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -658,10 +616,9 @@ SET dimensions = $3, is_multimodal = $4, type = $5, - enable_as = $6, updated_at = now() -WHERE id = $7 -RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at +WHERE id = $6 +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at ` type UpdateModelParams struct { @@ -670,7 +627,6 @@ type UpdateModelParams struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` - EnableAs pgtype.Text `json:"enable_as"` ID pgtype.UUID `json:"id"` } @@ -681,7 +637,6 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model arg.Dimensions, arg.IsMultimodal, arg.Type, - arg.EnableAs, arg.ID, ) var i Model @@ -693,7 +648,6 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -708,10 +662,9 @@ SET dimensions = $3, is_multimodal = $4, type = $5, - enable_as = $6, updated_at = now() -WHERE model_id = $7 -RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at +WHERE model_id = $6 +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at ` type UpdateModelByModelIDParams struct { @@ -720,7 +673,6 @@ type UpdateModelByModelIDParams struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` - EnableAs pgtype.Text `json:"enable_as"` ModelID string `json:"model_id"` } @@ -731,7 +683,6 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod arg.Dimensions, arg.IsMultimodal, arg.Type, - arg.EnableAs, arg.ModelID, ) var i Model @@ -743,7 +694,6 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod &i.Dimensions, &i.IsMultimodal, &i.Type, - &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/internal/db/sqlc/settings.sql.go b/internal/db/sqlc/settings.sql.go index 14fb440c..d0eea932 100644 --- a/internal/db/sqlc/settings.sql.go +++ b/internal/db/sqlc/settings.sql.go @@ -22,7 +22,7 @@ func (q *Queries) DeleteSettingsByUserID(ctx context.Context, userID pgtype.UUID } const getSettingsByUserID = `-- name: GetSettingsByUserID :one -SELECT user_id, max_context_load_time, language +SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language FROM user_settings WHERE user_id = $1 ` @@ -30,28 +30,55 @@ WHERE user_id = $1 func (q *Queries) GetSettingsByUserID(ctx context.Context, userID pgtype.UUID) (UserSetting, error) { row := q.db.QueryRow(ctx, getSettingsByUserID, userID) var i UserSetting - err := row.Scan(&i.UserID, &i.MaxContextLoadTime, &i.Language) + err := row.Scan( + &i.UserID, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + ) return i, err } const upsertSettings = `-- name: UpsertSettings :one -INSERT INTO user_settings (user_id, max_context_load_time, language) -VALUES ($1, $2, $3) +INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language) +VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (user_id) DO UPDATE SET + chat_model_id = EXCLUDED.chat_model_id, + memory_model_id = EXCLUDED.memory_model_id, + embedding_model_id = EXCLUDED.embedding_model_id, max_context_load_time = EXCLUDED.max_context_load_time, language = EXCLUDED.language -RETURNING user_id, max_context_load_time, language +RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language ` type UpsertSettingsParams struct { UserID pgtype.UUID `json:"user_id"` + ChatModelID pgtype.Text `json:"chat_model_id"` + MemoryModelID pgtype.Text `json:"memory_model_id"` + EmbeddingModelID pgtype.Text `json:"embedding_model_id"` MaxContextLoadTime int32 `json:"max_context_load_time"` Language string `json:"language"` } func (q *Queries) UpsertSettings(ctx context.Context, arg UpsertSettingsParams) (UserSetting, error) { - row := q.db.QueryRow(ctx, upsertSettings, arg.UserID, arg.MaxContextLoadTime, arg.Language) + row := q.db.QueryRow(ctx, upsertSettings, + arg.UserID, + arg.ChatModelID, + arg.MemoryModelID, + arg.EmbeddingModelID, + arg.MaxContextLoadTime, + arg.Language, + ) var i UserSetting - err := row.Scan(&i.UserID, &i.MaxContextLoadTime, &i.Language) + err := row.Scan( + &i.UserID, + &i.ChatModelID, + &i.MemoryModelID, + &i.EmbeddingModelID, + &i.MaxContextLoadTime, + &i.Language, + ) return i, err } diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go index c8becb7f..07e3bfbb 100644 --- a/internal/embeddings/resolver.go +++ b/internal/embeddings/resolver.go @@ -3,11 +3,13 @@ package embeddings import ( "context" "errors" + "fmt" "log/slog" "strings" "time" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/memohai/memoh/internal/db/sqlc" @@ -29,6 +31,7 @@ type Request struct { Model string Dimensions int Input Input + UserID string } type Input struct { @@ -176,21 +179,28 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model return models.GetResponse{}, errors.New("models service not configured") } - // If no model specified and no provider specified, try to get default embedding model - if req.Model == "" && req.Provider == "" { - defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsEmbedding) - if err == nil { - // Found default model, check if it matches the type requirement - if req.Type == TypeMultimodal && !defaultModel.IsMultimodal { - // Default is text, but need multimodal - continue to search - } else if req.Type == TypeText && defaultModel.IsMultimodal { - // Default is multimodal, but need text - continue to search - } else { - // Default model matches requirements - return defaultModel, nil - } + // If no model specified and no provider specified, try to get per-user embedding model. + if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.UserID) != "" { + modelID, err := r.loadUserEmbeddingModelID(ctx, req.UserID) + if err != nil { + return models.GetResponse{}, err + } + if modelID != "" { + selected, err := r.modelsService.GetByModelID(ctx, modelID) + if err != nil { + return models.GetResponse{}, fmt.Errorf("settings embedding model not found: %w", err) + } + if selected.Type != models.ModelTypeEmbedding { + return models.GetResponse{}, errors.New("settings embedding model is not an embedding model") + } + if req.Type == TypeMultimodal && !selected.IsMultimodal { + return models.GetResponse{}, errors.New("settings embedding model does not support multimodal") + } + if req.Type == TypeText && selected.IsMultimodal { + return models.GetResponse{}, errors.New("settings embedding model does not support text embeddings") + } + return selected, nil } - // No default model or doesn't match requirements, continue to search } var candidates []models.GetResponse @@ -246,3 +256,32 @@ func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.L copy(pgID.Bytes[:], parsed[:]) return r.queries.GetLlmProviderByID(ctx, pgID) } + +func (r *Resolver) loadUserEmbeddingModelID(ctx context.Context, userID string) (string, error) { + if r.queries == nil { + return "", nil + } + pgUserID, err := parseUUID(userID) + if err != nil { + return "", err + } + row, err := r.queries.GetSettingsByUserID(ctx, pgUserID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", nil + } + return "", err + } + return strings.TrimSpace(row.EmbeddingModelID.String), nil +} + +func parseUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(id) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + } + var pgID pgtype.UUID + pgID.Valid = true + copy(pgID.Bytes[:], parsed[:]) + return pgID, nil +} diff --git a/internal/handlers/embeddings.go b/internal/handlers/embeddings.go index e74a9fc1..4f4d411c 100644 --- a/internal/handlers/embeddings.go +++ b/internal/handlers/embeddings.go @@ -8,6 +8,7 @@ import ( "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/embeddings" "github.com/memohai/memoh/internal/models" @@ -84,6 +85,12 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL) req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL) + userID := "" + if c.Get("user") != nil { + if value, err := auth.UserIDFromContext(c); err == nil { + userID = value + } + } result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{ Type: req.Type, Provider: req.Provider, @@ -94,6 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { ImageURL: req.Input.ImageURL, VideoURL: req.Input.VideoURL, }, + UserID: userID, }) if err != nil { message := err.Error() diff --git a/internal/handlers/models.go b/internal/handlers/models.go index 9fafbca6..0dbd8af4 100644 --- a/internal/handlers/models.go +++ b/internal/handlers/models.go @@ -4,21 +4,26 @@ import ( "log/slog" "net/http" "net/url" + "strings" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/auth" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/settings" ) type ModelsHandler struct { - service *models.Service - logger *slog.Logger + service *models.Service + settingsService *settings.Service + logger *slog.Logger } -func NewModelsHandler(log *slog.Logger, service *models.Service) *ModelsHandler { +func NewModelsHandler(log *slog.Logger, service *models.Service, settingsService *settings.Service) *ModelsHandler { return &ModelsHandler{ - service: service, - logger: log.With(slog.String("handler", "models")), + service: service, + settingsService: settingsService, + logger: log.With(slog.String("handler", "models")), } } @@ -28,7 +33,7 @@ func (h *ModelsHandler) Register(e *echo.Echo) { group.GET("", h.List) group.GET("/:id", h.GetByID) group.GET("/model/:modelId", h.GetByModelID) - group.GET("/enable-as/:enableAs", h.GetByEnableAs) + group.POST("/enable", h.Enable) group.PUT("/:id", h.UpdateByID) group.PUT("/model/:modelId", h.UpdateByModelID) group.DELETE("/:id", h.DeleteByID) @@ -140,6 +145,67 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error { return c.JSON(http.StatusOK, resp) } +type EnableModelRequest struct { + As string `json:"as"` + ModelID string `json:"model_id"` +} + +// Enable godoc +// @Summary Enable model for chat/memory/embedding +// @Description Update the current user's settings to use the selected model +// @Tags models +// @Param payload body handlers.EnableModelRequest true "Enable model payload" +// @Success 200 {object} settings.Settings +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/enable [post] +func (h *ModelsHandler) Enable(c echo.Context) error { + if h.settingsService == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured") + } + userID, err := auth.UserIDFromContext(c) + if err != nil { + return err + } + var req EnableModelRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + req.As = strings.ToLower(strings.TrimSpace(req.As)) + req.ModelID = strings.TrimSpace(req.ModelID) + if req.As == "" || req.ModelID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "as and model_id are required") + } + if req.As != "chat" && req.As != "memory" && req.As != "embedding" { + return echo.NewHTTPError(http.StatusBadRequest, "as must be one of chat, memory, embedding") + } + model, err := h.service.GetByModelID(c.Request().Context(), req.ModelID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } + if req.As == "embedding" && model.Type != models.ModelTypeEmbedding { + return echo.NewHTTPError(http.StatusBadRequest, "model is not an embedding model") + } + if (req.As == "chat" || req.As == "memory") && model.Type != models.ModelTypeChat { + return echo.NewHTTPError(http.StatusBadRequest, "model is not a chat model") + } + upsert := settings.UpsertRequest{} + switch req.As { + case "chat": + upsert.ChatModelID = req.ModelID + case "memory": + upsert.MemoryModelID = req.ModelID + case "embedding": + upsert.EmbeddingModelID = req.ModelID + } + resp, err := h.settingsService.Upsert(c.Request().Context(), userID, upsert) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + // UpdateByID godoc // @Summary Update model by internal ID // @Description Update a model configuration by its internal UUID @@ -252,29 +318,6 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -// GetByEnableAs godoc -// @Summary Get model by enable_as -// @Description Get the model that is enabled for a specific purpose (chat, memory, embedding) -// @Tags models -// @Param enableAs path string true "Enable as value (chat, memory, embedding)" -// @Success 200 {object} models.GetResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /models/enable-as/{enableAs} [get] -func (h *ModelsHandler) GetByEnableAs(c echo.Context) error { - enableAs := c.Param("enableAs") - if enableAs == "" { - return echo.NewHTTPError(http.StatusBadRequest, "enableAs is required") - } - - resp, err := h.service.GetByEnableAs(c.Request().Context(), models.EnableAs(enableAs)) - if err != nil { - return echo.NewHTTPError(http.StatusNotFound, err.Error()) - } - return c.JSON(http.StatusOK, resp) -} - // Count godoc // @Summary Get model count // @Description Get the total count of models, optionally filtered by type diff --git a/internal/models/bootstrap.go b/internal/models/bootstrap.go deleted file mode 100644 index 4c302f7a..00000000 --- a/internal/models/bootstrap.go +++ /dev/null @@ -1,57 +0,0 @@ -package models - -import ( - "context" - "fmt" - "strings" - - "github.com/memohai/memoh/internal/db/sqlc" -) - -// SelectMemoryModel selects a chat model for memory operations. -func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) { - // First try to get the memory-enabled model. - memoryModel, err := modelsService.GetByEnableAs(ctx, EnableAsMemory) - if err == nil { - provider, err := FetchProviderByID(ctx, queries, memoryModel.LlmProviderID) - if err != nil { - return GetResponse{}, sqlc.LlmProvider{}, err - } - return memoryModel, provider, nil - } - - // Fallback to chat model. - chatModel, err := modelsService.GetByEnableAs(ctx, EnableAsChat) - if err == nil { - provider, err := FetchProviderByID(ctx, queries, chatModel.LlmProviderID) - if err != nil { - return GetResponse{}, sqlc.LlmProvider{}, err - } - return chatModel, provider, nil - } - - // If no enabled models, try to find any chat model. - candidates, err := modelsService.ListByType(ctx, ModelTypeChat) - if err != nil || len(candidates) == 0 { - return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") - } - - selected := candidates[0] - provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID) - if err != nil { - return GetResponse{}, sqlc.LlmProvider{}, err - } - return selected, provider, nil -} - -// FetchProviderByID fetches a provider by ID. -func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) { - if strings.TrimSpace(providerID) == "" { - return sqlc.LlmProvider{}, fmt.Errorf("provider id missing") - } - parsed, err := parseUUID(providerID) - if err != nil { - return sqlc.LlmProvider{}, err - } - return queries.GetLlmProviderByID(ctx, parsed) -} diff --git a/internal/models/models.go b/internal/models/models.go index 92ddd546..2654bd8b 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "strings" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" @@ -31,13 +32,6 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro return AddResponse{}, fmt.Errorf("validation failed: %w", err) } - // If enable_as is set, clear any existing model with the same enable_as - if model.EnableAs != nil { - if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil { - return AddResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err) - } - } - // Convert to sqlc params llmProviderID, err := parseUUID(model.LlmProviderID) if err != nil { @@ -61,11 +55,6 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} } - // Handle optional enable_as field - if model.EnableAs != nil { - params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true} - } - created, err := s.queries.CreateModel(ctx, params) if err != nil { return AddResponse{}, fmt.Errorf("failed to create model: %w", err) @@ -166,13 +155,6 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) return GetResponse{}, fmt.Errorf("validation failed: %w", err) } - // If enable_as is being set, clear any existing model with the same enable_as - if model.EnableAs != nil { - if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil { - return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err) - } - } - params := sqlc.UpdateModelParams{ ID: uuid, IsMultimodal: model.IsMultimodal, @@ -193,11 +175,6 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} } - // Handle optional enable_as field - if model.EnableAs != nil { - params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true} - } - updated, err := s.queries.UpdateModel(ctx, params) if err != nil { return GetResponse{}, fmt.Errorf("failed to update model: %w", err) @@ -217,13 +194,6 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat return GetResponse{}, fmt.Errorf("validation failed: %w", err) } - // If enable_as is being set, clear any existing model with the same enable_as - if model.EnableAs != nil { - if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil { - return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err) - } - } - params := sqlc.UpdateModelByModelIDParams{ ModelID: modelID, IsMultimodal: model.IsMultimodal, @@ -244,11 +214,6 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} } - // Handle optional enable_as field - if model.EnableAs != nil { - params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true} - } - updated, err := s.queries.UpdateModelByModelID(ctx, params) if err != nil { return GetResponse{}, fmt.Errorf("failed to update model: %w", err) @@ -306,20 +271,6 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, return count, nil } -// GetByEnableAs retrieves the model that has the specified enable_as value -func (s *Service) GetByEnableAs(ctx context.Context, enableAs EnableAs) (GetResponse, error) { - if enableAs != EnableAsChat && enableAs != EnableAsMemory && enableAs != EnableAsEmbedding { - return GetResponse{}, fmt.Errorf("invalid enable_as value: %s", enableAs) - } - - dbModel, err := s.queries.GetModelByEnableAs(ctx, pgtype.Text{String: string(enableAs), Valid: true}) - if err != nil { - return GetResponse{}, fmt.Errorf("failed to get model by enable_as: %w", err) - } - - return convertToGetResponse(dbModel), nil -} - // Helper functions func parseUUID(id string) (pgtype.UUID, error) { @@ -357,11 +308,6 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { resp.Model.Dimensions = int(dbModel.Dimensions.Int32) } - if dbModel.EnableAs.Valid { - enableAs := EnableAs(dbModel.EnableAs.String) - resp.Model.EnableAs = &enableAs - } - return resp } @@ -399,3 +345,32 @@ func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) { } return id.String(), true } + +// SelectMemoryModel selects a chat model for memory operations. +func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) { + if modelsService == nil { + return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + } + candidates, err := modelsService.ListByType(ctx, ModelTypeChat) + if err != nil || len(candidates) == 0 { + return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") + } + selected := candidates[0] + provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID) + if err != nil { + return GetResponse{}, sqlc.LlmProvider{}, err + } + return selected, provider, nil +} + +// FetchProviderByID fetches a provider by ID. +func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) { + if strings.TrimSpace(providerID) == "" { + return sqlc.LlmProvider{}, fmt.Errorf("provider id missing") + } + parsed, err := parseUUID(providerID) + if err != nil { + return sqlc.LlmProvider{}, err + } + return queries.GetLlmProviderByID(ctx, parsed) +} diff --git a/internal/models/types.go b/internal/models/types.go index 7bed6c09..cdd5a66f 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -13,14 +13,6 @@ const ( ModelTypeEmbedding ModelType = "embedding" ) -type EnableAs string - -const ( - EnableAsChat EnableAs = "chat" - EnableAsMemory EnableAs = "memory" - EnableAsEmbedding EnableAs = "embedding" -) - type ClientType string const ( @@ -41,7 +33,6 @@ type Model struct { IsMultimodal bool `json:"is_multimodal"` Type ModelType `json:"type"` Dimensions int `json:"dimensions"` - EnableAs *EnableAs `json:"enable_as,omitempty"` } func (m *Model) Validate() error { @@ -60,21 +51,7 @@ func (m *Model) Validate() error { if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 { return errors.New("dimensions must be greater than 0") } - - // Validate enable_as based on type - if m.EnableAs != nil { - switch m.Type { - case ModelTypeEmbedding: - if *m.EnableAs != EnableAsEmbedding { - return errors.New("embedding models can only have enable_as set to 'embedding'") - } - case ModelTypeChat: - if *m.EnableAs != EnableAsChat && *m.EnableAs != EnableAsMemory { - return errors.New("chat models can only have enable_as set to 'chat' or 'memory'") - } - } - } - + return nil } diff --git a/internal/settings/service.go b/internal/settings/service.go index db4f9210..5e4c293e 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -35,6 +35,9 @@ func (s *Service) Get(ctx context.Context, userID string) (Settings, error) { if err != nil { if errors.Is(err, pgx.ErrNoRows) { return Settings{ + ChatModelID: "", + MemoryModelID: "", + EmbeddingModelID: "", MaxContextLoadTime: DefaultMaxContextLoadTime, Language: DefaultLanguage, }, nil @@ -54,6 +57,9 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) } current := Settings{ + ChatModelID: "", + MemoryModelID: "", + EmbeddingModelID: "", MaxContextLoadTime: DefaultMaxContextLoadTime, Language: DefaultLanguage, } @@ -65,6 +71,15 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) current = normalizeUserSetting(existing) } + if value := strings.TrimSpace(req.ChatModelID); value != "" { + current.ChatModelID = value + } + if value := strings.TrimSpace(req.MemoryModelID); value != "" { + current.MemoryModelID = value + } + if value := strings.TrimSpace(req.EmbeddingModelID); value != "" { + current.EmbeddingModelID = value + } if req.MaxContextLoadTime != nil && *req.MaxContextLoadTime > 0 { current.MaxContextLoadTime = *req.MaxContextLoadTime } @@ -74,6 +89,9 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) _, err = s.queries.UpsertSettings(ctx, sqlc.UpsertSettingsParams{ UserID: pgID, + ChatModelID: pgtype.Text{String: current.ChatModelID, Valid: current.ChatModelID != ""}, + MemoryModelID: pgtype.Text{String: current.MemoryModelID, Valid: current.MemoryModelID != ""}, + EmbeddingModelID: pgtype.Text{String: current.EmbeddingModelID, Valid: current.EmbeddingModelID != ""}, MaxContextLoadTime: int32(current.MaxContextLoadTime), Language: current.Language, }) @@ -96,6 +114,9 @@ func (s *Service) Delete(ctx context.Context, userID string) error { func normalizeUserSetting(row sqlc.UserSetting) Settings { settings := Settings{ + ChatModelID: strings.TrimSpace(row.ChatModelID.String), + MemoryModelID: strings.TrimSpace(row.MemoryModelID.String), + EmbeddingModelID: strings.TrimSpace(row.EmbeddingModelID.String), MaxContextLoadTime: int(row.MaxContextLoadTime), Language: strings.TrimSpace(row.Language), } diff --git a/internal/settings/types.go b/internal/settings/types.go index 808ff7ce..c2158e58 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -6,12 +6,17 @@ const ( ) type Settings struct { + ChatModelID string `json:"chat_model_id"` + MemoryModelID string `json:"memory_model_id"` + EmbeddingModelID string `json:"embedding_model_id"` MaxContextLoadTime int `json:"max_context_load_time"` Language string `json:"language"` } type UpsertRequest struct { + ChatModelID string `json:"chat_model_id,omitempty"` + MemoryModelID string `json:"memory_model_id,omitempty"` + EmbeddingModelID string `json:"embedding_model_id,omitempty"` MaxContextLoadTime *int `json:"max_context_load_time,omitempty"` Language string `json:"language,omitempty"` } - diff --git a/packages/cli/src/cli/index.ts b/packages/cli/src/cli/index.ts index 70ecc154..c4efda61 100755 --- a/packages/cli/src/cli/index.ts +++ b/packages/cli/src/cli/index.ts @@ -34,7 +34,6 @@ type Model = { is_multimodal: boolean type: 'chat' | 'embedding' dimensions?: number - enable_as?: 'chat' | 'memory' | 'embedding' } type ModelResponse = Partial & { @@ -61,6 +60,9 @@ type ScheduleListResponse = { } type Settings = { + chat_model_id: string + memory_model_id: string + embedding_model_id: string max_context_load_time: number language: string } @@ -105,7 +107,6 @@ const getModelId = (item: ModelResponse) => item.model?.model_id ?? item.model_i const getProviderId = (item: ModelResponse) => item.model?.llm_provider_id ?? item.llm_provider_id ?? '' const getModelType = (item: ModelResponse) => item.model?.type ?? item.type ?? 'chat' const getModelMultimodal = (item: ModelResponse) => item.model?.is_multimodal ?? item.is_multimodal ?? false -const getModelEnableAs = (item: ModelResponse) => item.model?.enable_as ?? item.enable_as const renderProvidersTable = (providers: Provider[], models: ModelResponse[]) => { const rows: string[][] = [['Provider', 'Type', 'Base URL', 'Models']] @@ -125,14 +126,13 @@ const renderProvidersTable = (providers: Provider[], models: ModelResponse[]) => const renderModelsTable = (models: ModelResponse[], providers: Provider[]) => { const providerMap = new Map(providers.map(p => [p.id, p.name])) - const rows: string[][] = [['Model ID', 'Type', 'Provider', 'Multimodal', 'Enable As']] + const rows: string[][] = [['Model ID', 'Type', 'Provider', 'Multimodal']] for (const item of models) { rows.push([ getModelId(item), getModelType(item), providerMap.get(getProviderId(item)) ?? getProviderId(item), getModelMultimodal(item) ? 'yes' : 'no', - getModelEnableAs(item) ?? '-', ]) } return table(rows) @@ -229,6 +229,9 @@ configCmd.action(async () => { if (!token?.access_token) return try { const settings = await apiRequest('/settings', {}, token) + console.log(`chat_model_id = "${settings.chat_model_id || ''}"`) + console.log(`memory_model_id = "${settings.memory_model_id || ''}"`) + console.log(`embedding_model_id = "${settings.embedding_model_id || ''}"`) console.log(`max_context_load_time = ${settings.max_context_load_time}`) console.log(`language = "${settings.language}"`) } catch (err: unknown) { @@ -241,6 +244,9 @@ configCmd .description('Update config') .option('--host ') .option('--port ') + .option('--chat_model_id ') + .option('--memory_model_id ') + .option('--embedding_model_id ') .option('--max_context_load_time ') .option('--language ') .action(async (opts) => { @@ -257,7 +263,11 @@ configCmd maxContextLoadTime = parsed } let language = opts.language - const hasSettingsInput = opts.max_context_load_time !== undefined || opts.language !== undefined + const hasSettingsInput = opts.max_context_load_time !== undefined + || opts.language !== undefined + || opts.chat_model_id !== undefined + || opts.memory_model_id !== undefined + || opts.embedding_model_id !== undefined const hasConfigInput = Boolean(host || port) if (!hasConfigInput && !hasSettingsInput) { @@ -282,6 +292,9 @@ configCmd language = String(language).trim() } const payload: Partial = {} + if (opts.chat_model_id) payload.chat_model_id = String(opts.chat_model_id).trim() + if (opts.memory_model_id) payload.memory_model_id = String(opts.memory_model_id).trim() + if (opts.embedding_model_id) payload.embedding_model_id = String(opts.embedding_model_id).trim() if (maxContextLoadTime !== undefined) payload.max_context_load_time = maxContextLoadTime if (language) payload.language = language const token = ensureAuth() @@ -393,7 +406,6 @@ model .option('--type ') .option('--dimensions ') .option('--multimodal', 'Is multimodal') - .option('--enable_as ') .action(async (opts) => { const token = ensureAuth() const providers = await apiRequest('/providers', {}, token) @@ -438,7 +450,6 @@ model is_multimodal: isMultimodal, type: modelType, dimensions, - enable_as: opts.enable_as, } const spinner = ora('Creating model...').start() try { @@ -472,10 +483,9 @@ model model .command('enable') - .description('Enable model') - .option('--as ') - .option('--model ') - .option('--provider ') + .description('Enable model for chat/memory/embedding') + .option('--as ') + .option('--model ') .action(async (opts) => { const token = ensureAuth() let enableAs = opts.as @@ -488,57 +498,40 @@ model }]) enableAs = answer.enable_as } - const [providers, models] = await Promise.all([ - apiRequest('/providers', {}, token), - apiRequest('/models', {}, token), - ]) - let providerName = opts.provider - if (!providerName) { - const answer = await inquirer.prompt([{ - type: 'list', - name: 'provider', - message: 'Select provider:', - choices: providers.map(p => p.name), - }]) - providerName = answer.provider - } - const provider = providers.find(p => p.name === providerName) - if (!provider) { - console.log(chalk.red('Provider not found.')) + enableAs = String(enableAs).trim() + if (!['chat', 'memory', 'embedding'].includes(enableAs)) { + console.log(chalk.red('Enable as must be one of chat, memory, embedding.')) process.exit(1) } - let modelName = opts.model - if (!modelName) { - const providerModels = models - .filter(m => getProviderId(m) === provider.id) - .map(m => getModelId(m)) - if (providerModels.length === 0) { - console.log(chalk.red('No models found for selected provider.')) - process.exit(1) - } + const models = await apiRequest('/models', {}, token) + const requiredType = enableAs === 'embedding' ? 'embedding' : 'chat' + const candidates = models.filter(m => getModelType(m) === requiredType) + if (candidates.length === 0) { + console.log(chalk.red(`No ${requiredType} models available.`)) + process.exit(1) + } + let modelId = opts.model + if (!modelId) { const answer = await inquirer.prompt([{ type: 'list', name: 'model', message: 'Select model:', - choices: providerModels, + choices: candidates.map(m => getModelId(m)), }]) - modelName = answer.model + modelId = answer.model } - const current = models.find(m => getModelId(m) === modelName && getProviderId(m) === provider.id) - ?? await apiRequest(`/models/model/${encodeURIComponent(modelName)}`, {}, token) - const modelPayload = current.model - ? { ...current.model, model_id: current.model.model_id } - : { ...current, model_id: current.model_id ?? modelName } - const payload = { - ...modelPayload, - enable_as: enableAs, + const selected = candidates.find(m => getModelId(m) === modelId) + if (!selected) { + console.log(chalk.red('Selected model not found.')) + process.exit(1) } - const spinner = ora('Updating model...').start() + const payload: Partial = {} + if (enableAs === 'chat') payload.chat_model_id = getModelId(selected) + if (enableAs === 'memory') payload.memory_model_id = getModelId(selected) + if (enableAs === 'embedding') payload.embedding_model_id = getModelId(selected) + const spinner = ora('Updating settings...').start() try { - await apiRequest(`/models/model/${encodeURIComponent(modelName)}`, { - method: 'PUT', - body: JSON.stringify(payload), - }, token) + await apiRequest('/settings', { method: 'PUT', body: JSON.stringify(payload) }, token) spinner.succeed('Model enabled') } catch (err: unknown) { spinner.fail(getErrorMessage(err) || 'Failed to enable model')