mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
0e646625bf
Allow users to configure what percentage of older messages to compact, keeping the most recent portion intact. Default ratio is 80%, meaning the oldest 80% of uncompacted messages are summarized while the newest 20% remain as-is for full-fidelity context.
396 lines
12 KiB
Go
396 lines
12 KiB
Go
package acl
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
|
|
"github.com/memohai/memoh/internal/bots"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
)
|
|
|
|
// ---- fake DB infrastructure ----
|
|
|
|
type fakeDBTX struct {
|
|
queryRowFunc func(ctx context.Context, sql string, args ...any) pgx.Row
|
|
queryFunc func(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
|
execFunc func(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
|
|
}
|
|
|
|
func (f *fakeDBTX) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
|
if f.execFunc != nil {
|
|
return f.execFunc(ctx, sql, args...)
|
|
}
|
|
return pgconn.CommandTag{}, nil
|
|
}
|
|
|
|
func (f *fakeDBTX) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
|
|
if f.queryFunc != nil {
|
|
return f.queryFunc(ctx, sql, args...)
|
|
}
|
|
return &fakeRows{}, nil
|
|
}
|
|
|
|
func (f *fakeDBTX) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
|
|
if f.queryRowFunc != nil {
|
|
return f.queryRowFunc(ctx, sql, args...)
|
|
}
|
|
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
|
|
}
|
|
|
|
type fakeRow struct {
|
|
scanFunc func(dest ...any) error
|
|
}
|
|
|
|
func (r *fakeRow) Scan(dest ...any) error {
|
|
if r.scanFunc == nil {
|
|
return pgx.ErrNoRows
|
|
}
|
|
return r.scanFunc(dest...)
|
|
}
|
|
|
|
type fakeRows struct {
|
|
rows []func(dest ...any) error
|
|
idx int
|
|
lastErr error
|
|
}
|
|
|
|
func (*fakeRows) Close() {}
|
|
func (r *fakeRows) Err() error { return r.lastErr }
|
|
func (*fakeRows) CommandTag() pgconn.CommandTag { return pgconn.CommandTag{} }
|
|
func (*fakeRows) FieldDescriptions() []pgconn.FieldDescription { return nil }
|
|
func (r *fakeRows) Next() bool {
|
|
if r.idx >= len(r.rows) {
|
|
return false
|
|
}
|
|
r.idx++
|
|
return true
|
|
}
|
|
|
|
func (r *fakeRows) Scan(dest ...any) error {
|
|
if r.idx == 0 || r.idx > len(r.rows) {
|
|
return errors.New("scan called without next")
|
|
}
|
|
scan := r.rows[r.idx-1]
|
|
if scan == nil {
|
|
return nil
|
|
}
|
|
return scan(dest...)
|
|
}
|
|
func (*fakeRows) Values() ([]any, error) { return nil, nil }
|
|
func (*fakeRows) RawValues() [][]byte { return nil }
|
|
func (*fakeRows) Conn() *pgx.Conn { return nil }
|
|
|
|
// ---- helpers ----
|
|
|
|
func makeBotRow(botID, ownerUserID pgtype.UUID) *fakeRow {
|
|
return &fakeRow{
|
|
scanFunc: func(dest ...any) error {
|
|
if len(dest) < 22 {
|
|
return pgx.ErrNoRows
|
|
}
|
|
*dest[0].(*pgtype.UUID) = botID
|
|
*dest[1].(*pgtype.UUID) = ownerUserID
|
|
*dest[2].(*pgtype.Text) = pgtype.Text{String: "bot", Valid: true}
|
|
*dest[3].(*pgtype.Text) = pgtype.Text{}
|
|
*dest[4].(*pgtype.Text) = pgtype.Text{}
|
|
*dest[5].(*bool) = true
|
|
*dest[6].(*string) = bots.BotStatusReady
|
|
*dest[7].(*string) = "" // Language
|
|
*dest[8].(*bool) = false // ReasoningEnabled
|
|
*dest[9].(*string) = "medium" // ReasoningEffort
|
|
*dest[10].(*pgtype.UUID) = pgtype.UUID{} // ChatModelID
|
|
*dest[11].(*pgtype.UUID) = pgtype.UUID{} // SearchProviderID
|
|
*dest[12].(*pgtype.UUID) = pgtype.UUID{} // MemoryProviderID
|
|
*dest[13].(*bool) = false // HeartbeatEnabled
|
|
*dest[14].(*int32) = 30 // HeartbeatInterval
|
|
*dest[15].(*string) = "" // HeartbeatPrompt
|
|
*dest[16].(*bool) = false // CompactionEnabled
|
|
*dest[17].(*int32) = 100000 // CompactionThreshold
|
|
*dest[18].(*int32) = 80 // CompactionRatio
|
|
*dest[19].(*pgtype.UUID) = pgtype.UUID{} // CompactionModelID
|
|
*dest[20].(*[]byte) = []byte(`{}`)
|
|
*dest[21].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
|
*dest[22].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func makeStringRow(value string) *fakeRow {
|
|
return &fakeRow{
|
|
scanFunc: func(dest ...any) error {
|
|
*dest[0].(*string) = value
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func textFromArg(value any) string {
|
|
switch v := value.(type) {
|
|
case pgtype.Text:
|
|
return strings.TrimSpace(v.String)
|
|
case *pgtype.Text:
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(v.String)
|
|
case string:
|
|
return strings.TrimSpace(v)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// matchedRule returns a fakeRow that scans the given effect string.
|
|
func matchedRule(effect string) *fakeRow {
|
|
return makeStringRow(effect)
|
|
}
|
|
|
|
// noRule returns a fakeRow that returns pgx.ErrNoRows (no matching rule).
|
|
func noRule() *fakeRow {
|
|
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
|
|
}
|
|
|
|
// ---- Evaluate tests ----
|
|
|
|
func TestEvaluate(t *testing.T) {
|
|
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
|
|
ownerUUID := pgtype.UUID{Bytes: uuid.MustParse("22222222-2222-2222-2222-222222222222"), Valid: true}
|
|
|
|
tests := []struct {
|
|
name string
|
|
matchedEffect string // "" means no matching rule
|
|
defaultEffect string
|
|
wantAllowed bool
|
|
}{
|
|
{
|
|
name: "first rule allow",
|
|
matchedEffect: EffectAllow,
|
|
defaultEffect: EffectDeny,
|
|
wantAllowed: true,
|
|
},
|
|
{
|
|
name: "first rule deny",
|
|
matchedEffect: EffectDeny,
|
|
defaultEffect: EffectAllow,
|
|
wantAllowed: false,
|
|
},
|
|
{
|
|
name: "no matching rule - default allow",
|
|
matchedEffect: "",
|
|
defaultEffect: EffectAllow,
|
|
wantAllowed: true,
|
|
},
|
|
{
|
|
name: "no matching rule - default deny",
|
|
matchedEffect: "",
|
|
defaultEffect: EffectDeny,
|
|
wantAllowed: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := &fakeDBTX{
|
|
queryRowFunc: func(_ context.Context, sql string, _ ...any) pgx.Row {
|
|
switch {
|
|
case strings.Contains(sql, "FROM bots") && strings.Contains(sql, "owner_user_id"):
|
|
return makeBotRow(botUUID, ownerUUID)
|
|
case strings.Contains(sql, "FROM bot_acl_rules") && strings.Contains(sql, "LIMIT 1"):
|
|
// Evaluate query
|
|
if tt.matchedEffect == "" {
|
|
return noRule()
|
|
}
|
|
return matchedRule(tt.matchedEffect)
|
|
case strings.Contains(sql, "acl_default_effect"):
|
|
return makeStringRow(tt.defaultEffect)
|
|
default:
|
|
return noRule()
|
|
}
|
|
},
|
|
}
|
|
queries := sqlc.New(db)
|
|
botService := bots.NewService(nil, queries)
|
|
service := NewService(nil, queries, botService)
|
|
|
|
allowed, err := service.Evaluate(context.Background(), EvaluateRequest{
|
|
BotID: botUUID.String(),
|
|
ChannelIdentityID: "55555555-5555-5555-5555-555555555555",
|
|
ChannelType: "telegram",
|
|
SourceScope: SourceScope{
|
|
ConversationType: "group",
|
|
ConversationID: "group-1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if allowed != tt.wantAllowed {
|
|
t.Fatalf("expected allowed=%v, got %v", tt.wantAllowed, allowed)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEvaluateRejectsInvalidScope(t *testing.T) {
|
|
service := NewService(nil, nil, nil)
|
|
_, err := service.Evaluate(context.Background(), EvaluateRequest{
|
|
BotID: "11111111-1111-1111-1111-111111111111",
|
|
SourceScope: SourceScope{
|
|
ThreadID: "thread-1",
|
|
// missing ConversationID - invalid
|
|
},
|
|
})
|
|
if !errors.Is(err, ErrInvalidSourceScope) {
|
|
t.Fatalf("expected ErrInvalidSourceScope, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateSubject(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
kind string
|
|
channelIdentityID string
|
|
subjectChannelType string
|
|
wantErr bool
|
|
}{
|
|
{"all - no fields", SubjectKindAll, "", "", false},
|
|
{"all - with identity", SubjectKindAll, "some-id", "", true},
|
|
{"all - with channel type", SubjectKindAll, "", "telegram", true},
|
|
{"channel_identity - valid", SubjectKindChannelIdentity, "some-id", "", false},
|
|
{"channel_identity - missing id", SubjectKindChannelIdentity, "", "", true},
|
|
{"channel_identity - extra channel type", SubjectKindChannelIdentity, "some-id", "telegram", true},
|
|
{"channel_type - valid", SubjectKindChannelType, "", "telegram", false},
|
|
{"channel_type - missing channel type", SubjectKindChannelType, "", "", true},
|
|
{"channel_type - extra identity", SubjectKindChannelType, "some-id", "telegram", true},
|
|
{"unknown kind", "unknown", "", "", true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateSubject(tt.kind, tt.channelIdentityID, tt.subjectChannelType)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Fatalf("validateSubject() error = %v, wantErr = %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateEffect(t *testing.T) {
|
|
if err := validateEffect(EffectAllow); err != nil {
|
|
t.Fatalf("allow should be valid: %v", err)
|
|
}
|
|
if err := validateEffect(EffectDeny); err != nil {
|
|
t.Fatalf("deny should be valid: %v", err)
|
|
}
|
|
if err := validateEffect("unknown"); err == nil {
|
|
t.Fatal("expected error for unknown effect")
|
|
}
|
|
}
|
|
|
|
func TestSetDefaultEffect(t *testing.T) {
|
|
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
|
|
var capturedEffect string
|
|
db := &fakeDBTX{
|
|
execFunc: func(_ context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
|
if strings.Contains(sql, "acl_default_effect") {
|
|
capturedEffect = args[1].(string)
|
|
}
|
|
return pgconn.CommandTag{}, nil
|
|
},
|
|
}
|
|
service := NewService(nil, sqlc.New(db), nil)
|
|
if err := service.SetDefaultEffect(context.Background(), botUUID.String(), EffectAllow); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if capturedEffect != EffectAllow {
|
|
t.Fatalf("expected effect %q, got %q", EffectAllow, capturedEffect)
|
|
}
|
|
if err := service.SetDefaultEffect(context.Background(), botUUID.String(), "invalid"); !errors.Is(err, ErrInvalidEffect) {
|
|
t.Fatalf("expected ErrInvalidEffect, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestListObservedConversationsByChannelIdentity(t *testing.T) {
|
|
botUUID := pgtype.UUID{Bytes: uuid.MustParse("11111111-1111-1111-1111-111111111111"), Valid: true}
|
|
channelIdentityUUID := pgtype.UUID{Bytes: uuid.MustParse("55555555-5555-5555-5555-555555555555"), Valid: true}
|
|
routeUUID := pgtype.UUID{Bytes: uuid.MustParse("66666666-6666-6666-6666-666666666666"), Valid: true}
|
|
now := time.Now().UTC()
|
|
|
|
db := &fakeDBTX{
|
|
queryFunc: func(_ context.Context, sql string, _ ...any) (pgx.Rows, error) {
|
|
if !strings.Contains(sql, "observed_routes") && !strings.Contains(sql, "bot_sessions") {
|
|
return &fakeRows{}, nil
|
|
}
|
|
return &fakeRows{
|
|
rows: []func(dest ...any) error{
|
|
func(dest ...any) error {
|
|
*dest[0].(*pgtype.UUID) = routeUUID
|
|
*dest[1].(*string) = "feishu"
|
|
*dest[2].(*string) = "group"
|
|
*dest[3].(*string) = "chat-1"
|
|
*dest[4].(*string) = "thread-1"
|
|
*dest[5].(*string) = "Team Chat"
|
|
*dest[6].(*pgtype.Timestamptz) = pgtype.Timestamptz{Time: now, Valid: true}
|
|
return nil
|
|
},
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
service := NewService(nil, sqlc.New(db), nil)
|
|
items, err := service.ListObservedConversationsByChannelIdentity(context.Background(), botUUID.String(), channelIdentityUUID.String())
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(items) != 1 {
|
|
t.Fatalf("expected 1 item, got %d", len(items))
|
|
}
|
|
if items[0].RouteID != routeUUID.String() {
|
|
t.Fatalf("unexpected route id: %s", items[0].RouteID)
|
|
}
|
|
if items[0].ConversationID != "chat-1" || items[0].ThreadID != "thread-1" {
|
|
t.Fatalf("unexpected conversation scope: %+v", items[0])
|
|
}
|
|
}
|
|
|
|
func TestReorderRules(t *testing.T) {
|
|
ruleUUID := pgtype.UUID{Bytes: uuid.MustParse("77777777-7777-7777-7777-777777777777"), Valid: true}
|
|
var capturedPriority int32
|
|
db := &fakeDBTX{
|
|
execFunc: func(_ context.Context, sql string, args ...any) (pgconn.CommandTag, error) {
|
|
if strings.Contains(sql, "priority") {
|
|
capturedPriority = args[1].(int32)
|
|
}
|
|
return pgconn.CommandTag{}, nil
|
|
},
|
|
}
|
|
service := NewService(nil, sqlc.New(db), nil)
|
|
err := service.ReorderRules(context.Background(), []ReorderItem{
|
|
{ID: ruleUUID.String(), Priority: 42},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if capturedPriority != 42 {
|
|
t.Fatalf("expected priority 42, got %d", capturedPriority)
|
|
}
|
|
}
|
|
|
|
func TestTextFromArg(t *testing.T) {
|
|
if got := textFromArg(pgtype.Text{String: " hello ", Valid: true}); got != "hello" {
|
|
t.Fatalf("unexpected: %q", got)
|
|
}
|
|
if got := textFromArg("world"); got != "world" {
|
|
t.Fatalf("unexpected: %q", got)
|
|
}
|
|
}
|