mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
390 lines
14 KiB
Go
390 lines
14 KiB
Go
package acl
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
|
|
"github.com/memohai/memoh/internal/bots"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
)
|
|
|
|
type fakeDBTX struct {
|
|
queryRowFunc func(ctx context.Context, sql string, args ...any) pgx.Row
|
|
queryFunc func(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
|
execFunc func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
|
|
}
|
|
|
|
func (f *fakeDBTX) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
|
if f.execFunc != nil {
|
|
return f.execFunc(ctx, sql, args...)
|
|
}
|
|
return pgconn.CommandTag{}, nil
|
|
}
|
|
|
|
func (f *fakeDBTX) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
|
|
if f.queryFunc != nil {
|
|
return f.queryFunc(ctx, sql, args...)
|
|
}
|
|
return &fakeRows{}, nil
|
|
}
|
|
|
|
func (f *fakeDBTX) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
|
|
if f.queryRowFunc != nil {
|
|
return f.queryRowFunc(ctx, sql, args...)
|
|
}
|
|
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
|
|
}
|
|
|
|
type fakeRow struct {
|
|
scanFunc func(dest ...any) error
|
|
}
|
|
|
|
func (r *fakeRow) Scan(dest ...any) error {
|
|
if r.scanFunc == nil {
|
|
return pgx.ErrNoRows
|
|
}
|
|
return r.scanFunc(dest...)
|
|
}
|
|
|
|
func makeBotRow(botID, ownerUserID pgtype.UUID) *fakeRow {
|
|
return &fakeRow{
|
|
scanFunc: func(dest ...any) error {
|
|
*dest[0].(*pgtype.UUID) = botID
|
|
*dest[1].(*pgtype.UUID) = ownerUserID
|
|
*dest[2].(*string) = bots.BotTypePublic
|
|
*dest[3].(*pgtype.Text) = pgtype.Text{String: "bot", Valid: true}
|
|
*dest[4].(*pgtype.Text) = pgtype.Text{}
|
|
*dest[5].(*bool) = true
|
|
*dest[6].(*string) = bots.BotStatusReady
|
|
*dest[7].(*int32) = 30
|
|
*dest[8].(*int32) = 0
|
|
*dest[9].(*int32) = 50
|
|
*dest[10].(*string) = "auto"
|
|
*dest[11].(*bool) = false
|
|
*dest[12].(*string) = "medium"
|
|
*dest[13].(*pgtype.UUID) = pgtype.UUID{}
|
|
*dest[14].(*pgtype.UUID) = pgtype.UUID{}
|
|
*dest[15].(*pgtype.UUID) = pgtype.UUID{}
|
|
*dest[16].(*bool) = false
|
|
*dest[17].(*int32) = 30
|
|
*dest[18].(*string) = ""
|
|
*dest[19].(*[]byte) = []byte(`{}`)
|
|
*dest[20].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
|
*dest[21].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func makeBoolRow(value bool) *fakeRow {
|
|
return &fakeRow{
|
|
scanFunc: func(dest ...any) error {
|
|
*dest[0].(*bool) = value
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
type fakeRows struct {
|
|
rows []func(dest ...any) error
|
|
idx int
|
|
lastErr error
|
|
}
|
|
|
|
func (*fakeRows) Close() {}
|
|
func (r *fakeRows) Err() error { return r.lastErr }
|
|
func (*fakeRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} }
|
|
func (*fakeRows) FieldDescriptions() []pgconn.FieldDescription { return nil }
|
|
func (r *fakeRows) Next() bool {
|
|
if r.idx >= len(r.rows) {
|
|
return false
|
|
}
|
|
r.idx++
|
|
return true
|
|
}
|
|
|
|
func (r *fakeRows) Scan(dest ...any) error {
|
|
if r.idx == 0 || r.idx > len(r.rows) {
|
|
return errors.New("scan called without next")
|
|
}
|
|
scan := r.rows[r.idx-1]
|
|
if scan == nil {
|
|
return nil
|
|
}
|
|
return scan(dest...)
|
|
}
|
|
func (*fakeRows) Values() ([]any, error) { return nil, nil }
|
|
func (*fakeRows) RawValues() [][]byte { return nil }
|
|
func (*fakeRows) Conn() *pgx.Conn { return nil }
|
|
|
|
func textFromArg(value any) string {
|
|
switch v := value.(type) {
|
|
case pgtype.Text:
|
|
return strings.TrimSpace(v.String)
|
|
case *pgtype.Text:
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(v.String)
|
|
case string:
|
|
return strings.TrimSpace(v)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func scopeMatches(rule *SourceScope, args ...any) bool {
|
|
if rule == nil {
|
|
return false
|
|
}
|
|
scope := rule.Normalize()
|
|
return (scope.Channel == "" || scope.Channel == textFromArg(args[3])) &&
|
|
(scope.ConversationType == "" || scope.ConversationType == textFromArg(args[4])) &&
|
|
(scope.ConversationID == "" || scope.ConversationID == textFromArg(args[5])) &&
|
|
(scope.ThreadID == "" || scope.ThreadID == textFromArg(args[6]))
|
|
}
|
|
|
|
func TestCanPerformChatTrigger(t *testing.T) {
|
|
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
|
|
ownerUUID := pgtype.UUID{Bytes: uuid.MustParse("22222222-2222-2222-2222-222222222222"), Valid: true}
|
|
userUUID := pgtype.UUID{Bytes: uuid.MustParse("44444444-4444-4444-4444-444444444444"), Valid: true}
|
|
channelIdentityUUID := pgtype.UUID{Bytes: uuid.MustParse("55555555-5555-5555-5555-555555555555"), Valid: true}
|
|
|
|
tests := []struct {
|
|
name string
|
|
userID string
|
|
channelIdentityID string
|
|
sourceScope SourceScope
|
|
denyUserScope *SourceScope
|
|
allowUserScope *SourceScope
|
|
denyChannelScope *SourceScope
|
|
allowChannelScope *SourceScope
|
|
allowGuestAll bool
|
|
wantAllowed bool
|
|
}{
|
|
{name: "owner bypass", userID: ownerUUID.String(), wantAllowed: true},
|
|
{name: "deny user wins", userID: userUUID.String(), denyUserScope: &SourceScope{}, allowGuestAll: true, wantAllowed: false},
|
|
{name: "allow user", userID: userUUID.String(), allowUserScope: &SourceScope{}, wantAllowed: true},
|
|
{name: "deny channel wins", channelIdentityID: channelIdentityUUID.String(), denyChannelScope: &SourceScope{}, allowGuestAll: true, wantAllowed: false},
|
|
{name: "allow channel identity", channelIdentityID: channelIdentityUUID.String(), allowChannelScope: &SourceScope{}, wantAllowed: true},
|
|
{
|
|
name: "scoped allow user private",
|
|
userID: userUUID.String(),
|
|
sourceScope: SourceScope{Channel: "feishu", ConversationType: "private", ConversationID: "chat-1"},
|
|
allowUserScope: &SourceScope{Channel: "feishu", ConversationType: "private", ConversationID: "chat-1"},
|
|
wantAllowed: true,
|
|
},
|
|
{
|
|
name: "scoped allow user does not match other conversation",
|
|
userID: userUUID.String(),
|
|
sourceScope: SourceScope{Channel: "feishu", ConversationType: "private", ConversationID: "chat-2"},
|
|
allowUserScope: &SourceScope{Channel: "feishu", ConversationType: "private", ConversationID: "chat-1"},
|
|
wantAllowed: false,
|
|
},
|
|
{
|
|
name: "scoped deny overrides guest fallback",
|
|
channelIdentityID: channelIdentityUUID.String(),
|
|
sourceScope: SourceScope{Channel: "telegram", ConversationType: "group", ConversationID: "group-1"},
|
|
denyChannelScope: &SourceScope{Channel: "telegram", ConversationType: "group", ConversationID: "group-1"},
|
|
allowGuestAll: true,
|
|
wantAllowed: false,
|
|
},
|
|
{
|
|
name: "scoped deny does not block different source",
|
|
channelIdentityID: channelIdentityUUID.String(),
|
|
sourceScope: SourceScope{Channel: "telegram", ConversationType: "group", ConversationID: "group-2"},
|
|
denyChannelScope: &SourceScope{Channel: "telegram", ConversationType: "group", ConversationID: "group-1"},
|
|
allowGuestAll: true,
|
|
wantAllowed: true,
|
|
},
|
|
{name: "guest_all fallback", allowGuestAll: true, wantAllowed: true},
|
|
{name: "default deny", wantAllowed: false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := &fakeDBTX{
|
|
queryRowFunc: func(_ context.Context, sql string, args ...any) pgx.Row {
|
|
switch {
|
|
case strings.Contains(sql, "FROM bots"):
|
|
return makeBotRow(botUUID, ownerUUID)
|
|
case strings.Contains(sql, "subject_kind = 'user'"):
|
|
effect := args[1].(string)
|
|
if effect == EffectDeny {
|
|
return makeBoolRow(scopeMatches(tt.denyUserScope, args...))
|
|
}
|
|
return makeBoolRow(scopeMatches(tt.allowUserScope, args...))
|
|
case strings.Contains(sql, "subject_kind = 'channel_identity'"):
|
|
effect := args[1].(string)
|
|
if effect == EffectDeny {
|
|
return makeBoolRow(scopeMatches(tt.denyChannelScope, args...))
|
|
}
|
|
return makeBoolRow(scopeMatches(tt.allowChannelScope, args...))
|
|
case strings.Contains(sql, "subject_kind = 'guest_all'"):
|
|
return makeBoolRow(tt.allowGuestAll)
|
|
default:
|
|
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
|
|
}
|
|
},
|
|
}
|
|
|
|
queries := sqlc.New(db)
|
|
botService := bots.NewService(nil, queries)
|
|
service := NewService(nil, queries, botService)
|
|
allowed, err := service.CanPerformChatTrigger(context.Background(), ChatTriggerRequest{
|
|
BotID: botUUID.String(),
|
|
UserID: tt.userID,
|
|
ChannelIdentityID: tt.channelIdentityID,
|
|
SourceScope: tt.sourceScope,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if allowed != tt.wantAllowed {
|
|
t.Fatalf("expected allowed=%v, got %v", tt.wantAllowed, allowed)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCanPerformChatTriggerRejectsInvalidScope(t *testing.T) {
|
|
service := NewService(nil, nil, nil)
|
|
_, err := service.CanPerformChatTrigger(context.Background(), ChatTriggerRequest{
|
|
BotID: "bot-1",
|
|
SourceScope: SourceScope{
|
|
Channel: "feishu",
|
|
ThreadID: "thread-1",
|
|
},
|
|
})
|
|
if !errors.Is(err, ErrInvalidSourceScope) {
|
|
t.Fatalf("expected invalid source scope error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestListObservedConversationsByChannelIdentity(t *testing.T) {
|
|
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
|
|
channelIdentityUUID := pgtype.UUID{Bytes: uuid.MustParse("55555555-5555-5555-5555-555555555555"), Valid: true}
|
|
routeUUID := pgtype.UUID{Bytes: uuid.MustParse("66666666-6666-6666-6666-666666666666"), Valid: true}
|
|
now := time.Now().UTC()
|
|
|
|
db := &fakeDBTX{
|
|
queryFunc: func(_ context.Context, sql string, _ ...any) (pgx.Rows, error) {
|
|
if !strings.Contains(sql, "ListObservedConversationsByChannelIdentity") &&
|
|
!strings.Contains(sql, "FROM bot_history_messages m") {
|
|
return &fakeRows{}, nil
|
|
}
|
|
return &fakeRows{
|
|
rows: []func(dest ...any) error{
|
|
func(dest ...any) error {
|
|
*dest[0].(*pgtype.UUID) = routeUUID
|
|
*dest[1].(*string) = "feishu"
|
|
*dest[2].(*string) = "group"
|
|
*dest[3].(*string) = "chat-1"
|
|
*dest[4].(*string) = "thread-1"
|
|
*dest[5].(*string) = "Team Chat"
|
|
*dest[6].(*pgtype.Timestamptz) = pgtype.Timestamptz{Time: now, Valid: true}
|
|
return nil
|
|
},
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
service := NewService(nil, sqlc.New(db), nil)
|
|
items, err := service.ListObservedConversationsByChannelIdentity(context.Background(), botUUID.String(), channelIdentityUUID.String())
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(items) != 1 {
|
|
t.Fatalf("expected one observed conversation, got %d", len(items))
|
|
}
|
|
if items[0].RouteID != routeUUID.String() {
|
|
t.Fatalf("unexpected route id: %s", items[0].RouteID)
|
|
}
|
|
if items[0].ConversationID != "chat-1" || items[0].ThreadID != "thread-1" {
|
|
t.Fatalf("unexpected conversation scope: %+v", items[0])
|
|
}
|
|
if items[0].ConversationName != "Team Chat" {
|
|
t.Fatalf("unexpected conversation name: %q", items[0].ConversationName)
|
|
}
|
|
}
|
|
|
|
func TestAddWhitelistEntryChannelIdentityForcesIdentityChannel(t *testing.T) {
|
|
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
|
|
ruleUUID := pgtype.UUID{Bytes: uuid.MustParse("77777777-7777-7777-7777-777777777777"), Valid: true}
|
|
channelIdentityUUID := pgtype.UUID{Bytes: uuid.MustParse("55555555-5555-5555-5555-555555555555"), Valid: true}
|
|
createdByUUID := pgtype.UUID{Bytes: uuid.MustParse("88888888-8888-8888-8888-888888888888"), Valid: true}
|
|
now := time.Now().UTC()
|
|
|
|
db := &fakeDBTX{
|
|
queryRowFunc: func(_ context.Context, sql string, args ...any) pgx.Row {
|
|
switch {
|
|
case strings.Contains(sql, "FROM channel_identities"):
|
|
return &fakeRow{
|
|
scanFunc: func(dest ...any) error {
|
|
*dest[0].(*pgtype.UUID) = channelIdentityUUID
|
|
*dest[1].(*pgtype.UUID) = pgtype.UUID{}
|
|
*dest[2].(*string) = "feishu"
|
|
*dest[3].(*string) = "ou_123"
|
|
*dest[4].(*pgtype.Text) = pgtype.Text{String: "Tester", Valid: true}
|
|
*dest[5].(*pgtype.Text) = pgtype.Text{}
|
|
*dest[6].(*[]byte) = []byte(`{}`)
|
|
*dest[7].(*pgtype.Timestamptz) = pgtype.Timestamptz{Time: now, Valid: true}
|
|
*dest[8].(*pgtype.Timestamptz) = pgtype.Timestamptz{Time: now, Valid: true}
|
|
return nil
|
|
},
|
|
}
|
|
case strings.Contains(sql, "INSERT INTO bot_acl_rules"):
|
|
if got := textFromArg(args[4]); got != "feishu" {
|
|
t.Fatalf("expected source_channel to be normalized to feishu, got %q", got)
|
|
}
|
|
return &fakeRow{
|
|
scanFunc: func(dest ...any) error {
|
|
*dest[0].(*pgtype.UUID) = ruleUUID
|
|
*dest[1].(*pgtype.UUID) = botUUID
|
|
*dest[2].(*string) = ActionChatTrigger
|
|
*dest[3].(*string) = EffectAllow
|
|
*dest[4].(*string) = SubjectKindChannelIdentity
|
|
*dest[5].(*pgtype.UUID) = pgtype.UUID{}
|
|
*dest[6].(*pgtype.UUID) = channelIdentityUUID
|
|
*dest[7].(*pgtype.Text) = pgtype.Text{String: "feishu", Valid: true}
|
|
*dest[8].(*pgtype.Text) = pgtype.Text{String: "group", Valid: true}
|
|
*dest[9].(*pgtype.Text) = pgtype.Text{String: "chat-1", Valid: true}
|
|
*dest[10].(*pgtype.Text) = pgtype.Text{}
|
|
*dest[11].(*pgtype.UUID) = createdByUUID
|
|
*dest[12].(*pgtype.Timestamptz) = pgtype.Timestamptz{Time: now, Valid: true}
|
|
*dest[13].(*pgtype.Timestamptz) = pgtype.Timestamptz{Time: now, Valid: true}
|
|
return nil
|
|
},
|
|
}
|
|
default:
|
|
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
|
|
}
|
|
},
|
|
}
|
|
|
|
service := NewService(nil, sqlc.New(db), nil)
|
|
rule, err := service.AddWhitelistEntry(context.Background(), botUUID.String(), createdByUUID.String(), UpsertRuleRequest{
|
|
ChannelIdentityID: channelIdentityUUID.String(),
|
|
SourceScope: &SourceScope{
|
|
Channel: "telegram",
|
|
ConversationType: "group",
|
|
ConversationID: "chat-1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if rule.SourceScope == nil || rule.SourceScope.Channel != "feishu" {
|
|
t.Fatalf("expected normalized source scope channel feishu, got %+v", rule.SourceScope)
|
|
}
|
|
}
|