feat: allow setting default model enable as chat, memory, embedding

This commit is contained in:
Acbox
2026-01-26 21:54:46 +08:00
parent 3ff0e2c4dd
commit 35a8927a79
14 changed files with 495 additions and 59 deletions
+35 -17
View File
@@ -2,7 +2,6 @@ package main
import (
"context"
"fmt"
"log"
"os"
"strings"
@@ -42,10 +41,10 @@ func (e *resolverTextEmbedder) Dimensions() int {
return e.dims
}
func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, error) {
func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) {
candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding)
if err != nil {
return nil, models.GetResponse{}, models.GetResponse{}, err
return nil, models.GetResponse{}, models.GetResponse{}, false, err
}
vectors := map[string]int{}
var textModel models.GetResponse
@@ -64,13 +63,12 @@ func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[
textModel = model
}
}
if textModel.ModelID == "" {
return vectors, textModel, multimodalModel, fmt.Errorf("no text embedding model configured")
}
if multimodalModel.ModelID == "" {
return vectors, textModel, multimodalModel, fmt.Errorf("no multimodal embedding model configured")
}
return vectors, textModel, multimodalModel, nil
hasTextModel := textModel.ModelID != ""
hasMultimodalModel := multimodalModel.ModelID != ""
hasAnyModel := hasTextModel || hasMultimodalModel
return vectors, textModel, multimodalModel, hasAnyModel, nil
}
func main() {
@@ -122,19 +120,36 @@ func main() {
time.Duration(cfg.Memory.TimeoutSeconds)*time.Second,
)
resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second)
vectors, textModel, multimodalModel, err := collectEmbeddingVectors(ctx, modelsService)
vectors, textModel, multimodalModel, hasModels, err := collectEmbeddingVectors(ctx, modelsService)
if err != nil {
log.Fatalf("embedding models: %v", err)
}
if textModel.Dimensions <= 0 {
log.Fatalf("text embedding dimensions not configured")
var memoryService *memory.Service
var memoryHandler *handlers.MemoryHandler
if !hasModels {
log.Println("WARNING: No embedding models configured. Memory service will not be available.")
log.Println("You can add embedding models via the /models API endpoint.")
memoryHandler = handlers.NewMemoryHandler(nil)
} else {
if textModel.ModelID == "" {
log.Println("WARNING: No text embedding model configured. Text embedding features will be limited.")
}
textEmbedder := &resolverTextEmbedder{
if multimodalModel.ModelID == "" {
log.Println("WARNING: No multimodal embedding model configured. Multimodal embedding features will be limited.")
}
var textEmbedder embeddings.Embedder
var store *memory.QdrantStore
if textModel.ModelID != "" && textModel.Dimensions > 0 {
textEmbedder = &resolverTextEmbedder{
resolver: resolver,
modelID: textModel.ModelID,
dims: textModel.Dimensions,
}
var store *memory.QdrantStore
if len(vectors) > 0 {
store, err = memory.NewQdrantStoreWithVectors(
cfg.Qdrant.BaseURL,
@@ -158,8 +173,11 @@ func main() {
log.Fatalf("qdrant init: %v", err)
}
}
memoryService := memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
memoryHandler := handlers.NewMemoryHandler(memoryService)
}
memoryService = memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
memoryHandler = handlers.NewMemoryHandler(memoryService)
}
embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries)
fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace)
swaggerHandler := handlers.NewSwaggerHandler()
+1
View File
@@ -1,5 +1,6 @@
DROP TABLE IF EXISTS lifecycle_events;
DROP TABLE IF EXISTS container_versions;
DROP TABLE IF EXISTS models;
DROP TABLE IF EXISTS snapshots;
DROP TABLE IF EXISTS containers;
DROP TABLE IF EXISTS users;
+9
View File
@@ -79,9 +79,14 @@ CREATE TABLE IF NOT EXISTS models (
dimensions INTEGER,
is_multimodal BOOLEAN NOT NULL DEFAULT false,
type TEXT NOT NULL DEFAULT 'chat',
enable_as TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT models_model_id_unique UNIQUE (model_id),
CONSTRAINT models_enable_as_check CHECK (
(type = 'embedding' AND (enable_as = 'embedding' OR enable_as IS NULL)) OR
(type = 'chat' AND (enable_as IN ('chat', 'memory') OR enable_as IS NULL))
),
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')),
CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL)
);
@@ -99,6 +104,10 @@ CREATE TABLE IF NOT EXISTS model_variants (
CREATE INDEX IF NOT EXISTS idx_model_variants_model_uuid ON model_variants(model_uuid);
CREATE INDEX IF NOT EXISTS idx_model_variants_variant_id ON model_variants(variant_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_models_enable_as_unique ON models(enable_as) WHERE enable_as IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id);
CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id);
CREATE TABLE IF NOT EXISTS container_versions (
id TEXT PRIMARY KEY,
+13 -2
View File
@@ -46,14 +46,15 @@ SELECT COUNT(*) FROM llm_providers;
SELECT COUNT(*) FROM llm_providers WHERE client_type = sqlc.arg(client_type);
-- name: CreateModel :one
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type)
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as)
VALUES (
sqlc.arg(model_id),
sqlc.arg(name),
sqlc.arg(llm_provider_id),
sqlc.arg(dimensions),
sqlc.arg(is_multimodal),
sqlc.arg(type)
sqlc.arg(type),
sqlc.arg(enable_as)
)
RETURNING *;
@@ -86,6 +87,7 @@ SET
dimensions = sqlc.arg(dimensions),
is_multimodal = sqlc.arg(is_multimodal),
type = sqlc.arg(type),
enable_as = sqlc.arg(enable_as),
updated_at = now()
WHERE id = sqlc.arg(id)
RETURNING *;
@@ -98,6 +100,7 @@ SET
dimensions = sqlc.arg(dimensions),
is_multimodal = sqlc.arg(is_multimodal),
type = sqlc.arg(type),
enable_as = sqlc.arg(enable_as),
updated_at = now()
WHERE model_id = sqlc.arg(model_id)
RETURNING *;
@@ -114,6 +117,14 @@ SELECT COUNT(*) FROM models;
-- name: CountModelsByType :one
SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type);
-- name: GetModelByEnableAs :one
SELECT * FROM models WHERE enable_as = sqlc.arg(enable_as) LIMIT 1;
-- name: ClearEnableAs :exec
UPDATE models
SET enable_as = NULL, updated_at = now()
WHERE enable_as = sqlc.arg(enable_as);
-- name: CreateModelVariant :one
INSERT INTO model_variants (model_uuid, variant_id, weight, metadata)
VALUES (
+66
View File
@@ -818,6 +818,50 @@ const docTemplate = `{
}
}
},
"/models/enable-as/{enableAs}": {
"get": {
"description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)",
"tags": [
"models"
],
"summary": "Get model by enable_as",
"parameters": [
{
"type": "string",
"description": "Enable as value (chat, memory, embedding)",
"name": "enableAs",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/models.GetResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/models/model/{modelId}": {
"get": {
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
@@ -1556,6 +1600,9 @@ const docTemplate = `{
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -1592,12 +1639,28 @@ const docTemplate = `{
}
}
},
"models.EnableAs": {
"type": "string",
"enum": [
"chat",
"memory",
"embedding"
],
"x-enum-varnames": [
"EnableAsChat",
"EnableAsMemory",
"EnableAsEmbedding"
]
},
"models.GetResponse": {
"type": "object",
"properties": {
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -1632,6 +1695,9 @@ const docTemplate = `{
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
+66
View File
@@ -807,6 +807,50 @@
}
}
},
"/models/enable-as/{enableAs}": {
"get": {
"description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)",
"tags": [
"models"
],
"summary": "Get model by enable_as",
"parameters": [
{
"type": "string",
"description": "Enable as value (chat, memory, embedding)",
"name": "enableAs",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/models.GetResponse"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/models/model/{modelId}": {
"get": {
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
@@ -1545,6 +1589,9 @@
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -1581,12 +1628,28 @@
}
}
},
"models.EnableAs": {
"type": "string",
"enum": [
"chat",
"memory",
"embedding"
],
"x-enum-varnames": [
"EnableAsChat",
"EnableAsMemory",
"EnableAsEmbedding"
]
},
"models.GetResponse": {
"type": "object",
"properties": {
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
@@ -1621,6 +1684,9 @@
"dimensions": {
"type": "integer"
},
"enable_as": {
"$ref": "#/definitions/models.EnableAs"
},
"is_multimodal": {
"type": "boolean"
},
+46
View File
@@ -305,6 +305,8 @@ definitions:
properties:
dimensions:
type: integer
enable_as:
$ref: '#/definitions/models.EnableAs'
is_multimodal:
type: boolean
llm_provider_id:
@@ -328,10 +330,22 @@ definitions:
count:
type: integer
type: object
models.EnableAs:
enum:
- chat
- memory
- embedding
type: string
x-enum-varnames:
- EnableAsChat
- EnableAsMemory
- EnableAsEmbedding
models.GetResponse:
properties:
dimensions:
type: integer
enable_as:
$ref: '#/definitions/models.EnableAs'
is_multimodal:
type: boolean
llm_provider_id:
@@ -355,6 +369,8 @@ definitions:
properties:
dimensions:
type: integer
enable_as:
$ref: '#/definitions/models.EnableAs'
is_multimodal:
type: boolean
llm_provider_id:
@@ -984,6 +1000,36 @@ paths:
summary: Get model count
tags:
- models
/models/enable-as/{enableAs}:
get:
description: Get the model that is enabled for a specific purpose (chat, memory,
embedding)
parameters:
- description: Enable as value (chat, memory, embedding)
in: path
name: enableAs
required: true
type: string
responses:
"200":
description: OK
schema:
$ref: '#/definitions/models.GetResponse'
"400":
description: Bad Request
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"404":
description: Not Found
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/handlers.ErrorResponse'
summary: Get model by enable_as
tags:
- models
/models/model/{modelId}:
delete:
description: Delete a model configuration by its model_id field (e.g., gpt-4)
+1
View File
@@ -60,6 +60,7 @@ type Model struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
+62 -12
View File
@@ -11,6 +11,17 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const clearEnableAs = `-- name: ClearEnableAs :exec
UPDATE models
SET enable_as = NULL, updated_at = now()
WHERE enable_as = $1
`
func (q *Queries) ClearEnableAs(ctx context.Context, enableAs pgtype.Text) error {
_, err := q.db.Exec(ctx, clearEnableAs, enableAs)
return err
}
const countLlmProviders = `-- name: CountLlmProviders :one
SELECT COUNT(*) FROM llm_providers
`
@@ -98,16 +109,17 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa
}
const createModel = `-- name: CreateModel :one
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type)
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as)
VALUES (
$1,
$2,
$3,
$4,
$5,
$6
$6,
$7
)
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
`
type CreateModelParams struct {
@@ -117,6 +129,7 @@ type CreateModelParams struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
}
func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) {
@@ -127,6 +140,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
arg.Dimensions,
arg.IsMultimodal,
arg.Type,
arg.EnableAs,
)
var i Model
err := row.Scan(
@@ -137,6 +151,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -257,8 +272,30 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro
return i, err
}
const getModelByEnableAs = `-- name: GetModelByEnableAs :one
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE enable_as = $1 LIMIT 1
`
func (q *Queries) GetModelByEnableAs(ctx context.Context, enableAs pgtype.Text) (Model, error) {
row := q.db.QueryRow(ctx, getModelByEnableAs, enableAs)
var i Model
err := row.Scan(
&i.ID,
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getModelByID = `-- name: GetModelByID :one
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE id = $1
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE id = $1
`
func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) {
@@ -272,6 +309,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -279,7 +317,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
}
const getModelByModelID = `-- name: GetModelByModelID :one
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE model_id = $1
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE model_id = $1
`
func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) {
@@ -293,6 +331,7 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -456,7 +495,7 @@ func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID st
}
const listModels = `-- name: ListModels :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models
ORDER BY created_at DESC
`
@@ -477,6 +516,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -491,7 +531,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
}
const listModelsByClientType = `-- name: ListModelsByClientType :many
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.created_at, m.updated_at FROM models AS m
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.enable_as, m.created_at, m.updated_at FROM models AS m
JOIN llm_providers AS p ON p.id = m.llm_provider_id
WHERE p.client_type = $1
ORDER BY m.created_at DESC
@@ -514,6 +554,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -528,7 +569,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
}
const listModelsByType = `-- name: ListModelsByType :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models
WHERE type = $1
ORDER BY created_at DESC
`
@@ -550,6 +591,7 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
@@ -616,9 +658,10 @@ SET
dimensions = $3,
is_multimodal = $4,
type = $5,
enable_as = $6,
updated_at = now()
WHERE id = $6
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
WHERE id = $7
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
`
type UpdateModelParams struct {
@@ -627,6 +670,7 @@ type UpdateModelParams struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
ID pgtype.UUID `json:"id"`
}
@@ -637,6 +681,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
arg.Dimensions,
arg.IsMultimodal,
arg.Type,
arg.EnableAs,
arg.ID,
)
var i Model
@@ -648,6 +693,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
@@ -662,9 +708,10 @@ SET
dimensions = $3,
is_multimodal = $4,
type = $5,
enable_as = $6,
updated_at = now()
WHERE model_id = $6
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
WHERE model_id = $7
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
`
type UpdateModelByModelIDParams struct {
@@ -673,6 +720,7 @@ type UpdateModelByModelIDParams struct {
Dimensions pgtype.Int4 `json:"dimensions"`
IsMultimodal bool `json:"is_multimodal"`
Type string `json:"type"`
EnableAs pgtype.Text `json:"enable_as"`
ModelID string `json:"model_id"`
}
@@ -683,6 +731,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
arg.Dimensions,
arg.IsMultimodal,
arg.Type,
arg.EnableAs,
arg.ModelID,
)
var i Model
@@ -694,6 +743,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.EnableAs,
&i.CreatedAt,
&i.UpdatedAt,
)
+17
View File
@@ -173,6 +173,23 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model
return models.GetResponse{}, errors.New("models service not configured")
}
// If no model specified and no provider specified, try to get default embedding model
if req.Model == "" && req.Provider == "" {
defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsEmbedding)
if err == nil {
// Found default model, check if it matches the type requirement
if req.Type == TypeMultimodal && !defaultModel.IsMultimodal {
// Default is text, but need multimodal - continue to search
} else if req.Type == TypeText && defaultModel.IsMultimodal {
// Default is multimodal, but need text - continue to search
} else {
// Default model matches requirements
return defaultModel, nil
}
}
// No default model or doesn't match requirements, continue to search
}
var candidates []models.GetResponse
var err error
if req.Provider != "" {
+39
View File
@@ -31,6 +31,13 @@ func (h *MemoryHandler) Register(e *echo.Echo) {
group.DELETE("/memories", h.DeleteAll)
}
func (h *MemoryHandler) checkService() error {
if h.service == nil {
return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available: no embedding models configured")
}
return nil
}
// EmbedUpsert godoc
// @Summary Embed and upsert memory
// @Description Embed text or multimodal input and upsert into memory store
@@ -41,6 +48,10 @@ func (h *MemoryHandler) Register(e *echo.Echo) {
// @Failure 500 {object} ErrorResponse
// @Router /memory/embed [post]
func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
@@ -72,6 +83,10 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /memory/add [post]
func (h *MemoryHandler) Add(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
@@ -103,6 +118,10 @@ func (h *MemoryHandler) Add(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /memory/search [post]
func (h *MemoryHandler) Search(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
@@ -134,6 +153,10 @@ func (h *MemoryHandler) Search(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /memory/update [post]
func (h *MemoryHandler) Update(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
@@ -170,6 +193,10 @@ func (h *MemoryHandler) Update(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /memory/memories/{memoryId} [get]
func (h *MemoryHandler) Get(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
@@ -203,6 +230,10 @@ func (h *MemoryHandler) Get(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /memory/memories [get]
func (h *MemoryHandler) GetAll(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
@@ -240,6 +271,10 @@ func (h *MemoryHandler) GetAll(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /memory/memories/{memoryId} [delete]
func (h *MemoryHandler) Delete(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
@@ -275,6 +310,10 @@ func (h *MemoryHandler) Delete(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /memory/memories [delete]
func (h *MemoryHandler) DeleteAll(c echo.Context) error {
if err := h.checkService(); err != nil {
return err
}
userID, err := h.requireUserID(c)
if err != nil {
return err
+33
View File
@@ -22,6 +22,7 @@ func (h *ModelsHandler) Register(e *echo.Echo) {
group.GET("", h.List)
group.GET("/:id", h.GetByID)
group.GET("/model/:modelId", h.GetByModelID)
group.GET("/enable-as/:enableAs", h.GetByEnableAs)
group.PUT("/:id", h.UpdateByID)
group.PUT("/model/:modelId", h.UpdateByModelID)
group.DELETE("/:id", h.DeleteByID)
@@ -230,6 +231,38 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
}
// GetByEnableAs godoc
// @Summary Get model by enable_as
// @Description Get the model that is enabled for a specific purpose (chat, memory, embedding)
// @Tags models
// @Param enableAs path string true "Enable as value (chat, memory, embedding)"
// @Success 200 {object} models.GetResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /models/enable-as/{enableAs} [get]
// GetByEnableAs godoc
// @Summary Get default model by enable_as
// @Description Get the default model configured for a specific purpose (chat, memory, or embedding)
// @Tags models
// @Param enableAs path string true "Enable as value (chat, memory, embedding)"
// @Success 200 {object} models.GetResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Router /models/enable-as/{enableAs} [get]
func (h *ModelsHandler) GetByEnableAs(c echo.Context) error {
enableAs := c.Param("enableAs")
if enableAs == "" {
return echo.NewHTTPError(http.StatusBadRequest, "enableAs is required")
}
resp, err := h.service.GetByEnableAs(c.Request().Context(), models.EnableAs(enableAs))
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
return c.JSON(http.StatusOK, resp)
}
// Count godoc
// @Summary Get model count
// @Description Get the total count of models, optionally filtered by type
+55
View File
@@ -28,6 +28,13 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
return AddResponse{}, fmt.Errorf("validation failed: %w", err)
}
// If enable_as is set, clear any existing model with the same enable_as
if model.EnableAs != nil {
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
return AddResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
}
}
// Convert to sqlc params
llmProviderID, err := parseUUID(model.LlmProviderID)
if err != nil {
@@ -51,6 +58,11 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
// Handle optional enable_as field
if model.EnableAs != nil {
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
}
created, err := s.queries.CreateModel(ctx, params)
if err != nil {
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
@@ -151,6 +163,13 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
}
// If enable_as is being set, clear any existing model with the same enable_as
if model.EnableAs != nil {
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
}
}
params := sqlc.UpdateModelParams{
ID: uuid,
IsMultimodal: model.IsMultimodal,
@@ -171,6 +190,11 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
// Handle optional enable_as field
if model.EnableAs != nil {
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
}
updated, err := s.queries.UpdateModel(ctx, params)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
@@ -190,6 +214,13 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
}
// If enable_as is being set, clear any existing model with the same enable_as
if model.EnableAs != nil {
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
}
}
params := sqlc.UpdateModelByModelIDParams{
ModelID: modelID,
IsMultimodal: model.IsMultimodal,
@@ -210,6 +241,11 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
// Handle optional enable_as field
if model.EnableAs != nil {
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
}
updated, err := s.queries.UpdateModelByModelID(ctx, params)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
@@ -267,6 +303,20 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
return count, nil
}
// GetByEnableAs retrieves the model that has the specified enable_as value
func (s *Service) GetByEnableAs(ctx context.Context, enableAs EnableAs) (GetResponse, error) {
if enableAs != EnableAsChat && enableAs != EnableAsMemory && enableAs != EnableAsEmbedding {
return GetResponse{}, fmt.Errorf("invalid enable_as value: %s", enableAs)
}
dbModel, err := s.queries.GetModelByEnableAs(ctx, pgtype.Text{String: string(enableAs), Valid: true})
if err != nil {
return GetResponse{}, fmt.Errorf("failed to get model by enable_as: %w", err)
}
return convertToGetResponse(dbModel), nil
}
// Helper functions
func parseUUID(id string) (pgtype.UUID, error) {
@@ -304,6 +354,11 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
resp.Model.Dimensions = int(dbModel.Dimensions.Int32)
}
if dbModel.EnableAs.Valid {
enableAs := EnableAs(dbModel.EnableAs.String)
resp.Model.EnableAs = &enableAs
}
return resp
}
+24
View File
@@ -13,6 +13,14 @@ const (
ModelTypeEmbedding ModelType = "embedding"
)
type EnableAs string
const (
EnableAsChat EnableAs = "chat"
EnableAsMemory EnableAs = "memory"
EnableAsEmbedding EnableAs = "embedding"
)
type ClientType string
const (
@@ -33,6 +41,7 @@ type Model struct {
IsMultimodal bool `json:"is_multimodal"`
Type ModelType `json:"type"`
Dimensions int `json:"dimensions"`
EnableAs *EnableAs `json:"enable_as,omitempty"`
}
func (m *Model) Validate() error {
@@ -51,6 +60,21 @@ func (m *Model) Validate() error {
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
return errors.New("dimensions must be greater than 0")
}
// Validate enable_as based on type
if m.EnableAs != nil {
switch m.Type {
case ModelTypeEmbedding:
if *m.EnableAs != EnableAsEmbedding {
return errors.New("embedding models can only have enable_as set to 'embedding'")
}
case ModelTypeChat:
if *m.EnableAs != EnableAsChat && *m.EnableAs != EnableAsMemory {
return errors.New("chat models can only have enable_as set to 'chat' or 'memory'")
}
}
}
return nil
}