mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: allow setting default model enable as chat, memory, embedding
This commit is contained in:
+35
-17
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -42,10 +41,10 @@ func (e *resolverTextEmbedder) Dimensions() int {
|
||||
return e.dims
|
||||
}
|
||||
|
||||
func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, error) {
|
||||
func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) {
|
||||
candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return nil, models.GetResponse{}, models.GetResponse{}, err
|
||||
return nil, models.GetResponse{}, models.GetResponse{}, false, err
|
||||
}
|
||||
vectors := map[string]int{}
|
||||
var textModel models.GetResponse
|
||||
@@ -64,13 +63,12 @@ func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[
|
||||
textModel = model
|
||||
}
|
||||
}
|
||||
if textModel.ModelID == "" {
|
||||
return vectors, textModel, multimodalModel, fmt.Errorf("no text embedding model configured")
|
||||
}
|
||||
if multimodalModel.ModelID == "" {
|
||||
return vectors, textModel, multimodalModel, fmt.Errorf("no multimodal embedding model configured")
|
||||
}
|
||||
return vectors, textModel, multimodalModel, nil
|
||||
|
||||
hasTextModel := textModel.ModelID != ""
|
||||
hasMultimodalModel := multimodalModel.ModelID != ""
|
||||
hasAnyModel := hasTextModel || hasMultimodalModel
|
||||
|
||||
return vectors, textModel, multimodalModel, hasAnyModel, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -122,19 +120,36 @@ func main() {
|
||||
time.Duration(cfg.Memory.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second)
|
||||
vectors, textModel, multimodalModel, err := collectEmbeddingVectors(ctx, modelsService)
|
||||
vectors, textModel, multimodalModel, hasModels, err := collectEmbeddingVectors(ctx, modelsService)
|
||||
if err != nil {
|
||||
log.Fatalf("embedding models: %v", err)
|
||||
}
|
||||
if textModel.Dimensions <= 0 {
|
||||
log.Fatalf("text embedding dimensions not configured")
|
||||
|
||||
var memoryService *memory.Service
|
||||
var memoryHandler *handlers.MemoryHandler
|
||||
|
||||
if !hasModels {
|
||||
log.Println("WARNING: No embedding models configured. Memory service will not be available.")
|
||||
log.Println("You can add embedding models via the /models API endpoint.")
|
||||
memoryHandler = handlers.NewMemoryHandler(nil)
|
||||
} else {
|
||||
if textModel.ModelID == "" {
|
||||
log.Println("WARNING: No text embedding model configured. Text embedding features will be limited.")
|
||||
}
|
||||
textEmbedder := &resolverTextEmbedder{
|
||||
if multimodalModel.ModelID == "" {
|
||||
log.Println("WARNING: No multimodal embedding model configured. Multimodal embedding features will be limited.")
|
||||
}
|
||||
|
||||
var textEmbedder embeddings.Embedder
|
||||
var store *memory.QdrantStore
|
||||
|
||||
if textModel.ModelID != "" && textModel.Dimensions > 0 {
|
||||
textEmbedder = &resolverTextEmbedder{
|
||||
resolver: resolver,
|
||||
modelID: textModel.ModelID,
|
||||
dims: textModel.Dimensions,
|
||||
}
|
||||
var store *memory.QdrantStore
|
||||
|
||||
if len(vectors) > 0 {
|
||||
store, err = memory.NewQdrantStoreWithVectors(
|
||||
cfg.Qdrant.BaseURL,
|
||||
@@ -158,8 +173,11 @@ func main() {
|
||||
log.Fatalf("qdrant init: %v", err)
|
||||
}
|
||||
}
|
||||
memoryService := memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
|
||||
memoryHandler := handlers.NewMemoryHandler(memoryService)
|
||||
}
|
||||
|
||||
memoryService = memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
|
||||
memoryHandler = handlers.NewMemoryHandler(memoryService)
|
||||
}
|
||||
embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries)
|
||||
fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace)
|
||||
swaggerHandler := handlers.NewSwaggerHandler()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
DROP TABLE IF EXISTS lifecycle_events;
|
||||
DROP TABLE IF EXISTS container_versions;
|
||||
DROP TABLE IF EXISTS models;
|
||||
DROP TABLE IF EXISTS snapshots;
|
||||
DROP TABLE IF EXISTS containers;
|
||||
DROP TABLE IF EXISTS users;
|
||||
|
||||
@@ -79,9 +79,14 @@ CREATE TABLE IF NOT EXISTS models (
|
||||
dimensions INTEGER,
|
||||
is_multimodal BOOLEAN NOT NULL DEFAULT false,
|
||||
type TEXT NOT NULL DEFAULT 'chat',
|
||||
enable_as TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT models_model_id_unique UNIQUE (model_id),
|
||||
CONSTRAINT models_enable_as_check CHECK (
|
||||
(type = 'embedding' AND (enable_as = 'embedding' OR enable_as IS NULL)) OR
|
||||
(type = 'chat' AND (enable_as IN ('chat', 'memory') OR enable_as IS NULL))
|
||||
),
|
||||
CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')),
|
||||
CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL)
|
||||
);
|
||||
@@ -99,6 +104,10 @@ CREATE TABLE IF NOT EXISTS model_variants (
|
||||
CREATE INDEX IF NOT EXISTS idx_model_variants_model_uuid ON model_variants(model_uuid);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_variants_variant_id ON model_variants(variant_id);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_models_enable_as_unique ON models(enable_as) WHERE enable_as IS NOT NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS container_versions (
|
||||
id TEXT PRIMARY KEY,
|
||||
|
||||
+13
-2
@@ -46,14 +46,15 @@ SELECT COUNT(*) FROM llm_providers;
|
||||
SELECT COUNT(*) FROM llm_providers WHERE client_type = sqlc.arg(client_type);
|
||||
|
||||
-- name: CreateModel :one
|
||||
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type)
|
||||
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as)
|
||||
VALUES (
|
||||
sqlc.arg(model_id),
|
||||
sqlc.arg(name),
|
||||
sqlc.arg(llm_provider_id),
|
||||
sqlc.arg(dimensions),
|
||||
sqlc.arg(is_multimodal),
|
||||
sqlc.arg(type)
|
||||
sqlc.arg(type),
|
||||
sqlc.arg(enable_as)
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -86,6 +87,7 @@ SET
|
||||
dimensions = sqlc.arg(dimensions),
|
||||
is_multimodal = sqlc.arg(is_multimodal),
|
||||
type = sqlc.arg(type),
|
||||
enable_as = sqlc.arg(enable_as),
|
||||
updated_at = now()
|
||||
WHERE id = sqlc.arg(id)
|
||||
RETURNING *;
|
||||
@@ -98,6 +100,7 @@ SET
|
||||
dimensions = sqlc.arg(dimensions),
|
||||
is_multimodal = sqlc.arg(is_multimodal),
|
||||
type = sqlc.arg(type),
|
||||
enable_as = sqlc.arg(enable_as),
|
||||
updated_at = now()
|
||||
WHERE model_id = sqlc.arg(model_id)
|
||||
RETURNING *;
|
||||
@@ -114,6 +117,14 @@ SELECT COUNT(*) FROM models;
|
||||
-- name: CountModelsByType :one
|
||||
SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type);
|
||||
|
||||
-- name: GetModelByEnableAs :one
|
||||
SELECT * FROM models WHERE enable_as = sqlc.arg(enable_as) LIMIT 1;
|
||||
|
||||
-- name: ClearEnableAs :exec
|
||||
UPDATE models
|
||||
SET enable_as = NULL, updated_at = now()
|
||||
WHERE enable_as = sqlc.arg(enable_as);
|
||||
|
||||
-- name: CreateModelVariant :one
|
||||
INSERT INTO model_variants (model_uuid, variant_id, weight, metadata)
|
||||
VALUES (
|
||||
|
||||
@@ -818,6 +818,50 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/enable-as/{enableAs}": {
|
||||
"get": {
|
||||
"description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)",
|
||||
"tags": [
|
||||
"models"
|
||||
],
|
||||
"summary": "Get model by enable_as",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Enable as value (chat, memory, embedding)",
|
||||
"name": "enableAs",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/models.GetResponse"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/model/{modelId}": {
|
||||
"get": {
|
||||
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
|
||||
@@ -1556,6 +1600,9 @@ const docTemplate = `{
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"enable_as": {
|
||||
"$ref": "#/definitions/models.EnableAs"
|
||||
},
|
||||
"is_multimodal": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -1592,12 +1639,28 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"models.EnableAs": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"chat",
|
||||
"memory",
|
||||
"embedding"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"EnableAsChat",
|
||||
"EnableAsMemory",
|
||||
"EnableAsEmbedding"
|
||||
]
|
||||
},
|
||||
"models.GetResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"enable_as": {
|
||||
"$ref": "#/definitions/models.EnableAs"
|
||||
},
|
||||
"is_multimodal": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -1632,6 +1695,9 @@ const docTemplate = `{
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"enable_as": {
|
||||
"$ref": "#/definitions/models.EnableAs"
|
||||
},
|
||||
"is_multimodal": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
@@ -807,6 +807,50 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/enable-as/{enableAs}": {
|
||||
"get": {
|
||||
"description": "Get the model that is enabled for a specific purpose (chat, memory, embedding)",
|
||||
"tags": [
|
||||
"models"
|
||||
],
|
||||
"summary": "Get model by enable_as",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Enable as value (chat, memory, embedding)",
|
||||
"name": "enableAs",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/models.GetResponse"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/model/{modelId}": {
|
||||
"get": {
|
||||
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
|
||||
@@ -1545,6 +1589,9 @@
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"enable_as": {
|
||||
"$ref": "#/definitions/models.EnableAs"
|
||||
},
|
||||
"is_multimodal": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -1581,12 +1628,28 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"models.EnableAs": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"chat",
|
||||
"memory",
|
||||
"embedding"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"EnableAsChat",
|
||||
"EnableAsMemory",
|
||||
"EnableAsEmbedding"
|
||||
]
|
||||
},
|
||||
"models.GetResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"enable_as": {
|
||||
"$ref": "#/definitions/models.EnableAs"
|
||||
},
|
||||
"is_multimodal": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -1621,6 +1684,9 @@
|
||||
"dimensions": {
|
||||
"type": "integer"
|
||||
},
|
||||
"enable_as": {
|
||||
"$ref": "#/definitions/models.EnableAs"
|
||||
},
|
||||
"is_multimodal": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
@@ -305,6 +305,8 @@ definitions:
|
||||
properties:
|
||||
dimensions:
|
||||
type: integer
|
||||
enable_as:
|
||||
$ref: '#/definitions/models.EnableAs'
|
||||
is_multimodal:
|
||||
type: boolean
|
||||
llm_provider_id:
|
||||
@@ -328,10 +330,22 @@ definitions:
|
||||
count:
|
||||
type: integer
|
||||
type: object
|
||||
models.EnableAs:
|
||||
enum:
|
||||
- chat
|
||||
- memory
|
||||
- embedding
|
||||
type: string
|
||||
x-enum-varnames:
|
||||
- EnableAsChat
|
||||
- EnableAsMemory
|
||||
- EnableAsEmbedding
|
||||
models.GetResponse:
|
||||
properties:
|
||||
dimensions:
|
||||
type: integer
|
||||
enable_as:
|
||||
$ref: '#/definitions/models.EnableAs'
|
||||
is_multimodal:
|
||||
type: boolean
|
||||
llm_provider_id:
|
||||
@@ -355,6 +369,8 @@ definitions:
|
||||
properties:
|
||||
dimensions:
|
||||
type: integer
|
||||
enable_as:
|
||||
$ref: '#/definitions/models.EnableAs'
|
||||
is_multimodal:
|
||||
type: boolean
|
||||
llm_provider_id:
|
||||
@@ -984,6 +1000,36 @@ paths:
|
||||
summary: Get model count
|
||||
tags:
|
||||
- models
|
||||
/models/enable-as/{enableAs}:
|
||||
get:
|
||||
description: Get the model that is enabled for a specific purpose (chat, memory,
|
||||
embedding)
|
||||
parameters:
|
||||
- description: Enable as value (chat, memory, embedding)
|
||||
in: path
|
||||
name: enableAs
|
||||
required: true
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
$ref: '#/definitions/models.GetResponse'
|
||||
"400":
|
||||
description: Bad Request
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
"404":
|
||||
description: Not Found
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
summary: Get model by enable_as
|
||||
tags:
|
||||
- models
|
||||
/models/model/{modelId}:
|
||||
delete:
|
||||
description: Delete a model configuration by its model_id field (e.g., gpt-4)
|
||||
|
||||
@@ -60,6 +60,7 @@ type Model struct {
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
IsMultimodal bool `json:"is_multimodal"`
|
||||
Type string `json:"type"`
|
||||
EnableAs pgtype.Text `json:"enable_as"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -11,6 +11,17 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const clearEnableAs = `-- name: ClearEnableAs :exec
|
||||
UPDATE models
|
||||
SET enable_as = NULL, updated_at = now()
|
||||
WHERE enable_as = $1
|
||||
`
|
||||
|
||||
func (q *Queries) ClearEnableAs(ctx context.Context, enableAs pgtype.Text) error {
|
||||
_, err := q.db.Exec(ctx, clearEnableAs, enableAs)
|
||||
return err
|
||||
}
|
||||
|
||||
const countLlmProviders = `-- name: CountLlmProviders :one
|
||||
SELECT COUNT(*) FROM llm_providers
|
||||
`
|
||||
@@ -98,16 +109,17 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa
|
||||
}
|
||||
|
||||
const createModel = `-- name: CreateModel :one
|
||||
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type)
|
||||
INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6
|
||||
$6,
|
||||
$7
|
||||
)
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateModelParams struct {
|
||||
@@ -117,6 +129,7 @@ type CreateModelParams struct {
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
IsMultimodal bool `json:"is_multimodal"`
|
||||
Type string `json:"type"`
|
||||
EnableAs pgtype.Text `json:"enable_as"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) {
|
||||
@@ -127,6 +140,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
|
||||
arg.Dimensions,
|
||||
arg.IsMultimodal,
|
||||
arg.Type,
|
||||
arg.EnableAs,
|
||||
)
|
||||
var i Model
|
||||
err := row.Scan(
|
||||
@@ -137,6 +151,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -257,8 +272,30 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getModelByEnableAs = `-- name: GetModelByEnableAs :one
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE enable_as = $1 LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetModelByEnableAs(ctx context.Context, enableAs pgtype.Text) (Model, error) {
|
||||
row := q.db.QueryRow(ctx, getModelByEnableAs, enableAs)
|
||||
var i Model
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getModelByID = `-- name: GetModelByID :one
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE id = $1
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) {
|
||||
@@ -272,6 +309,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -279,7 +317,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
|
||||
}
|
||||
|
||||
const getModelByModelID = `-- name: GetModelByModelID :one
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE model_id = $1
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models WHERE model_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) {
|
||||
@@ -293,6 +331,7 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -456,7 +495,7 @@ func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID st
|
||||
}
|
||||
|
||||
const listModels = `-- name: ListModels :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
@@ -477,6 +516,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -491,7 +531,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
|
||||
}
|
||||
|
||||
const listModelsByClientType = `-- name: ListModelsByClientType :many
|
||||
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.created_at, m.updated_at FROM models AS m
|
||||
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.enable_as, m.created_at, m.updated_at FROM models AS m
|
||||
JOIN llm_providers AS p ON p.id = m.llm_provider_id
|
||||
WHERE p.client_type = $1
|
||||
ORDER BY m.created_at DESC
|
||||
@@ -514,6 +554,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -528,7 +569,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
|
||||
}
|
||||
|
||||
const listModelsByType = `-- name: ListModelsByType :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at FROM models
|
||||
WHERE type = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
@@ -550,6 +591,7 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model,
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
@@ -616,9 +658,10 @@ SET
|
||||
dimensions = $3,
|
||||
is_multimodal = $4,
|
||||
type = $5,
|
||||
enable_as = $6,
|
||||
updated_at = now()
|
||||
WHERE id = $6
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
|
||||
WHERE id = $7
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateModelParams struct {
|
||||
@@ -627,6 +670,7 @@ type UpdateModelParams struct {
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
IsMultimodal bool `json:"is_multimodal"`
|
||||
Type string `json:"type"`
|
||||
EnableAs pgtype.Text `json:"enable_as"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
@@ -637,6 +681,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
|
||||
arg.Dimensions,
|
||||
arg.IsMultimodal,
|
||||
arg.Type,
|
||||
arg.EnableAs,
|
||||
arg.ID,
|
||||
)
|
||||
var i Model
|
||||
@@ -648,6 +693,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
@@ -662,9 +708,10 @@ SET
|
||||
dimensions = $3,
|
||||
is_multimodal = $4,
|
||||
type = $5,
|
||||
enable_as = $6,
|
||||
updated_at = now()
|
||||
WHERE model_id = $6
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at
|
||||
WHERE model_id = $7
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, enable_as, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateModelByModelIDParams struct {
|
||||
@@ -673,6 +720,7 @@ type UpdateModelByModelIDParams struct {
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
IsMultimodal bool `json:"is_multimodal"`
|
||||
Type string `json:"type"`
|
||||
EnableAs pgtype.Text `json:"enable_as"`
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
|
||||
@@ -683,6 +731,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
|
||||
arg.Dimensions,
|
||||
arg.IsMultimodal,
|
||||
arg.Type,
|
||||
arg.EnableAs,
|
||||
arg.ModelID,
|
||||
)
|
||||
var i Model
|
||||
@@ -694,6 +743,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.EnableAs,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
|
||||
@@ -173,6 +173,23 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model
|
||||
return models.GetResponse{}, errors.New("models service not configured")
|
||||
}
|
||||
|
||||
// If no model specified and no provider specified, try to get default embedding model
|
||||
if req.Model == "" && req.Provider == "" {
|
||||
defaultModel, err := r.modelsService.GetByEnableAs(ctx, models.EnableAsEmbedding)
|
||||
if err == nil {
|
||||
// Found default model, check if it matches the type requirement
|
||||
if req.Type == TypeMultimodal && !defaultModel.IsMultimodal {
|
||||
// Default is text, but need multimodal - continue to search
|
||||
} else if req.Type == TypeText && defaultModel.IsMultimodal {
|
||||
// Default is multimodal, but need text - continue to search
|
||||
} else {
|
||||
// Default model matches requirements
|
||||
return defaultModel, nil
|
||||
}
|
||||
}
|
||||
// No default model or doesn't match requirements, continue to search
|
||||
}
|
||||
|
||||
var candidates []models.GetResponse
|
||||
var err error
|
||||
if req.Provider != "" {
|
||||
|
||||
@@ -31,6 +31,13 @@ func (h *MemoryHandler) Register(e *echo.Echo) {
|
||||
group.DELETE("/memories", h.DeleteAll)
|
||||
}
|
||||
|
||||
func (h *MemoryHandler) checkService() error {
|
||||
if h.service == nil {
|
||||
return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available: no embedding models configured")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbedUpsert godoc
|
||||
// @Summary Embed and upsert memory
|
||||
// @Description Embed text or multimodal input and upsert into memory store
|
||||
@@ -41,6 +48,10 @@ func (h *MemoryHandler) Register(e *echo.Echo) {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/embed [post]
|
||||
func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -72,6 +83,10 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/add [post]
|
||||
func (h *MemoryHandler) Add(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -103,6 +118,10 @@ func (h *MemoryHandler) Add(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/search [post]
|
||||
func (h *MemoryHandler) Search(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -134,6 +153,10 @@ func (h *MemoryHandler) Search(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/update [post]
|
||||
func (h *MemoryHandler) Update(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -170,6 +193,10 @@ func (h *MemoryHandler) Update(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/memories/{memoryId} [get]
|
||||
func (h *MemoryHandler) Get(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -203,6 +230,10 @@ func (h *MemoryHandler) Get(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/memories [get]
|
||||
func (h *MemoryHandler) GetAll(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -240,6 +271,10 @@ func (h *MemoryHandler) GetAll(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/memories/{memoryId} [delete]
|
||||
func (h *MemoryHandler) Delete(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -275,6 +310,10 @@ func (h *MemoryHandler) Delete(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /memory/memories [delete]
|
||||
func (h *MemoryHandler) DeleteAll(c echo.Context) error {
|
||||
if err := h.checkService(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -22,6 +22,7 @@ func (h *ModelsHandler) Register(e *echo.Echo) {
|
||||
group.GET("", h.List)
|
||||
group.GET("/:id", h.GetByID)
|
||||
group.GET("/model/:modelId", h.GetByModelID)
|
||||
group.GET("/enable-as/:enableAs", h.GetByEnableAs)
|
||||
group.PUT("/:id", h.UpdateByID)
|
||||
group.PUT("/model/:modelId", h.UpdateByModelID)
|
||||
group.DELETE("/:id", h.DeleteByID)
|
||||
@@ -230,6 +231,38 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetByEnableAs godoc
|
||||
// @Summary Get model by enable_as
|
||||
// @Description Get the model that is enabled for a specific purpose (chat, memory, embedding)
|
||||
// @Tags models
|
||||
// @Param enableAs path string true "Enable as value (chat, memory, embedding)"
|
||||
// @Success 200 {object} models.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /models/enable-as/{enableAs} [get]
|
||||
// GetByEnableAs godoc
|
||||
// @Summary Get default model by enable_as
|
||||
// @Description Get the default model configured for a specific purpose (chat, memory, or embedding)
|
||||
// @Tags models
|
||||
// @Param enableAs path string true "Enable as value (chat, memory, embedding)"
|
||||
// @Success 200 {object} models.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Router /models/enable-as/{enableAs} [get]
|
||||
func (h *ModelsHandler) GetByEnableAs(c echo.Context) error {
|
||||
enableAs := c.Param("enableAs")
|
||||
if enableAs == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "enableAs is required")
|
||||
}
|
||||
|
||||
resp, err := h.service.GetByEnableAs(c.Request().Context(), models.EnableAs(enableAs))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Count godoc
|
||||
// @Summary Get model count
|
||||
// @Description Get the total count of models, optionally filtered by type
|
||||
|
||||
@@ -28,6 +28,13 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
return AddResponse{}, fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// If enable_as is set, clear any existing model with the same enable_as
|
||||
if model.EnableAs != nil {
|
||||
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
|
||||
return AddResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to sqlc params
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
@@ -51,6 +58,11 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||
}
|
||||
|
||||
// Handle optional enable_as field
|
||||
if model.EnableAs != nil {
|
||||
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
|
||||
}
|
||||
|
||||
created, err := s.queries.CreateModel(ctx, params)
|
||||
if err != nil {
|
||||
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
|
||||
@@ -151,6 +163,13 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// If enable_as is being set, clear any existing model with the same enable_as
|
||||
if model.EnableAs != nil {
|
||||
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelParams{
|
||||
ID: uuid,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
@@ -171,6 +190,11 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||
}
|
||||
|
||||
// Handle optional enable_as field
|
||||
if model.EnableAs != nil {
|
||||
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
|
||||
}
|
||||
|
||||
updated, err := s.queries.UpdateModel(ctx, params)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||
@@ -190,6 +214,13 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// If enable_as is being set, clear any existing model with the same enable_as
|
||||
if model.EnableAs != nil {
|
||||
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelByModelIDParams{
|
||||
ModelID: modelID,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
@@ -210,6 +241,11 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||
}
|
||||
|
||||
// Handle optional enable_as field
|
||||
if model.EnableAs != nil {
|
||||
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
|
||||
}
|
||||
|
||||
updated, err := s.queries.UpdateModelByModelID(ctx, params)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||
@@ -267,6 +303,20 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetByEnableAs retrieves the model that has the specified enable_as value
|
||||
func (s *Service) GetByEnableAs(ctx context.Context, enableAs EnableAs) (GetResponse, error) {
|
||||
if enableAs != EnableAsChat && enableAs != EnableAsMemory && enableAs != EnableAsEmbedding {
|
||||
return GetResponse{}, fmt.Errorf("invalid enable_as value: %s", enableAs)
|
||||
}
|
||||
|
||||
dbModel, err := s.queries.GetModelByEnableAs(ctx, pgtype.Text{String: string(enableAs), Valid: true})
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to get model by enable_as: %w", err)
|
||||
}
|
||||
|
||||
return convertToGetResponse(dbModel), nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func parseUUID(id string) (pgtype.UUID, error) {
|
||||
@@ -304,6 +354,11 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||
resp.Model.Dimensions = int(dbModel.Dimensions.Int32)
|
||||
}
|
||||
|
||||
if dbModel.EnableAs.Valid {
|
||||
enableAs := EnableAs(dbModel.EnableAs.String)
|
||||
resp.Model.EnableAs = &enableAs
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,14 @@ const (
|
||||
ModelTypeEmbedding ModelType = "embedding"
|
||||
)
|
||||
|
||||
type EnableAs string
|
||||
|
||||
const (
|
||||
EnableAsChat EnableAs = "chat"
|
||||
EnableAsMemory EnableAs = "memory"
|
||||
EnableAsEmbedding EnableAs = "embedding"
|
||||
)
|
||||
|
||||
type ClientType string
|
||||
|
||||
const (
|
||||
@@ -33,6 +41,7 @@ type Model struct {
|
||||
IsMultimodal bool `json:"is_multimodal"`
|
||||
Type ModelType `json:"type"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
EnableAs *EnableAs `json:"enable_as,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
@@ -51,6 +60,21 @@ func (m *Model) Validate() error {
|
||||
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
|
||||
return errors.New("dimensions must be greater than 0")
|
||||
}
|
||||
|
||||
// Validate enable_as based on type
|
||||
if m.EnableAs != nil {
|
||||
switch m.Type {
|
||||
case ModelTypeEmbedding:
|
||||
if *m.EnableAs != EnableAsEmbedding {
|
||||
return errors.New("embedding models can only have enable_as set to 'embedding'")
|
||||
}
|
||||
case ModelTypeChat:
|
||||
if *m.EnableAs != EnableAsChat && *m.EnableAs != EnableAsMemory {
|
||||
return errors.New("chat models can only have enable_as set to 'chat' or 'memory'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user