diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 3ad8df3d..f3eadaf5 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "log" "os" "strings" @@ -42,10 +41,10 @@ func (e *resolverTextEmbedder) Dimensions() int { return e.dims } -func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, error) { +func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) { candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding) if err != nil { - return nil, models.GetResponse{}, models.GetResponse{}, err + return nil, models.GetResponse{}, models.GetResponse{}, false, err } vectors := map[string]int{} var textModel models.GetResponse @@ -64,13 +63,12 @@ func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[ textModel = model } } - if textModel.ModelID == "" { - return vectors, textModel, multimodalModel, fmt.Errorf("no text embedding model configured") - } - if multimodalModel.ModelID == "" { - return vectors, textModel, multimodalModel, fmt.Errorf("no multimodal embedding model configured") - } - return vectors, textModel, multimodalModel, nil + + hasTextModel := textModel.ModelID != "" + hasMultimodalModel := multimodalModel.ModelID != "" + hasAnyModel := hasTextModel || hasMultimodalModel + + return vectors, textModel, multimodalModel, hasAnyModel, nil } func main() { @@ -122,44 +120,64 @@ func main() { time.Duration(cfg.Memory.TimeoutSeconds)*time.Second, ) resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second) - vectors, textModel, multimodalModel, err := collectEmbeddingVectors(ctx, modelsService) + vectors, textModel, multimodalModel, hasModels, err := collectEmbeddingVectors(ctx, modelsService) if err != nil { log.Fatalf("embedding models: %v", err) } - if textModel.Dimensions <= 0 { - log.Fatalf("text embedding dimensions not configured") - } - textEmbedder := &resolverTextEmbedder{ - resolver: resolver, - modelID: textModel.ModelID, - dims: textModel.Dimensions, - } - var store *memory.QdrantStore - if len(vectors) > 0 { - store, err = memory.NewQdrantStoreWithVectors( - cfg.Qdrant.BaseURL, - cfg.Qdrant.APIKey, - cfg.Qdrant.Collection, - vectors, - time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, - ) - if err != nil { - log.Fatalf("qdrant named vectors init: %v", err) - } + + var memoryService *memory.Service + var memoryHandler *handlers.MemoryHandler + + if !hasModels { + log.Println("WARNING: No embedding models configured. Memory service will not be available.") + log.Println("You can add embedding models via the /models API endpoint.") + memoryHandler = handlers.NewMemoryHandler(nil) } else { - store, err = memory.NewQdrantStore( - cfg.Qdrant.BaseURL, - cfg.Qdrant.APIKey, - cfg.Qdrant.Collection, - textModel.Dimensions, - time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, - ) - if err != nil { - log.Fatalf("qdrant init: %v", err) + if textModel.ModelID == "" { + log.Println("WARNING: No text embedding model configured. Text embedding features will be limited.") } + if multimodalModel.ModelID == "" { + log.Println("WARNING: No multimodal embedding model configured. Multimodal embedding features will be limited.") + } + + var textEmbedder embeddings.Embedder + var store *memory.QdrantStore + + if textModel.ModelID != "" && textModel.Dimensions > 0 { + textEmbedder = &resolverTextEmbedder{ + resolver: resolver, + modelID: textModel.ModelID, + dims: textModel.Dimensions, + } + + if len(vectors) > 0 { + store, err = memory.NewQdrantStoreWithVectors( + cfg.Qdrant.BaseURL, + cfg.Qdrant.APIKey, + cfg.Qdrant.Collection, + vectors, + time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, + ) + if err != nil { + log.Fatalf("qdrant named vectors init: %v", err) + } + } else { + store, err = memory.NewQdrantStore( + cfg.Qdrant.BaseURL, + cfg.Qdrant.APIKey, + cfg.Qdrant.Collection, + textModel.Dimensions, + time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, + ) + if err != nil { + log.Fatalf("qdrant init: %v", err) + } + } + } + + memoryService = memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID) + memoryHandler = handlers.NewMemoryHandler(memoryService) } - memoryService := memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID) - memoryHandler := handlers.NewMemoryHandler(memoryService) embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries) fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace) swaggerHandler := handlers.NewSwaggerHandler() diff --git a/db/migrations/0001_init.down.sql b/db/migrations/0001_init.down.sql index 0117350a..1eb747e0 100644 --- a/db/migrations/0001_init.down.sql +++ b/db/migrations/0001_init.down.sql @@ -1,5 +1,6 @@ DROP TABLE IF EXISTS lifecycle_events; DROP TABLE IF EXISTS container_versions; +DROP TABLE IF EXISTS models; DROP TABLE IF EXISTS snapshots; DROP TABLE IF EXISTS containers; DROP TABLE IF EXISTS users; diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index eebaccd8..4bd351d5 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -79,9 +79,14 @@ CREATE TABLE IF NOT EXISTS models ( dimensions INTEGER, is_multimodal BOOLEAN NOT NULL DEFAULT false, type TEXT NOT NULL DEFAULT 'chat', + enable_as TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), CONSTRAINT models_model_id_unique UNIQUE (model_id), + CONSTRAINT models_enable_as_check CHECK ( + (type = 'embedding' AND (enable_as = 'embedding' OR enable_as IS NULL)) OR + (type = 'chat' AND (enable_as IN ('chat', 'memory') OR enable_as IS NULL)) + ), CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')), CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL) ); @@ -99,6 +104,10 @@ CREATE TABLE IF NOT EXISTS model_variants ( CREATE INDEX IF NOT EXISTS idx_model_variants_model_uuid ON model_variants(model_uuid); CREATE INDEX IF NOT EXISTS idx_model_variants_variant_id ON model_variants(variant_id); +CREATE UNIQUE INDEX IF NOT EXISTS idx_models_enable_as_unique ON models(enable_as) WHERE enable_as IS NOT NULL; + +CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id); +CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id); CREATE TABLE IF NOT EXISTS container_versions ( id TEXT PRIMARY KEY, diff --git a/db/queries/models.sql b/db/queries/models.sql index fe444bbb..bc829107 100644 --- a/db/queries/models.sql +++ b/db/queries/models.sql @@ -46,14 +46,15 @@ SELECT COUNT(*) FROM llm_providers; SELECT COUNT(*) FROM llm_providers WHERE client_type = sqlc.arg(client_type); -- name: CreateModel :one -INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type) +INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as) VALUES ( sqlc.arg(model_id), sqlc.arg(name), sqlc.arg(llm_provider_id), sqlc.arg(dimensions), sqlc.arg(is_multimodal), - sqlc.arg(type) + sqlc.arg(type), + sqlc.arg(enable_as) ) RETURNING *; @@ -86,6 +87,7 @@ SET dimensions = sqlc.arg(dimensions), is_multimodal = sqlc.arg(is_multimodal), type = sqlc.arg(type), + enable_as = sqlc.arg(enable_as), updated_at = now() WHERE id = sqlc.arg(id) RETURNING *; @@ -98,6 +100,7 @@ SET dimensions = sqlc.arg(dimensions), is_multimodal = sqlc.arg(is_multimodal), type = sqlc.arg(type), + enable_as = sqlc.arg(enable_as), updated_at = now() WHERE model_id = sqlc.arg(model_id) RETURNING *; @@ -114,6 +117,14 @@ SELECT COUNT(*) FROM models; -- name: CountModelsByType :one SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type); +-- name: GetModelByEnableAs :one +SELECT * FROM models WHERE enable_as = sqlc.arg(enable_as) LIMIT 1; + +-- name: ClearEnableAs :exec +UPDATE models +SET enable_as = NULL, updated_at = now() +WHERE enable_as = sqlc.arg(enable_as); + -- name: CreateModelVariant :one INSERT INTO model_variants (model_uuid, variant_id, weight, metadata) VALUES ( diff --git a/docs/docs.go b/docs/docs.go index 4d7b476d..4c57d603 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -818,6 +818,50 @@ const docTemplate = `{ } } }, + "/models/enable-as/{enableAs}": { + "get": { + "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)", + "tags": [ + "models" + ], + "summary": "Get model by enable_as", + "parameters": [ + { + "type": "string", + "description": "Enable as value (chat, memory, embedding)", + "name": "enableAs", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/models/model/{modelId}": { "get": { "description": "Get a model configuration by its model_id field (e.g., gpt-4)", @@ -1556,6 +1600,9 @@ const docTemplate = `{ "dimensions": { "type": "integer" }, + "enable_as": { + "$ref": "#/definitions/models.EnableAs" + }, "is_multimodal": { "type": "boolean" }, @@ -1592,12 +1639,28 @@ const docTemplate = `{ } } }, + "models.EnableAs": { + "type": "string", + "enum": [ + "chat", + "memory", + "embedding" + ], + "x-enum-varnames": [ + "EnableAsChat", + "EnableAsMemory", + "EnableAsEmbedding" + ] + }, "models.GetResponse": { "type": "object", "properties": { "dimensions": { "type": "integer" }, + "enable_as": { + "$ref": "#/definitions/models.EnableAs" + }, "is_multimodal": { "type": "boolean" }, @@ -1632,6 +1695,9 @@ const docTemplate = `{ "dimensions": { "type": "integer" }, + "enable_as": { + "$ref": "#/definitions/models.EnableAs" + }, "is_multimodal": { "type": "boolean" }, diff --git a/docs/swagger.json b/docs/swagger.json index ec81d649..c68a99e4 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -807,6 +807,50 @@ } } }, + "/models/enable-as/{enableAs}": { + "get": { + "description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)", + "tags": [ + "models" + ], + "summary": "Get model by enable_as", + "parameters": [ + { + "type": "string", + "description": "Enable as value (chat, memory, embedding)", + "name": "enableAs", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/models/model/{modelId}": { "get": { "description": "Get a model configuration by its model_id field (e.g., gpt-4)", @@ -1545,6 +1589,9 @@ "dimensions": { "type": "integer" }, + "enable_as": { + "$ref": "#/definitions/models.EnableAs" + }, "is_multimodal": { "type": "boolean" }, @@ -1581,12 +1628,28 @@ } } }, + "models.EnableAs": { + "type": "string", + "enum": [ + "chat", + "memory", + "embedding" + ], + "x-enum-varnames": [ + "EnableAsChat", + "EnableAsMemory", + "EnableAsEmbedding" + ] + }, "models.GetResponse": { "type": "object", "properties": { "dimensions": { "type": "integer" }, + "enable_as": { + "$ref": "#/definitions/models.EnableAs" + }, "is_multimodal": { "type": "boolean" }, @@ -1621,6 +1684,9 @@ "dimensions": { "type": "integer" }, + "enable_as": { + "$ref": "#/definitions/models.EnableAs" + }, "is_multimodal": { "type": "boolean" }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index e7a88ec5..1a103d34 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -305,6 +305,8 @@ definitions: properties: dimensions: type: integer + enable_as: + $ref: '#/definitions/models.EnableAs' is_multimodal: type: boolean llm_provider_id: @@ -328,10 +330,22 @@ definitions: count: type: integer type: object + models.EnableAs: + enum: + - chat + - memory + - embedding + type: string + x-enum-varnames: + - EnableAsChat + - EnableAsMemory + - EnableAsEmbedding models.GetResponse: properties: dimensions: type: integer + enable_as: + $ref: '#/definitions/models.EnableAs' is_multimodal: type: boolean llm_provider_id: @@ -355,6 +369,8 @@ definitions: properties: dimensions: type: integer + enable_as: + $ref: '#/definitions/models.EnableAs' is_multimodal: type: boolean llm_provider_id: @@ -984,6 +1000,36 @@ paths: summary: Get model count tags: - models + /models/enable-as/{enableAs}: + get: + description: Get the model that is enabled for a specific purpose (chat, memory, + embedding) + parameters: + - description: Enable as value (chat, memory, embedding) + in: path + name: enableAs + required: true + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/models.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get model by enable_as + tags: + - models /models/model/{modelId}: delete: description: Delete a model configuration by its model_id field (e.g., gpt-4) diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 85eee040..3714b95c 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -60,6 +60,7 @@ type Model struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` + EnableAs pgtype.Text `json:"enable_as"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } diff --git a/internal/db/sqlc/models.sql.go b/internal/db/sqlc/models.sql.go index e66c3ce2..b5b36426 100644 --- a/internal/db/sqlc/models.sql.go +++ b/internal/db/sqlc/models.sql.go @@ -11,6 +11,17 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const clearEnableAs = `-- name: ClearEnableAs :exec +UPDATE models +SET enable_as = NULL, updated_at = now() +WHERE enable_as = $1 +` + +func (q *Queries) ClearEnableAs(ctx context.Context, enableAs pgtype.Text) error { + _, err := q.db.Exec(ctx, clearEnableAs, enableAs) + return err +} + const countLlmProviders = `-- name: CountLlmProviders :one SELECT COUNT(*) FROM llm_providers ` @@ -98,16 +109,17 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa } const createModel = `-- name: CreateModel :one -INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type) +INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as) VALUES ( $1, $2, $3, $4, $5, - $6 + $6, + $7 ) -RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at ` type CreateModelParams struct { @@ -117,6 +129,7 @@ type CreateModelParams struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` + EnableAs pgtype.Text `json:"enable_as"` } func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) { @@ -127,6 +140,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model arg.Dimensions, arg.IsMultimodal, arg.Type, + arg.EnableAs, ) var i Model err := row.Scan( @@ -137,6 +151,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -257,8 +272,30 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro return i, err } +const getModelByEnableAs = `-- name: GetModelByEnableAs :one +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE enable_as = $1 LIMIT 1 +` + +func (q *Queries) GetModelByEnableAs(ctx context.Context, enableAs pgtype.Text) (Model, error) { + row := q.db.QueryRow(ctx, getModelByEnableAs, enableAs) + var i Model + err := row.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.LlmProviderID, + &i.Dimensions, + &i.IsMultimodal, + &i.Type, + &i.EnableAs, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getModelByID = `-- name: GetModelByID :one -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE id = $1 +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE id = $1 ` func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) { @@ -272,6 +309,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -279,7 +317,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro } const getModelByModelID = `-- name: GetModelByModelID :one -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE model_id = $1 +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE model_id = $1 ` func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) { @@ -293,6 +331,7 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -456,7 +495,7 @@ func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID st } const listModels = `-- name: ListModels :many -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models ORDER BY created_at DESC ` @@ -477,6 +516,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) { &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -491,7 +531,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) { } const listModelsByClientType = `-- name: ListModelsByClientType :many -SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.created_at, m.updated_at FROM models AS m +SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.enable_as, m.created_at, m.updated_at FROM models AS m JOIN llm_providers AS p ON p.id = m.llm_provider_id WHERE p.client_type = $1 ORDER BY m.created_at DESC @@ -514,6 +554,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -528,7 +569,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) } const listModelsByType = `-- name: ListModelsByType :many -SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE type = $1 ORDER BY created_at DESC ` @@ -550,6 +591,7 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model, &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ); err != nil { @@ -616,9 +658,10 @@ SET dimensions = $3, is_multimodal = $4, type = $5, + enable_as = $6, updated_at = now() -WHERE id = $6 -RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at +WHERE id = $7 +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at ` type UpdateModelParams struct { @@ -627,6 +670,7 @@ type UpdateModelParams struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` + EnableAs pgtype.Text `json:"enable_as"` ID pgtype.UUID `json:"id"` } @@ -637,6 +681,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model arg.Dimensions, arg.IsMultimodal, arg.Type, + arg.EnableAs, arg.ID, ) var i Model @@ -648,6 +693,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) @@ -662,9 +708,10 @@ SET dimensions = $3, is_multimodal = $4, type = $5, + enable_as = $6, updated_at = now() -WHERE model_id = $6 -RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at +WHERE model_id = $7 +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at ` type UpdateModelByModelIDParams struct { @@ -673,6 +720,7 @@ type UpdateModelByModelIDParams struct { Dimensions pgtype.Int4 `json:"dimensions"` IsMultimodal bool `json:"is_multimodal"` Type string `json:"type"` + EnableAs pgtype.Text `json:"enable_as"` ModelID string `json:"model_id"` } @@ -683,6 +731,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod arg.Dimensions, arg.IsMultimodal, arg.Type, + arg.EnableAs, arg.ModelID, ) var i Model @@ -694,6 +743,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod &i.Dimensions, &i.IsMultimodal, &i.Type, + &i.EnableAs, &i.CreatedAt, &i.UpdatedAt, ) diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go index 24af20cf..44373afb 100644 --- a/internal/embeddings/resolver.go +++ b/internal/embeddings/resolver.go @@ -173,6 +173,23 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model return models.GetResponse{}, errors.New("models service not configured") } + // If no model specified and no provider specified, try to get default embedding model + if req.Model == "" && req.Provider == "" { + defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsEmbedding) + if err == nil { + // Found default model, check if it matches the type requirement + if req.Type == TypeMultimodal && !defaultModel.IsMultimodal { + // Default is text, but need multimodal - continue to search + } else if req.Type == TypeText && defaultModel.IsMultimodal { + // Default is multimodal, but need text - continue to search + } else { + // Default model matches requirements + return defaultModel, nil + } + } + // No default model or doesn't match requirements, continue to search + } + var candidates []models.GetResponse var err error if req.Provider != "" { diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index 8218ca38..0eac6ad5 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -31,6 +31,13 @@ func (h *MemoryHandler) Register(e *echo.Echo) { group.DELETE("/memories", h.DeleteAll) } +func (h *MemoryHandler) checkService() error { + if h.service == nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available: no embedding models configured") + } + return nil +} + // EmbedUpsert godoc // @Summary Embed and upsert memory // @Description Embed text or multimodal input and upsert into memory store @@ -41,6 +48,10 @@ func (h *MemoryHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /memory/embed [post] func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err @@ -72,6 +83,10 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /memory/add [post] func (h *MemoryHandler) Add(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err @@ -103,6 +118,10 @@ func (h *MemoryHandler) Add(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /memory/search [post] func (h *MemoryHandler) Search(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err @@ -134,6 +153,10 @@ func (h *MemoryHandler) Search(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /memory/update [post] func (h *MemoryHandler) Update(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err @@ -170,6 +193,10 @@ func (h *MemoryHandler) Update(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /memory/memories/{memoryId} [get] func (h *MemoryHandler) Get(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err @@ -203,6 +230,10 @@ func (h *MemoryHandler) Get(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /memory/memories [get] func (h *MemoryHandler) GetAll(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err @@ -240,6 +271,10 @@ func (h *MemoryHandler) GetAll(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /memory/memories/{memoryId} [delete] func (h *MemoryHandler) Delete(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err @@ -275,6 +310,10 @@ func (h *MemoryHandler) Delete(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /memory/memories [delete] func (h *MemoryHandler) DeleteAll(c echo.Context) error { + if err := h.checkService(); err != nil { + return err + } + userID, err := h.requireUserID(c) if err != nil { return err diff --git a/internal/handlers/models.go b/internal/handlers/models.go index 49484f2e..6c654cad 100644 --- a/internal/handlers/models.go +++ b/internal/handlers/models.go @@ -22,6 +22,7 @@ func (h *ModelsHandler) Register(e *echo.Echo) { group.GET("", h.List) group.GET("/:id", h.GetByID) group.GET("/model/:modelId", h.GetByModelID) + group.GET("/enable-as/:enableAs", h.GetByEnableAs) group.PUT("/:id", h.UpdateByID) group.PUT("/model/:modelId", h.UpdateByModelID) group.DELETE("/:id", h.DeleteByID) @@ -230,6 +231,38 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { return c.NoContent(http.StatusNoContent) } +// GetByEnableAs godoc +// @Summary Get model by enable_as +// @Description Get the model that is enabled for a specific purpose (chat, memory, embedding) +// @Tags models +// @Param enableAs path string true "Enable as value (chat, memory, embedding)" +// @Success 200 {object} models.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/enable-as/{enableAs} [get] +// GetByEnableAs godoc +// @Summary Get default model by enable_as +// @Description Get the default model configured for a specific purpose (chat, memory, or embedding) +// @Tags models +// @Param enableAs path string true "Enable as value (chat, memory, embedding)" +// @Success 200 {object} models.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Router /models/enable-as/{enableAs} [get] +func (h *ModelsHandler) GetByEnableAs(c echo.Context) error { + enableAs := c.Param("enableAs") + if enableAs == "" { + return echo.NewHTTPError(http.StatusBadRequest, "enableAs is required") + } + + resp, err := h.service.GetByEnableAs(c.Request().Context(), models.EnableAs(enableAs)) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + // Count godoc // @Summary Get model count // @Description Get the total count of models, optionally filtered by type diff --git a/internal/models/models.go b/internal/models/models.go index a3c8f713..6095522e 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -28,6 +28,13 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro return AddResponse{}, fmt.Errorf("validation failed: %w", err) } + // If enable_as is set, clear any existing model with the same enable_as + if model.EnableAs != nil { + if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil { + return AddResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err) + } + } + // Convert to sqlc params llmProviderID, err := parseUUID(model.LlmProviderID) if err != nil { @@ -51,6 +58,11 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} } + // Handle optional enable_as field + if model.EnableAs != nil { + params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true} + } + created, err := s.queries.CreateModel(ctx, params) if err != nil { return AddResponse{}, fmt.Errorf("failed to create model: %w", err) @@ -151,6 +163,13 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) return GetResponse{}, fmt.Errorf("validation failed: %w", err) } + // If enable_as is being set, clear any existing model with the same enable_as + if model.EnableAs != nil { + if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil { + return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err) + } + } + params := sqlc.UpdateModelParams{ ID: uuid, IsMultimodal: model.IsMultimodal, @@ -171,6 +190,11 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} } + // Handle optional enable_as field + if model.EnableAs != nil { + params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true} + } + updated, err := s.queries.UpdateModel(ctx, params) if err != nil { return GetResponse{}, fmt.Errorf("failed to update model: %w", err) @@ -190,10 +214,17 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat return GetResponse{}, fmt.Errorf("validation failed: %w", err) } + // If enable_as is being set, clear any existing model with the same enable_as + if model.EnableAs != nil { + if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil { + return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err) + } + } + params := sqlc.UpdateModelByModelIDParams{ - ModelID: modelID, - IsMultimodal: model.IsMultimodal, - Type: string(model.Type), + ModelID: modelID, + IsMultimodal: model.IsMultimodal, + Type: string(model.Type), } llmProviderID, err := parseUUID(model.LlmProviderID) @@ -210,6 +241,11 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} } + // Handle optional enable_as field + if model.EnableAs != nil { + params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true} + } + updated, err := s.queries.UpdateModelByModelID(ctx, params) if err != nil { return GetResponse{}, fmt.Errorf("failed to update model: %w", err) @@ -267,6 +303,20 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, return count, nil } +// GetByEnableAs retrieves the model that has the specified enable_as value +func (s *Service) GetByEnableAs(ctx context.Context, enableAs EnableAs) (GetResponse, error) { + if enableAs != EnableAsChat && enableAs != EnableAsMemory && enableAs != EnableAsEmbedding { + return GetResponse{}, fmt.Errorf("invalid enable_as value: %s", enableAs) + } + + dbModel, err := s.queries.GetModelByEnableAs(ctx, pgtype.Text{String: string(enableAs), Valid: true}) + if err != nil { + return GetResponse{}, fmt.Errorf("failed to get model by enable_as: %w", err) + } + + return convertToGetResponse(dbModel), nil +} + // Helper functions func parseUUID(id string) (pgtype.UUID, error) { @@ -304,6 +354,11 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { resp.Model.Dimensions = int(dbModel.Dimensions.Int32) } + if dbModel.EnableAs.Valid { + enableAs := EnableAs(dbModel.EnableAs.String) + resp.Model.EnableAs = &enableAs + } + return resp } diff --git a/internal/models/types.go b/internal/models/types.go index 5066ede3..7bed6c09 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -13,6 +13,14 @@ const ( ModelTypeEmbedding ModelType = "embedding" ) +type EnableAs string + +const ( + EnableAsChat EnableAs = "chat" + EnableAsMemory EnableAs = "memory" + EnableAsEmbedding EnableAs = "embedding" +) + type ClientType string const ( @@ -33,6 +41,7 @@ type Model struct { IsMultimodal bool `json:"is_multimodal"` Type ModelType `json:"type"` Dimensions int `json:"dimensions"` + EnableAs *EnableAs `json:"enable_as,omitempty"` } func (m *Model) Validate() error { @@ -51,6 +60,21 @@ func (m *Model) Validate() error { if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 { return errors.New("dimensions must be greater than 0") } + + // Validate enable_as based on type + if m.EnableAs != nil { + switch m.Type { + case ModelTypeEmbedding: + if *m.EnableAs != EnableAsEmbedding { + return errors.New("embedding models can only have enable_as set to 'embedding'") + } + case ModelTypeChat: + if *m.EnableAs != EnableAsChat && *m.EnableAs != EnableAsMemory { + return errors.New("chat models can only have enable_as set to 'chat' or 'memory'") + } + } + } + return nil }