feat: add context compaction to automatically summarize old messages (#compaction) (#276)

When input tokens exceed a configurable threshold after a conversation round,
the system asynchronously compacts older messages into a summary. Cascading
compactions reference prior summaries via <prior_context> tags to maintain
conversational continuity without duplicating content.

- Add bot_history_message_compacts table and compact_id on messages
- Add compaction_enabled, compaction_threshold, compaction_model_id to bots
- Implement compaction service (internal/compaction) with LLM summarization
- Integrate into conversation flow: replace compacted messages with summaries
  wrapped in <summary> tags during context loading
- Add REST API endpoints (GET/DELETE /bots/:bot_id/compaction/logs)
- Add frontend Compaction tab with settings and log viewer
- Wire compaction service into both dev (cmd/agent) and prod (cmd/memoh) entry points
- Update test mocks to include new GetBotByID columns
This commit is contained in:
Acbox Liu
2026-03-22 14:26:00 +08:00
committed by GitHub
parent 91e5e44509
commit de62f94315
40 changed files with 2375 additions and 197 deletions
+26 -13
View File
@@ -11,6 +11,7 @@ import (
sdk "github.com/memohai/twilight-ai/sdk"
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db/sqlc"
memprovider "github.com/memohai/memoh/internal/memory/adapters"
@@ -49,19 +50,20 @@ type gatewayAssetLoader interface {
// Resolver orchestrates chat with the internal agent.
type Resolver struct {
agent *agentpkg.Agent
modelsService *models.Service
queries *sqlc.Queries
memoryRegistry *memprovider.Registry
conversationSvc ConversationSettingsReader
messageService messagepkg.Service
settingsService *settings.Service
sessionService SessionService
eventPublisher messageevent.Publisher
skillLoader SkillLoader
assetLoader gatewayAssetLoader
timeout time.Duration
logger *slog.Logger
agent *agentpkg.Agent
modelsService *models.Service
queries *sqlc.Queries
memoryRegistry *memprovider.Registry
conversationSvc ConversationSettingsReader
messageService messagepkg.Service
settingsService *settings.Service
sessionService SessionService
compactionService *compaction.Service
eventPublisher messageevent.Publisher
skillLoader SkillLoader
assetLoader gatewayAssetLoader
timeout time.Duration
logger *slog.Logger
}
// NewResolver creates a Resolver that uses the internal agent directly.
@@ -106,6 +108,11 @@ func (r *Resolver) SetGatewayAssetLoader(loader gatewayAssetLoader) {
r.assetLoader = loader
}
// SetCompactionService configures the compaction service for context compaction.
func (r *Resolver) SetCompactionService(s *compaction.Service) {
r.compactionService = s
}
type usageInfo struct {
InputTokens *int `json:"inputTokens"`
OutputTokens *int `json:"outputTokens"`
@@ -199,6 +206,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
}
loaded = pruneHistoryForGateway(loaded)
loaded = dedupePersistedCurrentUserMessage(loaded, req)
loaded = r.replaceCompactedMessages(ctx, loaded)
messages = trimMessagesByTokens(r.logger, loaded, historyBudget)
r.logger.Debug("context trim result",
slog.Int("loaded_messages", len(loaded)),
@@ -318,6 +326,11 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv
if err := r.storeRound(ctx, req, roundMessages, rc.model.ID); err != nil {
return conversation.ChatResponse{}, err
}
if result.Usage != nil {
go r.maybeCompact(context.WithoutCancel(ctx), req, rc, result.Usage.InputTokens)
}
return conversation.ChatResponse{
Messages: outputMessages,
Model: rc.model.ModelID,
@@ -0,0 +1,55 @@
package flow
import (
"context"
"log/slog"
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/models"
)
func (r *Resolver) maybeCompact(ctx context.Context, req conversation.ChatRequest, rc resolvedContext, inputTokens int) {
if r.compactionService == nil || r.settingsService == nil {
return
}
settings, err := r.settingsService.GetBot(ctx, req.BotID)
if err != nil {
r.logger.Warn("compaction: failed to load settings", slog.Any("error", err))
return
}
if !settings.CompactionEnabled || settings.CompactionThreshold <= 0 {
return
}
if !compaction.ShouldCompact(inputTokens, settings.CompactionThreshold) {
return
}
modelID := settings.CompactionModelID
if modelID == "" {
modelID = rc.model.ID
}
cfg := compaction.TriggerConfig{
BotID: req.BotID,
SessionID: req.SessionID,
}
model, err := r.modelsService.GetByID(ctx, modelID)
if err != nil {
r.logger.Warn("compaction: failed to resolve model", slog.Any("error", err))
return
}
cfg.ModelID = model.ModelID
cfg.ClientType = string(model.ClientType)
provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
if err != nil {
r.logger.Warn("compaction: failed to fetch provider", slog.Any("error", err))
return
}
cfg.APIKey = provider.ApiKey
cfg.BaseURL = provider.BaseUrl
r.compactionService.TriggerCompaction(ctx, cfg)
}
@@ -8,6 +8,7 @@ import (
"time"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db"
messagepkg "github.com/memohai/memoh/internal/message"
)
@@ -19,6 +20,7 @@ type messageWithUsage struct {
ExternalMessageID string
Platform string
SenderChannelID string
CompactID string
}
func (r *Resolver) loadMessages(ctx context.Context, chatID string, sessionID string, maxContextMinutes int) ([]messageWithUsage, error) {
@@ -63,6 +65,7 @@ func (r *Resolver) loadMessages(ctx context.Context, chatID string, sessionID st
ExternalMessageID: strings.TrimSpace(m.ExternalMessageID),
Platform: strings.TrimSpace(m.Platform),
SenderChannelID: strings.TrimSpace(m.SenderChannelIdentityID),
CompactID: strings.TrimSpace(m.CompactID),
})
}
return result, nil
@@ -165,3 +168,62 @@ func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxToke
}
return result
}
func (r *Resolver) replaceCompactedMessages(ctx context.Context, messages []messageWithUsage) []messageWithUsage {
if r.queries == nil {
return messages
}
compactGroups := make(map[string][]int) // compact_id -> indices
for i, m := range messages {
if m.CompactID != "" {
compactGroups[m.CompactID] = append(compactGroups[m.CompactID], i)
}
}
if len(compactGroups) == 0 {
return messages
}
summaries := make(map[string]string)
for compactID := range compactGroups {
cUUID, err := db.ParseUUID(compactID)
if err != nil {
continue
}
log, err := r.queries.GetCompactionLogByID(ctx, cUUID)
if err != nil {
r.logger.Warn("replaceCompactedMessages: failed to load compact log", slog.String("compact_id", compactID), slog.Any("error", err))
continue
}
if log.Status == "ok" && log.Summary != "" {
summaries[compactID] = log.Summary
}
}
var result []messageWithUsage
replaced := make(map[string]bool)
for _, m := range messages {
if m.CompactID == "" {
result = append(result, m)
continue
}
if replaced[m.CompactID] {
continue
}
replaced[m.CompactID] = true
summary, ok := summaries[m.CompactID]
if !ok || summary == "" {
for _, idx := range compactGroups[m.CompactID] {
result = append(result, messages[idx])
}
continue
}
result = append(result, messageWithUsage{
Message: conversation.ModelMessage{
Role: "user",
Content: json.RawMessage(`"<summary>\n` + summary + `\n</summary>"`),
},
})
}
return result
}
+26 -4
View File
@@ -63,7 +63,7 @@ func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest)
continue
}
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
if _, storeErr := r.tryStoreStream(ctx, streamReq, data, rc.model.ID); storeErr != nil {
if _, storeErr := r.tryStoreStream(ctx, streamReq, data, rc.model.ID, rc); storeErr != nil {
r.logger.Error("stream persist failed", slog.Any("error", storeErr))
} else {
stored = true
@@ -124,7 +124,7 @@ func (r *Resolver) StreamChatWS(
}
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
if _, storeErr := r.tryStoreStream(ctx, req, data, modelID); storeErr != nil {
if _, storeErr := r.tryStoreStream(ctx, req, data, modelID, rc); storeErr != nil {
r.logger.Error("ws persist failed", slog.Any("error", storeErr))
} else {
stored = true
@@ -142,10 +142,11 @@ func (r *Resolver) StreamChatWS(
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string) (bool, error) {
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string, rc resolvedContext) (bool, error) {
var envelope struct {
Type string `json:"type"`
Messages json.RawMessage `json:"messages"`
Usage json.RawMessage `json:"usage,omitempty"`
}
if err := json.Unmarshal(data, &envelope); err != nil {
return false, nil
@@ -161,5 +162,26 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequ
outputMessages := sdkMessagesToModelMessages(sdkMsgs)
roundMessages := prependUserMessage(req.Query, outputMessages)
return true, r.storeRound(ctx, req, roundMessages, modelID)
if err := r.storeRound(ctx, req, roundMessages, modelID); err != nil {
return false, err
}
if inputTokens := extractInputTokensFromUsage(envelope.Usage); inputTokens > 0 {
go r.maybeCompact(context.WithoutCancel(ctx), req, rc, inputTokens)
}
return true, nil
}
func extractInputTokensFromUsage(raw json.RawMessage) int {
if len(raw) == 0 {
return 0
}
var u struct {
InputTokens int `json:"inputTokens"`
}
if json.Unmarshal(raw, &u) != nil {
return 0
}
return u.InputTokens
}