diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 49ccc0a2..84178072 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -107,69 +107,27 @@ func main() { } resolver := embeddings.NewResolver(logger.L, modelsService, queries, 10*time.Second) - vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService) + vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService) if err != nil { logger.Error("embedding models", slog.Any("error", err)) os.Exit(1) } - var memoryService *memory.Service - var memoryHandler *handlers.MemoryHandler - - if !hasModels { - logger.Warn("No embedding models configured. Memory service will not be available.") - logger.Warn("You can add embedding models via the /models API endpoint.") - memoryHandler = handlers.NewMemoryHandler(logger.L, nil) - } else { - if textModel.ModelID == "" { - logger.Warn("No text embedding model configured. Text embedding features will be limited.") - } - if multimodalModel.ModelID == "" { - logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.") - } - - var textEmbedder embeddings.Embedder - var store *memory.QdrantStore - - if textModel.ModelID != "" && textModel.Dimensions > 0 { - textEmbedder = &embeddings.ResolverTextEmbedder{ - Resolver: resolver, - ModelID: textModel.ModelID, - Dims: textModel.Dimensions, - } - - if len(vectors) > 0 { - store, err = memory.NewQdrantStoreWithVectors( - logger.L, - cfg.Qdrant.BaseURL, - cfg.Qdrant.APIKey, - cfg.Qdrant.Collection, - vectors, - time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, - ) - if err != nil { - logger.Error("qdrant named vectors init", slog.Any("error", err)) - os.Exit(1) - } - } else { - store, err = memory.NewQdrantStore( - logger.L, - cfg.Qdrant.BaseURL, - cfg.Qdrant.APIKey, - cfg.Qdrant.Collection, - textModel.Dimensions, - time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second, - ) - if err != nil { - logger.Error("qdrant init", slog.Any("error", err)) - os.Exit(1) - } - } - } - - memoryService = memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID) - memoryHandler = handlers.NewMemoryHandler(logger.L, memoryService) + textEmbedder := buildTextEmbedder(resolver, textModel, hasEmbeddingModels, logger.L) + if hasEmbeddingModels && multimodalModel.ModelID == "" { + logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.") } + + store := buildQdrantStore(logger.L, cfg.Qdrant, vectors, hasEmbeddingModels, textModel.Dimensions) + + bm25Indexer := memory.NewBM25Indexer(logger.L) + memoryService := memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, bm25Indexer, textModel.ModelID, multimodalModel.ModelID) + memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService) + go func() { + if err := memoryService.WarmupBM25(ctx, 200); err != nil { + logger.Warn("bm25 warmup failed", slog.Any("error", err)) + } + }() chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, cfg.AgentGateway.BaseURL(), 30*time.Second) embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries) swaggerHandler := handlers.NewSwaggerHandler(logger.L) @@ -197,7 +155,7 @@ func main() { scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService) subagentService := subagent.NewService(logger.L, queries) subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService) - srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler, /*channelHandler*/) + srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler /*channelHandler*/) if err := srv.Start(); err != nil { logger.Error("server failed", slog.Any("error", err)) @@ -205,6 +163,55 @@ func main() { } } +func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder { + if !hasModels { + return nil + } + if textModel.ModelID == "" || textModel.Dimensions <= 0 { + log.Warn("No text embedding model configured. Text embedding features will be limited.") + return nil + } + return &embeddings.ResolverTextEmbedder{ + Resolver: resolver, + ModelID: textModel.ModelID, + Dims: textModel.Dimensions, + } +} + +func buildQdrantStore(log *slog.Logger, cfg config.QdrantConfig, vectors map[string]int, hasModels bool, textDims int) *memory.QdrantStore { + timeout := time.Duration(cfg.TimeoutSeconds) * time.Second + if hasModels && len(vectors) > 0 { + store, err := memory.NewQdrantStoreWithVectors( + log, + cfg.BaseURL, + cfg.APIKey, + cfg.Collection, + vectors, + "sparse_hash", + timeout, + ) + if err != nil { + log.Error("qdrant named vectors init", slog.Any("error", err)) + os.Exit(1) + } + return store + } + store, err := memory.NewQdrantStore( + log, + cfg.BaseURL, + cfg.APIKey, + cfg.Collection, + textDims, + "sparse_hash", + timeout, + ) + if err != nil { + log.Error("qdrant init", slog.Any("error", err)) + os.Exit(1) + } + return store +} + func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error { if queries == nil { return fmt.Errorf("db queries not configured") @@ -279,6 +286,14 @@ func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (m return client.Decide(ctx, req) } +func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) { + client, err := c.resolve(ctx) + if err != nil { + return "", err + } + return client.DetectLanguage(ctx, text) +} + func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) { if c.modelsService == nil || c.queries == nil { return nil, fmt.Errorf("models service not configured") diff --git a/docs/docs.go b/docs/docs.go index d33daa77..f8c3127a 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -135,11 +135,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "type": "array", - "items": { - "type": "integer", - "format": "int32" - } + "type": "string" } }, "400": { @@ -827,12 +823,6 @@ const docTemplate = `{ ], "summary": "List memories", "parameters": [ - { - "type": "string", - "description": "Agent ID", - "name": "agent_id", - "in": "query" - }, { "type": "string", "description": "Run ID", @@ -3061,8 +3051,8 @@ const docTemplate = `{ "handlers.memoryAddPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" + "embedding_enabled": { + "type": "boolean" }, "filters": { "type": "object", @@ -3092,9 +3082,6 @@ const docTemplate = `{ "handlers.memoryDeleteAllPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" - }, "run_id": { "type": "string" } @@ -3103,9 +3090,6 @@ const docTemplate = `{ "handlers.memoryEmbedUpsertPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" - }, "filters": { "type": "object", "additionalProperties": true @@ -3137,8 +3121,8 @@ const docTemplate = `{ "handlers.memorySearchPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" + "embedding_enabled": { + "type": "boolean" }, "filters": { "type": "object", @@ -3267,9 +3251,6 @@ const docTemplate = `{ "memory.MemoryItem": { "type": "object", "properties": { - "agentId": { - "type": "string" - }, "createdAt": { "type": "string" }, @@ -3329,6 +3310,9 @@ const docTemplate = `{ "memory.UpdateRequest": { "type": "object", "properties": { + "embedding_enabled": { + "type": "boolean" + }, "memory": { "type": "string" }, diff --git a/docs/swagger.json b/docs/swagger.json index 1b255791..256abe1d 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -126,11 +126,7 @@ "200": { "description": "OK", "schema": { - "type": "array", - "items": { - "type": "integer", - "format": "int32" - } + "type": "string" } }, "400": { @@ -818,12 +814,6 @@ ], "summary": "List memories", "parameters": [ - { - "type": "string", - "description": "Agent ID", - "name": "agent_id", - "in": "query" - }, { "type": "string", "description": "Run ID", @@ -3052,8 +3042,8 @@ "handlers.memoryAddPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" + "embedding_enabled": { + "type": "boolean" }, "filters": { "type": "object", @@ -3083,9 +3073,6 @@ "handlers.memoryDeleteAllPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" - }, "run_id": { "type": "string" } @@ -3094,9 +3081,6 @@ "handlers.memoryEmbedUpsertPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" - }, "filters": { "type": "object", "additionalProperties": true @@ -3128,8 +3112,8 @@ "handlers.memorySearchPayload": { "type": "object", "properties": { - "agent_id": { - "type": "string" + "embedding_enabled": { + "type": "boolean" }, "filters": { "type": "object", @@ -3258,9 +3242,6 @@ "memory.MemoryItem": { "type": "object", "properties": { - "agentId": { - "type": "string" - }, "createdAt": { "type": "string" }, @@ -3320,6 +3301,9 @@ "memory.UpdateRequest": { "type": "object", "properties": { + "embedding_enabled": { + "type": "boolean" + }, "memory": { "type": "string" }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index ecca7bd3..9b152852 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -269,8 +269,8 @@ definitions: type: object handlers.memoryAddPayload: properties: - agent_id: - type: string + embedding_enabled: + type: boolean filters: additionalProperties: true type: object @@ -290,15 +290,11 @@ definitions: type: object handlers.memoryDeleteAllPayload: properties: - agent_id: - type: string run_id: type: string type: object handlers.memoryEmbedUpsertPayload: properties: - agent_id: - type: string filters: additionalProperties: true type: object @@ -320,8 +316,8 @@ definitions: type: object handlers.memorySearchPayload: properties: - agent_id: - type: string + embedding_enabled: + type: boolean filters: additionalProperties: true type: object @@ -405,8 +401,6 @@ definitions: type: object memory.MemoryItem: properties: - agentId: - type: string createdAt: type: string hash: @@ -446,6 +440,8 @@ definitions: type: object memory.UpdateRequest: properties: + embedding_enabled: + type: boolean memory: type: string memory_id: @@ -872,10 +868,7 @@ paths: "200": description: OK schema: - items: - format: int32 - type: integer - type: array + type: string "400": description: Bad Request schema: @@ -1365,10 +1358,6 @@ paths: description: 'List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).' parameters: - - description: Agent ID - in: query - name: agent_id - type: string - description: Run ID in: query name: run_id diff --git a/go.mod b/go.mod index 79e71034..8436cd68 100644 --- a/go.mod +++ b/go.mod @@ -4,15 +4,19 @@ go 1.25.2 require ( github.com/BurntSushi/toml v1.6.0 + github.com/blevesearch/bleve/v2 v2.5.7 github.com/containerd/containerd/api v1.10.0 github.com/containerd/containerd/v2 v2.2.1 github.com/containerd/errdefs v1.0.0 + github.com/containerd/platforms v1.0.0-rc.2 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.8.0 github.com/labstack/echo-jwt/v4 v4.4.0 github.com/labstack/echo/v4 v4.15.0 github.com/modelcontextprotocol/go-sdk v1.2.0 + github.com/opencontainers/go-digest v1.0.0 + github.com/opencontainers/image-spec v1.1.1 github.com/opencontainers/runtime-spec v1.3.0 github.com/qdrant/go-client v1.16.2 github.com/robfig/cron/v3 v3.0.1 @@ -25,13 +29,20 @@ require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.14.0-rc.1 // indirect + github.com/bits-and-blooms/bitset v1.22.0 // indirect + github.com/blevesearch/bleve_index_api v1.2.11 // indirect + github.com/blevesearch/geo v0.2.4 // indirect + github.com/blevesearch/go-porterstemmer v1.0.3 // indirect + github.com/blevesearch/segment v0.9.1 // indirect + github.com/blevesearch/snowballstem v0.9.0 // indirect + github.com/blevesearch/stempel v0.2.0 // indirect + github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/cgroups/v3 v3.1.2 // indirect github.com/containerd/continuity v0.4.5 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/fifo v1.1.0 // indirect github.com/containerd/log v0.1.0 // indirect - github.com/containerd/platforms v1.0.0-rc.2 // indirect github.com/containerd/plugin v1.0.0 // indirect github.com/containerd/ttrpc v1.2.7 // indirect github.com/containerd/typeurl/v2 v2.2.3 // indirect @@ -51,7 +62,6 @@ require ( github.com/go-openapi/swag/stringutils v0.25.4 // indirect github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect - github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/go-cmp v0.7.0 // indirect @@ -59,9 +69,9 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.3 // indirect github.com/labstack/gommon v0.4.2 // indirect - github.com/larksuite/oapi-sdk-go/v3 v3.5.3 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/moby/locker v1.0.1 // indirect @@ -70,8 +80,8 @@ require ( github.com/moby/sys/signal v0.7.1 // indirect github.com/moby/sys/user v0.4.0 // indirect github.com/moby/sys/userns v0.1.0 // indirect - github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/opencontainers/selinux v1.13.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index e8a5606d..567ccd8e 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,24 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/hcsshim v0.14.0-rc.1 h1:qAPXKwGOkVn8LlqgBN8GS0bxZ83hOJpcjxzmlQKxKsQ= github.com/Microsoft/hcsshim v0.14.0-rc.1/go.mod h1:hTKFGbnDtQb1wHiOWv4v0eN+7boSWAHyK/tNAaYZL0c= +github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCkcs2uw7w4= +github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/blevesearch/bleve/v2 v2.5.7 h1:2d9YrL5zrX5EBBW++GOaEKjE+NPWeZGaX77IM26m1Z8= +github.com/blevesearch/bleve/v2 v2.5.7/go.mod h1:yj0NlS7ocGC4VOSAedqDDMktdh2935v2CSWOCDMHdSA= +github.com/blevesearch/bleve_index_api v1.2.11 h1:bXQ54kVuwP8hdrXUSOnvTQfgK0KI1+f9A0ITJT8tX1s= +github.com/blevesearch/bleve_index_api v1.2.11/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0= +github.com/blevesearch/geo v0.2.4 h1:ECIGQhw+QALCZaDcogRTNSJYQXRtC8/m8IKiA706cqk= +github.com/blevesearch/geo v0.2.4/go.mod h1:K56Q33AzXt2YExVHGObtmRSFYZKYGv0JEN5mdacJJR8= +github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= +github.com/blevesearch/go-porterstemmer v1.0.3/go.mod h1:angGc5Ht+k2xhJdZi511LtmxuEf0OVpvUUNrwmM1P7M= +github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= +github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= +github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s= +github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs= +github.com/blevesearch/stempel v0.2.0 h1:CYzVPaScODMvgE9o+kf6D4RJ/VRomyi9uHF+PtB+Afc= +github.com/blevesearch/stempel v0.2.0/go.mod h1:wjeTHqQv+nQdbPuJ/YcvOjTInA2EIc6Ks1FoSUzSLvc= +github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A= +github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -85,8 +103,6 @@ github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxE github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg= github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= -github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc= -github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= @@ -115,12 +131,12 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -129,6 +145,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= @@ -143,8 +161,6 @@ github.com/labstack/echo/v4 v4.15.0 h1:hoRTKWcnR5STXZFe9BmYun9AMTNeSbjHi2vtDuADJ github.com/labstack/echo/v4 v4.15.0/go.mod h1:xmw1clThob0BSVRX1CRQkGQ/vjwcpOMjQZSZa9fKA/c= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= -github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= -github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -163,6 +179,12 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8= +github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -267,6 +289,7 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go index 0dad3bc0..cb98ea1b 100644 --- a/internal/handlers/chat.go +++ b/internal/handlers/chat.go @@ -76,7 +76,7 @@ func (h *ChatHandler) Chat(c echo.Context) error { // @Accept json // @Produce text/event-stream // @Param request body chat.ChatRequest true "Chat request" -// @Success 200 {object} chat.StreamChunk +// @Success 200 {string} string // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /chat/stream [post] diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index acccfa72..4f8cea9d 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -18,22 +18,22 @@ type MemoryHandler struct { } type memoryAddPayload struct { - Message string `json:"message,omitempty"` - Messages []memory.Message `json:"messages,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` - Infer *bool `json:"infer,omitempty"` + Message string `json:"message,omitempty"` + Messages []memory.Message `json:"messages,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` + Infer *bool `json:"infer,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type memorySearchPayload struct { - Query string `json:"query"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` - Sources []string `json:"sources,omitempty"` + Query string `json:"query"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` + Sources []string `json:"sources,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type memoryEmbedUpsertPayload struct { @@ -42,15 +42,13 @@ type memoryEmbedUpsertPayload struct { Model string `json:"model,omitempty"` Input memory.EmbedInput `json:"input"` Source string `json:"source,omitempty"` - AgentID string `json:"agent_id,omitempty"` RunID string `json:"run_id,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` Filters map[string]interface{} `json:"filters,omitempty"` } type memoryDeleteAllPayload struct { - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` + RunID string `json:"run_id,omitempty"` } func NewMemoryHandler(log *slog.Logger, service *memory.Service) *MemoryHandler { @@ -74,7 +72,7 @@ func (h *MemoryHandler) Register(e *echo.Echo) { func (h *MemoryHandler) checkService() error { if h.service == nil { - return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available: no embedding models configured") + return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available") } return nil } @@ -109,7 +107,6 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { Input: payload.Input, Source: payload.Source, UserID: userID, - AgentID: payload.AgentID, RunID: payload.RunID, Metadata: payload.Metadata, Filters: payload.Filters, @@ -146,14 +143,14 @@ func (h *MemoryHandler) Add(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } req := memory.AddRequest{ - Message: payload.Message, - Messages: payload.Messages, - UserID: userID, - AgentID: payload.AgentID, - RunID: payload.RunID, - Metadata: payload.Metadata, - Filters: payload.Filters, - Infer: payload.Infer, + Message: payload.Message, + Messages: payload.Messages, + UserID: userID, + RunID: payload.RunID, + Metadata: payload.Metadata, + Filters: payload.Filters, + Infer: payload.Infer, + EmbeddingEnabled: payload.EmbeddingEnabled, } resp, err := h.service.Add(c.Request().Context(), req) @@ -187,13 +184,13 @@ func (h *MemoryHandler) Search(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } req := memory.SearchRequest{ - Query: payload.Query, - UserID: userID, - AgentID: payload.AgentID, - RunID: payload.RunID, - Limit: payload.Limit, - Filters: payload.Filters, - Sources: payload.Sources, + Query: payload.Query, + UserID: userID, + RunID: payload.RunID, + Limit: payload.Limit, + Filters: payload.Filters, + Sources: payload.Sources, + EmbeddingEnabled: payload.EmbeddingEnabled, } resp, err := h.service.Search(c.Request().Context(), req) @@ -281,7 +278,6 @@ func (h *MemoryHandler) Get(c echo.Context) error { // @Summary List memories // @Description List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory -// @Param agent_id query string false "Agent ID" // @Param run_id query string false "Run ID" // @Param limit query int false "Limit" // @Success 200 {object} memory.SearchResponse @@ -299,9 +295,8 @@ func (h *MemoryHandler) GetAll(c echo.Context) error { } req := memory.GetAllRequest{ - UserID: userID, - AgentID: c.QueryParam("agent_id"), - RunID: c.QueryParam("run_id"), + UserID: userID, + RunID: c.QueryParam("run_id"), } if limit := c.QueryParam("limit"); limit != "" { var parsed int @@ -380,9 +375,8 @@ func (h *MemoryHandler) DeleteAll(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } req := memory.DeleteAllRequest{ - UserID: userID, - AgentID: payload.AgentID, - RunID: payload.RunID, + UserID: userID, + RunID: payload.RunID, } resp, err := h.service.DeleteAll(c.Request().Context(), req) diff --git a/internal/memory/indexer.go b/internal/memory/indexer.go new file mode 100644 index 00000000..86180a3f --- /dev/null +++ b/internal/memory/indexer.go @@ -0,0 +1,252 @@ +package memory + +import ( + "fmt" + "hash/fnv" + "log/slog" + "math" + "sort" + "strings" + "sync" + + "github.com/blevesearch/bleve/v2/registry" + + _ "github.com/blevesearch/bleve/v2/analysis/analyzer/standard" + _ "github.com/blevesearch/bleve/v2/analysis/lang/ar" + _ "github.com/blevesearch/bleve/v2/analysis/lang/bg" + _ "github.com/blevesearch/bleve/v2/analysis/lang/ca" + _ "github.com/blevesearch/bleve/v2/analysis/lang/cjk" + _ "github.com/blevesearch/bleve/v2/analysis/lang/ckb" + _ "github.com/blevesearch/bleve/v2/analysis/lang/da" + _ "github.com/blevesearch/bleve/v2/analysis/lang/de" + _ "github.com/blevesearch/bleve/v2/analysis/lang/el" + _ "github.com/blevesearch/bleve/v2/analysis/lang/en" + _ "github.com/blevesearch/bleve/v2/analysis/lang/es" + _ "github.com/blevesearch/bleve/v2/analysis/lang/eu" + _ "github.com/blevesearch/bleve/v2/analysis/lang/fa" + _ "github.com/blevesearch/bleve/v2/analysis/lang/fi" + _ "github.com/blevesearch/bleve/v2/analysis/lang/fr" + _ "github.com/blevesearch/bleve/v2/analysis/lang/ga" + _ "github.com/blevesearch/bleve/v2/analysis/lang/gl" + _ "github.com/blevesearch/bleve/v2/analysis/lang/hi" + _ "github.com/blevesearch/bleve/v2/analysis/lang/hr" + _ "github.com/blevesearch/bleve/v2/analysis/lang/hu" + _ "github.com/blevesearch/bleve/v2/analysis/lang/hy" + _ "github.com/blevesearch/bleve/v2/analysis/lang/id" + _ "github.com/blevesearch/bleve/v2/analysis/lang/it" + _ "github.com/blevesearch/bleve/v2/analysis/lang/nl" + _ "github.com/blevesearch/bleve/v2/analysis/lang/no" + _ "github.com/blevesearch/bleve/v2/analysis/lang/pl" + _ "github.com/blevesearch/bleve/v2/analysis/lang/pt" + _ "github.com/blevesearch/bleve/v2/analysis/lang/ro" + _ "github.com/blevesearch/bleve/v2/analysis/lang/ru" + _ "github.com/blevesearch/bleve/v2/analysis/lang/sv" + _ "github.com/blevesearch/bleve/v2/analysis/lang/tr" +) + +const ( + defaultBM25K1 = 1.2 + defaultBM25B = 0.75 + sparseDimBits = 20 + sparseDimSize = 1 << sparseDimBits + sparseDimMask = sparseDimSize - 1 +) + +type BM25Indexer struct { + cache *registry.Cache + logger *slog.Logger + k1 float64 + b float64 + + mu sync.RWMutex + stats map[string]*bm25Stats +} + +type bm25Stats struct { + DocCount int + AvgDocLen float64 + DocFreq map[string]int +} + +func NewBM25Indexer(log *slog.Logger) *BM25Indexer { + if log == nil { + log = slog.Default() + } + return &BM25Indexer{ + cache: registry.NewCache(), + logger: log.With(slog.String("indexer", "bm25")), + k1: defaultBM25K1, + b: defaultBM25B, + stats: map[string]*bm25Stats{}, + } +} + +func (b *BM25Indexer) TermFrequencies(lang, text string) (map[string]int, int, error) { + analyzerName, err := b.normalizeAnalyzer(lang) + if err != nil { + return nil, 0, err + } + analyzer, err := b.cache.AnalyzerNamed(analyzerName) + if err != nil { + return nil, 0, fmt.Errorf("bm25 analyzer %s: %w", analyzerName, err) + } + tokens := analyzer.Analyze([]byte(text)) + freq := map[string]int{} + docLen := 0 + for _, token := range tokens { + term := strings.TrimSpace(string(token.Term)) + if term == "" { + continue + } + freq[term]++ + docLen++ + } + return freq, docLen, nil +} + +func (b *BM25Indexer) AddDocument(lang string, termFreq map[string]int, docLen int) (indices []uint32, values []float32) { + b.mu.Lock() + stats := b.ensureStatsLocked(lang) + b.updateStatsAddLocked(stats, termFreq, docLen) + indices, values = b.buildDocVectorLocked(stats, termFreq, docLen) + b.mu.Unlock() + return indices, values +} + +func (b *BM25Indexer) RemoveDocument(lang string, termFreq map[string]int, docLen int) { + b.mu.Lock() + stats := b.ensureStatsLocked(lang) + b.updateStatsRemoveLocked(stats, termFreq, docLen) + b.mu.Unlock() +} + +func (b *BM25Indexer) BuildQueryVector(lang string, termFreq map[string]int) (indices []uint32, values []float32) { + b.mu.RLock() + stats := b.ensureStatsLocked(lang) + indices, values = b.buildQueryVectorLocked(stats, termFreq) + b.mu.RUnlock() + return indices, values +} + +func (b *BM25Indexer) normalizeAnalyzer(lang string) (string, error) { + normalized := strings.ToLower(strings.TrimSpace(lang)) + switch normalized { + case "": + return "standard", nil + case "in": + normalized = "id" + } + return normalized, nil +} + +func (b *BM25Indexer) ensureStatsLocked(lang string) *bm25Stats { + name, _ := b.normalizeAnalyzer(lang) + stats := b.stats[name] + if stats == nil { + stats = &bm25Stats{ + DocFreq: map[string]int{}, + } + b.stats[name] = stats + } + return stats +} + +func (b *BM25Indexer) updateStatsAddLocked(stats *bm25Stats, termFreq map[string]int, docLen int) { + totalDocs := stats.DocCount + stats.DocCount++ + totalLen := stats.AvgDocLen * float64(totalDocs) + stats.AvgDocLen = (totalLen + float64(docLen)) / float64(stats.DocCount) + for term := range termFreq { + stats.DocFreq[term]++ + } +} + +func (b *BM25Indexer) updateStatsRemoveLocked(stats *bm25Stats, termFreq map[string]int, docLen int) { + if stats.DocCount <= 0 { + return + } + totalDocs := stats.DocCount + totalLen := stats.AvgDocLen * float64(totalDocs) + stats.DocCount-- + if stats.DocCount > 0 { + stats.AvgDocLen = (totalLen - float64(docLen)) / float64(stats.DocCount) + } else { + stats.AvgDocLen = 0 + } + for term := range termFreq { + if stats.DocFreq[term] > 1 { + stats.DocFreq[term]-- + } else { + delete(stats.DocFreq, term) + } + } +} + +func (b *BM25Indexer) buildDocVectorLocked(stats *bm25Stats, termFreq map[string]int, docLen int) ([]uint32, []float32) { + if stats.DocCount == 0 || docLen == 0 { + return nil, nil + } + avgDocLen := stats.AvgDocLen + if avgDocLen <= 0 { + avgDocLen = 1 + } + weights := map[uint32]float32{} + for term, tf := range termFreq { + df := stats.DocFreq[term] + if df == 0 { + continue + } + idf := math.Log(1 + (float64(stats.DocCount)-float64(df)+0.5)/(float64(df)+0.5)) + numerator := float64(tf) * (b.k1 + 1) + denominator := float64(tf) + b.k1*(1-b.b+b.b*float64(docLen)/avgDocLen) + tfNorm := numerator / denominator + weight := float32(tfNorm * idf) + if weight == 0 { + continue + } + index := termHash(term) + weights[index] += weight + } + return sparseWeightsToVector(weights) +} + +func (b *BM25Indexer) buildQueryVectorLocked(stats *bm25Stats, termFreq map[string]int) ([]uint32, []float32) { + if stats.DocCount == 0 { + return nil, nil + } + weights := map[uint32]float32{} + for term, tf := range termFreq { + if stats.DocFreq[term] == 0 { + continue + } + weight := float32(tf) + if weight == 0 { + continue + } + index := termHash(term) + weights[index] += weight + } + return sparseWeightsToVector(weights) +} + +func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) { + if len(weights) == 0 { + return nil, nil + } + indices := make([]uint32, 0, len(weights)) + for idx := range weights { + indices = append(indices, idx) + } + sort.Slice(indices, func(i, j int) bool { return indices[i] < indices[j] }) + values := make([]float32, 0, len(indices)) + for _, idx := range indices { + values = append(values, weights[idx]) + } + return indices, values +} + +func termHash(term string) uint32 { + hasher := fnv.New32a() + _, _ = hasher.Write([]byte(term)) + return hasher.Sum32() & sparseDimMask +} diff --git a/internal/memory/llm_client.go b/internal/memory/llm_client.go index f1f80fe6..8197898b 100644 --- a/internal/memory/llm_client.go +++ b/internal/memory/llm_client.go @@ -29,7 +29,7 @@ func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time. } baseURL = strings.TrimRight(baseURL, "/") if model == "" { - model = "gpt-4.1-nano-2025-04-14" + model = "gpt-4.1-nano" } if timeout <= 0 { timeout = 10 * time.Second @@ -119,6 +119,31 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon return DecideResponse{Actions: actions}, nil } +func (c *LLMClient) DetectLanguage(ctx context.Context, text string) (string, error) { + if strings.TrimSpace(text) == "" { + return "", fmt.Errorf("text is required") + } + systemPrompt, userPrompt := getLanguageDetectionMessages(text) + content, err := c.callChat(ctx, []chatMessage{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: userPrompt}, + }) + if err != nil { + return "", err + } + var parsed struct { + Language string `json:"language"` + } + if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil { + return "", err + } + lang := strings.ToLower(strings.TrimSpace(parsed.Language)) + if !isAllowedLanguageCode(lang) { + return "", fmt.Errorf("unsupported language code: %s", lang) + } + return lang, nil +} + type chatMessage struct { Role string `json:"role"` Content string `json:"content"` @@ -247,3 +272,14 @@ func normalizeMemoryItems(value interface{}) []map[string]interface{} { return nil } } + +func isAllowedLanguageCode(code string) bool { + switch strings.ToLower(strings.TrimSpace(code)) { + case "ar", "bg", "ca", "cjk", "ckb", "da", "de", "el", "en", "es", "eu", + "fa", "fi", "fr", "ga", "gl", "hi", "hr", "hu", "hy", "id", "in", + "it", "nl", "no", "pl", "pt", "ro", "ru", "sv", "tr": + return true + default: + return false + } +} diff --git a/internal/memory/prompts.go b/internal/memory/prompts.go index 610e4904..b312e673 100644 --- a/internal/memory/prompts.go +++ b/internal/memory/prompts.go @@ -106,6 +106,20 @@ Follow the instruction mentioned below: Do not return anything except the JSON format.`, toJSON(retrievedOldMemory), toJSON(newRetrievedFacts), "```json", "```") } +func getLanguageDetectionMessages(text string) (string, string) { + systemPrompt := `You are a language classifier for the given input text. +Return a JSON object with a single key "language" whose value is one of the allowed codes. +Allowed codes: ar, bg, ca, cjk, ckb, da, de, el, en, es, eu, fa, fi, fr, ga, gl, hi, hr, hu, hy, id, in, it, nl, no, pl, pt, ro, ru, sv, tr. +Use "cjk" for Chinese/Japanese/Korean text, ckb=Kurdish(Sorani), ga=Irish(Gaelic), gl=Galician, eu=Basque, hy=Armenian, fa=Persian, hr=Croatian, hu=Hungarian, ro=Romanian, bg=Bulgarian. If unsure between id/in, use id. +If multiple languages appear, choose the dominant language. +Do not include any extra keys, comments, or formatting. Output must be valid JSON only. +If the text is Chinese, Japanese, or Korean, output exactly {"language":"cjk"}. +Never output "zh", "zh-cn", "zh-tw", "ja", "ko", or any code not in the allowed list. +Before finalizing, verify the value is one of the allowed codes.` + userPrompt := fmt.Sprintf("Text:\n%s", text) + return systemPrompt, userPrompt +} + func parseMessages(messages []string) string { return strings.Join(messages, "\n") } diff --git a/internal/memory/qdrant_store.go b/internal/memory/qdrant_store.go index 9159d0bc..ab856ee2 100644 --- a/internal/memory/qdrant_store.go +++ b/internal/memory/qdrant_store.go @@ -12,34 +12,47 @@ import ( "github.com/qdrant/go-client/qdrant" ) +const ( + sparseHashVectorName = "sparse_hash" + sparseVocabVectorName = "sparse_vocab" +) + type QdrantStore struct { - client *qdrant.Client - collection string - dimension int - baseURL string - apiKey string - timeout time.Duration - logger *slog.Logger - vectorNames map[string]int - usesNamedVectors bool + client *qdrant.Client + collection string + dimension int + baseURL string + apiKey string + timeout time.Duration + logger *slog.Logger + vectorNames map[string]int + usesNamedVectors bool + sparseVectorName string + usesSparseVectors bool } type qdrantPoint struct { - ID string `json:"id"` - Vector []float32 `json:"vector"` - VectorName string `json:"vector_name,omitempty"` - Payload map[string]interface{} `json:"payload,omitempty"` + ID string `json:"id"` + Vector []float32 `json:"vector"` + VectorName string `json:"vector_name,omitempty"` + SparseIndices []uint32 `json:"sparse_indices,omitempty"` + SparseValues []float32 `json:"sparse_values,omitempty"` + SparseVectorName string `json:"sparse_vector_name,omitempty"` + Payload map[string]interface{} `json:"payload,omitempty"` } -func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, timeout time.Duration) (*QdrantStore, error) { +func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) { host, port, useTLS, err := parseQdrantEndpoint(baseURL) if err != nil { return nil, err } + if strings.TrimSpace(sparseVectorName) == "" { + sparseVectorName = sparseHashVectorName + } if collection == "" { collection = "memory" } - if dimension <= 0 { + if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" { dimension = 1536 } @@ -55,13 +68,15 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens } store := &QdrantStore{ - client: client, - collection: collection, - dimension: dimension, - baseURL: baseURL, - apiKey: apiKey, - timeout: timeoutOrDefault(timeout), - logger: log.With(slog.String("store", "qdrant")), + client: client, + collection: collection, + dimension: dimension, + baseURL: baseURL, + apiKey: apiKey, + timeout: timeoutOrDefault(timeout), + logger: log.With(slog.String("store", "qdrant")), + sparseVectorName: strings.TrimSpace(sparseVectorName), + usesSparseVectors: strings.TrimSpace(sparseVectorName) != "", } ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout)) @@ -73,14 +88,17 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens } func (s *QdrantStore) NewSibling(collection string, dimension int) (*QdrantStore, error) { - return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.timeout) + return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.sparseVectorName, s.timeout) } -func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, timeout time.Duration) (*QdrantStore, error) { +func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) { host, port, useTLS, err := parseQdrantEndpoint(baseURL) if err != nil { return nil, err } + if strings.TrimSpace(sparseVectorName) == "" { + sparseVectorName = sparseHashVectorName + } if collection == "" { collection = "memory" } @@ -100,14 +118,16 @@ func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection str } store := &QdrantStore{ - client: client, - collection: collection, - baseURL: baseURL, - apiKey: apiKey, - timeout: timeoutOrDefault(timeout), - logger: log.With(slog.String("store", "qdrant")), - vectorNames: vectors, - usesNamedVectors: true, + client: client, + collection: collection, + baseURL: baseURL, + apiKey: apiKey, + timeout: timeoutOrDefault(timeout), + logger: log.With(slog.String("store", "qdrant")), + vectorNames: vectors, + usesNamedVectors: true, + sparseVectorName: strings.TrimSpace(sparseVectorName), + usesSparseVectors: strings.TrimSpace(sparseVectorName) != "", } ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout)) @@ -129,12 +149,31 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { return err } var vectors *qdrant.Vectors - if point.VectorName != "" && s.usesNamedVectors { - vectors = qdrant.NewVectorsMap(map[string]*qdrant.Vector{ - point.VectorName: qdrant.NewVectorDense(point.Vector), - }) - } else { - vectors = qdrant.NewVectorsDense(point.Vector) + vectorMap := map[string]*qdrant.Vector{} + if len(point.Vector) > 0 { + if point.VectorName != "" && s.usesNamedVectors { + vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector) + } else if !s.usesNamedVectors && len(point.SparseIndices) == 0 { + vectors = qdrant.NewVectorsDense(point.Vector) + } else if point.VectorName != "" { + vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector) + } + } + if len(point.SparseIndices) > 0 && len(point.SparseValues) > 0 { + sparseName := strings.TrimSpace(point.SparseVectorName) + if sparseName == "" { + sparseName = s.sparseVectorName + } + if sparseName == "" { + return fmt.Errorf("sparse vector name is required") + } + vectorMap[sparseName] = qdrant.NewVectorSparse(point.SparseIndices, point.SparseValues) + } + if vectors == nil { + if len(vectorMap) == 0 { + return fmt.Errorf("no vector data provided for point %s", point.ID) + } + vectors = qdrant.NewVectorsMap(vectorMap) } qPoints = append(qPoints, &qdrant.PointStruct{ Id: qdrant.NewIDUUID(point.ID), @@ -183,6 +222,41 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f return points, scores, nil } +func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}) ([]qdrantPoint, []float64, error) { + if limit <= 0 { + limit = 10 + } + if len(indices) == 0 || len(values) == 0 { + return nil, nil, nil + } + if s.sparseVectorName == "" { + return nil, nil, fmt.Errorf("sparse vector name not configured") + } + filter := buildQdrantFilter(filters) + using := qdrant.PtrOf(s.sparseVectorName) + results, err := s.client.Query(ctx, &qdrant.QueryPoints{ + CollectionName: s.collection, + Query: qdrant.NewQuerySparse(indices, values), + Using: using, + Limit: qdrant.PtrOf(uint64(limit)), + Filter: filter, + WithPayload: qdrant.NewWithPayload(true), + }) + if err != nil { + return nil, nil, err + } + points := make([]qdrantPoint, 0, len(results)) + scores := make([]float64, 0, len(results)) + for _, scored := range results { + points = append(points, qdrantPoint{ + ID: pointIDToString(scored.GetId()), + Payload: valueMapToInterface(scored.GetPayload()), + }) + scores = append(scores, float64(scored.GetScore())) + } + return points, scores, nil +} + func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) { pointsBySource := make(map[string][]qdrantPoint, len(sources)) scoresBySource := make(map[string][]float64, len(sources)) @@ -204,6 +278,27 @@ func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, lim return pointsBySource, scoresBySource, nil } +func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}, sources []string) (map[string][]qdrantPoint, map[string][]float64, error) { + pointsBySource := make(map[string][]qdrantPoint, len(sources)) + scoresBySource := make(map[string][]float64, len(sources)) + if len(sources) == 0 { + return pointsBySource, scoresBySource, nil + } + for _, source := range sources { + merged := cloneFilters(filters) + if source != "" { + merged["source"] = source + } + points, scores, err := s.SearchSparse(ctx, indices, values, limit, merged) + if err != nil { + return nil, nil, err + } + pointsBySource[source] = points + scoresBySource[source] = scores + } + return pointsBySource, scoresBySource, nil +} + func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) { result, err := s.client.Get(ctx, &qdrant.GetPoints{ CollectionName: s.collection, @@ -257,6 +352,31 @@ func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]in return result, nil } +func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]interface{}, offset *qdrant.PointId) ([]qdrantPoint, *qdrant.PointId, error) { + if limit <= 0 { + limit = 100 + } + filter := buildQdrantFilter(filters) + points, nextOffset, err := s.client.ScrollAndOffset(ctx, &qdrant.ScrollPoints{ + CollectionName: s.collection, + Limit: qdrant.PtrOf(uint32(limit)), + Filter: filter, + Offset: offset, + WithPayload: qdrant.NewWithPayload(true), + }) + if err != nil { + return nil, nil, err + } + result := make([]qdrantPoint, 0, len(points)) + for _, point := range points { + result = append(result, qdrantPoint{ + ID: pointIDToString(point.GetId()), + Payload: valueMapToInterface(point.GetPayload()), + }) + } + return result, nextOffset, nil +} + func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]interface{}) error { filter := buildQdrantFilter(filters) if filter == nil { @@ -278,6 +398,7 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i if exists { return s.refreshCollectionSchema(ctx, vectors) } + var vectorsConfig *qdrant.VectorsConfig if len(vectors) > 0 { params := make(map[string]*qdrant.VectorParams, len(vectors)) for name, dim := range vectors { @@ -286,17 +407,24 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i Distance: qdrant.Distance_Cosine, } } - return s.client.CreateCollection(ctx, &qdrant.CreateCollection{ - CollectionName: s.collection, - VectorsConfig: qdrant.NewVectorsConfigMap(params), + vectorsConfig = qdrant.NewVectorsConfigMap(params) + } else if s.dimension > 0 { + vectorsConfig = qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: uint64(s.dimension), + Distance: qdrant.Distance_Cosine, + }) + } + var sparseConfig *qdrant.SparseVectorConfig + if s.sparseVectorName != "" { + sparseConfig = qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{ + s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)}, + sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)}, }) } return s.client.CreateCollection(ctx, &qdrant.CreateCollection{ - CollectionName: s.collection, - VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ - Size: uint64(s.dimension), - Distance: qdrant.Distance_Cosine, - }), + CollectionName: s.collection, + VectorsConfig: vectorsConfig, + SparseVectorsConfig: sparseConfig, }) } @@ -306,11 +434,12 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s return err } config := info.GetConfig() - if config == nil || config.GetParams() == nil || config.GetParams().GetVectorsConfig() == nil { + if config == nil || config.GetParams() == nil { return nil } - vectorsConfig := config.GetParams().GetVectorsConfig() - if vectorsConfig.GetParamsMap() != nil { + params := config.GetParams() + vectorsConfig := params.GetVectorsConfig() + if vectorsConfig != nil && vectorsConfig.GetParamsMap() != nil { s.usesNamedVectors = true s.vectorNames = map[string]int{} for name, vec := range vectorsConfig.GetParamsMap().GetMap() { @@ -319,7 +448,7 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s } } if len(vectors) == 0 { - return nil + goto sparseCheck } for name, dim := range vectors { if existing, ok := s.vectorNames[name]; ok && existing == dim { @@ -327,13 +456,58 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s } return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim) } + } + if vectorsConfig == nil || vectorsConfig.GetParamsMap() == nil { + s.usesNamedVectors = false + s.vectorNames = nil + } + +sparseCheck: + sparseConfig := params.GetSparseVectorsConfig() + if s.sparseVectorName != "" { + needsUpdate := false + if sparseConfig == nil || len(sparseConfig.GetMap()) == 0 { + needsUpdate = true + } else { + if _, ok := sparseConfig.GetMap()[s.sparseVectorName]; !ok { + needsUpdate = true + } + if _, ok := sparseConfig.GetMap()[sparseVocabVectorName]; !ok { + needsUpdate = true + } + } + if needsUpdate { + if err := s.ensureSparseVectors(ctx); err != nil { + return err + } + } + s.usesSparseVectors = true return nil } - s.usesNamedVectors = false - s.vectorNames = nil + if sparseConfig != nil && len(sparseConfig.GetMap()) > 0 { + s.usesSparseVectors = true + for name := range sparseConfig.GetMap() { + s.sparseVectorName = name + break + } + } return nil } +func (s *QdrantStore) ensureSparseVectors(ctx context.Context) error { + if s.sparseVectorName == "" { + return nil + } + err := s.client.UpdateCollection(ctx, &qdrant.UpdateCollection{ + CollectionName: s.collection, + SparseVectorsConfig: qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{ + s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)}, + sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)}, + }), + }) + return err +} + func parseQdrantEndpoint(endpoint string) (string, int, bool, error) { if endpoint == "" { return "127.0.0.1", 6334, false, nil diff --git a/internal/memory/service.go b/internal/memory/service.go index 155f0176..883aca43 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -12,6 +12,7 @@ import ( "time" "github.com/google/uuid" + "github.com/qdrant/go-client/qdrant" "github.com/memohai/memoh/internal/embeddings" ) @@ -21,17 +22,19 @@ type Service struct { embedder embeddings.Embedder store *QdrantStore resolver *embeddings.Resolver + bm25 *BM25Indexer logger *slog.Logger defaultTextModelID string defaultMultimodalModelID string } -func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service { +func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, bm25 *BM25Indexer, defaultTextModelID, defaultMultimodalModelID string) *Service { return &Service{ llm: llm, embedder: embedder, store: store, resolver: resolver, + bm25: bm25, logger: log.With(slog.String("service", "memory")), defaultTextModelID: defaultTextModelID, defaultMultimodalModelID: defaultMultimodalModelID, @@ -42,15 +45,16 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro if req.Message == "" && len(req.Messages) == 0 { return SearchResponse{}, fmt.Errorf("message or messages is required") } - if req.UserID == "" && req.AgentID == "" && req.RunID == "" { - return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required") + if req.UserID == "" { + return SearchResponse{}, fmt.Errorf("user_id is required") } messages := normalizeMessages(req) filters := buildFilters(req) + embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled if req.Infer != nil && !*req.Infer { - return s.addRawMessages(ctx, messages, filters, req.Metadata) + return s.addRawMessages(ctx, messages, filters, req.Metadata, embeddingEnabled) } extractResp, err := s.llm.Extract(ctx, ExtractRequest{ @@ -95,7 +99,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro for _, action := range actions { switch strings.ToUpper(action.Event) { case "ADD": - item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata) + item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata, embeddingEnabled) if err != nil { return SearchResponse{}, err } @@ -104,7 +108,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro }) results = append(results, item) case "UPDATE": - item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata) + item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata, embeddingEnabled) if err != nil { return SearchResponse{}, err } @@ -134,19 +138,19 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse if strings.TrimSpace(req.Query) == "" { return SearchResponse{}, fmt.Errorf("query is required") } + if s.store == nil { + return SearchResponse{}, fmt.Errorf("qdrant store not configured") + } filters := buildSearchFilters(req) modality := "" if raw, ok := filters["modality"].(string); ok { modality = strings.ToLower(strings.TrimSpace(raw)) } - - var ( - vector []float32 - store *QdrantStore - vectorName string - err error - ) + embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled if modality == embeddings.TypeMultimodal { + if !embeddingEnabled { + return SearchResponse{}, fmt.Errorf("embedding is disabled") + } if s.resolver == nil { return SearchResponse{}, fmt.Errorf("embeddings resolver not configured") } @@ -159,24 +163,79 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse if err != nil { return SearchResponse{}, err } - vector = result.Embedding - store = s.store - vectorName = s.vectorNameForMultimodal() - } else { - vector, err = s.embedder.Embed(ctx, req.Query) + vectorName := s.vectorNameForMultimodal() + if len(req.Sources) == 0 { + points, scores, err := s.store.Search(ctx, result.Embedding, req.Limit, filters, vectorName) + if err != nil { + return SearchResponse{}, err + } + results := make([]MemoryItem, 0, len(points)) + for idx, point := range points { + item := payloadToMemoryItem(point.ID, point.Payload) + if idx < len(scores) { + item.Score = scores[idx] + } + results = append(results, item) + } + return SearchResponse{Results: results}, nil + } + pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, result.Embedding, req.Limit, filters, req.Sources, vectorName) if err != nil { return SearchResponse{}, err } - store = s.store - vectorName = s.vectorNameForText() + results := fuseByRankFusion(pointsBySource, scoresBySource) + return SearchResponse{Results: results}, nil } - if len(req.Sources) == 0 { - points, scores, err := store.Search(ctx, vector, req.Limit, filters, vectorName) + if embeddingEnabled { + if s.embedder == nil { + return SearchResponse{}, fmt.Errorf("embedder not configured") + } + vector, err := s.embedder.Embed(ctx, req.Query) if err != nil { return SearchResponse{}, err } + vectorName := s.vectorNameForText() + if len(req.Sources) == 0 { + points, scores, err := s.store.Search(ctx, vector, req.Limit, filters, vectorName) + if err != nil { + return SearchResponse{}, err + } + results := make([]MemoryItem, 0, len(points)) + for idx, point := range points { + item := payloadToMemoryItem(point.ID, point.Payload) + if idx < len(scores) { + item.Score = scores[idx] + } + results = append(results, item) + } + return SearchResponse{Results: results}, nil + } + pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName) + if err != nil { + return SearchResponse{}, err + } + results := fuseByRankFusion(pointsBySource, scoresBySource) + return SearchResponse{Results: results}, nil + } + if s.bm25 == nil { + return SearchResponse{}, fmt.Errorf("bm25 indexer not configured") + } + lang, err := s.detectLanguage(ctx, req.Query) + if err != nil { + return SearchResponse{}, err + } + termFreq, _, err := s.bm25.TermFrequencies(lang, req.Query) + if err != nil { + return SearchResponse{}, err + } + indices, values := s.bm25.BuildQueryVector(lang, termFreq) + if len(req.Sources) == 0 { + points, scores, err := s.store.SearchSparse(ctx, indices, values, req.Limit, filters) + if err != nil { + return SearchResponse{}, err + } results := make([]MemoryItem, 0, len(points)) for idx, point := range points { item := payloadToMemoryItem(point.ID, point.Payload) @@ -187,8 +246,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse } return SearchResponse{Results: results}, nil } - - pointsBySource, scoresBySource, err := store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName) + pointsBySource, scoresBySource, err := s.store.SearchSparseBySources(ctx, indices, values, req.Limit, filters, req.Sources) if err != nil { return SearchResponse{}, err } @@ -200,8 +258,8 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe if s.resolver == nil { return EmbedUpsertResponse{}, fmt.Errorf("embeddings resolver not configured") } - if req.UserID == "" && req.AgentID == "" && req.RunID == "" { - return EmbedUpsertResponse{}, fmt.Errorf("user_id, agent_id or run_id is required") + if req.UserID == "" { + return EmbedUpsertResponse{}, fmt.Errorf("user_id is required") } req.Type = strings.TrimSpace(req.Type) req.Provider = strings.TrimSpace(req.Provider) @@ -264,6 +322,12 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er if strings.TrimSpace(req.Memory) == "" { return MemoryItem{}, fmt.Errorf("memory is required") } + if s.store == nil { + return MemoryItem{}, fmt.Errorf("qdrant store not configured") + } + if s.bm25 == nil { + return MemoryItem{}, fmt.Errorf("bm25 indexer not configured") + } existing, err := s.store.Get(ctx, req.MemoryID) if err != nil { @@ -272,22 +336,58 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er if existing == nil { return MemoryItem{}, fmt.Errorf("memory not found") } + if s.bm25 == nil { + return MemoryItem{}, fmt.Errorf("bm25 indexer not configured") + } payload := existing.Payload - payload["data"] = req.Memory - payload["hash"] = hashMemory(req.Memory) - payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339) + oldText := fmt.Sprint(payload["data"]) + oldLang := fmt.Sprint(payload["lang"]) + if oldLang == "" && strings.TrimSpace(oldText) != "" { + oldLang, _ = s.detectLanguage(ctx, oldText) + } + if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" { + oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText) + if err == nil { + s.bm25.RemoveDocument(oldLang, oldFreq, oldLen) + } + } - vector, err := s.embedder.Embed(ctx, req.Memory) + newLang, err := s.detectLanguage(ctx, req.Memory) if err != nil { return MemoryItem{}, err } - if err := s.store.Upsert(ctx, []qdrantPoint{{ - ID: req.MemoryID, - Vector: vector, - VectorName: s.vectorNameForText(), - Payload: payload, - }}); err != nil { + newFreq, newLen, err := s.bm25.TermFrequencies(newLang, req.Memory) + if err != nil { + return MemoryItem{}, err + } + sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen) + + payload["data"] = req.Memory + payload["hash"] = hashMemory(req.Memory) + payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339) + payload["lang"] = newLang + + embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled + point := qdrantPoint{ + ID: req.MemoryID, + SparseIndices: sparseIndices, + SparseValues: sparseValues, + SparseVectorName: s.store.sparseVectorName, + Payload: payload, + } + if embeddingEnabled { + if s.embedder == nil { + return MemoryItem{}, fmt.Errorf("embedder not configured") + } + vector, err := s.embedder.Embed(ctx, req.Memory) + if err != nil { + return MemoryItem{}, err + } + point.Vector = vector + point.VectorName = s.vectorNameForText() + } + if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil { return MemoryItem{}, err } return payloadToMemoryItem(req.MemoryID, payload), nil @@ -312,14 +412,11 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse if req.UserID != "" { filters["userId"] = req.UserID } - if req.AgentID != "" { - filters["agentId"] = req.AgentID - } if req.RunID != "" { filters["runId"] = req.RunID } if len(filters) == 0 { - return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required") + return SearchResponse{}, fmt.Errorf("user_id is required") } points, err := s.store.List(ctx, req.Limit, filters) @@ -348,14 +445,11 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe if req.UserID != "" { filters["userId"] = req.UserID } - if req.AgentID != "" { - filters["agentId"] = req.AgentID - } if req.RunID != "" { filters["runId"] = req.RunID } if len(filters) == 0 { - return DeleteResponse{}, fmt.Errorf("user_id, agent_id or run_id is required") + return DeleteResponse{}, fmt.Errorf("user_id is required") } if err := s.store.DeleteAll(ctx, filters); err != nil { return DeleteResponse{}, err @@ -363,10 +457,46 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe return DeleteResponse{Message: "Memories deleted successfully!"}, nil } -func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}) (SearchResponse, error) { +func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error { + if s.bm25 == nil || s.store == nil { + return nil + } + var offset *qdrant.PointId + for { + points, next, err := s.store.Scroll(ctx, batchSize, nil, offset) + if err != nil { + return err + } + if len(points) == 0 { + break + } + for _, point := range points { + text := fmt.Sprint(point.Payload["data"]) + if strings.TrimSpace(text) == "" { + continue + } + lang := fmt.Sprint(point.Payload["lang"]) + if lang == "" { + lang = fallbackLanguageCode(text) + } + termFreq, docLen, err := s.bm25.TermFrequencies(lang, text) + if err != nil { + continue + } + s.bm25.AddDocument(lang, termFreq, docLen) + } + if next == nil { + break + } + offset = next + } + return nil +} + +func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (SearchResponse, error) { results := make([]MemoryItem, 0, len(messages)) for _, message := range messages { - item, err := s.applyAdd(ctx, message.Content, filters, metadata) + item, err := s.applyAdd(ctx, message.Content, filters, metadata, embeddingEnabled) if err != nil { return SearchResponse{}, err } @@ -381,11 +511,19 @@ func (s *Service) addRawMessages(ctx context.Context, messages []Message, filter func (s *Service) collectCandidates(ctx context.Context, facts []string, filters map[string]interface{}) ([]CandidateMemory, error) { unique := map[string]CandidateMemory{} for _, fact := range facts { - vector, err := s.embedder.Embed(ctx, fact) + if s.bm25 == nil { + return nil, fmt.Errorf("bm25 indexer not configured") + } + lang, err := s.detectLanguage(ctx, fact) if err != nil { return nil, err } - points, _, err := s.store.Search(ctx, vector, 5, filters, s.vectorNameForText()) + termFreq, _, err := s.bm25.TermFrequencies(lang, fact) + if err != nil { + return nil, err + } + indices, values := s.bm25.BuildQueryVector(lang, termFreq) + points, _, err := s.store.SearchSparse(ctx, indices, values, 5, filters) if err != nil { return nil, err } @@ -406,25 +544,50 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters return candidates, nil } -func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) { - vector, err := s.embedder.Embed(ctx, text) +func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) { + if s.store == nil { + return MemoryItem{}, fmt.Errorf("qdrant store not configured") + } + if s.bm25 == nil { + return MemoryItem{}, fmt.Errorf("bm25 indexer not configured") + } + lang, err := s.detectLanguage(ctx, text) if err != nil { return MemoryItem{}, err } + termFreq, docLen, err := s.bm25.TermFrequencies(lang, text) + if err != nil { + return MemoryItem{}, err + } + sparseIndices, sparseValues := s.bm25.AddDocument(lang, termFreq, docLen) id := uuid.NewString() payload := buildPayload(text, filters, metadata, "") - if err := s.store.Upsert(ctx, []qdrantPoint{{ - ID: id, - Vector: vector, - VectorName: s.vectorNameForText(), - Payload: payload, - }}); err != nil { + payload["lang"] = lang + point := qdrantPoint{ + ID: id, + SparseIndices: sparseIndices, + SparseValues: sparseValues, + SparseVectorName: s.store.sparseVectorName, + Payload: payload, + } + if embeddingEnabled { + if s.embedder == nil { + return MemoryItem{}, fmt.Errorf("embedder not configured") + } + vector, err := s.embedder.Embed(ctx, text) + if err != nil { + return MemoryItem{}, err + } + point.Vector = vector + point.VectorName = s.vectorNameForText() + } + if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil { return MemoryItem{}, err } return payloadToMemoryItem(id, payload), nil } -func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) { +func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) { if strings.TrimSpace(id) == "" { return MemoryItem{}, fmt.Errorf("update action missing id") } @@ -437,25 +600,55 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ } payload := existing.Payload + oldText := fmt.Sprint(payload["data"]) + oldLang := fmt.Sprint(payload["lang"]) + if oldLang == "" && strings.TrimSpace(oldText) != "" { + oldLang, _ = s.detectLanguage(ctx, oldText) + } + if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" { + oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText) + if err == nil { + s.bm25.RemoveDocument(oldLang, oldFreq, oldLen) + } + } + newLang, err := s.detectLanguage(ctx, text) + if err != nil { + return MemoryItem{}, err + } + newFreq, newLen, err := s.bm25.TermFrequencies(newLang, text) + if err != nil { + return MemoryItem{}, err + } + sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen) payload["data"] = text payload["hash"] = hashMemory(text) payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339) + payload["lang"] = newLang if metadata != nil { payload["metadata"] = mergeMetadata(payload["metadata"], metadata) } if filters != nil { applyFiltersToPayload(payload, filters) } - vector, err := s.embedder.Embed(ctx, text) - if err != nil { - return MemoryItem{}, err + point := qdrantPoint{ + ID: id, + SparseIndices: sparseIndices, + SparseValues: sparseValues, + SparseVectorName: s.store.sparseVectorName, + Payload: payload, } - if err := s.store.Upsert(ctx, []qdrantPoint{{ - ID: id, - Vector: vector, - VectorName: s.vectorNameForText(), - Payload: payload, - }}); err != nil { + if embeddingEnabled { + if s.embedder == nil { + return MemoryItem{}, fmt.Errorf("embedder not configured") + } + vector, err := s.embedder.Embed(ctx, text) + if err != nil { + return MemoryItem{}, err + } + point.Vector = vector + point.VectorName = s.vectorNameForText() + } + if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil { return MemoryItem{}, err } return payloadToMemoryItem(id, payload), nil @@ -473,6 +666,19 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error return MemoryItem{}, fmt.Errorf("memory not found") } item := payloadToMemoryItem(id, existing.Payload) + if s.bm25 != nil { + oldText := fmt.Sprint(existing.Payload["data"]) + oldLang := fmt.Sprint(existing.Payload["lang"]) + if oldLang == "" && strings.TrimSpace(oldText) != "" { + oldLang, _ = s.detectLanguage(ctx, oldText) + } + if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" { + oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText) + if err == nil { + s.bm25.RemoveDocument(oldLang, oldFreq, oldLen) + } + } + } if err := s.store.Delete(ctx, id); err != nil { return MemoryItem{}, err } @@ -486,6 +692,56 @@ func normalizeMessages(req AddRequest) []Message { return []Message{{Role: "user", Content: req.Message}} } +func (s *Service) detectLanguage(ctx context.Context, text string) (string, error) { + if s.llm == nil { + return "", fmt.Errorf("language detector not configured") + } + lang, err := s.llm.DetectLanguage(ctx, text) + if err == nil && lang != "" { + return lang, nil + } + fallback := fallbackLanguageCode(text) + if s.logger != nil { + s.logger.Warn("language detection failed; using fallback", slog.Any("error", err), slog.String("fallback", fallback)) + } + return fallback, nil +} + +func fallbackLanguageCode(text string) string { + for _, r := range text { + if isCJKRune(r) { + return "cjk" + } + } + return "en" +} + +func isCJKRune(r rune) bool { + switch { + case r >= 0x4E00 && r <= 0x9FFF: // CJK Unified Ideographs + return true + case r >= 0x3400 && r <= 0x4DBF: // CJK Unified Ideographs Extension A + return true + case r >= 0x20000 && r <= 0x2A6DF: // CJK Unified Ideographs Extension B + return true + case r >= 0x2A700 && r <= 0x2B73F: // CJK Unified Ideographs Extension C + return true + case r >= 0x2B740 && r <= 0x2B81F: // CJK Unified Ideographs Extension D + return true + case r >= 0x2B820 && r <= 0x2CEAF: // CJK Unified Ideographs Extension E + return true + case r >= 0x2CEB0 && r <= 0x2EBEF: // CJK Unified Ideographs Extension F + return true + case r >= 0x3000 && r <= 0x303F: // CJK Symbols and Punctuation + return true + case r >= 0x3040 && r <= 0x30FF: // Hiragana/Katakana + return true + case r >= 0xAC00 && r <= 0xD7AF: // Hangul Syllables + return true + } + return false +} + func buildFilters(req AddRequest) map[string]interface{} { filters := map[string]interface{}{} for key, value := range req.Filters { @@ -494,9 +750,6 @@ func buildFilters(req AddRequest) map[string]interface{} { if req.UserID != "" { filters["userId"] = req.UserID } - if req.AgentID != "" { - filters["agentId"] = req.AgentID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -511,9 +764,6 @@ func buildSearchFilters(req SearchRequest) map[string]interface{} { if req.UserID != "" { filters["userId"] = req.UserID } - if req.AgentID != "" { - filters["agentId"] = req.AgentID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -528,9 +778,6 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} { if req.UserID != "" { filters["userId"] = req.UserID } - if req.AgentID != "" { - filters["agentId"] = req.AgentID - } if req.RunID != "" { filters["runId"] = req.RunID } @@ -621,9 +868,6 @@ func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem { if v, ok := payload["userId"].(string); ok { item.UserID = v } - if v, ok := payload["agentId"].(string); ok { - item.AgentID = v - } if v, ok := payload["runId"].(string); ok { item.RunID = v } diff --git a/internal/memory/types.go b/internal/memory/types.go index 8c6ecf69..d0867f85 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -6,6 +6,7 @@ import "context" type LLM interface { Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) + DetectLanguage(ctx context.Context, text string) (string, error) } type Message struct { @@ -14,42 +15,41 @@ type Message struct { } type AddRequest struct { - Message string `json:"message,omitempty"` - Messages []Message `json:"messages,omitempty"` - UserID string `json:"user_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` - Infer *bool `json:"infer,omitempty"` + Message string `json:"message,omitempty"` + Messages []Message `json:"messages,omitempty"` + UserID string `json:"user_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` + Infer *bool `json:"infer,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type SearchRequest struct { - Query string `json:"query"` - UserID string `json:"user_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` - Filters map[string]interface{} `json:"filters,omitempty"` - Sources []string `json:"sources,omitempty"` + Query string `json:"query"` + UserID string `json:"user_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` + Sources []string `json:"sources,omitempty"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type UpdateRequest struct { - MemoryID string `json:"memory_id"` - Memory string `json:"memory"` + MemoryID string `json:"memory_id"` + Memory string `json:"memory"` + EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } type GetAllRequest struct { - UserID string `json:"user_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` - Limit int `json:"limit,omitempty"` + UserID string `json:"user_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` } type DeleteAllRequest struct { - UserID string `json:"user_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RunID string `json:"run_id,omitempty"` + UserID string `json:"user_id,omitempty"` + RunID string `json:"run_id,omitempty"` } type EmbedInput struct { @@ -65,7 +65,6 @@ type EmbedUpsertRequest struct { Input EmbedInput `json:"input"` Source string `json:"source,omitempty"` UserID string `json:"user_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` RunID string `json:"run_id,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` Filters map[string]interface{} `json:"filters,omitempty"` @@ -87,7 +86,6 @@ type MemoryItem struct { Score float64 `json:"score,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` UserID string `json:"userId,omitempty"` - AgentID string `json:"agentId,omitempty"` RunID string `json:"runId,omitempty"` }