diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 2454aacc..8c837b43 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -52,17 +52,9 @@ func UserIDFromContext(c echo.Context) (string, error) { if userID := claimString(claims, claimSubject); userID != "" { return userID, nil } - if legacyChannelIdentityID := claimString(claims, claimChannelIdentityID); legacyChannelIdentityID != "" { - return legacyChannelIdentityID, nil - } return "", echo.NewHTTPError(http.StatusUnauthorized, "user id missing") } -// ChannelIdentityIDFromContext is kept as compatibility alias and returns user id. -func ChannelIdentityIDFromContext(c echo.Context) (string, error) { - return UserIDFromContext(c) -} - // GenerateToken creates a signed JWT for the user. func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(userID) == "" { @@ -78,11 +70,10 @@ func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time now := time.Now().UTC() expiresAt := now.Add(expiresIn) claims := jwt.MapClaims{ - claimSubject: userID, - claimUserID: userID, - claimChannelIdentityID: userID, // legacy compatibility for handlers still reading channel_identity_id - "iat": now.Unix(), - "exp": expiresAt.Unix(), + claimSubject: userID, + claimUserID: userID, + "iat": now.Unix(), + "exp": expiresAt.Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signed, err := token.SignedString([]byte(secret)) diff --git a/internal/bind/service_integration_test.go b/internal/bind/service_integration_test.go index 48a2c145..02648e2c 100644 --- a/internal/bind/service_integration_test.go +++ b/internal/bind/service_integration_test.go @@ -15,8 +15,8 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" - "github.com/memohai/memoh/internal/channelidentities" "github.com/memohai/memoh/internal/bind" + "github.com/memohai/memoh/internal/channelidentities" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -47,7 +47,7 @@ func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.S return queries, channelIdentitySvc, bindSvc, func() { pool.Close() } } -func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) { +func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) { row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ IsActive: true, Metadata: []byte("{}"), @@ -58,12 +58,15 @@ func createUser(ctx context.Context, queries *sqlc.Queries) (string, error) { return db.UUIDToString(row.ID), nil } -func createBot(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { +func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { pgOwnerID, err := db.ParseUUID(ownerUserID) if err != nil { return "", err } - meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) + meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"}) + if err != nil { + return "", err + } row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ OwnerUserID: pgOwnerID, Type: "personal", @@ -82,15 +85,15 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { defer cleanup() ctx := context.Background() - ownerUserID, err := createUser(ctx, queries) + ownerUserID, err := createUserForBindTest(ctx, queries) if err != nil { t.Fatalf("create owner user failed: %v", err) } sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) if err != nil { - t.Fatalf("create source channelIdentity failed: %v", err) + t.Fatalf("create source channel identity failed: %v", err) } - botID, err := createBot(ctx, queries, ownerUserID) + botID, err := createBotForBindTest(ctx, queries, ownerUserID) if err != nil { t.Fatalf("create bot failed: %v", err) } @@ -127,177 +130,6 @@ func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { } } -func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { - queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) - defer cleanup() - - ctx := context.Background() - ownerUserID, err := createUser(ctx, queries) - if err != nil { - t.Fatalf("create owner user failed: %v", err) - } - otherUserID, err := createUser(ctx, queries) - if err != nil { - t.Fatalf("create other user failed: %v", err) - } - sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) - if err != nil { - t.Fatalf("create source channelIdentity failed: %v", err) - } - if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { - t.Fatalf("pre-link source channelIdentity failed: %v", err) - } - botID, err := createBot(ctx, queries, ownerUserID) - if err != nil { - t.Fatalf("create bot failed: %v", err) - } - - code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) - if err != nil { - t.Fatalf("issue bind code failed: %v", err) - } - if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrLinkConflict) { - t.Fatalf("expected ErrLinkConflict, got %v", err) - } - - after, err := bindSvc.Get(ctx, code.Token) - if err != nil { - t.Fatalf("get bind code failed: %v", err) - } - if !after.UsedAt.IsZero() { - t.Fatal("expected used_at to remain empty when consume fails") - } -} -package bind_test - -import ( - "context" - "encoding/json" - "errors" - "log/slog" - "os" - "testing" - "time" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - - "github.com/memohai/memoh/internal/channelidentities" - "github.com/memohai/memoh/internal/bind" - "github.com/memohai/memoh/internal/db" - "github.com/memohai/memoh/internal/db/sqlc" -) - -func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { - t.Helper() - - dsn := os.Getenv("TEST_POSTGRES_DSN") - if dsn == "" { - t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") - } - - ctx := context.Background() - pool, err := pgxpool.New(ctx, dsn) - if err != nil { - t.Skipf("skip integration test: cannot connect to database: %v", err) - } - if err := pool.Ping(ctx); err != nil { - pool.Close() - t.Skipf("skip integration test: database ping failed: %v", err) - } - - queries := sqlc.New(pool) - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - channelIdentitySvc := channelidentities.NewService(logger, queries) - bindSvc := bind.NewService(logger, pool, queries) - - cleanup := func() { - pool.Close() - } - return queries, channelIdentitySvc, bindSvc, cleanup -} - -func createUserForBindTest(ctx context.Context, queries *sqlc.Queries) (string, error) { - row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ - IsActive: true, - Metadata: []byte("{}"), - }) - if err != nil { - return "", err - } - return db.UUIDToString(row.ID), nil -} - -func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { - pgOwnerID, err := db.ParseUUID(ownerUserID) - if err != nil { - return "", err - } - meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) - row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ - OwnerUserID: pgOwnerID, - Type: "personal", - DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, - AvatarUrl: pgtype.Text{}, - IsActive: true, - Metadata: meta, - }) - if err != nil { - return "", err - } - return db.UUIDToString(row.ID), nil -} - -func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { - queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) - defer cleanup() - - ctx := context.Background() - ownerUserID, err := createUserForBindTest(ctx, queries) - if err != nil { - t.Fatalf("create owner user failed: %v", err) - } - sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) - if err != nil { - t.Fatalf("create source channelIdentity failed: %v", err) - } - botID, err := createBotForBindTest(ctx, queries, ownerUserID) - if err != nil { - t.Fatalf("create bot failed: %v", err) - } - - code, err := bindSvc.Issue(ctx, botID, ownerUserID, 10*time.Minute) - if err != nil { - t.Fatalf("issue bind code failed: %v", err) - } - if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); err != nil { - t.Fatalf("consume bind code failed: %v", err) - } - - after, err := bindSvc.Get(ctx, code.Token) - if err != nil { - t.Fatalf("get bind code failed: %v", err) - } - if after.UsedAt.IsZero() { - t.Fatal("expected used_at to be set after successful consume") - } - if after.UsedByChannelIdentityID != sourceChannelIdentity.ID { - t.Fatalf("expected used_by_channel_identity_id=%s, got %s", sourceChannelIdentity.ID, after.UsedByChannelIdentityID) - } - - linkedUserID, err := channelIdentitySvc.GetLinkedUserID(ctx, sourceChannelIdentity.ID) - if err != nil { - t.Fatalf("get linked user failed: %v", err) - } - if linkedUserID != ownerUserID { - t.Fatalf("expected linked user=%s, got %s", ownerUserID, linkedUserID) - } - - if err := bindSvc.Consume(ctx, code, sourceChannelIdentity.ID); !errors.Is(err, bind.ErrCodeUsed) { - t.Fatalf("expected ErrCodeUsed on second consume, got %v", err) - } -} - func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) defer cleanup() @@ -313,10 +145,10 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { } sourceChannelIdentity, err := channelIdentitySvc.Create(ctx, channelidentities.KindChannel) if err != nil { - t.Fatalf("create source channelIdentity failed: %v", err) + t.Fatalf("create source channel identity failed: %v", err) } if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, sourceChannelIdentity.ID, otherUserID); err != nil { - t.Fatalf("pre-link source channelIdentity failed: %v", err) + t.Fatalf("pre-link source channel identity failed: %v", err) } botID, err := createBotForBindTest(ctx, queries, ownerUserID) if err != nil { @@ -339,224 +171,3 @@ func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { t.Fatal("expected used_at to remain empty when consume fails") } } -package bind_test - -import ( - "context" - "encoding/json" - "errors" - "log/slog" - "os" - "testing" - "time" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - - "github.com/memohai/memoh/internal/channelidentities" - "github.com/memohai/memoh/internal/bind" - "github.com/memohai/memoh/internal/db" - "github.com/memohai/memoh/internal/db/sqlc" -) - -func setupBindIntegrationTest(t *testing.T) (*sqlc.Queries, *channelidentities.Service, *bind.Service, func()) { - t.Helper() - - dsn := os.Getenv("TEST_POSTGRES_DSN") - if dsn == "" { - t.Skip("skip integration test: TEST_POSTGRES_DSN is not set") - } - - ctx := context.Background() - pool, err := pgxpool.New(ctx, dsn) - if err != nil { - t.Skipf("skip integration test: cannot connect to database: %v", err) - } - if err := pool.Ping(ctx); err != nil { - pool.Close() - t.Skipf("skip integration test: database ping failed: %v", err) - } - - queries := sqlc.New(pool) - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - channelIdentitySvc := channelidentities.NewService(logger, queries) - bindSvc := bind.NewService(logger, pool, queries) - - cleanup := func() { - pool.Close() - } - return queries, channelIdentitySvc, bindSvc, cleanup -} - -func createBotForBindTest(ctx context.Context, queries *sqlc.Queries, ownerChannelIdentityID string) (string, error) { - pgOwnerID, err := db.ParseUUID(ownerChannelIdentityID) - if err != nil { - return "", err - } - meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) - row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ - OwnerChannelIdentityID: pgOwnerID, - Type: "personal", - DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, - AvatarUrl: pgtype.Text{}, - IsActive: true, - Metadata: meta, - }) - if err != nil { - return "", err - } - return db.UUIDToString(row.ID), nil -} - -func createChatForBindTest(ctx context.Context, queries *sqlc.Queries, botID, channelIdentityID string) (string, error) { - pgBotID, err := db.ParseUUID(botID) - if err != nil { - return "", err - } - pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) - if err != nil { - return "", err - } - row, err := queries.CreateChat(ctx, sqlc.CreateChatParams{ - BotID: pgBotID, - Kind: "direct", - ParentChatID: pgtype.UUID{}, - Title: pgtype.Text{}, - CreatedBy: pgChannelIdentityID, - Metadata: []byte("{}"), - }) - if err != nil { - return "", err - } - return db.UUIDToString(row.ID), nil -} - -func TestIntegrationConsumeBindCodeSuccessAndSingleUse(t *testing.T) { - queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) - defer cleanup() - - ctx := context.Background() - human, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman) - if err != nil { - t.Fatalf("create human failed: %v", err) - } - shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow) - if err != nil { - t.Fatalf("create shadow failed: %v", err) - } - botID, err := createBotForBindTest(ctx, queries, human.ID) - if err != nil { - t.Fatalf("create bot failed: %v", err) - } - chatID, err := createChatForBindTest(ctx, queries, botID, human.ID) - if err != nil { - t.Fatalf("create chat failed: %v", err) - } - pgChatID, err := db.ParseUUID(chatID) - if err != nil { - t.Fatalf("parse chat id failed: %v", err) - } - pgShadowID, err := db.ParseUUID(shadow.ID) - if err != nil { - t.Fatalf("parse shadow id failed: %v", err) - } - pgHumanID, err := db.ParseUUID(human.ID) - if err != nil { - t.Fatalf("parse human id failed: %v", err) - } - if _, err := queries.AddChatParticipant(ctx, sqlc.AddChatParticipantParams{ - ChatID: pgChatID, - ChannelIdentityID: pgShadowID, - Role: "member", - }); err != nil { - t.Fatalf("add shadow participant failed: %v", err) - } - - code, err := bindSvc.Issue(ctx, botID, human.ID, 10*time.Minute) - if err != nil { - t.Fatalf("issue bind code failed: %v", err) - } - if err := bindSvc.Consume(ctx, code, shadow.ID); err != nil { - t.Fatalf("consume bind code failed: %v", err) - } - - after, err := bindSvc.Get(ctx, code.Token) - if err != nil { - t.Fatalf("get bind code failed: %v", err) - } - if after.UsedAt.IsZero() { - t.Fatal("expected used_at to be set after successful consume") - } - if after.UsedByChannelIdentityID != shadow.ID { - t.Fatalf("expected used_by_channel_identity_id=%s, got %s", shadow.ID, after.UsedByChannelIdentityID) - } - - canonical, err := channelIdentitySvc.Canonicalize(ctx, shadow.ID) - if err != nil { - t.Fatalf("canonicalize failed: %v", err) - } - if canonical != human.ID { - t.Fatalf("expected canonical=%s, got %s", human.ID, canonical) - } - if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ - ChatID: pgChatID, - ChannelIdentityID: pgHumanID, - }); err != nil { - t.Fatalf("expected human participant after bind, got error: %v", err) - } - if _, err := queries.GetChatParticipant(ctx, sqlc.GetChatParticipantParams{ - ChatID: pgChatID, - ChannelIdentityID: pgShadowID, - }); !errors.Is(err, pgx.ErrNoRows) { - t.Fatalf("expected shadow participant removed after bind, got %v", err) - } - - if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrCodeUsed) { - t.Fatalf("expected ErrCodeUsed on second consume, got %v", err) - } -} - -func TestIntegrationConsumeBindCodeRollbackOnLinkConflict(t *testing.T) { - queries, channelIdentitySvc, bindSvc, cleanup := setupBindIntegrationTest(t) - defer cleanup() - - ctx := context.Background() - humanA, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman) - if err != nil { - t.Fatalf("create humanA failed: %v", err) - } - humanB, err := channelIdentitySvc.Create(ctx, channelidentities.KindHuman) - if err != nil { - t.Fatalf("create humanB failed: %v", err) - } - shadow, err := channelIdentitySvc.Create(ctx, channelidentities.KindShadow) - if err != nil { - t.Fatalf("create shadow failed: %v", err) - } - botID, err := createBotForBindTest(ctx, queries, humanA.ID) - if err != nil { - t.Fatalf("create bot failed: %v", err) - } - - // Pre-link shadow to another user so bind consume hits link conflict. - if err := channelIdentitySvc.LinkChannelIdentityToUser(ctx, shadow.ID, humanB.ID); err != nil { - t.Fatalf("pre link shadow->humanB failed: %v", err) - } - - code, err := bindSvc.Issue(ctx, botID, humanA.ID, 10*time.Minute) - if err != nil { - t.Fatalf("issue bind code failed: %v", err) - } - if err := bindSvc.Consume(ctx, code, shadow.ID); !errors.Is(err, bind.ErrLinkConflict) { - t.Fatalf("expected ErrLinkConflict, got %v", err) - } - - after, err := bindSvc.Get(ctx, code.Token) - if err != nil { - t.Fatalf("get bind code failed: %v", err) - } - if !after.UsedAt.IsZero() { - t.Fatal("expected used_at to remain empty when consume fails") - } -} diff --git a/internal/bind/service_link_integration_test.go b/internal/bind/service_link_integration_test.go index 19b8f78a..05e88b63 100644 --- a/internal/bind/service_link_integration_test.go +++ b/internal/bind/service_link_integration_test.go @@ -7,7 +7,6 @@ import ( "fmt" "log/slog" "os" - "strings" "testing" "time" @@ -61,7 +60,10 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st if err != nil { return "", err } - meta, _ := json.Marshal(map[string]any{"source": "bind-integration-test"}) + meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"}) + if err != nil { + return "", err + } row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ OwnerUserID: pgOwnerID, Type: "personal", @@ -75,14 +77,6 @@ func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID st return db.UUIDToString(row.ID), nil } -func isLegacyBindSchemaError(err error) bool { - if err == nil { - return false - } - msg := strings.ToLower(err.Error()) - return strings.Contains(msg, "relation \"users\" does not exist") -} - func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { queries, channelIdentitySvc, bindSvc, cleanup := setupBindLinkIntegrationTest(t) defer cleanup() @@ -90,9 +84,6 @@ func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { ctx := context.Background() ownerUserID, err := createUserForBind(ctx, queries) if err != nil { - if isLegacyBindSchemaError(err) { - t.Skipf("skip integration test on legacy schema: %v", err) - } t.Fatalf("create owner user failed: %v", err) } sourceChannelIdentity, err := channelIdentitySvc.ResolveByChannelIdentity(ctx, "feishu", fmt.Sprintf("bind-src-%d", time.Now().UnixNano()), "source") @@ -134,9 +125,6 @@ func TestBindConsumeConflictDoesNotMarkUsed(t *testing.T) { ctx := context.Background() issuerUserID, err := createUserForBind(ctx, queries) if err != nil { - if isLegacyBindSchemaError(err) { - t.Skipf("skip integration test on legacy schema: %v", err) - } t.Fatalf("create issuer user failed: %v", err) } otherUserID, err := createUserForBind(ctx, queries) diff --git a/internal/channelidentities/service_identity_integration_test.go b/internal/channelidentities/service_identity_integration_test.go index da24c05c..827ebcb0 100644 --- a/internal/channelidentities/service_identity_integration_test.go +++ b/internal/channelidentities/service_identity_integration_test.go @@ -5,7 +5,6 @@ import ( "fmt" "log/slog" "os" - "strings" "testing" "time" @@ -43,16 +42,6 @@ func formatUUID(bytes [16]byte) string { return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) } -func isLegacyChannelIdentitySchemaError(err error) bool { - if err == nil { - return false - } - msg := strings.ToLower(err.Error()) - return strings.Contains(msg, "channelidentities_kind_check") || - strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") || - strings.Contains(msg, "column \"channel_subject_id\" of relation \"channelidentities\" does not exist") -} - func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) { svc, _, cleanup := setupChannelIdentityIdentityIntegrationTest(t) defer cleanup() @@ -61,9 +50,6 @@ func TestChannelIdentityResolveChannelIdentityStable(t *testing.T) { externalID := fmt.Sprintf("stable_%d", time.Now().UnixNano()) first, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "first") if err != nil { - if isLegacyChannelIdentitySchemaError(err) { - t.Skipf("skip integration test on legacy schema: %v", err) - } t.Fatalf("first resolve failed: %v", err) } second, err := svc.ResolveByChannelIdentity(ctx, "feishu", externalID, "second") @@ -82,9 +68,6 @@ func TestChannelIdentityLinkToUser(t *testing.T) { ctx := context.Background() channelIdentity, err := svc.ResolveByChannelIdentity(ctx, "telegram", fmt.Sprintf("link_%d", time.Now().UnixNano()), "tg") if err != nil { - if isLegacyChannelIdentitySchemaError(err) { - t.Skipf("skip integration test on legacy schema: %v", err) - } t.Fatalf("resolve channelIdentity failed: %v", err) } user, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 6e07202d..12452e23 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -208,7 +208,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex r.enforceGroupMemoryPolicy(ctx, req.ChatID, &chatSettings) } - userSettings, err := r.loadUserSettings(ctx, req.ChannelIdentityID) + userSettings, err := r.loadUserSettings(ctx, req.UserID) if err != nil { return resolvedContext{}, err } @@ -297,7 +297,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex BotID: req.BotID, SessionID: req.ChatID, ContainerID: containerID, - ChannelIdentityID: firstNonEmpty(req.ChannelIdentityID, req.BotID), + ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID), DisplayName: firstNonEmpty(req.DisplayName, "User"), CurrentPlatform: req.CurrentChannel, ReplyTarget: "", @@ -348,11 +348,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc chatID = "schedule-" + payload.ID } req := ChatRequest{ - BotID: botID, - ChatID: chatID, - Query: payload.Command, - ChannelIdentityID: payload.OwnerUserID, - Token: token, + BotID: botID, + ChatID: chatID, + Query: payload.Command, + UserID: payload.OwnerUserID, + Token: token, } rc, err := r.resolve(ctx, req) if err != nil { @@ -373,7 +373,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc BotID: rc.payload.Identity.BotID, SessionID: rc.payload.Identity.SessionID, ContainerID: rc.payload.Identity.ContainerID, - ChannelIdentityID: firstNonEmpty(payload.OwnerUserID, botID), + ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID), DisplayName: "Scheduler", }, Attachments: rc.payload.Attachments, @@ -676,8 +676,8 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest if settings.EnableChatMemory { scopes = append(scopes, memoryScope{Namespace: "chat", ScopeID: req.ChatID}) } - if settings.EnablePrivateMemory && strings.TrimSpace(req.ChannelIdentityID) != "" { - scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.ChannelIdentityID}) + if settings.EnablePrivateMemory && strings.TrimSpace(req.UserID) != "" { + scopes = append(scopes, memoryScope{Namespace: "private", ScopeID: req.UserID}) } if settings.EnablePublicMemory { scopes = append(scopes, memoryScope{Namespace: "public", ScopeID: req.BotID}) @@ -778,7 +778,7 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M fullRound = append(fullRound, messages...) r.storeMessages(ctx, req, fullRound) - r.storeMemory(ctx, req.BotID, req.ChatID, req.ChannelIdentityID, req.Query, fullRound) + r.storeMemory(ctx, req.BotID, req.ChatID, req.UserID, req.Query, fullRound) return nil } @@ -805,10 +805,12 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages if err != nil { continue } - senderID := "" + senderChannelIdentityID := "" + senderUserID := "" externalMessageID := "" if msg.Role == "user" { - senderID = req.ChannelIdentityID + senderChannelIdentityID = req.SourceChannelIdentityID + senderUserID = req.UserID externalMessageID = req.ExternalMessageID } if _, err := r.chatService.PersistMessage( @@ -816,8 +818,8 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages req.ChatID, req.BotID, req.RouteID, - "", - senderID, + senderChannelIdentityID, + senderUserID, req.CurrentChannel, externalMessageID, msg.Role, @@ -829,7 +831,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages } } -func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdentityID, query string, messages []ModelMessage) { +func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, userID, query string, messages []ModelMessage) { if r.memoryService == nil { return } @@ -869,9 +871,9 @@ func (r *Resolver) storeMemory(ctx context.Context, botID, chatID, channelIdenti r.addMemory(ctx, botID, memMsgs, "chat", chatID) } - // Write to private namespace if enabled and channel identity is known. - if cs.EnablePrivateMemory && strings.TrimSpace(channelIdentityID) != "" { - r.addMemory(ctx, botID, memMsgs, "private", channelIdentityID) + // Write to private namespace if enabled and user id is known. + if cs.EnablePrivateMemory && strings.TrimSpace(userID) != "" { + r.addMemory(ctx, botID, memMsgs, "private", userID) } // Write to public namespace if enabled. diff --git a/internal/chat/service.go b/internal/chat/service.go index 11ebbf8b..863db8c8 100644 --- a/internal/chat/service.go +++ b/internal/chat/service.go @@ -70,7 +70,10 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r } } - metadata, _ := json.Marshal(nonNilMap(req.Metadata)) + metadata, err := json.Marshal(nonNilMap(req.Metadata)) + if err != nil { + return Chat{}, fmt.Errorf("marshal chat metadata: %w", err) + } row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{ BotID: pgBotID, @@ -393,7 +396,10 @@ func (s *Service) CreateRoute(ctx context.Context, chatID string, r Route) (Rout return Route{}, err } } - metadata, _ := json.Marshal(nonNilMap(r.Metadata)) + metadata, err := json.Marshal(nonNilMap(r.Metadata)) + if err != nil { + return Route{}, fmt.Errorf("marshal route metadata: %w", err) + } row, err := s.queries.CreateChatRoute(ctx, sqlc.CreateChatRouteParams{ ChatID: pgChatID, BotID: pgBotID, @@ -488,7 +494,10 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation if err == nil { // Route found, ensure the sender identity is a participant. if strings.TrimSpace(channelIdentityID) != "" { - ok, _ := s.IsParticipant(ctx, route.ChatID, channelIdentityID) + ok, checkErr := s.IsParticipant(ctx, route.ChatID, channelIdentityID) + if checkErr != nil { + return ResolveChatResult{}, fmt.Errorf("check chat participant: %w", checkErr) + } if !ok { if _, err := s.AddParticipant(ctx, route.ChatID, channelIdentityID, RoleMember); err != nil { s.logger.Warn("auto-add participant failed", slog.Any("error", err)) @@ -497,9 +506,17 @@ func (s *Service) ResolveChat(ctx context.Context, botID, platform, conversation } // Update reply target if changed. if strings.TrimSpace(replyTarget) != "" && replyTarget != route.ReplyTarget { - _ = s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget) + if err := s.UpdateRouteReplyTarget(ctx, route.ID, replyTarget); err != nil && s.logger != nil { + s.logger.Warn("update route reply target failed", slog.Any("error", err)) + } + } + pgRouteChatID, parseErr := parseUUID(route.ChatID) + if parseErr != nil { + return ResolveChatResult{}, fmt.Errorf("parse route chat id: %w", parseErr) + } + if err := s.queries.TouchChat(ctx, pgRouteChatID); err != nil && s.logger != nil { + s.logger.Warn("touch chat failed", slog.Any("error", err)) } - _ = s.queries.TouchChat(ctx, mustParseUUID(route.ChatID)) return ResolveChatResult{ChatID: route.ChatID, RouteID: route.ID, Created: false}, nil } @@ -564,13 +581,22 @@ func (s *Service) PersistMessage(ctx context.Context, chatID, botID, routeID, se } var pgSender pgtype.UUID if strings.TrimSpace(senderChannelIdentityID) != "" { - pgSender, _ = parseUUID(senderChannelIdentityID) + pgSender, err = parseUUID(senderChannelIdentityID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender channel identity id: %w", err) + } } var pgSenderUser pgtype.UUID if strings.TrimSpace(senderUserID) != "" { - pgSenderUser, _ = parseUUID(senderUserID) + pgSenderUser, err = parseUUID(senderUserID) + if err != nil { + return Message{}, fmt.Errorf("invalid sender user id: %w", err) + } + } + metaBytes, err := json.Marshal(nonNilMap(metadata)) + if err != nil { + return Message{}, fmt.Errorf("marshal message metadata: %w", err) } - metaBytes, _ := json.Marshal(nonNilMap(metadata)) if len(content) == 0 { content = []byte("{}") } @@ -813,11 +839,6 @@ func pgTimePtr(ts pgtype.Timestamptz) *time.Time { return &value } -func mustParseUUID(id string) pgtype.UUID { - pgID, _ := parseUUID(id) - return pgID -} - func nonNilMap(m map[string]any) map[string]any { if m == nil { return map[string]any{} diff --git a/internal/chat/service_presence_integration_test.go b/internal/chat/service_presence_integration_test.go index c8856b54..69ced70b 100644 --- a/internal/chat/service_presence_integration_test.go +++ b/internal/chat/service_presence_integration_test.go @@ -7,7 +7,6 @@ import ( "fmt" "log/slog" "os" - "strings" "testing" "time" @@ -56,17 +55,6 @@ func setupChatPresenceIntegrationTest(t *testing.T) chatPresenceFixture { } } -func isLegacyChatPresenceSchemaError(err error) bool { - if err == nil { - return false - } - msg := strings.ToLower(err.Error()) - return strings.Contains(msg, "relation \"chat_channelIdentity_presence\" does not exist") || - strings.Contains(msg, "column \"user_id\" of relation \"channelidentities\" does not exist") || - strings.Contains(msg, "column \"sender_user_id\" of relation \"chat_messages\" does not exist") || - strings.Contains(msg, "column \"created_by_user_id\" of relation \"chats\" does not exist") -} - func createUserForChatPresence(ctx context.Context, queries *sqlc.Queries) (string, error) { row, err := queries.CreateUser(ctx, sqlc.CreateUserParams{ IsActive: true, @@ -83,7 +71,10 @@ func createBotForChatPresence(ctx context.Context, queries *sqlc.Queries, ownerU if err != nil { return "", err } - meta, _ := json.Marshal(map[string]any{"source": "chat-presence-integration-test"}) + meta, err := json.Marshal(map[string]any{"source": "chat-presence-integration-test"}) + if err != nil { + return "", err + } row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ OwnerUserID: pgOwnerID, Type: "personal", @@ -105,10 +96,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin ownerUserID, err := createUserForChatPresence(ctx, fixture.queries) if err != nil { - if isLegacyChatPresenceSchemaError(err) { - fixture.cleanup() - t.Skipf("skip integration test on legacy schema: %v", err) - } fixture.cleanup() t.Fatalf("create owner user failed: %v", err) } @@ -128,10 +115,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin Title: "presence-observed", }) if err != nil { - if isLegacyChatPresenceSchemaError(err) { - fixture.cleanup() - t.Skipf("skip integration test on legacy schema: %v", err) - } fixture.cleanup() t.Fatalf("create chat failed: %v", err) } @@ -143,10 +126,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin "presence-observer", ) if err != nil { - if isLegacyChatPresenceSchemaError(err) { - fixture.cleanup() - t.Skipf("skip integration test on legacy schema: %v", err) - } fixture.cleanup() t.Fatalf("resolve channelIdentity failed: %v", err) } @@ -165,10 +144,6 @@ func setupObservedChatScenario(t *testing.T) (chatPresenceFixture, string, strin nil, ) if err != nil { - if isLegacyChatPresenceSchemaError(err) { - fixture.cleanup() - t.Skipf("skip integration test on legacy schema: %v", err) - } fixture.cleanup() t.Fatalf("persist message failed: %v", err) } diff --git a/internal/chat/types.go b/internal/chat/types.go index 77cda9be..4cd9d66d 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -234,15 +234,16 @@ type ToolCallFunction struct { // ChatRequest is the input for Chat and StreamChat. type ChatRequest struct { - BotID string `json:"-"` - ChatID string `json:"-"` - Token string `json:"-"` - ChannelIdentityID string `json:"-"` - ContainerID string `json:"-"` - DisplayName string `json:"-"` - RouteID string `json:"-"` - ChatToken string `json:"-"` - ExternalMessageID string `json:"-"` + BotID string `json:"-"` + ChatID string `json:"-"` + Token string `json:"-"` + UserID string `json:"-"` + SourceChannelIdentityID string `json:"-"` + ContainerID string `json:"-"` + DisplayName string `json:"-"` + RouteID string `json:"-"` + ChatToken string `json:"-"` + ExternalMessageID string `json:"-"` Query string `json:"query"` Model string `json:"model,omitempty"` diff --git a/internal/handlers/bind.go b/internal/handlers/bind.go index 00fc9933..0c106e08 100644 --- a/internal/handlers/bind.go +++ b/internal/handlers/bind.go @@ -53,7 +53,7 @@ func (h *BindHandler) Issue(c echo.Context) error { if h.service == nil { return echo.NewHTTPError(http.StatusServiceUnavailable, "bind service not available") } - channelIdentityID, err := h.requireChannelIdentityID(c) + userID, err := h.requireUserID(c) if err != nil { return err } @@ -68,7 +68,7 @@ func (h *BindHandler) Issue(c echo.Context) error { ttl = time.Duration(req.TTLSeconds) * time.Second } - code, err := h.service.Issue(c.Request().Context(), channelIdentityID, strings.TrimSpace(req.Platform), ttl) + code, err := h.service.Issue(c.Request().Context(), userID, strings.TrimSpace(req.Platform), ttl) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -79,13 +79,13 @@ func (h *BindHandler) Issue(c echo.Context) error { }) } -func (h *BindHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) +func (h *BindHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { + if err := identity.ValidateChannelIdentityID(userID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return channelIdentityID, nil + return userID, nil } diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index 74655dd1..9f8ec869 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -161,7 +161,7 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error { } func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/handlers/chat.go b/internal/handlers/chat.go index b8fc91aa..ec4b50da 100644 --- a/internal/handlers/chat.go +++ b/internal/handlers/chat.go @@ -189,7 +189,8 @@ func (h *ChatHandler) SendMessage(c echo.Context) error { req.BotID = chatObj.BotID req.ChatID = chatID req.Token = c.Request().Header.Get("Authorization") - req.ChannelIdentityID = channelIdentityID + req.UserID = channelIdentityID + req.SourceChannelIdentityID = channelIdentityID resp, err := h.resolver.Chat(c.Request().Context(), req) if err != nil { @@ -224,7 +225,8 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error { req.BotID = chatObj.BotID req.ChatID = chatID req.Token = c.Request().Header.Get("Authorization") - req.ChannelIdentityID = channelIdentityID + req.UserID = channelIdentityID + req.SourceChannelIdentityID = channelIdentityID c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") @@ -258,7 +260,10 @@ func (h *ChatHandler) StreamMessage(c echo.Context) error { if err != nil { h.logger.Error("chat stream failed", slog.Any("error", err)) errData := map[string]string{"error": err.Error()} - data, _ := json.Marshal(errData) + data, marshalErr := json.Marshal(errData) + if marshalErr != nil { + return echo.NewHTTPError(http.StatusInternalServerError, marshalErr.Error()) + } writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))) writer.Flush() flusher.Flush() @@ -500,7 +505,7 @@ func (h *ChatHandler) ListThreads(c echo.Context) error { // --- helpers --- func (h *ChatHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } @@ -558,7 +563,10 @@ func (h *ChatHandler) requireParticipant(ctx context.Context, chatID, channelIde } // Admin bypass. if h.accountService != nil { - isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } if isAdmin { return nil } @@ -579,7 +587,10 @@ func (h *ChatHandler) requireReadable(ctx context.Context, chatID, channelIdenti } // Admin bypass. if h.accountService != nil { - isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } if isAdmin { return nil } @@ -600,7 +611,10 @@ func (h *ChatHandler) requireRole(ctx context.Context, chatID, channelIdentityID } // Admin bypass. if h.accountService != nil { - isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } if isAdmin { return nil } diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index 603efa13..6abdbdcb 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -627,7 +627,7 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) { } func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/handlers/embeddings.go b/internal/handlers/embeddings.go index ab3ea517..63b788e6 100644 --- a/internal/handlers/embeddings.go +++ b/internal/handlers/embeddings.go @@ -85,10 +85,10 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL) req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL) - channelIdentityID := "" + userID := "" if c.Get("user") != nil { - if value, err := auth.ChannelIdentityIDFromContext(c); err == nil { - channelIdentityID = value + if value, err := auth.UserIDFromContext(c); err == nil { + userID = value } } result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{ @@ -101,7 +101,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { ImageURL: req.Input.ImageURL, VideoURL: req.Input.VideoURL, }, - ChannelIdentityID: channelIdentityID, + ChannelIdentityID: userID, }) if err != nil { message := err.Error() diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index d80a5ae9..c5c66958 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -228,7 +228,7 @@ func (h *LocalChannelHandler) ensureChatParticipant(ctx context.Context, chatID, } func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go index 0a512db2..784150d4 100644 --- a/internal/handlers/mcp.go +++ b/internal/handlers/mcp.go @@ -218,7 +218,7 @@ func (h *MCPHandler) Delete(c echo.Context) error { } func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) { - userID, err := auth.ChannelIdentityIDFromContext(c) + userID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index b63cbebb..fff1ee0d 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -364,7 +364,10 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan return echo.NewHTTPError(http.StatusInternalServerError, "chat service not configured") } if h.accountService != nil { - isAdmin, _ := h.accountService.IsAdmin(ctx, channelIdentityID) + isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } if isAdmin { return nil } @@ -380,7 +383,7 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan } func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/handlers/models.go b/internal/handlers/models.go index 357845f1..0dbd8af4 100644 --- a/internal/handlers/models.go +++ b/internal/handlers/models.go @@ -164,7 +164,7 @@ func (h *ModelsHandler) Enable(c echo.Context) error { if h.settingsService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured") } - userID, err := auth.ChannelIdentityIDFromContext(c) + userID, err := auth.UserIDFromContext(c) if err != nil { return err } diff --git a/internal/handlers/preauth.go b/internal/handlers/preauth.go index 5213859f..2f5ed413 100644 --- a/internal/handlers/preauth.go +++ b/internal/handlers/preauth.go @@ -40,7 +40,7 @@ type preauthIssueRequest struct { } func (h *PreauthHandler) Issue(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) + userID, err := h.requireUserID(c) if err != nil { return err } @@ -48,7 +48,7 @@ func (h *PreauthHandler) Issue(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { return err } var req preauthIssueRequest @@ -59,33 +59,33 @@ func (h *PreauthHandler) Issue(c echo.Context) error { if req.TTLSeconds > 0 { ttl = time.Duration(req.TTLSeconds) * time.Second } - key, err := h.service.Issue(c.Request().Context(), botID, channelIdentityID, ttl) + key, err := h.service.Issue(c.Request().Context(), botID, userID, ttl) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, key) } -func (h *PreauthHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) +func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { + if err := identity.ValidateChannelIdentityID(userID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return channelIdentityID, nil + return userID, nil } -func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { +func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) { if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + isAdmin, err := h.accountService.IsAdmin(ctx, userID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/schedule.go b/internal/handlers/schedule.go index 4784346a..42643408 100644 --- a/internal/handlers/schedule.go +++ b/internal/handlers/schedule.go @@ -51,7 +51,7 @@ func (h *ScheduleHandler) Register(e *echo.Echo) { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule [post] func (h *ScheduleHandler) Create(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) + userID, err := h.requireUserID(c) if err != nil { return err } @@ -59,7 +59,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { return err } var req schedule.CreateRequest @@ -82,7 +82,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule [get] func (h *ScheduleHandler) List(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) + userID, err := h.requireUserID(c) if err != nil { return err } @@ -90,7 +90,7 @@ func (h *ScheduleHandler) List(c echo.Context) error { if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { return err } items, err := h.service.List(c.Request().Context(), botID) @@ -111,7 +111,7 @@ func (h *ScheduleHandler) List(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule/{id} [get] func (h *ScheduleHandler) Get(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) + userID, err := h.requireUserID(c) if err != nil { return err } @@ -130,7 +130,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { return err } return c.JSON(http.StatusOK, item) @@ -147,7 +147,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule/{id} [put] func (h *ScheduleHandler) Update(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) + userID, err := h.requireUserID(c) if err != nil { return err } @@ -170,7 +170,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { return err } resp, err := h.service.Update(c.Request().Context(), id, req) @@ -190,7 +190,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error { // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/schedule/{id} [delete] func (h *ScheduleHandler) Delete(c echo.Context) error { - channelIdentityID, err := h.requireChannelIdentityID(c) + userID, err := h.requireUserID(c) if err != nil { return err } @@ -209,7 +209,7 @@ func (h *ScheduleHandler) Delete(c echo.Context) error { if item.BotID != botID { return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") } - if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil { + if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil { return err } if err := h.service.Delete(c.Request().Context(), id); err != nil { @@ -218,26 +218,26 @@ func (h *ScheduleHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *ScheduleHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) +func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) { + userID, err := auth.UserIDFromContext(c) if err != nil { return "", err } - if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil { + if err := identity.ValidateChannelIdentityID(userID); err != nil { return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return channelIdentityID, nil + return userID, nil } -func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { +func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) { if h.botService == nil || h.accountService == nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured") } - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) + isAdmin, err := h.accountService.IsAdmin(ctx, userID) if err != nil { return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) + bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false}) if err != nil { if errors.Is(err, bots.ErrBotNotFound) { return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found") diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index 78e7572c..7c3d57d0 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -130,7 +130,7 @@ func (h *SettingsHandler) Delete(c echo.Context) error { } func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/handlers/subagent.go b/internal/handlers/subagent.go index 0d67ad95..40d86c15 100644 --- a/internal/handlers/subagent.go +++ b/internal/handlers/subagent.go @@ -433,7 +433,7 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error { } func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 973931b7..688af741 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -907,7 +907,7 @@ func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID } func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) { - channelIdentityID, err := auth.ChannelIdentityIDFromContext(c) + channelIdentityID, err := auth.UserIDFromContext(c) if err != nil { return "", err } diff --git a/internal/router/channel.go b/internal/router/channel.go index e86474ed..b5c23e40 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -167,17 +167,18 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel desc, _ = p.registry.GetDescriptor(msg.Channel) } resp, err := p.chat.Chat(ctx, chat.ChatRequest{ - BotID: identity.BotID, - ChatID: resolved.ChatID, - Token: token, - ChannelIdentityID: identity.UserID, - DisplayName: identity.DisplayName, - RouteID: resolved.RouteID, - ChatToken: chatToken, - ExternalMessageID: strings.TrimSpace(msg.Message.ID), - Query: text, - CurrentChannel: msg.Channel.String(), - Channels: []string{msg.Channel.String()}, + BotID: identity.BotID, + ChatID: resolved.ChatID, + Token: token, + UserID: identity.UserID, + SourceChannelIdentityID: identity.ChannelIdentityID, + DisplayName: identity.DisplayName, + RouteID: resolved.RouteID, + ChatToken: chatToken, + ExternalMessageID: strings.TrimSpace(msg.Message.ID), + Query: text, + CurrentChannel: msg.Channel.String(), + Channels: []string{msg.Channel.String()}, }) if err != nil { if p.logger != nil { diff --git a/internal/router/channel_test.go b/internal/router/channel_test.go index 6d8c130b..0b7b666c 100644 --- a/internal/router/channel_test.go +++ b/internal/router/channel_test.go @@ -97,8 +97,11 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) { if gateway.gotReq.Query != "hello" { t.Errorf("expected query 'hello', got: %s", gateway.gotReq.Query) } - if gateway.gotReq.ChannelIdentityID != "channelIdentity-1" { - t.Errorf("expected channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.ChannelIdentityID) + if gateway.gotReq.UserID != "channelIdentity-1" { + t.Errorf("expected user_id 'channelIdentity-1', got: %s", gateway.gotReq.UserID) + } + if gateway.gotReq.SourceChannelIdentityID != "channelIdentity-1" { + t.Errorf("expected source_channel_identity_id 'channelIdentity-1', got: %s", gateway.gotReq.SourceChannelIdentityID) } if gateway.gotReq.ChatID != "chat-1" { t.Errorf("expected chat_id 'chat-1', got: %s", gateway.gotReq.ChatID) diff --git a/internal/router/identity.go b/internal/router/identity.go index f9e3fcca..e4aea0b1 100644 --- a/internal/router/identity.go +++ b/internal/router/identity.go @@ -209,7 +209,9 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi state.Decision = &decision return state, err } - if r.policy != nil && isGroupConversationType(msg.Conversation.Type) { + + // Personal bots are owner-only and must not depend on member/guest/preauth bypass. + if r.policy != nil { botType, err := r.policy.BotType(ctx, botID) if err != nil { return state, err @@ -219,20 +221,35 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi if err != nil { return state, err } - if strings.TrimSpace(state.Identity.UserID) == "" || strings.TrimSpace(ownerUserID) != strings.TrimSpace(state.Identity.UserID) { - // Personal bots in group chats only answer owner messages. - state.Decision = &IdentityDecision{Stop: true} + isOwner := strings.TrimSpace(state.Identity.UserID) != "" && + strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID) + if !isOwner { + if isGroupConversationType(msg.Conversation.Type) { + // Ignore non-owner group messages for personal bots. + state.Decision = &IdentityDecision{Stop: true} + return state, nil + } + state.Decision = &IdentityDecision{ + Stop: true, + Reply: channel.Message{Text: r.unboundReply}, + } return state, nil } - // Owner can chat normally in group for personal bots. - state.Identity.ForceReply = true + if isGroupConversationType(msg.Conversation.Type) { + // Owner can chat in group for personal bots. + state.Identity.ForceReply = true + } + return state, nil } } // Phase 2: Authorization (bot membership check). if r.members != nil { if strings.TrimSpace(state.Identity.UserID) != "" { - isMember, _ := r.members.IsMember(ctx, botID, state.Identity.UserID) + isMember, err := r.members.IsMember(ctx, botID, state.Identity.UserID) + if err != nil { + return state, fmt.Errorf("check bot membership: %w", err) + } if isMember { return state, nil } @@ -256,9 +273,6 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi return state, err } if allowed { - if r.members != nil && strings.TrimSpace(state.Identity.UserID) != "" { - _ = r.members.UpsertMemberSimple(ctx, botID, state.Identity.UserID, "member") - } return state, nil } } @@ -312,9 +326,13 @@ func (r *IdentityResolver) tryHandlePreauthKey(ctx context.Context, msg channel. return true, reply("Current channel account is not linked to a user."), nil } if r.members != nil { - _ = r.members.UpsertMemberSimple(ctx, botID, userID, "member") + if err := r.members.UpsertMemberSimple(ctx, botID, userID, "member"); err != nil { + return true, IdentityDecision{}, fmt.Errorf("upsert preauth member: %w", err) + } + } + if _, err := r.preauth.MarkUsed(ctx, key.ID); err != nil { + return true, IdentityDecision{}, fmt.Errorf("mark preauth key used: %w", err) } - _, _ = r.preauth.MarkUsed(ctx, key.ID) return true, reply(r.preauthReply), nil } @@ -365,7 +383,11 @@ func (r *IdentityResolver) tryHandleBindCode(ctx context.Context, msg channel.In // Resolve linked user after binding. newUserID := code.IssuedByUserID if r.channelIdentities != nil { - if linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID); err == nil && strings.TrimSpace(linkedUserID) != "" { + linkedUserID, err := r.channelIdentities.GetLinkedUserID(ctx, channelIdentityID) + if err != nil { + return true, IdentityDecision{}, "", fmt.Errorf("resolve linked user after bind: %w", err) + } + if strings.TrimSpace(linkedUserID) != "" { newUserID = linkedUserID } } diff --git a/internal/router/identity_test.go b/internal/router/identity_test.go index 1e2b8d17..45549492 100644 --- a/internal/router/identity_test.go +++ b/internal/router/identity_test.go @@ -145,7 +145,7 @@ func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelCh return f.consumeErr } -func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) { +func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) { channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-1"}} memberSvc := &fakeMemberService{isMember: false} policySvc := &fakePolicyService{allow: true, botType: "public"} @@ -165,8 +165,8 @@ func TestIdentityResolverAllowGuestUpsertsMember(t *testing.T) { if state.Identity.ChannelIdentityID != "channelIdentity-1" { t.Fatalf("expected channelIdentity-1, got: %s", state.Identity.ChannelIdentityID) } - if !memberSvc.upsertCalled { - t.Fatal("expected UpsertMemberSimple to be called") + if memberSvc.upsertCalled { + t.Fatal("guest allow should not upsert membership") } if state.Decision != nil { t.Fatal("expected no decision for allowed guest") @@ -376,6 +376,35 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin } } +func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: channelidentities.ChannelIdentity{ID: "channelIdentity-non-owner"}} + memberSvc := &fakeMemberService{isMember: true} + policySvc := &fakePolicyService{allow: true, botType: "personal", ownerUserID: "channelIdentity-owner"} + resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, memberSvc, policySvc, nil, nil, "Access denied.", "") + + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("feishu"), + Message: channel.Message{Text: "hello from non-owner"}, + Sender: channel.Identity{SubjectID: "ext-non-owner"}, + Conversation: channel.Conversation{ + ID: "p2p-2", + Type: "p2p", + }, + } + + state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state.Decision == nil || !state.Decision.Stop { + t.Fatal("non-owner direct message should be rejected for personal bot") + } + if state.Decision.Reply.Text != "Access denied." { + t.Fatalf("unexpected reject message: %s", state.Decision.Reply.Text) + } +} + func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) { shadowID := "channelIdentity-shadow" humanID := "channelIdentity-human" diff --git a/internal/schedule/service_test.go b/internal/schedule/service_test.go index aca684b6..15ec87fc 100644 --- a/internal/schedule/service_test.go +++ b/internal/schedule/service_test.go @@ -55,8 +55,8 @@ func TestGenerateTriggerToken(t *testing.T) { if sub, _ := claims["sub"].(string); sub != userID { t.Errorf("expected sub=%s, got=%s", userID, sub) } - if uid, _ := claims["channel_identity_id"].(string); uid != userID { - t.Errorf("expected channel_identity_id=%s, got=%s", userID, uid) + if uid, _ := claims["user_id"].(string); uid != userID { + t.Errorf("expected user_id=%s, got=%s", userID, uid) } exp, _ := claims["exp"].(float64) if exp == 0 {