fix(models): guard bot memory model type and fallback for memory LLM (#61)

This commit is contained in:
Ringo.Typowriter
2026-02-17 20:14:44 +08:00
committed by GitHub
parent cd8cb59236
commit daed9d2d95
4 changed files with 90 additions and 1 deletions
+2 -1
View File
@@ -676,7 +676,8 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
if c.modelsService == nil || c.queries == nil {
return nil, fmt.Errorf("models service not configured")
}
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries)
botID := memory.BotIDFromContext(ctx)
memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, botID)
if err != nil {
return nil, err
}
+28
View File
@@ -0,0 +1,28 @@
package memory
import (
"context"
"strings"
)
type contextKey string
const memoryBotIDContextKey contextKey = "memory_bot_id"
// WithBotID attaches bot ID to context so model selection can honor bot settings.
func WithBotID(ctx context.Context, botID string) context.Context {
botID = strings.TrimSpace(botID)
if botID == "" {
return ctx
}
return context.WithValue(ctx, memoryBotIDContextKey, botID)
}
// BotIDFromContext returns bot ID carried by WithBotID.
func BotIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
botID, _ := ctx.Value(memoryBotIDContextKey).(string)
return strings.TrimSpace(botID)
}
+22
View File
@@ -51,6 +51,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
messages := normalizeMessages(req)
filters := buildFilters(req)
ctx = WithBotID(ctx, resolveBotID(req.BotID, filters))
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
if req.Infer != nil && !*req.Infer {
@@ -142,6 +143,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
return SearchResponse{}, fmt.Errorf("qdrant store not configured")
}
filters := buildSearchFilters(req)
ctx = WithBotID(ctx, resolveBotID(req.BotID, filters))
modality := ""
if raw, ok := filters["modality"].(string); ok {
modality = strings.ToLower(strings.TrimSpace(raw))
@@ -359,6 +361,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
if existing == nil {
return MemoryItem{}, fmt.Errorf("memory not found")
}
ctx = WithBotID(ctx, resolveBotID("", existing.Payload))
payload := existing.Payload
oldText := fmt.Sprint(payload["data"])
@@ -530,6 +533,7 @@ func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio flo
if ratio <= 0 || ratio > 1 {
ratio = 0.5
}
ctx = WithBotID(ctx, resolveBotID("", filters))
// Fetch all existing memories.
points, err := s.store.List(ctx, 0, filters, false)
@@ -1007,6 +1011,24 @@ func buildFilters(req AddRequest) map[string]any {
return filters
}
func resolveBotID(explicitBotID string, filters map[string]any) string {
if botID := strings.TrimSpace(explicitBotID); botID != "" {
return botID
}
if len(filters) == 0 {
return ""
}
if raw, ok := filters["bot_id"].(string); ok {
if botID := strings.TrimSpace(raw); botID != "" {
return botID
}
}
if raw, ok := filters["scopeId"].(string); ok {
return strings.TrimSpace(raw)
}
return ""
}
func buildSearchFilters(req SearchRequest) map[string]any {
filters := map[string]any{}
for key, value := range req.Filters {
+38
View File
@@ -391,6 +391,9 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql
if modelsService == nil {
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
}
if queries == nil {
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("queries not configured")
}
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
if err != nil || len(candidates) == 0 {
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
@@ -403,6 +406,41 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql
return selected, provider, nil
}
// SelectMemoryModelForBot selects memory model by bot settings first, then falls back to SelectMemoryModel.
func SelectMemoryModelForBot(ctx context.Context, modelsService *Service, queries *sqlc.Queries, botID string) (GetResponse, sqlc.LlmProvider, error) {
botID = strings.TrimSpace(botID)
if botID == "" {
return SelectMemoryModel(ctx, modelsService, queries)
}
if queries == nil {
return SelectMemoryModel(ctx, modelsService, queries)
}
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return SelectMemoryModel(ctx, modelsService, queries)
}
bot, err := queries.GetBotByID(ctx, pgBotID)
if err != nil {
return SelectMemoryModel(ctx, modelsService, queries)
}
if !bot.MemoryModelID.Valid {
return SelectMemoryModel(ctx, modelsService, queries)
}
dbModel, err := queries.GetModelByID(ctx, bot.MemoryModelID)
if err != nil {
return SelectMemoryModel(ctx, modelsService, queries)
}
selected := convertToGetResponse(dbModel)
if selected.Type != ModelTypeChat {
return SelectMemoryModel(ctx, modelsService, queries)
}
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return selected, provider, nil
}
// FetchProviderByID fetches a provider by ID.
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
if strings.TrimSpace(providerID) == "" {