mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
471 lines
15 KiB
Go
471 lines
15 KiB
Go
package acl
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
|
|
"github.com/memohai/memoh/internal/bots"
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidRuleSubject = errors.New("exactly one of user_id or channel_identity_id is required")
|
|
ErrInvalidSourceScope = errors.New("invalid source scope")
|
|
)
|
|
|
|
type Service struct {
|
|
queries *sqlc.Queries
|
|
bots *bots.Service
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func NewService(log *slog.Logger, queries *sqlc.Queries, botService *bots.Service) *Service {
|
|
if log == nil {
|
|
log = slog.Default()
|
|
}
|
|
return &Service{
|
|
queries: queries,
|
|
bots: botService,
|
|
logger: log.With(slog.String("service", "acl")),
|
|
}
|
|
}
|
|
|
|
func (s *Service) AllowGuestEnabled(ctx context.Context, botID string) (bool, error) {
|
|
if s == nil || s.queries == nil {
|
|
return false, errors.New("acl queries not configured")
|
|
}
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return s.queries.HasBotACLGuestAllAllowRule(ctx, pgBotID)
|
|
}
|
|
|
|
func (s *Service) SetAllowGuest(ctx context.Context, botID, createdByUserID string, enabled bool) error {
|
|
if s == nil || s.queries == nil {
|
|
return errors.New("acl queries not configured")
|
|
}
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if enabled {
|
|
_, err = s.queries.UpsertBotACLGuestAllAllowRule(ctx, sqlc.UpsertBotACLGuestAllAllowRuleParams{
|
|
BotID: pgBotID,
|
|
CreatedByUserID: optionalUUID(createdByUserID),
|
|
})
|
|
return err
|
|
}
|
|
return s.queries.DeleteBotACLGuestAllAllowRule(ctx, pgBotID)
|
|
}
|
|
|
|
func (s *Service) ListWhitelist(ctx context.Context, botID string) ([]Rule, error) {
|
|
return s.listByEffect(ctx, botID, EffectAllow)
|
|
}
|
|
|
|
func (s *Service) ListBlacklist(ctx context.Context, botID string) ([]Rule, error) {
|
|
return s.listByEffect(ctx, botID, EffectDeny)
|
|
}
|
|
|
|
func (s *Service) AddWhitelistEntry(ctx context.Context, botID, createdByUserID string, req UpsertRuleRequest) (Rule, error) {
|
|
return s.upsertEntry(ctx, botID, createdByUserID, EffectAllow, req)
|
|
}
|
|
|
|
func (s *Service) AddBlacklistEntry(ctx context.Context, botID, createdByUserID string, req UpsertRuleRequest) (Rule, error) {
|
|
return s.upsertEntry(ctx, botID, createdByUserID, EffectDeny, req)
|
|
}
|
|
|
|
func (s *Service) DeleteRule(ctx context.Context, ruleID string) error {
|
|
if s == nil || s.queries == nil {
|
|
return errors.New("acl queries not configured")
|
|
}
|
|
pgRuleID, err := db.ParseUUID(ruleID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.DeleteBotACLRuleByID(ctx, pgRuleID)
|
|
}
|
|
|
|
func (s *Service) CanPerformChatTrigger(ctx context.Context, req ChatTriggerRequest) (bool, error) {
|
|
if s == nil {
|
|
return false, errors.New("acl service not configured")
|
|
}
|
|
botID := strings.TrimSpace(req.BotID)
|
|
userID := strings.TrimSpace(req.UserID)
|
|
channelIdentityID := strings.TrimSpace(req.ChannelIdentityID)
|
|
sourceScope, err := normalizeSourceScope(req.SourceScope)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if s.queries == nil || s.bots == nil {
|
|
return false, errors.New("acl service not configured")
|
|
}
|
|
|
|
bot, err := s.bots.Get(ctx, botID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if userID != "" && strings.TrimSpace(bot.OwnerUserID) == userID {
|
|
return true, nil
|
|
}
|
|
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if userID != "" {
|
|
matched, err := s.queries.HasBotACLUserRule(ctx, sqlc.HasBotACLUserRuleParams{
|
|
BotID: pgBotID,
|
|
Effect: EffectDeny,
|
|
UserID: optionalUUID(userID),
|
|
SourceChannel: optionalText(sourceScope.Channel),
|
|
SourceConversationType: optionalText(sourceScope.ConversationType),
|
|
SourceConversationID: optionalText(sourceScope.ConversationID),
|
|
SourceThreadID: optionalText(sourceScope.ThreadID),
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if matched {
|
|
return false, nil
|
|
}
|
|
}
|
|
if channelIdentityID != "" {
|
|
matched, err := s.queries.HasBotACLChannelIdentityRule(ctx, sqlc.HasBotACLChannelIdentityRuleParams{
|
|
BotID: pgBotID,
|
|
Effect: EffectDeny,
|
|
ChannelIdentityID: optionalUUID(channelIdentityID),
|
|
SourceChannel: optionalText(sourceScope.Channel),
|
|
SourceConversationType: optionalText(sourceScope.ConversationType),
|
|
SourceConversationID: optionalText(sourceScope.ConversationID),
|
|
SourceThreadID: optionalText(sourceScope.ThreadID),
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if matched {
|
|
return false, nil
|
|
}
|
|
}
|
|
if userID != "" {
|
|
matched, err := s.queries.HasBotACLUserRule(ctx, sqlc.HasBotACLUserRuleParams{
|
|
BotID: pgBotID,
|
|
Effect: EffectAllow,
|
|
UserID: optionalUUID(userID),
|
|
SourceChannel: optionalText(sourceScope.Channel),
|
|
SourceConversationType: optionalText(sourceScope.ConversationType),
|
|
SourceConversationID: optionalText(sourceScope.ConversationID),
|
|
SourceThreadID: optionalText(sourceScope.ThreadID),
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if matched {
|
|
return true, nil
|
|
}
|
|
}
|
|
if channelIdentityID != "" {
|
|
matched, err := s.queries.HasBotACLChannelIdentityRule(ctx, sqlc.HasBotACLChannelIdentityRuleParams{
|
|
BotID: pgBotID,
|
|
Effect: EffectAllow,
|
|
ChannelIdentityID: optionalUUID(channelIdentityID),
|
|
SourceChannel: optionalText(sourceScope.Channel),
|
|
SourceConversationType: optionalText(sourceScope.ConversationType),
|
|
SourceConversationID: optionalText(sourceScope.ConversationID),
|
|
SourceThreadID: optionalText(sourceScope.ThreadID),
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if matched {
|
|
return true, nil
|
|
}
|
|
}
|
|
return s.queries.HasBotACLGuestAllAllowRule(ctx, pgBotID)
|
|
}
|
|
|
|
func (s *Service) ListObservedConversationsByChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ObservedConversationCandidate, error) {
|
|
if s == nil || s.queries == nil {
|
|
return nil, errors.New("acl queries not configured")
|
|
}
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rows, err := s.queries.ListObservedConversationsByChannelIdentity(ctx, sqlc.ListObservedConversationsByChannelIdentityParams{
|
|
BotID: pgBotID,
|
|
ChannelIdentityID: strings.TrimSpace(channelIdentityID),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items := make([]ObservedConversationCandidate, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, ObservedConversationCandidate{
|
|
RouteID: row.RouteID.String(),
|
|
Channel: strings.TrimSpace(row.Channel),
|
|
ConversationType: strings.TrimSpace(row.ConversationType),
|
|
ConversationID: strings.TrimSpace(row.ConversationID),
|
|
ThreadID: strings.TrimSpace(row.ThreadID),
|
|
ConversationName: strings.TrimSpace(row.ConversationName),
|
|
LastObservedAt: timeFromPg(row.LastObservedAt),
|
|
})
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Service) listByEffect(ctx context.Context, botID, effect string) ([]Rule, error) {
|
|
if s == nil || s.queries == nil {
|
|
return nil, errors.New("acl queries not configured")
|
|
}
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rows, err := s.queries.ListBotACLSubjectRulesByEffect(ctx, sqlc.ListBotACLSubjectRulesByEffectParams{
|
|
BotID: pgBotID,
|
|
Effect: effect,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items := make([]Rule, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, toRule(row))
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Service) upsertEntry(ctx context.Context, botID, createdByUserID, effect string, req UpsertRuleRequest) (Rule, error) {
|
|
if s == nil || s.queries == nil {
|
|
return Rule{}, errors.New("acl queries not configured")
|
|
}
|
|
pgBotID, err := db.ParseUUID(botID)
|
|
if err != nil {
|
|
return Rule{}, err
|
|
}
|
|
userID := strings.TrimSpace(req.UserID)
|
|
channelIdentityID := strings.TrimSpace(req.ChannelIdentityID)
|
|
sourceScope, err := normalizeOptionalSourceScope(req.SourceScope)
|
|
if err != nil {
|
|
return Rule{}, err
|
|
}
|
|
if (userID == "" && channelIdentityID == "") || (userID != "" && channelIdentityID != "") {
|
|
return Rule{}, ErrInvalidRuleSubject
|
|
}
|
|
if userID != "" {
|
|
row, err := s.queries.UpsertBotACLUserRule(ctx, sqlc.UpsertBotACLUserRuleParams{
|
|
BotID: pgBotID,
|
|
Effect: effect,
|
|
UserID: optionalUUID(userID),
|
|
SourceChannel: optionalText(sourceScope.Channel),
|
|
SourceConversationType: optionalText(sourceScope.ConversationType),
|
|
SourceConversationID: optionalText(sourceScope.ConversationID),
|
|
SourceThreadID: optionalText(sourceScope.ThreadID),
|
|
CreatedByUserID: optionalUUID(createdByUserID),
|
|
})
|
|
if err != nil {
|
|
return Rule{}, err
|
|
}
|
|
return ruleFromWriteRow(
|
|
row.ID,
|
|
row.BotID,
|
|
row.Action,
|
|
row.Effect,
|
|
row.SubjectKind,
|
|
row.UserID,
|
|
row.ChannelIdentityID,
|
|
row.SourceChannel,
|
|
row.SourceConversationType,
|
|
row.SourceConversationID,
|
|
row.SourceThreadID,
|
|
row.CreatedAt,
|
|
row.UpdatedAt,
|
|
), nil
|
|
}
|
|
sourceScope, err = s.normalizeChannelIdentitySourceScope(ctx, channelIdentityID, sourceScope)
|
|
if err != nil {
|
|
return Rule{}, err
|
|
}
|
|
row, err := s.queries.UpsertBotACLChannelIdentityRule(ctx, sqlc.UpsertBotACLChannelIdentityRuleParams{
|
|
BotID: pgBotID,
|
|
Effect: effect,
|
|
ChannelIdentityID: optionalUUID(channelIdentityID),
|
|
SourceChannel: optionalText(sourceScope.Channel),
|
|
SourceConversationType: optionalText(sourceScope.ConversationType),
|
|
SourceConversationID: optionalText(sourceScope.ConversationID),
|
|
SourceThreadID: optionalText(sourceScope.ThreadID),
|
|
CreatedByUserID: optionalUUID(createdByUserID),
|
|
})
|
|
if err != nil {
|
|
return Rule{}, err
|
|
}
|
|
return ruleFromWriteRow(
|
|
row.ID,
|
|
row.BotID,
|
|
row.Action,
|
|
row.Effect,
|
|
row.SubjectKind,
|
|
row.UserID,
|
|
row.ChannelIdentityID,
|
|
row.SourceChannel,
|
|
row.SourceConversationType,
|
|
row.SourceConversationID,
|
|
row.SourceThreadID,
|
|
row.CreatedAt,
|
|
row.UpdatedAt,
|
|
), nil
|
|
}
|
|
|
|
func toRule(row sqlc.ListBotACLSubjectRulesByEffectRow) Rule {
|
|
rule := Rule{
|
|
ID: uuid.UUID(row.ID.Bytes).String(),
|
|
BotID: uuid.UUID(row.BotID.Bytes).String(),
|
|
Action: row.Action,
|
|
Effect: row.Effect,
|
|
SubjectKind: row.SubjectKind,
|
|
UserUsername: strings.TrimSpace(row.UserUsername.String),
|
|
UserDisplayName: strings.TrimSpace(row.UserDisplayName.String),
|
|
UserAvatarURL: strings.TrimSpace(row.UserAvatarUrl.String),
|
|
ChannelType: strings.TrimSpace(row.ChannelType.String),
|
|
ChannelSubjectID: strings.TrimSpace(row.ChannelSubjectID.String),
|
|
ChannelIdentityDisplayName: strings.TrimSpace(row.ChannelIdentityDisplayName.String),
|
|
ChannelIdentityAvatarURL: strings.TrimSpace(row.ChannelIdentityAvatarUrl.String),
|
|
LinkedUserUsername: strings.TrimSpace(row.LinkedUserUsername.String),
|
|
LinkedUserDisplayName: strings.TrimSpace(row.LinkedUserDisplayName.String),
|
|
LinkedUserAvatarURL: strings.TrimSpace(row.LinkedUserAvatarUrl.String),
|
|
CreatedAt: timeFromPg(row.CreatedAt),
|
|
UpdatedAt: timeFromPg(row.UpdatedAt),
|
|
}
|
|
rule.SourceScope = sourceScopeFromPg(row.SourceChannel, row.SourceConversationType, row.SourceConversationID, row.SourceThreadID)
|
|
if row.UserID.Valid {
|
|
rule.UserID = uuid.UUID(row.UserID.Bytes).String()
|
|
}
|
|
if row.ChannelIdentityID.Valid {
|
|
rule.ChannelIdentityID = uuid.UUID(row.ChannelIdentityID.Bytes).String()
|
|
}
|
|
if row.LinkedUserID.Valid {
|
|
rule.LinkedUserID = uuid.UUID(row.LinkedUserID.Bytes).String()
|
|
}
|
|
return rule
|
|
}
|
|
|
|
func optionalUUID(value string) pgtype.UUID {
|
|
parsed, err := db.ParseUUID(strings.TrimSpace(value))
|
|
if err != nil {
|
|
return pgtype.UUID{}
|
|
}
|
|
return parsed
|
|
}
|
|
|
|
func optionalText(value string) pgtype.Text {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return pgtype.Text{}
|
|
}
|
|
return pgtype.Text{String: value, Valid: true}
|
|
}
|
|
|
|
func normalizeSourceScope(scope SourceScope) (SourceScope, error) {
|
|
normalized := scope.Normalize()
|
|
if normalized.ThreadID != "" && normalized.ConversationID == "" {
|
|
return SourceScope{}, ErrInvalidSourceScope
|
|
}
|
|
if (normalized.ConversationID != "" || normalized.ThreadID != "") && normalized.Channel == "" {
|
|
return SourceScope{}, ErrInvalidSourceScope
|
|
}
|
|
return normalized, nil
|
|
}
|
|
|
|
func normalizeOptionalSourceScope(scope *SourceScope) (SourceScope, error) {
|
|
if scope == nil {
|
|
return SourceScope{}, nil
|
|
}
|
|
normalized, err := normalizeSourceScope(*scope)
|
|
if err != nil {
|
|
return SourceScope{}, err
|
|
}
|
|
return normalized, nil
|
|
}
|
|
|
|
func (s *Service) normalizeChannelIdentitySourceScope(ctx context.Context, channelIdentityID string, sourceScope SourceScope) (SourceScope, error) {
|
|
channelIdentityID = strings.TrimSpace(channelIdentityID)
|
|
if channelIdentityID == "" {
|
|
return sourceScope, nil
|
|
}
|
|
if s == nil || s.queries == nil {
|
|
return SourceScope{}, errors.New("acl queries not configured")
|
|
}
|
|
pgChannelIdentityID, err := db.ParseUUID(channelIdentityID)
|
|
if err != nil {
|
|
return SourceScope{}, err
|
|
}
|
|
identityRow, err := s.queries.GetChannelIdentityByID(ctx, pgChannelIdentityID)
|
|
if err != nil {
|
|
return SourceScope{}, err
|
|
}
|
|
sourceScope.Channel = strings.TrimSpace(identityRow.ChannelType)
|
|
return normalizeSourceScope(sourceScope)
|
|
}
|
|
|
|
func sourceScopeFromPg(channelValue, conversationTypeValue, conversationIDValue, threadIDValue pgtype.Text) *SourceScope {
|
|
scope := SourceScope{
|
|
Channel: strings.TrimSpace(channelValue.String),
|
|
ConversationType: strings.TrimSpace(conversationTypeValue.String),
|
|
ConversationID: strings.TrimSpace(conversationIDValue.String),
|
|
ThreadID: strings.TrimSpace(threadIDValue.String),
|
|
}
|
|
if scope.IsZero() {
|
|
return nil
|
|
}
|
|
return &scope
|
|
}
|
|
|
|
func ruleFromWriteRow(
|
|
id pgtype.UUID,
|
|
botID pgtype.UUID,
|
|
action string,
|
|
effect string,
|
|
subjectKind string,
|
|
userID pgtype.UUID,
|
|
channelIdentityID pgtype.UUID,
|
|
sourceChannel pgtype.Text,
|
|
sourceConversationType pgtype.Text,
|
|
sourceConversationID pgtype.Text,
|
|
sourceThreadID pgtype.Text,
|
|
createdAt pgtype.Timestamptz,
|
|
updatedAt pgtype.Timestamptz,
|
|
) Rule {
|
|
rule := Rule{
|
|
ID: uuid.UUID(id.Bytes).String(),
|
|
BotID: uuid.UUID(botID.Bytes).String(),
|
|
Action: action,
|
|
Effect: effect,
|
|
SubjectKind: subjectKind,
|
|
SourceScope: sourceScopeFromPg(sourceChannel, sourceConversationType, sourceConversationID, sourceThreadID),
|
|
CreatedAt: timeFromPg(createdAt),
|
|
UpdatedAt: timeFromPg(updatedAt),
|
|
}
|
|
if userID.Valid {
|
|
rule.UserID = uuid.UUID(userID.Bytes).String()
|
|
}
|
|
if channelIdentityID.Valid {
|
|
rule.ChannelIdentityID = uuid.UUID(channelIdentityID.Bytes).String()
|
|
}
|
|
return rule
|
|
}
|
|
|
|
func timeFromPg(value pgtype.Timestamptz) time.Time {
|
|
if value.Valid {
|
|
return value.Time
|
|
}
|
|
return time.Time{}
|
|
}
|