diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 18ed4b77..29bddcce 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -676,7 +676,8 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) { if c.modelsService == nil || c.queries == nil { return nil, fmt.Errorf("models service not configured") } - memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries) + botID := memory.BotIDFromContext(ctx) + memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, botID) if err != nil { return nil, err } diff --git a/internal/memory/context.go b/internal/memory/context.go new file mode 100644 index 00000000..fdb5d835 --- /dev/null +++ b/internal/memory/context.go @@ -0,0 +1,28 @@ +package memory + +import ( + "context" + "strings" +) + +type contextKey string + +const memoryBotIDContextKey contextKey = "memory_bot_id" + +// WithBotID attaches bot ID to context so model selection can honor bot settings. +func WithBotID(ctx context.Context, botID string) context.Context { + botID = strings.TrimSpace(botID) + if botID == "" { + return ctx + } + return context.WithValue(ctx, memoryBotIDContextKey, botID) +} + +// BotIDFromContext returns bot ID carried by WithBotID. +func BotIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + botID, _ := ctx.Value(memoryBotIDContextKey).(string) + return strings.TrimSpace(botID) +} diff --git a/internal/memory/service.go b/internal/memory/service.go index c1b9b530..6b45d423 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -51,6 +51,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro messages := normalizeMessages(req) filters := buildFilters(req) + ctx = WithBotID(ctx, resolveBotID(req.BotID, filters)) embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled if req.Infer != nil && !*req.Infer { @@ -142,6 +143,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse return SearchResponse{}, fmt.Errorf("qdrant store not configured") } filters := buildSearchFilters(req) + ctx = WithBotID(ctx, resolveBotID(req.BotID, filters)) modality := "" if raw, ok := filters["modality"].(string); ok { modality = strings.ToLower(strings.TrimSpace(raw)) @@ -359,6 +361,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er if existing == nil { return MemoryItem{}, fmt.Errorf("memory not found") } + ctx = WithBotID(ctx, resolveBotID("", existing.Payload)) payload := existing.Payload oldText := fmt.Sprint(payload["data"]) @@ -530,6 +533,7 @@ func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio flo if ratio <= 0 || ratio > 1 { ratio = 0.5 } + ctx = WithBotID(ctx, resolveBotID("", filters)) // Fetch all existing memories. points, err := s.store.List(ctx, 0, filters, false) @@ -1007,6 +1011,24 @@ func buildFilters(req AddRequest) map[string]any { return filters } +func resolveBotID(explicitBotID string, filters map[string]any) string { + if botID := strings.TrimSpace(explicitBotID); botID != "" { + return botID + } + if len(filters) == 0 { + return "" + } + if raw, ok := filters["bot_id"].(string); ok { + if botID := strings.TrimSpace(raw); botID != "" { + return botID + } + } + if raw, ok := filters["scopeId"].(string); ok { + return strings.TrimSpace(raw) + } + return "" +} + func buildSearchFilters(req SearchRequest) map[string]any { filters := map[string]any{} for key, value := range req.Filters { diff --git a/internal/models/models.go b/internal/models/models.go index 623c0705..7ac5b27f 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -391,6 +391,9 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql if modelsService == nil { return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") } + if queries == nil { + return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("queries not configured") + } candidates, err := modelsService.ListByType(ctx, ModelTypeChat) if err != nil || len(candidates) == 0 { return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") @@ -403,6 +406,41 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql return selected, provider, nil } +// SelectMemoryModelForBot selects memory model by bot settings first, then falls back to SelectMemoryModel. +func SelectMemoryModelForBot(ctx context.Context, modelsService *Service, queries *sqlc.Queries, botID string) (GetResponse, sqlc.LlmProvider, error) { + botID = strings.TrimSpace(botID) + if botID == "" { + return SelectMemoryModel(ctx, modelsService, queries) + } + if queries == nil { + return SelectMemoryModel(ctx, modelsService, queries) + } + pgBotID, err := db.ParseUUID(botID) + if err != nil { + return SelectMemoryModel(ctx, modelsService, queries) + } + bot, err := queries.GetBotByID(ctx, pgBotID) + if err != nil { + return SelectMemoryModel(ctx, modelsService, queries) + } + if !bot.MemoryModelID.Valid { + return SelectMemoryModel(ctx, modelsService, queries) + } + dbModel, err := queries.GetModelByID(ctx, bot.MemoryModelID) + if err != nil { + return SelectMemoryModel(ctx, modelsService, queries) + } + selected := convertToGetResponse(dbModel) + if selected.Type != ModelTypeChat { + return SelectMemoryModel(ctx, modelsService, queries) + } + provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID) + if err != nil { + return GetResponse{}, sqlc.LlmProvider{}, err + } + return selected, provider, nil +} + // FetchProviderByID fetches a provider by ID. func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) { if strings.TrimSpace(providerID) == "" {