Files
Memoh/internal/settings/service.go
T
Acbox 473d559042 feat(channel): structured tool-call IM display with edit-in-place
Introduce a new `show_tool_calls_in_im` bot setting plus a full overhaul of
how tool calls are surfaced in IM channels:

- Add per-bot setting + migration (0072) and expose through settings API /
  handlers / frontend SDK.
- Introduce a `toolCallDroppingStream` wrapper that filters tool_call_* events
  when the setting is off, keeping the rest of the stream intact.
- Add a shared `ToolCallPresentation` model (Header / Body blocks / Footer)
  with plain and Markdown renderers, and a per-tool formatter registry that
  produces rich output (e.g. `web_search` link lists, `list` directory
  previews, `exec` stdout/stderr tails) instead of raw JSON dumps.
- High-capability adapters (Telegram, Feishu, Matrix, Slack, Discord) now
  flush pre-text and then send ONE tool-call message per call, editing it
  in-place from `running` to `completed` / `failed`; mapping from callID to
  platform message ID is tracked per stream, with a fallback to a new
  message if the edit fails. Low-capability adapters (WeCom, QQ, DingTalk)
  keep posting a single final message, but now benefit from the same rich
  per-tool formatting.
- Suppress the early duplicate `EventToolCallStart` (from
  `sdk.ToolInputStartPart`) so that the SDK's final `StreamToolCallPart`
  remains the single source of truth for tool call start, preventing
  duplicated "running" bubbles in IM.
- Stop auto-populating `InputSummary` / `ResultSummary` after a per-tool
  formatter runs, which previously leaked the raw JSON result as a
  fallback footer underneath the formatted body.

Add regression tests for the formatters, the Markdown renderer, the
edit-in-place flow on Telegram/Matrix, and the JSON-leak guard on `list`.
2026-04-23 20:49:44 +08:00

465 lines
14 KiB
Go

package settings
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/acl"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
tzutil "github.com/memohai/memoh/internal/timezone"
)
type Service struct {
queries *sqlc.Queries
acl *acl.Service
logger *slog.Logger
}
var (
ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
ErrInvalidModelRef = errors.New("invalid model reference")
)
func NewService(log *slog.Logger, queries *sqlc.Queries, aclService *acl.Service) *Service {
return &Service{
queries: queries,
acl: aclService,
logger: log.With(slog.String("service", "settings")),
}
}
func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) {
pgID, err := db.ParseUUID(botID)
if err != nil {
return Settings{}, err
}
row, err := s.queries.GetSettingsByBotID(ctx, pgID)
if err != nil {
return Settings{}, err
}
settings := normalizeBotSettingsReadRow(row)
aclDefaultEffect, err := s.getDefaultEffect(ctx, botID)
if err != nil {
return Settings{}, err
}
settings.AclDefaultEffect = aclDefaultEffect
return settings, nil
}
func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) {
if s.queries == nil {
return Settings{}, errors.New("settings queries not configured")
}
pgID, err := db.ParseUUID(botID)
if err != nil {
return Settings{}, err
}
botRow, err := s.queries.GetBotByID(ctx, pgID)
if err != nil {
return Settings{}, err
}
aclDefaultEffect, err := s.getDefaultEffect(ctx, botID)
if err != nil {
return Settings{}, err
}
current := normalizeBotSetting(botRow.Language, aclDefaultEffect, botRow.ReasoningEnabled, botRow.ReasoningEffort, botRow.HeartbeatEnabled, botRow.HeartbeatInterval, botRow.CompactionEnabled, botRow.CompactionThreshold, botRow.CompactionRatio)
if strings.TrimSpace(req.Language) != "" {
current.Language = strings.TrimSpace(req.Language)
}
if effect := strings.TrimSpace(req.AclDefaultEffect); effect != "" {
current.AclDefaultEffect = effect
}
if req.ReasoningEnabled != nil {
current.ReasoningEnabled = *req.ReasoningEnabled
}
if req.ReasoningEffort != nil && isValidReasoningEffort(*req.ReasoningEffort) {
current.ReasoningEffort = *req.ReasoningEffort
}
if req.HeartbeatEnabled != nil {
current.HeartbeatEnabled = *req.HeartbeatEnabled
}
if req.HeartbeatInterval != nil && *req.HeartbeatInterval > 0 {
current.HeartbeatInterval = *req.HeartbeatInterval
}
if req.CompactionEnabled != nil {
current.CompactionEnabled = *req.CompactionEnabled
}
if req.CompactionThreshold != nil && *req.CompactionThreshold >= 0 {
current.CompactionThreshold = *req.CompactionThreshold
}
if req.CompactionRatio != nil && *req.CompactionRatio >= 1 && *req.CompactionRatio <= 100 {
current.CompactionRatio = *req.CompactionRatio
}
if req.PersistFullToolResults != nil {
current.PersistFullToolResults = *req.PersistFullToolResults
}
if req.ShowToolCallsInIM != nil {
current.ShowToolCallsInIM = *req.ShowToolCallsInIM
}
timezoneValue := pgtype.Text{}
if req.Timezone != nil {
normalized, err := normalizeOptionalTimezone(*req.Timezone)
if err != nil {
return Settings{}, err
}
timezoneValue = normalized
}
chatModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.ChatModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
chatModelUUID = modelID
}
heartbeatModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.HeartbeatModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
heartbeatModelUUID = modelID
}
compactionModelUUID := pgtype.UUID{}
if req.CompactionModelID != nil {
if value := strings.TrimSpace(*req.CompactionModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
compactionModelUUID = modelID
}
}
titleModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.TitleModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
titleModelUUID = modelID
}
imageModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.ImageModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
imageModelUUID = modelID
}
searchProviderUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.SearchProviderID); value != "" {
providerID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
searchProviderUUID = providerID
}
memoryProviderUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.MemoryProviderID); value != "" {
providerID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
memoryProviderUUID = providerID
}
ttsModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.TtsModelID); value != "" {
modelID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
ttsModelUUID = modelID
}
transcriptionModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.TranscriptionModelID); value != "" {
modelID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
transcriptionModelUUID = modelID
}
browserContextUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.BrowserContextID); value != "" {
ctxID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
browserContextUUID = ctxID
}
updated, err := s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{
ID: pgID,
Timezone: timezoneValue,
Language: current.Language,
ReasoningEnabled: current.ReasoningEnabled,
ReasoningEffort: current.ReasoningEffort,
HeartbeatEnabled: current.HeartbeatEnabled,
HeartbeatInterval: int32(current.HeartbeatInterval), //nolint:gosec // bounded by positive-only setter above
HeartbeatPrompt: "",
CompactionEnabled: current.CompactionEnabled,
CompactionThreshold: int32(current.CompactionThreshold), //nolint:gosec // bounded by non-negative setter above
CompactionRatio: int32(current.CompactionRatio), //nolint:gosec // bounded 1-100 above
ChatModelID: chatModelUUID,
HeartbeatModelID: heartbeatModelUUID,
CompactionModelID: compactionModelUUID,
TitleModelID: titleModelUUID,
ImageModelID: imageModelUUID,
SearchProviderID: searchProviderUUID,
MemoryProviderID: memoryProviderUUID,
TtsModelID: ttsModelUUID,
TranscriptionModelID: transcriptionModelUUID,
BrowserContextID: browserContextUUID,
PersistFullToolResults: current.PersistFullToolResults,
ShowToolCallsInIm: current.ShowToolCallsInIM,
})
if err != nil {
return Settings{}, err
}
createdByUserID := ""
if botRow.OwnerUserID.Valid {
createdByUserID = uuid.UUID(botRow.OwnerUserID.Bytes).String()
}
_ = createdByUserID
if err := s.setDefaultEffect(ctx, botID, current.AclDefaultEffect); err != nil {
return Settings{}, err
}
settings := normalizeBotSettingsWriteRow(updated)
settings.AclDefaultEffect = current.AclDefaultEffect
return settings, nil
}
func (s *Service) Delete(ctx context.Context, botID string) error {
if s.queries == nil {
return errors.New("settings queries not configured")
}
pgID, err := db.ParseUUID(botID)
if err != nil {
return err
}
if err := s.queries.DeleteSettingsByBotID(ctx, pgID); err != nil {
return err
}
return nil
}
func normalizeBotSetting(language string, aclDefaultEffect string, reasoningEnabled bool, reasoningEffort string, heartbeatEnabled bool, heartbeatInterval int32, compactionEnabled bool, compactionThreshold int32, compactionRatio int32) Settings {
settings := Settings{
Language: strings.TrimSpace(language),
AclDefaultEffect: strings.TrimSpace(aclDefaultEffect),
ReasoningEnabled: reasoningEnabled,
ReasoningEffort: strings.TrimSpace(reasoningEffort),
HeartbeatEnabled: heartbeatEnabled,
HeartbeatInterval: int(heartbeatInterval),
CompactionEnabled: compactionEnabled,
CompactionThreshold: int(compactionThreshold),
CompactionRatio: int(compactionRatio),
}
if settings.Language == "" {
settings.Language = DefaultLanguage
}
if settings.AclDefaultEffect == "" {
settings.AclDefaultEffect = "allow"
}
if !isValidReasoningEffort(settings.ReasoningEffort) {
settings.ReasoningEffort = DefaultReasoningEffort
}
if settings.HeartbeatInterval <= 0 {
settings.HeartbeatInterval = DefaultHeartbeatInterval
}
if settings.CompactionThreshold < 0 {
settings.CompactionThreshold = 0
}
if settings.CompactionRatio < 1 || settings.CompactionRatio > 100 {
settings.CompactionRatio = 80
}
return settings
}
func isValidReasoningEffort(effort string) bool {
switch effort {
case "low", "medium", "high":
return true
default:
return false
}
}
func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings {
return normalizeBotSettingsFields(
row.Language,
row.ReasoningEnabled,
row.ReasoningEffort,
row.HeartbeatEnabled,
row.HeartbeatInterval,
row.CompactionEnabled,
row.CompactionThreshold,
row.CompactionRatio,
row.Timezone,
row.ChatModelID,
row.HeartbeatModelID,
row.CompactionModelID,
row.TitleModelID,
row.ImageModelID,
row.SearchProviderID,
row.MemoryProviderID,
row.TtsModelID,
row.TranscriptionModelID,
row.BrowserContextID,
row.PersistFullToolResults,
row.ShowToolCallsInIm,
)
}
func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings {
return normalizeBotSettingsFields(
row.Language,
row.ReasoningEnabled,
row.ReasoningEffort,
row.HeartbeatEnabled,
row.HeartbeatInterval,
row.CompactionEnabled,
row.CompactionThreshold,
row.CompactionRatio,
row.Timezone,
row.ChatModelID,
row.HeartbeatModelID,
row.CompactionModelID,
row.TitleModelID,
row.ImageModelID,
row.SearchProviderID,
row.MemoryProviderID,
row.TtsModelID,
row.TranscriptionModelID,
row.BrowserContextID,
row.PersistFullToolResults,
row.ShowToolCallsInIm,
)
}
func normalizeBotSettingsFields(
language string,
reasoningEnabled bool,
reasoningEffort string,
heartbeatEnabled bool,
heartbeatInterval int32,
compactionEnabled bool,
compactionThreshold int32,
compactionRatio int32,
timezone pgtype.Text,
chatModelID pgtype.UUID,
heartbeatModelID pgtype.UUID,
compactionModelID pgtype.UUID,
titleModelID pgtype.UUID,
imageModelID pgtype.UUID,
searchProviderID pgtype.UUID,
memoryProviderID pgtype.UUID,
ttsModelID pgtype.UUID,
transcriptionModelID pgtype.UUID,
browserContextID pgtype.UUID,
persistFullToolResults bool,
showToolCallsInIM bool,
) Settings {
settings := normalizeBotSetting(language, "", reasoningEnabled, reasoningEffort, heartbeatEnabled, heartbeatInterval, compactionEnabled, compactionThreshold, compactionRatio)
if timezone.Valid {
settings.Timezone = timezone.String
}
if chatModelID.Valid {
settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String()
}
if heartbeatModelID.Valid {
settings.HeartbeatModelID = uuid.UUID(heartbeatModelID.Bytes).String()
}
if compactionModelID.Valid {
settings.CompactionModelID = uuid.UUID(compactionModelID.Bytes).String()
}
if titleModelID.Valid {
settings.TitleModelID = uuid.UUID(titleModelID.Bytes).String()
}
if imageModelID.Valid {
settings.ImageModelID = uuid.UUID(imageModelID.Bytes).String()
}
if searchProviderID.Valid {
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
}
if memoryProviderID.Valid {
settings.MemoryProviderID = uuid.UUID(memoryProviderID.Bytes).String()
}
if ttsModelID.Valid {
settings.TtsModelID = uuid.UUID(ttsModelID.Bytes).String()
}
if transcriptionModelID.Valid {
settings.TranscriptionModelID = uuid.UUID(transcriptionModelID.Bytes).String()
}
if browserContextID.Valid {
settings.BrowserContextID = uuid.UUID(browserContextID.Bytes).String()
}
settings.PersistFullToolResults = persistFullToolResults
settings.ShowToolCallsInIM = showToolCallsInIM
return settings
}
func (s *Service) getDefaultEffect(ctx context.Context, botID string) (string, error) {
if s.acl == nil {
return "allow", nil
}
return s.acl.GetDefaultEffect(ctx, botID)
}
func (s *Service) setDefaultEffect(ctx context.Context, botID, effect string) error {
if s.acl == nil {
return nil
}
if effect == "" {
return nil
}
return s.acl.SetDefaultEffect(ctx, botID, effect)
}
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef)
}
// Preferred path: when caller already passes the model UUID.
if parsed, err := db.ParseUUID(modelID); err == nil {
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
return parsed, nil
} else if !errors.Is(err, pgx.ErrNoRows) {
return pgtype.UUID{}, err
}
}
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
if err != nil {
return pgtype.UUID{}, err
}
if len(rows) == 0 {
return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID)
}
if len(rows) > 1 {
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID)
}
return rows[0].ID, nil
}
func normalizeOptionalTimezone(raw string) (pgtype.Text, error) {
normalized := strings.TrimSpace(raw)
if normalized == "" {
return pgtype.Text{}, nil
}
loc, _, err := tzutil.Resolve(normalized)
if err != nil {
return pgtype.Text{}, fmt.Errorf("invalid timezone: %w", err)
}
return pgtype.Text{String: loc.String(), Valid: true}, nil
}