mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
65b2797626
Move CreateModel, BuildReasoningOptions, ReasoningBudgetTokens and related types from internal/agent to internal/models as NewSDKChatModel, SDKModelConfig, etc. This eliminates duplicate ClientType constants and centralises all Twilight AI SDK instance creation in a single package. NewSDKEmbeddingModel now accepts a clientType parameter and dispatches to the native Google embedding provider for google-generative-ai, instead of always using the OpenAI-compatible endpoint.
257 lines
6.3 KiB
Go
257 lines
6.3 KiB
Go
package compaction
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
"github.com/memohai/memoh/internal/models"
|
|
)
|
|
|
|
// Service manages context compaction for bot conversations.
|
|
type Service struct {
|
|
queries *sqlc.Queries
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewService creates a new compaction Service.
|
|
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
|
return &Service{
|
|
queries: queries,
|
|
logger: log,
|
|
}
|
|
}
|
|
|
|
// ShouldCompact returns true if inputTokens exceeds the threshold.
|
|
func ShouldCompact(inputTokens, threshold int) bool {
|
|
return threshold > 0 && inputTokens >= threshold
|
|
}
|
|
|
|
// TriggerCompaction runs compaction in the background.
|
|
func (s *Service) TriggerCompaction(ctx context.Context, cfg TriggerConfig) {
|
|
go func() {
|
|
bgCtx := context.WithoutCancel(ctx)
|
|
if err := s.runCompaction(bgCtx, cfg); err != nil {
|
|
s.logger.Error("compaction failed", slog.String("bot_id", cfg.BotID), slog.String("session_id", cfg.SessionID), slog.String("error", err.Error()))
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *Service) runCompaction(ctx context.Context, cfg TriggerConfig) error {
|
|
botUUID, err := db.ParseUUID(cfg.BotID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sessionUUID, err := db.ParseUUID(cfg.SessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logRow, err := s.queries.CreateCompactionLog(ctx, sqlc.CreateCompactionLogParams{
|
|
BotID: botUUID,
|
|
SessionID: sessionUUID,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
compactErr := s.doCompaction(ctx, logRow.ID, sessionUUID, cfg)
|
|
if compactErr != nil {
|
|
s.completeLog(ctx, logRow.ID, "error", "", compactErr.Error(), nil, pgtype.UUID{})
|
|
}
|
|
return compactErr
|
|
}
|
|
|
|
func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUUID pgtype.UUID, cfg TriggerConfig) error {
|
|
messages, err := s.queries.ListUncompactedMessagesBySession(ctx, sessionUUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(messages) == 0 {
|
|
s.completeLog(ctx, logID, "ok", "", "", nil, pgtype.UUID{})
|
|
return nil
|
|
}
|
|
|
|
priorLogs, err := s.queries.ListCompactionLogsBySession(ctx, sessionUUID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var priorSummaries []string
|
|
for _, l := range priorLogs {
|
|
if l.Summary != "" {
|
|
priorSummaries = append(priorSummaries, l.Summary)
|
|
}
|
|
}
|
|
|
|
entries := make([]messageEntry, 0, len(messages))
|
|
messageIDs := make([]pgtype.UUID, 0, len(messages))
|
|
for _, m := range messages {
|
|
entries = append(entries, messageEntry{
|
|
Role: m.Role,
|
|
Content: extractTextContent(m.Content),
|
|
})
|
|
messageIDs = append(messageIDs, m.ID)
|
|
}
|
|
|
|
userPrompt := buildUserPrompt(priorSummaries, entries)
|
|
|
|
model := models.NewSDKChatModel(models.SDKModelConfig{
|
|
ClientType: cfg.ClientType,
|
|
BaseURL: cfg.BaseURL,
|
|
APIKey: cfg.APIKey,
|
|
ModelID: cfg.ModelID,
|
|
})
|
|
|
|
result, err := sdk.GenerateTextResult(ctx,
|
|
sdk.WithModel(model),
|
|
sdk.WithSystem(systemPrompt),
|
|
sdk.WithMessages([]sdk.Message{sdk.UserMessage(userPrompt)}),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
usageJSON, _ := json.Marshal(result.Usage)
|
|
|
|
modelUUID := db.ParseUUIDOrEmpty(cfg.ModelID)
|
|
|
|
if err := s.queries.MarkMessagesCompacted(ctx, sqlc.MarkMessagesCompactedParams{
|
|
CompactID: logID,
|
|
Column2: messageIDs,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.completeLog(ctx, logID, "ok", result.Text, "", usageJSON, modelUUID)
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) completeLog(ctx context.Context, logID pgtype.UUID, status, summary, errMsg string, usage []byte, modelID pgtype.UUID) {
|
|
if _, err := s.queries.CompleteCompactionLog(ctx, sqlc.CompleteCompactionLogParams{
|
|
ID: logID,
|
|
Status: status,
|
|
Summary: summary,
|
|
MessageCount: 0,
|
|
ErrorMessage: errMsg,
|
|
Usage: usage,
|
|
ModelID: modelID,
|
|
}); err != nil {
|
|
s.logger.Error("failed to complete compaction log", slog.String("error", err.Error()))
|
|
}
|
|
}
|
|
|
|
// ListLogs returns paginated compaction logs for a bot.
|
|
func (s *Service) ListLogs(ctx context.Context, botID string, before *time.Time, limit int) ([]Log, error) {
|
|
botUUID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var beforeTS pgtype.Timestamptz
|
|
if before != nil {
|
|
beforeTS = pgtype.Timestamptz{Time: *before, Valid: true}
|
|
}
|
|
|
|
clampedLimit := limit
|
|
if clampedLimit > 1000 {
|
|
clampedLimit = 1000
|
|
}
|
|
rows, err := s.queries.ListCompactionLogsByBot(ctx, sqlc.ListCompactionLogsByBotParams{
|
|
BotID: botUUID,
|
|
Column2: beforeTS,
|
|
Limit: int32(clampedLimit), //nolint:gosec // clamped above
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logs := make([]Log, len(rows))
|
|
for i, r := range rows {
|
|
logs[i] = toLog(r)
|
|
}
|
|
return logs, nil
|
|
}
|
|
|
|
// DeleteLogs deletes all compaction logs for a bot.
|
|
func (s *Service) DeleteLogs(ctx context.Context, botID string) error {
|
|
botUUID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.DeleteCompactionLogsByBot(ctx, botUUID)
|
|
}
|
|
|
|
func toLog(r sqlc.BotHistoryMessageCompact) Log {
|
|
l := Log{
|
|
ID: formatUUID(r.ID),
|
|
BotID: formatUUID(r.BotID),
|
|
SessionID: formatUUID(r.SessionID),
|
|
Status: r.Status,
|
|
Summary: r.Summary,
|
|
MessageCount: int(r.MessageCount),
|
|
ErrorMessage: r.ErrorMessage,
|
|
ModelID: formatUUID(r.ModelID),
|
|
StartedAt: r.StartedAt.Time,
|
|
}
|
|
if r.CompletedAt.Valid {
|
|
t := r.CompletedAt.Time
|
|
l.CompletedAt = &t
|
|
}
|
|
if len(r.Usage) > 0 {
|
|
var u any
|
|
if json.Unmarshal(r.Usage, &u) == nil {
|
|
l.Usage = u
|
|
}
|
|
}
|
|
return l
|
|
}
|
|
|
|
func formatUUID(id pgtype.UUID) string {
|
|
if !id.Valid {
|
|
return ""
|
|
}
|
|
return uuid.UUID(id.Bytes).String()
|
|
}
|
|
|
|
// extractTextContent extracts plain text from a message content JSONB field.
|
|
// The content may be a JSON string, an array of content parts, or raw bytes.
|
|
func extractTextContent(content []byte) string {
|
|
if len(content) == 0 {
|
|
return ""
|
|
}
|
|
|
|
var s string
|
|
if json.Unmarshal(content, &s) == nil {
|
|
return s
|
|
}
|
|
|
|
var parts []map[string]any
|
|
if json.Unmarshal(content, &parts) == nil {
|
|
var texts []string
|
|
for _, p := range parts {
|
|
if t, ok := p["type"].(string); ok && t == "text" {
|
|
if text, ok := p["text"].(string); ok {
|
|
texts = append(texts, text)
|
|
}
|
|
}
|
|
}
|
|
if len(texts) > 0 {
|
|
return joinTexts(texts)
|
|
}
|
|
}
|
|
|
|
return string(content)
|
|
}
|
|
|
|
func joinTexts(parts []string) string {
|
|
return strings.Join(parts, " ")
|
|
}
|