Files
Memoh/internal/settings/service.go
T
BBQ 60517bc2a6 feat(acl): add bot security policy presets
Initialize new bots with preset ACL templates and an allow-by-default fallback so common access setups can be selected during bot creation instead of being configured manually afterward.
2026-04-14 05:51:38 +08:00

458 lines
14 KiB
Go

package settings
import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/acl"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
tzutil "github.com/memohai/memoh/internal/timezone"
)
type Service struct {
queries *sqlc.Queries
acl *acl.Service
logger *slog.Logger
}
var (
ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
ErrInvalidModelRef = errors.New("invalid model reference")
)
func NewService(log *slog.Logger, queries *sqlc.Queries, aclService *acl.Service) *Service {
return &Service{
queries: queries,
acl: aclService,
logger: log.With(slog.String("service", "settings")),
}
}
func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) {
pgID, err := db.ParseUUID(botID)
if err != nil {
return Settings{}, err
}
row, err := s.queries.GetSettingsByBotID(ctx, pgID)
if err != nil {
return Settings{}, err
}
settings := normalizeBotSettingsReadRow(row)
aclDefaultEffect, err := s.getDefaultEffect(ctx, botID)
if err != nil {
return Settings{}, err
}
settings.AclDefaultEffect = aclDefaultEffect
return settings, nil
}
func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) {
if s.queries == nil {
return Settings{}, errors.New("settings queries not configured")
}
pgID, err := db.ParseUUID(botID)
if err != nil {
return Settings{}, err
}
botRow, err := s.queries.GetBotByID(ctx, pgID)
if err != nil {
return Settings{}, err
}
aclDefaultEffect, err := s.getDefaultEffect(ctx, botID)
if err != nil {
return Settings{}, err
}
current := normalizeBotSetting(botRow.Language, aclDefaultEffect, botRow.ReasoningEnabled, botRow.ReasoningEffort, botRow.HeartbeatEnabled, botRow.HeartbeatInterval, botRow.CompactionEnabled, botRow.CompactionThreshold, botRow.CompactionRatio)
if strings.TrimSpace(req.Language) != "" {
current.Language = strings.TrimSpace(req.Language)
}
if effect := strings.TrimSpace(req.AclDefaultEffect); effect != "" {
current.AclDefaultEffect = effect
}
if req.ReasoningEnabled != nil {
current.ReasoningEnabled = *req.ReasoningEnabled
}
if req.ReasoningEffort != nil && isValidReasoningEffort(*req.ReasoningEffort) {
current.ReasoningEffort = *req.ReasoningEffort
}
if req.HeartbeatEnabled != nil {
current.HeartbeatEnabled = *req.HeartbeatEnabled
}
if req.HeartbeatInterval != nil && *req.HeartbeatInterval > 0 {
current.HeartbeatInterval = *req.HeartbeatInterval
}
if req.CompactionEnabled != nil {
current.CompactionEnabled = *req.CompactionEnabled
}
if req.CompactionThreshold != nil && *req.CompactionThreshold >= 0 {
current.CompactionThreshold = *req.CompactionThreshold
}
if req.CompactionRatio != nil && *req.CompactionRatio >= 1 && *req.CompactionRatio <= 100 {
current.CompactionRatio = *req.CompactionRatio
}
if req.PersistFullToolResults != nil {
current.PersistFullToolResults = *req.PersistFullToolResults
}
timezoneValue := pgtype.Text{}
if req.Timezone != nil {
normalized, err := normalizeOptionalTimezone(*req.Timezone)
if err != nil {
return Settings{}, err
}
timezoneValue = normalized
}
chatModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.ChatModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
chatModelUUID = modelID
}
heartbeatModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.HeartbeatModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
heartbeatModelUUID = modelID
}
compactionModelUUID := pgtype.UUID{}
if req.CompactionModelID != nil {
if value := strings.TrimSpace(*req.CompactionModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
compactionModelUUID = modelID
}
}
titleModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.TitleModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
titleModelUUID = modelID
}
imageModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.ImageModelID); value != "" {
modelID, err := s.resolveModelUUID(ctx, value)
if err != nil {
return Settings{}, err
}
imageModelUUID = modelID
}
searchProviderUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.SearchProviderID); value != "" {
providerID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
searchProviderUUID = providerID
}
memoryProviderUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.MemoryProviderID); value != "" {
providerID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
memoryProviderUUID = providerID
}
ttsModelUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.TtsModelID); value != "" {
modelID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
ttsModelUUID = modelID
}
browserContextUUID := pgtype.UUID{}
if value := strings.TrimSpace(req.BrowserContextID); value != "" {
ctxID, err := db.ParseUUID(value)
if err != nil {
return Settings{}, err
}
browserContextUUID = ctxID
}
contextTokenBudgetValue := pgtype.Int4{}
if req.ContextTokenBudget != nil && *req.ContextTokenBudget >= 0 {
v := *req.ContextTokenBudget
if v > math.MaxInt32 {
v = math.MaxInt32
}
contextTokenBudgetValue = pgtype.Int4{Int32: int32(v), Valid: true} //nolint:gosec // G115: clamped above
}
updated, err := s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{
ID: pgID,
Timezone: timezoneValue,
Language: current.Language,
ReasoningEnabled: current.ReasoningEnabled,
ReasoningEffort: current.ReasoningEffort,
HeartbeatEnabled: current.HeartbeatEnabled,
HeartbeatInterval: int32(current.HeartbeatInterval), //nolint:gosec // bounded by positive-only setter above
HeartbeatPrompt: "",
CompactionEnabled: current.CompactionEnabled,
CompactionThreshold: int32(current.CompactionThreshold), //nolint:gosec // bounded by non-negative setter above
CompactionRatio: int32(current.CompactionRatio), //nolint:gosec // bounded 1-100 above
ChatModelID: chatModelUUID,
HeartbeatModelID: heartbeatModelUUID,
CompactionModelID: compactionModelUUID,
TitleModelID: titleModelUUID,
ImageModelID: imageModelUUID,
SearchProviderID: searchProviderUUID,
MemoryProviderID: memoryProviderUUID,
TtsModelID: ttsModelUUID,
BrowserContextID: browserContextUUID,
ContextTokenBudget: contextTokenBudgetValue,
PersistFullToolResults: current.PersistFullToolResults,
})
if err != nil {
return Settings{}, err
}
createdByUserID := ""
if botRow.OwnerUserID.Valid {
createdByUserID = uuid.UUID(botRow.OwnerUserID.Bytes).String()
}
_ = createdByUserID
if err := s.setDefaultEffect(ctx, botID, current.AclDefaultEffect); err != nil {
return Settings{}, err
}
settings := normalizeBotSettingsWriteRow(updated)
settings.AclDefaultEffect = current.AclDefaultEffect
return settings, nil
}
func (s *Service) Delete(ctx context.Context, botID string) error {
if s.queries == nil {
return errors.New("settings queries not configured")
}
pgID, err := db.ParseUUID(botID)
if err != nil {
return err
}
if err := s.queries.DeleteSettingsByBotID(ctx, pgID); err != nil {
return err
}
return nil
}
func normalizeBotSetting(language string, aclDefaultEffect string, reasoningEnabled bool, reasoningEffort string, heartbeatEnabled bool, heartbeatInterval int32, compactionEnabled bool, compactionThreshold int32, compactionRatio int32) Settings {
settings := Settings{
Language: strings.TrimSpace(language),
AclDefaultEffect: strings.TrimSpace(aclDefaultEffect),
ReasoningEnabled: reasoningEnabled,
ReasoningEffort: strings.TrimSpace(reasoningEffort),
HeartbeatEnabled: heartbeatEnabled,
HeartbeatInterval: int(heartbeatInterval),
CompactionEnabled: compactionEnabled,
CompactionThreshold: int(compactionThreshold),
CompactionRatio: int(compactionRatio),
}
if settings.Language == "" {
settings.Language = DefaultLanguage
}
if settings.AclDefaultEffect == "" {
settings.AclDefaultEffect = "allow"
}
if !isValidReasoningEffort(settings.ReasoningEffort) {
settings.ReasoningEffort = DefaultReasoningEffort
}
if settings.HeartbeatInterval <= 0 {
settings.HeartbeatInterval = DefaultHeartbeatInterval
}
if settings.CompactionThreshold < 0 {
settings.CompactionThreshold = 0
}
if settings.CompactionRatio < 1 || settings.CompactionRatio > 100 {
settings.CompactionRatio = 80
}
return settings
}
func isValidReasoningEffort(effort string) bool {
switch effort {
case "low", "medium", "high":
return true
default:
return false
}
}
func normalizeBotSettingsReadRow(row sqlc.GetSettingsByBotIDRow) Settings {
return normalizeBotSettingsFields(
row.Language,
row.ReasoningEnabled,
row.ReasoningEffort,
row.HeartbeatEnabled,
row.HeartbeatInterval,
row.CompactionEnabled,
row.CompactionThreshold,
row.CompactionRatio,
row.Timezone,
row.ChatModelID,
row.HeartbeatModelID,
row.CompactionModelID,
row.TitleModelID,
row.ImageModelID,
row.SearchProviderID,
row.MemoryProviderID,
row.TtsModelID,
row.BrowserContextID,
row.ContextTokenBudget,
row.PersistFullToolResults,
)
}
func normalizeBotSettingsWriteRow(row sqlc.UpsertBotSettingsRow) Settings {
return normalizeBotSettingsFields(
row.Language,
row.ReasoningEnabled,
row.ReasoningEffort,
row.HeartbeatEnabled,
row.HeartbeatInterval,
row.CompactionEnabled,
row.CompactionThreshold,
row.CompactionRatio,
row.Timezone,
row.ChatModelID,
row.HeartbeatModelID,
row.CompactionModelID,
row.TitleModelID,
row.ImageModelID,
row.SearchProviderID,
row.MemoryProviderID,
row.TtsModelID,
row.BrowserContextID,
row.ContextTokenBudget,
row.PersistFullToolResults,
)
}
func normalizeBotSettingsFields(
language string,
reasoningEnabled bool,
reasoningEffort string,
heartbeatEnabled bool,
heartbeatInterval int32,
compactionEnabled bool,
compactionThreshold int32,
compactionRatio int32,
timezone pgtype.Text,
chatModelID pgtype.UUID,
heartbeatModelID pgtype.UUID,
compactionModelID pgtype.UUID,
titleModelID pgtype.UUID,
imageModelID pgtype.UUID,
searchProviderID pgtype.UUID,
memoryProviderID pgtype.UUID,
ttsModelID pgtype.UUID,
browserContextID pgtype.UUID,
contextTokenBudget pgtype.Int4,
persistFullToolResults bool,
) Settings {
settings := normalizeBotSetting(language, "", reasoningEnabled, reasoningEffort, heartbeatEnabled, heartbeatInterval, compactionEnabled, compactionThreshold, compactionRatio)
if timezone.Valid {
settings.Timezone = timezone.String
}
if chatModelID.Valid {
settings.ChatModelID = uuid.UUID(chatModelID.Bytes).String()
}
if heartbeatModelID.Valid {
settings.HeartbeatModelID = uuid.UUID(heartbeatModelID.Bytes).String()
}
if compactionModelID.Valid {
settings.CompactionModelID = uuid.UUID(compactionModelID.Bytes).String()
}
if titleModelID.Valid {
settings.TitleModelID = uuid.UUID(titleModelID.Bytes).String()
}
if imageModelID.Valid {
settings.ImageModelID = uuid.UUID(imageModelID.Bytes).String()
}
if searchProviderID.Valid {
settings.SearchProviderID = uuid.UUID(searchProviderID.Bytes).String()
}
if memoryProviderID.Valid {
settings.MemoryProviderID = uuid.UUID(memoryProviderID.Bytes).String()
}
if ttsModelID.Valid {
settings.TtsModelID = uuid.UUID(ttsModelID.Bytes).String()
}
if browserContextID.Valid {
settings.BrowserContextID = uuid.UUID(browserContextID.Bytes).String()
}
if contextTokenBudget.Valid {
settings.ContextTokenBudget = int(contextTokenBudget.Int32)
}
settings.PersistFullToolResults = persistFullToolResults
return settings
}
func (s *Service) getDefaultEffect(ctx context.Context, botID string) (string, error) {
if s.acl == nil {
return "allow", nil
}
return s.acl.GetDefaultEffect(ctx, botID)
}
func (s *Service) setDefaultEffect(ctx context.Context, botID, effect string) error {
if s.acl == nil {
return nil
}
if effect == "" {
return nil
}
return s.acl.SetDefaultEffect(ctx, botID, effect)
}
func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return pgtype.UUID{}, fmt.Errorf("%w: model_id is required", ErrInvalidModelRef)
}
// Preferred path: when caller already passes the model UUID.
if parsed, err := db.ParseUUID(modelID); err == nil {
if _, err := s.queries.GetModelByID(ctx, parsed); err == nil {
return parsed, nil
} else if !errors.Is(err, pgx.ErrNoRows) {
return pgtype.UUID{}, err
}
}
rows, err := s.queries.ListModelsByModelID(ctx, modelID)
if err != nil {
return pgtype.UUID{}, err
}
if len(rows) == 0 {
return pgtype.UUID{}, fmt.Errorf("%w: model not found: %s", ErrInvalidModelRef, modelID)
}
if len(rows) > 1 {
return pgtype.UUID{}, fmt.Errorf("%w: %s", ErrModelIDAmbiguous, modelID)
}
return rows[0].ID, nil
}
func normalizeOptionalTimezone(raw string) (pgtype.Text, error) {
normalized := strings.TrimSpace(raw)
if normalized == "" {
return pgtype.Text{}, nil
}
loc, _, err := tzutil.Resolve(normalized)
if err != nil {
return pgtype.Text{}, fmt.Errorf("invalid timezone: %w", err)
}
return pgtype.Text{String: loc.String(), Valid: true}, nil
}