refactor(core): finalize user-centric identity and policy cleanup

Unify auth and chat identity semantics around user_id, enforce personal-bot owner-only authorization, and remove legacy compatibility branches in integration tests.
This commit is contained in:
BBQ
2026-02-11 15:42:21 +08:00
parent 06e8619a37
commit 02b33c8e85
27 changed files with 246 additions and 602 deletions
-9
View File
@@ -52,17 +52,9 @@ func UserIDFromContext(c echo.Context) (string, error) {
if userID := claimString(claims, claimSubject); userID != "" {
return userID, nil
}
if legacyChannelIdentityID := claimString(claims, claimChannelIdentityID); legacyChannelIdentityID != "" {
return legacyChannelIdentityID, nil
}
return "", echo.NewHTTPError(http.StatusUnauthorized, "user id missing")
}
// ChannelIdentityIDFromContext is kept as compatibility alias and returns user id.
func ChannelIdentityIDFromContext(c echo.Context) (string, error) {
return UserIDFromContext(c)
}
// GenerateToken creates a signed JWT for the user.
func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) {
if strings.TrimSpace(userID) == "" {
@@ -80,7 +72,6 @@ func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time
claims := jwt.MapClaims{
claimSubject: userID,
claimUserID: userID,
claimChannelIdentityID: userID, // legacy compatibility for handlers still reading channel_identity_id
"iat": now.Unix(),
"exp": expiresAt.Unix(),
}
+12 -401
View File
@@ -15,8 +15,8 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/memohai/memoh/internal/channelidentities"
"github.com/memohai/memoh/internal/bind"
"github.com/memohai/memoh/internal/channelidentities"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
@@ -47,7 +47,7 @@ func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.S
return queries, channelIdentitySvc, bindSvc, func() { pool.Close() }
}
func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) {
func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) {
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
IsActive: true,
Metadata: []byte("{}"),
@@ -58,12 +58,15 @@ func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) {
return db.UUIDToString(row.ID), nil
}
func createBot(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
pgOwnerID, err := db.ParseUUID(ownerUserID)
if err != nil {
return "", err
}
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"})
if err != nil {
return "", err
}
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
OwnerUserID: pgOwnerID,
Type: "personal",
@@ -82,15 +85,15 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
defer cleanup()
ctx := context.Background()
ownerUserID, err := createUser(ctx, queries)
ownerUserID, err := createUserForBindTest(ctx, queries)
if err != nil {
t.Fatalf("create owner user failed: %v", err)
}
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
if err != nil {
t.Fatalf("create source channelIdentity failed: %v", err)
t.Fatalf("create source channel identity failed: %v", err)
}
botID, err := createBot(ctx, queries, ownerUserID)
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
if err != nil {
t.Fatalf("create bot failed: %v", err)
}
@@ -127,177 +130,6 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
}
}
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
defer cleanup()
ctx := context.Background()
ownerUserID, err := createUser(ctx, queries)
if err != nil {
t.Fatalf("create owner user failed: %v", err)
}
otherUserID, err := createUser(ctx, queries)
if err != nil {
t.Fatalf("create other user failed: %v", err)
}
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
if err != nil {
t.Fatalf("create source channelIdentity failed: %v", err)
}
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil {
t.Fatalf("pre-link source channelIdentity failed: %v", err)
}
botID, err := createBot(ctx, queries, ownerUserID)
if err != nil {
t.Fatalf("create bot failed: %v", err)
}
code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute)
if err != nil {
t.Fatalf("issue bind code failed: %v", err)
}
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) {
t.Fatalf("expected ErrLinkConflict, got %v", err)
}
after, err := bindSvc.Get(ctx, code.Token)
if err != nil {
t.Fatalf("get bind code failed: %v", err)
}
if !after.UsedAt.IsZero() {
t.Fatal("expected used_at to remain empty when consume fails")
}
}
package bind_test
import (
"context"
"encoding/json"
"errors"
"log/slog"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/memohai/memoh/internal/channelidentities"
"github.com/memohai/memoh/internal/bind"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) {
t.Helper()
dsn := os.Getenv("TEST_POSTGRES_DSN")
if dsn == "" {
t.Skip("skip integration test: TEST_POSTGRES_DSN is not set")
}
ctx := context.Background()
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Skipf("skip integration test: cannot connect to database: %v", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
t.Skipf("skip integration test: database ping failed: %v", err)
}
queries := sqlc.New(pool)
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
channelIdentitySvc := channelidentities.NewService(logger, queries)
bindSvc := bind.NewService(logger, pool, queries)
cleanup := func() {
pool.Close()
}
return queries, channelIdentitySvc, bindSvc, cleanup
}
func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) {
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
IsActive: true,
Metadata: []byte("{}"),
})
if err != nil {
return "", err
}
return db.UUIDToString(row.ID), nil
}
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) {
pgOwnerID, err := db.ParseUUID(ownerUserID)
if err != nil {
return "", err
}
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
OwnerUserID: pgOwnerID,
Type: "personal",
DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true},
AvatarUrl: pgtype.Text{},
IsActive: true,
Metadata: meta,
})
if err != nil {
return "", err
}
return db.UUIDToString(row.ID), nil
}
func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
defer cleanup()
ctx := context.Background()
ownerUserID, err := createUserForBindTest(ctx, queries)
if err != nil {
t.Fatalf("create owner user failed: %v", err)
}
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
if err != nil {
t.Fatalf("create source channelIdentity failed: %v", err)
}
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
if err != nil {
t.Fatalf("create bot failed: %v", err)
}
code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute)
if err != nil {
t.Fatalf("issue bind code failed: %v", err)
}
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil {
t.Fatalf("consume bind code failed: %v", err)
}
after, err := bindSvc.Get(ctx, code.Token)
if err != nil {
t.Fatalf("get bind code failed: %v", err)
}
if after.UsedAt.IsZero() {
t.Fatal("expected used_at to be set after successful consume")
}
if after.UsedByChannelIdentityID != sourceChannelIdentity.ID {
t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID)
}
linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID)
if err != nil {
t.Fatalf("get linked user failed: %v", err)
}
if linkedUserID != ownerUserID {
t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID)
}
if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrCodeUsed) {
t.Fatalf("expected ErrCodeUsed on second consume, got %v", err)
}
}
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
defer cleanup()
@@ -313,10 +145,10 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
}
sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel)
if err != nil {
t.Fatalf("create source channelIdentity failed: %v", err)
t.Fatalf("create source channel identity failed: %v", err)
}
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil {
t.Fatalf("pre-link source channelIdentity failed: %v", err)
t.Fatalf("pre-link source channel identity failed: %v", err)
}
botID, err := createBotForBindTest(ctx, queries, ownerUserID)
if err != nil {
@@ -339,224 +171,3 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
t.Fatal("expected used_at to remain empty when consume fails")
}
}
package bind_test
import (
"context"
"encoding/json"
"errors"
"log/slog"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/memohai/memoh/internal/channelidentities"
"github.com/memohai/memoh/internal/bind"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) {
t.Helper()
dsn := os.Getenv("TEST_POSTGRES_DSN")
if dsn == "" {
t.Skip("skip integration test: TEST_POSTGRES_DSN is not set")
}
ctx := context.Background()
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Skipf("skip integration test: cannot connect to database: %v", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
t.Skipf("skip integration test: database ping failed: %v", err)
}
queries := sqlc.New(pool)
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
channelIdentitySvc := channelidentities.NewService(logger, queries)
bindSvc := bind.NewService(logger, pool, queries)
cleanup := func() {
pool.Close()
}
return queries, channelIdentitySvc, bindSvc, cleanup
}
func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerChannelIdentityID string) (string, error) {
pgOwnerID, err := db.ParseUUID(ownerChannelIdentityID)
if err != nil {
return "", err
}
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
OwnerChannelIdentityID: pgOwnerID,
Type: "personal",
DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true},
AvatarUrl: pgtype.Text{},
IsActive: true,
Metadata: meta,
})
if err != nil {
return "", err
}
return db.UUIDToString(row.ID), nil
}
func createChatForBindTest(ctx context.Context, queries *sqlc.Queries, botID, channelIdentityID string) (string, error) {
pgBotID, err := db.ParseUUID(botID)
if err != nil {
return "", err
}
pgChannelIdentityID, err := db.ParseUUID(channelIdentityID)
if err != nil {
return "", err
}
row, err := queries.CreateChat(ctx, sqlc.CreateChatParams{
BotID: pgBotID,
Kind: "direct",
ParentChatID: pgtype.UUID{},
Title: pgtype.Text{},
CreatedBy: pgChannelIdentityID,
Metadata: []byte("{}"),
})
if err != nil {
return "", err
}
return db.UUIDToString(row.ID), nil
}
func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) {
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
defer cleanup()
ctx := context.Background()
human, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
if err != nil {
t.Fatalf("create human failed: %v", err)
}
shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow)
if err != nil {
t.Fatalf("create shadow failed: %v", err)
}
botID, err := createBotForBindTest(ctx, queries, human.ID)
if err != nil {
t.Fatalf("create bot failed: %v", err)
}
chatID, err := createChatForBindTest(ctx, queries, botID, human.ID)
if err != nil {
t.Fatalf("create chat failed: %v", err)
}
pgChatID, err := db.ParseUUID(chatID)
if err != nil {
t.Fatalf("parse chat id failed: %v", err)
}
pgShadowID, err := db.ParseUUID(shadow.ID)
if err != nil {
t.Fatalf("parse shadow id failed: %v", err)
}
pgHumanID, err := db.ParseUUID(human.ID)
if err != nil {
t.Fatalf("parse human id failed: %v", err)
}
if _, err := queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{
ChatID: pgChatID,
ChannelIdentityID: pgShadowID,
Role: "member",
}); err != nil {
t.Fatalf("add shadow participant failed: %v", err)
}
code, err := bindSvc.Issue(ctx, botID, human.ID, 10*time.Minute)
if err != nil {
t.Fatalf("issue bind code failed: %v", err)
}
if err := bindSvc.Consume(ctx, code, shadow.ID); err != nil {
t.Fatalf("consume bind code failed: %v", err)
}
after, err := bindSvc.Get(ctx, code.Token)
if err != nil {
t.Fatalf("get bind code failed: %v", err)
}
if after.UsedAt.IsZero() {
t.Fatal("expected used_at to be set after successful consume")
}
if after.UsedByChannelIdentityID != shadow.ID {
t.Fatalf("expected used_by_channel_identity_id=%s, got %s", shadow.ID, after.UsedByChannelIdentityID)
}
canonical, err := channelIdentitySvc.Canonicalize(ctx, shadow.ID)
if err != nil {
t.Fatalf("canonicalize failed: %v", err)
}
if canonical != human.ID {
t.Fatalf("expected canonical=%s, got %s", human.ID, canonical)
}
if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{
ChatID: pgChatID,
ChannelIdentityID: pgHumanID,
}); err != nil {
t.Fatalf("expected human participant after bind, got error: %v", err)
}
if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{
ChatID: pgChatID,
ChannelIdentityID: pgShadowID,
}); !errors.Is(err, pgx.ErrNoRows) {
t.Fatalf("expected shadow participant removed after bind, got %v", err)
}
if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrCodeUsed) {
t.Fatalf("expected ErrCodeUsed on second consume, got %v", err)
}
}
func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) {
queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t)
defer cleanup()
ctx := context.Background()
humanA, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
if err != nil {
t.Fatalf("create humanA failed: %v", err)
}
humanB, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman)
if err != nil {
t.Fatalf("create humanB failed: %v", err)
}
shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow)
if err != nil {
t.Fatalf("create shadow failed: %v", err)
}
botID, err := createBotForBindTest(ctx, queries, humanA.ID)
if err != nil {
t.Fatalf("create bot failed: %v", err)
}
// Pre-link shadow to another user so bind consume hits link conflict.
if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, shadow.ID, humanB.ID); err != nil {
t.Fatalf("pre link shadow->humanB failed: %v", err)
}
code, err := bindSvc.Issue(ctx, botID, humanA.ID, 10*time.Minute)
if err != nil {
t.Fatalf("issue bind code failed: %v", err)
}
if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrLinkConflict) {
t.Fatalf("expected ErrLinkConflict, got %v", err)
}
after, err := bindSvc.Get(ctx, code.Token)
if err != nil {
t.Fatalf("get bind code failed: %v", err)
}
if !after.UsedAt.IsZero() {
t.Fatal("expected used_at to remain empty when consume fails")
}
}
+4 -16
View File
@@ -7,7 +7,6 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"testing"
"time"
@@ -61,7 +60,10 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st
if err != nil {
return "", err
}
meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"})
meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"})
if err != nil {
return "", err
}
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
OwnerUserID: pgOwnerID,
Type: "personal",
@@ -75,14 +77,6 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st
return db.UUIDToString(row.ID), nil
}
func isLegacyBindSchemaError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "relation \"users\" does not exist")
}
func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) {
queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t)
defer cleanup()
@@ -90,9 +84,6 @@ func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) {
ctx := context.Background()
ownerUserID, err := createUserForBind(ctx, queries)
if err != nil {
if isLegacyBindSchemaError(err) {
t.Skipf("skip integration test on legacy schema: %v", err)
}
t.Fatalf("create owner user failed: %v", err)
}
sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source")
@@ -134,9 +125,6 @@ func TestBindConsumeConflictDoesNotMarkUsed(t *testing.T) {
ctx := context.Background()
issuerUserID, err := createUserForBind(ctx, queries)
if err != nil {
if isLegacyBindSchemaError(err) {
t.Skipf("skip integration test on legacy schema: %v", err)
}
t.Fatalf("create issuer user failed: %v", err)
}
otherUserID, err := createUserForBind(ctx, queries)
@@ -5,7 +5,6 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"testing"
"time"
@@ -43,16 +42,6 @@ func formatUUID(bytes [16]byte) string {
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
}
func isLegacyChannelIdentitySchemaError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "channelidentities_kind_check") ||
strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") ||
strings.Contains(msg, "column \"channel_subject_id\" of relation \"channelidentities\" does not exist")
}
func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) {
svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t)
defer cleanup()
@@ -61,9 +50,6 @@ func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) {
externalID := fmt.Sprintf("stable_%d", time.Now().UnixNano())
first, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "first")
if err != nil {
if isLegacyChannelIdentitySchemaError(err) {
t.Skipf("skip integration test on legacy schema: %v", err)
}
t.Fatalf("first resolve failed: %v", err)
}
second, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "second")
@@ -82,9 +68,6 @@ func TestChannelIdentityLinkToUser(t *testing.T) {
ctx := context.Background()
channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", fmt.Sprintf("link_%d", time.Now().UnixNano()), "tg")
if err != nil {
if isLegacyChannelIdentitySchemaError(err) {
t.Skipf("skip integration test on legacy schema: %v", err)
}
t.Fatalf("resolve channelIdentity failed: %v", err)
}
user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
+17 -15
View File
@@ -208,7 +208,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
r.enforceGroupMemoryPolicy(ctx, req.ChatID, &chatSettings)
}
userSettings, err := r.loadUserSettings(ctx, req.ChannelIdentityID)
userSettings, err := r.loadUserSettings(ctx, req.UserID)
if err != nil {
return resolvedContext{}, err
}
@@ -297,7 +297,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
BotID: req.BotID,
SessionID: req.ChatID,
ContainerID: containerID,
ChannelIdentityID: firstNonEmpty(req.ChannelIdentityID, req.BotID),
ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID),
DisplayName: firstNonEmpty(req.DisplayName, "User"),
CurrentPlatform: req.CurrentChannel,
ReplyTarget: "",
@@ -351,7 +351,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
BotID: botID,
ChatID: chatID,
Query: payload.Command,
ChannelIdentityID: payload.OwnerUserID,
UserID: payload.OwnerUserID,
Token: token,
}
rc, err := r.resolve(ctx, req)
@@ -373,7 +373,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
BotID: rc.payload.Identity.BotID,
SessionID: rc.payload.Identity.SessionID,
ContainerID: rc.payload.Identity.ContainerID,
ChannelIdentityID: firstNonEmpty(payload.OwnerUserID, botID),
ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID),
DisplayName: "Scheduler",
},
Attachments: rc.payload.Attachments,
@@ -676,8 +676,8 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest
if settings.EnableChatMemory {
scopes = append(scopes, memoryScope{Namespace: "chat", ScopeID: req.ChatID})
}
if settings.EnablePrivateMemory && strings.TrimSpace(req.ChannelIdentityID) != "" {
scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.ChannelIdentityID})
if settings.EnablePrivateMemory && strings.TrimSpace(req.UserID) != "" {
scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.UserID})
}
if settings.EnablePublicMemory {
scopes = append(scopes, memoryScope{Namespace: "public", ScopeID: req.BotID})
@@ -778,7 +778,7 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M
fullRound = append(fullRound, messages...)
r.storeMessages(ctx, req, fullRound)
r.storeMemory(ctx, req.BotID, req.ChatID, req.ChannelIdentityID, req.Query, fullRound)
r.storeMemory(ctx, req.BotID, req.ChatID, req.UserID, req.Query, fullRound)
return nil
}
@@ -805,10 +805,12 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
if err != nil {
continue
}
senderID := ""
senderChannelIdentityID := ""
senderUserID := ""
externalMessageID := ""
if msg.Role == "user" {
senderID = req.ChannelIdentityID
senderChannelIdentityID = req.SourceChannelIdentityID
senderUserID = req.UserID
externalMessageID = req.ExternalMessageID
}
if _, err := r.chatService.PersistMessage(
@@ -816,8 +818,8 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
req.ChatID,
req.BotID,
req.RouteID,
"",
senderID,
senderChannelIdentityID,
senderUserID,
req.CurrentChannel,
externalMessageID,
msg.Role,
@@ -829,7 +831,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
}
}
func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdentityID, query string, messages []ModelMessage) {
func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, userID, query string, messages []ModelMessage) {
if r.memoryService == nil {
return
}
@@ -869,9 +871,9 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdenti
r.addMemory(ctx, botID, memMsgs, "chat", chatID)
}
// Write to private namespace if enabled and channel identity is known.
if cs.EnablePrivateMemory && strings.TrimSpace(channelIdentityID) != "" {
r.addMemory(ctx, botID, memMsgs, "private", channelIdentityID)
// Write to private namespace if enabled and user id is known.
if cs.EnablePrivateMemory && strings.TrimSpace(userID) != "" {
r.addMemory(ctx, botID, memMsgs, "private", userID)
}
// Write to public namespace if enabled.
+34 -13
View File
@@ -70,7 +70,10 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
}
}
metadata, _ := json.Marshal(nonNilMap(req.Metadata))
metadata, err := json.Marshal(nonNilMap(req.Metadata))
if err != nil {
return Chat{}, fmt.Errorf("marshal chat metadata: %w", err)
}
row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{
BotID: pgBotID,
@@ -393,7 +396,10 @@ func (s *Service) CreateRoute(ctx context.Context, chatID string, r Route) (Rout
return Route{}, err
}
}
metadata, _ := json.Marshal(nonNilMap(r.Metadata))
metadata, err := json.Marshal(nonNilMap(r.Metadata))
if err != nil {
return Route{}, fmt.Errorf("marshal route metadata: %w", err)
}
row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{
ChatID: pgChatID,
BotID: pgBotID,
@@ -488,7 +494,10 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation
if err == nil {
// Route found, ensure the sender identity is a participant.
if strings.TrimSpace(channelIdentityID) != "" {
ok, _ := s.IsParticipant(ctx, route.ChatID, channelIdentityID)
ok, checkErr := s.IsParticipant(ctx, route.ChatID, channelIdentityID)
if checkErr != nil {
return ResolveChatResult{}, fmt.Errorf("check chat participant: %w", checkErr)
}
if !ok {
if _, err := s.AddParticipant(ctx, route.ChatID, channelIdentityID, RoleMember); err != nil {
s.logger.Warn("auto-add participant failed", slog.Any("error", err))
@@ -497,9 +506,17 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation
}
// Update reply target if changed.
if strings.TrimSpace(replyTarget) != "" && replyTarget != route.ReplyTarget {
_ = s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget)
if err := s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget); err != nil && s.logger != nil {
s.logger.Warn("update route reply target failed", slog.Any("error", err))
}
}
pgRouteChatID, parseErr := parseUUID(route.ChatID)
if parseErr != nil {
return ResolveChatResult{}, fmt.Errorf("parse route chat id: %w", parseErr)
}
if err := s.queries.TouchChat(ctx, pgRouteChatID); err != nil && s.logger != nil {
s.logger.Warn("touch chat failed", slog.Any("error", err))
}
_ = s.queries.TouchChat(ctx, mustParseUUID(route.ChatID))
return ResolveChatResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil
}
@@ -564,13 +581,22 @@ func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, se
}
var pgSender pgtype.UUID
if strings.TrimSpace(senderChannelIdentityID) != "" {
pgSender, _ = parseUUID(senderChannelIdentityID)
pgSender, err = parseUUID(senderChannelIdentityID)
if err != nil {
return Message{}, fmt.Errorf("invalid sender channel identity id: %w", err)
}
}
var pgSenderUser pgtype.UUID
if strings.TrimSpace(senderUserID) != "" {
pgSenderUser, _ = parseUUID(senderUserID)
pgSenderUser, err = parseUUID(senderUserID)
if err != nil {
return Message{}, fmt.Errorf("invalid sender user id: %w", err)
}
}
metaBytes, err := json.Marshal(nonNilMap(metadata))
if err != nil {
return Message{}, fmt.Errorf("marshal message metadata: %w", err)
}
metaBytes, _ := json.Marshal(nonNilMap(metadata))
if len(content) == 0 {
content = []byte("{}")
}
@@ -813,11 +839,6 @@ func pgTimePtr(ts pgtype.Timestamptz) *time.Time {
return &value
}
func mustParseUUID(id string) pgtype.UUID {
pgID, _ := parseUUID(id)
return pgID
}
func nonNilMap(m map[string]any) map[string]any {
if m == nil {
return map[string]any{}
@@ -7,7 +7,6 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"testing"
"time"
@@ -56,17 +55,6 @@ func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture {
}
}
func isLegacyChatPresenceSchemaError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "relation \"chat_channelIdentity_presence\" does not exist") ||
strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") ||
strings.Contains(msg, "column \"sender_user_id\" of relation \"chat_messages\" does not exist") ||
strings.Contains(msg, "column \"created_by_user_id\" of relation \"chats\" does not exist")
}
func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) {
row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{
IsActive: true,
@@ -83,7 +71,10 @@ func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerU
if err != nil {
return "", err
}
meta, _ := json.Marshal(map[string]any{"source": "chat-presence-integration-test"})
meta, err := json.Marshal(map[string]any{"source": "chat-presence-integration-test"})
if err != nil {
return "", err
}
row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{
OwnerUserID: pgOwnerID,
Type: "personal",
@@ -105,10 +96,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
ownerUserID, err := createUserForChatPresence(ctx, fixture.queries)
if err != nil {
if isLegacyChatPresenceSchemaError(err) {
fixture.cleanup()
t.Skipf("skip integration test on legacy schema: %v", err)
}
fixture.cleanup()
t.Fatalf("create owner user failed: %v", err)
}
@@ -128,10 +115,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
Title: "presence-observed",
})
if err != nil {
if isLegacyChatPresenceSchemaError(err) {
fixture.cleanup()
t.Skipf("skip integration test on legacy schema: %v", err)
}
fixture.cleanup()
t.Fatalf("create chat failed: %v", err)
}
@@ -143,10 +126,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
"presence-observer",
)
if err != nil {
if isLegacyChatPresenceSchemaError(err) {
fixture.cleanup()
t.Skipf("skip integration test on legacy schema: %v", err)
}
fixture.cleanup()
t.Fatalf("resolve channelIdentity failed: %v", err)
}
@@ -165,10 +144,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin
nil,
)
if err != nil {
if isLegacyChatPresenceSchemaError(err) {
fixture.cleanup()
t.Skipf("skip integration test on legacy schema: %v", err)
}
fixture.cleanup()
t.Fatalf("persist message failed: %v", err)
}
+2 -1
View File
@@ -237,7 +237,8 @@ type ChatRequest struct {
BotID string `json:"-"`
ChatID string `json:"-"`
Token string `json:"-"`
ChannelIdentityID string `json:"-"`
UserID string `json:"-"`
SourceChannelIdentityID string `json:"-"`
ContainerID string `json:"-"`
DisplayName string `json:"-"`
RouteID string `json:"-"`
+6 -6
View File
@@ -53,7 +53,7 @@ func (h *BindHandler) Issue(c echo.Context) error {
if h.service == nil {
return echo.NewHTTPError(http.StatusServiceUnavailable, "bind service not available")
}
channelIdentityID, err := h.requireChannelIdentityID(c)
userID, err := h.requireUserID(c)
if err != nil {
return err
}
@@ -68,7 +68,7 @@ func (h *BindHandler) Issue(c echo.Context) error {
ttl = time.Duration(req.TTLSeconds) * time.Second
}
code, err := h.service.Issue(c.Request().Context(), channelIdentityID, strings.TrimSpace(req.Platform), ttl)
code, err := h.service.Issue(c.Request().Context(), userID, strings.TrimSpace(req.Platform), ttl)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
@@ -79,13 +79,13 @@ func (h *BindHandler) Issue(c echo.Context) error {
})
}
func (h *BindHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
func (h *BindHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
if err := identity.ValidateChannelIdentityID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return userID, nil
}
+1 -1
View File
@@ -161,7 +161,7 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error {
}
func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+21 -7
View File
@@ -189,7 +189,8 @@ func (h *ChatHandler) SendMessage(c echo.Context) error {
req.BotID = chatObj.BotID
req.ChatID = chatID
req.Token = c.Request().Header.Get("Authorization")
req.ChannelIdentityID = channelIdentityID
req.UserID = channelIdentityID
req.SourceChannelIdentityID = channelIdentityID
resp, err := h.resolver.Chat(c.Request().Context(), req)
if err != nil {
@@ -224,7 +225,8 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error {
req.BotID = chatObj.BotID
req.ChatID = chatID
req.Token = c.Request().Header.Get("Authorization")
req.ChannelIdentityID = channelIdentityID
req.UserID = channelIdentityID
req.SourceChannelIdentityID = channelIdentityID
c.Response().Header().Set(echo.HeaderContentType, "text/event-stream")
c.Response().Header().Set(echo.HeaderCacheControl, "no-cache")
@@ -258,7 +260,10 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error {
if err != nil {
h.logger.Error("chat stream failed", slog.Any("error", err))
errData := map[string]string{"error": err.Error()}
data, _ := json.Marshal(errData)
data, marshalErr := json.Marshal(errData)
if marshalErr != nil {
return echo.NewHTTPError(http.StatusInternalServerError, marshalErr.Error())
}
writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data)))
writer.Flush()
flusher.Flush()
@@ -500,7 +505,7 @@ func (h *ChatHandler) ListThreads(c echo.Context) error {
// --- helpers ---
func (h *ChatHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
@@ -558,7 +563,10 @@ func (h *ChatHandler) requireParticipant(ctx context.Context, chatID, channelIde
}
// Admin bypass.
if h.accountService != nil {
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if isAdmin {
return nil
}
@@ -579,7 +587,10 @@ func (h *ChatHandler) requireReadable(ctx context.Context, chatID, channelIdenti
}
// Admin bypass.
if h.accountService != nil {
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if isAdmin {
return nil
}
@@ -600,7 +611,10 @@ func (h *ChatHandler) requireRole(ctx context.Context, chatID, channelIdentityID
}
// Admin bypass.
if h.accountService != nil {
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if isAdmin {
return nil
}
+1 -1
View File
@@ -627,7 +627,7 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) {
}
func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+4 -4
View File
@@ -85,10 +85,10 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL)
req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL)
channelIdentityID := ""
userID := ""
if c.Get("user") != nil {
if value, err := auth.ChannelIdentityIDFromContext(c); err == nil {
channelIdentityID = value
if value, err := auth.UserIDFromContext(c); err == nil {
userID = value
}
}
result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{
@@ -101,7 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
ImageURL: req.Input.ImageURL,
VideoURL: req.Input.VideoURL,
},
ChannelIdentityID: channelIdentityID,
ChannelIdentityID: userID,
})
if err != nil {
message := err.Error()
+1 -1
View File
@@ -228,7 +228,7 @@ func (h *LocalChannelHandler) ensureChatParticipant(ctx context.Context, chatID,
}
func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+1 -1
View File
@@ -218,7 +218,7 @@ func (h *MCPHandler) Delete(c echo.Context) error {
}
func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
userID, err := auth.ChannelIdentityIDFromContext(c)
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+5 -2
View File
@@ -364,7 +364,10 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured")
}
if h.accountService != nil {
isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID)
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if isAdmin {
return nil
}
@@ -380,7 +383,7 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
}
func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+1 -1
View File
@@ -164,7 +164,7 @@ func (h *ModelsHandler) Enable(c echo.Context) error {
if h.settingsService == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured")
}
userID, err := auth.ChannelIdentityIDFromContext(c)
userID, err := auth.UserIDFromContext(c)
if err != nil {
return err
}
+10 -10
View File
@@ -40,7 +40,7 @@ type preauthIssueRequest struct {
}
func (h *PreauthHandler) Issue(c echo.Context) error {
channelIdentityID, err := h.requireChannelIdentityID(c)
userID, err := h.requireUserID(c)
if err != nil {
return err
}
@@ -48,7 +48,7 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
var req preauthIssueRequest
@@ -59,33 +59,33 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
if req.TTLSeconds > 0 {
ttl = time.Duration(req.TTLSeconds) * time.Second
}
key, err := h.service.Issue(c.Request().Context(), botID, channelIdentityID, ttl)
key, err := h.service.Issue(c.Request().Context(), botID, userID, ttl)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, key)
}
func (h *PreauthHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
if err := identity.ValidateChannelIdentityID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return userID, nil
}
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
+17 -17
View File
@@ -51,7 +51,7 @@ func (h *ScheduleHandler) Register(e *echo.Echo) {
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/schedule [post]
func (h *ScheduleHandler) Create(c echo.Context) error {
channelIdentityID, err := h.requireChannelIdentityID(c)
userID, err := h.requireUserID(c)
if err != nil {
return err
}
@@ -59,7 +59,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error {
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
var req schedule.CreateRequest
@@ -82,7 +82,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/schedule [get]
func (h *ScheduleHandler) List(c echo.Context) error {
channelIdentityID, err := h.requireChannelIdentityID(c)
userID, err := h.requireUserID(c)
if err != nil {
return err
}
@@ -90,7 +90,7 @@ func (h *ScheduleHandler) List(c echo.Context) error {
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
items, err := h.service.List(c.Request().Context(), botID)
@@ -111,7 +111,7 @@ func (h *ScheduleHandler) List(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/schedule/{id} [get]
func (h *ScheduleHandler) Get(c echo.Context) error {
channelIdentityID, err := h.requireChannelIdentityID(c)
userID, err := h.requireUserID(c)
if err != nil {
return err
}
@@ -130,7 +130,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error {
if item.BotID != botID {
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
return c.JSON(http.StatusOK, item)
@@ -147,7 +147,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/schedule/{id} [put]
func (h *ScheduleHandler) Update(c echo.Context) error {
channelIdentityID, err := h.requireChannelIdentityID(c)
userID, err := h.requireUserID(c)
if err != nil {
return err
}
@@ -170,7 +170,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error {
if item.BotID != botID {
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
resp, err := h.service.Update(c.Request().Context(), id, req)
@@ -190,7 +190,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error {
// @Failure 500 {object} ErrorResponse
// @Router /bots/{bot_id}/schedule/{id} [delete]
func (h *ScheduleHandler) Delete(c echo.Context) error {
channelIdentityID, err := h.requireChannelIdentityID(c)
userID, err := h.requireUserID(c)
if err != nil {
return err
}
@@ -209,7 +209,7 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
if item.BotID != botID {
return echo.NewHTTPError(http.StatusForbidden, "bot mismatch")
}
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
return err
}
if err := h.service.Delete(c.Request().Context(), id); err != nil {
@@ -218,26 +218,26 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
}
func (h *ScheduleHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
if err := identity.ValidateChannelIdentityID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return userID, nil
}
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
+1 -1
View File
@@ -130,7 +130,7 @@ func (h *SettingsHandler) Delete(c echo.Context) error {
}
func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+1 -1
View File
@@ -433,7 +433,7 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error {
}
func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+1 -1
View File
@@ -907,7 +907,7 @@ func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID
}
func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.ChannelIdentityIDFromContext(c)
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
+2 -1
View File
@@ -170,7 +170,8 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
BotID: identity.BotID,
ChatID: resolved.ChatID,
Token: token,
ChannelIdentityID: identity.UserID,
UserID: identity.UserID,
SourceChannelIdentityID: identity.ChannelIdentityID,
DisplayName: identity.DisplayName,
RouteID: resolved.RouteID,
ChatToken: chatToken,
+5 -2
View File
@@ -97,8 +97,11 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) {
if gateway.gotReq.Query != "hello" {
t.Errorf("expected query 'hello', got: %s", gateway.gotReq.Query)
}
if gateway.gotReq.ChannelIdentityID != "channelIdentity-1" {
t.Errorf("expected channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.ChannelIdentityID)
if gateway.gotReq.UserID != "channelIdentity-1" {
t.Errorf("expected user_id 'channelIdentity-1', got: %s", gateway.gotReq.UserID)
}
if gateway.gotReq.SourceChannelIdentityID != "channelIdentity-1" {
t.Errorf("expected source_channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.SourceChannelIdentityID)
}
if gateway.gotReq.ChatID != "chat-1" {
t.Errorf("expected chat_id 'chat-1', got: %s", gateway.gotReq.ChatID)
+33 -11
View File
@@ -209,7 +209,9 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
state.Decision = &decision
return state, err
}
if r.policy != nil && isGroupConversationType(msg.Conversation.Type) {
// Personal bots are owner-only and must not depend on member/guest/preauth bypass.
if r.policy != nil {
botType, err := r.policy.BotType(ctx, botID)
if err != nil {
return state, err
@@ -219,20 +221,35 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
if err != nil {
return state, err
}
if strings.TrimSpace(state.Identity.UserID) == "" || strings.TrimSpace(ownerUserID) != strings.TrimSpace(state.Identity.UserID) {
// Personal bots in group chats only answer owner messages.
isOwner := strings.TrimSpace(state.Identity.UserID) != "" &&
strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID)
if !isOwner {
if isGroupConversationType(msg.Conversation.Type) {
// Ignore non-owner group messages for personal bots.
state.Decision = &IdentityDecision{Stop: true}
return state, nil
}
// Owner can chat normally in group for personal bots.
state.Decision = &IdentityDecision{
Stop: true,
Reply: channel.Message{Text: r.unboundReply},
}
return state, nil
}
if isGroupConversationType(msg.Conversation.Type) {
// Owner can chat in group for personal bots.
state.Identity.ForceReply = true
}
return state, nil
}
}
// Phase 2: Authorization (bot membership check).
if r.members != nil {
if strings.TrimSpace(state.Identity.UserID) != "" {
isMember, _ := r.members.IsMember(ctx, botID, state.Identity.UserID)
isMember, err := r.members.IsMember(ctx, botID, state.Identity.UserID)
if err != nil {
return state, fmt.Errorf("check bot membership: %w", err)
}
if isMember {
return state, nil
}
@@ -256,9 +273,6 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
return state, err
}
if allowed {
if r.members != nil && strings.TrimSpace(state.Identity.UserID) != "" {
_ = r.members.UpsertMemberSimple(ctx, botID, state.Identity.UserID, "member")
}
return state, nil
}
}
@@ -312,9 +326,13 @@ func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel.
return true, reply("Current channel account is not linked to a user."), nil
}
if r.members != nil {
_ = r.members.UpsertMemberSimple(ctx, botID, userID, "member")
if err := r.members.UpsertMemberSimple(ctx, botID, userID, "member"); err != nil {
return true, IdentityDecision{}, fmt.Errorf("upsert preauth member: %w", err)
}
}
if _, err := r.preauth.MarkUsed(ctx, key.ID); err != nil {
return true, IdentityDecision{}, fmt.Errorf("mark preauth key used: %w", err)
}
_, _ = r.preauth.MarkUsed(ctx, key.ID)
return true, reply(r.preauthReply), nil
}
@@ -365,7 +383,11 @@ func (r *IdentityResolver) tryHandleBindCode(ctx context.Context, msg channel.In
// Resolve linked user after binding.
newUserID := code.IssuedByUserID
if r.channelIdentities != nil {
if linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID); err == nil && strings.TrimSpace(linkedUserID) != "" {
linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID)
if err != nil {
return true, IdentityDecision{}, "", fmt.Errorf("resolve linked user after bind: %w", err)
}
if strings.TrimSpace(linkedUserID) != "" {
newUserID = linkedUserID
}
}
+32 -3
View File
@@ -145,7 +145,7 @@ func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelCh
return f.consumeErr
}
func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) {
func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}}
memberSvc := &fakeMemberService{isMember: false}
policySvc := &fakePolicyService{allow: true, botType: "public"}
@@ -165,8 +165,8 @@ func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) {
if state.Identity.ChannelIdentityID != "channelIdentity-1" {
t.Fatalf("expected channelIdentity-1, got: %s", state.Identity.ChannelIdentityID)
}
if !memberSvc.upsertCalled {
t.Fatal("expected UpsertMemberSimple to be called")
if memberSvc.upsertCalled {
t.Fatal("guest allow should not upsert membership")
}
if state.Decision != nil {
t.Fatal("expected no decision for allowed guest")
@@ -376,6 +376,35 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin
}
}
func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-non-owner"}}
memberSvc := &fakeMemberService{isMember: true}
policySvc := &fakePolicyService{allow: true, botType: "personal", ownerUserID: "channelIdentity-owner"}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "")
msg := channel.InboundMessage{
BotID: "bot-1",
Channel: channel.ChannelType("feishu"),
Message: channel.Message{Text: "hello from non-owner"},
Sender: channel.Identity{SubjectID: "ext-non-owner"},
Conversation: channel.Conversation{
ID: "p2p-2",
Type: "p2p",
},
}
state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if state.Decision == nil || !state.Decision.Stop {
t.Fatal("non-owner direct message should be rejected for personal bot")
}
if state.Decision.Reply.Text != "Access denied." {
t.Fatalf("unexpected reject message: %s", state.Decision.Reply.Text)
}
}
func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) {
shadowID := "channelIdentity-shadow"
humanID := "channelIdentity-human"
+2 -2
View File
@@ -55,8 +55,8 @@ func TestGenerateTriggerToken(t *testing.T) {
if sub, _ := claims["sub"].(string); sub != userID {
t.Errorf("expected sub=%s, got=%s", userID, sub)
}
if uid, _ := claims["channel_identity_id"].(string); uid != userID {
t.Errorf("expected channel_identity_id=%s, got=%s", userID, uid)
if uid, _ := claims["user_id"].(string); uid != userID {
t.Errorf("expected user_id=%s, got=%s", userID, uid)
}
exp, _ := claims["exp"].(float64)
if exp == 0 {