Files
Memoh/internal/conversation/flow/resolver_history.go
T
Acbox Liu de62f94315 feat: add context compaction to automatically summarize old messages (#compaction) (#276)
When input tokens exceed a configurable threshold after a conversation round,
the system asynchronously compacts older messages into a summary. Cascading
compactions reference prior summaries via <prior_context> tags to maintain
conversational continuity without duplicating content.

- Add bot_history_message_compacts table and compact_id on messages
- Add compaction_enabled, compaction_threshold, compaction_model_id to bots
- Implement compaction service (internal/compaction) with LLM summarization
- Integrate into conversation flow: replace compacted messages with summaries
  wrapped in <summary> tags during context loading
- Add REST API endpoints (GET/DELETE /bots/:bot_id/compaction/logs)
- Add frontend Compaction tab with settings and log viewer
- Wire compaction service into both dev (cmd/agent) and prod (cmd/memoh) entry points
- Update test mocks to include new GetBotByID columns
2026-03-22 14:26:00 +08:00

230 lines
6.6 KiB
Go

package flow
import (
"context"
"encoding/json"
"log/slog"
"strings"
"time"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db"
messagepkg "github.com/memohai/memoh/internal/message"
)
type messageWithUsage struct {
Message conversation.ModelMessage
UsageInputTokens *int
UsageOutputTokens *int
SessionID string
ExternalMessageID string
Platform string
SenderChannelID string
CompactID string
}
func (r *Resolver) loadMessages(ctx context.Context, chatID string, sessionID string, maxContextMinutes int) ([]messageWithUsage, error) {
if r.messageService == nil {
return nil, nil
}
since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute)
var msgs []messagepkg.Message
var err error
if strings.TrimSpace(sessionID) != "" {
msgs, err = r.messageService.ListActiveSinceBySession(ctx, sessionID, since)
} else {
msgs, err = r.messageService.ListActiveSince(ctx, chatID, since)
}
if err != nil {
return nil, err
}
var result []messageWithUsage
for _, m := range msgs {
var mm conversation.ModelMessage
if err := json.Unmarshal(m.Content, &mm); err != nil {
r.logger.Warn("loadMessages: content unmarshal failed, treating as raw text",
slog.String("chat_id", chatID), slog.Any("error", err))
mm = conversation.ModelMessage{Role: m.Role, Content: m.Content}
} else {
mm.Role = m.Role
}
var inputTokens *int
var outputTokens *int
if len(m.Usage) > 0 {
var u usageInfo
if json.Unmarshal(m.Usage, &u) == nil {
inputTokens = u.InputTokens
outputTokens = u.OutputTokens
}
}
result = append(result, messageWithUsage{
Message: mm,
UsageInputTokens: inputTokens,
UsageOutputTokens: outputTokens,
SessionID: strings.TrimSpace(m.SessionID),
ExternalMessageID: strings.TrimSpace(m.ExternalMessageID),
Platform: strings.TrimSpace(m.Platform),
SenderChannelID: strings.TrimSpace(m.SenderChannelIdentityID),
CompactID: strings.TrimSpace(m.CompactID),
})
}
return result, nil
}
func dedupePersistedCurrentUserMessage(messages []messageWithUsage, req conversation.ChatRequest) []messageWithUsage {
if !req.UserMessagePersisted || len(messages) == 0 {
return messages
}
targetSessionID := strings.TrimSpace(req.SessionID)
targetExternalID := strings.TrimSpace(req.ExternalMessageID)
targetPlatform := strings.TrimSpace(req.CurrentChannel)
targetSenderChannelID := strings.TrimSpace(req.SourceChannelIdentityID)
if targetExternalID == "" {
return messages
}
for i := len(messages) - 1; i >= 0; i-- {
item := messages[i]
if !strings.EqualFold(strings.TrimSpace(item.Message.Role), "user") {
continue
}
if strings.TrimSpace(item.ExternalMessageID) != targetExternalID {
continue
}
if targetSessionID != "" && item.SessionID != "" && item.SessionID != targetSessionID {
continue
}
if targetPlatform != "" && item.Platform != "" && !strings.EqualFold(item.Platform, targetPlatform) {
continue
}
if targetSenderChannelID != "" && item.SenderChannelID != "" && item.SenderChannelID != targetSenderChannelID {
continue
}
return append(messages[:i], messages[i+1:]...)
}
return messages
}
func estimateMessageTokens(msg conversation.ModelMessage) int {
text := msg.TextContent()
if len(text) == 0 {
data, _ := json.Marshal(msg.Content)
return len(data) / 4
}
return len(text) / 4
}
func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxTokens int) []conversation.ModelMessage {
if maxTokens == 0 || len(messages) == 0 {
result := make([]conversation.ModelMessage, len(messages))
for i, m := range messages {
result[i] = m.Message
}
return result
}
// Scan from newest to oldest, accumulating per-message token costs.
// Messages with stored usage data use that value; others fall back to a
// character-based estimate so that user/tool messages are not free-passed.
totalTokens := 0
cutoff := 0
messagesWithUsage := 0
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].UsageOutputTokens != nil {
totalTokens += *messages[i].UsageOutputTokens
messagesWithUsage++
} else {
totalTokens += estimateMessageTokens(messages[i].Message)
}
if totalTokens > maxTokens {
cutoff = i + 1
break
}
}
// Keep provider-valid message order: a "tool" message must follow a preceding
// assistant tool call. When history is head-trimmed, a leading tool message
// may become orphaned and cause provider 400 errors.
for cutoff < len(messages) && strings.EqualFold(strings.TrimSpace(messages[cutoff].Message.Role), "tool") {
cutoff++
}
if log != nil {
log.Debug("trimMessagesByTokens",
slog.Int("total_messages", len(messages)),
slog.Int("messages_with_usage", messagesWithUsage),
slog.Int("accumulated_output_tokens", totalTokens),
slog.Int("max_tokens", maxTokens),
slog.Int("cutoff_index", cutoff),
slog.Int("kept_messages", len(messages)-cutoff),
)
}
result := make([]conversation.ModelMessage, 0, len(messages)-cutoff)
for _, m := range messages[cutoff:] {
result = append(result, m.Message)
}
return result
}
func (r *Resolver) replaceCompactedMessages(ctx context.Context, messages []messageWithUsage) []messageWithUsage {
if r.queries == nil {
return messages
}
compactGroups := make(map[string][]int) // compact_id -> indices
for i, m := range messages {
if m.CompactID != "" {
compactGroups[m.CompactID] = append(compactGroups[m.CompactID], i)
}
}
if len(compactGroups) == 0 {
return messages
}
summaries := make(map[string]string)
for compactID := range compactGroups {
cUUID, err := db.ParseUUID(compactID)
if err != nil {
continue
}
log, err := r.queries.GetCompactionLogByID(ctx, cUUID)
if err != nil {
r.logger.Warn("replaceCompactedMessages: failed to load compact log", slog.String("compact_id", compactID), slog.Any("error", err))
continue
}
if log.Status == "ok" && log.Summary != "" {
summaries[compactID] = log.Summary
}
}
var result []messageWithUsage
replaced := make(map[string]bool)
for _, m := range messages {
if m.CompactID == "" {
result = append(result, m)
continue
}
if replaced[m.CompactID] {
continue
}
replaced[m.CompactID] = true
summary, ok := summaries[m.CompactID]
if !ok || summary == "" {
for _, idx := range compactGroups[m.CompactID] {
result = append(result, messages[idx])
}
continue
}
result = append(result, messageWithUsage{
Message: conversation.ModelMessage{
Role: "user",
Content: json.RawMessage(`"<summary>\n` + summary + `\n</summary>"`),
},
})
}
return result
}