mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
295 lines
7.9 KiB
Go
295 lines
7.9 KiB
Go
package heartbeat
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"github.com/robfig/cron/v3"
|
|
|
|
"github.com/memohai/memoh/internal/auth"
|
|
"github.com/memohai/memoh/internal/boot"
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
)
|
|
|
|
const heartbeatTokenTTL = 10 * time.Minute
|
|
|
|
// heartbeatRunTimeout caps how long a single heartbeat execution may take.
|
|
// This prevents unbounded Generate() calls from hanging forever.
|
|
const heartbeatRunTimeout = 5 * time.Minute
|
|
|
|
// SessionCreator creates sessions for heartbeat runs.
|
|
type SessionCreator interface {
|
|
CreateSession(ctx context.Context, botID, sessionType string) (string, error)
|
|
}
|
|
|
|
type Service struct {
|
|
queries *sqlc.Queries
|
|
cron *cron.Cron
|
|
triggerer Triggerer
|
|
sessionCreator SessionCreator
|
|
jwtSecret string
|
|
logger *slog.Logger
|
|
mu sync.Mutex
|
|
jobs map[string]cron.EntryID
|
|
}
|
|
|
|
func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, sessionCreator SessionCreator, runtimeConfig *boot.RuntimeConfig) *Service {
|
|
c := cron.New()
|
|
service := &Service{
|
|
queries: queries,
|
|
cron: c,
|
|
triggerer: triggerer,
|
|
sessionCreator: sessionCreator,
|
|
jwtSecret: runtimeConfig.JwtSecret,
|
|
logger: log.With(slog.String("service", "heartbeat")),
|
|
jobs: map[string]cron.EntryID{},
|
|
}
|
|
c.Start()
|
|
return service
|
|
}
|
|
|
|
func (s *Service) Bootstrap(ctx context.Context) error {
|
|
if s.queries == nil {
|
|
return errors.New("heartbeat queries not configured")
|
|
}
|
|
rows, err := s.queries.ListHeartbeatEnabledBots(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, row := range rows {
|
|
botID := row.ID.String()
|
|
ownerUserID := row.OwnerUserID.String()
|
|
cfg := Config{
|
|
BotID: botID,
|
|
OwnerUserID: ownerUserID,
|
|
Interval: int(row.HeartbeatInterval),
|
|
}
|
|
if err := s.scheduleJob(ctx, cfg); err != nil {
|
|
s.logger.Error("failed to schedule heartbeat", slog.String("bot_id", botID), slog.Any("error", err))
|
|
}
|
|
}
|
|
s.logger.Info("heartbeat bootstrap complete", slog.Int("count", len(rows)))
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) Reschedule(ctx context.Context, botID string) error {
|
|
s.removeJob(botID)
|
|
|
|
pgID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bot, err := s.queries.GetBotByID(ctx, pgID)
|
|
if err != nil {
|
|
return fmt.Errorf("get bot: %w", err)
|
|
}
|
|
if !bot.HeartbeatEnabled || bot.Status != "ready" {
|
|
return nil
|
|
}
|
|
cfg := Config{
|
|
BotID: botID,
|
|
OwnerUserID: bot.OwnerUserID.String(),
|
|
Interval: int(bot.HeartbeatInterval),
|
|
}
|
|
return s.scheduleJob(ctx, cfg)
|
|
}
|
|
|
|
func (s *Service) Stop(botID string) {
|
|
s.removeJob(botID)
|
|
}
|
|
|
|
func (s *Service) runHeartbeat(ctx context.Context, cfg Config) {
|
|
if s.triggerer == nil {
|
|
s.logger.Error("heartbeat triggerer not configured")
|
|
return
|
|
}
|
|
|
|
pgBotID, err := db.ParseUUID(cfg.BotID)
|
|
if err != nil {
|
|
s.logger.Error("invalid bot id", slog.String("bot_id", cfg.BotID), slog.Any("error", err))
|
|
return
|
|
}
|
|
|
|
var sessionID string
|
|
var pgSessionID pgtype.UUID
|
|
if s.sessionCreator != nil {
|
|
sid, err := s.sessionCreator.CreateSession(ctx, cfg.BotID, "heartbeat")
|
|
if err != nil {
|
|
s.logger.Error("create heartbeat session failed", slog.String("bot_id", cfg.BotID), slog.Any("error", err))
|
|
} else {
|
|
sessionID = sid
|
|
pgSessionID = db.ParseUUIDOrEmpty(sid)
|
|
}
|
|
}
|
|
|
|
var lastHeartbeatAt string
|
|
if prevLogs, listErr := s.queries.ListHeartbeatLogsByBot(ctx, sqlc.ListHeartbeatLogsByBotParams{
|
|
BotID: pgBotID,
|
|
Limit: 1,
|
|
}); listErr == nil && len(prevLogs) > 0 {
|
|
lastHeartbeatAt = prevLogs[0].StartedAt.Time.UTC().Format("2006-01-02T15:04:05Z")
|
|
}
|
|
|
|
logRow, err := s.queries.CreateHeartbeatLog(ctx, sqlc.CreateHeartbeatLogParams{
|
|
BotID: pgBotID,
|
|
SessionID: pgSessionID,
|
|
})
|
|
if err != nil {
|
|
s.logger.Error("create heartbeat log failed", slog.String("bot_id", cfg.BotID), slog.Any("error", err))
|
|
return
|
|
}
|
|
|
|
token, err := s.generateTriggerToken(cfg.OwnerUserID)
|
|
if err != nil {
|
|
s.completeLog(ctx, logRow.ID, "error", "", err.Error(), nil, pgtype.UUID{})
|
|
s.logger.Error("generate trigger token failed", slog.String("bot_id", cfg.BotID), slog.Any("error", err))
|
|
return
|
|
}
|
|
|
|
result, err := s.triggerer.TriggerHeartbeat(ctx, cfg.BotID, TriggerPayload{
|
|
BotID: cfg.BotID,
|
|
Interval: cfg.Interval,
|
|
OwnerUserID: cfg.OwnerUserID,
|
|
SessionID: sessionID,
|
|
LastHeartbeatAt: lastHeartbeatAt,
|
|
}, token)
|
|
if err != nil {
|
|
s.completeLog(ctx, logRow.ID, "error", "", err.Error(), nil, pgtype.UUID{})
|
|
s.logger.Error("heartbeat trigger failed", slog.String("bot_id", cfg.BotID), slog.Any("error", err))
|
|
return
|
|
}
|
|
|
|
modelID := db.ParseUUIDOrEmpty(result.ModelID)
|
|
s.completeLog(ctx, logRow.ID, result.Status, result.Text, "", result.UsageBytes, modelID)
|
|
s.logger.Info("heartbeat completed", slog.String("bot_id", cfg.BotID), slog.String("status", result.Status))
|
|
}
|
|
|
|
func (s *Service) completeLog(ctx context.Context, logID pgtype.UUID, status, resultText, errorMessage string, usageBytes []byte, modelID pgtype.UUID) {
|
|
_, err := s.queries.CompleteHeartbeatLog(ctx, sqlc.CompleteHeartbeatLogParams{
|
|
ID: logID,
|
|
Status: status,
|
|
ResultText: resultText,
|
|
ErrorMessage: errorMessage,
|
|
Usage: usageBytes,
|
|
ModelID: modelID,
|
|
})
|
|
if err != nil {
|
|
s.logger.Error("complete heartbeat log failed", slog.Any("error", err))
|
|
}
|
|
}
|
|
|
|
func (s *Service) ListLogs(ctx context.Context, botID string, limit, offset int) ([]Log, int64, error) {
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
if limit <= 0 || limit > 100 {
|
|
limit = 50
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
total, err := s.queries.CountHeartbeatLogsByBot(ctx, pgBotID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
rows, err := s.queries.ListHeartbeatLogsByBot(ctx, sqlc.ListHeartbeatLogsByBotParams{
|
|
BotID: pgBotID,
|
|
Limit: int32(limit), //nolint:gosec // capped to 100 above
|
|
Offset: int32(offset), //nolint:gosec // validated above
|
|
})
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
items := make([]Log, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, toLog(row))
|
|
}
|
|
return items, total, nil
|
|
}
|
|
|
|
func (s *Service) DeleteLogs(ctx context.Context, botID string) error {
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.DeleteHeartbeatLogsByBot(ctx, pgBotID)
|
|
}
|
|
|
|
func (s *Service) generateTriggerToken(userID string) (string, error) {
|
|
if strings.TrimSpace(s.jwtSecret) == "" {
|
|
return "", errors.New("jwt secret not configured")
|
|
}
|
|
signed, _, err := auth.GenerateToken(userID, s.jwtSecret, heartbeatTokenTTL)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return "Bearer " + signed, nil
|
|
}
|
|
|
|
func (s *Service) scheduleJob(ctx context.Context, cfg Config) error {
|
|
if cfg.Interval <= 0 {
|
|
cfg.Interval = 30
|
|
}
|
|
spec := fmt.Sprintf("@every %dm", cfg.Interval)
|
|
job := func() {
|
|
runCtx, runCancel := context.WithTimeout(context.WithoutCancel(ctx), heartbeatRunTimeout)
|
|
defer runCancel()
|
|
s.runHeartbeat(runCtx, cfg)
|
|
}
|
|
entryID, err := s.cron.AddFunc(spec, job)
|
|
if err != nil {
|
|
return fmt.Errorf("add heartbeat cron job: %w", err)
|
|
}
|
|
s.mu.Lock()
|
|
s.jobs[cfg.BotID] = entryID
|
|
s.mu.Unlock()
|
|
s.logger.Info("heartbeat scheduled", slog.String("bot_id", cfg.BotID), slog.Int("interval_minutes", cfg.Interval))
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) removeJob(botID string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
entryID, ok := s.jobs[botID]
|
|
if ok {
|
|
s.cron.Remove(entryID)
|
|
delete(s.jobs, botID)
|
|
}
|
|
}
|
|
|
|
func toLog(row sqlc.ListHeartbeatLogsByBotRow) Log {
|
|
l := Log{
|
|
ID: row.ID.String(),
|
|
BotID: row.BotID.String(),
|
|
SessionID: row.SessionID.String(),
|
|
Status: row.Status,
|
|
ResultText: row.ResultText,
|
|
ErrorMessage: row.ErrorMessage,
|
|
}
|
|
if row.StartedAt.Valid {
|
|
l.StartedAt = row.StartedAt.Time
|
|
}
|
|
if row.CompletedAt.Valid {
|
|
t := row.CompletedAt.Time
|
|
l.CompletedAt = &t
|
|
}
|
|
if row.Usage != nil {
|
|
var usage any
|
|
if err := json.Unmarshal(row.Usage, &usage); err == nil {
|
|
l.Usage = usage
|
|
}
|
|
}
|
|
return l
|
|
}
|