mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor(core): finalize user-centric identity and policy cleanup
Unify auth and chat identity semantics around user_id, enforce personal-bot owner-only authorization, and remove legacy compatibility branches in integration tests.
This commit is contained in:
+4
-13
@@ -52,17 +52,9 @@ func UserIDFromContext(c echo.Context) (string, error) {
|
||||
if userID := claimString(claims, claimSubject); userID != "" {
|
||||
return userID, nil
|
||||
}
|
||||
if legacyChannelIdentityID := claimString(claims, claimChannelIdentityID); legacyChannelIdentityID != "" {
|
||||
return legacyChannelIdentityID, nil
|
||||
}
|
||||
return "", echo.NewHTTPError(http.StatusUnauthorized, "user id missing")
|
||||
}
|
||||
|
||||
// ChannelIdentityIDFromContext is kept as compatibility alias and returns user id.
|
||||
func ChannelIdentityIDFromContext(c echo.Context) (string, error) {
|
||||
return UserIDFromContext(c)
|
||||
}
|
||||
|
||||
// GenerateToken creates a signed JWT for the user.
|
||||
func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) {
|
||||
if strings.TrimSpace(userID) == "" {
|
||||
@@ -78,11 +70,10 @@ func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(expiresIn)
|
||||
claims := jwt.MapClaims{
|
||||
claimSubject: userID,
|
||||
claimUserID: userID,
|
||||
claimChannelIdentityID: userID, // legacy compatibility for handlers still reading channel_identity_id
|
||||
"iat": now.Unix(),
|
||||
"exp": expiresAt.Unix(),
|
||||
claimSubject: userID,
|
||||
claimUserID: userID,
|
||||
"iat": now.Unix(),
|
||||
"exp": expiresAt.Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString([]byte(secret))
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/memohai/memoh/internal/channelidentities"
|
||||
"github.com/memohai/memoh/internal/bind"
|
||||
"github.com/memohai/memoh/internal/channelidentities"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
@@ -47,7 +47,7 @@ func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.S
|
||||
return queries, channelIdentitySvc, bindSvc, func() { pool.Close() }
|
||||
}
|
||||
|
||||
func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||
func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
||||
IsActive: true,
|
||||
Metadata: []byte("{}"),
|
||||
@@ -58,12 +58,15 @@ func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||
return db.UUIDToString(row.ID), nil
|
||||
}
|
||||
|
||||
func createBot(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
|
||||
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
|
||||
pgOwnerID, err := db.ParseUUID(ownerUserID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||
meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||
OwnerUserID: pgOwnerID,
|
||||
Type: "personal",
|
||||
@@ -82,15 +85,15 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
ownerUserID, err := createUser(ctx, queries)
|
||||
ownerUserID, err := createUserForBindTest(ctx, queries)
|
||||
if err != nil {
|
||||
t.Fatalf("create owner user failed: %v", err)
|
||||
}
|
||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
||||
if err != nil {
|
||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
||||
t.Fatalf("create source channel identity failed: %v", err)
|
||||
}
|
||||
botID, err := createBot(ctx, queries, ownerUserID)
|
||||
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("create bot failed: %v", err)
|
||||
}
|
||||
@@ -127,177 +130,6 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
ownerUserID, err := createUser(ctx, queries)
|
||||
if err != nil {
|
||||
t.Fatalf("create owner user failed: %v", err)
|
||||
}
|
||||
otherUserID, err := createUser(ctx, queries)
|
||||
if err != nil {
|
||||
t.Fatalf("create other user failed: %v", err)
|
||||
}
|
||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
||||
if err != nil {
|
||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
||||
}
|
||||
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil {
|
||||
t.Fatalf("pre-link source channelIdentity failed: %v", err)
|
||||
}
|
||||
botID, err := createBot(ctx, queries, ownerUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("create bot failed: %v", err)
|
||||
}
|
||||
|
||||
code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("issue bind code failed: %v", err)
|
||||
}
|
||||
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) {
|
||||
t.Fatalf("expected ErrLinkConflict, got %v", err)
|
||||
}
|
||||
|
||||
after, err := bindSvc.Get(ctx, code.Token)
|
||||
if err != nil {
|
||||
t.Fatalf("get bind code failed: %v", err)
|
||||
}
|
||||
if !after.UsedAt.IsZero() {
|
||||
t.Fatal("expected used_at to remain empty when consume fails")
|
||||
}
|
||||
}
|
||||
package bind_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/memohai/memoh/internal/channelidentities"
|
||||
"github.com/memohai/memoh/internal/bind"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) {
|
||||
t.Helper()
|
||||
|
||||
dsn := os.Getenv("TEST_POSTGRES_DSN")
|
||||
if dsn == "" {
|
||||
t.Skip("skip integration test: TEST_POSTGRES_DSN is not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Skipf("skip integration test: cannot connect to database: %v", err)
|
||||
}
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
t.Skipf("skip integration test: database ping failed: %v", err)
|
||||
}
|
||||
|
||||
queries := sqlc.New(pool)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
channelIdentitySvc := channelidentities.NewService(logger, queries)
|
||||
bindSvc := bind.NewService(logger, pool, queries)
|
||||
|
||||
cleanup := func() {
|
||||
pool.Close()
|
||||
}
|
||||
return queries, channelIdentitySvc, bindSvc, cleanup
|
||||
}
|
||||
|
||||
func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
||||
IsActive: true,
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return db.UUIDToString(row.ID), nil
|
||||
}
|
||||
|
||||
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
|
||||
pgOwnerID, err := db.ParseUUID(ownerUserID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||
OwnerUserID: pgOwnerID,
|
||||
Type: "personal",
|
||||
DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true},
|
||||
AvatarUrl: pgtype.Text{},
|
||||
IsActive: true,
|
||||
Metadata: meta,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return db.UUIDToString(row.ID), nil
|
||||
}
|
||||
|
||||
func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
ownerUserID, err := createUserForBindTest(ctx, queries)
|
||||
if err != nil {
|
||||
t.Fatalf("create owner user failed: %v", err)
|
||||
}
|
||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
||||
if err != nil {
|
||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
||||
}
|
||||
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("create bot failed: %v", err)
|
||||
}
|
||||
|
||||
code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("issue bind code failed: %v", err)
|
||||
}
|
||||
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil {
|
||||
t.Fatalf("consume bind code failed: %v", err)
|
||||
}
|
||||
|
||||
after, err := bindSvc.Get(ctx, code.Token)
|
||||
if err != nil {
|
||||
t.Fatalf("get bind code failed: %v", err)
|
||||
}
|
||||
if after.UsedAt.IsZero() {
|
||||
t.Fatal("expected used_at to be set after successful consume")
|
||||
}
|
||||
if after.UsedByChannelIdentityID != sourceChannelIdentity.ID {
|
||||
t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID)
|
||||
}
|
||||
|
||||
linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get linked user failed: %v", err)
|
||||
}
|
||||
if linkedUserID != ownerUserID {
|
||||
t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID)
|
||||
}
|
||||
|
||||
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrCodeUsed) {
|
||||
t.Fatalf("expected ErrCodeUsed on second consume, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
||||
defer cleanup()
|
||||
@@ -313,10 +145,10 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
||||
}
|
||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
||||
if err != nil {
|
||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
||||
t.Fatalf("create source channel identity failed: %v", err)
|
||||
}
|
||||
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil {
|
||||
t.Fatalf("pre-link source channelIdentity failed: %v", err)
|
||||
t.Fatalf("pre-link source channel identity failed: %v", err)
|
||||
}
|
||||
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
|
||||
if err != nil {
|
||||
@@ -339,224 +171,3 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
||||
t.Fatal("expected used_at to remain empty when consume fails")
|
||||
}
|
||||
}
|
||||
package bind_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/memohai/memoh/internal/channelidentities"
|
||||
"github.com/memohai/memoh/internal/bind"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) {
|
||||
t.Helper()
|
||||
|
||||
dsn := os.Getenv("TEST_POSTGRES_DSN")
|
||||
if dsn == "" {
|
||||
t.Skip("skip integration test: TEST_POSTGRES_DSN is not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Skipf("skip integration test: cannot connect to database: %v", err)
|
||||
}
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
t.Skipf("skip integration test: database ping failed: %v", err)
|
||||
}
|
||||
|
||||
queries := sqlc.New(pool)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
channelIdentitySvc := channelidentities.NewService(logger, queries)
|
||||
bindSvc := bind.NewService(logger, pool, queries)
|
||||
|
||||
cleanup := func() {
|
||||
pool.Close()
|
||||
}
|
||||
return queries, channelIdentitySvc, bindSvc, cleanup
|
||||
}
|
||||
|
||||
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerChannelIdentityID string) (string, error) {
|
||||
pgOwnerID, err := db.ParseUUID(ownerChannelIdentityID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||
OwnerChannelIdentityID: pgOwnerID,
|
||||
Type: "personal",
|
||||
DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true},
|
||||
AvatarUrl: pgtype.Text{},
|
||||
IsActive: true,
|
||||
Metadata: meta,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return db.UUIDToString(row.ID), nil
|
||||
}
|
||||
|
||||
func createChatForBindTest(ctx context.Context, queries *sqlc.Queries, botID, channelIdentityID string) (string, error) {
|
||||
pgBotID, err := db.ParseUUID(botID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pgChannelIdentityID, err := db.ParseUUID(channelIdentityID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := queries.CreateChat(ctx, sqlc.CreateChatParams{
|
||||
BotID: pgBotID,
|
||||
Kind: "direct",
|
||||
ParentChatID: pgtype.UUID{},
|
||||
Title: pgtype.Text{},
|
||||
CreatedBy: pgChannelIdentityID,
|
||||
Metadata: []byte("{}"),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return db.UUIDToString(row.ID), nil
|
||||
}
|
||||
|
||||
func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
human, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
|
||||
if err != nil {
|
||||
t.Fatalf("create human failed: %v", err)
|
||||
}
|
||||
shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow)
|
||||
if err != nil {
|
||||
t.Fatalf("create shadow failed: %v", err)
|
||||
}
|
||||
botID, err := createBotForBindTest(ctx, queries, human.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("create bot failed: %v", err)
|
||||
}
|
||||
chatID, err := createChatForBindTest(ctx, queries, botID, human.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("create chat failed: %v", err)
|
||||
}
|
||||
pgChatID, err := db.ParseUUID(chatID)
|
||||
if err != nil {
|
||||
t.Fatalf("parse chat id failed: %v", err)
|
||||
}
|
||||
pgShadowID, err := db.ParseUUID(shadow.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("parse shadow id failed: %v", err)
|
||||
}
|
||||
pgHumanID, err := db.ParseUUID(human.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("parse human id failed: %v", err)
|
||||
}
|
||||
if _, err := queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{
|
||||
ChatID: pgChatID,
|
||||
ChannelIdentityID: pgShadowID,
|
||||
Role: "member",
|
||||
}); err != nil {
|
||||
t.Fatalf("add shadow participant failed: %v", err)
|
||||
}
|
||||
|
||||
code, err := bindSvc.Issue(ctx, botID, human.ID, 10*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("issue bind code failed: %v", err)
|
||||
}
|
||||
if err := bindSvc.Consume(ctx, code, shadow.ID); err != nil {
|
||||
t.Fatalf("consume bind code failed: %v", err)
|
||||
}
|
||||
|
||||
after, err := bindSvc.Get(ctx, code.Token)
|
||||
if err != nil {
|
||||
t.Fatalf("get bind code failed: %v", err)
|
||||
}
|
||||
if after.UsedAt.IsZero() {
|
||||
t.Fatal("expected used_at to be set after successful consume")
|
||||
}
|
||||
if after.UsedByChannelIdentityID != shadow.ID {
|
||||
t.Fatalf("expected used_by_channel_identity_id=%s, got %s", shadow.ID, after.UsedByChannelIdentityID)
|
||||
}
|
||||
|
||||
canonical, err := channelIdentitySvc.Canonicalize(ctx, shadow.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("canonicalize failed: %v", err)
|
||||
}
|
||||
if canonical != human.ID {
|
||||
t.Fatalf("expected canonical=%s, got %s", human.ID, canonical)
|
||||
}
|
||||
if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{
|
||||
ChatID: pgChatID,
|
||||
ChannelIdentityID: pgHumanID,
|
||||
}); err != nil {
|
||||
t.Fatalf("expected human participant after bind, got error: %v", err)
|
||||
}
|
||||
if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{
|
||||
ChatID: pgChatID,
|
||||
ChannelIdentityID: pgShadowID,
|
||||
}); !errors.Is(err, pgx.ErrNoRows) {
|
||||
t.Fatalf("expected shadow participant removed after bind, got %v", err)
|
||||
}
|
||||
|
||||
if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrCodeUsed) {
|
||||
t.Fatalf("expected ErrCodeUsed on second consume, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
humanA, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
|
||||
if err != nil {
|
||||
t.Fatalf("create humanA failed: %v", err)
|
||||
}
|
||||
humanB, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
|
||||
if err != nil {
|
||||
t.Fatalf("create humanB failed: %v", err)
|
||||
}
|
||||
shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow)
|
||||
if err != nil {
|
||||
t.Fatalf("create shadow failed: %v", err)
|
||||
}
|
||||
botID, err := createBotForBindTest(ctx, queries, humanA.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("create bot failed: %v", err)
|
||||
}
|
||||
|
||||
// Pre-link shadow to another user so bind consume hits link conflict.
|
||||
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, shadow.ID, humanB.ID); err != nil {
|
||||
t.Fatalf("pre link shadow->humanB failed: %v", err)
|
||||
}
|
||||
|
||||
code, err := bindSvc.Issue(ctx, botID, humanA.ID, 10*time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("issue bind code failed: %v", err)
|
||||
}
|
||||
if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrLinkConflict) {
|
||||
t.Fatalf("expected ErrLinkConflict, got %v", err)
|
||||
}
|
||||
|
||||
after, err := bindSvc.Get(ctx, code.Token)
|
||||
if err != nil {
|
||||
t.Fatalf("get bind code failed: %v", err)
|
||||
}
|
||||
if !after.UsedAt.IsZero() {
|
||||
t.Fatal("expected used_at to remain empty when consume fails")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -61,7 +60,10 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||
meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||
OwnerUserID: pgOwnerID,
|
||||
Type: "personal",
|
||||
@@ -75,14 +77,6 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st
|
||||
return db.UUIDToString(row.ID), nil
|
||||
}
|
||||
|
||||
func isLegacyBindSchemaError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "relation \"users\" does not exist")
|
||||
}
|
||||
|
||||
func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) {
|
||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t)
|
||||
defer cleanup()
|
||||
@@ -90,9 +84,6 @@ func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ownerUserID, err := createUserForBind(ctx, queries)
|
||||
if err != nil {
|
||||
if isLegacyBindSchemaError(err) {
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
t.Fatalf("create owner user failed: %v", err)
|
||||
}
|
||||
sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source")
|
||||
@@ -134,9 +125,6 @@ func TestBindConsumeConflictDoesNotMarkUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
issuerUserID, err := createUserForBind(ctx, queries)
|
||||
if err != nil {
|
||||
if isLegacyBindSchemaError(err) {
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
t.Fatalf("create issuer user failed: %v", err)
|
||||
}
|
||||
otherUserID, err := createUserForBind(ctx, queries)
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -43,16 +42,6 @@ func formatUUID(bytes [16]byte) string {
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
|
||||
}
|
||||
|
||||
func isLegacyChannelIdentitySchemaError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "channelidentities_kind_check") ||
|
||||
strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") ||
|
||||
strings.Contains(msg, "column \"channel_subject_id\" of relation \"channelidentities\" does not exist")
|
||||
}
|
||||
|
||||
func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) {
|
||||
svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t)
|
||||
defer cleanup()
|
||||
@@ -61,9 +50,6 @@ func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) {
|
||||
externalID := fmt.Sprintf("stable_%d", time.Now().UnixNano())
|
||||
first, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "first")
|
||||
if err != nil {
|
||||
if isLegacyChannelIdentitySchemaError(err) {
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
t.Fatalf("first resolve failed: %v", err)
|
||||
}
|
||||
second, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "second")
|
||||
@@ -82,9 +68,6 @@ func TestChannelIdentityLinkToUser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", fmt.Sprintf("link_%d", time.Now().UnixNano()), "tg")
|
||||
if err != nil {
|
||||
if isLegacyChannelIdentitySchemaError(err) {
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
t.Fatalf("resolve channelIdentity failed: %v", err)
|
||||
}
|
||||
user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
||||
|
||||
+21
-19
@@ -208,7 +208,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
r.enforceGroupMemoryPolicy(ctx, req.ChatID, &chatSettings)
|
||||
}
|
||||
|
||||
userSettings, err := r.loadUserSettings(ctx, req.ChannelIdentityID)
|
||||
userSettings, err := r.loadUserSettings(ctx, req.UserID)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
BotID: req.BotID,
|
||||
SessionID: req.ChatID,
|
||||
ContainerID: containerID,
|
||||
ChannelIdentityID: firstNonEmpty(req.ChannelIdentityID, req.BotID),
|
||||
ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID),
|
||||
DisplayName: firstNonEmpty(req.DisplayName, "User"),
|
||||
CurrentPlatform: req.CurrentChannel,
|
||||
ReplyTarget: "",
|
||||
@@ -348,11 +348,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
chatID = "schedule-" + payload.ID
|
||||
}
|
||||
req := ChatRequest{
|
||||
BotID: botID,
|
||||
ChatID: chatID,
|
||||
Query: payload.Command,
|
||||
ChannelIdentityID: payload.OwnerUserID,
|
||||
Token: token,
|
||||
BotID: botID,
|
||||
ChatID: chatID,
|
||||
Query: payload.Command,
|
||||
UserID: payload.OwnerUserID,
|
||||
Token: token,
|
||||
}
|
||||
rc, err := r.resolve(ctx, req)
|
||||
if err != nil {
|
||||
@@ -373,7 +373,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
BotID: rc.payload.Identity.BotID,
|
||||
SessionID: rc.payload.Identity.SessionID,
|
||||
ContainerID: rc.payload.Identity.ContainerID,
|
||||
ChannelIdentityID: firstNonEmpty(payload.OwnerUserID, botID),
|
||||
ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID),
|
||||
DisplayName: "Scheduler",
|
||||
},
|
||||
Attachments: rc.payload.Attachments,
|
||||
@@ -676,8 +676,8 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest
|
||||
if settings.EnableChatMemory {
|
||||
scopes = append(scopes, memoryScope{Namespace: "chat", ScopeID: req.ChatID})
|
||||
}
|
||||
if settings.EnablePrivateMemory && strings.TrimSpace(req.ChannelIdentityID) != "" {
|
||||
scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.ChannelIdentityID})
|
||||
if settings.EnablePrivateMemory && strings.TrimSpace(req.UserID) != "" {
|
||||
scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.UserID})
|
||||
}
|
||||
if settings.EnablePublicMemory {
|
||||
scopes = append(scopes, memoryScope{Namespace: "public", ScopeID: req.BotID})
|
||||
@@ -778,7 +778,7 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M
|
||||
fullRound = append(fullRound, messages...)
|
||||
|
||||
r.storeMessages(ctx, req, fullRound)
|
||||
r.storeMemory(ctx, req.BotID, req.ChatID, req.ChannelIdentityID, req.Query, fullRound)
|
||||
r.storeMemory(ctx, req.BotID, req.ChatID, req.UserID, req.Query, fullRound)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -805,10 +805,12 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
senderID := ""
|
||||
senderChannelIdentityID := ""
|
||||
senderUserID := ""
|
||||
externalMessageID := ""
|
||||
if msg.Role == "user" {
|
||||
senderID = req.ChannelIdentityID
|
||||
senderChannelIdentityID = req.SourceChannelIdentityID
|
||||
senderUserID = req.UserID
|
||||
externalMessageID = req.ExternalMessageID
|
||||
}
|
||||
if _, err := r.chatService.PersistMessage(
|
||||
@@ -816,8 +818,8 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
||||
req.ChatID,
|
||||
req.BotID,
|
||||
req.RouteID,
|
||||
"",
|
||||
senderID,
|
||||
senderChannelIdentityID,
|
||||
senderUserID,
|
||||
req.CurrentChannel,
|
||||
externalMessageID,
|
||||
msg.Role,
|
||||
@@ -829,7 +831,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdentityID, query string, messages []ModelMessage) {
|
||||
func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, userID, query string, messages []ModelMessage) {
|
||||
if r.memoryService == nil {
|
||||
return
|
||||
}
|
||||
@@ -869,9 +871,9 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdenti
|
||||
r.addMemory(ctx, botID, memMsgs, "chat", chatID)
|
||||
}
|
||||
|
||||
// Write to private namespace if enabled and channel identity is known.
|
||||
if cs.EnablePrivateMemory && strings.TrimSpace(channelIdentityID) != "" {
|
||||
r.addMemory(ctx, botID, memMsgs, "private", channelIdentityID)
|
||||
// Write to private namespace if enabled and user id is known.
|
||||
if cs.EnablePrivateMemory && strings.TrimSpace(userID) != "" {
|
||||
r.addMemory(ctx, botID, memMsgs, "private", userID)
|
||||
}
|
||||
|
||||
// Write to public namespace if enabled.
|
||||
|
||||
+34
-13
@@ -70,7 +70,10 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
|
||||
}
|
||||
}
|
||||
|
||||
metadata, _ := json.Marshal(nonNilMap(req.Metadata))
|
||||
metadata, err := json.Marshal(nonNilMap(req.Metadata))
|
||||
if err != nil {
|
||||
return Chat{}, fmt.Errorf("marshal chat metadata: %w", err)
|
||||
}
|
||||
|
||||
row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{
|
||||
BotID: pgBotID,
|
||||
@@ -393,7 +396,10 @@ func (s *Service) CreateRoute(ctx context.Context, chatID string, r Route) (Rout
|
||||
return Route{}, err
|
||||
}
|
||||
}
|
||||
metadata, _ := json.Marshal(nonNilMap(r.Metadata))
|
||||
metadata, err := json.Marshal(nonNilMap(r.Metadata))
|
||||
if err != nil {
|
||||
return Route{}, fmt.Errorf("marshal route metadata: %w", err)
|
||||
}
|
||||
row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{
|
||||
ChatID: pgChatID,
|
||||
BotID: pgBotID,
|
||||
@@ -488,7 +494,10 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation
|
||||
if err == nil {
|
||||
// Route found, ensure the sender identity is a participant.
|
||||
if strings.TrimSpace(channelIdentityID) != "" {
|
||||
ok, _ := s.IsParticipant(ctx, route.ChatID, channelIdentityID)
|
||||
ok, checkErr := s.IsParticipant(ctx, route.ChatID, channelIdentityID)
|
||||
if checkErr != nil {
|
||||
return ResolveChatResult{}, fmt.Errorf("check chat participant: %w", checkErr)
|
||||
}
|
||||
if !ok {
|
||||
if _, err := s.AddParticipant(ctx, route.ChatID, channelIdentityID, RoleMember); err != nil {
|
||||
s.logger.Warn("auto-add participant failed", slog.Any("error", err))
|
||||
@@ -497,9 +506,17 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation
|
||||
}
|
||||
// Update reply target if changed.
|
||||
if strings.TrimSpace(replyTarget) != "" && replyTarget != route.ReplyTarget {
|
||||
_ = s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget)
|
||||
if err := s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget); err != nil && s.logger != nil {
|
||||
s.logger.Warn("update route reply target failed", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
pgRouteChatID, parseErr := parseUUID(route.ChatID)
|
||||
if parseErr != nil {
|
||||
return ResolveChatResult{}, fmt.Errorf("parse route chat id: %w", parseErr)
|
||||
}
|
||||
if err := s.queries.TouchChat(ctx, pgRouteChatID); err != nil && s.logger != nil {
|
||||
s.logger.Warn("touch chat failed", slog.Any("error", err))
|
||||
}
|
||||
_ = s.queries.TouchChat(ctx, mustParseUUID(route.ChatID))
|
||||
return ResolveChatResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil
|
||||
}
|
||||
|
||||
@@ -564,13 +581,22 @@ func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, se
|
||||
}
|
||||
var pgSender pgtype.UUID
|
||||
if strings.TrimSpace(senderChannelIdentityID) != "" {
|
||||
pgSender, _ = parseUUID(senderChannelIdentityID)
|
||||
pgSender, err = parseUUID(senderChannelIdentityID)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("invalid sender channel identity id: %w", err)
|
||||
}
|
||||
}
|
||||
var pgSenderUser pgtype.UUID
|
||||
if strings.TrimSpace(senderUserID) != "" {
|
||||
pgSenderUser, _ = parseUUID(senderUserID)
|
||||
pgSenderUser, err = parseUUID(senderUserID)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("invalid sender user id: %w", err)
|
||||
}
|
||||
}
|
||||
metaBytes, err := json.Marshal(nonNilMap(metadata))
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("marshal message metadata: %w", err)
|
||||
}
|
||||
metaBytes, _ := json.Marshal(nonNilMap(metadata))
|
||||
if len(content) == 0 {
|
||||
content = []byte("{}")
|
||||
}
|
||||
@@ -813,11 +839,6 @@ func pgTimePtr(ts pgtype.Timestamptz) *time.Time {
|
||||
return &value
|
||||
}
|
||||
|
||||
func mustParseUUID(id string) pgtype.UUID {
|
||||
pgID, _ := parseUUID(id)
|
||||
return pgID
|
||||
}
|
||||
|
||||
func nonNilMap(m map[string]any) map[string]any {
|
||||
if m == nil {
|
||||
return map[string]any{}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -56,17 +55,6 @@ func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture {
|
||||
}
|
||||
}
|
||||
|
||||
func isLegacyChatPresenceSchemaError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "relation \"chat_channelIdentity_presence\" does not exist") ||
|
||||
strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") ||
|
||||
strings.Contains(msg, "column \"sender_user_id\" of relation \"chat_messages\" does not exist") ||
|
||||
strings.Contains(msg, "column \"created_by_user_id\" of relation \"chats\" does not exist")
|
||||
}
|
||||
|
||||
func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
||||
IsActive: true,
|
||||
@@ -83,7 +71,10 @@ func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerU
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
meta, _ := json.Marshal(map[string]any{"source": "chat-presence-integration-test"})
|
||||
meta, err := json.Marshal(map[string]any{"source": "chat-presence-integration-test"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||
OwnerUserID: pgOwnerID,
|
||||
Type: "personal",
|
||||
@@ -105,10 +96,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
||||
|
||||
ownerUserID, err := createUserForChatPresence(ctx, fixture.queries)
|
||||
if err != nil {
|
||||
if isLegacyChatPresenceSchemaError(err) {
|
||||
fixture.cleanup()
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
fixture.cleanup()
|
||||
t.Fatalf("create owner user failed: %v", err)
|
||||
}
|
||||
@@ -128,10 +115,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
||||
Title: "presence-observed",
|
||||
})
|
||||
if err != nil {
|
||||
if isLegacyChatPresenceSchemaError(err) {
|
||||
fixture.cleanup()
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
fixture.cleanup()
|
||||
t.Fatalf("create chat failed: %v", err)
|
||||
}
|
||||
@@ -143,10 +126,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
||||
"presence-observer",
|
||||
)
|
||||
if err != nil {
|
||||
if isLegacyChatPresenceSchemaError(err) {
|
||||
fixture.cleanup()
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
fixture.cleanup()
|
||||
t.Fatalf("resolve channelIdentity failed: %v", err)
|
||||
}
|
||||
@@ -165,10 +144,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
if isLegacyChatPresenceSchemaError(err) {
|
||||
fixture.cleanup()
|
||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
||||
}
|
||||
fixture.cleanup()
|
||||
t.Fatalf("persist message failed: %v", err)
|
||||
}
|
||||
|
||||
+10
-9
@@ -234,15 +234,16 @@ type ToolCallFunction struct {
|
||||
|
||||
// ChatRequest is the input for Chat and StreamChat.
|
||||
type ChatRequest struct {
|
||||
BotID string `json:"-"`
|
||||
ChatID string `json:"-"`
|
||||
Token string `json:"-"`
|
||||
ChannelIdentityID string `json:"-"`
|
||||
ContainerID string `json:"-"`
|
||||
DisplayName string `json:"-"`
|
||||
RouteID string `json:"-"`
|
||||
ChatToken string `json:"-"`
|
||||
ExternalMessageID string `json:"-"`
|
||||
BotID string `json:"-"`
|
||||
ChatID string `json:"-"`
|
||||
Token string `json:"-"`
|
||||
UserID string `json:"-"`
|
||||
SourceChannelIdentityID string `json:"-"`
|
||||
ContainerID string `json:"-"`
|
||||
DisplayName string `json:"-"`
|
||||
RouteID string `json:"-"`
|
||||
ChatToken string `json:"-"`
|
||||
ExternalMessageID string `json:"-"`
|
||||
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model,omitempty"`
|
||||
|
||||
@@ -53,7 +53,7 @@ func (h *BindHandler) Issue(c echo.Context) error {
|
||||
if h.service == nil {
|
||||
return echo.NewHTTPError(http.StatusServiceUnavailable, "bind service not available")
|
||||
}
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (h *BindHandler) Issue(c echo.Context) error {
|
||||
ttl = time.Duration(req.TTLSeconds) * time.Second
|
||||
}
|
||||
|
||||
code, err := h.service.Issue(c.Request().Context(), channelIdentityID, strings.TrimSpace(req.Platform), ttl)
|
||||
code, err := h.service.Issue(c.Request().Context(), userID, strings.TrimSpace(req.Platform), ttl)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
@@ -79,13 +79,13 @@ func (h *BindHandler) Issue(c echo.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *BindHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
func (h *BindHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -189,7 +189,8 @@ func (h *ChatHandler) SendMessage(c echo.Context) error {
|
||||
req.BotID = chatObj.BotID
|
||||
req.ChatID = chatID
|
||||
req.Token = c.Request().Header.Get("Authorization")
|
||||
req.ChannelIdentityID = channelIdentityID
|
||||
req.UserID = channelIdentityID
|
||||
req.SourceChannelIdentityID = channelIdentityID
|
||||
|
||||
resp, err := h.resolver.Chat(c.Request().Context(), req)
|
||||
if err != nil {
|
||||
@@ -224,7 +225,8 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error {
|
||||
req.BotID = chatObj.BotID
|
||||
req.ChatID = chatID
|
||||
req.Token = c.Request().Header.Get("Authorization")
|
||||
req.ChannelIdentityID = channelIdentityID
|
||||
req.UserID = channelIdentityID
|
||||
req.SourceChannelIdentityID = channelIdentityID
|
||||
|
||||
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
|
||||
c.Response().Header().Set(echo.HeaderCacheControl, "no-cache")
|
||||
@@ -258,7 +260,10 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error {
|
||||
if err != nil {
|
||||
h.logger.Error("chat stream failed", slog.Any("error", err))
|
||||
errData := map[string]string{"error": err.Error()}
|
||||
data, _ := json.Marshal(errData)
|
||||
data, marshalErr := json.Marshal(errData)
|
||||
if marshalErr != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, marshalErr.Error())
|
||||
}
|
||||
writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data)))
|
||||
writer.Flush()
|
||||
flusher.Flush()
|
||||
@@ -500,7 +505,7 @@ func (h *ChatHandler) ListThreads(c echo.Context) error {
|
||||
// --- helpers ---
|
||||
|
||||
func (h *ChatHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -558,7 +563,10 @@ func (h *ChatHandler) requireParticipant(ctx context.Context, chatID, channelIde
|
||||
}
|
||||
// Admin bypass.
|
||||
if h.accountService != nil {
|
||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if isAdmin {
|
||||
return nil
|
||||
}
|
||||
@@ -579,7 +587,10 @@ func (h *ChatHandler) requireReadable(ctx context.Context, chatID, channelIdenti
|
||||
}
|
||||
// Admin bypass.
|
||||
if h.accountService != nil {
|
||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if isAdmin {
|
||||
return nil
|
||||
}
|
||||
@@ -600,7 +611,10 @@ func (h *ChatHandler) requireRole(ctx context.Context, chatID, channelIdentityID
|
||||
}
|
||||
// Admin bypass.
|
||||
if h.accountService != nil {
|
||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if isAdmin {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -627,7 +627,7 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -85,10 +85,10 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
|
||||
req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL)
|
||||
req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL)
|
||||
|
||||
channelIdentityID := ""
|
||||
userID := ""
|
||||
if c.Get("user") != nil {
|
||||
if value, err := auth.ChannelIdentityIDFromContext(c); err == nil {
|
||||
channelIdentityID = value
|
||||
if value, err := auth.UserIDFromContext(c); err == nil {
|
||||
userID = value
|
||||
}
|
||||
}
|
||||
result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{
|
||||
@@ -101,7 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
|
||||
ImageURL: req.Input.ImageURL,
|
||||
VideoURL: req.Input.VideoURL,
|
||||
},
|
||||
ChannelIdentityID: channelIdentityID,
|
||||
ChannelIdentityID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
message := err.Error()
|
||||
|
||||
@@ -228,7 +228,7 @@ func (h *LocalChannelHandler) ensureChatParticipant(ctx context.Context, chatID,
|
||||
}
|
||||
|
||||
func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -218,7 +218,7 @@ func (h *MCPHandler) Delete(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
userID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -364,7 +364,10 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured")
|
||||
}
|
||||
if h.accountService != nil {
|
||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if isAdmin {
|
||||
return nil
|
||||
}
|
||||
@@ -380,7 +383,7 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
|
||||
}
|
||||
|
||||
func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ func (h *ModelsHandler) Enable(c echo.Context) error {
|
||||
if h.settingsService == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured")
|
||||
}
|
||||
userID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ type preauthIssueRequest struct {
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) Issue(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
var req preauthIssueRequest
|
||||
@@ -59,33 +59,33 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
|
||||
if req.TTLSeconds > 0 {
|
||||
ttl = time.Duration(req.TTLSeconds) * time.Second
|
||||
}
|
||||
key, err := h.service.Issue(c.Request().Context(), botID, channelIdentityID, ttl)
|
||||
key, err := h.service.Issue(c.Request().Context(), botID, userID, ttl)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, key)
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
|
||||
@@ -51,7 +51,7 @@ func (h *ScheduleHandler) Register(e *echo.Echo) {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/schedule [post]
|
||||
func (h *ScheduleHandler) Create(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error {
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
var req schedule.CreateRequest
|
||||
@@ -82,7 +82,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/schedule [get]
|
||||
func (h *ScheduleHandler) List(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (h *ScheduleHandler) List(c echo.Context) error {
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := h.service.List(c.Request().Context(), botID)
|
||||
@@ -111,7 +111,7 @@ func (h *ScheduleHandler) List(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/schedule/{id} [get]
|
||||
func (h *ScheduleHandler) Get(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -130,7 +130,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error {
|
||||
if item.BotID != botID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(http.StatusOK, item)
|
||||
@@ -147,7 +147,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/schedule/{id} [put]
|
||||
func (h *ScheduleHandler) Update(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error {
|
||||
if item.BotID != botID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := h.service.Update(c.Request().Context(), id, req)
|
||||
@@ -190,7 +190,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error {
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/schedule/{id} [delete]
|
||||
func (h *ScheduleHandler) Delete(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -209,7 +209,7 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
|
||||
if item.BotID != botID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.service.Delete(c.Request().Context(), id); err != nil {
|
||||
@@ -218,26 +218,26 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *ScheduleHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
|
||||
@@ -130,7 +130,7 @@ func (h *SettingsHandler) Delete(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -433,7 +433,7 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -907,7 +907,7 @@ func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID
|
||||
}
|
||||
|
||||
func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
+12
-11
@@ -167,17 +167,18 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
desc, _ = p.registry.GetDescriptor(msg.Channel)
|
||||
}
|
||||
resp, err := p.chat.Chat(ctx, chat.ChatRequest{
|
||||
BotID: identity.BotID,
|
||||
ChatID: resolved.ChatID,
|
||||
Token: token,
|
||||
ChannelIdentityID: identity.UserID,
|
||||
DisplayName: identity.DisplayName,
|
||||
RouteID: resolved.RouteID,
|
||||
ChatToken: chatToken,
|
||||
ExternalMessageID: strings.TrimSpace(msg.Message.ID),
|
||||
Query: text,
|
||||
CurrentChannel: msg.Channel.String(),
|
||||
Channels: []string{msg.Channel.String()},
|
||||
BotID: identity.BotID,
|
||||
ChatID: resolved.ChatID,
|
||||
Token: token,
|
||||
UserID: identity.UserID,
|
||||
SourceChannelIdentityID: identity.ChannelIdentityID,
|
||||
DisplayName: identity.DisplayName,
|
||||
RouteID: resolved.RouteID,
|
||||
ChatToken: chatToken,
|
||||
ExternalMessageID: strings.TrimSpace(msg.Message.ID),
|
||||
Query: text,
|
||||
CurrentChannel: msg.Channel.String(),
|
||||
Channels: []string{msg.Channel.String()},
|
||||
})
|
||||
if err != nil {
|
||||
if p.logger != nil {
|
||||
|
||||
@@ -97,8 +97,11 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) {
|
||||
if gateway.gotReq.Query != "hello" {
|
||||
t.Errorf("expected query 'hello', got: %s", gateway.gotReq.Query)
|
||||
}
|
||||
if gateway.gotReq.ChannelIdentityID != "channelIdentity-1" {
|
||||
t.Errorf("expected channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.ChannelIdentityID)
|
||||
if gateway.gotReq.UserID != "channelIdentity-1" {
|
||||
t.Errorf("expected user_id 'channelIdentity-1', got: %s", gateway.gotReq.UserID)
|
||||
}
|
||||
if gateway.gotReq.SourceChannelIdentityID != "channelIdentity-1" {
|
||||
t.Errorf("expected source_channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.SourceChannelIdentityID)
|
||||
}
|
||||
if gateway.gotReq.ChatID != "chat-1" {
|
||||
t.Errorf("expected chat_id 'chat-1', got: %s", gateway.gotReq.ChatID)
|
||||
|
||||
+35
-13
@@ -209,7 +209,9 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
|
||||
state.Decision = &decision
|
||||
return state, err
|
||||
}
|
||||
if r.policy != nil && isGroupConversationType(msg.Conversation.Type) {
|
||||
|
||||
// Personal bots are owner-only and must not depend on member/guest/preauth bypass.
|
||||
if r.policy != nil {
|
||||
botType, err := r.policy.BotType(ctx, botID)
|
||||
if err != nil {
|
||||
return state, err
|
||||
@@ -219,20 +221,35 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
|
||||
if err != nil {
|
||||
return state, err
|
||||
}
|
||||
if strings.TrimSpace(state.Identity.UserID) == "" || strings.TrimSpace(ownerUserID) != strings.TrimSpace(state.Identity.UserID) {
|
||||
// Personal bots in group chats only answer owner messages.
|
||||
state.Decision = &IdentityDecision{Stop: true}
|
||||
isOwner := strings.TrimSpace(state.Identity.UserID) != "" &&
|
||||
strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID)
|
||||
if !isOwner {
|
||||
if isGroupConversationType(msg.Conversation.Type) {
|
||||
// Ignore non-owner group messages for personal bots.
|
||||
state.Decision = &IdentityDecision{Stop: true}
|
||||
return state, nil
|
||||
}
|
||||
state.Decision = &IdentityDecision{
|
||||
Stop: true,
|
||||
Reply: channel.Message{Text: r.unboundReply},
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
// Owner can chat normally in group for personal bots.
|
||||
state.Identity.ForceReply = true
|
||||
if isGroupConversationType(msg.Conversation.Type) {
|
||||
// Owner can chat in group for personal bots.
|
||||
state.Identity.ForceReply = true
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Authorization (bot membership check).
|
||||
if r.members != nil {
|
||||
if strings.TrimSpace(state.Identity.UserID) != "" {
|
||||
isMember, _ := r.members.IsMember(ctx, botID, state.Identity.UserID)
|
||||
isMember, err := r.members.IsMember(ctx, botID, state.Identity.UserID)
|
||||
if err != nil {
|
||||
return state, fmt.Errorf("check bot membership: %w", err)
|
||||
}
|
||||
if isMember {
|
||||
return state, nil
|
||||
}
|
||||
@@ -256,9 +273,6 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
|
||||
return state, err
|
||||
}
|
||||
if allowed {
|
||||
if r.members != nil && strings.TrimSpace(state.Identity.UserID) != "" {
|
||||
_ = r.members.UpsertMemberSimple(ctx, botID, state.Identity.UserID, "member")
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
@@ -312,9 +326,13 @@ func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.
|
||||
return true, reply("Current channel account is not linked to a user."), nil
|
||||
}
|
||||
if r.members != nil {
|
||||
_ = r.members.UpsertMemberSimple(ctx, botID, userID, "member")
|
||||
if err := r.members.UpsertMemberSimple(ctx, botID, userID, "member"); err != nil {
|
||||
return true, IdentityDecision{}, fmt.Errorf("upsert preauth member: %w", err)
|
||||
}
|
||||
}
|
||||
if _, err := r.preauth.MarkUsed(ctx, key.ID); err != nil {
|
||||
return true, IdentityDecision{}, fmt.Errorf("mark preauth key used: %w", err)
|
||||
}
|
||||
_, _ = r.preauth.MarkUsed(ctx, key.ID)
|
||||
return true, reply(r.preauthReply), nil
|
||||
}
|
||||
|
||||
@@ -365,7 +383,11 @@ func (r *IdentityResolver) tryHandleBindCode(ctx context.Context, msg channel.In
|
||||
// Resolve linked user after binding.
|
||||
newUserID := code.IssuedByUserID
|
||||
if r.channelIdentities != nil {
|
||||
if linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID); err == nil && strings.TrimSpace(linkedUserID) != "" {
|
||||
linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return true, IdentityDecision{}, "", fmt.Errorf("resolve linked user after bind: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(linkedUserID) != "" {
|
||||
newUserID = linkedUserID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelCh
|
||||
return f.consumeErr
|
||||
}
|
||||
|
||||
func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) {
|
||||
func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) {
|
||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}}
|
||||
memberSvc := &fakeMemberService{isMember: false}
|
||||
policySvc := &fakePolicyService{allow: true, botType: "public"}
|
||||
@@ -165,8 +165,8 @@ func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) {
|
||||
if state.Identity.ChannelIdentityID != "channelIdentity-1" {
|
||||
t.Fatalf("expected channelIdentity-1, got: %s", state.Identity.ChannelIdentityID)
|
||||
}
|
||||
if !memberSvc.upsertCalled {
|
||||
t.Fatal("expected UpsertMemberSimple to be called")
|
||||
if memberSvc.upsertCalled {
|
||||
t.Fatal("guest allow should not upsert membership")
|
||||
}
|
||||
if state.Decision != nil {
|
||||
t.Fatal("expected no decision for allowed guest")
|
||||
@@ -376,6 +376,35 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing.T) {
|
||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-non-owner"}}
|
||||
memberSvc := &fakeMemberService{isMember: true}
|
||||
policySvc := &fakePolicyService{allow: true, botType: "personal", ownerUserID: "channelIdentity-owner"}
|
||||
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "")
|
||||
|
||||
msg := channel.InboundMessage{
|
||||
BotID: "bot-1",
|
||||
Channel: channel.ChannelType("feishu"),
|
||||
Message: channel.Message{Text: "hello from non-owner"},
|
||||
Sender: channel.Identity{SubjectID: "ext-non-owner"},
|
||||
Conversation: channel.Conversation{
|
||||
ID: "p2p-2",
|
||||
Type: "p2p",
|
||||
},
|
||||
}
|
||||
|
||||
state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if state.Decision == nil || !state.Decision.Stop {
|
||||
t.Fatal("non-owner direct message should be rejected for personal bot")
|
||||
}
|
||||
if state.Decision.Reply.Text != "Access denied." {
|
||||
t.Fatalf("unexpected reject message: %s", state.Decision.Reply.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) {
|
||||
shadowID := "channelIdentity-shadow"
|
||||
humanID := "channelIdentity-human"
|
||||
|
||||
@@ -55,8 +55,8 @@ func TestGenerateTriggerToken(t *testing.T) {
|
||||
if sub, _ := claims["sub"].(string); sub != userID {
|
||||
t.Errorf("expected sub=%s, got=%s", userID, sub)
|
||||
}
|
||||
if uid, _ := claims["channel_identity_id"].(string); uid != userID {
|
||||
t.Errorf("expected channel_identity_id=%s, got=%s", userID, uid)
|
||||
if uid, _ := claims["user_id"].(string); uid != userID {
|
||||
t.Errorf("expected user_id=%s, got=%s", userID, uid)
|
||||
}
|
||||
exp, _ := claims["exp"].(float64)
|
||||
if exp == 0 {
|
||||
|
||||
Reference in New Issue
Block a user