diff --git a/cmd/agent/main.go b/cmd/agent/main.go index acd4f649..3ad8df3d 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "log" "os" "strings" @@ -10,12 +11,68 @@ import ( "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" "github.com/memohai/memoh/internal/db" + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/embeddings" "github.com/memohai/memoh/internal/handlers" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/memory" + "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/server" ) +type resolverTextEmbedder struct { + resolver *embeddings.Resolver + modelID string + dims int +} + +func (e *resolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) { + result, err := e.resolver.Embed(ctx, embeddings.Request{ + Type: embeddings.TypeText, + Model: e.modelID, + Input: embeddings.Input{Text: input}, + }) + if err != nil { + return nil, err + } + return result.Embedding, nil +} + +func (e *resolverTextEmbedder) Dimensions() int { + return e.dims +} + +func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, error) { + candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding) + if err != nil { + return nil, models.GetResponse{}, models.GetResponse{}, err + } + vectors := map[string]int{} + var textModel models.GetResponse + var multimodalModel models.GetResponse + for _, model := range candidates { + if model.Dimensions > 0 && model.ModelID != "" { + vectors[model.ModelID] = model.Dimensions + } + if model.IsMultimodal { + if multimodalModel.ModelID == "" { + multimodalModel = model + } + continue + } + if textModel.ModelID == "" { + textModel = model + } + } + if textModel.ModelID == "" { + return vectors, textModel, multimodalModel, fmt.Errorf("no text embedding model configured") + } + if multimodalModel.ModelID == "" { + return vectors, textModel, multimodalModel, fmt.Errorf("no multimodal embedding model configured") + } + return vectors, textModel, multimodalModel, nil +} + func main() { ctx := context.Background() cfgPath := os.Getenv("CONFIG_PATH") @@ -53,6 +110,8 @@ func main() { } defer conn.Close() manager.WithDB(conn) + queries := dbsqlc.New(conn) + modelsService := models.NewService(queries) pingHandler := handlers.NewPingHandler() authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn) @@ -62,28 +121,49 @@ func main() { cfg.Memory.Model, time.Duration(cfg.Memory.TimeoutSeconds)*time.Second, ) - embedder := memory.NewOpenAIEmbedder( - cfg.Embeddings.OpenAIAPIKey, - cfg.Embeddings.OpenAIBaseURL, - cfg.Embeddings.Model, - cfg.Embeddings.Dimensions, - time.Duration(cfg.Embeddings.TimeoutSeconds)*time.Second, - ) - store, err := memory.NewQdrantStore( - cfg.Qdrant.BaseURL, - cfg.Qdrant.APIKey, - cfg.Qdrant.Collection, - cfg.Embeddings.Dimensions, - time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, - ) + resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second) + vectors, textModel, multimodalModel, err := collectEmbeddingVectors(ctx, modelsService) if err != nil { - log.Fatalf("qdrant init: %v", err) + log.Fatalf("embedding models: %v", err) } - memoryService := memory.NewService(llmClient, embedder, store) + if textModel.Dimensions <= 0 { + log.Fatalf("text embedding dimensions not configured") + } + textEmbedder := &resolverTextEmbedder{ + resolver: resolver, + modelID: textModel.ModelID, + dims: textModel.Dimensions, + } + var store *memory.QdrantStore + if len(vectors) > 0 { + store, err = memory.NewQdrantStoreWithVectors( + cfg.Qdrant.BaseURL, + cfg.Qdrant.APIKey, + cfg.Qdrant.Collection, + vectors, + time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, + ) + if err != nil { + log.Fatalf("qdrant named vectors init: %v", err) + } + } else { + store, err = memory.NewQdrantStore( + cfg.Qdrant.BaseURL, + cfg.Qdrant.APIKey, + cfg.Qdrant.Collection, + textModel.Dimensions, + time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, + ) + if err != nil { + log.Fatalf("qdrant init: %v", err) + } + } + memoryService := memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID) memoryHandler := handlers.NewMemoryHandler(memoryService) + embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries) fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace) swaggerHandler := handlers.NewSwaggerHandler() - srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, fsHandler, swaggerHandler) + srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, fsHandler, swaggerHandler) if err := srv.Start(); err != nil { log.Fatalf("server failed: %v", err) diff --git a/config.toml.example b/config.toml.example index a2dcd28b..1fac56b6 100644 --- a/config.toml.example +++ b/config.toml.example @@ -39,14 +39,5 @@ timeout_seconds = 10 [qdrant] base_url = "http://127.0.0.1:6334" api_key = "" -collection = "mem0" -timeout_seconds = 10 - -## Embeddings configuration -[embeddings] -provider = "openai" -openai_api_key = "" -openai_base_url = "https://api.openai.com/v1" -model = "text-embedding-3-small" -dimensions = 1536 +collection = "memory" timeout_seconds = 10 \ No newline at end of file diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 48639ab7..eebaccd8 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -55,25 +55,50 @@ CREATE TABLE IF NOT EXISTS snapshots ( created_at TIMESTAMPTZ NOT NULL DEFAULT now() ); +CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id); +CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id); + +CREATE TABLE IF NOT EXISTS llm_providers ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + client_type TEXT NOT NULL, + base_url TEXT NOT NULL, + api_key TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT llm_providers_name_unique UNIQUE (name), + CONSTRAINT llm_providers_client_type_check CHECK (client_type IN ('openai', 'anthropic', 'google', 'bedrock', 'ollama', 'azure', 'dashscope', 'other')) +); + CREATE TABLE IF NOT EXISTS models ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), model_id TEXT NOT NULL, name TEXT, - base_url TEXT NOT NULL, - api_key TEXT NOT NULL, - client_type TEXT NOT NULL, + llm_provider_id UUID NOT NULL REFERENCES llm_providers(id) ON DELETE CASCADE, dimensions INTEGER, + is_multimodal BOOLEAN NOT NULL DEFAULT false, type TEXT NOT NULL DEFAULT 'chat', created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), CONSTRAINT models_model_id_unique UNIQUE (model_id), CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')), - CONSTRAINT models_client_type_check CHECK (client_type IN ('openai', 'anthropic', 'google')), CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL) ); -CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id); -CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id); +CREATE TABLE IF NOT EXISTS model_variants ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + model_uuid UUID NOT NULL REFERENCES models(id) ON DELETE CASCADE, + variant_id TEXT NOT NULL, + weight INTEGER NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_model_variants_model_uuid ON model_variants(model_uuid); +CREATE INDEX IF NOT EXISTS idx_model_variants_variant_id ON model_variants(variant_id); + CREATE TABLE IF NOT EXISTS container_versions ( id TEXT PRIMARY KEY, diff --git a/db/queries/models.sql b/db/queries/models.sql index 6d21f972..fe444bbb 100644 --- a/db/queries/models.sql +++ b/db/queries/models.sql @@ -1,12 +1,58 @@ +-- name: CreateLlmProvider :one +INSERT INTO llm_providers (name, client_type, base_url, api_key, metadata) +VALUES ( + sqlc.arg(name), + sqlc.arg(client_type), + sqlc.arg(base_url), + sqlc.arg(api_key), + sqlc.arg(metadata) +) +RETURNING *; + +-- name: GetLlmProviderByID :one +SELECT * FROM llm_providers WHERE id = sqlc.arg(id); + +-- name: GetLlmProviderByName :one +SELECT * FROM llm_providers WHERE name = sqlc.arg(name); + +-- name: ListLlmProviders :many +SELECT * FROM llm_providers +ORDER BY created_at DESC; + +-- name: ListLlmProvidersByClientType :many +SELECT * FROM llm_providers +WHERE client_type = sqlc.arg(client_type) +ORDER BY created_at DESC; + +-- name: UpdateLlmProvider :one +UPDATE llm_providers +SET + name = sqlc.arg(name), + client_type = sqlc.arg(client_type), + base_url = sqlc.arg(base_url), + api_key = sqlc.arg(api_key), + metadata = sqlc.arg(metadata), + updated_at = now() +WHERE id = sqlc.arg(id) +RETURNING *; + +-- name: DeleteLlmProvider :exec +DELETE FROM llm_providers WHERE id = sqlc.arg(id); + +-- name: CountLlmProviders :one +SELECT COUNT(*) FROM llm_providers; + +-- name: CountLlmProvidersByClientType :one +SELECT COUNT(*) FROM llm_providers WHERE client_type = sqlc.arg(client_type); + -- name: CreateModel :one -INSERT INTO models (model_id, name, base_url, api_key, client_type, dimensions, type) +INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type) VALUES ( sqlc.arg(model_id), sqlc.arg(name), - sqlc.arg(base_url), - sqlc.arg(api_key), - sqlc.arg(client_type), + sqlc.arg(llm_provider_id), sqlc.arg(dimensions), + sqlc.arg(is_multimodal), sqlc.arg(type) ) RETURNING *; @@ -27,18 +73,18 @@ WHERE type = sqlc.arg(type) ORDER BY created_at DESC; -- name: ListModelsByClientType :many -SELECT * FROM models -WHERE client_type = sqlc.arg(client_type) -ORDER BY created_at DESC; +SELECT m.* FROM models AS m +JOIN llm_providers AS p ON p.id = m.llm_provider_id +WHERE p.client_type = sqlc.arg(client_type) +ORDER BY m.created_at DESC; -- name: UpdateModel :one UPDATE models SET name = sqlc.arg(name), - base_url = sqlc.arg(base_url), - api_key = sqlc.arg(api_key), - client_type = sqlc.arg(client_type), + llm_provider_id = sqlc.arg(llm_provider_id), dimensions = sqlc.arg(dimensions), + is_multimodal = sqlc.arg(is_multimodal), type = sqlc.arg(type), updated_at = now() WHERE id = sqlc.arg(id) @@ -48,10 +94,9 @@ RETURNING *; UPDATE models SET name = sqlc.arg(name), - base_url = sqlc.arg(base_url), - api_key = sqlc.arg(api_key), - client_type = sqlc.arg(client_type), + llm_provider_id = sqlc.arg(llm_provider_id), dimensions = sqlc.arg(dimensions), + is_multimodal = sqlc.arg(is_multimodal), type = sqlc.arg(type), updated_at = now() WHERE model_id = sqlc.arg(model_id) @@ -69,3 +114,39 @@ SELECT COUNT(*) FROM models; -- name: CountModelsByType :one SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type); +-- name: CreateModelVariant :one +INSERT INTO model_variants (model_uuid, variant_id, weight, metadata) +VALUES ( + sqlc.arg(model_uuid), + sqlc.arg(variant_id), + sqlc.arg(weight), + sqlc.arg(metadata) +) +RETURNING *; + +-- name: GetModelVariantByID :one +SELECT * FROM model_variants WHERE id = sqlc.arg(id); + +-- name: ListModelVariantsByModelUUID :many +SELECT * FROM model_variants +WHERE model_uuid = sqlc.arg(model_uuid) +ORDER BY weight DESC, created_at DESC; + +-- name: ListModelVariantsByVariantID :many +SELECT * FROM model_variants +WHERE variant_id = sqlc.arg(variant_id) +ORDER BY created_at DESC; + +-- name: UpdateModelVariant :one +UPDATE model_variants +SET + variant_id = sqlc.arg(variant_id), + weight = sqlc.arg(weight), + metadata = sqlc.arg(metadata), + updated_at = now() +WHERE id = sqlc.arg(id) +RETURNING *; + +-- name: DeleteModelVariant :exec +DELETE FROM model_variants WHERE id = sqlc.arg(id); + diff --git a/docs/docs.go b/docs/docs.go index 2d46cf22..4d7b476d 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -61,6 +61,52 @@ const docTemplate = `{ } } }, + "/embeddings": { + "post": { + "description": "Create text or multimodal embeddings", + "tags": [ + "embeddings" + ], + "summary": "Create embeddings", + "parameters": [ + { + "description": "Embeddings request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.EmbeddingsRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.EmbeddingsResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "501": { + "description": "Not Implemented", + "schema": { + "$ref": "#/definitions/handlers.EmbeddingsResponse" + } + } + } + } + }, "/fs/apply_patch": { "post": { "description": "Apply a unified diff patch to a file under the user data mount", @@ -364,6 +410,46 @@ const docTemplate = `{ } } }, + "/memory/embed": { + "post": { + "description": "Embed text or multimodal input and upsert into memory store", + "tags": [ + "memory" + ], + "summary": "Embed and upsert memory", + "parameters": [ + { + "description": "Embed upsert request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/memory.EmbedUpsertRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/memory.EmbedUpsertResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/memory/memories": { "get": { "description": "List memories for a user via memory", @@ -1044,6 +1130,83 @@ const docTemplate = `{ } } }, + "handlers.EmbeddingsInput": { + "type": "object", + "properties": { + "image_url": { + "type": "string" + }, + "text": { + "type": "string" + }, + "video_url": { + "type": "string" + } + } + }, + "handlers.EmbeddingsRequest": { + "type": "object", + "properties": { + "dimensions": { + "type": "integer" + }, + "input": { + "$ref": "#/definitions/handlers.EmbeddingsInput" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "handlers.EmbeddingsResponse": { + "type": "object", + "properties": { + "dimensions": { + "type": "integer" + }, + "embedding": { + "type": "array", + "items": { + "type": "number" + } + }, + "message": { + "type": "string" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "type": { + "type": "string" + }, + "usage": { + "$ref": "#/definitions/handlers.EmbeddingsUsage" + } + } + }, + "handlers.EmbeddingsUsage": { + "type": "object", + "properties": { + "image_tokens": { + "type": "integer" + }, + "input_tokens": { + "type": "integer" + }, + "video_tokens": { + "type": "integer" + } + } + }, "handlers.ErrorResponse": { "type": "object", "properties": { @@ -1216,6 +1379,74 @@ const docTemplate = `{ } } }, + "memory.EmbedInput": { + "type": "object", + "properties": { + "image_url": { + "type": "string" + }, + "text": { + "type": "string" + }, + "video_url": { + "type": "string" + } + } + }, + "memory.EmbedUpsertRequest": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "input": { + "$ref": "#/definitions/memory.EmbedInput" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "run_id": { + "type": "string" + }, + "source": { + "type": "string" + }, + "type": { + "type": "string" + }, + "user_id": { + "type": "string" + } + } + }, + "memory.EmbedUpsertResponse": { + "type": "object", + "properties": { + "dimensions": { + "type": "integer" + }, + "item": { + "$ref": "#/definitions/memory.MemoryItem" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + } + } + }, "memory.MemoryItem": { "type": "object", "properties": { @@ -1282,6 +1513,12 @@ const docTemplate = `{ "run_id": { "type": "string" }, + "sources": { + "type": "array", + "items": { + "type": "string" + } + }, "user_id": { "type": "string" } @@ -1316,18 +1553,15 @@ const docTemplate = `{ "models.AddRequest": { "type": "object", "properties": { - "api_key": { - "type": "string" - }, - "base_url": { - "type": "string" - }, - "client_type": { - "$ref": "#/definitions/models.ClientType" - }, "dimensions": { "type": "integer" }, + "is_multimodal": { + "type": "boolean" + }, + "llm_provider_id": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -1335,7 +1569,7 @@ const docTemplate = `{ "type": "string" }, "type": { - "type": "string" + "$ref": "#/definitions/models.ModelType" } } }, @@ -1350,19 +1584,6 @@ const docTemplate = `{ } } }, - "models.ClientType": { - "type": "string", - "enum": [ - "openai", - "anthropic", - "google" - ], - "x-enum-varnames": [ - "ClientTypeOpenAI", - "ClientTypeAnthropic", - "ClientTypeGoogle" - ] - }, "models.CountResponse": { "type": "object", "properties": { @@ -1374,18 +1595,15 @@ const docTemplate = `{ "models.GetResponse": { "type": "object", "properties": { - "api_key": { - "type": "string" - }, - "base_url": { - "type": "string" - }, - "client_type": { - "$ref": "#/definitions/models.ClientType" - }, "dimensions": { "type": "integer" }, + "is_multimodal": { + "type": "boolean" + }, + "llm_provider_id": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -1393,25 +1611,33 @@ const docTemplate = `{ "type": "string" }, "type": { - "type": "string" + "$ref": "#/definitions/models.ModelType" } } }, + "models.ModelType": { + "type": "string", + "enum": [ + "chat", + "embedding" + ], + "x-enum-varnames": [ + "ModelTypeChat", + "ModelTypeEmbedding" + ] + }, "models.UpdateRequest": { "type": "object", "properties": { - "api_key": { - "type": "string" - }, - "base_url": { - "type": "string" - }, - "client_type": { - "$ref": "#/definitions/models.ClientType" - }, "dimensions": { "type": "integer" }, + "is_multimodal": { + "type": "boolean" + }, + "llm_provider_id": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -1419,7 +1645,7 @@ const docTemplate = `{ "type": "string" }, "type": { - "type": "string" + "$ref": "#/definitions/models.ModelType" } } } diff --git a/docs/swagger.json b/docs/swagger.json index 0c058f78..ec81d649 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -50,6 +50,52 @@ } } }, + "/embeddings": { + "post": { + "description": "Create text or multimodal embeddings", + "tags": [ + "embeddings" + ], + "summary": "Create embeddings", + "parameters": [ + { + "description": "Embeddings request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/handlers.EmbeddingsRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.EmbeddingsResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "501": { + "description": "Not Implemented", + "schema": { + "$ref": "#/definitions/handlers.EmbeddingsResponse" + } + } + } + } + }, "/fs/apply_patch": { "post": { "description": "Apply a unified diff patch to a file under the user data mount", @@ -353,6 +399,46 @@ } } }, + "/memory/embed": { + "post": { + "description": "Embed text or multimodal input and upsert into memory store", + "tags": [ + "memory" + ], + "summary": "Embed and upsert memory", + "parameters": [ + { + "description": "Embed upsert request", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/memory.EmbedUpsertRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/memory.EmbedUpsertResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, "/memory/memories": { "get": { "description": "List memories for a user via memory", @@ -1033,6 +1119,83 @@ } } }, + "handlers.EmbeddingsInput": { + "type": "object", + "properties": { + "image_url": { + "type": "string" + }, + "text": { + "type": "string" + }, + "video_url": { + "type": "string" + } + } + }, + "handlers.EmbeddingsRequest": { + "type": "object", + "properties": { + "dimensions": { + "type": "integer" + }, + "input": { + "$ref": "#/definitions/handlers.EmbeddingsInput" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "handlers.EmbeddingsResponse": { + "type": "object", + "properties": { + "dimensions": { + "type": "integer" + }, + "embedding": { + "type": "array", + "items": { + "type": "number" + } + }, + "message": { + "type": "string" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "type": { + "type": "string" + }, + "usage": { + "$ref": "#/definitions/handlers.EmbeddingsUsage" + } + } + }, + "handlers.EmbeddingsUsage": { + "type": "object", + "properties": { + "image_tokens": { + "type": "integer" + }, + "input_tokens": { + "type": "integer" + }, + "video_tokens": { + "type": "integer" + } + } + }, "handlers.ErrorResponse": { "type": "object", "properties": { @@ -1205,6 +1368,74 @@ } } }, + "memory.EmbedInput": { + "type": "object", + "properties": { + "image_url": { + "type": "string" + }, + "text": { + "type": "string" + }, + "video_url": { + "type": "string" + } + } + }, + "memory.EmbedUpsertRequest": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "filters": { + "type": "object", + "additionalProperties": true + }, + "input": { + "$ref": "#/definitions/memory.EmbedInput" + }, + "metadata": { + "type": "object", + "additionalProperties": true + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "run_id": { + "type": "string" + }, + "source": { + "type": "string" + }, + "type": { + "type": "string" + }, + "user_id": { + "type": "string" + } + } + }, + "memory.EmbedUpsertResponse": { + "type": "object", + "properties": { + "dimensions": { + "type": "integer" + }, + "item": { + "$ref": "#/definitions/memory.MemoryItem" + }, + "model": { + "type": "string" + }, + "provider": { + "type": "string" + } + } + }, "memory.MemoryItem": { "type": "object", "properties": { @@ -1271,6 +1502,12 @@ "run_id": { "type": "string" }, + "sources": { + "type": "array", + "items": { + "type": "string" + } + }, "user_id": { "type": "string" } @@ -1305,18 +1542,15 @@ "models.AddRequest": { "type": "object", "properties": { - "api_key": { - "type": "string" - }, - "base_url": { - "type": "string" - }, - "client_type": { - "$ref": "#/definitions/models.ClientType" - }, "dimensions": { "type": "integer" }, + "is_multimodal": { + "type": "boolean" + }, + "llm_provider_id": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -1324,7 +1558,7 @@ "type": "string" }, "type": { - "type": "string" + "$ref": "#/definitions/models.ModelType" } } }, @@ -1339,19 +1573,6 @@ } } }, - "models.ClientType": { - "type": "string", - "enum": [ - "openai", - "anthropic", - "google" - ], - "x-enum-varnames": [ - "ClientTypeOpenAI", - "ClientTypeAnthropic", - "ClientTypeGoogle" - ] - }, "models.CountResponse": { "type": "object", "properties": { @@ -1363,18 +1584,15 @@ "models.GetResponse": { "type": "object", "properties": { - "api_key": { - "type": "string" - }, - "base_url": { - "type": "string" - }, - "client_type": { - "$ref": "#/definitions/models.ClientType" - }, "dimensions": { "type": "integer" }, + "is_multimodal": { + "type": "boolean" + }, + "llm_provider_id": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -1382,25 +1600,33 @@ "type": "string" }, "type": { - "type": "string" + "$ref": "#/definitions/models.ModelType" } } }, + "models.ModelType": { + "type": "string", + "enum": [ + "chat", + "embedding" + ], + "x-enum-varnames": [ + "ModelTypeChat", + "ModelTypeEmbedding" + ] + }, "models.UpdateRequest": { "type": "object", "properties": { - "api_key": { - "type": "string" - }, - "base_url": { - "type": "string" - }, - "client_type": { - "$ref": "#/definitions/models.ClientType" - }, "dimensions": { "type": "integer" }, + "is_multimodal": { + "type": "boolean" + }, + "llm_provider_id": { + "type": "string" + }, "model_id": { "type": "string" }, @@ -1408,7 +1634,7 @@ "type": "string" }, "type": { - "type": "string" + "$ref": "#/definitions/models.ModelType" } } } diff --git a/docs/swagger.yaml b/docs/swagger.yaml index af179497..e7a88ec5 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -26,6 +26,56 @@ definitions: version: type: integer type: object + handlers.EmbeddingsInput: + properties: + image_url: + type: string + text: + type: string + video_url: + type: string + type: object + handlers.EmbeddingsRequest: + properties: + dimensions: + type: integer + input: + $ref: '#/definitions/handlers.EmbeddingsInput' + model: + type: string + provider: + type: string + type: + type: string + type: object + handlers.EmbeddingsResponse: + properties: + dimensions: + type: integer + embedding: + items: + type: number + type: array + message: + type: string + model: + type: string + provider: + type: string + type: + type: string + usage: + $ref: '#/definitions/handlers.EmbeddingsUsage' + type: object + handlers.EmbeddingsUsage: + properties: + image_tokens: + type: integer + input_tokens: + type: integer + video_tokens: + type: integer + type: object handlers.ErrorResponse: properties: message: @@ -138,6 +188,51 @@ definitions: message: type: string type: object + memory.EmbedInput: + properties: + image_url: + type: string + text: + type: string + video_url: + type: string + type: object + memory.EmbedUpsertRequest: + properties: + agent_id: + type: string + filters: + additionalProperties: true + type: object + input: + $ref: '#/definitions/memory.EmbedInput' + metadata: + additionalProperties: true + type: object + model: + type: string + provider: + type: string + run_id: + type: string + source: + type: string + type: + type: string + user_id: + type: string + type: object + memory.EmbedUpsertResponse: + properties: + dimensions: + type: integer + item: + $ref: '#/definitions/memory.MemoryItem' + model: + type: string + provider: + type: string + type: object memory.MemoryItem: properties: agentId: @@ -182,6 +277,10 @@ definitions: type: string run_id: type: string + sources: + items: + type: string + type: array user_id: type: string type: object @@ -204,20 +303,18 @@ definitions: type: object models.AddRequest: properties: - api_key: - type: string - base_url: - type: string - client_type: - $ref: '#/definitions/models.ClientType' dimensions: type: integer + is_multimodal: + type: boolean + llm_provider_id: + type: string model_id: type: string name: type: string type: - type: string + $ref: '#/definitions/models.ModelType' type: object models.AddResponse: properties: @@ -226,16 +323,6 @@ definitions: model_id: type: string type: object - models.ClientType: - enum: - - openai - - anthropic - - google - type: string - x-enum-varnames: - - ClientTypeOpenAI - - ClientTypeAnthropic - - ClientTypeGoogle models.CountResponse: properties: count: @@ -243,37 +330,41 @@ definitions: type: object models.GetResponse: properties: - api_key: - type: string - base_url: - type: string - client_type: - $ref: '#/definitions/models.ClientType' dimensions: type: integer + is_multimodal: + type: boolean + llm_provider_id: + type: string model_id: type: string name: type: string type: - type: string + $ref: '#/definitions/models.ModelType' type: object + models.ModelType: + enum: + - chat + - embedding + type: string + x-enum-varnames: + - ModelTypeChat + - ModelTypeEmbedding models.UpdateRequest: properties: - api_key: - type: string - base_url: - type: string - client_type: - $ref: '#/definitions/models.ClientType' dimensions: type: integer + is_multimodal: + type: boolean + llm_provider_id: + type: string model_id: type: string name: type: string type: - type: string + $ref: '#/definitions/models.ModelType' type: object info: contact: {} @@ -308,6 +399,36 @@ paths: summary: Login tags: - auth + /embeddings: + post: + description: Create text or multimodal embeddings + parameters: + - description: Embeddings request + in: body + name: payload + required: true + schema: + $ref: '#/definitions/handlers.EmbeddingsRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.EmbeddingsResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "501": + description: Not Implemented + schema: + $ref: '#/definitions/handlers.EmbeddingsResponse' + summary: Create embeddings + tags: + - embeddings /fs/apply_patch: post: description: Apply a unified diff patch to a file under the user data mount @@ -506,6 +627,32 @@ paths: summary: Add memory tags: - memory + /memory/embed: + post: + description: Embed text or multimodal input and upsert into memory store + parameters: + - description: Embed upsert request + in: body + name: payload + required: true + schema: + $ref: '#/definitions/memory.EmbedUpsertRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/memory.EmbedUpsertResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Embed and upsert memory + tags: + - memory /memory/memories: delete: description: Delete all memories for a user via memory diff --git a/internal/config/config.go b/internal/config/config.go index d8a2361b..ea2c6e2e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,8 +24,6 @@ const ( DefaultMemoryTimeout = 10 DefaultQdrantURL = "http://127.0.0.1:6334" DefaultQdrantCollection = "memory" - DefaultEmbeddingModel = "text-embedding-3-small" - DefaultEmbeddingDims = 1536 ) type Config struct { @@ -36,7 +34,6 @@ type Config struct { Postgres PostgresConfig `toml:"postgres"` Memory MemoryConfig `toml:"memory"` Qdrant QdrantConfig `toml:"qdrant"` - Embeddings EmbeddingsConfig `toml:"embeddings"` } type ServerConfig struct { @@ -83,15 +80,6 @@ type QdrantConfig struct { TimeoutSeconds int `toml:"timeout_seconds"` } -type EmbeddingsConfig struct { - Provider string `toml:"provider"` - OpenAIAPIKey string `toml:"openai_api_key"` - OpenAIBaseURL string `toml:"openai_base_url"` - Model string `toml:"model"` - Dimensions int `toml:"dimensions"` - TimeoutSeconds int `toml:"timeout_seconds"` -} - func Load(path string) (Config, error) { cfg := Config{ Server: ServerConfig{ @@ -125,11 +113,6 @@ func Load(path string) (Config, error) { BaseURL: DefaultQdrantURL, Collection: DefaultQdrantCollection, }, - Embeddings: EmbeddingsConfig{ - Provider: "openai", - Model: DefaultEmbeddingModel, - Dimensions: DefaultEmbeddingDims, - }, } if path == "" { diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 63b9d97a..85eee040 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -41,19 +41,39 @@ type LifecycleEvent struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } -type Model struct { +type LlmProvider struct { ID pgtype.UUID `json:"id"` - ModelID string `json:"model_id"` - Name pgtype.Text `json:"name"` + Name string `json:"name"` + ClientType string `json:"client_type"` BaseUrl string `json:"base_url"` ApiKey string `json:"api_key"` - ClientType string `json:"client_type"` - Dimensions pgtype.Int4 `json:"dimensions"` - Type string `json:"type"` + Metadata []byte `json:"metadata"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } +type Model struct { + ID pgtype.UUID `json:"id"` + ModelID string `json:"model_id"` + Name pgtype.Text `json:"name"` + LlmProviderID pgtype.UUID `json:"llm_provider_id"` + Dimensions pgtype.Int4 `json:"dimensions"` + IsMultimodal bool `json:"is_multimodal"` + Type string `json:"type"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +type ModelVariant struct { + ID pgtype.UUID `json:"id"` + ModelUuid pgtype.UUID `json:"model_uuid"` + VariantID string `json:"variant_id"` + Weight int32 `json:"weight"` + Metadata []byte `json:"metadata"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + type Snapshot struct { ID string `json:"id"` ContainerID string `json:"container_id"` diff --git a/internal/db/sqlc/models.sql.go b/internal/db/sqlc/models.sql.go index 9c4d23be..e66c3ce2 100644 --- a/internal/db/sqlc/models.sql.go +++ b/internal/db/sqlc/models.sql.go @@ -11,6 +11,28 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const countLlmProviders = `-- name: CountLlmProviders :one +SELECT COUNT(*) FROM llm_providers +` + +func (q *Queries) CountLlmProviders(ctx context.Context) (int64, error) { + row := q.db.QueryRow(ctx, countLlmProviders) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countLlmProvidersByClientType = `-- name: CountLlmProvidersByClientType :one +SELECT COUNT(*) FROM llm_providers WHERE client_type = $1 +` + +func (q *Queries) CountLlmProvidersByClientType(ctx context.Context, clientType string) (int64, error) { + row := q.db.QueryRow(ctx, countLlmProvidersByClientType, clientType) + var count int64 + err := row.Scan(&count) + return count, err +} + const countModels = `-- name: CountModels :one SELECT COUNT(*) FROM models ` @@ -33,38 +55,77 @@ func (q *Queries) CountModelsByType(ctx context.Context, type_ string) (int64, e return count, err } +const createLlmProvider = `-- name: CreateLlmProvider :one +INSERT INTO llm_providers (name, client_type, base_url, api_key, metadata) +VALUES ( + $1, + $2, + $3, + $4, + $5 +) +RETURNING id, name, client_type, base_url, api_key, metadata, created_at, updated_at +` + +type CreateLlmProviderParams struct { + Name string `json:"name"` + ClientType string `json:"client_type"` + BaseUrl string `json:"base_url"` + ApiKey string `json:"api_key"` + Metadata []byte `json:"metadata"` +} + +func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderParams) (LlmProvider, error) { + row := q.db.QueryRow(ctx, createLlmProvider, + arg.Name, + arg.ClientType, + arg.BaseUrl, + arg.ApiKey, + arg.Metadata, + ) + var i LlmProvider + err := row.Scan( + &i.ID, + &i.Name, + &i.ClientType, + &i.BaseUrl, + &i.ApiKey, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const createModel = `-- name: CreateModel :one -INSERT INTO models (model_id, name, base_url, api_key, client_type, dimensions, type) +INSERT INTO models (model_id, name, llm_provider_id, dimensions, is_multimodal, type) VALUES ( $1, $2, $3, $4, $5, - $6, - $7 + $6 ) -RETURNING id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at ` type CreateModelParams struct { - ModelID string `json:"model_id"` - Name pgtype.Text `json:"name"` - BaseUrl string `json:"base_url"` - ApiKey string `json:"api_key"` - ClientType string `json:"client_type"` - Dimensions pgtype.Int4 `json:"dimensions"` - Type string `json:"type"` + ModelID string `json:"model_id"` + Name pgtype.Text `json:"name"` + LlmProviderID pgtype.UUID `json:"llm_provider_id"` + Dimensions pgtype.Int4 `json:"dimensions"` + IsMultimodal bool `json:"is_multimodal"` + Type string `json:"type"` } func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) { row := q.db.QueryRow(ctx, createModel, arg.ModelID, arg.Name, - arg.BaseUrl, - arg.ApiKey, - arg.ClientType, + arg.LlmProviderID, arg.Dimensions, + arg.IsMultimodal, arg.Type, ) var i Model @@ -72,10 +133,9 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, @@ -83,6 +143,53 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model return i, err } +const createModelVariant = `-- name: CreateModelVariant :one +INSERT INTO model_variants (model_uuid, variant_id, weight, metadata) +VALUES ( + $1, + $2, + $3, + $4 +) +RETURNING id, model_uuid, variant_id, weight, metadata, created_at, updated_at +` + +type CreateModelVariantParams struct { + ModelUuid pgtype.UUID `json:"model_uuid"` + VariantID string `json:"variant_id"` + Weight int32 `json:"weight"` + Metadata []byte `json:"metadata"` +} + +func (q *Queries) CreateModelVariant(ctx context.Context, arg CreateModelVariantParams) (ModelVariant, error) { + row := q.db.QueryRow(ctx, createModelVariant, + arg.ModelUuid, + arg.VariantID, + arg.Weight, + arg.Metadata, + ) + var i ModelVariant + err := row.Scan( + &i.ID, + &i.ModelUuid, + &i.VariantID, + &i.Weight, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteLlmProvider = `-- name: DeleteLlmProvider :exec +DELETE FROM llm_providers WHERE id = $1 +` + +func (q *Queries) DeleteLlmProvider(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteLlmProvider, id) + return err +} + const deleteModel = `-- name: DeleteModel :exec DELETE FROM models WHERE id = $1 ` @@ -101,8 +208,57 @@ func (q *Queries) DeleteModelByModelID(ctx context.Context, modelID string) erro return err } +const deleteModelVariant = `-- name: DeleteModelVariant :exec +DELETE FROM model_variants WHERE id = $1 +` + +func (q *Queries) DeleteModelVariant(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteModelVariant, id) + return err +} + +const getLlmProviderByID = `-- name: GetLlmProviderByID :one +SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE id = $1 +` + +func (q *Queries) GetLlmProviderByID(ctx context.Context, id pgtype.UUID) (LlmProvider, error) { + row := q.db.QueryRow(ctx, getLlmProviderByID, id) + var i LlmProvider + err := row.Scan( + &i.ID, + &i.Name, + &i.ClientType, + &i.BaseUrl, + &i.ApiKey, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getLlmProviderByName = `-- name: GetLlmProviderByName :one +SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE name = $1 +` + +func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmProvider, error) { + row := q.db.QueryRow(ctx, getLlmProviderByName, name) + var i LlmProvider + err := row.Scan( + &i.ID, + &i.Name, + &i.ClientType, + &i.BaseUrl, + &i.ApiKey, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getModelByID = `-- name: GetModelByID :one -SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models WHERE id = $1 +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE id = $1 ` func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) { @@ -112,10 +268,9 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, @@ -124,7 +279,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro } const getModelByModelID = `-- name: GetModelByModelID :one -SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models WHERE model_id = $1 +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE model_id = $1 ` func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) { @@ -134,10 +289,9 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, @@ -145,8 +299,164 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, return i, err } +const getModelVariantByID = `-- name: GetModelVariantByID :one +SELECT id, model_uuid, variant_id, weight, metadata, created_at, updated_at FROM model_variants WHERE id = $1 +` + +func (q *Queries) GetModelVariantByID(ctx context.Context, id pgtype.UUID) (ModelVariant, error) { + row := q.db.QueryRow(ctx, getModelVariantByID, id) + var i ModelVariant + err := row.Scan( + &i.ID, + &i.ModelUuid, + &i.VariantID, + &i.Weight, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listLlmProviders = `-- name: ListLlmProviders :many +SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers +ORDER BY created_at DESC +` + +func (q *Queries) ListLlmProviders(ctx context.Context) ([]LlmProvider, error) { + rows, err := q.db.Query(ctx, listLlmProviders) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LlmProvider + for rows.Next() { + var i LlmProvider + if err := rows.Scan( + &i.ID, + &i.Name, + &i.ClientType, + &i.BaseUrl, + &i.ApiKey, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listLlmProvidersByClientType = `-- name: ListLlmProvidersByClientType :many +SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers +WHERE client_type = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListLlmProvidersByClientType(ctx context.Context, clientType string) ([]LlmProvider, error) { + rows, err := q.db.Query(ctx, listLlmProvidersByClientType, clientType) + if err != nil { + return nil, err + } + defer rows.Close() + var items []LlmProvider + for rows.Next() { + var i LlmProvider + if err := rows.Scan( + &i.ID, + &i.Name, + &i.ClientType, + &i.BaseUrl, + &i.ApiKey, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listModelVariantsByModelUUID = `-- name: ListModelVariantsByModelUUID :many +SELECT id, model_uuid, variant_id, weight, metadata, created_at, updated_at FROM model_variants +WHERE model_uuid = $1 +ORDER BY weight DESC, created_at DESC +` + +func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pgtype.UUID) ([]ModelVariant, error) { + rows, err := q.db.Query(ctx, listModelVariantsByModelUUID, modelUuid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ModelVariant + for rows.Next() { + var i ModelVariant + if err := rows.Scan( + &i.ID, + &i.ModelUuid, + &i.VariantID, + &i.Weight, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listModelVariantsByVariantID = `-- name: ListModelVariantsByVariantID :many +SELECT id, model_uuid, variant_id, weight, metadata, created_at, updated_at FROM model_variants +WHERE variant_id = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID string) ([]ModelVariant, error) { + rows, err := q.db.Query(ctx, listModelVariantsByVariantID, variantID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ModelVariant + for rows.Next() { + var i ModelVariant + if err := rows.Scan( + &i.ID, + &i.ModelUuid, + &i.VariantID, + &i.Weight, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listModels = `-- name: ListModels :many -SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models ORDER BY created_at DESC ` @@ -163,10 +473,9 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) { &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, @@ -182,9 +491,10 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) { } const listModelsByClientType = `-- name: ListModelsByClientType :many -SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models -WHERE client_type = $1 -ORDER BY created_at DESC +SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.is_multimodal, m.type, m.created_at, m.updated_at FROM models AS m +JOIN llm_providers AS p ON p.id = m.llm_provider_id +WHERE p.client_type = $1 +ORDER BY m.created_at DESC ` func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) ([]Model, error) { @@ -200,10 +510,9 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, @@ -219,7 +528,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) } const listModelsByType = `-- name: ListModelsByType :many -SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models +SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models WHERE type = $1 ORDER BY created_at DESC ` @@ -237,10 +546,9 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model, &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, @@ -255,37 +563,79 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model, return items, nil } +const updateLlmProvider = `-- name: UpdateLlmProvider :one +UPDATE llm_providers +SET + name = $1, + client_type = $2, + base_url = $3, + api_key = $4, + metadata = $5, + updated_at = now() +WHERE id = $6 +RETURNING id, name, client_type, base_url, api_key, metadata, created_at, updated_at +` + +type UpdateLlmProviderParams struct { + Name string `json:"name"` + ClientType string `json:"client_type"` + BaseUrl string `json:"base_url"` + ApiKey string `json:"api_key"` + Metadata []byte `json:"metadata"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderParams) (LlmProvider, error) { + row := q.db.QueryRow(ctx, updateLlmProvider, + arg.Name, + arg.ClientType, + arg.BaseUrl, + arg.ApiKey, + arg.Metadata, + arg.ID, + ) + var i LlmProvider + err := row.Scan( + &i.ID, + &i.Name, + &i.ClientType, + &i.BaseUrl, + &i.ApiKey, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const updateModel = `-- name: UpdateModel :one UPDATE models SET name = $1, - base_url = $2, - api_key = $3, - client_type = $4, - dimensions = $5, - type = $6, + llm_provider_id = $2, + dimensions = $3, + is_multimodal = $4, + type = $5, updated_at = now() -WHERE id = $7 -RETURNING id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at +WHERE id = $6 +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at ` type UpdateModelParams struct { - Name pgtype.Text `json:"name"` - BaseUrl string `json:"base_url"` - ApiKey string `json:"api_key"` - ClientType string `json:"client_type"` - Dimensions pgtype.Int4 `json:"dimensions"` - Type string `json:"type"` - ID pgtype.UUID `json:"id"` + Name pgtype.Text `json:"name"` + LlmProviderID pgtype.UUID `json:"llm_provider_id"` + Dimensions pgtype.Int4 `json:"dimensions"` + IsMultimodal bool `json:"is_multimodal"` + Type string `json:"type"` + ID pgtype.UUID `json:"id"` } func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) { row := q.db.QueryRow(ctx, updateModel, arg.Name, - arg.BaseUrl, - arg.ApiKey, - arg.ClientType, + arg.LlmProviderID, arg.Dimensions, + arg.IsMultimodal, arg.Type, arg.ID, ) @@ -294,10 +644,9 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, @@ -309,33 +658,30 @@ const updateModelByModelID = `-- name: UpdateModelByModelID :one UPDATE models SET name = $1, - base_url = $2, - api_key = $3, - client_type = $4, - dimensions = $5, - type = $6, + llm_provider_id = $2, + dimensions = $3, + is_multimodal = $4, + type = $5, updated_at = now() -WHERE model_id = $7 -RETURNING id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at +WHERE model_id = $6 +RETURNING id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at ` type UpdateModelByModelIDParams struct { - Name pgtype.Text `json:"name"` - BaseUrl string `json:"base_url"` - ApiKey string `json:"api_key"` - ClientType string `json:"client_type"` - Dimensions pgtype.Int4 `json:"dimensions"` - Type string `json:"type"` - ModelID string `json:"model_id"` + Name pgtype.Text `json:"name"` + LlmProviderID pgtype.UUID `json:"llm_provider_id"` + Dimensions pgtype.Int4 `json:"dimensions"` + IsMultimodal bool `json:"is_multimodal"` + Type string `json:"type"` + ModelID string `json:"model_id"` } func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByModelIDParams) (Model, error) { row := q.db.QueryRow(ctx, updateModelByModelID, arg.Name, - arg.BaseUrl, - arg.ApiKey, - arg.ClientType, + arg.LlmProviderID, arg.Dimensions, + arg.IsMultimodal, arg.Type, arg.ModelID, ) @@ -344,13 +690,50 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod &i.ID, &i.ModelID, &i.Name, - &i.BaseUrl, - &i.ApiKey, - &i.ClientType, + &i.LlmProviderID, &i.Dimensions, + &i.IsMultimodal, &i.Type, &i.CreatedAt, &i.UpdatedAt, ) return i, err } + +const updateModelVariant = `-- name: UpdateModelVariant :one +UPDATE model_variants +SET + variant_id = $1, + weight = $2, + metadata = $3, + updated_at = now() +WHERE id = $4 +RETURNING id, model_uuid, variant_id, weight, metadata, created_at, updated_at +` + +type UpdateModelVariantParams struct { + VariantID string `json:"variant_id"` + Weight int32 `json:"weight"` + Metadata []byte `json:"metadata"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) UpdateModelVariant(ctx context.Context, arg UpdateModelVariantParams) (ModelVariant, error) { + row := q.db.QueryRow(ctx, updateModelVariant, + arg.VariantID, + arg.Weight, + arg.Metadata, + arg.ID, + ) + var i ModelVariant + err := row.Scan( + &i.ID, + &i.ModelUuid, + &i.VariantID, + &i.Weight, + &i.Metadata, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/embeddings/dashscope.go b/internal/embeddings/dashscope.go new file mode 100644 index 00000000..939e8822 --- /dev/null +++ b/internal/embeddings/dashscope.go @@ -0,0 +1,142 @@ +package embeddings + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + DefaultDashScopeBaseURL = "https://dashscope.aliyuncs.com" + DashScopeEmbeddingPath = "/api/v1/services/embeddings/multimodal-embedding/multimodal-embedding" +) + +type DashScopeEmbedder struct { + apiKey string + baseURL string + model string + http *http.Client +} + +type DashScopeUsage struct { + InputTokens int `json:"input_tokens"` + ImageTokens int `json:"image_tokens"` + ImageCount int `json:"image_count,omitempty"` + Duration int `json:"duration,omitempty"` +} + +type dashScopeRequest struct { + Model string `json:"model"` + Input dashScopeRequestInput `json:"input"` +} + +type dashScopeRequestInput struct { + Contents []map[string]string `json:"contents"` +} + +type dashScopeResponse struct { + Output struct { + Embeddings []struct { + Index int `json:"index"` + Embedding []float32 `json:"embedding"` + Type string `json:"type"` + } `json:"embeddings"` + } `json:"output"` + Usage DashScopeUsage `json:"usage"` + RequestID string `json:"request_id"` + Code string `json:"code"` + Message string `json:"message"` +} + +func NewDashScopeEmbedder(apiKey, baseURL, model string, timeout time.Duration) *DashScopeEmbedder { + if baseURL == "" { + baseURL = DefaultDashScopeBaseURL + } + if timeout <= 0 { + timeout = 10 * time.Second + } + return &DashScopeEmbedder{ + apiKey: apiKey, + baseURL: strings.TrimRight(baseURL, "/"), + model: model, + http: &http.Client{ + Timeout: timeout, + }, + } +} + +func (e *DashScopeEmbedder) Embed(ctx context.Context, text string, imageURL string, videoURL string) ([]float32, DashScopeUsage, error) { + contents := make([]map[string]string, 0, 3) + if strings.TrimSpace(text) != "" { + contents = append(contents, map[string]string{"text": text}) + } + if strings.TrimSpace(imageURL) != "" { + contents = append(contents, map[string]string{"image": imageURL}) + } + if strings.TrimSpace(videoURL) != "" { + contents = append(contents, map[string]string{"video": videoURL}) + } + if len(contents) == 0 { + return nil, DashScopeUsage{}, fmt.Errorf("dashscope input is required") + } + + payload, err := json.Marshal(dashScopeRequest{ + Model: e.model, + Input: dashScopeRequestInput{Contents: contents}, + }) + if err != nil { + return nil, DashScopeUsage{}, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+DashScopeEmbeddingPath, bytes.NewReader(payload)) + if err != nil { + return nil, DashScopeUsage{}, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+e.apiKey) + + resp, err := e.http.Do(req) + if err != nil { + return nil, DashScopeUsage{}, err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, DashScopeUsage{}, fmt.Errorf("dashscope embeddings error: %s", strings.TrimSpace(string(body))) + } + + var parsed dashScopeResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, DashScopeUsage{}, err + } + if parsed.Code != "" { + return nil, parsed.Usage, fmt.Errorf("dashscope embeddings error: %s", parsed.Message) + } + if len(parsed.Output.Embeddings) == 0 { + return nil, parsed.Usage, fmt.Errorf("dashscope embeddings empty response") + } + + preferredType := "" + if strings.TrimSpace(text) != "" { + preferredType = "text" + } else if strings.TrimSpace(imageURL) != "" { + preferredType = "image" + } else if strings.TrimSpace(videoURL) != "" { + preferredType = "video" + } + + if preferredType != "" { + for _, item := range parsed.Output.Embeddings { + if strings.EqualFold(item.Type, preferredType) && len(item.Embedding) > 0 { + return item.Embedding, parsed.Usage, nil + } + } + } + + return parsed.Output.Embeddings[0].Embedding, parsed.Usage, nil +} diff --git a/internal/memory/embeddings.go b/internal/embeddings/embeddings.go similarity index 99% rename from internal/memory/embeddings.go rename to internal/embeddings/embeddings.go index 92b44eef..808c9dfb 100644 --- a/internal/memory/embeddings.go +++ b/internal/embeddings/embeddings.go @@ -1,4 +1,4 @@ -package memory +package embeddings import ( "bytes" diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go new file mode 100644 index 00000000..24af20cf --- /dev/null +++ b/internal/embeddings/resolver.go @@ -0,0 +1,228 @@ +package embeddings + +import ( + "context" + "errors" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" +) + +const ( + TypeText = "text" + TypeMultimodal = "multimodal" + + ProviderOpenAI = "openai" + ProviderBedrock = "bedrock" + ProviderDashScope = "dashscope" +) + +type Request struct { + Type string + Provider string + Model string + Dimensions int + Input Input +} + +type Input struct { + Text string + ImageURL string + VideoURL string +} + +type Usage struct { + InputTokens int + ImageTokens int + VideoTokens int +} + +type Result struct { + Type string + Provider string + Model string + Dimensions int + Embedding []float32 + Usage Usage +} + +type Resolver struct { + modelsService *models.Service + queries *sqlc.Queries + timeout time.Duration +} + +func NewResolver(modelsService *models.Service, queries *sqlc.Queries, timeout time.Duration) *Resolver { + return &Resolver{ + modelsService: modelsService, + queries: queries, + timeout: timeout, + } +} + +func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) { + req.Type = strings.ToLower(strings.TrimSpace(req.Type)) + req.Provider = strings.ToLower(strings.TrimSpace(req.Provider)) + req.Model = strings.TrimSpace(req.Model) + req.Input.Text = strings.TrimSpace(req.Input.Text) + req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL) + req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL) + + if req.Type == "" { + return Result{}, errors.New("type is required") + } + switch req.Type { + case TypeText: + if req.Provider != "" && req.Provider != ProviderOpenAI { + return Result{}, errors.New("invalid provider for text embeddings") + } + if req.Input.Text == "" { + return Result{}, errors.New("text input is required") + } + case TypeMultimodal: + if req.Provider != "" && req.Provider != ProviderBedrock && req.Provider != ProviderDashScope { + return Result{}, errors.New("invalid provider for multimodal embeddings") + } + if req.Input.Text == "" && req.Input.ImageURL == "" && req.Input.VideoURL == "" { + return Result{}, errors.New("multimodal input is required") + } + default: + return Result{}, errors.New("invalid embeddings type") + } + + selected, err := r.selectEmbeddingModel(ctx, req) + if err != nil { + return Result{}, err + } + provider, err := r.fetchProvider(ctx, selected.LlmProviderID) + if err != nil { + return Result{}, err + } + + req.Model = selected.ModelID + req.Dimensions = selected.Dimensions + req.Provider = strings.ToLower(strings.TrimSpace(provider.ClientType)) + if req.Model == "" { + return Result{}, errors.New("embedding model id not configured") + } + if req.Dimensions <= 0 { + return Result{}, errors.New("embedding model dimensions not configured") + } + + timeout := r.timeout + if timeout <= 0 { + timeout = 10 * time.Second + } + + switch req.Type { + case TypeText: + if req.Provider != ProviderOpenAI { + return Result{}, errors.New("provider not implemented") + } + if strings.TrimSpace(provider.ApiKey) == "" { + return Result{}, errors.New("openai api key is required") + } + embedder := NewOpenAIEmbedder(provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout) + vector, err := embedder.Embed(ctx, req.Input.Text) + if err != nil { + return Result{}, err + } + return Result{ + Type: req.Type, + Provider: req.Provider, + Model: req.Model, + Dimensions: req.Dimensions, + Embedding: vector, + }, nil + case TypeMultimodal: + if req.Provider == ProviderDashScope { + if strings.TrimSpace(provider.ApiKey) == "" { + return Result{}, errors.New("dashscope api key is required") + } + dashscope := NewDashScopeEmbedder(provider.ApiKey, provider.BaseUrl, req.Model, timeout) + vector, usage, err := dashscope.Embed(ctx, req.Input.Text, req.Input.ImageURL, req.Input.VideoURL) + if err != nil { + return Result{}, err + } + return Result{ + Type: req.Type, + Provider: req.Provider, + Model: req.Model, + Dimensions: req.Dimensions, + Embedding: vector, + Usage: Usage{ + InputTokens: usage.InputTokens, + ImageTokens: usage.ImageTokens, + VideoTokens: usage.Duration, + }, + }, nil + } + return Result{}, errors.New("provider not implemented") + default: + return Result{}, errors.New("invalid embeddings type") + } +} + +func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (models.GetResponse, error) { + if r.modelsService == nil { + return models.GetResponse{}, errors.New("models service not configured") + } + + var candidates []models.GetResponse + var err error + if req.Provider != "" { + candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(req.Provider)) + } else { + candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeEmbedding) + } + if err != nil { + return models.GetResponse{}, err + } + + filtered := make([]models.GetResponse, 0, len(candidates)) + for _, model := range candidates { + if model.Type != models.ModelTypeEmbedding { + continue + } + if req.Type == TypeMultimodal && !model.IsMultimodal { + continue + } + if req.Type == TypeText && model.IsMultimodal { + continue + } + filtered = append(filtered, model) + } + if len(filtered) == 0 { + return models.GetResponse{}, errors.New("no embedding models available") + } + if req.Model != "" { + for _, model := range filtered { + if model.ModelID == req.Model { + return model, nil + } + } + return models.GetResponse{}, errors.New("embedding model not found") + } + return filtered[0], nil +} + +func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.LlmProvider, error) { + if r.queries == nil { + return sqlc.LlmProvider{}, errors.New("llm provider queries not configured") + } + if strings.TrimSpace(providerID) == "" { + return sqlc.LlmProvider{}, errors.New("llm provider id missing") + } + parsed, err := uuid.Parse(providerID) + if err != nil { + return sqlc.LlmProvider{}, err + } + pgID := pgtype.UUID{Valid: true} + copy(pgID.Bytes[:], parsed[:]) + return r.queries.GetLlmProviderByID(ctx, pgID) +} diff --git a/internal/handlers/embeddings.go b/internal/handlers/embeddings.go new file mode 100644 index 00000000..6ab5f2fa --- /dev/null +++ b/internal/handlers/embeddings.go @@ -0,0 +1,136 @@ +package handlers + +import ( + "net/http" + "strings" + "time" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/embeddings" + "github.com/memohai/memoh/internal/models" +) + +const DefaultEmbeddingTimeout = 10 * time.Second + +type EmbeddingsHandler struct { + resolver *embeddings.Resolver +} + +type EmbeddingsRequest struct { + Type string `json:"type"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Input EmbeddingsInput `json:"input"` +} + +type EmbeddingsInput struct { + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` + VideoURL string `json:"video_url,omitempty"` +} + +type EmbeddingsResponse struct { + Type string `json:"type"` + Provider string `json:"provider"` + Model string `json:"model"` + Dimensions int `json:"dimensions"` + Embedding []float32 `json:"embedding"` + Usage EmbeddingsUsage `json:"usage,omitempty"` + Message string `json:"message,omitempty"` +} + +type EmbeddingsUsage struct { + InputTokens int `json:"input_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` + VideoTokens int `json:"video_tokens,omitempty"` +} + +func NewEmbeddingsHandler(modelsService *models.Service, queries *sqlc.Queries) *EmbeddingsHandler { + return &EmbeddingsHandler{ + resolver: embeddings.NewResolver(modelsService, queries, DefaultEmbeddingTimeout), + } +} + +func (h *EmbeddingsHandler) Register(e *echo.Echo) { + e.POST("/embeddings", h.Embed) +} + +// Embed godoc +// @Summary Create embeddings +// @Description Create text or multimodal embeddings +// @Tags embeddings +// @Param payload body EmbeddingsRequest true "Embeddings request" +// @Success 200 {object} EmbeddingsResponse +// @Failure 400 {object} ErrorResponse +// @Failure 501 {object} EmbeddingsResponse +// @Failure 500 {object} ErrorResponse +// @Router /embeddings [post] +func (h *EmbeddingsHandler) Embed(c echo.Context) error { + var req EmbeddingsRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + req.Type = normalizeEmbeddingValue(req.Type) + req.Provider = normalizeEmbeddingValue(req.Provider) + req.Model = strings.TrimSpace(req.Model) + req.Input.Text = strings.TrimSpace(req.Input.Text) + req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL) + req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL) + + result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{ + Type: req.Type, + Provider: req.Provider, + Model: req.Model, + Dimensions: req.Dimensions, + Input: embeddings.Input{ + Text: req.Input.Text, + ImageURL: req.Input.ImageURL, + VideoURL: req.Input.VideoURL, + }, + }) + if err != nil { + message := err.Error() + switch message { + case "no embedding models available": + return echo.NewHTTPError(http.StatusNotFound, message) + case "embedding model not found": + return echo.NewHTTPError(http.StatusBadRequest, message) + case "provider not implemented": + resp := EmbeddingsResponse{ + Type: req.Type, + Provider: req.Provider, + Model: req.Model, + Dimensions: req.Dimensions, + Embedding: []float32{}, + Message: "embeddings provider not implemented", + } + return c.JSON(http.StatusNotImplemented, resp) + default: + if strings.Contains(message, "required") || strings.Contains(message, "invalid") { + return echo.NewHTTPError(http.StatusBadRequest, message) + } + return echo.NewHTTPError(http.StatusInternalServerError, message) + } + } + + return c.JSON(http.StatusOK, EmbeddingsResponse{ + Type: result.Type, + Provider: result.Provider, + Model: result.Model, + Dimensions: result.Dimensions, + Embedding: result.Embedding, + Usage: EmbeddingsUsage{ + InputTokens: result.Usage.InputTokens, + ImageTokens: result.Usage.ImageTokens, + VideoTokens: result.Usage.VideoTokens, + }, + }) +} + +func normalizeEmbeddingValue(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index 8671ff53..8218ca38 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -22,6 +22,7 @@ func NewMemoryHandler(service *memory.Service) *MemoryHandler { func (h *MemoryHandler) Register(e *echo.Echo) { group := e.Group("/memory") group.POST("/add", h.Add) + group.POST("/embed", h.EmbedUpsert) group.POST("/search", h.Search) group.POST("/update", h.Update) group.GET("/memories/:memoryId", h.Get) @@ -30,6 +31,37 @@ func (h *MemoryHandler) Register(e *echo.Echo) { group.DELETE("/memories", h.DeleteAll) } +// EmbedUpsert godoc +// @Summary Embed and upsert memory +// @Description Embed text or multimodal input and upsert into memory store +// @Tags memory +// @Param payload body memory.EmbedUpsertRequest true "Embed upsert request" +// @Success 200 {object} memory.EmbedUpsertResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /memory/embed [post] +func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { + userID, err := h.requireUserID(c) + if err != nil { + return err + } + + var req memory.EmbedUpsertRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if req.UserID != "" && req.UserID != userID { + return echo.NewHTTPError(http.StatusForbidden, "user mismatch") + } + req.UserID = userID + + resp, err := h.service.EmbedUpsert(c.Request().Context(), req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + // Add godoc // @Summary Add memory // @Description Add memory for a user via memory diff --git a/internal/memory/qdrant_store.go b/internal/memory/qdrant_store.go index 9f2e4104..b425ad99 100644 --- a/internal/memory/qdrant_store.go +++ b/internal/memory/qdrant_store.go @@ -15,12 +15,18 @@ type QdrantStore struct { client *qdrant.Client collection string dimension int + baseURL string + apiKey string + timeout time.Duration + vectorNames map[string]int + usesNamedVectors bool } type qdrantPoint struct { - ID string `json:"id"` - Vector []float32 `json:"vector"` - Payload map[string]interface{} `json:"payload,omitempty"` + ID string `json:"id"` + Vector []float32 `json:"vector"` + VectorName string `json:"vector_name,omitempty"` + Payload map[string]interface{} `json:"payload,omitempty"` } func NewQdrantStore(baseURL, apiKey, collection string, dimension int, timeout time.Duration) (*QdrantStore, error) { @@ -50,11 +56,59 @@ func NewQdrantStore(baseURL, apiKey, collection string, dimension int, timeout t client: client, collection: collection, dimension: dimension, + baseURL: baseURL, + apiKey: apiKey, + timeout: timeoutOrDefault(timeout), } ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout)) defer cancel() - if err := store.ensureCollection(ctx); err != nil { + if err := store.ensureCollection(ctx, nil); err != nil { + return nil, err + } + return store, nil +} + +func (s *QdrantStore) NewSibling(collection string, dimension int) (*QdrantStore, error) { + return NewQdrantStore(s.baseURL, s.apiKey, collection, dimension, s.timeout) +} + +func NewQdrantStoreWithVectors(baseURL, apiKey, collection string, vectors map[string]int, timeout time.Duration) (*QdrantStore, error) { + host, port, useTLS, err := parseQdrantEndpoint(baseURL) + if err != nil { + return nil, err + } + if collection == "" { + collection = "memory" + } + if len(vectors) == 0 { + return nil, fmt.Errorf("vectors map is required") + } + + cfg := &qdrant.Config{ + Host: host, + Port: port, + APIKey: apiKey, + UseTLS: useTLS, + } + client, err := qdrant.NewClient(cfg) + if err != nil { + return nil, err + } + + store := &QdrantStore{ + client: client, + collection: collection, + baseURL: baseURL, + apiKey: apiKey, + timeout: timeoutOrDefault(timeout), + vectorNames: vectors, + usesNamedVectors: true, + } + + ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout)) + defer cancel() + if err := store.ensureCollection(ctx, vectors); err != nil { return nil, err } return store, nil @@ -70,9 +124,17 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { if err != nil { return err } + var vectors *qdrant.Vectors + if point.VectorName != "" && s.usesNamedVectors { + vectors = qdrant.NewVectorsMap(map[string]*qdrant.Vector{ + point.VectorName: qdrant.NewVectorDense(point.Vector), + }) + } else { + vectors = qdrant.NewVectorsDense(point.Vector) + } qPoints = append(qPoints, &qdrant.PointStruct{ Id: qdrant.NewIDUUID(point.ID), - Vectors: qdrant.NewVectorsDense(point.Vector), + Vectors: vectors, Payload: payload, }) } @@ -84,14 +146,19 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { return err } -func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, filters map[string]interface{}) ([]qdrantPoint, []float64, error) { +func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, vectorName string) ([]qdrantPoint, []float64, error) { if limit <= 0 { limit = 10 } filter := buildQdrantFilter(filters) + var using *string + if vectorName != "" && s.usesNamedVectors { + using = qdrant.PtrOf(vectorName) + } results, err := s.client.Query(ctx, &qdrant.QueryPoints{ CollectionName: s.collection, Query: qdrant.NewQueryDense(vector), + Using: using, Limit: qdrant.PtrOf(uint64(limit)), Filter: filter, WithPayload: qdrant.NewWithPayload(true), @@ -112,6 +179,27 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f return points, scores, nil } +func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) { + pointsBySource := make(map[string][]qdrantPoint, len(sources)) + scoresBySource := make(map[string][]float64, len(sources)) + if len(sources) == 0 { + return pointsBySource, scoresBySource, nil + } + for _, source := range sources { + merged := cloneFilters(filters) + if source != "" { + merged["source"] = source + } + points, scores, err := s.Search(ctx, vector, limit, merged, vectorName) + if err != nil { + return nil, nil, err + } + pointsBySource[source] = points + scoresBySource[source] = scores + } + return pointsBySource, scoresBySource, nil +} + func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) { result, err := s.client.Get(ctx, &qdrant.GetPoints{ CollectionName: s.collection, @@ -178,13 +266,26 @@ func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]interfac return err } -func (s *QdrantStore) ensureCollection(ctx context.Context) error { +func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]int) error { exists, err := s.client.CollectionExists(ctx, s.collection) if err != nil { return err } if exists { - return nil + return s.refreshCollectionSchema(ctx, vectors) + } + if len(vectors) > 0 { + params := make(map[string]*qdrant.VectorParams, len(vectors)) + for name, dim := range vectors { + params[name] = &qdrant.VectorParams{ + Size: uint64(dim), + Distance: qdrant.Distance_Cosine, + } + } + return s.client.CreateCollection(ctx, &qdrant.CreateCollection{ + CollectionName: s.collection, + VectorsConfig: qdrant.NewVectorsConfigMap(params), + }) } return s.client.CreateCollection(ctx, &qdrant.CreateCollection{ CollectionName: s.collection, @@ -195,6 +296,40 @@ func (s *QdrantStore) ensureCollection(ctx context.Context) error { }) } +func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[string]int) error { + info, err := s.client.GetCollectionInfo(ctx, s.collection) + if err != nil { + return err + } + config := info.GetConfig() + if config == nil || config.GetParams() == nil || config.GetParams().GetVectorsConfig() == nil { + return nil + } + vectorsConfig := config.GetParams().GetVectorsConfig() + if vectorsConfig.GetParamsMap() != nil { + s.usesNamedVectors = true + s.vectorNames = map[string]int{} + for name, vec := range vectorsConfig.GetParamsMap().GetMap() { + if vec != nil { + s.vectorNames[name] = int(vec.GetSize()) + } + } + if len(vectors) == 0 { + return nil + } + for name, dim := range vectors { + if existing, ok := s.vectorNames[name]; ok && existing == dim { + continue + } + return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim) + } + return nil + } + s.usesNamedVectors = false + s.vectorNames = nil + return nil +} + func parseQdrantEndpoint(endpoint string) (string, int, bool, error) { if endpoint == "" { return "127.0.0.1", 6334, false, nil @@ -247,6 +382,17 @@ func buildQdrantFilter(filters map[string]interface{}) *qdrant.Filter { } } +func cloneFilters(filters map[string]interface{}) map[string]interface{} { + if len(filters) == 0 { + return map[string]interface{}{} + } + clone := make(map[string]interface{}, len(filters)) + for key, value := range filters { + clone[key] = value + } + return clone +} + func buildQdrantCondition(key string, value interface{}) *qdrant.Condition { switch typed := value.(type) { case string: diff --git a/internal/memory/service.go b/internal/memory/service.go index c94d2289..f8bb0e7b 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -5,23 +5,33 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "math" + "sort" "strings" "time" "github.com/google/uuid" + + "github.com/memohai/memoh/internal/embeddings" ) type Service struct { - llm *LLMClient - embedder Embedder - store *QdrantStore + llm *LLMClient + embedder embeddings.Embedder + store *QdrantStore + resolver *embeddings.Resolver + defaultTextModelID string + defaultMultimodalModelID string } -func NewService(llm *LLMClient, embedder Embedder, store *QdrantStore) *Service { +func NewService(llm *LLMClient, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service { return &Service{ - llm: llm, - embedder: embedder, - store: store, + llm: llm, + embedder: embedder, + store: store, + resolver: resolver, + defaultTextModelID: defaultTextModelID, + defaultMultimodalModelID: defaultMultimodalModelID, } } @@ -122,27 +132,128 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse return SearchResponse{}, fmt.Errorf("query is required") } filters := buildSearchFilters(req) - vector, err := s.embedder.Embed(ctx, req.Query) - if err != nil { - return SearchResponse{}, err + modality := "" + if raw, ok := filters["modality"].(string); ok { + modality = strings.ToLower(strings.TrimSpace(raw)) } - points, scores, err := s.store.Search(ctx, vector, req.Limit, filters) - if err != nil { - return SearchResponse{}, err - } - - results := make([]MemoryItem, 0, len(points)) - for idx, point := range points { - item := payloadToMemoryItem(point.ID, point.Payload) - if idx < len(scores) { - item.Score = scores[idx] + var ( + vector []float32 + store *QdrantStore + vectorName string + err error + ) + if modality == embeddings.TypeMultimodal { + if s.resolver == nil { + return SearchResponse{}, fmt.Errorf("embeddings resolver not configured") } - results = append(results, item) + result, err := s.resolver.Embed(ctx, embeddings.Request{ + Type: embeddings.TypeMultimodal, + Input: embeddings.Input{ + Text: req.Query, + }, + }) + if err != nil { + return SearchResponse{}, err + } + vector = result.Embedding + store = s.store + vectorName = s.vectorNameForMultimodal() + } else { + vector, err = s.embedder.Embed(ctx, req.Query) + if err != nil { + return SearchResponse{}, err + } + store = s.store + vectorName = s.vectorNameForText() } + + if len(req.Sources) == 0 { + points, scores, err := store.Search(ctx, vector, req.Limit, filters, vectorName) + if err != nil { + return SearchResponse{}, err + } + + results := make([]MemoryItem, 0, len(points)) + for idx, point := range points { + item := payloadToMemoryItem(point.ID, point.Payload) + if idx < len(scores) { + item.Score = scores[idx] + } + results = append(results, item) + } + return SearchResponse{Results: results}, nil + } + + pointsBySource, scoresBySource, err := store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName) + if err != nil { + return SearchResponse{}, err + } + results := fuseByRankFusion(pointsBySource, scoresBySource) return SearchResponse{Results: results}, nil } +func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (EmbedUpsertResponse, error) { + if s.resolver == nil { + return EmbedUpsertResponse{}, fmt.Errorf("embeddings resolver not configured") + } + if req.UserID == "" && req.AgentID == "" && req.RunID == "" { + return EmbedUpsertResponse{}, fmt.Errorf("user_id, agent_id or run_id is required") + } + req.Type = strings.TrimSpace(req.Type) + req.Provider = strings.TrimSpace(req.Provider) + req.Model = strings.TrimSpace(req.Model) + req.Input.Text = strings.TrimSpace(req.Input.Text) + req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL) + req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL) + + result, err := s.resolver.Embed(ctx, embeddings.Request{ + Type: req.Type, + Provider: req.Provider, + Model: req.Model, + Input: embeddings.Input{ + Text: req.Input.Text, + ImageURL: req.Input.ImageURL, + VideoURL: req.Input.VideoURL, + }, + }) + if err != nil { + return EmbedUpsertResponse{}, err + } + + if s.store == nil { + return EmbedUpsertResponse{}, fmt.Errorf("qdrant store not configured") + } + + vectorName := "" + if s.store != nil && s.store.usesNamedVectors { + vectorName = result.Model + } + + id := uuid.NewString() + filters := buildEmbedFilters(req) + payload := buildEmbeddingPayload(req, filters) + if metadata, ok := payload["metadata"].(map[string]interface{}); ok && result.Model != "" { + metadata["model_id"] = result.Model + } + if err := s.store.Upsert(ctx, []qdrantPoint{{ + ID: id, + Vector: result.Embedding, + VectorName: vectorName, + Payload: payload, + }}); err != nil { + return EmbedUpsertResponse{}, err + } + + item := payloadToMemoryItem(id, payload) + return EmbedUpsertResponse{ + Item: item, + Provider: result.Provider, + Model: result.Model, + Dimensions: result.Dimensions, + }, nil +} + func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, error) { if strings.TrimSpace(req.MemoryID) == "" { return MemoryItem{}, fmt.Errorf("memory_id is required") @@ -171,6 +282,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er if err := s.store.Upsert(ctx, []qdrantPoint{{ ID: req.MemoryID, Vector: vector, + VectorName: s.vectorNameForText(), Payload: payload, }}); err != nil { return MemoryItem{}, err @@ -270,7 +382,7 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters if err != nil { return nil, err } - points, _, err := s.store.Search(ctx, vector, 5, filters) + points, _, err := s.store.Search(ctx, vector, 5, filters, s.vectorNameForText()) if err != nil { return nil, err } @@ -301,6 +413,7 @@ func (s *Service) applyAdd(ctx context.Context, text string, filters map[string] if err := s.store.Upsert(ctx, []qdrantPoint{{ ID: id, Vector: vector, + VectorName: s.vectorNameForText(), Payload: payload, }}); err != nil { return MemoryItem{}, err @@ -337,6 +450,7 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ if err := s.store.Upsert(ctx, []qdrantPoint{{ ID: id, Vector: vector, + VectorName: s.vectorNameForText(), Payload: payload, }}); err != nil { return MemoryItem{}, err @@ -403,6 +517,68 @@ func buildSearchFilters(req SearchRequest) map[string]interface{} { return filters } +func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} { + filters := map[string]interface{}{} + for key, value := range req.Filters { + filters[key] = value + } + if req.UserID != "" { + filters["userId"] = req.UserID + } + if req.AgentID != "" { + filters["agentId"] = req.AgentID + } + if req.RunID != "" { + filters["runId"] = req.RunID + } + return filters +} + +func buildEmbeddingPayload(req EmbedUpsertRequest, filters map[string]interface{}) map[string]interface{} { + text := req.Input.Text + payload := buildPayload(text, filters, req.Metadata, "") + payload["hash"] = hashEmbeddingInput(req.Input.Text, req.Input.ImageURL, req.Input.VideoURL) + if req.Source != "" { + payload["source"] = req.Source + } + modality := "text" + if req.Type != "" { + modality = strings.ToLower(req.Type) + } + payload["modality"] = modality + + if payload["metadata"] == nil { + payload["metadata"] = map[string]interface{}{} + } + if metadata, ok := payload["metadata"].(map[string]interface{}); ok { + if req.Source != "" { + metadata["source"] = req.Source + } + metadata["modality"] = modality + if req.Input.ImageURL != "" { + metadata["image_url"] = req.Input.ImageURL + } + if req.Input.VideoURL != "" { + metadata["video_url"] = req.Input.VideoURL + } + } + return payload +} + +func (s *Service) vectorNameForText() string { + if s.store == nil || !s.store.usesNamedVectors { + return "" + } + return strings.TrimSpace(s.defaultTextModelID) +} + +func (s *Service) vectorNameForMultimodal() string { + if s.store == nil || !s.store.usesNamedVectors { + return "" + } + return strings.TrimSpace(s.defaultMultimodalModelID) +} + func buildPayload(text string, filters map[string]interface{}, metadata map[string]interface{}, createdAt string) map[string]interface{} { if createdAt == "" { createdAt = time.Now().UTC().Format(time.RFC3339) @@ -450,6 +626,16 @@ func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem { } if meta, ok := payload["metadata"].(map[string]interface{}); ok { item.Metadata = meta + } else if payload["metadata"] == nil { + item.Metadata = map[string]interface{}{} + } + if item.Metadata != nil { + if source, ok := payload["source"].(string); ok && source != "" { + item.Metadata["source"] = source + } + if modality, ok := payload["modality"].(string); ok && modality != "" { + item.Metadata["modality"] = modality + } } return item } @@ -459,6 +645,16 @@ func hashMemory(text string) string { return hex.EncodeToString(sum[:]) } +func hashEmbeddingInput(text, imageURL, videoURL string) string { + combined := strings.Join([]string{ + strings.TrimSpace(text), + strings.TrimSpace(imageURL), + strings.TrimSpace(videoURL), + }, "|") + sum := md5.Sum([]byte(combined)) + return hex.EncodeToString(sum[:]) +} + func mergeMetadata(base interface{}, extra map[string]interface{}) map[string]interface{} { merged := map[string]interface{}{} if baseMap, ok := base.(map[string]interface{}); ok { @@ -471,3 +667,93 @@ func mergeMetadata(base interface{}, extra map[string]interface{}) map[string]in } return merged } + +type rerankCandidate struct { + ID string + Payload map[string]interface{} + Score float64 + Source string + Rank int +} + +const ( + fusionModeRRF = "rrf" + fusionModeCombMNZ = "combmnz" + fusionMode = fusionModeRRF + rrfK = 60.0 +) + +func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, scoresBySource map[string][]float64) []MemoryItem { + candidates := map[string]*rerankCandidate{} + rrfScores := map[string]float64{} + combScores := map[string]float64{} + combCounts := map[string]int{} + + for source, points := range pointsBySource { + scores := scoresBySource[source] + minScore := math.MaxFloat64 + maxScore := -math.MaxFloat64 + for idx, point := range points { + if idx >= len(scores) { + continue + } + score := scores[idx] + if score < minScore { + minScore = score + } + if score > maxScore { + maxScore = score + } + if _, ok := candidates[point.ID]; !ok { + candidates[point.ID] = &rerankCandidate{ + ID: point.ID, + Payload: point.Payload, + } + } + } + if minScore == math.MaxFloat64 { + minScore = 0 + } + if maxScore == -math.MaxFloat64 { + maxScore = minScore + } + + for idx, point := range points { + if idx >= len(scores) { + continue + } + score := scores[idx] + rank := float64(idx + 1) + rrfScores[point.ID] += 1.0 / (rrfK + rank) + + scoreNorm := normalizeScore(score, minScore, maxScore) + combScores[point.ID] += scoreNorm + combCounts[point.ID]++ + } + } + + items := make([]MemoryItem, 0, len(candidates)) + for id, candidate := range candidates { + item := payloadToMemoryItem(candidate.ID, candidate.Payload) + switch fusionMode { + case fusionModeCombMNZ: + item.Score = combScores[id] * float64(combCounts[id]) + default: + item.Score = rrfScores[id] + } + items = append(items, item) + } + + sort.Slice(items, func(i, j int) bool { + return items[i].Score > items[j].Score + }) + return items +} + +func normalizeScore(score, minScore, maxScore float64) float64 { + if maxScore <= minScore { + return 1 + } + return (score - minScore) / (maxScore - minScore) +} + diff --git a/internal/memory/types.go b/internal/memory/types.go index 50d41215..0a20f61e 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -23,6 +23,7 @@ type SearchRequest struct { RunID string `json:"run_id,omitempty"` Limit int `json:"limit,omitempty"` Filters map[string]interface{} `json:"filters,omitempty"` + Sources []string `json:"sources,omitempty"` } type UpdateRequest struct { @@ -43,6 +44,32 @@ type DeleteAllRequest struct { RunID string `json:"run_id,omitempty"` } +type EmbedInput struct { + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` + VideoURL string `json:"video_url,omitempty"` +} + +type EmbedUpsertRequest struct { + Type string `json:"type"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Input EmbedInput `json:"input"` + Source string `json:"source,omitempty"` + UserID string `json:"user_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` +} + +type EmbedUpsertResponse struct { + Item MemoryItem `json:"item"` + Provider string `json:"provider"` + Model string `json:"model"` + Dimensions int `json:"dimensions"` +} + type MemoryItem struct { ID string `json:"id"` Memory string `json:"memory"` diff --git a/internal/models/models.go b/internal/models/models.go index 31771eec..a3c8f713 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -29,12 +29,16 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro } // Convert to sqlc params + llmProviderID, err := parseUUID(model.LlmProviderID) + if err != nil { + return AddResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) + } + params := sqlc.CreateModelParams{ - ModelID: model.ModelID, - BaseUrl: model.BaseURL, - ApiKey: model.APIKey, - ClientType: string(model.ClientType), - Type: string(model.Type), + ModelID: model.ModelID, + LlmProviderID: llmProviderID, + IsMultimodal: model.IsMultimodal, + Type: string(model.Type), } // Handle optional name field @@ -123,7 +127,7 @@ func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetRes // ListByClientType returns models filtered by client type func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) { - if clientType != ClientTypeOpenAI && clientType != ClientTypeAnthropic && clientType != ClientTypeGoogle { + if !isValidClientType(clientType) { return nil, fmt.Errorf("invalid client type: %s", clientType) } @@ -148,13 +152,17 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) } params := sqlc.UpdateModelParams{ - ID: uuid, - BaseUrl: model.BaseURL, - ApiKey: model.APIKey, - ClientType: string(model.ClientType), - Type: string(model.Type), + ID: uuid, + IsMultimodal: model.IsMultimodal, + Type: string(model.Type), } + llmProviderID, err := parseUUID(model.LlmProviderID) + if err != nil { + return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) + } + params.LlmProviderID = llmProviderID + if model.Name != "" { params.Name = pgtype.Text{String: model.Name, Valid: true} } @@ -183,13 +191,17 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat } params := sqlc.UpdateModelByModelIDParams{ - ModelID: modelID, - BaseUrl: model.BaseURL, - ApiKey: model.APIKey, - ClientType: string(model.ClientType), - Type: string(model.Type), + ModelID: modelID, + IsMultimodal: model.IsMultimodal, + Type: string(model.Type), } + llmProviderID, err := parseUUID(model.LlmProviderID) + if err != nil { + return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err) + } + params.LlmProviderID = llmProviderID + if model.Name != "" { params.Name = pgtype.Text{String: model.Name, Valid: true} } @@ -274,14 +286,16 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { resp := GetResponse{ ModelId: dbModel.ModelID, Model: Model{ - ModelID: dbModel.ModelID, - BaseURL: dbModel.BaseUrl, - APIKey: dbModel.ApiKey, - ClientType: ClientType(dbModel.ClientType), - Type: ModelType(dbModel.Type), + ModelID: dbModel.ModelID, + IsMultimodal: dbModel.IsMultimodal, + Type: ModelType(dbModel.Type), }, } + if llmProviderID, ok := uuidStringFromPgUUID(dbModel.LlmProviderID); ok { + resp.Model.LlmProviderID = llmProviderID + } + if dbModel.Name.Valid { resp.Model.Name = dbModel.Name.String } @@ -300,3 +314,30 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse { } return responses } + +func isValidClientType(clientType ClientType) bool { + switch clientType { + case ClientTypeOpenAI, + ClientTypeAnthropic, + ClientTypeGoogle, + ClientTypeBedrock, + ClientTypeOllama, + ClientTypeAzure, + ClientTypeDashscope, + ClientTypeOther: + return true + default: + return false + } +} + +func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) { + if !value.Valid { + return "", false + } + id, err := uuid.FromBytes(value.Bytes[:]) + if err != nil { + return "", false + } + return id.String(), true +} diff --git a/internal/models/models_test.go b/internal/models/models_test.go index 315e669b..42573164 100644 --- a/internal/models/models_test.go +++ b/internal/models/models_test.go @@ -18,9 +18,7 @@ func ExampleService_Create() { // req := models.AddRequest{ // ModelID: "gpt-4", // Name: "GPT-4", - // BaseURL: "https://api.openai.com/v1", - // APIKey: "sk-...", - // ClientType: models.ClientTypeOpenAI, + // LlmProviderID: "11111111-1111-1111-1111-111111111111", // Type: models.ModelTypeChat, // } @@ -77,9 +75,7 @@ func ExampleService_UpdateByModelID() { // req := models.UpdateRequest{ // ModelID: "gpt-4", // Name: "GPT-4 Turbo", - // BaseURL: "https://api.openai.com/v1", - // APIKey: "sk-...", - // ClientType: models.ClientTypeOpenAI, + // LlmProviderID: "11111111-1111-1111-1111-111111111111", // Type: models.ModelTypeChat, // } @@ -111,89 +107,65 @@ func TestModel_Validate(t *testing.T) { { name: "valid chat model", model: models.Model{ - ModelID: "gpt-4", - Name: "GPT-4", - BaseURL: "https://api.openai.com/v1", - APIKey: "sk-test", - ClientType: models.ClientTypeOpenAI, - Type: models.ModelTypeChat, + ModelID: "gpt-4", + Name: "GPT-4", + LlmProviderID: "11111111-1111-1111-1111-111111111111", + Type: models.ModelTypeChat, }, wantErr: false, }, { name: "valid embedding model", model: models.Model{ - ModelID: "text-embedding-ada-002", - Name: "Ada Embeddings", - BaseURL: "https://api.openai.com/v1", - APIKey: "sk-test", - ClientType: models.ClientTypeOpenAI, - Type: models.ModelTypeEmbedding, - Dimensions: 1536, + ModelID: "text-embedding-ada-002", + Name: "Ada Embeddings", + LlmProviderID: "11111111-1111-1111-1111-111111111111", + Type: models.ModelTypeEmbedding, + Dimensions: 1536, }, wantErr: false, }, { name: "missing model_id", model: models.Model{ - BaseURL: "https://api.openai.com/v1", - APIKey: "sk-test", - ClientType: models.ClientTypeOpenAI, - Type: models.ModelTypeChat, + LlmProviderID: "11111111-1111-1111-1111-111111111111", + Type: models.ModelTypeChat, }, wantErr: true, }, { - name: "missing base_url", + name: "missing llm_provider_id", model: models.Model{ - ModelID: "gpt-4", - APIKey: "sk-test", - ClientType: models.ClientTypeOpenAI, - Type: models.ModelTypeChat, + ModelID: "gpt-4", + Type: models.ModelTypeChat, }, wantErr: true, }, { - name: "missing api_key", + name: "invalid llm_provider_id", model: models.Model{ - ModelID: "gpt-4", - BaseURL: "https://api.openai.com/v1", - ClientType: models.ClientTypeOpenAI, - Type: models.ModelTypeChat, - }, - wantErr: true, - }, - { - name: "invalid client type", - model: models.Model{ - ModelID: "gpt-4", - BaseURL: "https://api.openai.com/v1", - APIKey: "sk-test", - ClientType: "invalid", - Type: models.ModelTypeChat, + ModelID: "gpt-4", + LlmProviderID: "not-a-uuid", + Type: models.ModelTypeChat, }, wantErr: true, }, { name: "invalid model type", model: models.Model{ - ModelID: "gpt-4", - BaseURL: "https://api.openai.com/v1", - APIKey: "sk-test", - ClientType: models.ClientTypeOpenAI, - Type: "invalid", + ModelID: "gpt-4", + LlmProviderID: "11111111-1111-1111-1111-111111111111", + Type: "invalid", }, wantErr: true, }, { name: "embedding model missing dimensions", model: models.Model{ - ModelID: "text-embedding-ada-002", - BaseURL: "https://api.openai.com/v1", - APIKey: "sk-test", - ClientType: models.ClientTypeOpenAI, - Type: models.ModelTypeEmbedding, - Dimensions: 0, + ModelID: "text-embedding-ada-002", + LlmProviderID: "11111111-1111-1111-1111-111111111111", + Type: models.ModelTypeEmbedding, + Dimensions: 0, }, wantErr: true, }, @@ -221,6 +193,11 @@ func TestModelTypes(t *testing.T) { assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI) assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic) assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle) + assert.Equal(t, models.ClientType("bedrock"), models.ClientTypeBedrock) + assert.Equal(t, models.ClientType("ollama"), models.ClientTypeOllama) + assert.Equal(t, models.ClientType("azure"), models.ClientTypeAzure) + assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope) + assert.Equal(t, models.ClientType("other"), models.ClientTypeOther) }) } diff --git a/internal/models/types.go b/internal/models/types.go index 3256bf07..5066ede3 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -2,13 +2,15 @@ package models import ( "errors" + + "github.com/google/uuid" ) type ModelType string const ( - ModelTypeChat = "chat" - ModelTypeEmbedding = "embedding" + ModelTypeChat ModelType = "chat" + ModelTypeEmbedding ModelType = "embedding" ) type ClientType string @@ -17,37 +19,35 @@ const ( ClientTypeOpenAI ClientType = "openai" ClientTypeAnthropic ClientType = "anthropic" ClientTypeGoogle ClientType = "google" + ClientTypeBedrock ClientType = "bedrock" + ClientTypeOllama ClientType = "ollama" + ClientTypeAzure ClientType = "azure" + ClientTypeDashscope ClientType = "dashscope" + ClientTypeOther ClientType = "other" ) type Model struct { - ModelID string `json:"model_id"` - Name string `json:"name"` - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` - ClientType ClientType `json:"client_type"` - Type ModelType `json:"type"` - Dimensions int `json:"dimensions"` + ModelID string `json:"model_id"` + Name string `json:"name"` + LlmProviderID string `json:"llm_provider_id"` + IsMultimodal bool `json:"is_multimodal"` + Type ModelType `json:"type"` + Dimensions int `json:"dimensions"` } func (m *Model) Validate() error { if m.ModelID == "" { return errors.New("model ID is required") } - if m.BaseURL == "" { - return errors.New("base URL is required") + if m.LlmProviderID == "" { + return errors.New("llm provider ID is required") } - if m.APIKey == "" { - return errors.New("API key is required") - } - if m.ClientType == "" { - return errors.New("client type is required") + if _, err := uuid.Parse(m.LlmProviderID); err != nil { + return errors.New("llm provider ID must be a valid UUID") } if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding { return errors.New("invalid model type") } - if m.ClientType != ClientTypeOpenAI && m.ClientType != ClientTypeAnthropic && m.ClientType != ClientTypeGoogle { - return errors.New("invalid client type") - } if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 { return errors.New("dimensions must be greater than 0") } diff --git a/internal/server/server.go b/internal/server/server.go index d24852e1..cdf3509b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,7 @@ type Server struct { addr string } -func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, fsHandler *handlers.FSHandler, swaggerHandler *handlers.SwaggerHandler) *Server { +func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, fsHandler *handlers.FSHandler, swaggerHandler *handlers.SwaggerHandler) *Server { if addr == "" { addr = ":8080" } @@ -44,6 +44,9 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, if memoryHandler != nil { memoryHandler.Register(e) } + if embeddingsHandler != nil { + embeddingsHandler.Register(e) + } if fsHandler != nil { fsHandler.Register(e) }