mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(models): guard bot memory model type and fallback for memory LLM (#61)
This commit is contained in:
+2
-1
@@ -676,7 +676,8 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) {
|
||||
if c.modelsService == nil || c.queries == nil {
|
||||
return nil, fmt.Errorf("models service not configured")
|
||||
}
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries)
|
||||
botID := memory.BotIDFromContext(ctx)
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, botID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const memoryBotIDContextKey contextKey = "memory_bot_id"
|
||||
|
||||
// WithBotID attaches bot ID to context so model selection can honor bot settings.
|
||||
func WithBotID(ctx context.Context, botID string) context.Context {
|
||||
botID = strings.TrimSpace(botID)
|
||||
if botID == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, memoryBotIDContextKey, botID)
|
||||
}
|
||||
|
||||
// BotIDFromContext returns bot ID carried by WithBotID.
|
||||
func BotIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
botID, _ := ctx.Value(memoryBotIDContextKey).(string)
|
||||
return strings.TrimSpace(botID)
|
||||
}
|
||||
@@ -51,6 +51,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro
|
||||
|
||||
messages := normalizeMessages(req)
|
||||
filters := buildFilters(req)
|
||||
ctx = WithBotID(ctx, resolveBotID(req.BotID, filters))
|
||||
|
||||
embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled
|
||||
if req.Infer != nil && !*req.Infer {
|
||||
@@ -142,6 +143,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse
|
||||
return SearchResponse{}, fmt.Errorf("qdrant store not configured")
|
||||
}
|
||||
filters := buildSearchFilters(req)
|
||||
ctx = WithBotID(ctx, resolveBotID(req.BotID, filters))
|
||||
modality := ""
|
||||
if raw, ok := filters["modality"].(string); ok {
|
||||
modality = strings.ToLower(strings.TrimSpace(raw))
|
||||
@@ -359,6 +361,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er
|
||||
if existing == nil {
|
||||
return MemoryItem{}, fmt.Errorf("memory not found")
|
||||
}
|
||||
ctx = WithBotID(ctx, resolveBotID("", existing.Payload))
|
||||
|
||||
payload := existing.Payload
|
||||
oldText := fmt.Sprint(payload["data"])
|
||||
@@ -530,6 +533,7 @@ func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio flo
|
||||
if ratio <= 0 || ratio > 1 {
|
||||
ratio = 0.5
|
||||
}
|
||||
ctx = WithBotID(ctx, resolveBotID("", filters))
|
||||
|
||||
// Fetch all existing memories.
|
||||
points, err := s.store.List(ctx, 0, filters, false)
|
||||
@@ -1007,6 +1011,24 @@ func buildFilters(req AddRequest) map[string]any {
|
||||
return filters
|
||||
}
|
||||
|
||||
func resolveBotID(explicitBotID string, filters map[string]any) string {
|
||||
if botID := strings.TrimSpace(explicitBotID); botID != "" {
|
||||
return botID
|
||||
}
|
||||
if len(filters) == 0 {
|
||||
return ""
|
||||
}
|
||||
if raw, ok := filters["bot_id"].(string); ok {
|
||||
if botID := strings.TrimSpace(raw); botID != "" {
|
||||
return botID
|
||||
}
|
||||
}
|
||||
if raw, ok := filters["scopeId"].(string); ok {
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildSearchFilters(req SearchRequest) map[string]any {
|
||||
filters := map[string]any{}
|
||||
for key, value := range req.Filters {
|
||||
|
||||
@@ -391,6 +391,9 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql
|
||||
if modelsService == nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
|
||||
}
|
||||
if queries == nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("queries not configured")
|
||||
}
|
||||
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
|
||||
if err != nil || len(candidates) == 0 {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
|
||||
@@ -403,6 +406,41 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql
|
||||
return selected, provider, nil
|
||||
}
|
||||
|
||||
// SelectMemoryModelForBot selects memory model by bot settings first, then falls back to SelectMemoryModel.
|
||||
func SelectMemoryModelForBot(ctx context.Context, modelsService *Service, queries *sqlc.Queries, botID string) (GetResponse, sqlc.LlmProvider, error) {
|
||||
botID = strings.TrimSpace(botID)
|
||||
if botID == "" {
|
||||
return SelectMemoryModel(ctx, modelsService, queries)
|
||||
}
|
||||
if queries == nil {
|
||||
return SelectMemoryModel(ctx, modelsService, queries)
|
||||
}
|
||||
pgBotID, err := db.ParseUUID(botID)
|
||||
if err != nil {
|
||||
return SelectMemoryModel(ctx, modelsService, queries)
|
||||
}
|
||||
bot, err := queries.GetBotByID(ctx, pgBotID)
|
||||
if err != nil {
|
||||
return SelectMemoryModel(ctx, modelsService, queries)
|
||||
}
|
||||
if !bot.MemoryModelID.Valid {
|
||||
return SelectMemoryModel(ctx, modelsService, queries)
|
||||
}
|
||||
dbModel, err := queries.GetModelByID(ctx, bot.MemoryModelID)
|
||||
if err != nil {
|
||||
return SelectMemoryModel(ctx, modelsService, queries)
|
||||
}
|
||||
selected := convertToGetResponse(dbModel)
|
||||
if selected.Type != ModelTypeChat {
|
||||
return SelectMemoryModel(ctx, modelsService, queries)
|
||||
}
|
||||
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return selected, provider, nil
|
||||
}
|
||||
|
||||
// FetchProviderByID fetches a provider by ID.
|
||||
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
|
||||
Reference in New Issue
Block a user