mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
364 lines
9.8 KiB
Go
364 lines
9.8 KiB
Go
package compaction
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
"github.com/memohai/memoh/internal/models"
|
|
)
|
|
|
|
// Service manages context compaction for bot conversations.
|
|
type Service struct {
|
|
queries *sqlc.Queries
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewService creates a new compaction Service.
|
|
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
|
return &Service{
|
|
queries: queries,
|
|
logger: log,
|
|
}
|
|
}
|
|
|
|
// ShouldCompact returns true if inputTokens exceeds the threshold.
|
|
func ShouldCompact(inputTokens, threshold int) bool {
|
|
return threshold > 0 && inputTokens >= threshold
|
|
}
|
|
|
|
// TriggerCompaction runs compaction in the background.
|
|
func (s *Service) TriggerCompaction(ctx context.Context, cfg TriggerConfig) {
|
|
go func() {
|
|
bgCtx := context.WithoutCancel(ctx)
|
|
if err := s.runCompaction(bgCtx, cfg); err != nil {
|
|
s.logger.Error("compaction failed", slog.String("bot_id", cfg.BotID), slog.String("session_id", cfg.SessionID), slog.String("error", err.Error()))
|
|
}
|
|
}()
|
|
}
|
|
|
|
// RunCompactionSync runs compaction synchronously and returns any error.
|
|
func (s *Service) RunCompactionSync(ctx context.Context, cfg TriggerConfig) error {
|
|
return s.runCompaction(ctx, cfg)
|
|
}
|
|
|
|
func (s *Service) runCompaction(ctx context.Context, cfg TriggerConfig) error {
|
|
botUUID, err := db.ParseUUID(cfg.BotID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sessionUUID, err := db.ParseUUID(cfg.SessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logRow, err := s.queries.CreateCompactionLog(ctx, sqlc.CreateCompactionLogParams{
|
|
BotID: botUUID,
|
|
SessionID: sessionUUID,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
compactErr := s.doCompaction(ctx, logRow.ID, sessionUUID, cfg)
|
|
if compactErr != nil {
|
|
s.completeLog(ctx, logRow.ID, "error", "", compactErr.Error(), 0, nil, pgtype.UUID{})
|
|
}
|
|
return compactErr
|
|
}
|
|
|
|
func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUUID pgtype.UUID, cfg TriggerConfig) error {
|
|
messages, err := s.queries.ListUncompactedMessagesBySession(ctx, sessionUUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(messages) == 0 {
|
|
s.completeLog(ctx, logID, "ok", "", "", 0, nil, pgtype.UUID{})
|
|
return nil
|
|
}
|
|
|
|
var toCompact []sqlc.ListUncompactedMessagesBySessionRow
|
|
if cfg.TargetTokens > 0 {
|
|
// Sync compaction: compress enough messages to bring context
|
|
// down to TargetTokens. Calculate how many tokens to keep
|
|
// (newest messages) and compact everything older.
|
|
toCompact = splitByTarget(messages, cfg.TargetTokens)
|
|
} else {
|
|
toCompact = splitByRatio(messages, cfg.TotalInputTokens, cfg.Ratio)
|
|
}
|
|
if len(toCompact) == 0 {
|
|
s.completeLog(ctx, logID, "ok", "", "", 0, nil, pgtype.UUID{})
|
|
return nil
|
|
}
|
|
|
|
// Cap the compaction input to avoid exceeding the compaction model's
|
|
// context window. MaxCompactTokens is typically set to 90% of the model's
|
|
// window. If not set, use a conservative default of 30K tokens.
|
|
maxCompactTokens := cfg.MaxCompactTokens
|
|
if maxCompactTokens <= 0 {
|
|
maxCompactTokens = 30000
|
|
}
|
|
s.logger.Info("compaction: before trim",
|
|
slog.Int("messages", len(toCompact)),
|
|
slog.Int("total_uncompacted", len(messages)),
|
|
slog.Int("max_compact_tokens", maxCompactTokens),
|
|
)
|
|
toCompact = trimCompactMessages(toCompact, maxCompactTokens)
|
|
s.logger.Info("compaction: after trim",
|
|
slog.Int("messages", len(toCompact)),
|
|
)
|
|
|
|
priorLogs, err := s.queries.ListCompactionLogsBySession(ctx, sessionUUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var priorSummaries []string
|
|
for _, l := range priorLogs {
|
|
if l.Summary != "" {
|
|
priorSummaries = append(priorSummaries, l.Summary)
|
|
}
|
|
}
|
|
|
|
entries := make([]messageEntry, 0, len(toCompact))
|
|
messageIDs := make([]pgtype.UUID, 0, len(toCompact))
|
|
for _, m := range toCompact {
|
|
entries = append(entries, messageEntry{
|
|
Role: m.Role,
|
|
Content: string(m.Content),
|
|
})
|
|
messageIDs = append(messageIDs, m.ID)
|
|
}
|
|
|
|
userPrompt := buildUserPrompt(priorSummaries, entries)
|
|
|
|
model := models.NewSDKChatModel(models.SDKModelConfig{
|
|
ClientType: cfg.ClientType,
|
|
BaseURL: cfg.BaseURL,
|
|
APIKey: cfg.APIKey,
|
|
CodexAccountID: cfg.CodexAccountID,
|
|
ModelID: cfg.ModelID,
|
|
HTTPClient: cfg.HTTPClient,
|
|
})
|
|
|
|
result, err := sdk.GenerateTextResult(ctx,
|
|
sdk.WithModel(model),
|
|
sdk.WithSystem(systemPrompt),
|
|
sdk.WithMessages([]sdk.Message{sdk.UserMessage(userPrompt)}),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
usageJSON, _ := json.Marshal(result.Usage)
|
|
|
|
modelUUID := db.ParseUUIDOrEmpty(cfg.ModelID)
|
|
|
|
if err := s.queries.MarkMessagesCompacted(ctx, sqlc.MarkMessagesCompactedParams{
|
|
CompactID: logID,
|
|
Column2: messageIDs,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.completeLog(ctx, logID, "ok", result.Text, "", len(messageIDs), usageJSON, modelUUID)
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) completeLog(ctx context.Context, logID pgtype.UUID, status, summary, errMsg string, messageCount int, usage []byte, modelID pgtype.UUID) {
|
|
if _, err := s.queries.CompleteCompactionLog(ctx, sqlc.CompleteCompactionLogParams{
|
|
ID: logID,
|
|
Status: status,
|
|
Summary: summary,
|
|
MessageCount: int32(messageCount), //nolint:gosec // count always small
|
|
ErrorMessage: errMsg,
|
|
Usage: usage,
|
|
ModelID: modelID,
|
|
}); err != nil {
|
|
s.logger.Error("failed to complete compaction log", slog.String("error", err.Error()))
|
|
}
|
|
}
|
|
|
|
// ListLogs returns paginated compaction logs for a bot.
|
|
func (s *Service) ListLogs(ctx context.Context, botID string, limit, offset int) ([]Log, int64, error) {
|
|
botUUID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
if limit <= 0 || limit > 100 {
|
|
limit = 50
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
total, err := s.queries.CountCompactionLogsByBot(ctx, botUUID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
rows, err := s.queries.ListCompactionLogsByBot(ctx, sqlc.ListCompactionLogsByBotParams{
|
|
BotID: botUUID,
|
|
Limit: int32(limit), //nolint:gosec // clamped above
|
|
Offset: int32(offset), //nolint:gosec // validated above
|
|
})
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
logs := make([]Log, len(rows))
|
|
for i, r := range rows {
|
|
logs[i] = toLog(r)
|
|
}
|
|
return logs, total, nil
|
|
}
|
|
|
|
// DeleteLogs deletes all compaction logs for a bot.
|
|
func (s *Service) DeleteLogs(ctx context.Context, botID string) error {
|
|
botUUID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.DeleteCompactionLogsByBot(ctx, botUUID)
|
|
}
|
|
|
|
func toLog(r sqlc.BotHistoryMessageCompact) Log {
|
|
l := Log{
|
|
ID: formatUUID(r.ID),
|
|
BotID: formatUUID(r.BotID),
|
|
SessionID: formatUUID(r.SessionID),
|
|
Status: r.Status,
|
|
Summary: r.Summary,
|
|
MessageCount: int(r.MessageCount),
|
|
ErrorMessage: r.ErrorMessage,
|
|
ModelID: formatUUID(r.ModelID),
|
|
StartedAt: r.StartedAt.Time,
|
|
}
|
|
if r.CompletedAt.Valid {
|
|
t := r.CompletedAt.Time
|
|
l.CompletedAt = &t
|
|
}
|
|
if len(r.Usage) > 0 {
|
|
var u any
|
|
if json.Unmarshal(r.Usage, &u) == nil {
|
|
l.Usage = u
|
|
}
|
|
}
|
|
return l
|
|
}
|
|
|
|
func formatUUID(id pgtype.UUID) string {
|
|
if !id.Valid {
|
|
return ""
|
|
}
|
|
return uuid.UUID(id.Bytes).String()
|
|
}
|
|
|
|
// splitByRatio splits messages so that roughly the first ratio% (by token weight)
|
|
// are returned for compaction, and the rest are kept as-is.
|
|
// When ratio >= 100, all messages are returned for compaction.
|
|
// When ratio <= 0 or totalInputTokens <= 0 or messages is empty, nil is returned (no compaction).
|
|
func splitByRatio(messages []sqlc.ListUncompactedMessagesBySessionRow, totalInputTokens, ratio int) []sqlc.ListUncompactedMessagesBySessionRow {
|
|
if ratio <= 0 || totalInputTokens <= 0 || len(messages) == 0 {
|
|
return nil
|
|
}
|
|
if ratio >= 100 {
|
|
return messages
|
|
}
|
|
|
|
keepTokens := totalInputTokens * (100 - ratio) / 100
|
|
if keepTokens <= 0 {
|
|
return messages
|
|
}
|
|
|
|
accumulated := 0
|
|
cutoff := len(messages)
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
accumulated += estimateRowTokens(messages[i])
|
|
if accumulated >= keepTokens {
|
|
cutoff = i + 1
|
|
break
|
|
}
|
|
}
|
|
|
|
if cutoff <= 0 {
|
|
return nil
|
|
}
|
|
if cutoff >= len(messages) {
|
|
return messages
|
|
}
|
|
return messages[:cutoff]
|
|
}
|
|
|
|
// splitByTarget returns the oldest messages to compact so that the remaining
|
|
// newest messages fit within targetTokens. This is used for synchronous
|
|
// compaction where the goal is to reduce context to a specific size.
|
|
func splitByTarget(messages []sqlc.ListUncompactedMessagesBySessionRow, targetTokens int) []sqlc.ListUncompactedMessagesBySessionRow {
|
|
if targetTokens <= 0 || len(messages) == 0 {
|
|
return nil
|
|
}
|
|
// Scan from newest to oldest, keeping messages that fit within target.
|
|
accumulated := 0
|
|
cutoff := 0
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
accumulated += estimateRowTokens(messages[i])
|
|
if accumulated > targetTokens {
|
|
cutoff = i + 1
|
|
break
|
|
}
|
|
}
|
|
if cutoff <= 0 {
|
|
return nil
|
|
}
|
|
return messages[:cutoff]
|
|
}
|
|
|
|
type usagePayload struct {
|
|
OutputTokens *int `json:"output_tokens"`
|
|
}
|
|
|
|
func estimateRowTokens(m sqlc.ListUncompactedMessagesBySessionRow) int {
|
|
if len(m.Usage) > 0 {
|
|
var u usagePayload
|
|
if json.Unmarshal(m.Usage, &u) == nil && u.OutputTokens != nil && *u.OutputTokens > 0 {
|
|
return *u.OutputTokens
|
|
}
|
|
}
|
|
return len(m.Content) / 4
|
|
}
|
|
|
|
// trimCompactMessages trims the compaction input from the tail (oldest)
|
|
// so the total estimated tokens stay within maxTokens.
|
|
func trimCompactMessages(messages []sqlc.ListUncompactedMessagesBySessionRow, maxTokens int) []sqlc.ListUncompactedMessagesBySessionRow {
|
|
if len(messages) == 0 || maxTokens <= 0 {
|
|
return messages
|
|
}
|
|
total := 0
|
|
for _, m := range messages {
|
|
total += estimateRowTokens(m)
|
|
}
|
|
if total <= maxTokens {
|
|
return messages
|
|
}
|
|
// Drop oldest messages from the tail until within budget.
|
|
accumulated := 0
|
|
cutoff := len(messages)
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
accumulated += estimateRowTokens(messages[i])
|
|
if accumulated > maxTokens {
|
|
cutoff = i + 1
|
|
break
|
|
}
|
|
}
|
|
if cutoff >= len(messages) {
|
|
return messages
|
|
}
|
|
return messages[cutoff:]
|
|
}
|