Files
Memoh/internal/acl/service.go
T

471 lines
15 KiB
Go

package acl
import (
"context"
"errors"
"log/slog"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
var (
ErrInvalidRuleSubject = errors.New("exactly one of user_id or channel_identity_id is required")
ErrInvalidSourceScope = errors.New("invalid source scope")
)
type Service struct {
queries *sqlc.Queries
bots *bots.Service
logger *slog.Logger
}
func NewService(log *slog.Logger, queries *sqlc.Queries, botService *bots.Service) *Service {
if log == nil {
log = slog.Default()
}
return &Service{
queries: queries,
bots: botService,
logger: log.With(slog.String("service", "acl")),
}
}
func (s *Service) AllowGuestEnabled(ctx context.Context, botID string) (bool, error) {
if s == nil || s.queries == nil {
return false, errors.New("acl queries not configured")
}
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return false, err
}
return s.queries.HasBotACLGuestAllAllowRule(ctx, pgBotID)
}
func (s *Service) SetAllowGuest(ctx context.Context, botID, createdByUserID string, enabled bool) error {
if s == nil || s.queries == nil {
return errors.New("acl queries not configured")
}
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return err
}
if enabled {
_, err = s.queries.UpsertBotACLGuestAllAllowRule(ctx, sqlc.UpsertBotACLGuestAllAllowRuleParams{
BotID: pgBotID,
CreatedByUserID: optionalUUID(createdByUserID),
})
return err
}
return s.queries.DeleteBotACLGuestAllAllowRule(ctx, pgBotID)
}
func (s *Service) ListWhitelist(ctx context.Context, botID string) ([]Rule, error) {
return s.listByEffect(ctx, botID, EffectAllow)
}
func (s *Service) ListBlacklist(ctx context.Context, botID string) ([]Rule, error) {
return s.listByEffect(ctx, botID, EffectDeny)
}
func (s *Service) AddWhitelistEntry(ctx context.Context, botID, createdByUserID string, req UpsertRuleRequest) (Rule, error) {
return s.upsertEntry(ctx, botID, createdByUserID, EffectAllow, req)
}
func (s *Service) AddBlacklistEntry(ctx context.Context, botID, createdByUserID string, req UpsertRuleRequest) (Rule, error) {
return s.upsertEntry(ctx, botID, createdByUserID, EffectDeny, req)
}
func (s *Service) DeleteRule(ctx context.Context, ruleID string) error {
if s == nil || s.queries == nil {
return errors.New("acl queries not configured")
}
pgRuleID, err := db.ParseUUID(ruleID)
if err != nil {
return err
}
return s.queries.DeleteBotACLRuleByID(ctx, pgRuleID)
}
func (s *Service) CanPerformChatTrigger(ctx context.Context, req ChatTriggerRequest) (bool, error) {
if s == nil {
return false, errors.New("acl service not configured")
}
botID := strings.TrimSpace(req.BotID)
userID := strings.TrimSpace(req.UserID)
channelIdentityID := strings.TrimSpace(req.ChannelIdentityID)
sourceScope, err := normalizeSourceScope(req.SourceScope)
if err != nil {
return false, err
}
if s.queries == nil || s.bots == nil {
return false, errors.New("acl service not configured")
}
bot, err := s.bots.Get(ctx, botID)
if err != nil {
return false, err
}
if userID != "" && strings.TrimSpace(bot.OwnerUserID) == userID {
return true, nil
}
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return false, err
}
if userID != "" {
matched, err := s.queries.HasBotACLUserRule(ctx, sqlc.HasBotACLUserRuleParams{
BotID: pgBotID,
Effect: EffectDeny,
UserID: optionalUUID(userID),
SourceChannel: optionalText(sourceScope.Channel),
SourceConversationType: optionalText(sourceScope.ConversationType),
SourceConversationID: optionalText(sourceScope.ConversationID),
SourceThreadID: optionalText(sourceScope.ThreadID),
})
if err != nil {
return false, err
}
if matched {
return false, nil
}
}
if channelIdentityID != "" {
matched, err := s.queries.HasBotACLChannelIdentityRule(ctx, sqlc.HasBotACLChannelIdentityRuleParams{
BotID: pgBotID,
Effect: EffectDeny,
ChannelIdentityID: optionalUUID(channelIdentityID),
SourceChannel: optionalText(sourceScope.Channel),
SourceConversationType: optionalText(sourceScope.ConversationType),
SourceConversationID: optionalText(sourceScope.ConversationID),
SourceThreadID: optionalText(sourceScope.ThreadID),
})
if err != nil {
return false, err
}
if matched {
return false, nil
}
}
if userID != "" {
matched, err := s.queries.HasBotACLUserRule(ctx, sqlc.HasBotACLUserRuleParams{
BotID: pgBotID,
Effect: EffectAllow,
UserID: optionalUUID(userID),
SourceChannel: optionalText(sourceScope.Channel),
SourceConversationType: optionalText(sourceScope.ConversationType),
SourceConversationID: optionalText(sourceScope.ConversationID),
SourceThreadID: optionalText(sourceScope.ThreadID),
})
if err != nil {
return false, err
}
if matched {
return true, nil
}
}
if channelIdentityID != "" {
matched, err := s.queries.HasBotACLChannelIdentityRule(ctx, sqlc.HasBotACLChannelIdentityRuleParams{
BotID: pgBotID,
Effect: EffectAllow,
ChannelIdentityID: optionalUUID(channelIdentityID),
SourceChannel: optionalText(sourceScope.Channel),
SourceConversationType: optionalText(sourceScope.ConversationType),
SourceConversationID: optionalText(sourceScope.ConversationID),
SourceThreadID: optionalText(sourceScope.ThreadID),
})
if err != nil {
return false, err
}
if matched {
return true, nil
}
}
return s.queries.HasBotACLGuestAllAllowRule(ctx, pgBotID)
}
func (s *Service) ListObservedConversationsByChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ObservedConversationCandidate, error) {
if s == nil || s.queries == nil {
return nil, errors.New("acl queries not configured")
}
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return nil, err
}
rows, err := s.queries.ListObservedConversationsByChannelIdentity(ctx, sqlc.ListObservedConversationsByChannelIdentityParams{
BotID: pgBotID,
ChannelIdentityID: strings.TrimSpace(channelIdentityID),
})
if err != nil {
return nil, err
}
items := make([]ObservedConversationCandidate, 0, len(rows))
for _, row := range rows {
items = append(items, ObservedConversationCandidate{
RouteID: row.RouteID.String(),
Channel: strings.TrimSpace(row.Channel),
ConversationType: strings.TrimSpace(row.ConversationType),
ConversationID: strings.TrimSpace(row.ConversationID),
ThreadID: strings.TrimSpace(row.ThreadID),
ConversationName: strings.TrimSpace(row.ConversationName),
LastObservedAt: timeFromPg(row.LastObservedAt),
})
}
return items, nil
}
func (s *Service) listByEffect(ctx context.Context, botID, effect string) ([]Rule, error) {
if s == nil || s.queries == nil {
return nil, errors.New("acl queries not configured")
}
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return nil, err
}
rows, err := s.queries.ListBotACLSubjectRulesByEffect(ctx, sqlc.ListBotACLSubjectRulesByEffectParams{
BotID: pgBotID,
Effect: effect,
})
if err != nil {
return nil, err
}
items := make([]Rule, 0, len(rows))
for _, row := range rows {
items = append(items, toRule(row))
}
return items, nil
}
func (s *Service) upsertEntry(ctx context.Context, botID, createdByUserID, effect string, req UpsertRuleRequest) (Rule, error) {
if s == nil || s.queries == nil {
return Rule{}, errors.New("acl queries not configured")
}
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return Rule{}, err
}
userID := strings.TrimSpace(req.UserID)
channelIdentityID := strings.TrimSpace(req.ChannelIdentityID)
sourceScope, err := normalizeOptionalSourceScope(req.SourceScope)
if err != nil {
return Rule{}, err
}
if (userID == "" && channelIdentityID == "") || (userID != "" && channelIdentityID != "") {
return Rule{}, ErrInvalidRuleSubject
}
if userID != "" {
row, err := s.queries.UpsertBotACLUserRule(ctx, sqlc.UpsertBotACLUserRuleParams{
BotID: pgBotID,
Effect: effect,
UserID: optionalUUID(userID),
SourceChannel: optionalText(sourceScope.Channel),
SourceConversationType: optionalText(sourceScope.ConversationType),
SourceConversationID: optionalText(sourceScope.ConversationID),
SourceThreadID: optionalText(sourceScope.ThreadID),
CreatedByUserID: optionalUUID(createdByUserID),
})
if err != nil {
return Rule{}, err
}
return ruleFromWriteRow(
row.ID,
row.BotID,
row.Action,
row.Effect,
row.SubjectKind,
row.UserID,
row.ChannelIdentityID,
row.SourceChannel,
row.SourceConversationType,
row.SourceConversationID,
row.SourceThreadID,
row.CreatedAt,
row.UpdatedAt,
), nil
}
sourceScope, err = s.normalizeChannelIdentitySourceScope(ctx, channelIdentityID, sourceScope)
if err != nil {
return Rule{}, err
}
row, err := s.queries.UpsertBotACLChannelIdentityRule(ctx, sqlc.UpsertBotACLChannelIdentityRuleParams{
BotID: pgBotID,
Effect: effect,
ChannelIdentityID: optionalUUID(channelIdentityID),
SourceChannel: optionalText(sourceScope.Channel),
SourceConversationType: optionalText(sourceScope.ConversationType),
SourceConversationID: optionalText(sourceScope.ConversationID),
SourceThreadID: optionalText(sourceScope.ThreadID),
CreatedByUserID: optionalUUID(createdByUserID),
})
if err != nil {
return Rule{}, err
}
return ruleFromWriteRow(
row.ID,
row.BotID,
row.Action,
row.Effect,
row.SubjectKind,
row.UserID,
row.ChannelIdentityID,
row.SourceChannel,
row.SourceConversationType,
row.SourceConversationID,
row.SourceThreadID,
row.CreatedAt,
row.UpdatedAt,
), nil
}
func toRule(row sqlc.ListBotACLSubjectRulesByEffectRow) Rule {
rule := Rule{
ID: uuid.UUID(row.ID.Bytes).String(),
BotID: uuid.UUID(row.BotID.Bytes).String(),
Action: row.Action,
Effect: row.Effect,
SubjectKind: row.SubjectKind,
UserUsername: strings.TrimSpace(row.UserUsername.String),
UserDisplayName: strings.TrimSpace(row.UserDisplayName.String),
UserAvatarURL: strings.TrimSpace(row.UserAvatarUrl.String),
ChannelType: strings.TrimSpace(row.ChannelType.String),
ChannelSubjectID: strings.TrimSpace(row.ChannelSubjectID.String),
ChannelIdentityDisplayName: strings.TrimSpace(row.ChannelIdentityDisplayName.String),
ChannelIdentityAvatarURL: strings.TrimSpace(row.ChannelIdentityAvatarUrl.String),
LinkedUserUsername: strings.TrimSpace(row.LinkedUserUsername.String),
LinkedUserDisplayName: strings.TrimSpace(row.LinkedUserDisplayName.String),
LinkedUserAvatarURL: strings.TrimSpace(row.LinkedUserAvatarUrl.String),
CreatedAt: timeFromPg(row.CreatedAt),
UpdatedAt: timeFromPg(row.UpdatedAt),
}
rule.SourceScope = sourceScopeFromPg(row.SourceChannel, row.SourceConversationType, row.SourceConversationID, row.SourceThreadID)
if row.UserID.Valid {
rule.UserID = uuid.UUID(row.UserID.Bytes).String()
}
if row.ChannelIdentityID.Valid {
rule.ChannelIdentityID = uuid.UUID(row.ChannelIdentityID.Bytes).String()
}
if row.LinkedUserID.Valid {
rule.LinkedUserID = uuid.UUID(row.LinkedUserID.Bytes).String()
}
return rule
}
func optionalUUID(value string) pgtype.UUID {
parsed, err := db.ParseUUID(strings.TrimSpace(value))
if err != nil {
return pgtype.UUID{}
}
return parsed
}
func optionalText(value string) pgtype.Text {
value = strings.TrimSpace(value)
if value == "" {
return pgtype.Text{}
}
return pgtype.Text{String: value, Valid: true}
}
func normalizeSourceScope(scope SourceScope) (SourceScope, error) {
normalized := scope.Normalize()
if normalized.ThreadID != "" && normalized.ConversationID == "" {
return SourceScope{}, ErrInvalidSourceScope
}
if (normalized.ConversationID != "" || normalized.ThreadID != "") && normalized.Channel == "" {
return SourceScope{}, ErrInvalidSourceScope
}
return normalized, nil
}
func normalizeOptionalSourceScope(scope *SourceScope) (SourceScope, error) {
if scope == nil {
return SourceScope{}, nil
}
normalized, err := normalizeSourceScope(*scope)
if err != nil {
return SourceScope{}, err
}
return normalized, nil
}
func (s *Service) normalizeChannelIdentitySourceScope(ctx context.Context, channelIdentityID string, sourceScope SourceScope) (SourceScope, error) {
channelIdentityID = strings.TrimSpace(channelIdentityID)
if channelIdentityID == "" {
return sourceScope, nil
}
if s == nil || s.queries == nil {
return SourceScope{}, errors.New("acl queries not configured")
}
pgChannelIdentityID, err := db.ParseUUID(channelIdentityID)
if err != nil {
return SourceScope{}, err
}
identityRow, err := s.queries.GetChannelIdentityByID(ctx, pgChannelIdentityID)
if err != nil {
return SourceScope{}, err
}
sourceScope.Channel = strings.TrimSpace(identityRow.ChannelType)
return normalizeSourceScope(sourceScope)
}
func sourceScopeFromPg(channelValue, conversationTypeValue, conversationIDValue, threadIDValue pgtype.Text) *SourceScope {
scope := SourceScope{
Channel: strings.TrimSpace(channelValue.String),
ConversationType: strings.TrimSpace(conversationTypeValue.String),
ConversationID: strings.TrimSpace(conversationIDValue.String),
ThreadID: strings.TrimSpace(threadIDValue.String),
}
if scope.IsZero() {
return nil
}
return &scope
}
func ruleFromWriteRow(
id pgtype.UUID,
botID pgtype.UUID,
action string,
effect string,
subjectKind string,
userID pgtype.UUID,
channelIdentityID pgtype.UUID,
sourceChannel pgtype.Text,
sourceConversationType pgtype.Text,
sourceConversationID pgtype.Text,
sourceThreadID pgtype.Text,
createdAt pgtype.Timestamptz,
updatedAt pgtype.Timestamptz,
) Rule {
rule := Rule{
ID: uuid.UUID(id.Bytes).String(),
BotID: uuid.UUID(botID.Bytes).String(),
Action: action,
Effect: effect,
SubjectKind: subjectKind,
SourceScope: sourceScopeFromPg(sourceChannel, sourceConversationType, sourceConversationID, sourceThreadID),
CreatedAt: timeFromPg(createdAt),
UpdatedAt: timeFromPg(updatedAt),
}
if userID.Valid {
rule.UserID = uuid.UUID(userID.Bytes).String()
}
if channelIdentityID.Valid {
rule.ChannelIdentityID = uuid.UUID(channelIdentityID.Bytes).String()
}
return rule
}
func timeFromPg(value pgtype.Timestamptz) time.Time {
if value.Valid {
return value.Time
}
return time.Time{}
}