Files
Memoh/internal/acl/service_test.go
T
Acbox 0e646625bf feat: add compaction ratio setting to control partial context compaction
Allow users to configure what percentage of older messages to compact,
keeping the most recent portion intact. Default ratio is 80%, meaning
the oldest 80% of uncompacted messages are summarized while the newest
20% remain as-is for full-fidelity context.
2026-03-29 19:14:43 +08:00

396 lines
12 KiB
Go

package acl
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/db/sqlc"
)
// ---- fake DB infrastructure ----
type fakeDBTX struct {
queryRowFunc func(ctx context.Context, sql string, args ...any) pgx.Row
queryFunc func(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
execFunc func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
}
func (f *fakeDBTX) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
if f.execFunc != nil {
return f.execFunc(ctx, sql, args...)
}
return pgconn.CommandTag{}, nil
}
func (f *fakeDBTX) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
if f.queryFunc != nil {
return f.queryFunc(ctx, sql, args...)
}
return &fakeRows{}, nil
}
func (f *fakeDBTX) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
if f.queryRowFunc != nil {
return f.queryRowFunc(ctx, sql, args...)
}
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
}
type fakeRow struct {
scanFunc func(dest ...any) error
}
func (r *fakeRow) Scan(dest ...any) error {
if r.scanFunc == nil {
return pgx.ErrNoRows
}
return r.scanFunc(dest...)
}
type fakeRows struct {
rows []func(dest ...any) error
idx int
lastErr error
}
func (*fakeRows) Close() {}
func (r *fakeRows) Err() error { return r.lastErr }
func (*fakeRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} }
func (*fakeRows) FieldDescriptions() []pgconn.FieldDescription { return nil }
func (r *fakeRows) Next() bool {
if r.idx >= len(r.rows) {
return false
}
r.idx++
return true
}
func (r *fakeRows) Scan(dest ...any) error {
if r.idx == 0 || r.idx > len(r.rows) {
return errors.New("scan called without next")
}
scan := r.rows[r.idx-1]
if scan == nil {
return nil
}
return scan(dest...)
}
func (*fakeRows) Values() ([]any, error) { return nil, nil }
func (*fakeRows) RawValues() [][]byte { return nil }
func (*fakeRows) Conn() *pgx.Conn { return nil }
// ---- helpers ----
func makeBotRow(botID, ownerUserID pgtype.UUID) *fakeRow {
return &fakeRow{
scanFunc: func(dest ...any) error {
if len(dest) < 22 {
return pgx.ErrNoRows
}
*dest[0].(*pgtype.UUID) = botID
*dest[1].(*pgtype.UUID) = ownerUserID
*dest[2].(*pgtype.Text) = pgtype.Text{String: "bot", Valid: true}
*dest[3].(*pgtype.Text) = pgtype.Text{}
*dest[4].(*pgtype.Text) = pgtype.Text{}
*dest[5].(*bool) = true
*dest[6].(*string) = bots.BotStatusReady
*dest[7].(*string) = "" // Language
*dest[8].(*bool) = false // ReasoningEnabled
*dest[9].(*string) = "medium" // ReasoningEffort
*dest[10].(*pgtype.UUID) = pgtype.UUID{} // ChatModelID
*dest[11].(*pgtype.UUID) = pgtype.UUID{} // SearchProviderID
*dest[12].(*pgtype.UUID) = pgtype.UUID{} // MemoryProviderID
*dest[13].(*bool) = false // HeartbeatEnabled
*dest[14].(*int32) = 30 // HeartbeatInterval
*dest[15].(*string) = "" // HeartbeatPrompt
*dest[16].(*bool) = false // CompactionEnabled
*dest[17].(*int32) = 100000 // CompactionThreshold
*dest[18].(*int32) = 80 // CompactionRatio
*dest[19].(*pgtype.UUID) = pgtype.UUID{} // CompactionModelID
*dest[20].(*[]byte) = []byte(`{}`)
*dest[21].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
*dest[22].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
return nil
},
}
}
func makeStringRow(value string) *fakeRow {
return &fakeRow{
scanFunc: func(dest ...any) error {
*dest[0].(*string) = value
return nil
},
}
}
func textFromArg(value any) string {
switch v := value.(type) {
case pgtype.Text:
return strings.TrimSpace(v.String)
case *pgtype.Text:
if v == nil {
return ""
}
return strings.TrimSpace(v.String)
case string:
return strings.TrimSpace(v)
default:
return ""
}
}
// matchedRule returns a fakeRow that scans the given effect string.
func matchedRule(effect string) *fakeRow {
return makeStringRow(effect)
}
// noRule returns a fakeRow that returns pgx.ErrNoRows (no matching rule).
func noRule() *fakeRow {
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
}
// ---- Evaluate tests ----
func TestEvaluate(t *testing.T) {
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
ownerUUID := pgtype.UUID{Bytes: uuid.MustParse("22222222-2222-2222-2222-222222222222"), Valid: true}
tests := []struct {
name string
matchedEffect string // "" means no matching rule
defaultEffect string
wantAllowed bool
}{
{
name: "first rule allow",
matchedEffect: EffectAllow,
defaultEffect: EffectDeny,
wantAllowed: true,
},
{
name: "first rule deny",
matchedEffect: EffectDeny,
defaultEffect: EffectAllow,
wantAllowed: false,
},
{
name: "no matching rule - default allow",
matchedEffect: "",
defaultEffect: EffectAllow,
wantAllowed: true,
},
{
name: "no matching rule - default deny",
matchedEffect: "",
defaultEffect: EffectDeny,
wantAllowed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &fakeDBTX{
queryRowFunc: func(_ context.Context, sql string, _ ...any) pgx.Row {
switch {
case strings.Contains(sql, "FROM bots") && strings.Contains(sql, "owner_user_id"):
return makeBotRow(botUUID, ownerUUID)
case strings.Contains(sql, "FROM bot_acl_rules") && strings.Contains(sql, "LIMIT 1"):
// Evaluate query
if tt.matchedEffect == "" {
return noRule()
}
return matchedRule(tt.matchedEffect)
case strings.Contains(sql, "acl_default_effect"):
return makeStringRow(tt.defaultEffect)
default:
return noRule()
}
},
}
queries := sqlc.New(db)
botService := bots.NewService(nil, queries)
service := NewService(nil, queries, botService)
allowed, err := service.Evaluate(context.Background(), EvaluateRequest{
BotID: botUUID.String(),
ChannelIdentityID: "55555555-5555-5555-5555-555555555555",
ChannelType: "telegram",
SourceScope: SourceScope{
ConversationType: "group",
ConversationID: "group-1",
},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if allowed != tt.wantAllowed {
t.Fatalf("expected allowed=%v, got %v", tt.wantAllowed, allowed)
}
})
}
}
func TestEvaluateRejectsInvalidScope(t *testing.T) {
service := NewService(nil, nil, nil)
_, err := service.Evaluate(context.Background(), EvaluateRequest{
BotID: "11111111-1111-1111-1111-111111111111",
SourceScope: SourceScope{
ThreadID: "thread-1",
// missing ConversationID - invalid
},
})
if !errors.Is(err, ErrInvalidSourceScope) {
t.Fatalf("expected ErrInvalidSourceScope, got %v", err)
}
}
func TestValidateSubject(t *testing.T) {
tests := []struct {
name string
kind string
channelIdentityID string
subjectChannelType string
wantErr bool
}{
{"all - no fields", SubjectKindAll, "", "", false},
{"all - with identity", SubjectKindAll, "some-id", "", true},
{"all - with channel type", SubjectKindAll, "", "telegram", true},
{"channel_identity - valid", SubjectKindChannelIdentity, "some-id", "", false},
{"channel_identity - missing id", SubjectKindChannelIdentity, "", "", true},
{"channel_identity - extra channel type", SubjectKindChannelIdentity, "some-id", "telegram", true},
{"channel_type - valid", SubjectKindChannelType, "", "telegram", false},
{"channel_type - missing channel type", SubjectKindChannelType, "", "", true},
{"channel_type - extra identity", SubjectKindChannelType, "some-id", "telegram", true},
{"unknown kind", "unknown", "", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateSubject(tt.kind, tt.channelIdentityID, tt.subjectChannelType)
if (err != nil) != tt.wantErr {
t.Fatalf("validateSubject() error = %v, wantErr = %v", err, tt.wantErr)
}
})
}
}
func TestValidateEffect(t *testing.T) {
if err := validateEffect(EffectAllow); err != nil {
t.Fatalf("allow should be valid: %v", err)
}
if err := validateEffect(EffectDeny); err != nil {
t.Fatalf("deny should be valid: %v", err)
}
if err := validateEffect("unknown"); err == nil {
t.Fatal("expected error for unknown effect")
}
}
func TestSetDefaultEffect(t *testing.T) {
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
var capturedEffect string
db := &fakeDBTX{
execFunc: func(_ context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
if strings.Contains(sql, "acl_default_effect") {
capturedEffect = args[1].(string)
}
return pgconn.CommandTag{}, nil
},
}
service := NewService(nil, sqlc.New(db), nil)
if err := service.SetDefaultEffect(context.Background(), botUUID.String(), EffectAllow); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if capturedEffect != EffectAllow {
t.Fatalf("expected effect %q, got %q", EffectAllow, capturedEffect)
}
if err := service.SetDefaultEffect(context.Background(), botUUID.String(), "invalid"); !errors.Is(err, ErrInvalidEffect) {
t.Fatalf("expected ErrInvalidEffect, got %v", err)
}
}
func TestListObservedConversationsByChannelIdentity(t *testing.T) {
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
channelIdentityUUID := pgtype.UUID{Bytes: uuid.MustParse("55555555-5555-5555-5555-555555555555"), Valid: true}
routeUUID := pgtype.UUID{Bytes: uuid.MustParse("66666666-6666-6666-6666-666666666666"), Valid: true}
now := time.Now().UTC()
db := &fakeDBTX{
queryFunc: func(_ context.Context, sql string, _ ...any) (pgx.Rows, error) {
if !strings.Contains(sql, "observed_routes") && !strings.Contains(sql, "bot_sessions") {
return &fakeRows{}, nil
}
return &fakeRows{
rows: []func(dest ...any) error{
func(dest ...any) error {
*dest[0].(*pgtype.UUID) = routeUUID
*dest[1].(*string) = "feishu"
*dest[2].(*string) = "group"
*dest[3].(*string) = "chat-1"
*dest[4].(*string) = "thread-1"
*dest[5].(*string) = "Team Chat"
*dest[6].(*pgtype.Timestamptz) = pgtype.Timestamptz{Time: now, Valid: true}
return nil
},
},
}, nil
},
}
service := NewService(nil, sqlc.New(db), nil)
items, err := service.ListObservedConversationsByChannelIdentity(context.Background(), botUUID.String(), channelIdentityUUID.String())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(items) != 1 {
t.Fatalf("expected 1 item, got %d", len(items))
}
if items[0].RouteID != routeUUID.String() {
t.Fatalf("unexpected route id: %s", items[0].RouteID)
}
if items[0].ConversationID != "chat-1" || items[0].ThreadID != "thread-1" {
t.Fatalf("unexpected conversation scope: %+v", items[0])
}
}
func TestReorderRules(t *testing.T) {
ruleUUID := pgtype.UUID{Bytes: uuid.MustParse("77777777-7777-7777-7777-777777777777"), Valid: true}
var capturedPriority int32
db := &fakeDBTX{
execFunc: func(_ context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
if strings.Contains(sql, "priority") {
capturedPriority = args[1].(int32)
}
return pgconn.CommandTag{}, nil
},
}
service := NewService(nil, sqlc.New(db), nil)
err := service.ReorderRules(context.Background(), []ReorderItem{
{ID: ruleUUID.String(), Priority: 42},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if capturedPriority != 42 {
t.Fatalf("expected priority 42, got %d", capturedPriority)
}
}
func TestTextFromArg(t *testing.T) {
if got := textFromArg(pgtype.Text{String: " hello ", Valid: true}); got != "hello" {
t.Fatalf("unexpected: %q", got)
}
if got := textFromArg("world"); got != "world" {
t.Fatalf("unexpected: %q", got)
}
}