mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor(core): finalize user-centric identity and policy cleanup
Unify auth and chat identity semantics around user_id, enforce personal-bot owner-only authorization, and remove legacy compatibility branches in integration tests.
This commit is contained in:
+4
-13
@@ -52,17 +52,9 @@ func UserIDFromContext(c echo.Context) (string, error) {
|
|||||||
if userID := claimString(claims, claimSubject); userID != "" {
|
if userID := claimString(claims, claimSubject); userID != "" {
|
||||||
return userID, nil
|
return userID, nil
|
||||||
}
|
}
|
||||||
if legacyChannelIdentityID := claimString(claims, claimChannelIdentityID); legacyChannelIdentityID != "" {
|
|
||||||
return legacyChannelIdentityID, nil
|
|
||||||
}
|
|
||||||
return "", echo.NewHTTPError(http.StatusUnauthorized, "user id missing")
|
return "", echo.NewHTTPError(http.StatusUnauthorized, "user id missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChannelIdentityIDFromContext is kept as compatibility alias and returns user id.
|
|
||||||
func ChannelIdentityIDFromContext(c echo.Context) (string, error) {
|
|
||||||
return UserIDFromContext(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateToken creates a signed JWT for the user.
|
// GenerateToken creates a signed JWT for the user.
|
||||||
func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) {
|
func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) {
|
||||||
if strings.TrimSpace(userID) == "" {
|
if strings.TrimSpace(userID) == "" {
|
||||||
@@ -78,11 +70,10 @@ func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time
|
|||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
expiresAt := now.Add(expiresIn)
|
expiresAt := now.Add(expiresIn)
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
claimSubject: userID,
|
claimSubject: userID,
|
||||||
claimUserID: userID,
|
claimUserID: userID,
|
||||||
claimChannelIdentityID: userID, // legacy compatibility for handlers still reading channel_identity_id
|
"iat": now.Unix(),
|
||||||
"iat": now.Unix(),
|
"exp": expiresAt.Unix(),
|
||||||
"exp": expiresAt.Unix(),
|
|
||||||
}
|
}
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
signed, err := token.SignedString([]byte(secret))
|
signed, err := token.SignedString([]byte(secret))
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/channelidentities"
|
|
||||||
"github.com/memohai/memoh/internal/bind"
|
"github.com/memohai/memoh/internal/bind"
|
||||||
|
"github.com/memohai/memoh/internal/channelidentities"
|
||||||
"github.com/memohai/memoh/internal/db"
|
"github.com/memohai/memoh/internal/db"
|
||||||
"github.com/memohai/memoh/internal/db/sqlc"
|
"github.com/memohai/memoh/internal/db/sqlc"
|
||||||
)
|
)
|
||||||
@@ -47,7 +47,7 @@ func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.S
|
|||||||
return queries, channelIdentitySvc, bindSvc, func() { pool.Close() }
|
return queries, channelIdentitySvc, bindSvc, func() { pool.Close() }
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||||
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
Metadata: []byte("{}"),
|
Metadata: []byte("{}"),
|
||||||
@@ -58,12 +58,15 @@ func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
|||||||
return db.UUIDToString(row.ID), nil
|
return db.UUIDToString(row.ID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createBot(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
|
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
|
||||||
pgOwnerID, err := db.ParseUUID(ownerUserID)
|
pgOwnerID, err := db.ParseUUID(ownerUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||||
OwnerUserID: pgOwnerID,
|
OwnerUserID: pgOwnerID,
|
||||||
Type: "personal",
|
Type: "personal",
|
||||||
@@ -82,15 +85,15 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ownerUserID, err := createUser(ctx, queries)
|
ownerUserID, err := createUserForBindTest(ctx, queries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create owner user failed: %v", err)
|
t.Fatalf("create owner user failed: %v", err)
|
||||||
}
|
}
|
||||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
t.Fatalf("create source channel identity failed: %v", err)
|
||||||
}
|
}
|
||||||
botID, err := createBot(ctx, queries, ownerUserID)
|
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create bot failed: %v", err)
|
t.Fatalf("create bot failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -127,177 +130,6 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
|
||||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
ownerUserID, err := createUser(ctx, queries)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create owner user failed: %v", err)
|
|
||||||
}
|
|
||||||
otherUserID, err := createUser(ctx, queries)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create other user failed: %v", err)
|
|
||||||
}
|
|
||||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil {
|
|
||||||
t.Fatalf("pre-link source channelIdentity failed: %v", err)
|
|
||||||
}
|
|
||||||
botID, err := createBot(ctx, queries, ownerUserID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create bot failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("issue bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) {
|
|
||||||
t.Fatalf("expected ErrLinkConflict, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
after, err := bindSvc.Get(ctx, code.Token)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("get bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if !after.UsedAt.IsZero() {
|
|
||||||
t.Fatal("expected used_at to remain empty when consume fails")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
package bind_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/channelidentities"
|
|
||||||
"github.com/memohai/memoh/internal/bind"
|
|
||||||
"github.com/memohai/memoh/internal/db"
|
|
||||||
"github.com/memohai/memoh/internal/db/sqlc"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
dsn := os.Getenv("TEST_POSTGRES_DSN")
|
|
||||||
if dsn == "" {
|
|
||||||
t.Skip("skip integration test: TEST_POSTGRES_DSN is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
pool, err := pgxpool.New(ctx, dsn)
|
|
||||||
if err != nil {
|
|
||||||
t.Skipf("skip integration test: cannot connect to database: %v", err)
|
|
||||||
}
|
|
||||||
if err := pool.Ping(ctx); err != nil {
|
|
||||||
pool.Close()
|
|
||||||
t.Skipf("skip integration test: database ping failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
queries := sqlc.New(pool)
|
|
||||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
||||||
channelIdentitySvc := channelidentities.NewService(logger, queries)
|
|
||||||
bindSvc := bind.NewService(logger, pool, queries)
|
|
||||||
|
|
||||||
cleanup := func() {
|
|
||||||
pool.Close()
|
|
||||||
}
|
|
||||||
return queries, channelIdentitySvc, bindSvc, cleanup
|
|
||||||
}
|
|
||||||
|
|
||||||
func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
|
||||||
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
|
||||||
IsActive: true,
|
|
||||||
Metadata: []byte("{}"),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return db.UUIDToString(row.ID), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
|
|
||||||
pgOwnerID, err := db.ParseUUID(ownerUserID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
|
||||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
|
||||||
OwnerUserID: pgOwnerID,
|
|
||||||
Type: "personal",
|
|
||||||
DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true},
|
|
||||||
AvatarUrl: pgtype.Text{},
|
|
||||||
IsActive: true,
|
|
||||||
Metadata: meta,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return db.UUIDToString(row.ID), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
|
||||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
ownerUserID, err := createUserForBindTest(ctx, queries)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create owner user failed: %v", err)
|
|
||||||
}
|
|
||||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
|
||||||
}
|
|
||||||
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create bot failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("issue bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil {
|
|
||||||
t.Fatalf("consume bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
after, err := bindSvc.Get(ctx, code.Token)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("get bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if after.UsedAt.IsZero() {
|
|
||||||
t.Fatal("expected used_at to be set after successful consume")
|
|
||||||
}
|
|
||||||
if after.UsedByChannelIdentityID != sourceChannelIdentity.ID {
|
|
||||||
t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID)
|
|
||||||
}
|
|
||||||
|
|
||||||
linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("get linked user failed: %v", err)
|
|
||||||
}
|
|
||||||
if linkedUserID != ownerUserID {
|
|
||||||
t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrCodeUsed) {
|
|
||||||
t.Fatalf("expected ErrCodeUsed on second consume, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
||||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -313,10 +145,10 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
|||||||
}
|
}
|
||||||
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create source channelIdentity failed: %v", err)
|
t.Fatalf("create source channel identity failed: %v", err)
|
||||||
}
|
}
|
||||||
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil {
|
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil {
|
||||||
t.Fatalf("pre-link source channelIdentity failed: %v", err)
|
t.Fatalf("pre-link source channel identity failed: %v", err)
|
||||||
}
|
}
|
||||||
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
|
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -339,224 +171,3 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
|||||||
t.Fatal("expected used_at to remain empty when consume fails")
|
t.Fatal("expected used_at to remain empty when consume fails")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
package bind_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/channelidentities"
|
|
||||||
"github.com/memohai/memoh/internal/bind"
|
|
||||||
"github.com/memohai/memoh/internal/db"
|
|
||||||
"github.com/memohai/memoh/internal/db/sqlc"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
dsn := os.Getenv("TEST_POSTGRES_DSN")
|
|
||||||
if dsn == "" {
|
|
||||||
t.Skip("skip integration test: TEST_POSTGRES_DSN is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
pool, err := pgxpool.New(ctx, dsn)
|
|
||||||
if err != nil {
|
|
||||||
t.Skipf("skip integration test: cannot connect to database: %v", err)
|
|
||||||
}
|
|
||||||
if err := pool.Ping(ctx); err != nil {
|
|
||||||
pool.Close()
|
|
||||||
t.Skipf("skip integration test: database ping failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
queries := sqlc.New(pool)
|
|
||||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
||||||
channelIdentitySvc := channelidentities.NewService(logger, queries)
|
|
||||||
bindSvc := bind.NewService(logger, pool, queries)
|
|
||||||
|
|
||||||
cleanup := func() {
|
|
||||||
pool.Close()
|
|
||||||
}
|
|
||||||
return queries, channelIdentitySvc, bindSvc, cleanup
|
|
||||||
}
|
|
||||||
|
|
||||||
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerChannelIdentityID string) (string, error) {
|
|
||||||
pgOwnerID, err := db.ParseUUID(ownerChannelIdentityID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
|
||||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
|
||||||
OwnerChannelIdentityID: pgOwnerID,
|
|
||||||
Type: "personal",
|
|
||||||
DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true},
|
|
||||||
AvatarUrl: pgtype.Text{},
|
|
||||||
IsActive: true,
|
|
||||||
Metadata: meta,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return db.UUIDToString(row.ID), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createChatForBindTest(ctx context.Context, queries *sqlc.Queries, botID, channelIdentityID string) (string, error) {
|
|
||||||
pgBotID, err := db.ParseUUID(botID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
pgChannelIdentityID, err := db.ParseUUID(channelIdentityID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
row, err := queries.CreateChat(ctx, sqlc.CreateChatParams{
|
|
||||||
BotID: pgBotID,
|
|
||||||
Kind: "direct",
|
|
||||||
ParentChatID: pgtype.UUID{},
|
|
||||||
Title: pgtype.Text{},
|
|
||||||
CreatedBy: pgChannelIdentityID,
|
|
||||||
Metadata: []byte("{}"),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return db.UUIDToString(row.ID), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
|
|
||||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
human, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create human failed: %v", err)
|
|
||||||
}
|
|
||||||
shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create shadow failed: %v", err)
|
|
||||||
}
|
|
||||||
botID, err := createBotForBindTest(ctx, queries, human.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create bot failed: %v", err)
|
|
||||||
}
|
|
||||||
chatID, err := createChatForBindTest(ctx, queries, botID, human.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create chat failed: %v", err)
|
|
||||||
}
|
|
||||||
pgChatID, err := db.ParseUUID(chatID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse chat id failed: %v", err)
|
|
||||||
}
|
|
||||||
pgShadowID, err := db.ParseUUID(shadow.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse shadow id failed: %v", err)
|
|
||||||
}
|
|
||||||
pgHumanID, err := db.ParseUUID(human.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse human id failed: %v", err)
|
|
||||||
}
|
|
||||||
if _, err := queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{
|
|
||||||
ChatID: pgChatID,
|
|
||||||
ChannelIdentityID: pgShadowID,
|
|
||||||
Role: "member",
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("add shadow participant failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
code, err := bindSvc.Issue(ctx, botID, human.ID, 10*time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("issue bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := bindSvc.Consume(ctx, code, shadow.ID); err != nil {
|
|
||||||
t.Fatalf("consume bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
after, err := bindSvc.Get(ctx, code.Token)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("get bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if after.UsedAt.IsZero() {
|
|
||||||
t.Fatal("expected used_at to be set after successful consume")
|
|
||||||
}
|
|
||||||
if after.UsedByChannelIdentityID != shadow.ID {
|
|
||||||
t.Fatalf("expected used_by_channel_identity_id=%s, got %s", shadow.ID, after.UsedByChannelIdentityID)
|
|
||||||
}
|
|
||||||
|
|
||||||
canonical, err := channelIdentitySvc.Canonicalize(ctx, shadow.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("canonicalize failed: %v", err)
|
|
||||||
}
|
|
||||||
if canonical != human.ID {
|
|
||||||
t.Fatalf("expected canonical=%s, got %s", human.ID, canonical)
|
|
||||||
}
|
|
||||||
if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{
|
|
||||||
ChatID: pgChatID,
|
|
||||||
ChannelIdentityID: pgHumanID,
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("expected human participant after bind, got error: %v", err)
|
|
||||||
}
|
|
||||||
if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{
|
|
||||||
ChatID: pgChatID,
|
|
||||||
ChannelIdentityID: pgShadowID,
|
|
||||||
}); !errors.Is(err, pgx.ErrNoRows) {
|
|
||||||
t.Fatalf("expected shadow participant removed after bind, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrCodeUsed) {
|
|
||||||
t.Fatalf("expected ErrCodeUsed on second consume, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
|
|
||||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
humanA, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create humanA failed: %v", err)
|
|
||||||
}
|
|
||||||
humanB, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create humanB failed: %v", err)
|
|
||||||
}
|
|
||||||
shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create shadow failed: %v", err)
|
|
||||||
}
|
|
||||||
botID, err := createBotForBindTest(ctx, queries, humanA.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("create bot failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-link shadow to another user so bind consume hits link conflict.
|
|
||||||
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, shadow.ID, humanB.ID); err != nil {
|
|
||||||
t.Fatalf("pre link shadow->humanB failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
code, err := bindSvc.Issue(ctx, botID, humanA.ID, 10*time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("issue bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrLinkConflict) {
|
|
||||||
t.Fatalf("expected ErrLinkConflict, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
after, err := bindSvc.Get(ctx, code.Token)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("get bind code failed: %v", err)
|
|
||||||
}
|
|
||||||
if !after.UsedAt.IsZero() {
|
|
||||||
t.Fatal("expected used_at to remain empty when consume fails")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -61,7 +60,10 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||||
OwnerUserID: pgOwnerID,
|
OwnerUserID: pgOwnerID,
|
||||||
Type: "personal",
|
Type: "personal",
|
||||||
@@ -75,14 +77,6 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st
|
|||||||
return db.UUIDToString(row.ID), nil
|
return db.UUIDToString(row.ID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLegacyBindSchemaError(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
msg := strings.ToLower(err.Error())
|
|
||||||
return strings.Contains(msg, "relation \"users\" does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) {
|
func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) {
|
||||||
queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t)
|
queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -90,9 +84,6 @@ func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ownerUserID, err := createUserForBind(ctx, queries)
|
ownerUserID, err := createUserForBind(ctx, queries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyBindSchemaError(err) {
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
t.Fatalf("create owner user failed: %v", err)
|
t.Fatalf("create owner user failed: %v", err)
|
||||||
}
|
}
|
||||||
sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source")
|
sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source")
|
||||||
@@ -134,9 +125,6 @@ func TestBindConsumeConflictDoesNotMarkUsed(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
issuerUserID, err := createUserForBind(ctx, queries)
|
issuerUserID, err := createUserForBind(ctx, queries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyBindSchemaError(err) {
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
t.Fatalf("create issuer user failed: %v", err)
|
t.Fatalf("create issuer user failed: %v", err)
|
||||||
}
|
}
|
||||||
otherUserID, err := createUserForBind(ctx, queries)
|
otherUserID, err := createUserForBind(ctx, queries)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -43,16 +42,6 @@ func formatUUID(bytes [16]byte) string {
|
|||||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
|
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLegacyChannelIdentitySchemaError(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
msg := strings.ToLower(err.Error())
|
|
||||||
return strings.Contains(msg, "channelidentities_kind_check") ||
|
|
||||||
strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") ||
|
|
||||||
strings.Contains(msg, "column \"channel_subject_id\" of relation \"channelidentities\" does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) {
|
func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) {
|
||||||
svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t)
|
svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -61,9 +50,6 @@ func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) {
|
|||||||
externalID := fmt.Sprintf("stable_%d", time.Now().UnixNano())
|
externalID := fmt.Sprintf("stable_%d", time.Now().UnixNano())
|
||||||
first, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "first")
|
first, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "first")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyChannelIdentitySchemaError(err) {
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
t.Fatalf("first resolve failed: %v", err)
|
t.Fatalf("first resolve failed: %v", err)
|
||||||
}
|
}
|
||||||
second, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "second")
|
second, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "second")
|
||||||
@@ -82,9 +68,6 @@ func TestChannelIdentityLinkToUser(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", fmt.Sprintf("link_%d", time.Now().UnixNano()), "tg")
|
channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", fmt.Sprintf("link_%d", time.Now().UnixNano()), "tg")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyChannelIdentitySchemaError(err) {
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
t.Fatalf("resolve channelIdentity failed: %v", err)
|
t.Fatalf("resolve channelIdentity failed: %v", err)
|
||||||
}
|
}
|
||||||
user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
||||||
|
|||||||
+21
-19
@@ -208,7 +208,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
|||||||
r.enforceGroupMemoryPolicy(ctx, req.ChatID, &chatSettings)
|
r.enforceGroupMemoryPolicy(ctx, req.ChatID, &chatSettings)
|
||||||
}
|
}
|
||||||
|
|
||||||
userSettings, err := r.loadUserSettings(ctx, req.ChannelIdentityID)
|
userSettings, err := r.loadUserSettings(ctx, req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resolvedContext{}, err
|
return resolvedContext{}, err
|
||||||
}
|
}
|
||||||
@@ -297,7 +297,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
|||||||
BotID: req.BotID,
|
BotID: req.BotID,
|
||||||
SessionID: req.ChatID,
|
SessionID: req.ChatID,
|
||||||
ContainerID: containerID,
|
ContainerID: containerID,
|
||||||
ChannelIdentityID: firstNonEmpty(req.ChannelIdentityID, req.BotID),
|
ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID),
|
||||||
DisplayName: firstNonEmpty(req.DisplayName, "User"),
|
DisplayName: firstNonEmpty(req.DisplayName, "User"),
|
||||||
CurrentPlatform: req.CurrentChannel,
|
CurrentPlatform: req.CurrentChannel,
|
||||||
ReplyTarget: "",
|
ReplyTarget: "",
|
||||||
@@ -348,11 +348,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
|||||||
chatID = "schedule-" + payload.ID
|
chatID = "schedule-" + payload.ID
|
||||||
}
|
}
|
||||||
req := ChatRequest{
|
req := ChatRequest{
|
||||||
BotID: botID,
|
BotID: botID,
|
||||||
ChatID: chatID,
|
ChatID: chatID,
|
||||||
Query: payload.Command,
|
Query: payload.Command,
|
||||||
ChannelIdentityID: payload.OwnerUserID,
|
UserID: payload.OwnerUserID,
|
||||||
Token: token,
|
Token: token,
|
||||||
}
|
}
|
||||||
rc, err := r.resolve(ctx, req)
|
rc, err := r.resolve(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -373,7 +373,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
|||||||
BotID: rc.payload.Identity.BotID,
|
BotID: rc.payload.Identity.BotID,
|
||||||
SessionID: rc.payload.Identity.SessionID,
|
SessionID: rc.payload.Identity.SessionID,
|
||||||
ContainerID: rc.payload.Identity.ContainerID,
|
ContainerID: rc.payload.Identity.ContainerID,
|
||||||
ChannelIdentityID: firstNonEmpty(payload.OwnerUserID, botID),
|
ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID),
|
||||||
DisplayName: "Scheduler",
|
DisplayName: "Scheduler",
|
||||||
},
|
},
|
||||||
Attachments: rc.payload.Attachments,
|
Attachments: rc.payload.Attachments,
|
||||||
@@ -676,8 +676,8 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest
|
|||||||
if settings.EnableChatMemory {
|
if settings.EnableChatMemory {
|
||||||
scopes = append(scopes, memoryScope{Namespace: "chat", ScopeID: req.ChatID})
|
scopes = append(scopes, memoryScope{Namespace: "chat", ScopeID: req.ChatID})
|
||||||
}
|
}
|
||||||
if settings.EnablePrivateMemory && strings.TrimSpace(req.ChannelIdentityID) != "" {
|
if settings.EnablePrivateMemory && strings.TrimSpace(req.UserID) != "" {
|
||||||
scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.ChannelIdentityID})
|
scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.UserID})
|
||||||
}
|
}
|
||||||
if settings.EnablePublicMemory {
|
if settings.EnablePublicMemory {
|
||||||
scopes = append(scopes, memoryScope{Namespace: "public", ScopeID: req.BotID})
|
scopes = append(scopes, memoryScope{Namespace: "public", ScopeID: req.BotID})
|
||||||
@@ -778,7 +778,7 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M
|
|||||||
fullRound = append(fullRound, messages...)
|
fullRound = append(fullRound, messages...)
|
||||||
|
|
||||||
r.storeMessages(ctx, req, fullRound)
|
r.storeMessages(ctx, req, fullRound)
|
||||||
r.storeMemory(ctx, req.BotID, req.ChatID, req.ChannelIdentityID, req.Query, fullRound)
|
r.storeMemory(ctx, req.BotID, req.ChatID, req.UserID, req.Query, fullRound)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -805,10 +805,12 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
senderID := ""
|
senderChannelIdentityID := ""
|
||||||
|
senderUserID := ""
|
||||||
externalMessageID := ""
|
externalMessageID := ""
|
||||||
if msg.Role == "user" {
|
if msg.Role == "user" {
|
||||||
senderID = req.ChannelIdentityID
|
senderChannelIdentityID = req.SourceChannelIdentityID
|
||||||
|
senderUserID = req.UserID
|
||||||
externalMessageID = req.ExternalMessageID
|
externalMessageID = req.ExternalMessageID
|
||||||
}
|
}
|
||||||
if _, err := r.chatService.PersistMessage(
|
if _, err := r.chatService.PersistMessage(
|
||||||
@@ -816,8 +818,8 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
|||||||
req.ChatID,
|
req.ChatID,
|
||||||
req.BotID,
|
req.BotID,
|
||||||
req.RouteID,
|
req.RouteID,
|
||||||
"",
|
senderChannelIdentityID,
|
||||||
senderID,
|
senderUserID,
|
||||||
req.CurrentChannel,
|
req.CurrentChannel,
|
||||||
externalMessageID,
|
externalMessageID,
|
||||||
msg.Role,
|
msg.Role,
|
||||||
@@ -829,7 +831,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdentityID, query string, messages []ModelMessage) {
|
func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, userID, query string, messages []ModelMessage) {
|
||||||
if r.memoryService == nil {
|
if r.memoryService == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -869,9 +871,9 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdenti
|
|||||||
r.addMemory(ctx, botID, memMsgs, "chat", chatID)
|
r.addMemory(ctx, botID, memMsgs, "chat", chatID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write to private namespace if enabled and channel identity is known.
|
// Write to private namespace if enabled and user id is known.
|
||||||
if cs.EnablePrivateMemory && strings.TrimSpace(channelIdentityID) != "" {
|
if cs.EnablePrivateMemory && strings.TrimSpace(userID) != "" {
|
||||||
r.addMemory(ctx, botID, memMsgs, "private", channelIdentityID)
|
r.addMemory(ctx, botID, memMsgs, "private", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write to public namespace if enabled.
|
// Write to public namespace if enabled.
|
||||||
|
|||||||
+34
-13
@@ -70,7 +70,10 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, _ := json.Marshal(nonNilMap(req.Metadata))
|
metadata, err := json.Marshal(nonNilMap(req.Metadata))
|
||||||
|
if err != nil {
|
||||||
|
return Chat{}, fmt.Errorf("marshal chat metadata: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{
|
row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{
|
||||||
BotID: pgBotID,
|
BotID: pgBotID,
|
||||||
@@ -393,7 +396,10 @@ func (s *Service) CreateRoute(ctx context.Context, chatID string, r Route) (Rout
|
|||||||
return Route{}, err
|
return Route{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
metadata, _ := json.Marshal(nonNilMap(r.Metadata))
|
metadata, err := json.Marshal(nonNilMap(r.Metadata))
|
||||||
|
if err != nil {
|
||||||
|
return Route{}, fmt.Errorf("marshal route metadata: %w", err)
|
||||||
|
}
|
||||||
row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{
|
row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{
|
||||||
ChatID: pgChatID,
|
ChatID: pgChatID,
|
||||||
BotID: pgBotID,
|
BotID: pgBotID,
|
||||||
@@ -488,7 +494,10 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
// Route found, ensure the sender identity is a participant.
|
// Route found, ensure the sender identity is a participant.
|
||||||
if strings.TrimSpace(channelIdentityID) != "" {
|
if strings.TrimSpace(channelIdentityID) != "" {
|
||||||
ok, _ := s.IsParticipant(ctx, route.ChatID, channelIdentityID)
|
ok, checkErr := s.IsParticipant(ctx, route.ChatID, channelIdentityID)
|
||||||
|
if checkErr != nil {
|
||||||
|
return ResolveChatResult{}, fmt.Errorf("check chat participant: %w", checkErr)
|
||||||
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
if _, err := s.AddParticipant(ctx, route.ChatID, channelIdentityID, RoleMember); err != nil {
|
if _, err := s.AddParticipant(ctx, route.ChatID, channelIdentityID, RoleMember); err != nil {
|
||||||
s.logger.Warn("auto-add participant failed", slog.Any("error", err))
|
s.logger.Warn("auto-add participant failed", slog.Any("error", err))
|
||||||
@@ -497,9 +506,17 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation
|
|||||||
}
|
}
|
||||||
// Update reply target if changed.
|
// Update reply target if changed.
|
||||||
if strings.TrimSpace(replyTarget) != "" && replyTarget != route.ReplyTarget {
|
if strings.TrimSpace(replyTarget) != "" && replyTarget != route.ReplyTarget {
|
||||||
_ = s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget)
|
if err := s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget); err != nil && s.logger != nil {
|
||||||
|
s.logger.Warn("update route reply target failed", slog.Any("error", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pgRouteChatID, parseErr := parseUUID(route.ChatID)
|
||||||
|
if parseErr != nil {
|
||||||
|
return ResolveChatResult{}, fmt.Errorf("parse route chat id: %w", parseErr)
|
||||||
|
}
|
||||||
|
if err := s.queries.TouchChat(ctx, pgRouteChatID); err != nil && s.logger != nil {
|
||||||
|
s.logger.Warn("touch chat failed", slog.Any("error", err))
|
||||||
}
|
}
|
||||||
_ = s.queries.TouchChat(ctx, mustParseUUID(route.ChatID))
|
|
||||||
return ResolveChatResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil
|
return ResolveChatResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -564,13 +581,22 @@ func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, se
|
|||||||
}
|
}
|
||||||
var pgSender pgtype.UUID
|
var pgSender pgtype.UUID
|
||||||
if strings.TrimSpace(senderChannelIdentityID) != "" {
|
if strings.TrimSpace(senderChannelIdentityID) != "" {
|
||||||
pgSender, _ = parseUUID(senderChannelIdentityID)
|
pgSender, err = parseUUID(senderChannelIdentityID)
|
||||||
|
if err != nil {
|
||||||
|
return Message{}, fmt.Errorf("invalid sender channel identity id: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var pgSenderUser pgtype.UUID
|
var pgSenderUser pgtype.UUID
|
||||||
if strings.TrimSpace(senderUserID) != "" {
|
if strings.TrimSpace(senderUserID) != "" {
|
||||||
pgSenderUser, _ = parseUUID(senderUserID)
|
pgSenderUser, err = parseUUID(senderUserID)
|
||||||
|
if err != nil {
|
||||||
|
return Message{}, fmt.Errorf("invalid sender user id: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metaBytes, err := json.Marshal(nonNilMap(metadata))
|
||||||
|
if err != nil {
|
||||||
|
return Message{}, fmt.Errorf("marshal message metadata: %w", err)
|
||||||
}
|
}
|
||||||
metaBytes, _ := json.Marshal(nonNilMap(metadata))
|
|
||||||
if len(content) == 0 {
|
if len(content) == 0 {
|
||||||
content = []byte("{}")
|
content = []byte("{}")
|
||||||
}
|
}
|
||||||
@@ -813,11 +839,6 @@ func pgTimePtr(ts pgtype.Timestamptz) *time.Time {
|
|||||||
return &value
|
return &value
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustParseUUID(id string) pgtype.UUID {
|
|
||||||
pgID, _ := parseUUID(id)
|
|
||||||
return pgID
|
|
||||||
}
|
|
||||||
|
|
||||||
func nonNilMap(m map[string]any) map[string]any {
|
func nonNilMap(m map[string]any) map[string]any {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return map[string]any{}
|
return map[string]any{}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -56,17 +55,6 @@ func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLegacyChatPresenceSchemaError(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
msg := strings.ToLower(err.Error())
|
|
||||||
return strings.Contains(msg, "relation \"chat_channelIdentity_presence\" does not exist") ||
|
|
||||||
strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") ||
|
|
||||||
strings.Contains(msg, "column \"sender_user_id\" of relation \"chat_messages\" does not exist") ||
|
|
||||||
strings.Contains(msg, "column \"created_by_user_id\" of relation \"chats\" does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) {
|
||||||
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
|
||||||
IsActive: true,
|
IsActive: true,
|
||||||
@@ -83,7 +71,10 @@ func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerU
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
meta, _ := json.Marshal(map[string]any{"source": "chat-presence-integration-test"})
|
meta, err := json.Marshal(map[string]any{"source": "chat-presence-integration-test"})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
|
||||||
OwnerUserID: pgOwnerID,
|
OwnerUserID: pgOwnerID,
|
||||||
Type: "personal",
|
Type: "personal",
|
||||||
@@ -105,10 +96,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
|||||||
|
|
||||||
ownerUserID, err := createUserForChatPresence(ctx, fixture.queries)
|
ownerUserID, err := createUserForChatPresence(ctx, fixture.queries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyChatPresenceSchemaError(err) {
|
|
||||||
fixture.cleanup()
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
fixture.cleanup()
|
fixture.cleanup()
|
||||||
t.Fatalf("create owner user failed: %v", err)
|
t.Fatalf("create owner user failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -128,10 +115,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
|||||||
Title: "presence-observed",
|
Title: "presence-observed",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyChatPresenceSchemaError(err) {
|
|
||||||
fixture.cleanup()
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
fixture.cleanup()
|
fixture.cleanup()
|
||||||
t.Fatalf("create chat failed: %v", err)
|
t.Fatalf("create chat failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -143,10 +126,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
|||||||
"presence-observer",
|
"presence-observer",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyChatPresenceSchemaError(err) {
|
|
||||||
fixture.cleanup()
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
fixture.cleanup()
|
fixture.cleanup()
|
||||||
t.Fatalf("resolve channelIdentity failed: %v", err)
|
t.Fatalf("resolve channelIdentity failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -165,10 +144,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLegacyChatPresenceSchemaError(err) {
|
|
||||||
fixture.cleanup()
|
|
||||||
t.Skipf("skip integration test on legacy schema: %v", err)
|
|
||||||
}
|
|
||||||
fixture.cleanup()
|
fixture.cleanup()
|
||||||
t.Fatalf("persist message failed: %v", err)
|
t.Fatalf("persist message failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+10
-9
@@ -234,15 +234,16 @@ type ToolCallFunction struct {
|
|||||||
|
|
||||||
// ChatRequest is the input for Chat and StreamChat.
|
// ChatRequest is the input for Chat and StreamChat.
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
BotID string `json:"-"`
|
BotID string `json:"-"`
|
||||||
ChatID string `json:"-"`
|
ChatID string `json:"-"`
|
||||||
Token string `json:"-"`
|
Token string `json:"-"`
|
||||||
ChannelIdentityID string `json:"-"`
|
UserID string `json:"-"`
|
||||||
ContainerID string `json:"-"`
|
SourceChannelIdentityID string `json:"-"`
|
||||||
DisplayName string `json:"-"`
|
ContainerID string `json:"-"`
|
||||||
RouteID string `json:"-"`
|
DisplayName string `json:"-"`
|
||||||
ChatToken string `json:"-"`
|
RouteID string `json:"-"`
|
||||||
ExternalMessageID string `json:"-"`
|
ChatToken string `json:"-"`
|
||||||
|
ExternalMessageID string `json:"-"`
|
||||||
|
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func (h *BindHandler) Issue(c echo.Context) error {
|
|||||||
if h.service == nil {
|
if h.service == nil {
|
||||||
return echo.NewHTTPError(http.StatusServiceUnavailable, "bind service not available")
|
return echo.NewHTTPError(http.StatusServiceUnavailable, "bind service not available")
|
||||||
}
|
}
|
||||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
userID, err := h.requireUserID(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -68,7 +68,7 @@ func (h *BindHandler) Issue(c echo.Context) error {
|
|||||||
ttl = time.Duration(req.TTLSeconds) * time.Second
|
ttl = time.Duration(req.TTLSeconds) * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
code, err := h.service.Issue(c.Request().Context(), channelIdentityID, strings.TrimSpace(req.Platform), ttl)
|
code, err := h.service.Issue(c.Request().Context(), userID, strings.TrimSpace(req.Platform), ttl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
@@ -79,13 +79,13 @@ func (h *BindHandler) Issue(c echo.Context) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *BindHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *BindHandler) requireUserID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
userID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
return channelIdentityID, nil
|
return userID, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -189,7 +189,8 @@ func (h *ChatHandler) SendMessage(c echo.Context) error {
|
|||||||
req.BotID = chatObj.BotID
|
req.BotID = chatObj.BotID
|
||||||
req.ChatID = chatID
|
req.ChatID = chatID
|
||||||
req.Token = c.Request().Header.Get("Authorization")
|
req.Token = c.Request().Header.Get("Authorization")
|
||||||
req.ChannelIdentityID = channelIdentityID
|
req.UserID = channelIdentityID
|
||||||
|
req.SourceChannelIdentityID = channelIdentityID
|
||||||
|
|
||||||
resp, err := h.resolver.Chat(c.Request().Context(), req)
|
resp, err := h.resolver.Chat(c.Request().Context(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -224,7 +225,8 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error {
|
|||||||
req.BotID = chatObj.BotID
|
req.BotID = chatObj.BotID
|
||||||
req.ChatID = chatID
|
req.ChatID = chatID
|
||||||
req.Token = c.Request().Header.Get("Authorization")
|
req.Token = c.Request().Header.Get("Authorization")
|
||||||
req.ChannelIdentityID = channelIdentityID
|
req.UserID = channelIdentityID
|
||||||
|
req.SourceChannelIdentityID = channelIdentityID
|
||||||
|
|
||||||
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
|
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
|
||||||
c.Response().Header().Set(echo.HeaderCacheControl, "no-cache")
|
c.Response().Header().Set(echo.HeaderCacheControl, "no-cache")
|
||||||
@@ -258,7 +260,10 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("chat stream failed", slog.Any("error", err))
|
h.logger.Error("chat stream failed", slog.Any("error", err))
|
||||||
errData := map[string]string{"error": err.Error()}
|
errData := map[string]string{"error": err.Error()}
|
||||||
data, _ := json.Marshal(errData)
|
data, marshalErr := json.Marshal(errData)
|
||||||
|
if marshalErr != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, marshalErr.Error())
|
||||||
|
}
|
||||||
writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data)))
|
writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data)))
|
||||||
writer.Flush()
|
writer.Flush()
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -500,7 +505,7 @@ func (h *ChatHandler) ListThreads(c echo.Context) error {
|
|||||||
// --- helpers ---
|
// --- helpers ---
|
||||||
|
|
||||||
func (h *ChatHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *ChatHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -558,7 +563,10 @@ func (h *ChatHandler) requireParticipant(ctx context.Context, chatID, channelIde
|
|||||||
}
|
}
|
||||||
// Admin bypass.
|
// Admin bypass.
|
||||||
if h.accountService != nil {
|
if h.accountService != nil {
|
||||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
if isAdmin {
|
if isAdmin {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -579,7 +587,10 @@ func (h *ChatHandler) requireReadable(ctx context.Context, chatID, channelIdenti
|
|||||||
}
|
}
|
||||||
// Admin bypass.
|
// Admin bypass.
|
||||||
if h.accountService != nil {
|
if h.accountService != nil {
|
||||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
if isAdmin {
|
if isAdmin {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -600,7 +611,10 @@ func (h *ChatHandler) requireRole(ctx context.Context, chatID, channelIdentityID
|
|||||||
}
|
}
|
||||||
// Admin bypass.
|
// Admin bypass.
|
||||||
if h.accountService != nil {
|
if h.accountService != nil {
|
||||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
if isAdmin {
|
if isAdmin {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -627,7 +627,7 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -85,10 +85,10 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
|
|||||||
req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL)
|
req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL)
|
||||||
req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL)
|
req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL)
|
||||||
|
|
||||||
channelIdentityID := ""
|
userID := ""
|
||||||
if c.Get("user") != nil {
|
if c.Get("user") != nil {
|
||||||
if value, err := auth.ChannelIdentityIDFromContext(c); err == nil {
|
if value, err := auth.UserIDFromContext(c); err == nil {
|
||||||
channelIdentityID = value
|
userID = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{
|
result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{
|
||||||
@@ -101,7 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
|
|||||||
ImageURL: req.Input.ImageURL,
|
ImageURL: req.Input.ImageURL,
|
||||||
VideoURL: req.Input.VideoURL,
|
VideoURL: req.Input.VideoURL,
|
||||||
},
|
},
|
||||||
ChannelIdentityID: channelIdentityID,
|
ChannelIdentityID: userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := err.Error()
|
message := err.Error()
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ func (h *LocalChannelHandler) ensureChatParticipant(ctx context.Context, chatID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ func (h *MCPHandler) Delete(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
userID, err := auth.ChannelIdentityIDFromContext(c)
|
userID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -364,7 +364,10 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
|
|||||||
return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured")
|
return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured")
|
||||||
}
|
}
|
||||||
if h.accountService != nil {
|
if h.accountService != nil {
|
||||||
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
|
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
if isAdmin {
|
if isAdmin {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -380,7 +383,7 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func (h *ModelsHandler) Enable(c echo.Context) error {
|
|||||||
if h.settingsService == nil {
|
if h.settingsService == nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured")
|
return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured")
|
||||||
}
|
}
|
||||||
userID, err := auth.ChannelIdentityIDFromContext(c)
|
userID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ type preauthIssueRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *PreauthHandler) Issue(c echo.Context) error {
|
func (h *PreauthHandler) Issue(c echo.Context) error {
|
||||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
userID, err := h.requireUserID(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -48,7 +48,7 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
|
|||||||
if botID == "" {
|
if botID == "" {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||||
}
|
}
|
||||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var req preauthIssueRequest
|
var req preauthIssueRequest
|
||||||
@@ -59,33 +59,33 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
|
|||||||
if req.TTLSeconds > 0 {
|
if req.TTLSeconds > 0 {
|
||||||
ttl = time.Duration(req.TTLSeconds) * time.Second
|
ttl = time.Duration(req.TTLSeconds) * time.Second
|
||||||
}
|
}
|
||||||
key, err := h.service.Issue(c.Request().Context(), botID, channelIdentityID, ttl)
|
key, err := h.service.Issue(c.Request().Context(), botID, userID, ttl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, key)
|
return c.JSON(http.StatusOK, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PreauthHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
userID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
return channelIdentityID, nil
|
return userID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||||
if h.botService == nil || h.accountService == nil {
|
if h.botService == nil || h.accountService == nil {
|
||||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||||
}
|
}
|
||||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, bots.ErrBotNotFound) {
|
if errors.Is(err, bots.ErrBotNotFound) {
|
||||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (h *ScheduleHandler) Register(e *echo.Echo) {
|
|||||||
// @Failure 500 {object} ErrorResponse
|
// @Failure 500 {object} ErrorResponse
|
||||||
// @Router /bots/{bot_id}/schedule [post]
|
// @Router /bots/{bot_id}/schedule [post]
|
||||||
func (h *ScheduleHandler) Create(c echo.Context) error {
|
func (h *ScheduleHandler) Create(c echo.Context) error {
|
||||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
userID, err := h.requireUserID(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -59,7 +59,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error {
|
|||||||
if botID == "" {
|
if botID == "" {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||||
}
|
}
|
||||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var req schedule.CreateRequest
|
var req schedule.CreateRequest
|
||||||
@@ -82,7 +82,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error {
|
|||||||
// @Failure 500 {object} ErrorResponse
|
// @Failure 500 {object} ErrorResponse
|
||||||
// @Router /bots/{bot_id}/schedule [get]
|
// @Router /bots/{bot_id}/schedule [get]
|
||||||
func (h *ScheduleHandler) List(c echo.Context) error {
|
func (h *ScheduleHandler) List(c echo.Context) error {
|
||||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
userID, err := h.requireUserID(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@ func (h *ScheduleHandler) List(c echo.Context) error {
|
|||||||
if botID == "" {
|
if botID == "" {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||||
}
|
}
|
||||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
items, err := h.service.List(c.Request().Context(), botID)
|
items, err := h.service.List(c.Request().Context(), botID)
|
||||||
@@ -111,7 +111,7 @@ func (h *ScheduleHandler) List(c echo.Context) error {
|
|||||||
// @Failure 500 {object} ErrorResponse
|
// @Failure 500 {object} ErrorResponse
|
||||||
// @Router /bots/{bot_id}/schedule/{id} [get]
|
// @Router /bots/{bot_id}/schedule/{id} [get]
|
||||||
func (h *ScheduleHandler) Get(c echo.Context) error {
|
func (h *ScheduleHandler) Get(c echo.Context) error {
|
||||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
userID, err := h.requireUserID(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -130,7 +130,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error {
|
|||||||
if item.BotID != botID {
|
if item.BotID != botID {
|
||||||
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
||||||
}
|
}
|
||||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, item)
|
return c.JSON(http.StatusOK, item)
|
||||||
@@ -147,7 +147,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error {
|
|||||||
// @Failure 500 {object} ErrorResponse
|
// @Failure 500 {object} ErrorResponse
|
||||||
// @Router /bots/{bot_id}/schedule/{id} [put]
|
// @Router /bots/{bot_id}/schedule/{id} [put]
|
||||||
func (h *ScheduleHandler) Update(c echo.Context) error {
|
func (h *ScheduleHandler) Update(c echo.Context) error {
|
||||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
userID, err := h.requireUserID(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -170,7 +170,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error {
|
|||||||
if item.BotID != botID {
|
if item.BotID != botID {
|
||||||
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
||||||
}
|
}
|
||||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
resp, err := h.service.Update(c.Request().Context(), id, req)
|
resp, err := h.service.Update(c.Request().Context(), id, req)
|
||||||
@@ -190,7 +190,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error {
|
|||||||
// @Failure 500 {object} ErrorResponse
|
// @Failure 500 {object} ErrorResponse
|
||||||
// @Router /bots/{bot_id}/schedule/{id} [delete]
|
// @Router /bots/{bot_id}/schedule/{id} [delete]
|
||||||
func (h *ScheduleHandler) Delete(c echo.Context) error {
|
func (h *ScheduleHandler) Delete(c echo.Context) error {
|
||||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
userID, err := h.requireUserID(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -209,7 +209,7 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
|
|||||||
if item.BotID != botID {
|
if item.BotID != botID {
|
||||||
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
|
||||||
}
|
}
|
||||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := h.service.Delete(c.Request().Context(), id); err != nil {
|
if err := h.service.Delete(c.Request().Context(), id); err != nil {
|
||||||
@@ -218,26 +218,26 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
|
|||||||
return c.NoContent(http.StatusNoContent)
|
return c.NoContent(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ScheduleHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
userID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
return channelIdentityID, nil
|
return userID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||||
if h.botService == nil || h.accountService == nil {
|
if h.botService == nil || h.accountService == nil {
|
||||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||||
}
|
}
|
||||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, bots.ErrBotNotFound) {
|
if errors.Is(err, bots.ErrBotNotFound) {
|
||||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ func (h *SettingsHandler) Delete(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -907,7 +907,7 @@ func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||||
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
|
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
+12
-11
@@ -167,17 +167,18 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
|||||||
desc, _ = p.registry.GetDescriptor(msg.Channel)
|
desc, _ = p.registry.GetDescriptor(msg.Channel)
|
||||||
}
|
}
|
||||||
resp, err := p.chat.Chat(ctx, chat.ChatRequest{
|
resp, err := p.chat.Chat(ctx, chat.ChatRequest{
|
||||||
BotID: identity.BotID,
|
BotID: identity.BotID,
|
||||||
ChatID: resolved.ChatID,
|
ChatID: resolved.ChatID,
|
||||||
Token: token,
|
Token: token,
|
||||||
ChannelIdentityID: identity.UserID,
|
UserID: identity.UserID,
|
||||||
DisplayName: identity.DisplayName,
|
SourceChannelIdentityID: identity.ChannelIdentityID,
|
||||||
RouteID: resolved.RouteID,
|
DisplayName: identity.DisplayName,
|
||||||
ChatToken: chatToken,
|
RouteID: resolved.RouteID,
|
||||||
ExternalMessageID: strings.TrimSpace(msg.Message.ID),
|
ChatToken: chatToken,
|
||||||
Query: text,
|
ExternalMessageID: strings.TrimSpace(msg.Message.ID),
|
||||||
CurrentChannel: msg.Channel.String(),
|
Query: text,
|
||||||
Channels: []string{msg.Channel.String()},
|
CurrentChannel: msg.Channel.String(),
|
||||||
|
Channels: []string{msg.Channel.String()},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if p.logger != nil {
|
if p.logger != nil {
|
||||||
|
|||||||
@@ -97,8 +97,11 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) {
|
|||||||
if gateway.gotReq.Query != "hello" {
|
if gateway.gotReq.Query != "hello" {
|
||||||
t.Errorf("expected query 'hello', got: %s", gateway.gotReq.Query)
|
t.Errorf("expected query 'hello', got: %s", gateway.gotReq.Query)
|
||||||
}
|
}
|
||||||
if gateway.gotReq.ChannelIdentityID != "channelIdentity-1" {
|
if gateway.gotReq.UserID != "channelIdentity-1" {
|
||||||
t.Errorf("expected channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.ChannelIdentityID)
|
t.Errorf("expected user_id 'channelIdentity-1', got: %s", gateway.gotReq.UserID)
|
||||||
|
}
|
||||||
|
if gateway.gotReq.SourceChannelIdentityID != "channelIdentity-1" {
|
||||||
|
t.Errorf("expected source_channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.SourceChannelIdentityID)
|
||||||
}
|
}
|
||||||
if gateway.gotReq.ChatID != "chat-1" {
|
if gateway.gotReq.ChatID != "chat-1" {
|
||||||
t.Errorf("expected chat_id 'chat-1', got: %s", gateway.gotReq.ChatID)
|
t.Errorf("expected chat_id 'chat-1', got: %s", gateway.gotReq.ChatID)
|
||||||
|
|||||||
+35
-13
@@ -209,7 +209,9 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
|
|||||||
state.Decision = &decision
|
state.Decision = &decision
|
||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
if r.policy != nil && isGroupConversationType(msg.Conversation.Type) {
|
|
||||||
|
// Personal bots are owner-only and must not depend on member/guest/preauth bypass.
|
||||||
|
if r.policy != nil {
|
||||||
botType, err := r.policy.BotType(ctx, botID)
|
botType, err := r.policy.BotType(ctx, botID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state, err
|
return state, err
|
||||||
@@ -219,20 +221,35 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(state.Identity.UserID) == "" || strings.TrimSpace(ownerUserID) != strings.TrimSpace(state.Identity.UserID) {
|
isOwner := strings.TrimSpace(state.Identity.UserID) != "" &&
|
||||||
// Personal bots in group chats only answer owner messages.
|
strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID)
|
||||||
state.Decision = &IdentityDecision{Stop: true}
|
if !isOwner {
|
||||||
|
if isGroupConversationType(msg.Conversation.Type) {
|
||||||
|
// Ignore non-owner group messages for personal bots.
|
||||||
|
state.Decision = &IdentityDecision{Stop: true}
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
state.Decision = &IdentityDecision{
|
||||||
|
Stop: true,
|
||||||
|
Reply: channel.Message{Text: r.unboundReply},
|
||||||
|
}
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
// Owner can chat normally in group for personal bots.
|
if isGroupConversationType(msg.Conversation.Type) {
|
||||||
state.Identity.ForceReply = true
|
// Owner can chat in group for personal bots.
|
||||||
|
state.Identity.ForceReply = true
|
||||||
|
}
|
||||||
|
return state, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 2: Authorization (bot membership check).
|
// Phase 2: Authorization (bot membership check).
|
||||||
if r.members != nil {
|
if r.members != nil {
|
||||||
if strings.TrimSpace(state.Identity.UserID) != "" {
|
if strings.TrimSpace(state.Identity.UserID) != "" {
|
||||||
isMember, _ := r.members.IsMember(ctx, botID, state.Identity.UserID)
|
isMember, err := r.members.IsMember(ctx, botID, state.Identity.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return state, fmt.Errorf("check bot membership: %w", err)
|
||||||
|
}
|
||||||
if isMember {
|
if isMember {
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
@@ -256,9 +273,6 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
|
|||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
if allowed {
|
if allowed {
|
||||||
if r.members != nil && strings.TrimSpace(state.Identity.UserID) != "" {
|
|
||||||
_ = r.members.UpsertMemberSimple(ctx, botID, state.Identity.UserID, "member")
|
|
||||||
}
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -312,9 +326,13 @@ func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.
|
|||||||
return true, reply("Current channel account is not linked to a user."), nil
|
return true, reply("Current channel account is not linked to a user."), nil
|
||||||
}
|
}
|
||||||
if r.members != nil {
|
if r.members != nil {
|
||||||
_ = r.members.UpsertMemberSimple(ctx, botID, userID, "member")
|
if err := r.members.UpsertMemberSimple(ctx, botID, userID, "member"); err != nil {
|
||||||
|
return true, IdentityDecision{}, fmt.Errorf("upsert preauth member: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := r.preauth.MarkUsed(ctx, key.ID); err != nil {
|
||||||
|
return true, IdentityDecision{}, fmt.Errorf("mark preauth key used: %w", err)
|
||||||
}
|
}
|
||||||
_, _ = r.preauth.MarkUsed(ctx, key.ID)
|
|
||||||
return true, reply(r.preauthReply), nil
|
return true, reply(r.preauthReply), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,7 +383,11 @@ func (r *IdentityResolver) tryHandleBindCode(ctx context.Context, msg channel.In
|
|||||||
// Resolve linked user after binding.
|
// Resolve linked user after binding.
|
||||||
newUserID := code.IssuedByUserID
|
newUserID := code.IssuedByUserID
|
||||||
if r.channelIdentities != nil {
|
if r.channelIdentities != nil {
|
||||||
if linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID); err == nil && strings.TrimSpace(linkedUserID) != "" {
|
linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID)
|
||||||
|
if err != nil {
|
||||||
|
return true, IdentityDecision{}, "", fmt.Errorf("resolve linked user after bind: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(linkedUserID) != "" {
|
||||||
newUserID = linkedUserID
|
newUserID = linkedUserID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelCh
|
|||||||
return f.consumeErr
|
return f.consumeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) {
|
func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) {
|
||||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}}
|
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}}
|
||||||
memberSvc := &fakeMemberService{isMember: false}
|
memberSvc := &fakeMemberService{isMember: false}
|
||||||
policySvc := &fakePolicyService{allow: true, botType: "public"}
|
policySvc := &fakePolicyService{allow: true, botType: "public"}
|
||||||
@@ -165,8 +165,8 @@ func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) {
|
|||||||
if state.Identity.ChannelIdentityID != "channelIdentity-1" {
|
if state.Identity.ChannelIdentityID != "channelIdentity-1" {
|
||||||
t.Fatalf("expected channelIdentity-1, got: %s", state.Identity.ChannelIdentityID)
|
t.Fatalf("expected channelIdentity-1, got: %s", state.Identity.ChannelIdentityID)
|
||||||
}
|
}
|
||||||
if !memberSvc.upsertCalled {
|
if memberSvc.upsertCalled {
|
||||||
t.Fatal("expected UpsertMemberSimple to be called")
|
t.Fatal("guest allow should not upsert membership")
|
||||||
}
|
}
|
||||||
if state.Decision != nil {
|
if state.Decision != nil {
|
||||||
t.Fatal("expected no decision for allowed guest")
|
t.Fatal("expected no decision for allowed guest")
|
||||||
@@ -376,6 +376,35 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing.T) {
|
||||||
|
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-non-owner"}}
|
||||||
|
memberSvc := &fakeMemberService{isMember: true}
|
||||||
|
policySvc := &fakePolicyService{allow: true, botType: "personal", ownerUserID: "channelIdentity-owner"}
|
||||||
|
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "")
|
||||||
|
|
||||||
|
msg := channel.InboundMessage{
|
||||||
|
BotID: "bot-1",
|
||||||
|
Channel: channel.ChannelType("feishu"),
|
||||||
|
Message: channel.Message{Text: "hello from non-owner"},
|
||||||
|
Sender: channel.Identity{SubjectID: "ext-non-owner"},
|
||||||
|
Conversation: channel.Conversation{
|
||||||
|
ID: "p2p-2",
|
||||||
|
Type: "p2p",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if state.Decision == nil || !state.Decision.Stop {
|
||||||
|
t.Fatal("non-owner direct message should be rejected for personal bot")
|
||||||
|
}
|
||||||
|
if state.Decision.Reply.Text != "Access denied." {
|
||||||
|
t.Fatalf("unexpected reject message: %s", state.Decision.Reply.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) {
|
func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) {
|
||||||
shadowID := "channelIdentity-shadow"
|
shadowID := "channelIdentity-shadow"
|
||||||
humanID := "channelIdentity-human"
|
humanID := "channelIdentity-human"
|
||||||
|
|||||||
@@ -55,8 +55,8 @@ func TestGenerateTriggerToken(t *testing.T) {
|
|||||||
if sub, _ := claims["sub"].(string); sub != userID {
|
if sub, _ := claims["sub"].(string); sub != userID {
|
||||||
t.Errorf("expected sub=%s, got=%s", userID, sub)
|
t.Errorf("expected sub=%s, got=%s", userID, sub)
|
||||||
}
|
}
|
||||||
if uid, _ := claims["channel_identity_id"].(string); uid != userID {
|
if uid, _ := claims["user_id"].(string); uid != userID {
|
||||||
t.Errorf("expected channel_identity_id=%s, got=%s", userID, uid)
|
t.Errorf("expected user_id=%s, got=%s", userID, uid)
|
||||||
}
|
}
|
||||||
exp, _ := claims["exp"].(float64)
|
exp, _ := claims["exp"].(float64)
|
||||||
if exp == 0 {
|
if exp == 0 {
|
||||||
|
|||||||
Reference in New Issue
Block a user