mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: add context compaction to automatically summarize old messages (#compaction) (#276)
When input tokens exceed a configurable threshold after a conversation round, the system asynchronously compacts older messages into a summary. Cascading compactions reference prior summaries via <prior_context> tags to maintain conversational continuity without duplicating content. - Add bot_history_message_compacts table and compact_id on messages - Add compaction_enabled, compaction_threshold, compaction_model_id to bots - Implement compaction service (internal/compaction) with LLM summarization - Integrate into conversation flow: replace compacted messages with summaries wrapped in <summary> tags during context loading - Add REST API endpoints (GET/DELETE /bots/:bot_id/compaction/logs) - Add frontend Compaction tab with settings and log viewer - Wire compaction service into both dev (cmd/agent) and prod (cmd/memoh) entry points - Update test mocks to include new GetBotByID columns
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
agentpkg "github.com/memohai/memoh/internal/agent"
|
||||
"github.com/memohai/memoh/internal/compaction"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
memprovider "github.com/memohai/memoh/internal/memory/adapters"
|
||||
@@ -49,19 +50,20 @@ type gatewayAssetLoader interface {
|
||||
|
||||
// Resolver orchestrates chat with the internal agent.
|
||||
type Resolver struct {
|
||||
agent *agentpkg.Agent
|
||||
modelsService *models.Service
|
||||
queries *sqlc.Queries
|
||||
memoryRegistry *memprovider.Registry
|
||||
conversationSvc ConversationSettingsReader
|
||||
messageService messagepkg.Service
|
||||
settingsService *settings.Service
|
||||
sessionService SessionService
|
||||
eventPublisher messageevent.Publisher
|
||||
skillLoader SkillLoader
|
||||
assetLoader gatewayAssetLoader
|
||||
timeout time.Duration
|
||||
logger *slog.Logger
|
||||
agent *agentpkg.Agent
|
||||
modelsService *models.Service
|
||||
queries *sqlc.Queries
|
||||
memoryRegistry *memprovider.Registry
|
||||
conversationSvc ConversationSettingsReader
|
||||
messageService messagepkg.Service
|
||||
settingsService *settings.Service
|
||||
sessionService SessionService
|
||||
compactionService *compaction.Service
|
||||
eventPublisher messageevent.Publisher
|
||||
skillLoader SkillLoader
|
||||
assetLoader gatewayAssetLoader
|
||||
timeout time.Duration
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewResolver creates a Resolver that uses the internal agent directly.
|
||||
@@ -106,6 +108,11 @@ func (r *Resolver) SetGatewayAssetLoader(loader gatewayAssetLoader) {
|
||||
r.assetLoader = loader
|
||||
}
|
||||
|
||||
// SetCompactionService configures the compaction service for context compaction.
|
||||
func (r *Resolver) SetCompactionService(s *compaction.Service) {
|
||||
r.compactionService = s
|
||||
}
|
||||
|
||||
type usageInfo struct {
|
||||
InputTokens *int `json:"inputTokens"`
|
||||
OutputTokens *int `json:"outputTokens"`
|
||||
@@ -199,6 +206,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
}
|
||||
loaded = pruneHistoryForGateway(loaded)
|
||||
loaded = dedupePersistedCurrentUserMessage(loaded, req)
|
||||
loaded = r.replaceCompactedMessages(ctx, loaded)
|
||||
messages = trimMessagesByTokens(r.logger, loaded, historyBudget)
|
||||
r.logger.Debug("context trim result",
|
||||
slog.Int("loaded_messages", len(loaded)),
|
||||
@@ -318,6 +326,11 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv
|
||||
if err := r.storeRound(ctx, req, roundMessages, rc.model.ID); err != nil {
|
||||
return conversation.ChatResponse{}, err
|
||||
}
|
||||
|
||||
if result.Usage != nil {
|
||||
go r.maybeCompact(context.WithoutCancel(ctx), req, rc, result.Usage.InputTokens)
|
||||
}
|
||||
|
||||
return conversation.ChatResponse{
|
||||
Messages: outputMessages,
|
||||
Model: rc.model.ModelID,
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/memohai/memoh/internal/compaction"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatRequest, rc resolvedContext, inputTokens int) {
|
||||
if r.compactionService == nil || r.settingsService == nil {
|
||||
return
|
||||
}
|
||||
settings, err := r.settingsService.GetBot(ctx, req.BotID)
|
||||
if err != nil {
|
||||
r.logger.Warn("compaction: failed to load settings", slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
if !settings.CompactionEnabled || settings.CompactionThreshold <= 0 {
|
||||
return
|
||||
}
|
||||
if !compaction.ShouldCompact(inputTokens, settings.CompactionThreshold) {
|
||||
return
|
||||
}
|
||||
|
||||
modelID := settings.CompactionModelID
|
||||
if modelID == "" {
|
||||
modelID = rc.model.ID
|
||||
}
|
||||
|
||||
cfg := compaction.TriggerConfig{
|
||||
BotID: req.BotID,
|
||||
SessionID: req.SessionID,
|
||||
}
|
||||
|
||||
model, err := r.modelsService.GetByID(ctx, modelID)
|
||||
if err != nil {
|
||||
r.logger.Warn("compaction: failed to resolve model", slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
cfg.ModelID = model.ModelID
|
||||
cfg.ClientType = string(model.ClientType)
|
||||
|
||||
provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
|
||||
if err != nil {
|
||||
r.logger.Warn("compaction: failed to fetch provider", slog.Any("error", err))
|
||||
return
|
||||
}
|
||||
cfg.APIKey = provider.ApiKey
|
||||
cfg.BaseURL = provider.BaseUrl
|
||||
|
||||
r.compactionService.TriggerCompaction(ctx, cfg)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,7 @@ type messageWithUsage struct {
|
||||
ExternalMessageID string
|
||||
Platform string
|
||||
SenderChannelID string
|
||||
CompactID string
|
||||
}
|
||||
|
||||
func (r *Resolver) loadMessages(ctx context.Context, chatID string, sessionID string, maxContextMinutes int) ([]messageWithUsage, error) {
|
||||
@@ -63,6 +65,7 @@ func (r *Resolver) loadMessages(ctx context.Context, chatID string, sessionID st
|
||||
ExternalMessageID: strings.TrimSpace(m.ExternalMessageID),
|
||||
Platform: strings.TrimSpace(m.Platform),
|
||||
SenderChannelID: strings.TrimSpace(m.SenderChannelIdentityID),
|
||||
CompactID: strings.TrimSpace(m.CompactID),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
@@ -165,3 +168,62 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *Resolver) replaceCompactedMessages(ctx context.Context, messages []messageWithUsage) []messageWithUsage {
|
||||
if r.queries == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
compactGroups := make(map[string][]int) // compact_id -> indices
|
||||
for i, m := range messages {
|
||||
if m.CompactID != "" {
|
||||
compactGroups[m.CompactID] = append(compactGroups[m.CompactID], i)
|
||||
}
|
||||
}
|
||||
if len(compactGroups) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
summaries := make(map[string]string)
|
||||
for compactID := range compactGroups {
|
||||
cUUID, err := db.ParseUUID(compactID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
log, err := r.queries.GetCompactionLogByID(ctx, cUUID)
|
||||
if err != nil {
|
||||
r.logger.Warn("replaceCompactedMessages: failed to load compact log", slog.String("compact_id", compactID), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
if log.Status == "ok" && log.Summary != "" {
|
||||
summaries[compactID] = log.Summary
|
||||
}
|
||||
}
|
||||
|
||||
var result []messageWithUsage
|
||||
replaced := make(map[string]bool)
|
||||
for _, m := range messages {
|
||||
if m.CompactID == "" {
|
||||
result = append(result, m)
|
||||
continue
|
||||
}
|
||||
if replaced[m.CompactID] {
|
||||
continue
|
||||
}
|
||||
replaced[m.CompactID] = true
|
||||
summary, ok := summaries[m.CompactID]
|
||||
if !ok || summary == "" {
|
||||
for _, idx := range compactGroups[m.CompactID] {
|
||||
result = append(result, messages[idx])
|
||||
}
|
||||
continue
|
||||
}
|
||||
result = append(result, messageWithUsage{
|
||||
Message: conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`"<summary>\n` + summary + `\n</summary>"`),
|
||||
},
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
|
||||
continue
|
||||
}
|
||||
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
|
||||
if _, storeErr := r.tryStoreStream(ctx, streamReq, data, rc.model.ID); storeErr != nil {
|
||||
if _, storeErr := r.tryStoreStream(ctx, streamReq, data, rc.model.ID, rc); storeErr != nil {
|
||||
r.logger.Error("stream persist failed", slog.Any("error", storeErr))
|
||||
} else {
|
||||
stored = true
|
||||
@@ -124,7 +124,7 @@ func (r *Resolver) StreamChatWS(
|
||||
}
|
||||
|
||||
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
|
||||
if _, storeErr := r.tryStoreStream(ctx, req, data, modelID); storeErr != nil {
|
||||
if _, storeErr := r.tryStoreStream(ctx, req, data, modelID, rc); storeErr != nil {
|
||||
r.logger.Error("ws persist failed", slog.Any("error", storeErr))
|
||||
} else {
|
||||
stored = true
|
||||
@@ -142,10 +142,11 @@ func (r *Resolver) StreamChatWS(
|
||||
}
|
||||
|
||||
// tryStoreStream attempts to extract final messages from a stream event and persist them.
|
||||
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string) (bool, error) {
|
||||
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string, rc resolvedContext) (bool, error) {
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Messages json.RawMessage `json:"messages"`
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &envelope); err != nil {
|
||||
return false, nil
|
||||
@@ -161,5 +162,26 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
|
||||
outputMessages := sdkMessagesToModelMessages(sdkMsgs)
|
||||
roundMessages := prependUserMessage(req.Query, outputMessages)
|
||||
|
||||
return true, r.storeRound(ctx, req, roundMessages, modelID)
|
||||
if err := r.storeRound(ctx, req, roundMessages, modelID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if inputTokens := extractInputTokensFromUsage(envelope.Usage); inputTokens > 0 {
|
||||
go r.maybeCompact(context.WithoutCancel(ctx), req, rc, inputTokens)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func extractInputTokensFromUsage(raw json.RawMessage) int {
|
||||
if len(raw) == 0 {
|
||||
return 0
|
||||
}
|
||||
var u struct {
|
||||
InputTokens int `json:"inputTokens"`
|
||||
}
|
||||
if json.Unmarshal(raw, &u) != nil {
|
||||
return 0
|
||||
}
|
||||
return u.InputTokens
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user