refactor: use sparse vector for memory

This commit is contained in:
Ran
2026-02-04 11:45:10 +08:00
parent ecebe3c711
commit efd68d306d
14 changed files with 1049 additions and 332 deletions
+73 -58
View File
@@ -107,69 +107,27 @@ func main() {
}
resolver := embeddings.NewResolver(logger.L, modelsService, queries, 10*time.Second)
vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
if err != nil {
logger.Error("embedding models", slog.Any("error", err))
os.Exit(1)
}
var memoryService *memory.Service
var memoryHandler *handlers.MemoryHandler
if !hasModels {
logger.Warn("No embedding models configured. Memory service will not be available.")
logger.Warn("You can add embedding models via the /models API endpoint.")
memoryHandler = handlers.NewMemoryHandler(logger.L, nil)
} else {
if textModel.ModelID == "" {
logger.Warn("No text embedding model configured. Text embedding features will be limited.")
}
if multimodalModel.ModelID == "" {
logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
}
var textEmbedder embeddings.Embedder
var store *memory.QdrantStore
if textModel.ModelID != "" && textModel.Dimensions > 0 {
textEmbedder = &embeddings.ResolverTextEmbedder{
Resolver: resolver,
ModelID: textModel.ModelID,
Dims: textModel.Dimensions,
}
if len(vectors) > 0 {
store, err = memory.NewQdrantStoreWithVectors(
logger.L,
cfg.Qdrant.BaseURL,
cfg.Qdrant.APIKey,
cfg.Qdrant.Collection,
vectors,
time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second,
)
if err != nil {
logger.Error("qdrant named vectors init", slog.Any("error", err))
os.Exit(1)
}
} else {
store, err = memory.NewQdrantStore(
logger.L,
cfg.Qdrant.BaseURL,
cfg.Qdrant.APIKey,
cfg.Qdrant.Collection,
textModel.Dimensions,
time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second,
)
if err != nil {
logger.Error("qdrant init", slog.Any("error", err))
os.Exit(1)
}
}
}
memoryService = memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
memoryHandler = handlers.NewMemoryHandler(logger.L, memoryService)
textEmbedder := buildTextEmbedder(resolver, textModel, hasEmbeddingModels, logger.L)
if hasEmbeddingModels && multimodalModel.ModelID == "" {
logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
}
store := buildQdrantStore(logger.L, cfg.Qdrant, vectors, hasEmbeddingModels, textModel.Dimensions)
bm25Indexer := memory.NewBM25Indexer(logger.L)
memoryService := memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, bm25Indexer, textModel.ModelID, multimodalModel.ModelID)
memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService)
go func() {
if err := memoryService.WarmupBM25(ctx, 200); err != nil {
logger.Warn("bm25 warmup failed", slog.Any("error", err))
}
}()
chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, cfg.AgentGateway.BaseURL(), 30*time.Second)
embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries)
swaggerHandler := handlers.NewSwaggerHandler(logger.L)
@@ -197,7 +155,7 @@ func main() {
scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService)
subagentService := subagent.NewService(logger.L, queries)
subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService)
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler, /*channelHandler*/)
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler /*channelHandler*/)
if err := srv.Start(); err != nil {
logger.Error("server failed", slog.Any("error", err))
@@ -205,6 +163,55 @@ func main() {
}
}
func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder {
if !hasModels {
return nil
}
if textModel.ModelID == "" || textModel.Dimensions <= 0 {
log.Warn("No text embedding model configured. Text embedding features will be limited.")
return nil
}
return &embeddings.ResolverTextEmbedder{
Resolver: resolver,
ModelID: textModel.ModelID,
Dims: textModel.Dimensions,
}
}
func buildQdrantStore(log *slog.Logger, cfg config.QdrantConfig, vectors map[string]int, hasModels bool, textDims int) *memory.QdrantStore {
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
if hasModels && len(vectors) > 0 {
store, err := memory.NewQdrantStoreWithVectors(
log,
cfg.BaseURL,
cfg.APIKey,
cfg.Collection,
vectors,
"sparse_hash",
timeout,
)
if err != nil {
log.Error("qdrant named vectors init", slog.Any("error", err))
os.Exit(1)
}
return store
}
store, err := memory.NewQdrantStore(
log,
cfg.BaseURL,
cfg.APIKey,
cfg.Collection,
textDims,
"sparse_hash",
timeout,
)
if err != nil {
log.Error("qdrant init", slog.Any("error", err))
os.Exit(1)
}
return store
}
func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error {
if queries == nil {
return fmt.Errorf("db queries not configured")
@@ -279,6 +286,14 @@ func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (m
return client.Decide(ctx, req)
}
func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
client, err := c.resolve(ctx)
if err != nil {
return "", err
}
return client.DetectLanguage(ctx, text)
}
func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
if c.modelsService == nil || c.queries == nil {
return nil, fmt.Errorf("models service not configured")
+8 -24
View File
@@ -135,11 +135,7 @@ const docTemplate = `{
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "integer",
"format": "int32"
}
"type": "string"
}
},
"400": {
@@ -827,12 +823,6 @@ const docTemplate = `{
],
"summary": "List memories",
"parameters": [
{
"type": "string",
"description": "Agent ID",
"name": "agent_id",
"in": "query"
},
{
"type": "string",
"description": "Run ID",
@@ -3061,8 +3051,8 @@ const docTemplate = `{
"handlers.memoryAddPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
"embedding_enabled": {
"type": "boolean"
},
"filters": {
"type": "object",
@@ -3092,9 +3082,6 @@ const docTemplate = `{
"handlers.memoryDeleteAllPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"run_id": {
"type": "string"
}
@@ -3103,9 +3090,6 @@ const docTemplate = `{
"handlers.memoryEmbedUpsertPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
@@ -3137,8 +3121,8 @@ const docTemplate = `{
"handlers.memorySearchPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
"embedding_enabled": {
"type": "boolean"
},
"filters": {
"type": "object",
@@ -3267,9 +3251,6 @@ const docTemplate = `{
"memory.MemoryItem": {
"type": "object",
"properties": {
"agentId": {
"type": "string"
},
"createdAt": {
"type": "string"
},
@@ -3329,6 +3310,9 @@ const docTemplate = `{
"memory.UpdateRequest": {
"type": "object",
"properties": {
"embedding_enabled": {
"type": "boolean"
},
"memory": {
"type": "string"
},
+8 -24
View File
@@ -126,11 +126,7 @@
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"type": "integer",
"format": "int32"
}
"type": "string"
}
},
"400": {
@@ -818,12 +814,6 @@
],
"summary": "List memories",
"parameters": [
{
"type": "string",
"description": "Agent ID",
"name": "agent_id",
"in": "query"
},
{
"type": "string",
"description": "Run ID",
@@ -3052,8 +3042,8 @@
"handlers.memoryAddPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
"embedding_enabled": {
"type": "boolean"
},
"filters": {
"type": "object",
@@ -3083,9 +3073,6 @@
"handlers.memoryDeleteAllPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"run_id": {
"type": "string"
}
@@ -3094,9 +3081,6 @@
"handlers.memoryEmbedUpsertPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"filters": {
"type": "object",
"additionalProperties": true
@@ -3128,8 +3112,8 @@
"handlers.memorySearchPayload": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
"embedding_enabled": {
"type": "boolean"
},
"filters": {
"type": "object",
@@ -3258,9 +3242,6 @@
"memory.MemoryItem": {
"type": "object",
"properties": {
"agentId": {
"type": "string"
},
"createdAt": {
"type": "string"
},
@@ -3320,6 +3301,9 @@
"memory.UpdateRequest": {
"type": "object",
"properties": {
"embedding_enabled": {
"type": "boolean"
},
"memory": {
"type": "string"
},
+7 -18
View File
@@ -269,8 +269,8 @@ definitions:
type: object
handlers.memoryAddPayload:
properties:
agent_id:
type: string
embedding_enabled:
type: boolean
filters:
additionalProperties: true
type: object
@@ -290,15 +290,11 @@ definitions:
type: object
handlers.memoryDeleteAllPayload:
properties:
agent_id:
type: string
run_id:
type: string
type: object
handlers.memoryEmbedUpsertPayload:
properties:
agent_id:
type: string
filters:
additionalProperties: true
type: object
@@ -320,8 +316,8 @@ definitions:
type: object
handlers.memorySearchPayload:
properties:
agent_id:
type: string
embedding_enabled:
type: boolean
filters:
additionalProperties: true
type: object
@@ -405,8 +401,6 @@ definitions:
type: object
memory.MemoryItem:
properties:
agentId:
type: string
createdAt:
type: string
hash:
@@ -446,6 +440,8 @@ definitions:
type: object
memory.UpdateRequest:
properties:
embedding_enabled:
type: boolean
memory:
type: string
memory_id:
@@ -872,10 +868,7 @@ paths:
"200":
description: OK
schema:
items:
format: int32
type: integer
type: array
type: string
"400":
description: Bad Request
schema:
@@ -1365,10 +1358,6 @@ paths:
description: 'List memories for a user via memory. Auth: Bearer JWT determines
user_id (sub or user_id).'
parameters:
- description: Agent ID
in: query
name: agent_id
type: string
- description: Run ID
in: query
name: run_id
+15 -5
View File
@@ -4,15 +4,19 @@ go 1.25.2
require (
github.com/BurntSushi/toml v1.6.0
github.com/blevesearch/bleve/v2 v2.5.7
github.com/containerd/containerd/api v1.10.0
github.com/containerd/containerd/v2 v2.2.1
github.com/containerd/errdefs v1.0.0
github.com/containerd/platforms v1.0.0-rc.2
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.8.0
github.com/labstack/echo-jwt/v4 v4.4.0
github.com/labstack/echo/v4 v4.15.0
github.com/modelcontextprotocol/go-sdk v1.2.0
github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.1.1
github.com/opencontainers/runtime-spec v1.3.0
github.com/qdrant/go-client v1.16.2
github.com/robfig/cron/v3 v3.0.1
@@ -25,13 +29,20 @@ require (
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.14.0-rc.1 // indirect
github.com/bits-and-blooms/bitset v1.22.0 // indirect
github.com/blevesearch/bleve_index_api v1.2.11 // indirect
github.com/blevesearch/geo v0.2.4 // indirect
github.com/blevesearch/go-porterstemmer v1.0.3 // indirect
github.com/blevesearch/segment v0.9.1 // indirect
github.com/blevesearch/snowballstem v0.9.0 // indirect
github.com/blevesearch/stempel v0.2.0 // indirect
github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/cgroups/v3 v3.1.2 // indirect
github.com/containerd/continuity v0.4.5 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/fifo v1.1.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v1.0.0-rc.2 // indirect
github.com/containerd/plugin v1.0.0 // indirect
github.com/containerd/ttrpc v1.2.7 // indirect
github.com/containerd/typeurl/v2 v2.2.3 // indirect
@@ -51,7 +62,6 @@ require (
github.com/go-openapi/swag/stringutils v0.25.4 // indirect
github.com/go-openapi/swag/typeutils v0.25.4 // indirect
github.com/go-openapi/swag/yamlutils v0.25.4 // indirect
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/google/go-cmp v0.7.0 // indirect
@@ -59,9 +69,9 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.3 // indirect
github.com/labstack/gommon v0.4.2 // indirect
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/locker v1.0.1 // indirect
@@ -70,8 +80,8 @@ require (
github.com/moby/sys/signal v0.7.1 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
github.com/opencontainers/selinux v1.13.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
+28 -5
View File
@@ -10,6 +10,24 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.14.0-rc.1 h1:qAPXKwGOkVn8LlqgBN8GS0bxZ83hOJpcjxzmlQKxKsQ=
github.com/Microsoft/hcsshim v0.14.0-rc.1/go.mod h1:hTKFGbnDtQb1wHiOWv4v0eN+7boSWAHyK/tNAaYZL0c=
github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCkcs2uw7w4=
github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/blevesearch/bleve/v2 v2.5.7 h1:2d9YrL5zrX5EBBW++GOaEKjE+NPWeZGaX77IM26m1Z8=
github.com/blevesearch/bleve/v2 v2.5.7/go.mod h1:yj0NlS7ocGC4VOSAedqDDMktdh2935v2CSWOCDMHdSA=
github.com/blevesearch/bleve_index_api v1.2.11 h1:bXQ54kVuwP8hdrXUSOnvTQfgK0KI1+f9A0ITJT8tX1s=
github.com/blevesearch/bleve_index_api v1.2.11/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0=
github.com/blevesearch/geo v0.2.4 h1:ECIGQhw+QALCZaDcogRTNSJYQXRtC8/m8IKiA706cqk=
github.com/blevesearch/geo v0.2.4/go.mod h1:K56Q33AzXt2YExVHGObtmRSFYZKYGv0JEN5mdacJJR8=
github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo=
github.com/blevesearch/go-porterstemmer v1.0.3/go.mod h1:angGc5Ht+k2xhJdZi511LtmxuEf0OVpvUUNrwmM1P7M=
github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU=
github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw=
github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s=
github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs=
github.com/blevesearch/stempel v0.2.0 h1:CYzVPaScODMvgE9o+kf6D4RJ/VRomyi9uHF+PtB+Afc=
github.com/blevesearch/stempel v0.2.0/go.mod h1:wjeTHqQv+nQdbPuJ/YcvOjTInA2EIc6Ks1FoSUzSLvc=
github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A=
github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -85,8 +103,6 @@ github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxE
github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg=
github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls=
github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
@@ -115,12 +131,12 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -129,6 +145,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
@@ -143,8 +161,6 @@ github.com/labstack/echo/v4 v4.15.0 h1:hoRTKWcnR5STXZFe9BmYun9AMTNeSbjHi2vtDuADJ
github.com/labstack/echo/v4 v4.15.0/go.mod h1:xmw1clThob0BSVRX1CRQkGQ/vjwcpOMjQZSZa9fKA/c=
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -163,6 +179,12 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8=
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -267,6 +289,7 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
+1 -1
View File
@@ -76,7 +76,7 @@ func (h *ChatHandler) Chat(c echo.Context) error {
// @Accept json
// @Produce text/event-stream
// @Param request body chat.ChatRequest true "Chat request"
// @Success 200 {object} chat.StreamChunk
// @Success 200 {string} string
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /chat/stream [post]
+34 -40
View File
@@ -18,22 +18,22 @@ type MemoryHandler struct {
}
type memoryAddPayload struct {
Message string `json:"message,omitempty"`
Messages []memory.Message `json:"messages,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Infer *bool `json:"infer,omitempty"`
Message string `json:"message,omitempty"`
Messages []memory.Message `json:"messages,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Infer *bool `json:"infer,omitempty"`
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
}
type memorySearchPayload struct {
Query string `json:"query"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Limit int `json:"limit,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Sources []string `json:"sources,omitempty"`
Query string `json:"query"`
RunID string `json:"run_id,omitempty"`
Limit int `json:"limit,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Sources []string `json:"sources,omitempty"`
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
}
type memoryEmbedUpsertPayload struct {
@@ -42,15 +42,13 @@ type memoryEmbedUpsertPayload struct {
Model string `json:"model,omitempty"`
Input memory.EmbedInput `json:"input"`
Source string `json:"source,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
}
type memoryDeleteAllPayload struct {
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
RunID string `json:"run_id,omitempty"`
}
func NewMemoryHandler(log *slog.Logger, service *memory.Service) *MemoryHandler {
@@ -74,7 +72,7 @@ func (h *MemoryHandler) Register(e *echo.Echo) {
func (h *MemoryHandler) checkService() error {
if h.service == nil {
return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available: no embedding models configured")
return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available")
}
return nil
}
@@ -109,7 +107,6 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
Input: payload.Input,
Source: payload.Source,
UserID: userID,
AgentID: payload.AgentID,
RunID: payload.RunID,
Metadata: payload.Metadata,
Filters: payload.Filters,
@@ -146,14 +143,14 @@ func (h *MemoryHandler) Add(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
req := memory.AddRequest{
Message: payload.Message,
Messages: payload.Messages,
UserID: userID,
AgentID: payload.AgentID,
RunID: payload.RunID,
Metadata: payload.Metadata,
Filters: payload.Filters,
Infer: payload.Infer,
Message: payload.Message,
Messages: payload.Messages,
UserID: userID,
RunID: payload.RunID,
Metadata: payload.Metadata,
Filters: payload.Filters,
Infer: payload.Infer,
EmbeddingEnabled: payload.EmbeddingEnabled,
}
resp, err := h.service.Add(c.Request().Context(), req)
@@ -187,13 +184,13 @@ func (h *MemoryHandler) Search(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
req := memory.SearchRequest{
Query: payload.Query,
UserID: userID,
AgentID: payload.AgentID,
RunID: payload.RunID,
Limit: payload.Limit,
Filters: payload.Filters,
Sources: payload.Sources,
Query: payload.Query,
UserID: userID,
RunID: payload.RunID,
Limit: payload.Limit,
Filters: payload.Filters,
Sources: payload.Sources,
EmbeddingEnabled: payload.EmbeddingEnabled,
}
resp, err := h.service.Search(c.Request().Context(), req)
@@ -281,7 +278,6 @@ func (h *MemoryHandler) Get(c echo.Context) error {
// @Summary List memories
// @Description List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).
// @Tags memory
// @Param agent_id query string false "Agent ID"
// @Param run_id query string false "Run ID"
// @Param limit query int false "Limit"
// @Success 200 {object} memory.SearchResponse
@@ -299,9 +295,8 @@ func (h *MemoryHandler) GetAll(c echo.Context) error {
}
req := memory.GetAllRequest{
UserID: userID,
AgentID: c.QueryParam("agent_id"),
RunID: c.QueryParam("run_id"),
UserID: userID,
RunID: c.QueryParam("run_id"),
}
if limit := c.QueryParam("limit"); limit != "" {
var parsed int
@@ -380,9 +375,8 @@ func (h *MemoryHandler) DeleteAll(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
req := memory.DeleteAllRequest{
UserID: userID,
AgentID: payload.AgentID,
RunID: payload.RunID,
UserID: userID,
RunID: payload.RunID,
}
resp, err := h.service.DeleteAll(c.Request().Context(), req)
+252
View File
@@ -0,0 +1,252 @@
package memory
import (
"fmt"
"hash/fnv"
"log/slog"
"math"
"sort"
"strings"
"sync"
"github.com/blevesearch/bleve/v2/registry"
_ "github.com/blevesearch/bleve/v2/analysis/analyzer/standard"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ar"
_ "github.com/blevesearch/bleve/v2/analysis/lang/bg"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ca"
_ "github.com/blevesearch/bleve/v2/analysis/lang/cjk"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ckb"
_ "github.com/blevesearch/bleve/v2/analysis/lang/da"
_ "github.com/blevesearch/bleve/v2/analysis/lang/de"
_ "github.com/blevesearch/bleve/v2/analysis/lang/el"
_ "github.com/blevesearch/bleve/v2/analysis/lang/en"
_ "github.com/blevesearch/bleve/v2/analysis/lang/es"
_ "github.com/blevesearch/bleve/v2/analysis/lang/eu"
_ "github.com/blevesearch/bleve/v2/analysis/lang/fa"
_ "github.com/blevesearch/bleve/v2/analysis/lang/fi"
_ "github.com/blevesearch/bleve/v2/analysis/lang/fr"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ga"
_ "github.com/blevesearch/bleve/v2/analysis/lang/gl"
_ "github.com/blevesearch/bleve/v2/analysis/lang/hi"
_ "github.com/blevesearch/bleve/v2/analysis/lang/hr"
_ "github.com/blevesearch/bleve/v2/analysis/lang/hu"
_ "github.com/blevesearch/bleve/v2/analysis/lang/hy"
_ "github.com/blevesearch/bleve/v2/analysis/lang/id"
_ "github.com/blevesearch/bleve/v2/analysis/lang/it"
_ "github.com/blevesearch/bleve/v2/analysis/lang/nl"
_ "github.com/blevesearch/bleve/v2/analysis/lang/no"
_ "github.com/blevesearch/bleve/v2/analysis/lang/pl"
_ "github.com/blevesearch/bleve/v2/analysis/lang/pt"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ro"
_ "github.com/blevesearch/bleve/v2/analysis/lang/ru"
_ "github.com/blevesearch/bleve/v2/analysis/lang/sv"
_ "github.com/blevesearch/bleve/v2/analysis/lang/tr"
)
const (
defaultBM25K1 = 1.2
defaultBM25B = 0.75
sparseDimBits = 20
sparseDimSize = 1 << sparseDimBits
sparseDimMask = sparseDimSize - 1
)
type BM25Indexer struct {
cache *registry.Cache
logger *slog.Logger
k1 float64
b float64
mu sync.RWMutex
stats map[string]*bm25Stats
}
type bm25Stats struct {
DocCount int
AvgDocLen float64
DocFreq map[string]int
}
func NewBM25Indexer(log *slog.Logger) *BM25Indexer {
if log == nil {
log = slog.Default()
}
return &BM25Indexer{
cache: registry.NewCache(),
logger: log.With(slog.String("indexer", "bm25")),
k1: defaultBM25K1,
b: defaultBM25B,
stats: map[string]*bm25Stats{},
}
}
func (b *BM25Indexer) TermFrequencies(lang, text string) (map[string]int, int, error) {
analyzerName, err := b.normalizeAnalyzer(lang)
if err != nil {
return nil, 0, err
}
analyzer, err := b.cache.AnalyzerNamed(analyzerName)
if err != nil {
return nil, 0, fmt.Errorf("bm25 analyzer %s: %w", analyzerName, err)
}
tokens := analyzer.Analyze([]byte(text))
freq := map[string]int{}
docLen := 0
for _, token := range tokens {
term := strings.TrimSpace(string(token.Term))
if term == "" {
continue
}
freq[term]++
docLen++
}
return freq, docLen, nil
}
func (b *BM25Indexer) AddDocument(lang string, termFreq map[string]int, docLen int) (indices []uint32, values []float32) {
b.mu.Lock()
stats := b.ensureStatsLocked(lang)
b.updateStatsAddLocked(stats, termFreq, docLen)
indices, values = b.buildDocVectorLocked(stats, termFreq, docLen)
b.mu.Unlock()
return indices, values
}
func (b *BM25Indexer) RemoveDocument(lang string, termFreq map[string]int, docLen int) {
b.mu.Lock()
stats := b.ensureStatsLocked(lang)
b.updateStatsRemoveLocked(stats, termFreq, docLen)
b.mu.Unlock()
}
func (b *BM25Indexer) BuildQueryVector(lang string, termFreq map[string]int) (indices []uint32, values []float32) {
b.mu.RLock()
stats := b.ensureStatsLocked(lang)
indices, values = b.buildQueryVectorLocked(stats, termFreq)
b.mu.RUnlock()
return indices, values
}
func (b *BM25Indexer) normalizeAnalyzer(lang string) (string, error) {
normalized := strings.ToLower(strings.TrimSpace(lang))
switch normalized {
case "":
return "standard", nil
case "in":
normalized = "id"
}
return normalized, nil
}
func (b *BM25Indexer) ensureStatsLocked(lang string) *bm25Stats {
name, _ := b.normalizeAnalyzer(lang)
stats := b.stats[name]
if stats == nil {
stats = &bm25Stats{
DocFreq: map[string]int{},
}
b.stats[name] = stats
}
return stats
}
func (b *BM25Indexer) updateStatsAddLocked(stats *bm25Stats, termFreq map[string]int, docLen int) {
totalDocs := stats.DocCount
stats.DocCount++
totalLen := stats.AvgDocLen * float64(totalDocs)
stats.AvgDocLen = (totalLen + float64(docLen)) / float64(stats.DocCount)
for term := range termFreq {
stats.DocFreq[term]++
}
}
func (b *BM25Indexer) updateStatsRemoveLocked(stats *bm25Stats, termFreq map[string]int, docLen int) {
if stats.DocCount <= 0 {
return
}
totalDocs := stats.DocCount
totalLen := stats.AvgDocLen * float64(totalDocs)
stats.DocCount--
if stats.DocCount > 0 {
stats.AvgDocLen = (totalLen - float64(docLen)) / float64(stats.DocCount)
} else {
stats.AvgDocLen = 0
}
for term := range termFreq {
if stats.DocFreq[term] > 1 {
stats.DocFreq[term]--
} else {
delete(stats.DocFreq, term)
}
}
}
func (b *BM25Indexer) buildDocVectorLocked(stats *bm25Stats, termFreq map[string]int, docLen int) ([]uint32, []float32) {
if stats.DocCount == 0 || docLen == 0 {
return nil, nil
}
avgDocLen := stats.AvgDocLen
if avgDocLen <= 0 {
avgDocLen = 1
}
weights := map[uint32]float32{}
for term, tf := range termFreq {
df := stats.DocFreq[term]
if df == 0 {
continue
}
idf := math.Log(1 + (float64(stats.DocCount)-float64(df)+0.5)/(float64(df)+0.5))
numerator := float64(tf) * (b.k1 + 1)
denominator := float64(tf) + b.k1*(1-b.b+b.b*float64(docLen)/avgDocLen)
tfNorm := numerator / denominator
weight := float32(tfNorm * idf)
if weight == 0 {
continue
}
index := termHash(term)
weights[index] += weight
}
return sparseWeightsToVector(weights)
}
func (b *BM25Indexer) buildQueryVectorLocked(stats *bm25Stats, termFreq map[string]int) ([]uint32, []float32) {
if stats.DocCount == 0 {
return nil, nil
}
weights := map[uint32]float32{}
for term, tf := range termFreq {
if stats.DocFreq[term] == 0 {
continue
}
weight := float32(tf)
if weight == 0 {
continue
}
index := termHash(term)
weights[index] += weight
}
return sparseWeightsToVector(weights)
}
func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) {
if len(weights) == 0 {
return nil, nil
}
indices := make([]uint32, 0, len(weights))
for idx := range weights {
indices = append(indices, idx)
}
sort.Slice(indices, func(i, j int) bool { return indices[i] < indices[j] })
values := make([]float32, 0, len(indices))
for _, idx := range indices {
values = append(values, weights[idx])
}
return indices, values
}
func termHash(term string) uint32 {
hasher := fnv.New32a()
_, _ = hasher.Write([]byte(term))
return hasher.Sum32() & sparseDimMask
}
+37 -1
View File
@@ -29,7 +29,7 @@ func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.
}
baseURL = strings.TrimRight(baseURL, "/")
if model == "" {
model = "gpt-4.1-nano-2025-04-14"
model = "gpt-4.1-nano"
}
if timeout <= 0 {
timeout = 10 * time.Second
@@ -119,6 +119,31 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon
return DecideResponse{Actions: actions}, nil
}
func (c *LLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
if strings.TrimSpace(text) == "" {
return "", fmt.Errorf("text is required")
}
systemPrompt, userPrompt := getLanguageDetectionMessages(text)
content, err := c.callChat(ctx, []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
})
if err != nil {
return "", err
}
var parsed struct {
Language string `json:"language"`
}
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
return "", err
}
lang := strings.ToLower(strings.TrimSpace(parsed.Language))
if !isAllowedLanguageCode(lang) {
return "", fmt.Errorf("unsupported language code: %s", lang)
}
return lang, nil
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
@@ -247,3 +272,14 @@ func normalizeMemoryItems(value interface{}) []map[string]interface{} {
return nil
}
}
func isAllowedLanguageCode(code string) bool {
switch strings.ToLower(strings.TrimSpace(code)) {
case "ar", "bg", "ca", "cjk", "ckb", "da", "de", "el", "en", "es", "eu",
"fa", "fi", "fr", "ga", "gl", "hi", "hr", "hu", "hy", "id", "in",
"it", "nl", "no", "pl", "pt", "ro", "ru", "sv", "tr":
return true
default:
return false
}
}
+14
View File
@@ -106,6 +106,20 @@ Follow the instruction mentioned below:
Do not return anything except the JSON format.`, toJSON(retrievedOldMemory), toJSON(newRetrievedFacts), "```json", "```")
}
func getLanguageDetectionMessages(text string) (string, string) {
systemPrompt := `You are a language classifier for the given input text.
Return a JSON object with a single key "language" whose value is one of the allowed codes.
Allowed codes: ar, bg, ca, cjk, ckb, da, de, el, en, es, eu, fa, fi, fr, ga, gl, hi, hr, hu, hy, id, in, it, nl, no, pl, pt, ro, ru, sv, tr.
Use "cjk" for Chinese/Japanese/Korean text, ckb=Kurdish(Sorani), ga=Irish(Gaelic), gl=Galician, eu=Basque, hy=Armenian, fa=Persian, hr=Croatian, hu=Hungarian, ro=Romanian, bg=Bulgarian. If unsure between id/in, use id.
If multiple languages appear, choose the dominant language.
Do not include any extra keys, comments, or formatting. Output must be valid JSON only.
If the text is Chinese, Japanese, or Korean, output exactly {"language":"cjk"}.
Never output "zh", "zh-cn", "zh-tw", "ja", "ko", or any code not in the allowed list.
Before finalizing, verify the value is one of the allowed codes.`
userPrompt := fmt.Sprintf("Text:\n%s", text)
return systemPrompt, userPrompt
}
func parseMessages(messages []string) string {
return strings.Join(messages, "\n")
}
+226 -52
View File
@@ -12,34 +12,47 @@ import (
"github.com/qdrant/go-client/qdrant"
)
const (
sparseHashVectorName = "sparse_hash"
sparseVocabVectorName = "sparse_vocab"
)
type QdrantStore struct {
client *qdrant.Client
collection string
dimension int
baseURL string
apiKey string
timeout time.Duration
logger *slog.Logger
vectorNames map[string]int
usesNamedVectors bool
client *qdrant.Client
collection string
dimension int
baseURL string
apiKey string
timeout time.Duration
logger *slog.Logger
vectorNames map[string]int
usesNamedVectors bool
sparseVectorName string
usesSparseVectors bool
}
type qdrantPoint struct {
ID string `json:"id"`
Vector []float32 `json:"vector"`
VectorName string `json:"vector_name,omitempty"`
Payload map[string]interface{} `json:"payload,omitempty"`
ID string `json:"id"`
Vector []float32 `json:"vector"`
VectorName string `json:"vector_name,omitempty"`
SparseIndices []uint32 `json:"sparse_indices,omitempty"`
SparseValues []float32 `json:"sparse_values,omitempty"`
SparseVectorName string `json:"sparse_vector_name,omitempty"`
Payload map[string]interface{} `json:"payload,omitempty"`
}
func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, timeout time.Duration) (*QdrantStore, error) {
func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) {
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
if err != nil {
return nil, err
}
if strings.TrimSpace(sparseVectorName) == "" {
sparseVectorName = sparseHashVectorName
}
if collection == "" {
collection = "memory"
}
if dimension <= 0 {
if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" {
dimension = 1536
}
@@ -55,13 +68,15 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
}
store := &QdrantStore{
client: client,
collection: collection,
dimension: dimension,
baseURL: baseURL,
apiKey: apiKey,
timeout: timeoutOrDefault(timeout),
logger: log.With(slog.String("store", "qdrant")),
client: client,
collection: collection,
dimension: dimension,
baseURL: baseURL,
apiKey: apiKey,
timeout: timeoutOrDefault(timeout),
logger: log.With(slog.String("store", "qdrant")),
sparseVectorName: strings.TrimSpace(sparseVectorName),
usesSparseVectors: strings.TrimSpace(sparseVectorName) != "",
}
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
@@ -73,14 +88,17 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
}
func (s *QdrantStore) NewSibling(collection string, dimension int) (*QdrantStore, error) {
return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.timeout)
return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.sparseVectorName, s.timeout)
}
func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, timeout time.Duration) (*QdrantStore, error) {
func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) {
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
if err != nil {
return nil, err
}
if strings.TrimSpace(sparseVectorName) == "" {
sparseVectorName = sparseHashVectorName
}
if collection == "" {
collection = "memory"
}
@@ -100,14 +118,16 @@ func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection str
}
store := &QdrantStore{
client: client,
collection: collection,
baseURL: baseURL,
apiKey: apiKey,
timeout: timeoutOrDefault(timeout),
logger: log.With(slog.String("store", "qdrant")),
vectorNames: vectors,
usesNamedVectors: true,
client: client,
collection: collection,
baseURL: baseURL,
apiKey: apiKey,
timeout: timeoutOrDefault(timeout),
logger: log.With(slog.String("store", "qdrant")),
vectorNames: vectors,
usesNamedVectors: true,
sparseVectorName: strings.TrimSpace(sparseVectorName),
usesSparseVectors: strings.TrimSpace(sparseVectorName) != "",
}
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
@@ -129,12 +149,31 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error {
return err
}
var vectors *qdrant.Vectors
if point.VectorName != "" && s.usesNamedVectors {
vectors = qdrant.NewVectorsMap(map[string]*qdrant.Vector{
point.VectorName: qdrant.NewVectorDense(point.Vector),
})
} else {
vectors = qdrant.NewVectorsDense(point.Vector)
vectorMap := map[string]*qdrant.Vector{}
if len(point.Vector) > 0 {
if point.VectorName != "" && s.usesNamedVectors {
vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector)
} else if !s.usesNamedVectors && len(point.SparseIndices) == 0 {
vectors = qdrant.NewVectorsDense(point.Vector)
} else if point.VectorName != "" {
vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector)
}
}
if len(point.SparseIndices) > 0 && len(point.SparseValues) > 0 {
sparseName := strings.TrimSpace(point.SparseVectorName)
if sparseName == "" {
sparseName = s.sparseVectorName
}
if sparseName == "" {
return fmt.Errorf("sparse vector name is required")
}
vectorMap[sparseName] = qdrant.NewVectorSparse(point.SparseIndices, point.SparseValues)
}
if vectors == nil {
if len(vectorMap) == 0 {
return fmt.Errorf("no vector data provided for point %s", point.ID)
}
vectors = qdrant.NewVectorsMap(vectorMap)
}
qPoints = append(qPoints, &qdrant.PointStruct{
Id: qdrant.NewIDUUID(point.ID),
@@ -183,6 +222,41 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f
return points, scores, nil
}
func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}) ([]qdrantPoint, []float64, error) {
if limit <= 0 {
limit = 10
}
if len(indices) == 0 || len(values) == 0 {
return nil, nil, nil
}
if s.sparseVectorName == "" {
return nil, nil, fmt.Errorf("sparse vector name not configured")
}
filter := buildQdrantFilter(filters)
using := qdrant.PtrOf(s.sparseVectorName)
results, err := s.client.Query(ctx, &qdrant.QueryPoints{
CollectionName: s.collection,
Query: qdrant.NewQuerySparse(indices, values),
Using: using,
Limit: qdrant.PtrOf(uint64(limit)),
Filter: filter,
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return nil, nil, err
}
points := make([]qdrantPoint, 0, len(results))
scores := make([]float64, 0, len(results))
for _, scored := range results {
points = append(points, qdrantPoint{
ID: pointIDToString(scored.GetId()),
Payload: valueMapToInterface(scored.GetPayload()),
})
scores = append(scores, float64(scored.GetScore()))
}
return points, scores, nil
}
func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) {
pointsBySource := make(map[string][]qdrantPoint, len(sources))
scoresBySource := make(map[string][]float64, len(sources))
@@ -204,6 +278,27 @@ func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, lim
return pointsBySource, scoresBySource, nil
}
func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}, sources []string) (map[string][]qdrantPoint, map[string][]float64, error) {
pointsBySource := make(map[string][]qdrantPoint, len(sources))
scoresBySource := make(map[string][]float64, len(sources))
if len(sources) == 0 {
return pointsBySource, scoresBySource, nil
}
for _, source := range sources {
merged := cloneFilters(filters)
if source != "" {
merged["source"] = source
}
points, scores, err := s.SearchSparse(ctx, indices, values, limit, merged)
if err != nil {
return nil, nil, err
}
pointsBySource[source] = points
scoresBySource[source] = scores
}
return pointsBySource, scoresBySource, nil
}
func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) {
result, err := s.client.Get(ctx, &qdrant.GetPoints{
CollectionName: s.collection,
@@ -257,6 +352,31 @@ func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]in
return result, nil
}
func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]interface{}, offset *qdrant.PointId) ([]qdrantPoint, *qdrant.PointId, error) {
if limit <= 0 {
limit = 100
}
filter := buildQdrantFilter(filters)
points, nextOffset, err := s.client.ScrollAndOffset(ctx, &qdrant.ScrollPoints{
CollectionName: s.collection,
Limit: qdrant.PtrOf(uint32(limit)),
Filter: filter,
Offset: offset,
WithPayload: qdrant.NewWithPayload(true),
})
if err != nil {
return nil, nil, err
}
result := make([]qdrantPoint, 0, len(points))
for _, point := range points {
result = append(result, qdrantPoint{
ID: pointIDToString(point.GetId()),
Payload: valueMapToInterface(point.GetPayload()),
})
}
return result, nextOffset, nil
}
func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]interface{}) error {
filter := buildQdrantFilter(filters)
if filter == nil {
@@ -278,6 +398,7 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i
if exists {
return s.refreshCollectionSchema(ctx, vectors)
}
var vectorsConfig *qdrant.VectorsConfig
if len(vectors) > 0 {
params := make(map[string]*qdrant.VectorParams, len(vectors))
for name, dim := range vectors {
@@ -286,17 +407,24 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i
Distance: qdrant.Distance_Cosine,
}
}
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
CollectionName: s.collection,
VectorsConfig: qdrant.NewVectorsConfigMap(params),
vectorsConfig = qdrant.NewVectorsConfigMap(params)
} else if s.dimension > 0 {
vectorsConfig = qdrant.NewVectorsConfig(&qdrant.VectorParams{
Size: uint64(s.dimension),
Distance: qdrant.Distance_Cosine,
})
}
var sparseConfig *qdrant.SparseVectorConfig
if s.sparseVectorName != "" {
sparseConfig = qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{
s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
})
}
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
CollectionName: s.collection,
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
Size: uint64(s.dimension),
Distance: qdrant.Distance_Cosine,
}),
CollectionName: s.collection,
VectorsConfig: vectorsConfig,
SparseVectorsConfig: sparseConfig,
})
}
@@ -306,11 +434,12 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
return err
}
config := info.GetConfig()
if config == nil || config.GetParams() == nil || config.GetParams().GetVectorsConfig() == nil {
if config == nil || config.GetParams() == nil {
return nil
}
vectorsConfig := config.GetParams().GetVectorsConfig()
if vectorsConfig.GetParamsMap() != nil {
params := config.GetParams()
vectorsConfig := params.GetVectorsConfig()
if vectorsConfig != nil && vectorsConfig.GetParamsMap() != nil {
s.usesNamedVectors = true
s.vectorNames = map[string]int{}
for name, vec := range vectorsConfig.GetParamsMap().GetMap() {
@@ -319,7 +448,7 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
}
}
if len(vectors) == 0 {
return nil
goto sparseCheck
}
for name, dim := range vectors {
if existing, ok := s.vectorNames[name]; ok && existing == dim {
@@ -327,13 +456,58 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
}
return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim)
}
}
if vectorsConfig == nil || vectorsConfig.GetParamsMap() == nil {
s.usesNamedVectors = false
s.vectorNames = nil
}
sparseCheck:
sparseConfig := params.GetSparseVectorsConfig()
if s.sparseVectorName != "" {
needsUpdate := false
if sparseConfig == nil || len(sparseConfig.GetMap()) == 0 {
needsUpdate = true
} else {
if _, ok := sparseConfig.GetMap()[s.sparseVectorName]; !ok {
needsUpdate = true
}
if _, ok := sparseConfig.GetMap()[sparseVocabVectorName]; !ok {
needsUpdate = true
}
}
if needsUpdate {
if err := s.ensureSparseVectors(ctx); err != nil {
return err
}
}
s.usesSparseVectors = true
return nil
}
s.usesNamedVectors = false
s.vectorNames = nil
if sparseConfig != nil && len(sparseConfig.GetMap()) > 0 {
s.usesSparseVectors = true
for name := range sparseConfig.GetMap() {
s.sparseVectorName = name
break
}
}
return nil
}
func (s *QdrantStore) ensureSparseVectors(ctx context.Context) error {
if s.sparseVectorName == "" {
return nil
}
err := s.client.UpdateCollection(ctx, &qdrant.UpdateCollection{
CollectionName: s.collection,
SparseVectorsConfig: qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{
s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
}),
})
return err
}
func parseQdrantEndpoint(endpoint string) (string, int, bool, error) {
if endpoint == "" {
return "127.0.0.1", 6334, false, nil
+322 -78
View File
@@ -12,6 +12,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/qdrant/go-client/qdrant"
"github.com/memohai/memoh/internal/embeddings"
)
@@ -21,17 +22,19 @@ type Service struct {
embedder embeddings.Embedder
store *QdrantStore
resolver *embeddings.Resolver
bm25 *BM25Indexer
logger *slog.Logger
defaultTextModelID string
defaultMultimodalModelID string
}
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, bm25 *BM25Indexer, defaultTextModelID, defaultMultimodalModelID string) *Service {
return &Service{
llm: llm,
embedder: embedder,
store: store,
resolver: resolver,
bm25: bm25,
logger: log.With(slog.String("service", "memory")),
defaultTextModelID: defaultTextModelID,
defaultMultimodalModelID: defaultMultimodalModelID,
@@ -42,15 +45,16 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
if req.Message == "" && len(req.Messages) == 0 {
return SearchResponse{}, fmt.Errorf("message or messages is required")
}
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
if req.UserID == "" {
return SearchResponse{}, fmt.Errorf("user_id is required")
}
messages := normalizeMessages(req)
filters := buildFilters(req)
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
if req.Infer != nil && !*req.Infer {
return s.addRawMessages(ctx, messages, filters, req.Metadata)
return s.addRawMessages(ctx, messages, filters, req.Metadata, embeddingEnabled)
}
extractResp, err := s.llm.Extract(ctx, ExtractRequest{
@@ -95,7 +99,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
for _, action := range actions {
switch strings.ToUpper(action.Event) {
case "ADD":
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata)
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata, embeddingEnabled)
if err != nil {
return SearchResponse{}, err
}
@@ -104,7 +108,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
})
results = append(results, item)
case "UPDATE":
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata)
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata, embeddingEnabled)
if err != nil {
return SearchResponse{}, err
}
@@ -134,19 +138,19 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
if strings.TrimSpace(req.Query) == "" {
return SearchResponse{}, fmt.Errorf("query is required")
}
if s.store == nil {
return SearchResponse{}, fmt.Errorf("qdrant store not configured")
}
filters := buildSearchFilters(req)
modality := ""
if raw, ok := filters["modality"].(string); ok {
modality = strings.ToLower(strings.TrimSpace(raw))
}
var (
vector []float32
store *QdrantStore
vectorName string
err error
)
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
if modality == embeddings.TypeMultimodal {
if !embeddingEnabled {
return SearchResponse{}, fmt.Errorf("embedding is disabled")
}
if s.resolver == nil {
return SearchResponse{}, fmt.Errorf("embeddings resolver not configured")
}
@@ -159,24 +163,79 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
if err != nil {
return SearchResponse{}, err
}
vector = result.Embedding
store = s.store
vectorName = s.vectorNameForMultimodal()
} else {
vector, err = s.embedder.Embed(ctx, req.Query)
vectorName := s.vectorNameForMultimodal()
if len(req.Sources) == 0 {
points, scores, err := s.store.Search(ctx, result.Embedding, req.Limit, filters, vectorName)
if err != nil {
return SearchResponse{}, err
}
results := make([]MemoryItem, 0, len(points))
for idx, point := range points {
item := payloadToMemoryItem(point.ID, point.Payload)
if idx < len(scores) {
item.Score = scores[idx]
}
results = append(results, item)
}
return SearchResponse{Results: results}, nil
}
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, result.Embedding, req.Limit, filters, req.Sources, vectorName)
if err != nil {
return SearchResponse{}, err
}
store = s.store
vectorName = s.vectorNameForText()
results := fuseByRankFusion(pointsBySource, scoresBySource)
return SearchResponse{Results: results}, nil
}
if len(req.Sources) == 0 {
points, scores, err := store.Search(ctx, vector, req.Limit, filters, vectorName)
if embeddingEnabled {
if s.embedder == nil {
return SearchResponse{}, fmt.Errorf("embedder not configured")
}
vector, err := s.embedder.Embed(ctx, req.Query)
if err != nil {
return SearchResponse{}, err
}
vectorName := s.vectorNameForText()
if len(req.Sources) == 0 {
points, scores, err := s.store.Search(ctx, vector, req.Limit, filters, vectorName)
if err != nil {
return SearchResponse{}, err
}
results := make([]MemoryItem, 0, len(points))
for idx, point := range points {
item := payloadToMemoryItem(point.ID, point.Payload)
if idx < len(scores) {
item.Score = scores[idx]
}
results = append(results, item)
}
return SearchResponse{Results: results}, nil
}
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
if err != nil {
return SearchResponse{}, err
}
results := fuseByRankFusion(pointsBySource, scoresBySource)
return SearchResponse{Results: results}, nil
}
if s.bm25 == nil {
return SearchResponse{}, fmt.Errorf("bm25 indexer not configured")
}
lang, err := s.detectLanguage(ctx, req.Query)
if err != nil {
return SearchResponse{}, err
}
termFreq, _, err := s.bm25.TermFrequencies(lang, req.Query)
if err != nil {
return SearchResponse{}, err
}
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
if len(req.Sources) == 0 {
points, scores, err := s.store.SearchSparse(ctx, indices, values, req.Limit, filters)
if err != nil {
return SearchResponse{}, err
}
results := make([]MemoryItem, 0, len(points))
for idx, point := range points {
item := payloadToMemoryItem(point.ID, point.Payload)
@@ -187,8 +246,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
}
return SearchResponse{Results: results}, nil
}
pointsBySource, scoresBySource, err := store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
pointsBySource, scoresBySource, err := s.store.SearchSparseBySources(ctx, indices, values, req.Limit, filters, req.Sources)
if err != nil {
return SearchResponse{}, err
}
@@ -200,8 +258,8 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
if s.resolver == nil {
return EmbedUpsertResponse{}, fmt.Errorf("embeddings resolver not configured")
}
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
return EmbedUpsertResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
if req.UserID == "" {
return EmbedUpsertResponse{}, fmt.Errorf("user_id is required")
}
req.Type = strings.TrimSpace(req.Type)
req.Provider = strings.TrimSpace(req.Provider)
@@ -264,6 +322,12 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
if strings.TrimSpace(req.Memory) == "" {
return MemoryItem{}, fmt.Errorf("memory is required")
}
if s.store == nil {
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
}
if s.bm25 == nil {
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
}
existing, err := s.store.Get(ctx, req.MemoryID)
if err != nil {
@@ -272,22 +336,58 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
if existing == nil {
return MemoryItem{}, fmt.Errorf("memory not found")
}
if s.bm25 == nil {
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
}
payload := existing.Payload
payload["data"] = req.Memory
payload["hash"] = hashMemory(req.Memory)
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
oldText := fmt.Sprint(payload["data"])
oldLang := fmt.Sprint(payload["lang"])
if oldLang == "" && strings.TrimSpace(oldText) != "" {
oldLang, _ = s.detectLanguage(ctx, oldText)
}
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
if err == nil {
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
}
}
vector, err := s.embedder.Embed(ctx, req.Memory)
newLang, err := s.detectLanguage(ctx, req.Memory)
if err != nil {
return MemoryItem{}, err
}
if err := s.store.Upsert(ctx, []qdrantPoint{{
ID: req.MemoryID,
Vector: vector,
VectorName: s.vectorNameForText(),
Payload: payload,
}}); err != nil {
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, req.Memory)
if err != nil {
return MemoryItem{}, err
}
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
payload["data"] = req.Memory
payload["hash"] = hashMemory(req.Memory)
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
payload["lang"] = newLang
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
point := qdrantPoint{
ID: req.MemoryID,
SparseIndices: sparseIndices,
SparseValues: sparseValues,
SparseVectorName: s.store.sparseVectorName,
Payload: payload,
}
if embeddingEnabled {
if s.embedder == nil {
return MemoryItem{}, fmt.Errorf("embedder not configured")
}
vector, err := s.embedder.Embed(ctx, req.Memory)
if err != nil {
return MemoryItem{}, err
}
point.Vector = vector
point.VectorName = s.vectorNameForText()
}
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
return MemoryItem{}, err
}
return payloadToMemoryItem(req.MemoryID, payload), nil
@@ -312,14 +412,11 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse
if req.UserID != "" {
filters["userId"] = req.UserID
}
if req.AgentID != "" {
filters["agentId"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
}
if len(filters) == 0 {
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
return SearchResponse{}, fmt.Errorf("user_id is required")
}
points, err := s.store.List(ctx, req.Limit, filters)
@@ -348,14 +445,11 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
if req.UserID != "" {
filters["userId"] = req.UserID
}
if req.AgentID != "" {
filters["agentId"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
}
if len(filters) == 0 {
return DeleteResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
return DeleteResponse{}, fmt.Errorf("user_id is required")
}
if err := s.store.DeleteAll(ctx, filters); err != nil {
return DeleteResponse{}, err
@@ -363,10 +457,46 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
return DeleteResponse{Message: "Memories deleted successfully!"}, nil
}
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}) (SearchResponse, error) {
func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error {
if s.bm25 == nil || s.store == nil {
return nil
}
var offset *qdrant.PointId
for {
points, next, err := s.store.Scroll(ctx, batchSize, nil, offset)
if err != nil {
return err
}
if len(points) == 0 {
break
}
for _, point := range points {
text := fmt.Sprint(point.Payload["data"])
if strings.TrimSpace(text) == "" {
continue
}
lang := fmt.Sprint(point.Payload["lang"])
if lang == "" {
lang = fallbackLanguageCode(text)
}
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
if err != nil {
continue
}
s.bm25.AddDocument(lang, termFreq, docLen)
}
if next == nil {
break
}
offset = next
}
return nil
}
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (SearchResponse, error) {
results := make([]MemoryItem, 0, len(messages))
for _, message := range messages {
item, err := s.applyAdd(ctx, message.Content, filters, metadata)
item, err := s.applyAdd(ctx, message.Content, filters, metadata, embeddingEnabled)
if err != nil {
return SearchResponse{}, err
}
@@ -381,11 +511,19 @@ func (s *Service) addRawMessages(ctx context.Context, messages []Message, filter
func (s *Service) collectCandidates(ctx context.Context, facts []string, filters map[string]interface{}) ([]CandidateMemory, error) {
unique := map[string]CandidateMemory{}
for _, fact := range facts {
vector, err := s.embedder.Embed(ctx, fact)
if s.bm25 == nil {
return nil, fmt.Errorf("bm25 indexer not configured")
}
lang, err := s.detectLanguage(ctx, fact)
if err != nil {
return nil, err
}
points, _, err := s.store.Search(ctx, vector, 5, filters, s.vectorNameForText())
termFreq, _, err := s.bm25.TermFrequencies(lang, fact)
if err != nil {
return nil, err
}
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
points, _, err := s.store.SearchSparse(ctx, indices, values, 5, filters)
if err != nil {
return nil, err
}
@@ -406,25 +544,50 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters
return candidates, nil
}
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
vector, err := s.embedder.Embed(ctx, text)
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
if s.store == nil {
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
}
if s.bm25 == nil {
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
}
lang, err := s.detectLanguage(ctx, text)
if err != nil {
return MemoryItem{}, err
}
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
if err != nil {
return MemoryItem{}, err
}
sparseIndices, sparseValues := s.bm25.AddDocument(lang, termFreq, docLen)
id := uuid.NewString()
payload := buildPayload(text, filters, metadata, "")
if err := s.store.Upsert(ctx, []qdrantPoint{{
ID: id,
Vector: vector,
VectorName: s.vectorNameForText(),
Payload: payload,
}}); err != nil {
payload["lang"] = lang
point := qdrantPoint{
ID: id,
SparseIndices: sparseIndices,
SparseValues: sparseValues,
SparseVectorName: s.store.sparseVectorName,
Payload: payload,
}
if embeddingEnabled {
if s.embedder == nil {
return MemoryItem{}, fmt.Errorf("embedder not configured")
}
vector, err := s.embedder.Embed(ctx, text)
if err != nil {
return MemoryItem{}, err
}
point.Vector = vector
point.VectorName = s.vectorNameForText()
}
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
return MemoryItem{}, err
}
return payloadToMemoryItem(id, payload), nil
}
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
if strings.TrimSpace(id) == "" {
return MemoryItem{}, fmt.Errorf("update action missing id")
}
@@ -437,25 +600,55 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
}
payload := existing.Payload
oldText := fmt.Sprint(payload["data"])
oldLang := fmt.Sprint(payload["lang"])
if oldLang == "" && strings.TrimSpace(oldText) != "" {
oldLang, _ = s.detectLanguage(ctx, oldText)
}
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
if err == nil {
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
}
}
newLang, err := s.detectLanguage(ctx, text)
if err != nil {
return MemoryItem{}, err
}
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, text)
if err != nil {
return MemoryItem{}, err
}
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
payload["data"] = text
payload["hash"] = hashMemory(text)
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
payload["lang"] = newLang
if metadata != nil {
payload["metadata"] = mergeMetadata(payload["metadata"], metadata)
}
if filters != nil {
applyFiltersToPayload(payload, filters)
}
vector, err := s.embedder.Embed(ctx, text)
if err != nil {
return MemoryItem{}, err
point := qdrantPoint{
ID: id,
SparseIndices: sparseIndices,
SparseValues: sparseValues,
SparseVectorName: s.store.sparseVectorName,
Payload: payload,
}
if err := s.store.Upsert(ctx, []qdrantPoint{{
ID: id,
Vector: vector,
VectorName: s.vectorNameForText(),
Payload: payload,
}}); err != nil {
if embeddingEnabled {
if s.embedder == nil {
return MemoryItem{}, fmt.Errorf("embedder not configured")
}
vector, err := s.embedder.Embed(ctx, text)
if err != nil {
return MemoryItem{}, err
}
point.Vector = vector
point.VectorName = s.vectorNameForText()
}
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
return MemoryItem{}, err
}
return payloadToMemoryItem(id, payload), nil
@@ -473,6 +666,19 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error
return MemoryItem{}, fmt.Errorf("memory not found")
}
item := payloadToMemoryItem(id, existing.Payload)
if s.bm25 != nil {
oldText := fmt.Sprint(existing.Payload["data"])
oldLang := fmt.Sprint(existing.Payload["lang"])
if oldLang == "" && strings.TrimSpace(oldText) != "" {
oldLang, _ = s.detectLanguage(ctx, oldText)
}
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
if err == nil {
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
}
}
}
if err := s.store.Delete(ctx, id); err != nil {
return MemoryItem{}, err
}
@@ -486,6 +692,56 @@ func normalizeMessages(req AddRequest) []Message {
return []Message{{Role: "user", Content: req.Message}}
}
func (s *Service) detectLanguage(ctx context.Context, text string) (string, error) {
if s.llm == nil {
return "", fmt.Errorf("language detector not configured")
}
lang, err := s.llm.DetectLanguage(ctx, text)
if err == nil && lang != "" {
return lang, nil
}
fallback := fallbackLanguageCode(text)
if s.logger != nil {
s.logger.Warn("language detection failed; using fallback", slog.Any("error", err), slog.String("fallback", fallback))
}
return fallback, nil
}
func fallbackLanguageCode(text string) string {
for _, r := range text {
if isCJKRune(r) {
return "cjk"
}
}
return "en"
}
func isCJKRune(r rune) bool {
switch {
case r >= 0x4E00 && r <= 0x9FFF: // CJK Unified Ideographs
return true
case r >= 0x3400 && r <= 0x4DBF: // CJK Unified Ideographs Extension A
return true
case r >= 0x20000 && r <= 0x2A6DF: // CJK Unified Ideographs Extension B
return true
case r >= 0x2A700 && r <= 0x2B73F: // CJK Unified Ideographs Extension C
return true
case r >= 0x2B740 && r <= 0x2B81F: // CJK Unified Ideographs Extension D
return true
case r >= 0x2B820 && r <= 0x2CEAF: // CJK Unified Ideographs Extension E
return true
case r >= 0x2CEB0 && r <= 0x2EBEF: // CJK Unified Ideographs Extension F
return true
case r >= 0x3000 && r <= 0x303F: // CJK Symbols and Punctuation
return true
case r >= 0x3040 && r <= 0x30FF: // Hiragana/Katakana
return true
case r >= 0xAC00 && r <= 0xD7AF: // Hangul Syllables
return true
}
return false
}
func buildFilters(req AddRequest) map[string]interface{} {
filters := map[string]interface{}{}
for key, value := range req.Filters {
@@ -494,9 +750,6 @@ func buildFilters(req AddRequest) map[string]interface{} {
if req.UserID != "" {
filters["userId"] = req.UserID
}
if req.AgentID != "" {
filters["agentId"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
}
@@ -511,9 +764,6 @@ func buildSearchFilters(req SearchRequest) map[string]interface{} {
if req.UserID != "" {
filters["userId"] = req.UserID
}
if req.AgentID != "" {
filters["agentId"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
}
@@ -528,9 +778,6 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} {
if req.UserID != "" {
filters["userId"] = req.UserID
}
if req.AgentID != "" {
filters["agentId"] = req.AgentID
}
if req.RunID != "" {
filters["runId"] = req.RunID
}
@@ -621,9 +868,6 @@ func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem {
if v, ok := payload["userId"].(string); ok {
item.UserID = v
}
if v, ok := payload["agentId"].(string); ok {
item.AgentID = v
}
if v, ok := payload["runId"].(string); ok {
item.RunID = v
}
+24 -26
View File
@@ -6,6 +6,7 @@ import "context"
type LLM interface {
Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
DetectLanguage(ctx context.Context, text string) (string, error)
}
type Message struct {
@@ -14,42 +15,41 @@ type Message struct {
}
type AddRequest struct {
Message string `json:"message,omitempty"`
Messages []Message `json:"messages,omitempty"`
UserID string `json:"user_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Infer *bool `json:"infer,omitempty"`
Message string `json:"message,omitempty"`
Messages []Message `json:"messages,omitempty"`
UserID string `json:"user_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Infer *bool `json:"infer,omitempty"`
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
}
type SearchRequest struct {
Query string `json:"query"`
UserID string `json:"user_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Limit int `json:"limit,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Sources []string `json:"sources,omitempty"`
Query string `json:"query"`
UserID string `json:"user_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Limit int `json:"limit,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
Sources []string `json:"sources,omitempty"`
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
}
type UpdateRequest struct {
MemoryID string `json:"memory_id"`
Memory string `json:"memory"`
MemoryID string `json:"memory_id"`
Memory string `json:"memory"`
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
}
type GetAllRequest struct {
UserID string `json:"user_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Limit int `json:"limit,omitempty"`
UserID string `json:"user_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Limit int `json:"limit,omitempty"`
}
type DeleteAllRequest struct {
UserID string `json:"user_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
UserID string `json:"user_id,omitempty"`
RunID string `json:"run_id,omitempty"`
}
type EmbedInput struct {
@@ -65,7 +65,6 @@ type EmbedUpsertRequest struct {
Input EmbedInput `json:"input"`
Source string `json:"source,omitempty"`
UserID string `json:"user_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
RunID string `json:"run_id,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Filters map[string]interface{} `json:"filters,omitempty"`
@@ -87,7 +86,6 @@ type MemoryItem struct {
Score float64 `json:"score,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
UserID string `json:"userId,omitempty"`
AgentID string `json:"agentId,omitempty"`
RunID string `json:"runId,omitempty"`
}