mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
refactor: use sparse vector for memory
This commit is contained in:
+73
-58
@@ -107,69 +107,27 @@ func main() {
|
||||
}
|
||||
|
||||
resolver := embeddings.NewResolver(logger.L, modelsService, queries, 10*time.Second)
|
||||
vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
|
||||
vectors, textModel, multimodalModel, hasEmbeddingModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
|
||||
if err != nil {
|
||||
logger.Error("embedding models", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var memoryService *memory.Service
|
||||
var memoryHandler *handlers.MemoryHandler
|
||||
|
||||
if !hasModels {
|
||||
logger.Warn("No embedding models configured. Memory service will not be available.")
|
||||
logger.Warn("You can add embedding models via the /models API endpoint.")
|
||||
memoryHandler = handlers.NewMemoryHandler(logger.L, nil)
|
||||
} else {
|
||||
if textModel.ModelID == "" {
|
||||
logger.Warn("No text embedding model configured. Text embedding features will be limited.")
|
||||
}
|
||||
if multimodalModel.ModelID == "" {
|
||||
logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
|
||||
}
|
||||
|
||||
var textEmbedder embeddings.Embedder
|
||||
var store *memory.QdrantStore
|
||||
|
||||
if textModel.ModelID != "" && textModel.Dimensions > 0 {
|
||||
textEmbedder = &embeddings.ResolverTextEmbedder{
|
||||
Resolver: resolver,
|
||||
ModelID: textModel.ModelID,
|
||||
Dims: textModel.Dimensions,
|
||||
}
|
||||
|
||||
if len(vectors) > 0 {
|
||||
store, err = memory.NewQdrantStoreWithVectors(
|
||||
logger.L,
|
||||
cfg.Qdrant.BaseURL,
|
||||
cfg.Qdrant.APIKey,
|
||||
cfg.Qdrant.Collection,
|
||||
vectors,
|
||||
time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("qdrant named vectors init", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
store, err = memory.NewQdrantStore(
|
||||
logger.L,
|
||||
cfg.Qdrant.BaseURL,
|
||||
cfg.Qdrant.APIKey,
|
||||
cfg.Qdrant.Collection,
|
||||
textModel.Dimensions,
|
||||
time.Duration(cfg.Qdrant.TimeoutSeconds)*time.Second,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("qdrant init", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
memoryService = memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
|
||||
memoryHandler = handlers.NewMemoryHandler(logger.L, memoryService)
|
||||
textEmbedder := buildTextEmbedder(resolver, textModel, hasEmbeddingModels, logger.L)
|
||||
if hasEmbeddingModels && multimodalModel.ModelID == "" {
|
||||
logger.Warn("No multimodal embedding model configured. Multimodal embedding features will be limited.")
|
||||
}
|
||||
|
||||
store := buildQdrantStore(logger.L, cfg.Qdrant, vectors, hasEmbeddingModels, textModel.Dimensions)
|
||||
|
||||
bm25Indexer := memory.NewBM25Indexer(logger.L)
|
||||
memoryService := memory.NewService(logger.L, llmClient, textEmbedder, store, resolver, bm25Indexer, textModel.ModelID, multimodalModel.ModelID)
|
||||
memoryHandler := handlers.NewMemoryHandler(logger.L, memoryService)
|
||||
go func() {
|
||||
if err := memoryService.WarmupBM25(ctx, 200); err != nil {
|
||||
logger.Warn("bm25 warmup failed", slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
chatResolver = chat.NewResolver(logger.L, modelsService, queries, memoryService, cfg.AgentGateway.BaseURL(), 30*time.Second)
|
||||
embeddingsHandler := handlers.NewEmbeddingsHandler(logger.L, modelsService, queries)
|
||||
swaggerHandler := handlers.NewSwaggerHandler(logger.L)
|
||||
@@ -197,7 +155,7 @@ func main() {
|
||||
scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService)
|
||||
subagentService := subagent.NewService(logger.L, queries)
|
||||
subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService)
|
||||
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler, /*channelHandler*/)
|
||||
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler /*channelHandler*/)
|
||||
|
||||
if err := srv.Start(); err != nil {
|
||||
logger.Error("server failed", slog.Any("error", err))
|
||||
@@ -205,6 +163,55 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetResponse, hasModels bool, log *slog.Logger) embeddings.Embedder {
|
||||
if !hasModels {
|
||||
return nil
|
||||
}
|
||||
if textModel.ModelID == "" || textModel.Dimensions <= 0 {
|
||||
log.Warn("No text embedding model configured. Text embedding features will be limited.")
|
||||
return nil
|
||||
}
|
||||
return &embeddings.ResolverTextEmbedder{
|
||||
Resolver: resolver,
|
||||
ModelID: textModel.ModelID,
|
||||
Dims: textModel.Dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
func buildQdrantStore(log *slog.Logger, cfg config.QdrantConfig, vectors map[string]int, hasModels bool, textDims int) *memory.QdrantStore {
|
||||
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
||||
if hasModels && len(vectors) > 0 {
|
||||
store, err := memory.NewQdrantStoreWithVectors(
|
||||
log,
|
||||
cfg.BaseURL,
|
||||
cfg.APIKey,
|
||||
cfg.Collection,
|
||||
vectors,
|
||||
"sparse_hash",
|
||||
timeout,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("qdrant named vectors init", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
return store
|
||||
}
|
||||
store, err := memory.NewQdrantStore(
|
||||
log,
|
||||
cfg.BaseURL,
|
||||
cfg.APIKey,
|
||||
cfg.Collection,
|
||||
textDims,
|
||||
"sparse_hash",
|
||||
timeout,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("qdrant init", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error {
|
||||
if queries == nil {
|
||||
return fmt.Errorf("db queries not configured")
|
||||
@@ -279,6 +286,14 @@ func (c *lazyLLMClient) Decide(ctx context.Context, req memory.DecideRequest) (m
|
||||
return client.Decide(ctx, req)
|
||||
}
|
||||
|
||||
func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
|
||||
client, err := c.resolve(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return client.DetectLanguage(ctx, text)
|
||||
}
|
||||
|
||||
func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
|
||||
if c.modelsService == nil || c.queries == nil {
|
||||
return nil, fmt.Errorf("models service not configured")
|
||||
|
||||
+8
-24
@@ -135,11 +135,7 @@ const docTemplate = `{
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
}
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
@@ -827,12 +823,6 @@ const docTemplate = `{
|
||||
],
|
||||
"summary": "List memories",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Agent ID",
|
||||
"name": "agent_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Run ID",
|
||||
@@ -3061,8 +3051,8 @@ const docTemplate = `{
|
||||
"handlers.memoryAddPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
"embedding_enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
@@ -3092,9 +3082,6 @@ const docTemplate = `{
|
||||
"handlers.memoryDeleteAllPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"run_id": {
|
||||
"type": "string"
|
||||
}
|
||||
@@ -3103,9 +3090,6 @@ const docTemplate = `{
|
||||
"handlers.memoryEmbedUpsertPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
@@ -3137,8 +3121,8 @@ const docTemplate = `{
|
||||
"handlers.memorySearchPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
"embedding_enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
@@ -3267,9 +3251,6 @@ const docTemplate = `{
|
||||
"memory.MemoryItem": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agentId": {
|
||||
"type": "string"
|
||||
},
|
||||
"createdAt": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -3329,6 +3310,9 @@ const docTemplate = `{
|
||||
"memory.UpdateRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"embedding_enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"memory": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
+8
-24
@@ -126,11 +126,7 @@
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
}
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
@@ -818,12 +814,6 @@
|
||||
],
|
||||
"summary": "List memories",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Agent ID",
|
||||
"name": "agent_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Run ID",
|
||||
@@ -3052,8 +3042,8 @@
|
||||
"handlers.memoryAddPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
"embedding_enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
@@ -3083,9 +3073,6 @@
|
||||
"handlers.memoryDeleteAllPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"run_id": {
|
||||
"type": "string"
|
||||
}
|
||||
@@ -3094,9 +3081,6 @@
|
||||
"handlers.memoryEmbedUpsertPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
@@ -3128,8 +3112,8 @@
|
||||
"handlers.memorySearchPayload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
"embedding_enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
@@ -3258,9 +3242,6 @@
|
||||
"memory.MemoryItem": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agentId": {
|
||||
"type": "string"
|
||||
},
|
||||
"createdAt": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -3320,6 +3301,9 @@
|
||||
"memory.UpdateRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"embedding_enabled": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"memory": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
+7
-18
@@ -269,8 +269,8 @@ definitions:
|
||||
type: object
|
||||
handlers.memoryAddPayload:
|
||||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
embedding_enabled:
|
||||
type: boolean
|
||||
filters:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
@@ -290,15 +290,11 @@ definitions:
|
||||
type: object
|
||||
handlers.memoryDeleteAllPayload:
|
||||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
run_id:
|
||||
type: string
|
||||
type: object
|
||||
handlers.memoryEmbedUpsertPayload:
|
||||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
filters:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
@@ -320,8 +316,8 @@ definitions:
|
||||
type: object
|
||||
handlers.memorySearchPayload:
|
||||
properties:
|
||||
agent_id:
|
||||
type: string
|
||||
embedding_enabled:
|
||||
type: boolean
|
||||
filters:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
@@ -405,8 +401,6 @@ definitions:
|
||||
type: object
|
||||
memory.MemoryItem:
|
||||
properties:
|
||||
agentId:
|
||||
type: string
|
||||
createdAt:
|
||||
type: string
|
||||
hash:
|
||||
@@ -446,6 +440,8 @@ definitions:
|
||||
type: object
|
||||
memory.UpdateRequest:
|
||||
properties:
|
||||
embedding_enabled:
|
||||
type: boolean
|
||||
memory:
|
||||
type: string
|
||||
memory_id:
|
||||
@@ -872,10 +868,7 @@ paths:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
items:
|
||||
format: int32
|
||||
type: integer
|
||||
type: array
|
||||
type: string
|
||||
"400":
|
||||
description: Bad Request
|
||||
schema:
|
||||
@@ -1365,10 +1358,6 @@ paths:
|
||||
description: 'List memories for a user via memory. Auth: Bearer JWT determines
|
||||
user_id (sub or user_id).'
|
||||
parameters:
|
||||
- description: Agent ID
|
||||
in: query
|
||||
name: agent_id
|
||||
type: string
|
||||
- description: Run ID
|
||||
in: query
|
||||
name: run_id
|
||||
|
||||
@@ -4,15 +4,19 @@ go 1.25.2
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
github.com/blevesearch/bleve/v2 v2.5.7
|
||||
github.com/containerd/containerd/api v1.10.0
|
||||
github.com/containerd/containerd/v2 v2.2.1
|
||||
github.com/containerd/errdefs v1.0.0
|
||||
github.com/containerd/platforms v1.0.0-rc.2
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/labstack/echo-jwt/v4 v4.4.0
|
||||
github.com/labstack/echo/v4 v4.15.0
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||
github.com/opencontainers/go-digest v1.0.0
|
||||
github.com/opencontainers/image-spec v1.1.1
|
||||
github.com/opencontainers/runtime-spec v1.3.0
|
||||
github.com/qdrant/go-client v1.16.2
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
@@ -25,13 +29,20 @@ require (
|
||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/Microsoft/hcsshim v0.14.0-rc.1 // indirect
|
||||
github.com/bits-and-blooms/bitset v1.22.0 // indirect
|
||||
github.com/blevesearch/bleve_index_api v1.2.11 // indirect
|
||||
github.com/blevesearch/geo v0.2.4 // indirect
|
||||
github.com/blevesearch/go-porterstemmer v1.0.3 // indirect
|
||||
github.com/blevesearch/segment v0.9.1 // indirect
|
||||
github.com/blevesearch/snowballstem v0.9.0 // indirect
|
||||
github.com/blevesearch/stempel v0.2.0 // indirect
|
||||
github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/cgroups/v3 v3.1.2 // indirect
|
||||
github.com/containerd/continuity v0.4.5 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/fifo v1.1.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v1.0.0-rc.2 // indirect
|
||||
github.com/containerd/plugin v1.0.0 // indirect
|
||||
github.com/containerd/ttrpc v1.2.7 // indirect
|
||||
github.com/containerd/typeurl/v2 v2.2.3 // indirect
|
||||
@@ -51,7 +62,6 @@ require (
|
||||
github.com/go-openapi/swag/stringutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/typeutils v0.25.4 // indirect
|
||||
github.com/go-openapi/swag/yamlutils v0.25.4 // indirect
|
||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
@@ -59,9 +69,9 @@ require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.3 // indirect
|
||||
github.com/labstack/gommon v0.4.2 // indirect
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/moby/locker v1.0.1 // indirect
|
||||
@@ -70,8 +80,8 @@ require (
|
||||
github.com/moby/sys/signal v0.7.1 // indirect
|
||||
github.com/moby/sys/user v0.4.0 // indirect
|
||||
github.com/moby/sys/userns v0.1.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
|
||||
github.com/opencontainers/selinux v1.13.1 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
||||
@@ -10,6 +10,24 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/Microsoft/hcsshim v0.14.0-rc.1 h1:qAPXKwGOkVn8LlqgBN8GS0bxZ83hOJpcjxzmlQKxKsQ=
|
||||
github.com/Microsoft/hcsshim v0.14.0-rc.1/go.mod h1:hTKFGbnDtQb1wHiOWv4v0eN+7boSWAHyK/tNAaYZL0c=
|
||||
github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCkcs2uw7w4=
|
||||
github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
|
||||
github.com/blevesearch/bleve/v2 v2.5.7 h1:2d9YrL5zrX5EBBW++GOaEKjE+NPWeZGaX77IM26m1Z8=
|
||||
github.com/blevesearch/bleve/v2 v2.5.7/go.mod h1:yj0NlS7ocGC4VOSAedqDDMktdh2935v2CSWOCDMHdSA=
|
||||
github.com/blevesearch/bleve_index_api v1.2.11 h1:bXQ54kVuwP8hdrXUSOnvTQfgK0KI1+f9A0ITJT8tX1s=
|
||||
github.com/blevesearch/bleve_index_api v1.2.11/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0=
|
||||
github.com/blevesearch/geo v0.2.4 h1:ECIGQhw+QALCZaDcogRTNSJYQXRtC8/m8IKiA706cqk=
|
||||
github.com/blevesearch/geo v0.2.4/go.mod h1:K56Q33AzXt2YExVHGObtmRSFYZKYGv0JEN5mdacJJR8=
|
||||
github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo=
|
||||
github.com/blevesearch/go-porterstemmer v1.0.3/go.mod h1:angGc5Ht+k2xhJdZi511LtmxuEf0OVpvUUNrwmM1P7M=
|
||||
github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU=
|
||||
github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw=
|
||||
github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s=
|
||||
github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs=
|
||||
github.com/blevesearch/stempel v0.2.0 h1:CYzVPaScODMvgE9o+kf6D4RJ/VRomyi9uHF+PtB+Afc=
|
||||
github.com/blevesearch/stempel v0.2.0/go.mod h1:wjeTHqQv+nQdbPuJ/YcvOjTInA2EIc6Ks1FoSUzSLvc=
|
||||
github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A=
|
||||
github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
@@ -85,8 +103,6 @@ github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxE
|
||||
github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg=
|
||||
github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls=
|
||||
github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54=
|
||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
|
||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
@@ -115,12 +131,12 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -129,6 +145,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
@@ -143,8 +161,6 @@ github.com/labstack/echo/v4 v4.15.0 h1:hoRTKWcnR5STXZFe9BmYun9AMTNeSbjHi2vtDuADJ
|
||||
github.com/labstack/echo/v4 v4.15.0/go.mod h1:xmw1clThob0BSVRX1CRQkGQ/vjwcpOMjQZSZa9fKA/c=
|
||||
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
|
||||
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
@@ -163,6 +179,12 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8=
|
||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@@ -267,6 +289,7 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
|
||||
@@ -76,7 +76,7 @@ func (h *ChatHandler) Chat(c echo.Context) error {
|
||||
// @Accept json
|
||||
// @Produce text/event-stream
|
||||
// @Param request body chat.ChatRequest true "Chat request"
|
||||
// @Success 200 {object} chat.StreamChunk
|
||||
// @Success 200 {string} string
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /chat/stream [post]
|
||||
|
||||
+34
-40
@@ -18,22 +18,22 @@ type MemoryHandler struct {
|
||||
}
|
||||
|
||||
type memoryAddPayload struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Messages []memory.Message `json:"messages,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Messages []memory.Message `json:"messages,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type memorySearchPayload struct {
|
||||
Query string `json:"query"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Sources []string `json:"sources,omitempty"`
|
||||
Query string `json:"query"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Sources []string `json:"sources,omitempty"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type memoryEmbedUpsertPayload struct {
|
||||
@@ -42,15 +42,13 @@ type memoryEmbedUpsertPayload struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Input memory.EmbedInput `json:"input"`
|
||||
Source string `json:"source,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
}
|
||||
|
||||
type memoryDeleteAllPayload struct {
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
}
|
||||
|
||||
func NewMemoryHandler(log *slog.Logger, service *memory.Service) *MemoryHandler {
|
||||
@@ -74,7 +72,7 @@ func (h *MemoryHandler) Register(e *echo.Echo) {
|
||||
|
||||
func (h *MemoryHandler) checkService() error {
|
||||
if h.service == nil {
|
||||
return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available: no embedding models configured")
|
||||
return echo.NewHTTPError(http.StatusServiceUnavailable, "memory service not available")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -109,7 +107,6 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
|
||||
Input: payload.Input,
|
||||
Source: payload.Source,
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
Metadata: payload.Metadata,
|
||||
Filters: payload.Filters,
|
||||
@@ -146,14 +143,14 @@ func (h *MemoryHandler) Add(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
req := memory.AddRequest{
|
||||
Message: payload.Message,
|
||||
Messages: payload.Messages,
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
Metadata: payload.Metadata,
|
||||
Filters: payload.Filters,
|
||||
Infer: payload.Infer,
|
||||
Message: payload.Message,
|
||||
Messages: payload.Messages,
|
||||
UserID: userID,
|
||||
RunID: payload.RunID,
|
||||
Metadata: payload.Metadata,
|
||||
Filters: payload.Filters,
|
||||
Infer: payload.Infer,
|
||||
EmbeddingEnabled: payload.EmbeddingEnabled,
|
||||
}
|
||||
|
||||
resp, err := h.service.Add(c.Request().Context(), req)
|
||||
@@ -187,13 +184,13 @@ func (h *MemoryHandler) Search(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
req := memory.SearchRequest{
|
||||
Query: payload.Query,
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
Limit: payload.Limit,
|
||||
Filters: payload.Filters,
|
||||
Sources: payload.Sources,
|
||||
Query: payload.Query,
|
||||
UserID: userID,
|
||||
RunID: payload.RunID,
|
||||
Limit: payload.Limit,
|
||||
Filters: payload.Filters,
|
||||
Sources: payload.Sources,
|
||||
EmbeddingEnabled: payload.EmbeddingEnabled,
|
||||
}
|
||||
|
||||
resp, err := h.service.Search(c.Request().Context(), req)
|
||||
@@ -281,7 +278,6 @@ func (h *MemoryHandler) Get(c echo.Context) error {
|
||||
// @Summary List memories
|
||||
// @Description List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param agent_id query string false "Agent ID"
|
||||
// @Param run_id query string false "Run ID"
|
||||
// @Param limit query int false "Limit"
|
||||
// @Success 200 {object} memory.SearchResponse
|
||||
@@ -299,9 +295,8 @@ func (h *MemoryHandler) GetAll(c echo.Context) error {
|
||||
}
|
||||
|
||||
req := memory.GetAllRequest{
|
||||
UserID: userID,
|
||||
AgentID: c.QueryParam("agent_id"),
|
||||
RunID: c.QueryParam("run_id"),
|
||||
UserID: userID,
|
||||
RunID: c.QueryParam("run_id"),
|
||||
}
|
||||
if limit := c.QueryParam("limit"); limit != "" {
|
||||
var parsed int
|
||||
@@ -380,9 +375,8 @@ func (h *MemoryHandler) DeleteAll(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
req := memory.DeleteAllRequest{
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
UserID: userID,
|
||||
RunID: payload.RunID,
|
||||
}
|
||||
|
||||
resp, err := h.service.DeleteAll(c.Request().Context(), req)
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log/slog"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/blevesearch/bleve/v2/registry"
|
||||
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/analyzer/standard"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ar"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/bg"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ca"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/cjk"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ckb"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/da"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/de"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/el"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/en"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/es"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/eu"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fa"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fi"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/fr"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ga"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/gl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hi"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hr"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hu"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/hy"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/id"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/it"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/nl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/no"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/pl"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/pt"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ro"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/ru"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/sv"
|
||||
_ "github.com/blevesearch/bleve/v2/analysis/lang/tr"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBM25K1 = 1.2
|
||||
defaultBM25B = 0.75
|
||||
sparseDimBits = 20
|
||||
sparseDimSize = 1 << sparseDimBits
|
||||
sparseDimMask = sparseDimSize - 1
|
||||
)
|
||||
|
||||
type BM25Indexer struct {
|
||||
cache *registry.Cache
|
||||
logger *slog.Logger
|
||||
k1 float64
|
||||
b float64
|
||||
|
||||
mu sync.RWMutex
|
||||
stats map[string]*bm25Stats
|
||||
}
|
||||
|
||||
type bm25Stats struct {
|
||||
DocCount int
|
||||
AvgDocLen float64
|
||||
DocFreq map[string]int
|
||||
}
|
||||
|
||||
func NewBM25Indexer(log *slog.Logger) *BM25Indexer {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
return &BM25Indexer{
|
||||
cache: registry.NewCache(),
|
||||
logger: log.With(slog.String("indexer", "bm25")),
|
||||
k1: defaultBM25K1,
|
||||
b: defaultBM25B,
|
||||
stats: map[string]*bm25Stats{},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) TermFrequencies(lang, text string) (map[string]int, int, error) {
|
||||
analyzerName, err := b.normalizeAnalyzer(lang)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
analyzer, err := b.cache.AnalyzerNamed(analyzerName)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("bm25 analyzer %s: %w", analyzerName, err)
|
||||
}
|
||||
tokens := analyzer.Analyze([]byte(text))
|
||||
freq := map[string]int{}
|
||||
docLen := 0
|
||||
for _, token := range tokens {
|
||||
term := strings.TrimSpace(string(token.Term))
|
||||
if term == "" {
|
||||
continue
|
||||
}
|
||||
freq[term]++
|
||||
docLen++
|
||||
}
|
||||
return freq, docLen, nil
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) AddDocument(lang string, termFreq map[string]int, docLen int) (indices []uint32, values []float32) {
|
||||
b.mu.Lock()
|
||||
stats := b.ensureStatsLocked(lang)
|
||||
b.updateStatsAddLocked(stats, termFreq, docLen)
|
||||
indices, values = b.buildDocVectorLocked(stats, termFreq, docLen)
|
||||
b.mu.Unlock()
|
||||
return indices, values
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) RemoveDocument(lang string, termFreq map[string]int, docLen int) {
|
||||
b.mu.Lock()
|
||||
stats := b.ensureStatsLocked(lang)
|
||||
b.updateStatsRemoveLocked(stats, termFreq, docLen)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) BuildQueryVector(lang string, termFreq map[string]int) (indices []uint32, values []float32) {
|
||||
b.mu.RLock()
|
||||
stats := b.ensureStatsLocked(lang)
|
||||
indices, values = b.buildQueryVectorLocked(stats, termFreq)
|
||||
b.mu.RUnlock()
|
||||
return indices, values
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) normalizeAnalyzer(lang string) (string, error) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(lang))
|
||||
switch normalized {
|
||||
case "":
|
||||
return "standard", nil
|
||||
case "in":
|
||||
normalized = "id"
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) ensureStatsLocked(lang string) *bm25Stats {
|
||||
name, _ := b.normalizeAnalyzer(lang)
|
||||
stats := b.stats[name]
|
||||
if stats == nil {
|
||||
stats = &bm25Stats{
|
||||
DocFreq: map[string]int{},
|
||||
}
|
||||
b.stats[name] = stats
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) updateStatsAddLocked(stats *bm25Stats, termFreq map[string]int, docLen int) {
|
||||
totalDocs := stats.DocCount
|
||||
stats.DocCount++
|
||||
totalLen := stats.AvgDocLen * float64(totalDocs)
|
||||
stats.AvgDocLen = (totalLen + float64(docLen)) / float64(stats.DocCount)
|
||||
for term := range termFreq {
|
||||
stats.DocFreq[term]++
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) updateStatsRemoveLocked(stats *bm25Stats, termFreq map[string]int, docLen int) {
|
||||
if stats.DocCount <= 0 {
|
||||
return
|
||||
}
|
||||
totalDocs := stats.DocCount
|
||||
totalLen := stats.AvgDocLen * float64(totalDocs)
|
||||
stats.DocCount--
|
||||
if stats.DocCount > 0 {
|
||||
stats.AvgDocLen = (totalLen - float64(docLen)) / float64(stats.DocCount)
|
||||
} else {
|
||||
stats.AvgDocLen = 0
|
||||
}
|
||||
for term := range termFreq {
|
||||
if stats.DocFreq[term] > 1 {
|
||||
stats.DocFreq[term]--
|
||||
} else {
|
||||
delete(stats.DocFreq, term)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) buildDocVectorLocked(stats *bm25Stats, termFreq map[string]int, docLen int) ([]uint32, []float32) {
|
||||
if stats.DocCount == 0 || docLen == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
avgDocLen := stats.AvgDocLen
|
||||
if avgDocLen <= 0 {
|
||||
avgDocLen = 1
|
||||
}
|
||||
weights := map[uint32]float32{}
|
||||
for term, tf := range termFreq {
|
||||
df := stats.DocFreq[term]
|
||||
if df == 0 {
|
||||
continue
|
||||
}
|
||||
idf := math.Log(1 + (float64(stats.DocCount)-float64(df)+0.5)/(float64(df)+0.5))
|
||||
numerator := float64(tf) * (b.k1 + 1)
|
||||
denominator := float64(tf) + b.k1*(1-b.b+b.b*float64(docLen)/avgDocLen)
|
||||
tfNorm := numerator / denominator
|
||||
weight := float32(tfNorm * idf)
|
||||
if weight == 0 {
|
||||
continue
|
||||
}
|
||||
index := termHash(term)
|
||||
weights[index] += weight
|
||||
}
|
||||
return sparseWeightsToVector(weights)
|
||||
}
|
||||
|
||||
func (b *BM25Indexer) buildQueryVectorLocked(stats *bm25Stats, termFreq map[string]int) ([]uint32, []float32) {
|
||||
if stats.DocCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
weights := map[uint32]float32{}
|
||||
for term, tf := range termFreq {
|
||||
if stats.DocFreq[term] == 0 {
|
||||
continue
|
||||
}
|
||||
weight := float32(tf)
|
||||
if weight == 0 {
|
||||
continue
|
||||
}
|
||||
index := termHash(term)
|
||||
weights[index] += weight
|
||||
}
|
||||
return sparseWeightsToVector(weights)
|
||||
}
|
||||
|
||||
func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) {
|
||||
if len(weights) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
indices := make([]uint32, 0, len(weights))
|
||||
for idx := range weights {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
sort.Slice(indices, func(i, j int) bool { return indices[i] < indices[j] })
|
||||
values := make([]float32, 0, len(indices))
|
||||
for _, idx := range indices {
|
||||
values = append(values, weights[idx])
|
||||
}
|
||||
return indices, values
|
||||
}
|
||||
|
||||
func termHash(term string) uint32 {
|
||||
hasher := fnv.New32a()
|
||||
_, _ = hasher.Write([]byte(term))
|
||||
return hasher.Sum32() & sparseDimMask
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.
|
||||
}
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
if model == "" {
|
||||
model = "gpt-4.1-nano-2025-04-14"
|
||||
model = "gpt-4.1-nano"
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
@@ -119,6 +119,31 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon
|
||||
return DecideResponse{Actions: actions}, nil
|
||||
}
|
||||
|
||||
func (c *LLMClient) DetectLanguage(ctx context.Context, text string) (string, error) {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "", fmt.Errorf("text is required")
|
||||
}
|
||||
systemPrompt, userPrompt := getLanguageDetectionMessages(text)
|
||||
content, err := c.callChat(ctx, []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parsed struct {
|
||||
Language string `json:"language"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(removeCodeBlocks(content)), &parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
lang := strings.ToLower(strings.TrimSpace(parsed.Language))
|
||||
if !isAllowedLanguageCode(lang) {
|
||||
return "", fmt.Errorf("unsupported language code: %s", lang)
|
||||
}
|
||||
return lang, nil
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
@@ -247,3 +272,14 @@ func normalizeMemoryItems(value interface{}) []map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isAllowedLanguageCode(code string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(code)) {
|
||||
case "ar", "bg", "ca", "cjk", "ckb", "da", "de", "el", "en", "es", "eu",
|
||||
"fa", "fi", "fr", "ga", "gl", "hi", "hr", "hu", "hy", "id", "in",
|
||||
"it", "nl", "no", "pl", "pt", "ro", "ru", "sv", "tr":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,6 +106,20 @@ Follow the instruction mentioned below:
|
||||
Do not return anything except the JSON format.`, toJSON(retrievedOldMemory), toJSON(newRetrievedFacts), "```json", "```")
|
||||
}
|
||||
|
||||
func getLanguageDetectionMessages(text string) (string, string) {
|
||||
systemPrompt := `You are a language classifier for the given input text.
|
||||
Return a JSON object with a single key "language" whose value is one of the allowed codes.
|
||||
Allowed codes: ar, bg, ca, cjk, ckb, da, de, el, en, es, eu, fa, fi, fr, ga, gl, hi, hr, hu, hy, id, in, it, nl, no, pl, pt, ro, ru, sv, tr.
|
||||
Use "cjk" for Chinese/Japanese/Korean text, ckb=Kurdish(Sorani), ga=Irish(Gaelic), gl=Galician, eu=Basque, hy=Armenian, fa=Persian, hr=Croatian, hu=Hungarian, ro=Romanian, bg=Bulgarian. If unsure between id/in, use id.
|
||||
If multiple languages appear, choose the dominant language.
|
||||
Do not include any extra keys, comments, or formatting. Output must be valid JSON only.
|
||||
If the text is Chinese, Japanese, or Korean, output exactly {"language":"cjk"}.
|
||||
Never output "zh", "zh-cn", "zh-tw", "ja", "ko", or any code not in the allowed list.
|
||||
Before finalizing, verify the value is one of the allowed codes.`
|
||||
userPrompt := fmt.Sprintf("Text:\n%s", text)
|
||||
return systemPrompt, userPrompt
|
||||
}
|
||||
|
||||
func parseMessages(messages []string) string {
|
||||
return strings.Join(messages, "\n")
|
||||
}
|
||||
|
||||
+226
-52
@@ -12,34 +12,47 @@ import (
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
)
|
||||
|
||||
const (
|
||||
sparseHashVectorName = "sparse_hash"
|
||||
sparseVocabVectorName = "sparse_vocab"
|
||||
)
|
||||
|
||||
type QdrantStore struct {
|
||||
client *qdrant.Client
|
||||
collection string
|
||||
dimension int
|
||||
baseURL string
|
||||
apiKey string
|
||||
timeout time.Duration
|
||||
logger *slog.Logger
|
||||
vectorNames map[string]int
|
||||
usesNamedVectors bool
|
||||
client *qdrant.Client
|
||||
collection string
|
||||
dimension int
|
||||
baseURL string
|
||||
apiKey string
|
||||
timeout time.Duration
|
||||
logger *slog.Logger
|
||||
vectorNames map[string]int
|
||||
usesNamedVectors bool
|
||||
sparseVectorName string
|
||||
usesSparseVectors bool
|
||||
}
|
||||
|
||||
type qdrantPoint struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector"`
|
||||
VectorName string `json:"vector_name,omitempty"`
|
||||
Payload map[string]interface{} `json:"payload,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector"`
|
||||
VectorName string `json:"vector_name,omitempty"`
|
||||
SparseIndices []uint32 `json:"sparse_indices,omitempty"`
|
||||
SparseValues []float32 `json:"sparse_values,omitempty"`
|
||||
SparseVectorName string `json:"sparse_vector_name,omitempty"`
|
||||
Payload map[string]interface{} `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, timeout time.Duration) (*QdrantStore, error) {
|
||||
func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) {
|
||||
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(sparseVectorName) == "" {
|
||||
sparseVectorName = sparseHashVectorName
|
||||
}
|
||||
if collection == "" {
|
||||
collection = "memory"
|
||||
}
|
||||
if dimension <= 0 {
|
||||
if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" {
|
||||
dimension = 1536
|
||||
}
|
||||
|
||||
@@ -55,13 +68,15 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
|
||||
}
|
||||
|
||||
store := &QdrantStore{
|
||||
client: client,
|
||||
collection: collection,
|
||||
dimension: dimension,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
client: client,
|
||||
collection: collection,
|
||||
dimension: dimension,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
sparseVectorName: strings.TrimSpace(sparseVectorName),
|
||||
usesSparseVectors: strings.TrimSpace(sparseVectorName) != "",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
|
||||
@@ -73,14 +88,17 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens
|
||||
}
|
||||
|
||||
func (s *QdrantStore) NewSibling(collection string, dimension int) (*QdrantStore, error) {
|
||||
return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.timeout)
|
||||
return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.sparseVectorName, s.timeout)
|
||||
}
|
||||
|
||||
func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, timeout time.Duration) (*QdrantStore, error) {
|
||||
func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) {
|
||||
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(sparseVectorName) == "" {
|
||||
sparseVectorName = sparseHashVectorName
|
||||
}
|
||||
if collection == "" {
|
||||
collection = "memory"
|
||||
}
|
||||
@@ -100,14 +118,16 @@ func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection str
|
||||
}
|
||||
|
||||
store := &QdrantStore{
|
||||
client: client,
|
||||
collection: collection,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
vectorNames: vectors,
|
||||
usesNamedVectors: true,
|
||||
client: client,
|
||||
collection: collection,
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
timeout: timeoutOrDefault(timeout),
|
||||
logger: log.With(slog.String("store", "qdrant")),
|
||||
vectorNames: vectors,
|
||||
usesNamedVectors: true,
|
||||
sparseVectorName: strings.TrimSpace(sparseVectorName),
|
||||
usesSparseVectors: strings.TrimSpace(sparseVectorName) != "",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
|
||||
@@ -129,12 +149,31 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error {
|
||||
return err
|
||||
}
|
||||
var vectors *qdrant.Vectors
|
||||
if point.VectorName != "" && s.usesNamedVectors {
|
||||
vectors = qdrant.NewVectorsMap(map[string]*qdrant.Vector{
|
||||
point.VectorName: qdrant.NewVectorDense(point.Vector),
|
||||
})
|
||||
} else {
|
||||
vectors = qdrant.NewVectorsDense(point.Vector)
|
||||
vectorMap := map[string]*qdrant.Vector{}
|
||||
if len(point.Vector) > 0 {
|
||||
if point.VectorName != "" && s.usesNamedVectors {
|
||||
vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector)
|
||||
} else if !s.usesNamedVectors && len(point.SparseIndices) == 0 {
|
||||
vectors = qdrant.NewVectorsDense(point.Vector)
|
||||
} else if point.VectorName != "" {
|
||||
vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector)
|
||||
}
|
||||
}
|
||||
if len(point.SparseIndices) > 0 && len(point.SparseValues) > 0 {
|
||||
sparseName := strings.TrimSpace(point.SparseVectorName)
|
||||
if sparseName == "" {
|
||||
sparseName = s.sparseVectorName
|
||||
}
|
||||
if sparseName == "" {
|
||||
return fmt.Errorf("sparse vector name is required")
|
||||
}
|
||||
vectorMap[sparseName] = qdrant.NewVectorSparse(point.SparseIndices, point.SparseValues)
|
||||
}
|
||||
if vectors == nil {
|
||||
if len(vectorMap) == 0 {
|
||||
return fmt.Errorf("no vector data provided for point %s", point.ID)
|
||||
}
|
||||
vectors = qdrant.NewVectorsMap(vectorMap)
|
||||
}
|
||||
qPoints = append(qPoints, &qdrant.PointStruct{
|
||||
Id: qdrant.NewIDUUID(point.ID),
|
||||
@@ -183,6 +222,41 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f
|
||||
return points, scores, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}) ([]qdrantPoint, []float64, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
if len(indices) == 0 || len(values) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if s.sparseVectorName == "" {
|
||||
return nil, nil, fmt.Errorf("sparse vector name not configured")
|
||||
}
|
||||
filter := buildQdrantFilter(filters)
|
||||
using := qdrant.PtrOf(s.sparseVectorName)
|
||||
results, err := s.client.Query(ctx, &qdrant.QueryPoints{
|
||||
CollectionName: s.collection,
|
||||
Query: qdrant.NewQuerySparse(indices, values),
|
||||
Using: using,
|
||||
Limit: qdrant.PtrOf(uint64(limit)),
|
||||
Filter: filter,
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
points := make([]qdrantPoint, 0, len(results))
|
||||
scores := make([]float64, 0, len(results))
|
||||
for _, scored := range results {
|
||||
points = append(points, qdrantPoint{
|
||||
ID: pointIDToString(scored.GetId()),
|
||||
Payload: valueMapToInterface(scored.GetPayload()),
|
||||
})
|
||||
scores = append(scores, float64(scored.GetScore()))
|
||||
}
|
||||
return points, scores, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) {
|
||||
pointsBySource := make(map[string][]qdrantPoint, len(sources))
|
||||
scoresBySource := make(map[string][]float64, len(sources))
|
||||
@@ -204,6 +278,27 @@ func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, lim
|
||||
return pointsBySource, scoresBySource, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]interface{}, sources []string) (map[string][]qdrantPoint, map[string][]float64, error) {
|
||||
pointsBySource := make(map[string][]qdrantPoint, len(sources))
|
||||
scoresBySource := make(map[string][]float64, len(sources))
|
||||
if len(sources) == 0 {
|
||||
return pointsBySource, scoresBySource, nil
|
||||
}
|
||||
for _, source := range sources {
|
||||
merged := cloneFilters(filters)
|
||||
if source != "" {
|
||||
merged["source"] = source
|
||||
}
|
||||
points, scores, err := s.SearchSparse(ctx, indices, values, limit, merged)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pointsBySource[source] = points
|
||||
scoresBySource[source] = scores
|
||||
}
|
||||
return pointsBySource, scoresBySource, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) {
|
||||
result, err := s.client.Get(ctx, &qdrant.GetPoints{
|
||||
CollectionName: s.collection,
|
||||
@@ -257,6 +352,31 @@ func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]interface{}, offset *qdrant.PointId) ([]qdrantPoint, *qdrant.PointId, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
filter := buildQdrantFilter(filters)
|
||||
points, nextOffset, err := s.client.ScrollAndOffset(ctx, &qdrant.ScrollPoints{
|
||||
CollectionName: s.collection,
|
||||
Limit: qdrant.PtrOf(uint32(limit)),
|
||||
Filter: filter,
|
||||
Offset: offset,
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
result := make([]qdrantPoint, 0, len(points))
|
||||
for _, point := range points {
|
||||
result = append(result, qdrantPoint{
|
||||
ID: pointIDToString(point.GetId()),
|
||||
Payload: valueMapToInterface(point.GetPayload()),
|
||||
})
|
||||
}
|
||||
return result, nextOffset, nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]interface{}) error {
|
||||
filter := buildQdrantFilter(filters)
|
||||
if filter == nil {
|
||||
@@ -278,6 +398,7 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i
|
||||
if exists {
|
||||
return s.refreshCollectionSchema(ctx, vectors)
|
||||
}
|
||||
var vectorsConfig *qdrant.VectorsConfig
|
||||
if len(vectors) > 0 {
|
||||
params := make(map[string]*qdrant.VectorParams, len(vectors))
|
||||
for name, dim := range vectors {
|
||||
@@ -286,17 +407,24 @@ func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]i
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
}
|
||||
}
|
||||
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
||||
CollectionName: s.collection,
|
||||
VectorsConfig: qdrant.NewVectorsConfigMap(params),
|
||||
vectorsConfig = qdrant.NewVectorsConfigMap(params)
|
||||
} else if s.dimension > 0 {
|
||||
vectorsConfig = qdrant.NewVectorsConfig(&qdrant.VectorParams{
|
||||
Size: uint64(s.dimension),
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
})
|
||||
}
|
||||
var sparseConfig *qdrant.SparseVectorConfig
|
||||
if s.sparseVectorName != "" {
|
||||
sparseConfig = qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{
|
||||
s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
})
|
||||
}
|
||||
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
||||
CollectionName: s.collection,
|
||||
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
|
||||
Size: uint64(s.dimension),
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
}),
|
||||
CollectionName: s.collection,
|
||||
VectorsConfig: vectorsConfig,
|
||||
SparseVectorsConfig: sparseConfig,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -306,11 +434,12 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
return err
|
||||
}
|
||||
config := info.GetConfig()
|
||||
if config == nil || config.GetParams() == nil || config.GetParams().GetVectorsConfig() == nil {
|
||||
if config == nil || config.GetParams() == nil {
|
||||
return nil
|
||||
}
|
||||
vectorsConfig := config.GetParams().GetVectorsConfig()
|
||||
if vectorsConfig.GetParamsMap() != nil {
|
||||
params := config.GetParams()
|
||||
vectorsConfig := params.GetVectorsConfig()
|
||||
if vectorsConfig != nil && vectorsConfig.GetParamsMap() != nil {
|
||||
s.usesNamedVectors = true
|
||||
s.vectorNames = map[string]int{}
|
||||
for name, vec := range vectorsConfig.GetParamsMap().GetMap() {
|
||||
@@ -319,7 +448,7 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
}
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return nil
|
||||
goto sparseCheck
|
||||
}
|
||||
for name, dim := range vectors {
|
||||
if existing, ok := s.vectorNames[name]; ok && existing == dim {
|
||||
@@ -327,13 +456,58 @@ func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[s
|
||||
}
|
||||
return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim)
|
||||
}
|
||||
}
|
||||
if vectorsConfig == nil || vectorsConfig.GetParamsMap() == nil {
|
||||
s.usesNamedVectors = false
|
||||
s.vectorNames = nil
|
||||
}
|
||||
|
||||
sparseCheck:
|
||||
sparseConfig := params.GetSparseVectorsConfig()
|
||||
if s.sparseVectorName != "" {
|
||||
needsUpdate := false
|
||||
if sparseConfig == nil || len(sparseConfig.GetMap()) == 0 {
|
||||
needsUpdate = true
|
||||
} else {
|
||||
if _, ok := sparseConfig.GetMap()[s.sparseVectorName]; !ok {
|
||||
needsUpdate = true
|
||||
}
|
||||
if _, ok := sparseConfig.GetMap()[sparseVocabVectorName]; !ok {
|
||||
needsUpdate = true
|
||||
}
|
||||
}
|
||||
if needsUpdate {
|
||||
if err := s.ensureSparseVectors(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.usesSparseVectors = true
|
||||
return nil
|
||||
}
|
||||
s.usesNamedVectors = false
|
||||
s.vectorNames = nil
|
||||
if sparseConfig != nil && len(sparseConfig.GetMap()) > 0 {
|
||||
s.usesSparseVectors = true
|
||||
for name := range sparseConfig.GetMap() {
|
||||
s.sparseVectorName = name
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *QdrantStore) ensureSparseVectors(ctx context.Context) error {
|
||||
if s.sparseVectorName == "" {
|
||||
return nil
|
||||
}
|
||||
err := s.client.UpdateCollection(ctx, &qdrant.UpdateCollection{
|
||||
CollectionName: s.collection,
|
||||
SparseVectorsConfig: qdrant.NewSparseVectorsConfig(map[string]*qdrant.SparseVectorParams{
|
||||
s.sparseVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
sparseVocabVectorName: {Modifier: qdrant.PtrOf(qdrant.Modifier_None)},
|
||||
}),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func parseQdrantEndpoint(endpoint string) (string, int, bool, error) {
|
||||
if endpoint == "" {
|
||||
return "127.0.0.1", 6334, false, nil
|
||||
|
||||
+322
-78
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
|
||||
"github.com/memohai/memoh/internal/embeddings"
|
||||
)
|
||||
@@ -21,17 +22,19 @@ type Service struct {
|
||||
embedder embeddings.Embedder
|
||||
store *QdrantStore
|
||||
resolver *embeddings.Resolver
|
||||
bm25 *BM25Indexer
|
||||
logger *slog.Logger
|
||||
defaultTextModelID string
|
||||
defaultMultimodalModelID string
|
||||
}
|
||||
|
||||
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, bm25 *BM25Indexer, defaultTextModelID, defaultMultimodalModelID string) *Service {
|
||||
return &Service{
|
||||
llm: llm,
|
||||
embedder: embedder,
|
||||
store: store,
|
||||
resolver: resolver,
|
||||
bm25: bm25,
|
||||
logger: log.With(slog.String("service", "memory")),
|
||||
defaultTextModelID: defaultTextModelID,
|
||||
defaultMultimodalModelID: defaultMultimodalModelID,
|
||||
@@ -42,15 +45,16 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
if req.Message == "" && len(req.Messages) == 0 {
|
||||
return SearchResponse{}, fmt.Errorf("message or messages is required")
|
||||
}
|
||||
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
|
||||
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
if req.UserID == "" {
|
||||
return SearchResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
|
||||
messages := normalizeMessages(req)
|
||||
filters := buildFilters(req)
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
if req.Infer != nil && !*req.Infer {
|
||||
return s.addRawMessages(ctx, messages, filters, req.Metadata)
|
||||
return s.addRawMessages(ctx, messages, filters, req.Metadata, embeddingEnabled)
|
||||
}
|
||||
|
||||
extractResp, err := s.llm.Extract(ctx, ExtractRequest{
|
||||
@@ -95,7 +99,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
for _, action := range actions {
|
||||
switch strings.ToUpper(action.Event) {
|
||||
case "ADD":
|
||||
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata)
|
||||
item, err := s.applyAdd(ctx, action.Text, filters, req.Metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -104,7 +108,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
})
|
||||
results = append(results, item)
|
||||
case "UPDATE":
|
||||
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata)
|
||||
item, err := s.applyUpdate(ctx, action.ID, action.Text, filters, req.Metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -134,19 +138,19 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
return SearchResponse{}, fmt.Errorf("query is required")
|
||||
}
|
||||
if s.store == nil {
|
||||
return SearchResponse{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
filters := buildSearchFilters(req)
|
||||
modality := ""
|
||||
if raw, ok := filters["modality"].(string); ok {
|
||||
modality = strings.ToLower(strings.TrimSpace(raw))
|
||||
}
|
||||
|
||||
var (
|
||||
vector []float32
|
||||
store *QdrantStore
|
||||
vectorName string
|
||||
err error
|
||||
)
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
if modality == embeddings.TypeMultimodal {
|
||||
if !embeddingEnabled {
|
||||
return SearchResponse{}, fmt.Errorf("embedding is disabled")
|
||||
}
|
||||
if s.resolver == nil {
|
||||
return SearchResponse{}, fmt.Errorf("embeddings resolver not configured")
|
||||
}
|
||||
@@ -159,24 +163,79 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
vector = result.Embedding
|
||||
store = s.store
|
||||
vectorName = s.vectorNameForMultimodal()
|
||||
} else {
|
||||
vector, err = s.embedder.Embed(ctx, req.Query)
|
||||
vectorName := s.vectorNameForMultimodal()
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.Search(ctx, result.Embedding, req.Limit, filters, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
if idx < len(scores) {
|
||||
item.Score = scores[idx]
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, result.Embedding, req.Limit, filters, req.Sources, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
store = s.store
|
||||
vectorName = s.vectorNameForText()
|
||||
results := fuseByRankFusion(pointsBySource, scoresBySource)
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := store.Search(ctx, vector, req.Limit, filters, vectorName)
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return SearchResponse{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
vectorName := s.vectorNameForText()
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.Search(ctx, vector, req.Limit, filters, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
if idx < len(scores) {
|
||||
item.Score = scores[idx]
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
pointsBySource, scoresBySource, err := s.store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := fuseByRankFusion(pointsBySource, scoresBySource)
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
if s.bm25 == nil {
|
||||
return SearchResponse{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
termFreq, _, err := s.bm25.TermFrequencies(lang, req.Query)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
|
||||
if len(req.Sources) == 0 {
|
||||
points, scores, err := s.store.SearchSparse(ctx, indices, values, req.Limit, filters)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
results := make([]MemoryItem, 0, len(points))
|
||||
for idx, point := range points {
|
||||
item := payloadToMemoryItem(point.ID, point.Payload)
|
||||
@@ -187,8 +246,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
}
|
||||
return SearchResponse{Results: results}, nil
|
||||
}
|
||||
|
||||
pointsBySource, scoresBySource, err := store.SearchBySources(ctx, vector, req.Limit, filters, req.Sources, vectorName)
|
||||
pointsBySource, scoresBySource, err := s.store.SearchSparseBySources(ctx, indices, values, req.Limit, filters, req.Sources)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -200,8 +258,8 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe
|
||||
if s.resolver == nil {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("embeddings resolver not configured")
|
||||
}
|
||||
if req.UserID == "" && req.AgentID == "" && req.RunID == "" {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
if req.UserID == "" {
|
||||
return EmbedUpsertResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
req.Type = strings.TrimSpace(req.Type)
|
||||
req.Provider = strings.TrimSpace(req.Provider)
|
||||
@@ -264,6 +322,12 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if strings.TrimSpace(req.Memory) == "" {
|
||||
return MemoryItem{}, fmt.Errorf("memory is required")
|
||||
}
|
||||
if s.store == nil {
|
||||
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
|
||||
existing, err := s.store.Get(ctx, req.MemoryID)
|
||||
if err != nil {
|
||||
@@ -272,22 +336,58 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if existing == nil {
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
|
||||
payload := existing.Payload
|
||||
payload["data"] = req.Memory
|
||||
payload["hash"] = hashMemory(req.Memory)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
|
||||
vector, err := s.embedder.Embed(ctx, req.Memory)
|
||||
newLang, err := s.detectLanguage(ctx, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: req.MemoryID,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
|
||||
|
||||
payload["data"] = req.Memory
|
||||
payload["hash"] = hashMemory(req.Memory)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
point := qdrantPoint{
|
||||
ID: req.MemoryID,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, req.Memory)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(req.MemoryID, payload), nil
|
||||
@@ -312,14 +412,11 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return SearchResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
return SearchResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
|
||||
points, err := s.store.List(ctx, req.Limit, filters)
|
||||
@@ -348,14 +445,11 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return DeleteResponse{}, fmt.Errorf("user_id, agent_id or run_id is required")
|
||||
return DeleteResponse{}, fmt.Errorf("user_id is required")
|
||||
}
|
||||
if err := s.store.DeleteAll(ctx, filters); err != nil {
|
||||
return DeleteResponse{}, err
|
||||
@@ -363,10 +457,46 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe
|
||||
return DeleteResponse{Message: "Memories deleted successfully!"}, nil
|
||||
}
|
||||
|
||||
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}) (SearchResponse, error) {
|
||||
func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error {
|
||||
if s.bm25 == nil || s.store == nil {
|
||||
return nil
|
||||
}
|
||||
var offset *qdrant.PointId
|
||||
for {
|
||||
points, next, err := s.store.Scroll(ctx, batchSize, nil, offset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(points) == 0 {
|
||||
break
|
||||
}
|
||||
for _, point := range points {
|
||||
text := fmt.Sprint(point.Payload["data"])
|
||||
if strings.TrimSpace(text) == "" {
|
||||
continue
|
||||
}
|
||||
lang := fmt.Sprint(point.Payload["lang"])
|
||||
if lang == "" {
|
||||
lang = fallbackLanguageCode(text)
|
||||
}
|
||||
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s.bm25.AddDocument(lang, termFreq, docLen)
|
||||
}
|
||||
if next == nil {
|
||||
break
|
||||
}
|
||||
offset = next
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (SearchResponse, error) {
|
||||
results := make([]MemoryItem, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
item, err := s.applyAdd(ctx, message.Content, filters, metadata)
|
||||
item, err := s.applyAdd(ctx, message.Content, filters, metadata, embeddingEnabled)
|
||||
if err != nil {
|
||||
return SearchResponse{}, err
|
||||
}
|
||||
@@ -381,11 +511,19 @@ func (s *Service) addRawMessages(ctx context.Context, messages []Message, filter
|
||||
func (s *Service) collectCandidates(ctx context.Context, facts []string, filters map[string]interface{}) ([]CandidateMemory, error) {
|
||||
unique := map[string]CandidateMemory{}
|
||||
for _, fact := range facts {
|
||||
vector, err := s.embedder.Embed(ctx, fact)
|
||||
if s.bm25 == nil {
|
||||
return nil, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, fact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
points, _, err := s.store.Search(ctx, vector, 5, filters, s.vectorNameForText())
|
||||
termFreq, _, err := s.bm25.TermFrequencies(lang, fact)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
indices, values := s.bm25.BuildQueryVector(lang, termFreq)
|
||||
points, _, err := s.store.SearchSparse(ctx, indices, values, 5, filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -406,25 +544,50 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
|
||||
if s.store == nil {
|
||||
return MemoryItem{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
if s.bm25 == nil {
|
||||
return MemoryItem{}, fmt.Errorf("bm25 indexer not configured")
|
||||
}
|
||||
lang, err := s.detectLanguage(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
termFreq, docLen, err := s.bm25.TermFrequencies(lang, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(lang, termFreq, docLen)
|
||||
id := uuid.NewString()
|
||||
payload := buildPayload(text, filters, metadata, "")
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
payload["lang"] = lang
|
||||
point := qdrantPoint{
|
||||
ID: id,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(id, payload), nil
|
||||
}
|
||||
|
||||
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}) (MemoryItem, error) {
|
||||
func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]interface{}, metadata map[string]interface{}, embeddingEnabled bool) (MemoryItem, error) {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return MemoryItem{}, fmt.Errorf("update action missing id")
|
||||
}
|
||||
@@ -437,25 +600,55 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[
|
||||
}
|
||||
|
||||
payload := existing.Payload
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
oldLang := fmt.Sprint(payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
newLang, err := s.detectLanguage(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
newFreq, newLen, err := s.bm25.TermFrequencies(newLang, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen)
|
||||
payload["data"] = text
|
||||
payload["hash"] = hashMemory(text)
|
||||
payload["updatedAt"] = time.Now().UTC().Format(time.RFC3339)
|
||||
payload["lang"] = newLang
|
||||
if metadata != nil {
|
||||
payload["metadata"] = mergeMetadata(payload["metadata"], metadata)
|
||||
}
|
||||
if filters != nil {
|
||||
applyFiltersToPayload(payload, filters)
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
point := qdrantPoint{
|
||||
ID: id,
|
||||
SparseIndices: sparseIndices,
|
||||
SparseValues: sparseValues,
|
||||
SparseVectorName: s.store.sparseVectorName,
|
||||
Payload: payload,
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{{
|
||||
ID: id,
|
||||
Vector: vector,
|
||||
VectorName: s.vectorNameForText(),
|
||||
Payload: payload,
|
||||
}}); err != nil {
|
||||
if embeddingEnabled {
|
||||
if s.embedder == nil {
|
||||
return MemoryItem{}, fmt.Errorf("embedder not configured")
|
||||
}
|
||||
vector, err := s.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
point.Vector = vector
|
||||
point.VectorName = s.vectorNameForText()
|
||||
}
|
||||
if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
return payloadToMemoryItem(id, payload), nil
|
||||
@@ -473,6 +666,19 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
item := payloadToMemoryItem(id, existing.Payload)
|
||||
if s.bm25 != nil {
|
||||
oldText := fmt.Sprint(existing.Payload["data"])
|
||||
oldLang := fmt.Sprint(existing.Payload["lang"])
|
||||
if oldLang == "" && strings.TrimSpace(oldText) != "" {
|
||||
oldLang, _ = s.detectLanguage(ctx, oldText)
|
||||
}
|
||||
if strings.TrimSpace(oldText) != "" && strings.TrimSpace(oldLang) != "" {
|
||||
oldFreq, oldLen, err := s.bm25.TermFrequencies(oldLang, oldText)
|
||||
if err == nil {
|
||||
s.bm25.RemoveDocument(oldLang, oldFreq, oldLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.store.Delete(ctx, id); err != nil {
|
||||
return MemoryItem{}, err
|
||||
}
|
||||
@@ -486,6 +692,56 @@ func normalizeMessages(req AddRequest) []Message {
|
||||
return []Message{{Role: "user", Content: req.Message}}
|
||||
}
|
||||
|
||||
func (s *Service) detectLanguage(ctx context.Context, text string) (string, error) {
|
||||
if s.llm == nil {
|
||||
return "", fmt.Errorf("language detector not configured")
|
||||
}
|
||||
lang, err := s.llm.DetectLanguage(ctx, text)
|
||||
if err == nil && lang != "" {
|
||||
return lang, nil
|
||||
}
|
||||
fallback := fallbackLanguageCode(text)
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("language detection failed; using fallback", slog.Any("error", err), slog.String("fallback", fallback))
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func fallbackLanguageCode(text string) string {
|
||||
for _, r := range text {
|
||||
if isCJKRune(r) {
|
||||
return "cjk"
|
||||
}
|
||||
}
|
||||
return "en"
|
||||
}
|
||||
|
||||
func isCJKRune(r rune) bool {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF: // CJK Unified Ideographs
|
||||
return true
|
||||
case r >= 0x3400 && r <= 0x4DBF: // CJK Unified Ideographs Extension A
|
||||
return true
|
||||
case r >= 0x20000 && r <= 0x2A6DF: // CJK Unified Ideographs Extension B
|
||||
return true
|
||||
case r >= 0x2A700 && r <= 0x2B73F: // CJK Unified Ideographs Extension C
|
||||
return true
|
||||
case r >= 0x2B740 && r <= 0x2B81F: // CJK Unified Ideographs Extension D
|
||||
return true
|
||||
case r >= 0x2B820 && r <= 0x2CEAF: // CJK Unified Ideographs Extension E
|
||||
return true
|
||||
case r >= 0x2CEB0 && r <= 0x2EBEF: // CJK Unified Ideographs Extension F
|
||||
return true
|
||||
case r >= 0x3000 && r <= 0x303F: // CJK Symbols and Punctuation
|
||||
return true
|
||||
case r >= 0x3040 && r <= 0x30FF: // Hiragana/Katakana
|
||||
return true
|
||||
case r >= 0xAC00 && r <= 0xD7AF: // Hangul Syllables
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildFilters(req AddRequest) map[string]interface{} {
|
||||
filters := map[string]interface{}{}
|
||||
for key, value := range req.Filters {
|
||||
@@ -494,9 +750,6 @@ func buildFilters(req AddRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -511,9 +764,6 @@ func buildSearchFilters(req SearchRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -528,9 +778,6 @@ func buildEmbedFilters(req EmbedUpsertRequest) map[string]interface{} {
|
||||
if req.UserID != "" {
|
||||
filters["userId"] = req.UserID
|
||||
}
|
||||
if req.AgentID != "" {
|
||||
filters["agentId"] = req.AgentID
|
||||
}
|
||||
if req.RunID != "" {
|
||||
filters["runId"] = req.RunID
|
||||
}
|
||||
@@ -621,9 +868,6 @@ func payloadToMemoryItem(id string, payload map[string]interface{}) MemoryItem {
|
||||
if v, ok := payload["userId"].(string); ok {
|
||||
item.UserID = v
|
||||
}
|
||||
if v, ok := payload["agentId"].(string); ok {
|
||||
item.AgentID = v
|
||||
}
|
||||
if v, ok := payload["runId"].(string); ok {
|
||||
item.RunID = v
|
||||
}
|
||||
|
||||
+24
-26
@@ -6,6 +6,7 @@ import "context"
|
||||
type LLM interface {
|
||||
Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error)
|
||||
Decide(ctx context.Context, req DecideRequest) (DecideResponse, error)
|
||||
DetectLanguage(ctx context.Context, text string) (string, error)
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
@@ -14,42 +15,41 @@ type Message struct {
|
||||
}
|
||||
|
||||
type AddRequest struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Sources []string `json:"sources,omitempty"`
|
||||
Query string `json:"query"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Sources []string `json:"sources,omitempty"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
MemoryID string `json:"memory_id"`
|
||||
Memory string `json:"memory"`
|
||||
MemoryID string `json:"memory_id"`
|
||||
Memory string `json:"memory"`
|
||||
EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"`
|
||||
}
|
||||
|
||||
type GetAllRequest struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
type DeleteAllRequest struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
}
|
||||
|
||||
type EmbedInput struct {
|
||||
@@ -65,7 +65,6 @@ type EmbedUpsertRequest struct {
|
||||
Input EmbedInput `json:"input"`
|
||||
Source string `json:"source,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
@@ -87,7 +86,6 @@ type MemoryItem struct {
|
||||
Score float64 `json:"score,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
AgentID string `json:"agentId,omitempty"`
|
||||
RunID string `json:"runId,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user