Files
Memoh/internal/compaction/service.go
T

364 lines
9.8 KiB
Go

package compaction
import (
"context"
"encoding/json"
"log/slog"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
// Service manages context compaction for bot conversations.
type Service struct {
queries *sqlc.Queries
logger *slog.Logger
}
// NewService creates a new compaction Service.
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
return &Service{
queries: queries,
logger: log,
}
}
// ShouldCompact returns true if inputTokens exceeds the threshold.
func ShouldCompact(inputTokens, threshold int) bool {
return threshold > 0 && inputTokens >= threshold
}
// TriggerCompaction runs compaction in the background.
func (s *Service) TriggerCompaction(ctx context.Context, cfg TriggerConfig) {
go func() {
bgCtx := context.WithoutCancel(ctx)
if err := s.runCompaction(bgCtx, cfg); err != nil {
s.logger.Error("compaction failed", slog.String("bot_id", cfg.BotID), slog.String("session_id", cfg.SessionID), slog.String("error", err.Error()))
}
}()
}
// RunCompactionSync runs compaction synchronously and returns any error.
func (s *Service) RunCompactionSync(ctx context.Context, cfg TriggerConfig) error {
return s.runCompaction(ctx, cfg)
}
func (s *Service) runCompaction(ctx context.Context, cfg TriggerConfig) error {
botUUID, err := db.ParseUUID(cfg.BotID)
if err != nil {
return err
}
sessionUUID, err := db.ParseUUID(cfg.SessionID)
if err != nil {
return err
}
logRow, err := s.queries.CreateCompactionLog(ctx, sqlc.CreateCompactionLogParams{
BotID: botUUID,
SessionID: sessionUUID,
})
if err != nil {
return err
}
compactErr := s.doCompaction(ctx, logRow.ID, sessionUUID, cfg)
if compactErr != nil {
s.completeLog(ctx, logRow.ID, "error", "", compactErr.Error(), 0, nil, pgtype.UUID{})
}
return compactErr
}
func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUUID pgtype.UUID, cfg TriggerConfig) error {
messages, err := s.queries.ListUncompactedMessagesBySession(ctx, sessionUUID)
if err != nil {
return err
}
if len(messages) == 0 {
s.completeLog(ctx, logID, "ok", "", "", 0, nil, pgtype.UUID{})
return nil
}
var toCompact []sqlc.ListUncompactedMessagesBySessionRow
if cfg.TargetTokens > 0 {
// Sync compaction: compress enough messages to bring context
// down to TargetTokens. Calculate how many tokens to keep
// (newest messages) and compact everything older.
toCompact = splitByTarget(messages, cfg.TargetTokens)
} else {
toCompact = splitByRatio(messages, cfg.TotalInputTokens, cfg.Ratio)
}
if len(toCompact) == 0 {
s.completeLog(ctx, logID, "ok", "", "", 0, nil, pgtype.UUID{})
return nil
}
// Cap the compaction input to avoid exceeding the compaction model's
// context window. MaxCompactTokens is typically set to 90% of the model's
// window. If not set, use a conservative default of 30K tokens.
maxCompactTokens := cfg.MaxCompactTokens
if maxCompactTokens <= 0 {
maxCompactTokens = 30000
}
s.logger.Info("compaction: before trim",
slog.Int("messages", len(toCompact)),
slog.Int("total_uncompacted", len(messages)),
slog.Int("max_compact_tokens", maxCompactTokens),
)
toCompact = trimCompactMessages(toCompact, maxCompactTokens)
s.logger.Info("compaction: after trim",
slog.Int("messages", len(toCompact)),
)
priorLogs, err := s.queries.ListCompactionLogsBySession(ctx, sessionUUID)
if err != nil {
return err
}
var priorSummaries []string
for _, l := range priorLogs {
if l.Summary != "" {
priorSummaries = append(priorSummaries, l.Summary)
}
}
entries := make([]messageEntry, 0, len(toCompact))
messageIDs := make([]pgtype.UUID, 0, len(toCompact))
for _, m := range toCompact {
entries = append(entries, messageEntry{
Role: m.Role,
Content: string(m.Content),
})
messageIDs = append(messageIDs, m.ID)
}
userPrompt := buildUserPrompt(priorSummaries, entries)
model := models.NewSDKChatModel(models.SDKModelConfig{
ClientType: cfg.ClientType,
BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey,
CodexAccountID: cfg.CodexAccountID,
ModelID: cfg.ModelID,
HTTPClient: cfg.HTTPClient,
})
result, err := sdk.GenerateTextResult(ctx,
sdk.WithModel(model),
sdk.WithSystem(systemPrompt),
sdk.WithMessages([]sdk.Message{sdk.UserMessage(userPrompt)}),
)
if err != nil {
return err
}
usageJSON, _ := json.Marshal(result.Usage)
modelUUID := db.ParseUUIDOrEmpty(cfg.ModelID)
if err := s.queries.MarkMessagesCompacted(ctx, sqlc.MarkMessagesCompactedParams{
CompactID: logID,
Column2: messageIDs,
}); err != nil {
return err
}
s.completeLog(ctx, logID, "ok", result.Text, "", len(messageIDs), usageJSON, modelUUID)
return nil
}
func (s *Service) completeLog(ctx context.Context, logID pgtype.UUID, status, summary, errMsg string, messageCount int, usage []byte, modelID pgtype.UUID) {
if _, err := s.queries.CompleteCompactionLog(ctx, sqlc.CompleteCompactionLogParams{
ID: logID,
Status: status,
Summary: summary,
MessageCount: int32(messageCount), //nolint:gosec // count always small
ErrorMessage: errMsg,
Usage: usage,
ModelID: modelID,
}); err != nil {
s.logger.Error("failed to complete compaction log", slog.String("error", err.Error()))
}
}
// ListLogs returns paginated compaction logs for a bot.
func (s *Service) ListLogs(ctx context.Context, botID string, limit, offset int) ([]Log, int64, error) {
botUUID, err := db.ParseUUID(botID)
if err != nil {
return nil, 0, err
}
if limit <= 0 || limit > 100 {
limit = 50
}
if offset < 0 {
offset = 0
}
total, err := s.queries.CountCompactionLogsByBot(ctx, botUUID)
if err != nil {
return nil, 0, err
}
rows, err := s.queries.ListCompactionLogsByBot(ctx, sqlc.ListCompactionLogsByBotParams{
BotID: botUUID,
Limit: int32(limit), //nolint:gosec // clamped above
Offset: int32(offset), //nolint:gosec // validated above
})
if err != nil {
return nil, 0, err
}
logs := make([]Log, len(rows))
for i, r := range rows {
logs[i] = toLog(r)
}
return logs, total, nil
}
// DeleteLogs deletes all compaction logs for a bot.
func (s *Service) DeleteLogs(ctx context.Context, botID string) error {
botUUID, err := db.ParseUUID(botID)
if err != nil {
return err
}
return s.queries.DeleteCompactionLogsByBot(ctx, botUUID)
}
func toLog(r sqlc.BotHistoryMessageCompact) Log {
l := Log{
ID: formatUUID(r.ID),
BotID: formatUUID(r.BotID),
SessionID: formatUUID(r.SessionID),
Status: r.Status,
Summary: r.Summary,
MessageCount: int(r.MessageCount),
ErrorMessage: r.ErrorMessage,
ModelID: formatUUID(r.ModelID),
StartedAt: r.StartedAt.Time,
}
if r.CompletedAt.Valid {
t := r.CompletedAt.Time
l.CompletedAt = &t
}
if len(r.Usage) > 0 {
var u any
if json.Unmarshal(r.Usage, &u) == nil {
l.Usage = u
}
}
return l
}
func formatUUID(id pgtype.UUID) string {
if !id.Valid {
return ""
}
return uuid.UUID(id.Bytes).String()
}
// splitByRatio splits messages so that roughly the first ratio% (by token weight)
// are returned for compaction, and the rest are kept as-is.
// When ratio >= 100, all messages are returned for compaction.
// When ratio <= 0 or totalInputTokens <= 0 or messages is empty, nil is returned (no compaction).
func splitByRatio(messages []sqlc.ListUncompactedMessagesBySessionRow, totalInputTokens, ratio int) []sqlc.ListUncompactedMessagesBySessionRow {
if ratio <= 0 || totalInputTokens <= 0 || len(messages) == 0 {
return nil
}
if ratio >= 100 {
return messages
}
keepTokens := totalInputTokens * (100 - ratio) / 100
if keepTokens <= 0 {
return messages
}
accumulated := 0
cutoff := len(messages)
for i := len(messages) - 1; i >= 0; i-- {
accumulated += estimateRowTokens(messages[i])
if accumulated >= keepTokens {
cutoff = i + 1
break
}
}
if cutoff <= 0 {
return nil
}
if cutoff >= len(messages) {
return messages
}
return messages[:cutoff]
}
// splitByTarget returns the oldest messages to compact so that the remaining
// newest messages fit within targetTokens. This is used for synchronous
// compaction where the goal is to reduce context to a specific size.
func splitByTarget(messages []sqlc.ListUncompactedMessagesBySessionRow, targetTokens int) []sqlc.ListUncompactedMessagesBySessionRow {
if targetTokens <= 0 || len(messages) == 0 {
return nil
}
// Scan from newest to oldest, keeping messages that fit within target.
accumulated := 0
cutoff := 0
for i := len(messages) - 1; i >= 0; i-- {
accumulated += estimateRowTokens(messages[i])
if accumulated > targetTokens {
cutoff = i + 1
break
}
}
if cutoff <= 0 {
return nil
}
return messages[:cutoff]
}
type usagePayload struct {
OutputTokens *int `json:"output_tokens"`
}
func estimateRowTokens(m sqlc.ListUncompactedMessagesBySessionRow) int {
if len(m.Usage) > 0 {
var u usagePayload
if json.Unmarshal(m.Usage, &u) == nil && u.OutputTokens != nil && *u.OutputTokens > 0 {
return *u.OutputTokens
}
}
return len(m.Content) / 4
}
// trimCompactMessages trims the compaction input from the tail (oldest)
// so the total estimated tokens stay within maxTokens.
func trimCompactMessages(messages []sqlc.ListUncompactedMessagesBySessionRow, maxTokens int) []sqlc.ListUncompactedMessagesBySessionRow {
if len(messages) == 0 || maxTokens <= 0 {
return messages
}
total := 0
for _, m := range messages {
total += estimateRowTokens(m)
}
if total <= maxTokens {
return messages
}
// Drop oldest messages from the tail until within budget.
accumulated := 0
cutoff := len(messages)
for i := len(messages) - 1; i >= 0; i-- {
accumulated += estimateRowTokens(messages[i])
if accumulated > maxTokens {
cutoff = i + 1
break
}
}
if cutoff >= len(messages) {
return messages
}
return messages[cutoff:]
}