feat: move default model into user settings

This commit is contained in:
Acbox
2026-02-02 01:39:21 +08:00
parent ed5f41a87e
commit c731e0ca1d
20 changed files with 793 additions and 811 deletions
+3 -3
View File
@@ -11,20 +11,20 @@ import (
"github.com/memohai/memoh/internal/chat"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/logger"
ctr "github.com/memohai/memoh/internal/containerd"
"github.com/memohai/memoh/internal/db"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/embeddings"
"github.com/memohai/memoh/internal/handlers"
"github.com/memohai/memoh/internal/history"
"github.com/memohai/memoh/internal/logger"
"github.com/memohai/memoh/internal/mcp"
"github.com/memohai/memoh/internal/memory"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/schedule"
"github.com/memohai/memoh/internal/settings"
"github.com/memohai/memoh/internal/server"
"github.com/memohai/memoh/internal/settings"
"github.com/memohai/memoh/internal/subagent"
"github.com/memohai/memoh/internal/version"
@@ -177,9 +177,9 @@ func main() {
// Initialize providers and models handlers
providersService := providers.NewService(logger.L, queries)
providersHandler := handlers.NewProvidersHandler(logger.L, providersService)
modelsHandler := handlers.NewModelsHandler(logger.L, modelsService)
settingsService := settings.NewService(logger.L, queries)
settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService)
modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService)
historyService := history.NewService(logger.L, queries)
historyHandler := handlers.NewHistoryHandler(logger.L, historyService)
channelService := channel.NewService(queries)
+3 -7
View File
@@ -79,14 +79,9 @@ CREATE TABLE IF NOT EXISTS models (
dimensions INTEGER,
is_multimodal BOOLEAN NOT NULL DEFAULT false,
type TEXT NOT NULL DEFAULT 'chat',
enable_as TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT models_model_id_unique UNIQUE (model_id),
CONSTRAINT models_enable_as_check CHECK (
(type = 'embedding' AND (enable_as = 'embedding' OR enable_as IS NULL)) OR
(type = 'chat' AND (enable_as IN ('chat', 'memory') OR enable_as IS NULL))
),
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')),
CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL)
);
@@ -104,8 +99,6 @@ CREATE TABLE IF NOT EXISTS model_variants (
CREATE INDEX IF NOT EXISTS idx_model_variants_model_uuid ON model_variants(model_uuid);
CREATE INDEX IF NOT EXISTS idx_model_variants_variant_id ON model_variants(variant_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_models_enable_as_unique ON models(enable_as) WHERE enable_as IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id);
CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id);
@@ -144,6 +137,9 @@ CREATE INDEX IF NOT EXISTS idx_history_timestamp ON history(timestamp);
CREATE TABLE IF NOT EXISTS user_settings (
user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
chat_model_id TEXT,
memory_model_id TEXT,
embedding_model_id TEXT,
max_context_load_time INTEGER NOT NULL DEFAULT 1440,
language TEXT NOT NULL DEFAULT 'Same as user input'
);
+2 -12
View File
@@ -46,15 +46,14 @@ SELECT COUNT(*) FROM llm_providers;
SELECT COUNT(*) FROM llm_providers WHERE client_type = sqlc.arg(client_type);
-- name: CreateModel :one
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as)
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type)
VALUES (
sqlc.arg(model_id),
sqlc.arg(name),
sqlc.arg(llm_provider_id),
sqlc.arg(dimensions),
sqlc.arg(is_multimodal),
sqlc.arg(type),
sqlc.arg(enable_as)
sqlc.arg(type)
)
RETURNING *;
@@ -87,7 +86,6 @@ SET
dimensions = sqlc.arg(dimensions),
is_multimodal = sqlc.arg(is_multimodal),
type = sqlc.arg(type),
enable_as = sqlc.arg(enable_as),
updated_at = now()
WHERE id = sqlc.arg(id)
RETURNING *;
@@ -100,7 +98,6 @@ SET
dimensions = sqlc.arg(dimensions),
is_multimodal = sqlc.arg(is_multimodal),
type = sqlc.arg(type),
enable_as = sqlc.arg(enable_as),
updated_at = now()
WHERE model_id = sqlc.arg(model_id)
RETURNING *;
@@ -117,13 +114,6 @@ SELECT COUNT(*) FROM models;
-- name: CountModelsByType :one
SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type);
-- name: GetModelByEnableAs :one
SELECT * FROM models WHERE enable_as = sqlc.arg(enable_as) LIMIT 1;
-- name: ClearEnableAs :exec
UPDATE models
SET enable_as = NULL, updated_at = now()
WHERE enable_as = sqlc.arg(enable_as);
-- name: CreateModelVariant :one
INSERT INTO model_variants (model_uuid, variant_id, weight, metadata)
+7 -4
View File
@@ -1,15 +1,18 @@
-- name: GetSettingsByUserID :one
SELECT user_id, max_context_load_time, language
SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
FROM user_settings
WHERE user_id = $1;
-- name: UpsertSettings :one
INSERT INTO user_settings (user_id, max_context_load_time, language)
VALUES ($1, $2, $3)
INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_id) DO UPDATE SET
chat_model_id = EXCLUDED.chat_model_id,
memory_model_id = EXCLUDED.memory_model_id,
embedding_model_id = EXCLUDED.embedding_model_id,
max_context_load_time = EXCLUDED.max_context_load_time,
language = EXCLUDED.language
RETURNING user_id, max_context_load_time, language;
RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language;
-- name: DeleteSettingsByUserID :exec
DELETE FROM user_settings
+156 -165
View File
@@ -739,7 +739,7 @@ const docTemplate = `{
},
"/memory/add": {
"post": {
"description": "Add memory for a user via memory",
"description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -751,7 +751,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.AddRequest"
"$ref": "#/definitions/handlers.memoryAddPayload"
}
}
],
@@ -779,7 +779,7 @@ const docTemplate = `{
},
"/memory/embed": {
"post": {
"description": "Embed text or multimodal input and upsert into memory store",
"description": "Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -791,7 +791,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.EmbedUpsertRequest"
"$ref": "#/definitions/handlers.memoryEmbedUpsertPayload"
}
}
],
@@ -819,18 +819,12 @@ const docTemplate = `{
},
"/memory/memories": {
"get": {
"description": "List memories for a user via memory",
"description": "List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
"summary": "List memories",
"parameters": [
{
"type": "string",
"description": "User ID",
"name": "user_id",
"in": "query"
},
{
"type": "string",
"description": "Agent ID",
@@ -872,7 +866,7 @@ const docTemplate = `{
}
},
"delete": {
"description": "Delete all memories for a user via memory",
"description": "Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -884,7 +878,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.DeleteAllRequest"
"$ref": "#/definitions/handlers.memoryDeleteAllPayload"
}
}
],
@@ -912,7 +906,7 @@ const docTemplate = `{
},
"/memory/memories/{memoryId}": {
"get": {
"description": "Get a memory by ID via memory",
"description": "Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -948,7 +942,7 @@ const docTemplate = `{
}
},
"delete": {
"description": "Delete a memory by ID via memory",
"description": "Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -986,7 +980,7 @@ const docTemplate = `{
},
"/memory/search": {
"post": {
"description": "Search memories for a user via memory",
"description": "Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -998,7 +992,7 @@ const docTemplate = `{
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.SearchRequest"
"$ref": "#/definitions/handlers.memorySearchPayload"
}
}
],
@@ -1026,7 +1020,7 @@ const docTemplate = `{
},
"/memory/update": {
"post": {
"description": "Update a memory by ID via memory",
"description": "Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -1185,27 +1179,29 @@ const docTemplate = `{
}
}
},
"/models/enable-as/{enableAs}": {
"get": {
"description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)",
"/models/enable": {
"post": {
"description": "Update the current user's settings to use the selected model",
"tags": [
"models"
],
"summary": "Get model by enable_as",
"summary": "Enable model for chat/memory/embedding",
"parameters": [
{
"type": "string",
"description": "Enable as value (chat, memory, embedding)",
"name": "enableAs",
"in": "path",
"required": true
"description": "Enable model payload",
"name": "payload",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.EnableModelRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/models.GetResponse"
"$ref": "#/definitions/settings.Settings"
}
},
"400": {
@@ -2850,6 +2846,17 @@ const docTemplate = `{
}
}
},
"handlers.EnableModelRequest": {
"type": "object",
"properties": {
"as": {
"type": "string"
},
"model_id": {
"type": "string"
}
}
},
"handlers.ErrorResponse": {
"type": "object",
"properties": {
@@ -2996,6 +3003,109 @@ const docTemplate = `{
}
}
},
"handlers.memoryAddPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"infer": {
"type": "boolean"
},
"message": {
"type": "string"
},
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/memory.Message"
}
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"run_id": {
"type": "string"
}
}
},
"handlers.memoryDeleteAllPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"run_id": {
"type": "string"
}
}
},
"handlers.memoryEmbedUpsertPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"input": {
"$ref": "#/definitions/memory.EmbedInput"
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"run_id": {
"type": "string"
},
"source": {
"type": "string"
},
"type": {
"type": "string"
}
}
},
"handlers.memorySearchPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"limit": {
"type": "integer"
},
"query": {
"type": "string"
},
"run_id": {
"type": "string"
},
"sources": {
"type": "array",
"items": {
"type": "string"
}
}
}
},
"handlers.skillsOpResponse": {
"type": "object",
"properties": {
@@ -3060,54 +3170,6 @@ const docTemplate = `{
}
}
},
"memory.AddRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"infer": {
"type": "boolean"
},
"message": {
"type": "string"
},
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/memory.Message"
}
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"run_id": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
},
"memory.DeleteAllRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"run_id": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
},
"memory.DeleteResponse": {
"type": "object",
"properties": {
@@ -3130,43 +3192,6 @@ const docTemplate = `{
}
}
},
"memory.EmbedUpsertRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"input": {
"$ref": "#/definitions/memory.EmbedInput"
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"run_id": {
"type": "string"
},
"source": {
"type": "string"
},
"type": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
},
"memory.EmbedUpsertResponse": {
"type": "object",
"properties": {
@@ -3231,36 +3256,6 @@ const docTemplate = `{
}
}
},
"memory.SearchRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"limit": {
"type": "integer"
},
"query": {
"type": "string"
},
"run_id": {
"type": "string"
},
"sources": {
"type": "array",
"items": {
"type": "string"
}
},
"user_id": {
"type": "string"
}
}
},
"memory.SearchResponse": {
"type": "object",
"properties": {
@@ -3293,9 +3288,6 @@ const docTemplate = `{
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -3332,28 +3324,12 @@ const docTemplate = `{
}
}
},
"models.EnableAs": {
"type": "string",
"enum": [
"chat",
"memory",
"embedding"
],
"x-enum-varnames": [
"EnableAsChat",
"EnableAsMemory",
"EnableAsEmbedding"
]
},
"models.GetResponse": {
"type": "object",
"properties": {
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -3388,9 +3364,6 @@ const docTemplate = `{
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -3609,22 +3582,40 @@ const docTemplate = `{
"settings.Settings": {
"type": "object",
"properties": {
"chat_model_id": {
"type": "string"
},
"embedding_model_id": {
"type": "string"
},
"language": {
"type": "string"
},
"max_context_load_time": {
"type": "integer"
},
"memory_model_id": {
"type": "string"
}
}
},
"settings.UpsertRequest": {
"type": "object",
"properties": {
"chat_model_id": {
"type": "string"
},
"embedding_model_id": {
"type": "string"
},
"language": {
"type": "string"
},
"max_context_load_time": {
"type": "integer"
},
"memory_model_id": {
"type": "string"
}
}
},
+156 -165
View File
@@ -730,7 +730,7 @@
},
"/memory/add": {
"post": {
"description": "Add memory for a user via memory",
"description": "Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -742,7 +742,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.AddRequest"
"$ref": "#/definitions/handlers.memoryAddPayload"
}
}
],
@@ -770,7 +770,7 @@
},
"/memory/embed": {
"post": {
"description": "Embed text or multimodal input and upsert into memory store",
"description": "Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -782,7 +782,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.EmbedUpsertRequest"
"$ref": "#/definitions/handlers.memoryEmbedUpsertPayload"
}
}
],
@@ -810,18 +810,12 @@
},
"/memory/memories": {
"get": {
"description": "List memories for a user via memory",
"description": "List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
"summary": "List memories",
"parameters": [
{
"type": "string",
"description": "User ID",
"name": "user_id",
"in": "query"
},
{
"type": "string",
"description": "Agent ID",
@@ -863,7 +857,7 @@
}
},
"delete": {
"description": "Delete all memories for a user via memory",
"description": "Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -875,7 +869,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.DeleteAllRequest"
"$ref": "#/definitions/handlers.memoryDeleteAllPayload"
}
}
],
@@ -903,7 +897,7 @@
},
"/memory/memories/{memoryId}": {
"get": {
"description": "Get a memory by ID via memory",
"description": "Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -939,7 +933,7 @@
}
},
"delete": {
"description": "Delete a memory by ID via memory",
"description": "Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -977,7 +971,7 @@
},
"/memory/search": {
"post": {
"description": "Search memories for a user via memory",
"description": "Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -989,7 +983,7 @@
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/memory.SearchRequest"
"$ref": "#/definitions/handlers.memorySearchPayload"
}
}
],
@@ -1017,7 +1011,7 @@
},
"/memory/update": {
"post": {
"description": "Update a memory by ID via memory",
"description": "Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).",
"tags": [
"memory"
],
@@ -1176,27 +1170,29 @@
}
}
},
"/models/enable-as/{enableAs}": {
"get": {
"description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)",
"/models/enable": {
"post": {
"description": "Update the current user's settings to use the selected model",
"tags": [
"models"
],
"summary": "Get model by enable_as",
"summary": "Enable model for chat/memory/embedding",
"parameters": [
{
"type": "string",
"description": "Enable as value (chat, memory, embedding)",
"name": "enableAs",
"in": "path",
"required": true
"description": "Enable model payload",
"name": "payload",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.EnableModelRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/models.GetResponse"
"$ref": "#/definitions/settings.Settings"
}
},
"400": {
@@ -2841,6 +2837,17 @@
}
}
},
"handlers.EnableModelRequest": {
"type": "object",
"properties": {
"as": {
"type": "string"
},
"model_id": {
"type": "string"
}
}
},
"handlers.ErrorResponse": {
"type": "object",
"properties": {
@@ -2987,6 +2994,109 @@
}
}
},
"handlers.memoryAddPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"infer": {
"type": "boolean"
},
"message": {
"type": "string"
},
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/memory.Message"
}
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"run_id": {
"type": "string"
}
}
},
"handlers.memoryDeleteAllPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"run_id": {
"type": "string"
}
}
},
"handlers.memoryEmbedUpsertPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"input": {
"$ref": "#/definitions/memory.EmbedInput"
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"run_id": {
"type": "string"
},
"source": {
"type": "string"
},
"type": {
"type": "string"
}
}
},
"handlers.memorySearchPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"limit": {
"type": "integer"
},
"query": {
"type": "string"
},
"run_id": {
"type": "string"
},
"sources": {
"type": "array",
"items": {
"type": "string"
}
}
}
},
"handlers.skillsOpResponse": {
"type": "object",
"properties": {
@@ -3051,54 +3161,6 @@
}
}
},
"memory.AddRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"infer": {
"type": "boolean"
},
"message": {
"type": "string"
},
"messages": {
"type": "array",
"items": {
"$ref": "#/definitions/memory.Message"
}
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"run_id": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
},
"memory.DeleteAllRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"run_id": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
},
"memory.DeleteResponse": {
"type": "object",
"properties": {
@@ -3121,43 +3183,6 @@
}
}
},
"memory.EmbedUpsertRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"input": {
"$ref": "#/definitions/memory.EmbedInput"
},
"metadata": {
"type": "object",
"additionalProperties": true
},
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"run_id": {
"type": "string"
},
"source": {
"type": "string"
},
"type": {
"type": "string"
},
"user_id": {
"type": "string"
}
}
},
"memory.EmbedUpsertResponse": {
"type": "object",
"properties": {
@@ -3222,36 +3247,6 @@
}
}
},
"memory.SearchRequest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
},
"limit": {
"type": "integer"
},
"query": {
"type": "string"
},
"run_id": {
"type": "string"
},
"sources": {
"type": "array",
"items": {
"type": "string"
}
},
"user_id": {
"type": "string"
}
}
},
"memory.SearchResponse": {
"type": "object",
"properties": {
@@ -3284,9 +3279,6 @@
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -3323,28 +3315,12 @@
}
}
},
"models.EnableAs": {
"type": "string",
"enum": [
"chat",
"memory",
"embedding"
],
"x-enum-varnames": [
"EnableAsChat",
"EnableAsMemory",
"EnableAsEmbedding"
]
},
"models.GetResponse": {
"type": "object",
"properties": {
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -3379,9 +3355,6 @@
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -3600,22 +3573,40 @@
"settings.Settings": {
"type": "object",
"properties": {
"chat_model_id": {
"type": "string"
},
"embedding_model_id": {
"type": "string"
},
"language": {
"type": "string"
},
"max_context_load_time": {
"type": "integer"
},
"memory_model_id": {
"type": "string"
}
}
},
"settings.UpsertRequest": {
"type": "object",
"properties": {
"chat_model_id": {
"type": "string"
},
"embedding_model_id": {
"type": "string"
},
"language": {
"type": "string"
},
"max_context_load_time": {
"type": "integer"
},
"memory_model_id": {
"type": "string"
}
}
},
+118 -119
View File
@@ -166,6 +166,13 @@ definitions:
video_tokens:
type: integer
type: object
handlers.EnableModelRequest:
properties:
as:
type: string
model_id:
type: string
type: object
handlers.ErrorResponse:
properties:
message:
@@ -260,6 +267,75 @@ definitions:
updated_at:
type: string
type: object
handlers.memoryAddPayload:
properties:
agent_id:
type: string
filters:
additionalProperties: true
type: object
infer:
type: boolean
message:
type: string
messages:
items:
$ref: '#/definitions/memory.Message'
type: array
metadata:
additionalProperties: true
type: object
run_id:
type: string
type: object
handlers.memoryDeleteAllPayload:
properties:
agent_id:
type: string
run_id:
type: string
type: object
handlers.memoryEmbedUpsertPayload:
properties:
agent_id:
type: string
filters:
additionalProperties: true
type: object
input:
$ref: '#/definitions/memory.EmbedInput'
metadata:
additionalProperties: true
type: object
model:
type: string
provider:
type: string
run_id:
type: string
source:
type: string
type:
type: string
type: object
handlers.memorySearchPayload:
properties:
agent_id:
type: string
filters:
additionalProperties: true
type: object
limit:
type: integer
query:
type: string
run_id:
type: string
sources:
items:
type: string
type: array
type: object
handlers.skillsOpResponse:
properties:
ok:
@@ -302,38 +378,6 @@ definitions:
user_id:
type: string
type: object
memory.AddRequest:
properties:
agent_id:
type: string
filters:
additionalProperties: true
type: object
infer:
type: boolean
message:
type: string
messages:
items:
$ref: '#/definitions/memory.Message'
type: array
metadata:
additionalProperties: true
type: object
run_id:
type: string
user_id:
type: string
type: object
memory.DeleteAllRequest:
properties:
agent_id:
type: string
run_id:
type: string
user_id:
type: string
type: object
memory.DeleteResponse:
properties:
message:
@@ -348,31 +392,6 @@ definitions:
video_url:
type: string
type: object
memory.EmbedUpsertRequest:
properties:
agent_id:
type: string
filters:
additionalProperties: true
type: object
input:
$ref: '#/definitions/memory.EmbedInput'
metadata:
additionalProperties: true
type: object
model:
type: string
provider:
type: string
run_id:
type: string
source:
type: string
type:
type: string
user_id:
type: string
type: object
memory.EmbedUpsertResponse:
properties:
dimensions:
@@ -415,26 +434,6 @@ definitions:
role:
type: string
type: object
memory.SearchRequest:
properties:
agent_id:
type: string
filters:
additionalProperties: true
type: object
limit:
type: integer
query:
type: string
run_id:
type: string
sources:
items:
type: string
type: array
user_id:
type: string
type: object
memory.SearchResponse:
properties:
relations:
@@ -456,8 +455,6 @@ definitions:
properties:
dimensions:
type: integer
enable_as:
$ref: '#/definitions/models.EnableAs'
is_multimodal:
type: boolean
llm_provider_id:
@@ -481,22 +478,10 @@ definitions:
count:
type: integer
type: object
models.EnableAs:
enum:
- chat
- memory
- embedding
type: string
x-enum-varnames:
- EnableAsChat
- EnableAsMemory
- EnableAsEmbedding
models.GetResponse:
properties:
dimensions:
type: integer
enable_as:
$ref: '#/definitions/models.EnableAs'
is_multimodal:
type: boolean
llm_provider_id:
@@ -520,8 +505,6 @@ definitions:
properties:
dimensions:
type: integer
enable_as:
$ref: '#/definitions/models.EnableAs'
is_multimodal:
type: boolean
llm_provider_id:
@@ -669,17 +652,29 @@ definitions:
type: object
settings.Settings:
properties:
chat_model_id:
type: string
embedding_model_id:
type: string
language:
type: string
max_context_load_time:
type: integer
memory_model_id:
type: string
type: object
settings.UpsertRequest:
properties:
chat_model_id:
type: string
embedding_model_id:
type: string
language:
type: string
max_context_load_time:
type: integer
memory_model_id:
type: string
type: object
subagent.AddSkillsRequest:
properties:
@@ -1279,14 +1274,15 @@ paths:
- containerd
/memory/add:
post:
description: Add memory for a user via memory
description: 'Add memory for a user via memory. Auth: Bearer JWT determines
user_id (sub or user_id).'
parameters:
- description: Add request
in: body
name: payload
required: true
schema:
$ref: '#/definitions/memory.AddRequest'
$ref: '#/definitions/handlers.memoryAddPayload'
responses:
"200":
description: OK
@@ -1305,14 +1301,15 @@ paths:
- memory
/memory/embed:
post:
description: Embed text or multimodal input and upsert into memory store
description: 'Embed text or multimodal input and upsert into memory store. Auth:
Bearer JWT determines user_id (sub or user_id).'
parameters:
- description: Embed upsert request
in: body
name: payload
required: true
schema:
$ref: '#/definitions/memory.EmbedUpsertRequest'
$ref: '#/definitions/handlers.memoryEmbedUpsertPayload'
responses:
"200":
description: OK
@@ -1331,14 +1328,15 @@ paths:
- memory
/memory/memories:
delete:
description: Delete all memories for a user via memory
description: 'Delete all memories for a user via memory. Auth: Bearer JWT determines
user_id (sub or user_id).'
parameters:
- description: Delete all request
in: body
name: payload
required: true
schema:
$ref: '#/definitions/memory.DeleteAllRequest'
$ref: '#/definitions/handlers.memoryDeleteAllPayload'
responses:
"200":
description: OK
@@ -1356,12 +1354,9 @@ paths:
tags:
- memory
get:
description: List memories for a user via memory
description: 'List memories for a user via memory. Auth: Bearer JWT determines
user_id (sub or user_id).'
parameters:
- description: User ID
in: query
name: user_id
type: string
- description: Agent ID
in: query
name: agent_id
@@ -1392,7 +1387,8 @@ paths:
- memory
/memory/memories/{memoryId}:
delete:
description: Delete a memory by ID via memory
description: 'Delete a memory by ID via memory. Auth: Bearer JWT determines
user_id (sub or user_id).'
parameters:
- description: Memory ID
in: path
@@ -1416,7 +1412,8 @@ paths:
tags:
- memory
get:
description: Get a memory by ID via memory
description: 'Get a memory by ID via memory. Auth: Bearer JWT determines user_id
(sub or user_id).'
parameters:
- description: Memory ID
in: path
@@ -1441,14 +1438,15 @@ paths:
- memory
/memory/search:
post:
description: Search memories for a user via memory
description: 'Search memories for a user via memory. Auth: Bearer JWT determines
user_id (sub or user_id).'
parameters:
- description: Search request
in: body
name: payload
required: true
schema:
$ref: '#/definitions/memory.SearchRequest'
$ref: '#/definitions/handlers.memorySearchPayload'
responses:
"200":
description: OK
@@ -1467,7 +1465,8 @@ paths:
- memory
/memory/update:
post:
description: Update a memory by ID via memory
description: 'Update a memory by ID via memory. Auth: Bearer JWT determines
user_id (sub or user_id).'
parameters:
- description: Update request
in: body
@@ -1660,21 +1659,21 @@ paths:
summary: Get model count
tags:
- models
/models/enable-as/{enableAs}:
get:
description: Get the model that is enabled for a specific purpose (chat, memory,
embedding)
/models/enable:
post:
description: Update the current user's settings to use the selected model
parameters:
- description: Enable as value (chat, memory, embedding)
in: path
name: enableAs
- description: Enable model payload
in: body
name: payload
required: true
type: string
schema:
$ref: '#/definitions/handlers.EnableModelRequest'
responses:
"200":
description: OK
schema:
$ref: '#/definitions/models.GetResponse'
$ref: '#/definitions/settings.Settings'
"400":
description: Bad Request
schema:
@@ -1687,7 +1686,7 @@ paths:
description: Internal Server Error
schema:
$ref: '#/definitions/handlers.ErrorResponse'
summary: Get model by enable_as
summary: Enable model for chat/memory/embedding
tags:
- models
/models/model/{modelId}:
+63 -34
View File
@@ -36,6 +36,14 @@ type Resolver struct {
streamingClient *http.Client
}
type userSettings struct {
ChatModelID string
MemoryModelID string
EmbeddingModelID string
MaxContextLoadTime int
Language string
}
func NewResolver(log *slog.Logger, modelsService *models.Service, queries *sqlc.Queries, memoryService *memory.Service, gatewayBaseURL string, timeout time.Duration) *Resolver {
if strings.TrimSpace(gatewayBaseURL) == "" {
gatewayBaseURL = "http://127.0.0.1:8081"
@@ -67,7 +75,11 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
}
skipHistory := req.MaxContextLoadTime < 0
chatModel, provider, err := r.selectChatModel(ctx, req)
settings, err := r.loadUserSettings(ctx, req.UserID)
if err != nil {
return ChatResponse{}, err
}
chatModel, provider, err := r.selectChatModel(ctx, req, settings)
if err != nil {
return ChatResponse{}, err
}
@@ -76,10 +88,8 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
return ChatResponse{}, err
}
maxContextLoadTime, language, err := r.loadUserSettings(ctx, req.UserID)
if err != nil {
return ChatResponse{}, err
}
maxContextLoadTime := settings.MaxContextLoadTime
language := settings.Language
if req.MaxContextLoadTime > 0 {
maxContextLoadTime = req.MaxContextLoadTime
}
@@ -157,7 +167,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule
Locale: "",
Language: "",
}
chatModel, provider, err := r.selectChatModel(ctx, req)
settings, err := r.loadUserSettings(ctx, userID)
if err != nil {
return err
}
chatModel, provider, err := r.selectChatModel(ctx, req, settings)
if err != nil {
return err
}
@@ -166,10 +180,8 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule
return err
}
maxContextLoadTime, language, err := r.loadUserSettings(ctx, userID)
if err != nil {
return err
}
maxContextLoadTime := settings.MaxContextLoadTime
language := settings.Language
messages, err := r.loadHistoryMessages(ctx, userID, maxContextLoadTime)
if err != nil {
@@ -229,7 +241,12 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre
}
skipHistory := req.MaxContextLoadTime < 0
chatModel, provider, err := r.selectChatModel(ctx, req)
settings, err := r.loadUserSettings(ctx, req.UserID)
if err != nil {
errChan <- err
return
}
chatModel, provider, err := r.selectChatModel(ctx, req, settings)
if err != nil {
errChan <- err
return
@@ -240,11 +257,8 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre
return
}
maxContextLoadTime, language, err := r.loadUserSettings(ctx, req.UserID)
if err != nil {
errChan <- err
return
}
maxContextLoadTime := settings.MaxContextLoadTime
language := settings.Language
if req.MaxContextLoadTime > 0 {
maxContextLoadTime = req.MaxContextLoadTime
}
@@ -797,7 +811,7 @@ func isEmptyValue(value interface{}) bool {
}
}
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest) (models.GetResponse, sqlc.LlmProvider, error) {
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, settings userSettings) (models.GetResponse, sqlc.LlmProvider, error) {
if r.modelsService == nil {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
}
@@ -819,15 +833,19 @@ func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest) (models
return model, provider, nil
}
if providerFilter == "" && modelID == "" {
defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsChat)
if err == nil {
provider, err := models.FetchProviderByID(ctx, r.queries, defaultModel.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return defaultModel, provider, nil
if providerFilter == "" && modelID == "" && strings.TrimSpace(settings.ChatModelID) != "" {
selected, err := r.modelsService.GetByModelID(ctx, settings.ChatModelID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model not found: %w", err)
}
if selected.Type != models.ModelTypeChat {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model is not a chat model")
}
provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return selected, provider, nil
}
var candidates []models.GetResponse
@@ -880,26 +898,31 @@ func normalizeMaxContextLoad(value int) int {
return value
}
func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (int, string, error) {
func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (userSettings, error) {
if r.queries == nil {
return defaultMaxContextMinutes, "Same as user input", nil
return userSettings{
MaxContextLoadTime: defaultMaxContextMinutes,
Language: "Same as user input",
}, nil
}
pgUserID, err := parseUUID(userID)
if err != nil {
return 0, "", err
return userSettings{}, err
}
settingsRow, err := r.queries.GetSettingsByUserID(ctx, pgUserID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return defaultMaxContextMinutes, "Same as user input", nil
return userSettings{
MaxContextLoadTime: defaultMaxContextMinutes,
Language: "Same as user input",
}, nil
}
return 0, "", err
return userSettings{}, err
}
maxLoad, language := normalizeUserSettingRow(settingsRow)
return maxLoad, language, nil
return normalizeUserSettingRow(settingsRow), nil
}
func normalizeUserSettingRow(row sqlc.UserSetting) (int, string) {
func normalizeUserSettingRow(row sqlc.UserSetting) userSettings {
maxLoad := int(row.MaxContextLoadTime)
if maxLoad <= 0 {
maxLoad = defaultMaxContextMinutes
@@ -908,7 +931,13 @@ func normalizeUserSettingRow(row sqlc.UserSetting) (int, string) {
if language == "" {
language = "Same as user input"
}
return maxLoad, language
return userSettings{
ChatModelID: strings.TrimSpace(row.ChatModelID.String),
MemoryModelID: strings.TrimSpace(row.MemoryModelID.String),
EmbeddingModelID: strings.TrimSpace(row.EmbeddingModelID.String),
MaxContextLoadTime: maxLoad,
Language: language,
}
}
func normalizeClientType(clientType string) (string, error) {
+3 -1
View File
@@ -87,7 +87,6 @@ type Model struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
@@ -156,6 +155,9 @@ type User struct {
type UserSetting struct {
UserID pgtype.UUID `json:"user_id"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
MaxContextLoadTime int32 `json:"max_context_load_time"`
Language string `json:"language"`
}
+12 -62
View File
@@ -11,17 +11,6 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const clearEnableAs = `-- name: ClearEnableAs :exec
UPDATE models
SET enable_as = NULL, updated_at = now()
WHERE enable_as = $1
`
func (q *Queries) ClearEnableAs(ctx context.Context, enableAs pgtype.Text) error {
_, err := q.db.Exec(ctx, clearEnableAs, enableAs)
return err
}
const countLlmProviders = `-- name: CountLlmProviders :one
SELECT COUNT(*) FROM llm_providers
`
@@ -109,17 +98,16 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa
}
const createModel = `-- name: CreateModel :one
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as)
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type)
VALUES (
$1,
$2,
$3,
$4,
$5,
$6,
$7
$6
)
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
`
type CreateModelParams struct {
@@ -129,7 +117,6 @@ type CreateModelParams struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
}
func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) {
@@ -140,7 +127,6 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
arg.Dimensions,
arg.IsMultimodal,
arg.Type,
arg.EnableAs,
)
var i Model
err := row.Scan(
@@ -151,7 +137,6 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -272,30 +257,8 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro
return i, err
}
const getModelByEnableAs = `-- name: GetModelByEnableAs :one
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE enable_as = $1 LIMIT 1
`
func (q *Queries) GetModelByEnableAs(ctx context.Context, enableAs pgtype.Text) (Model, error) {
row := q.db.QueryRow(ctx, getModelByEnableAs, enableAs)
var i Model
err := row.Scan(
&i.ID,
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getModelByID = `-- name: GetModelByID :one
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE id = $1
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE id = $1
`
func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) {
@@ -309,7 +272,6 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -317,7 +279,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
}
const getModelByModelID = `-- name: GetModelByModelID :one
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE model_id = $1
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE model_id = $1
`
func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) {
@@ -331,7 +293,6 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -495,7 +456,7 @@ func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID st
}
const listModels = `-- name: ListModels :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
ORDER BY created_at DESC
`
@@ -516,7 +477,6 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -531,7 +491,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
}
const listModelsByClientType = `-- name: ListModelsByClientType :many
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.enable_as, m.created_at, m.updated_at FROM models AS m
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.created_at, m.updated_at FROM models AS m
JOIN llm_providers AS p ON p.id = m.llm_provider_id
WHERE p.client_type = $1
ORDER BY m.created_at DESC
@@ -554,7 +514,6 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -569,7 +528,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
}
const listModelsByType = `-- name: ListModelsByType :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
WHERE type = $1
ORDER BY created_at DESC
`
@@ -591,7 +550,6 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -658,10 +616,9 @@ SET
dimensions = $3,
is_multimodal = $4,
type = $5,
enable_as = $6,
updated_at = now()
WHERE id = $7
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
WHERE id = $6
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
`
type UpdateModelParams struct {
@@ -670,7 +627,6 @@ type UpdateModelParams struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
ID pgtype.UUID `json:"id"`
}
@@ -681,7 +637,6 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
arg.Dimensions,
arg.IsMultimodal,
arg.Type,
arg.EnableAs,
arg.ID,
)
var i Model
@@ -693,7 +648,6 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -708,10 +662,9 @@ SET
dimensions = $3,
is_multimodal = $4,
type = $5,
enable_as = $6,
updated_at = now()
WHERE model_id = $7
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
WHERE model_id = $6
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
`
type UpdateModelByModelIDParams struct {
@@ -720,7 +673,6 @@ type UpdateModelByModelIDParams struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
ModelID string `json:"model_id"`
}
@@ -731,7 +683,6 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
arg.Dimensions,
arg.IsMultimodal,
arg.Type,
arg.EnableAs,
arg.ModelID,
)
var i Model
@@ -743,7 +694,6 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
+34 -7
View File
@@ -22,7 +22,7 @@ func (q *Queries) DeleteSettingsByUserID(ctx context.Context, userID pgtype.UUID
}
const getSettingsByUserID = `-- name: GetSettingsByUserID :one
SELECT user_id, max_context_load_time, language
SELECT user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
FROM user_settings
WHERE user_id = $1
`
@@ -30,28 +30,55 @@ WHERE user_id = $1
func (q *Queries) GetSettingsByUserID(ctx context.Context, userID pgtype.UUID) (UserSetting, error) {
row := q.db.QueryRow(ctx, getSettingsByUserID, userID)
var i UserSetting
err := row.Scan(&i.UserID, &i.MaxContextLoadTime, &i.Language)
err := row.Scan(
&i.UserID,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
)
return i, err
}
const upsertSettings = `-- name: UpsertSettings :one
INSERT INTO user_settings (user_id, max_context_load_time, language)
VALUES ($1, $2, $3)
INSERT INTO user_settings (user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_id) DO UPDATE SET
chat_model_id = EXCLUDED.chat_model_id,
memory_model_id = EXCLUDED.memory_model_id,
embedding_model_id = EXCLUDED.embedding_model_id,
max_context_load_time = EXCLUDED.max_context_load_time,
language = EXCLUDED.language
RETURNING user_id, max_context_load_time, language
RETURNING user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
`
type UpsertSettingsParams struct {
UserID pgtype.UUID `json:"user_id"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
MaxContextLoadTime int32 `json:"max_context_load_time"`
Language string `json:"language"`
}
func (q *Queries) UpsertSettings(ctx context.Context, arg UpsertSettingsParams) (UserSetting, error) {
row := q.db.QueryRow(ctx, upsertSettings, arg.UserID, arg.MaxContextLoadTime, arg.Language)
row := q.db.QueryRow(ctx, upsertSettings,
arg.UserID,
arg.ChatModelID,
arg.MemoryModelID,
arg.EmbeddingModelID,
arg.MaxContextLoadTime,
arg.Language,
)
var i UserSetting
err := row.Scan(&i.UserID, &i.MaxContextLoadTime, &i.Language)
err := row.Scan(
&i.UserID,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
)
return i, err
}
+53 -14
View File
@@ -3,11 +3,13 @@ package embeddings
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db/sqlc"
@@ -29,6 +31,7 @@ type Request struct {
Model string
Dimensions int
Input Input
UserID string
}
type Input struct {
@@ -176,21 +179,28 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model
return models.GetResponse{}, errors.New("models service not configured")
}
// If no model specified and no provider specified, try to get default embedding model
if req.Model == "" && req.Provider == "" {
defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsEmbedding)
if err == nil {
// Found default model, check if it matches the type requirement
if req.Type == TypeMultimodal && !defaultModel.IsMultimodal {
// Default is text, but need multimodal - continue to search
} else if req.Type == TypeText && defaultModel.IsMultimodal {
// Default is multimodal, but need text - continue to search
} else {
// Default model matches requirements
return defaultModel, nil
}
// If no model specified and no provider specified, try to get per-user embedding model.
if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.UserID) != "" {
modelID, err := r.loadUserEmbeddingModelID(ctx, req.UserID)
if err != nil {
return models.GetResponse{}, err
}
if modelID != "" {
selected, err := r.modelsService.GetByModelID(ctx, modelID)
if err != nil {
return models.GetResponse{}, fmt.Errorf("settings embedding model not found: %w", err)
}
if selected.Type != models.ModelTypeEmbedding {
return models.GetResponse{}, errors.New("settings embedding model is not an embedding model")
}
if req.Type == TypeMultimodal && !selected.IsMultimodal {
return models.GetResponse{}, errors.New("settings embedding model does not support multimodal")
}
if req.Type == TypeText && selected.IsMultimodal {
return models.GetResponse{}, errors.New("settings embedding model does not support text embeddings")
}
return selected, nil
}
// No default model or doesn't match requirements, continue to search
}
var candidates []models.GetResponse
@@ -246,3 +256,32 @@ func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.L
copy(pgID.Bytes[:], parsed[:])
return r.queries.GetLlmProviderByID(ctx, pgID)
}
func (r *Resolver) loadUserEmbeddingModelID(ctx context.Context, userID string) (string, error) {
if r.queries == nil {
return "", nil
}
pgUserID, err := parseUUID(userID)
if err != nil {
return "", err
}
row, err := r.queries.GetSettingsByUserID(ctx, pgUserID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", nil
}
return "", err
}
return strings.TrimSpace(row.EmbeddingModelID.String), nil
}
func parseUUID(id string) (pgtype.UUID, error) {
parsed, err := uuid.Parse(id)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err)
}
var pgID pgtype.UUID
pgID.Valid = true
copy(pgID.Bytes[:], parsed[:])
return pgID, nil
}
+8
View File
@@ -8,6 +8,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/embeddings"
"github.com/memohai/memoh/internal/models"
@@ -84,6 +85,12 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL)
req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL)
userID := ""
if c.Get("user") != nil {
if value, err := auth.UserIDFromContext(c); err == nil {
userID = value
}
}
result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{
Type: req.Type,
Provider: req.Provider,
@@ -94,6 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
ImageURL: req.Input.ImageURL,
VideoURL: req.Input.VideoURL,
},
UserID: userID,
})
if err != nil {
message := err.Error()
+72 -29
View File
@@ -4,21 +4,26 @@ import (
"log/slog"
"net/http"
"net/url"
"strings"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/settings"
)
type ModelsHandler struct {
service *models.Service
logger *slog.Logger
service *models.Service
settingsService *settings.Service
logger *slog.Logger
}
func NewModelsHandler(log *slog.Logger, service *models.Service) *ModelsHandler {
func NewModelsHandler(log *slog.Logger, service *models.Service, settingsService *settings.Service) *ModelsHandler {
return &ModelsHandler{
service: service,
logger: log.With(slog.String("handler", "models")),
service: service,
settingsService: settingsService,
logger: log.With(slog.String("handler", "models")),
}
}
@@ -28,7 +33,7 @@ func (h *ModelsHandler) Register(e *echo.Echo) {
group.GET("", h.List)
group.GET("/:id", h.GetByID)
group.GET("/model/:modelId", h.GetByModelID)
group.GET("/enable-as/:enableAs", h.GetByEnableAs)
group.POST("/enable", h.Enable)
group.PUT("/:id", h.UpdateByID)
group.PUT("/model/:modelId", h.UpdateByModelID)
group.DELETE("/:id", h.DeleteByID)
@@ -140,6 +145,67 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error {
return c.JSON(http.StatusOK, resp)
}
type EnableModelRequest struct {
As string `json:"as"`
ModelID string `json:"model_id"`
}
// Enable godoc
// @Summary Enable model for chat/memory/embedding
// @Description Update the current user's settings to use the selected model
// @Tags models
// @Param payload body handlers.EnableModelRequest true "Enable model payload"
// @Success 200 {object} settings.Settings
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /models/enable [post]
func (h *ModelsHandler) Enable(c echo.Context) error {
if h.settingsService == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured")
}
userID, err := auth.UserIDFromContext(c)
if err != nil {
return err
}
var req EnableModelRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
req.As = strings.ToLower(strings.TrimSpace(req.As))
req.ModelID = strings.TrimSpace(req.ModelID)
if req.As == "" || req.ModelID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "as and model_id are required")
}
if req.As != "chat" && req.As != "memory" && req.As != "embedding" {
return echo.NewHTTPError(http.StatusBadRequest, "as must be one of chat, memory, embedding")
}
model, err := h.service.GetByModelID(c.Request().Context(), req.ModelID)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
if req.As == "embedding" && model.Type != models.ModelTypeEmbedding {
return echo.NewHTTPError(http.StatusBadRequest, "model is not an embedding model")
}
if (req.As == "chat" || req.As == "memory") && model.Type != models.ModelTypeChat {
return echo.NewHTTPError(http.StatusBadRequest, "model is not a chat model")
}
upsert := settings.UpsertRequest{}
switch req.As {
case "chat":
upsert.ChatModelID = req.ModelID
case "memory":
upsert.MemoryModelID = req.ModelID
case "embedding":
upsert.EmbeddingModelID = req.ModelID
}
resp, err := h.settingsService.Upsert(c.Request().Context(), userID, upsert)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, resp)
}
// UpdateByID godoc
// @Summary Update model by internal ID
// @Description Update a model configuration by its internal UUID
@@ -252,29 +318,6 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
}
// GetByEnableAs godoc
// @Summary Get model by enable_as
// @Description Get the model that is enabled for a specific purpose (chat, memory, embedding)
// @Tags models
// @Param enableAs path string true "Enable as value (chat, memory, embedding)"
// @Success 200 {object} models.GetResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /models/enable-as/{enableAs} [get]
func (h *ModelsHandler) GetByEnableAs(c echo.Context) error {
enableAs := c.Param("enableAs")
if enableAs == "" {
return echo.NewHTTPError(http.StatusBadRequest, "enableAs is required")
}
resp, err := h.service.GetByEnableAs(c.Request().Context(), models.EnableAs(enableAs))
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
return c.JSON(http.StatusOK, resp)
}
// Count godoc
// @Summary Get model count
// @Description Get the total count of models, optionally filtered by type
-57
View File
@@ -1,57 +0,0 @@
package models
import (
"context"
"fmt"
"strings"
"github.com/memohai/memoh/internal/db/sqlc"
)
// SelectMemoryModel selects a chat model for memory operations.
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) {
// First try to get the memory-enabled model.
memoryModel, err := modelsService.GetByEnableAs(ctx, EnableAsMemory)
if err == nil {
provider, err := FetchProviderByID(ctx, queries, memoryModel.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return memoryModel, provider, nil
}
// Fallback to chat model.
chatModel, err := modelsService.GetByEnableAs(ctx, EnableAsChat)
if err == nil {
provider, err := FetchProviderByID(ctx, queries, chatModel.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return chatModel, provider, nil
}
// If no enabled models, try to find any chat model.
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
if err != nil || len(candidates) == 0 {
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
}
selected := candidates[0]
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return selected, provider, nil
}
// FetchProviderByID fetches a provider by ID.
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
if strings.TrimSpace(providerID) == "" {
return sqlc.LlmProvider{}, fmt.Errorf("provider id missing")
}
parsed, err := parseUUID(providerID)
if err != nil {
return sqlc.LlmProvider{}, err
}
return queries.GetLlmProviderByID(ctx, parsed)
}
+30 -55
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
@@ -31,13 +32,6 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
return AddResponse{}, fmt.Errorf("validation failed: %w", err)
}
// If enable_as is set, clear any existing model with the same enable_as
if model.EnableAs != nil {
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
return AddResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
}
}
// Convert to sqlc params
llmProviderID, err := parseUUID(model.LlmProviderID)
if err != nil {
@@ -61,11 +55,6 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
// Handle optional enable_as field
if model.EnableAs != nil {
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
}
created, err := s.queries.CreateModel(ctx, params)
if err != nil {
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
@@ -166,13 +155,6 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
}
// If enable_as is being set, clear any existing model with the same enable_as
if model.EnableAs != nil {
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
}
}
params := sqlc.UpdateModelParams{
ID: uuid,
IsMultimodal: model.IsMultimodal,
@@ -193,11 +175,6 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
// Handle optional enable_as field
if model.EnableAs != nil {
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
}
updated, err := s.queries.UpdateModel(ctx, params)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
@@ -217,13 +194,6 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
}
// If enable_as is being set, clear any existing model with the same enable_as
if model.EnableAs != nil {
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
}
}
params := sqlc.UpdateModelByModelIDParams{
ModelID: modelID,
IsMultimodal: model.IsMultimodal,
@@ -244,11 +214,6 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
// Handle optional enable_as field
if model.EnableAs != nil {
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
}
updated, err := s.queries.UpdateModelByModelID(ctx, params)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
@@ -306,20 +271,6 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
return count, nil
}
// GetByEnableAs retrieves the model that has the specified enable_as value
func (s *Service) GetByEnableAs(ctx context.Context, enableAs EnableAs) (GetResponse, error) {
if enableAs != EnableAsChat && enableAs != EnableAsMemory && enableAs != EnableAsEmbedding {
return GetResponse{}, fmt.Errorf("invalid enable_as value: %s", enableAs)
}
dbModel, err := s.queries.GetModelByEnableAs(ctx, pgtype.Text{String: string(enableAs), Valid: true})
if err != nil {
return GetResponse{}, fmt.Errorf("failed to get model by enable_as: %w", err)
}
return convertToGetResponse(dbModel), nil
}
// Helper functions
func parseUUID(id string) (pgtype.UUID, error) {
@@ -357,11 +308,6 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
resp.Model.Dimensions = int(dbModel.Dimensions.Int32)
}
if dbModel.EnableAs.Valid {
enableAs := EnableAs(dbModel.EnableAs.String)
resp.Model.EnableAs = &enableAs
}
return resp
}
@@ -399,3 +345,32 @@ func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) {
}
return id.String(), true
}
// SelectMemoryModel selects a chat model for memory operations.
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) {
if modelsService == nil {
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
}
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
if err != nil || len(candidates) == 0 {
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
}
selected := candidates[0]
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return selected, provider, nil
}
// FetchProviderByID fetches a provider by ID.
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
if strings.TrimSpace(providerID) == "" {
return sqlc.LlmProvider{}, fmt.Errorf("provider id missing")
}
parsed, err := parseUUID(providerID)
if err != nil {
return sqlc.LlmProvider{}, err
}
return queries.GetLlmProviderByID(ctx, parsed)
}
-23
View File
@@ -13,14 +13,6 @@ const (
ModelTypeEmbedding ModelType = "embedding"
)
type EnableAs string
const (
EnableAsChat EnableAs = "chat"
EnableAsMemory EnableAs = "memory"
EnableAsEmbedding EnableAs = "embedding"
)
type ClientType string
const (
@@ -41,7 +33,6 @@ type Model struct {
IsMultimodal bool `json:"is_multimodal"`
Type ModelType `json:"type"`
Dimensions int `json:"dimensions"`
EnableAs *EnableAs `json:"enable_as,omitempty"`
}
func (m *Model) Validate() error {
@@ -61,20 +52,6 @@ func (m *Model) Validate() error {
return errors.New("dimensions must be greater than 0")
}
// Validate enable_as based on type
if m.EnableAs != nil {
switch m.Type {
case ModelTypeEmbedding:
if *m.EnableAs != EnableAsEmbedding {
return errors.New("embedding models can only have enable_as set to 'embedding'")
}
case ModelTypeChat:
if *m.EnableAs != EnableAsChat && *m.EnableAs != EnableAsMemory {
return errors.New("chat models can only have enable_as set to 'chat' or 'memory'")
}
}
}
return nil
}
+21
View File
@@ -35,6 +35,9 @@ func (s *Service) Get(ctx context.Context, userID string) (Settings, error) {
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return Settings{
ChatModelID: "",
MemoryModelID: "",
EmbeddingModelID: "",
MaxContextLoadTime: DefaultMaxContextLoadTime,
Language: DefaultLanguage,
}, nil
@@ -54,6 +57,9 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest)
}
current := Settings{
ChatModelID: "",
MemoryModelID: "",
EmbeddingModelID: "",
MaxContextLoadTime: DefaultMaxContextLoadTime,
Language: DefaultLanguage,
}
@@ -65,6 +71,15 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest)
current = normalizeUserSetting(existing)
}
if value := strings.TrimSpace(req.ChatModelID); value != "" {
current.ChatModelID = value
}
if value := strings.TrimSpace(req.MemoryModelID); value != "" {
current.MemoryModelID = value
}
if value := strings.TrimSpace(req.EmbeddingModelID); value != "" {
current.EmbeddingModelID = value
}
if req.MaxContextLoadTime != nil && *req.MaxContextLoadTime > 0 {
current.MaxContextLoadTime = *req.MaxContextLoadTime
}
@@ -74,6 +89,9 @@ func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest)
_, err = s.queries.UpsertSettings(ctx, sqlc.UpsertSettingsParams{
UserID: pgID,
ChatModelID: pgtype.Text{String: current.ChatModelID, Valid: current.ChatModelID != ""},
MemoryModelID: pgtype.Text{String: current.MemoryModelID, Valid: current.MemoryModelID != ""},
EmbeddingModelID: pgtype.Text{String: current.EmbeddingModelID, Valid: current.EmbeddingModelID != ""},
MaxContextLoadTime: int32(current.MaxContextLoadTime),
Language: current.Language,
})
@@ -96,6 +114,9 @@ func (s *Service) Delete(ctx context.Context, userID string) error {
func normalizeUserSetting(row sqlc.UserSetting) Settings {
settings := Settings{
ChatModelID: strings.TrimSpace(row.ChatModelID.String),
MemoryModelID: strings.TrimSpace(row.MemoryModelID.String),
EmbeddingModelID: strings.TrimSpace(row.EmbeddingModelID.String),
MaxContextLoadTime: int(row.MaxContextLoadTime),
Language: strings.TrimSpace(row.Language),
}
+6 -1
View File
@@ -6,12 +6,17 @@ const (
)
type Settings struct {
ChatModelID string `json:"chat_model_id"`
MemoryModelID string `json:"memory_model_id"`
EmbeddingModelID string `json:"embedding_model_id"`
MaxContextLoadTime int `json:"max_context_load_time"`
Language string `json:"language"`
}
type UpsertRequest struct {
ChatModelID string `json:"chat_model_id,omitempty"`
MemoryModelID string `json:"memory_model_id,omitempty"`
EmbeddingModelID string `json:"embedding_model_id,omitempty"`
MaxContextLoadTime *int `json:"max_context_load_time,omitempty"`
Language string `json:"language,omitempty"`
}
+45 -52
View File
@@ -34,7 +34,6 @@ type Model = {
is_multimodal: boolean
type: 'chat' | 'embedding'
dimensions?: number
enable_as?: 'chat' | 'memory' | 'embedding'
}
type ModelResponse = Partial<Model> & {
@@ -61,6 +60,9 @@ type ScheduleListResponse = {
}
type Settings = {
chat_model_id: string
memory_model_id: string
embedding_model_id: string
max_context_load_time: number
language: string
}
@@ -105,7 +107,6 @@ const getModelId = (item: ModelResponse) => item.model?.model_id ?? item.model_i
const getProviderId = (item: ModelResponse) => item.model?.llm_provider_id ?? item.llm_provider_id ?? ''
const getModelType = (item: ModelResponse) => item.model?.type ?? item.type ?? 'chat'
const getModelMultimodal = (item: ModelResponse) => item.model?.is_multimodal ?? item.is_multimodal ?? false
const getModelEnableAs = (item: ModelResponse) => item.model?.enable_as ?? item.enable_as
const renderProvidersTable = (providers: Provider[], models: ModelResponse[]) => {
const rows: string[][] = [['Provider', 'Type', 'Base URL', 'Models']]
@@ -125,14 +126,13 @@ const renderProvidersTable = (providers: Provider[], models: ModelResponse[]) =>
const renderModelsTable = (models: ModelResponse[], providers: Provider[]) => {
const providerMap = new Map(providers.map(p => [p.id, p.name]))
const rows: string[][] = [['Model ID', 'Type', 'Provider', 'Multimodal', 'Enable As']]
const rows: string[][] = [['Model ID', 'Type', 'Provider', 'Multimodal']]
for (const item of models) {
rows.push([
getModelId(item),
getModelType(item),
providerMap.get(getProviderId(item)) ?? getProviderId(item),
getModelMultimodal(item) ? 'yes' : 'no',
getModelEnableAs(item) ?? '-',
])
}
return table(rows)
@@ -229,6 +229,9 @@ configCmd.action(async () => {
if (!token?.access_token) return
try {
const settings = await apiRequest<Settings>('/settings', {}, token)
console.log(`chat_model_id = "${settings.chat_model_id || ''}"`)
console.log(`memory_model_id = "${settings.memory_model_id || ''}"`)
console.log(`embedding_model_id = "${settings.embedding_model_id || ''}"`)
console.log(`max_context_load_time = ${settings.max_context_load_time}`)
console.log(`language = "${settings.language}"`)
} catch (err: unknown) {
@@ -241,6 +244,9 @@ configCmd
.description('Update config')
.option('--host <host>')
.option('--port <port>')
.option('--chat_model_id <model_id>')
.option('--memory_model_id <model_id>')
.option('--embedding_model_id <model_id>')
.option('--max_context_load_time <minutes>')
.option('--language <language>')
.action(async (opts) => {
@@ -257,7 +263,11 @@ configCmd
maxContextLoadTime = parsed
}
let language = opts.language
const hasSettingsInput = opts.max_context_load_time !== undefined || opts.language !== undefined
const hasSettingsInput = opts.max_context_load_time !== undefined
|| opts.language !== undefined
|| opts.chat_model_id !== undefined
|| opts.memory_model_id !== undefined
|| opts.embedding_model_id !== undefined
const hasConfigInput = Boolean(host || port)
if (!hasConfigInput && !hasSettingsInput) {
@@ -282,6 +292,9 @@ configCmd
language = String(language).trim()
}
const payload: Partial<Settings> = {}
if (opts.chat_model_id) payload.chat_model_id = String(opts.chat_model_id).trim()
if (opts.memory_model_id) payload.memory_model_id = String(opts.memory_model_id).trim()
if (opts.embedding_model_id) payload.embedding_model_id = String(opts.embedding_model_id).trim()
if (maxContextLoadTime !== undefined) payload.max_context_load_time = maxContextLoadTime
if (language) payload.language = language
const token = ensureAuth()
@@ -393,7 +406,6 @@ model
.option('--type <type>')
.option('--dimensions <dimensions>')
.option('--multimodal', 'Is multimodal')
.option('--enable_as <enable_as>')
.action(async (opts) => {
const token = ensureAuth()
const providers = await apiRequest<Provider[]>('/providers', {}, token)
@@ -438,7 +450,6 @@ model
is_multimodal: isMultimodal,
type: modelType,
dimensions,
enable_as: opts.enable_as,
}
const spinner = ora('Creating model...').start()
try {
@@ -472,10 +483,9 @@ model
model
.command('enable')
.description('Enable model')
.option('--as <enable_as>')
.option('--model <model>')
.option('--provider <provider>')
.description('Enable model for chat/memory/embedding')
.option('--as <usage>')
.option('--model <model_id>')
.action(async (opts) => {
const token = ensureAuth()
let enableAs = opts.as
@@ -488,57 +498,40 @@ model
}])
enableAs = answer.enable_as
}
const [providers, models] = await Promise.all([
apiRequest<Provider[]>('/providers', {}, token),
apiRequest<ModelResponse[]>('/models', {}, token),
])
let providerName = opts.provider
if (!providerName) {
const answer = await inquirer.prompt([{
type: 'list',
name: 'provider',
message: 'Select provider:',
choices: providers.map(p => p.name),
}])
providerName = answer.provider
}
const provider = providers.find(p => p.name === providerName)
if (!provider) {
console.log(chalk.red('Provider not found.'))
enableAs = String(enableAs).trim()
if (!['chat', 'memory', 'embedding'].includes(enableAs)) {
console.log(chalk.red('Enable as must be one of chat, memory, embedding.'))
process.exit(1)
}
let modelName = opts.model
if (!modelName) {
const providerModels = models
.filter(m => getProviderId(m) === provider.id)
.map(m => getModelId(m))
if (providerModels.length === 0) {
console.log(chalk.red('No models found for selected provider.'))
process.exit(1)
}
const models = await apiRequest<ModelResponse[]>('/models', {}, token)
const requiredType = enableAs === 'embedding' ? 'embedding' : 'chat'
const candidates = models.filter(m => getModelType(m) === requiredType)
if (candidates.length === 0) {
console.log(chalk.red(`No ${requiredType} models available.`))
process.exit(1)
}
let modelId = opts.model
if (!modelId) {
const answer = await inquirer.prompt([{
type: 'list',
name: 'model',
message: 'Select model:',
choices: providerModels,
choices: candidates.map(m => getModelId(m)),
}])
modelName = answer.model
modelId = answer.model
}
const current = models.find(m => getModelId(m) === modelName && getProviderId(m) === provider.id)
?? await apiRequest<ModelResponse>(`/models/model/${encodeURIComponent(modelName)}`, {}, token)
const modelPayload = current.model
? { ...current.model, model_id: current.model.model_id }
: { ...current, model_id: current.model_id ?? modelName }
const payload = {
...modelPayload,
enable_as: enableAs,
const selected = candidates.find(m => getModelId(m) === modelId)
if (!selected) {
console.log(chalk.red('Selected model not found.'))
process.exit(1)
}
const spinner = ora('Updating model...').start()
const payload: Partial<Settings> = {}
if (enableAs === 'chat') payload.chat_model_id = getModelId(selected)
if (enableAs === 'memory') payload.memory_model_id = getModelId(selected)
if (enableAs === 'embedding') payload.embedding_model_id = getModelId(selected)
const spinner = ora('Updating settings...').start()
try {
await apiRequest(`/models/model/${encodeURIComponent(modelName)}`, {
method: 'PUT',
body: JSON.stringify(payload),
}, token)
await apiRequest('/settings', { method: 'PUT', body: JSON.stringify(payload) }, token)
spinner.succeed('Model enabled')
} catch (err: unknown) {
spinner.fail(getErrorMessage(err) || 'Failed to enable model')