refactor: remove bot type

This commit is contained in:
Acbox
2026-03-15 00:42:09 +08:00
parent 6741f3f3f1
commit ac8a935545
45 changed files with 163 additions and 453 deletions
+16 -17
View File
@@ -59,26 +59,25 @@ func makeBotRow(botID, ownerUserID pgtype.UUID) *fakeRow {
scanFunc: func(dest ...any) error {
*dest[0].(*pgtype.UUID) = botID
*dest[1].(*pgtype.UUID) = ownerUserID
*dest[2].(*string) = bots.BotTypePublic
*dest[3].(*pgtype.Text) = pgtype.Text{String: "bot", Valid: true}
*dest[4].(*pgtype.Text) = pgtype.Text{}
*dest[5].(*bool) = true
*dest[6].(*string) = bots.BotStatusReady
*dest[7].(*int32) = 30
*dest[8].(*int32) = 0
*dest[9].(*int32) = 50
*dest[10].(*string) = "auto"
*dest[11].(*bool) = false
*dest[12].(*string) = "medium"
*dest[2].(*pgtype.Text) = pgtype.Text{String: "bot", Valid: true}
*dest[3].(*pgtype.Text) = pgtype.Text{}
*dest[4].(*bool) = true
*dest[5].(*string) = bots.BotStatusReady
*dest[6].(*int32) = 30
*dest[7].(*int32) = 0
*dest[8].(*int32) = 50
*dest[9].(*string) = "auto"
*dest[10].(*bool) = false
*dest[11].(*string) = "medium"
*dest[12].(*pgtype.UUID) = pgtype.UUID{}
*dest[13].(*pgtype.UUID) = pgtype.UUID{}
*dest[14].(*pgtype.UUID) = pgtype.UUID{}
*dest[15].(*pgtype.UUID) = pgtype.UUID{}
*dest[16].(*bool) = false
*dest[17].(*int32) = 30
*dest[18].(*string) = ""
*dest[19].(*[]byte) = []byte(`{}`)
*dest[15].(*bool) = false
*dest[16].(*int32) = 30
*dest[17].(*string) = ""
*dest[18].(*[]byte) = []byte(`{}`)
*dest[19].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
*dest[20].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
*dest[21].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
return nil
},
}
+7 -36
View File
@@ -36,11 +36,6 @@ var (
ErrOwnerUserNotFound = errors.New("owner user not found")
)
// AccessPolicy controls bot access behavior.
type AccessPolicy struct {
AllowGuest bool
}
// NewService creates a new bot service.
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
if log == nil {
@@ -70,8 +65,8 @@ func (s *Service) AddRuntimeChecker(c RuntimeChecker) {
}
}
// AuthorizeAccess checks whether userID may access the given bot.
func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) {
// AuthorizeAccess checks whether userID may access the given bot (owner or admin only).
func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isAdmin bool) (Bot, error) {
if s.queries == nil {
return Bot{}, errors.New("bot queries not configured")
}
@@ -85,11 +80,6 @@ func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isA
if isAdmin || bot.OwnerUserID == userID {
return bot, nil
}
if bot.Type == BotTypePublic {
if policy.AllowGuest {
return bot, nil
}
}
return Bot{}, ErrBotAccessDenied
}
@@ -109,10 +99,6 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR
if err := s.ensureUserExists(ctx, ownerUUID); err != nil {
return Bot{}, err
}
normalizedType, err := normalizeBotType(req.Type)
if err != nil {
return Bot{}, err
}
displayName := strings.TrimSpace(req.DisplayName)
if displayName == "" {
displayName = "bot-" + uuid.NewString()
@@ -132,7 +118,6 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR
}
row, err := s.queries.CreateBot(ctx, sqlc.CreateBotParams{
OwnerUserID: ownerUUID,
Type: normalizedType,
DisplayName: pgtype.Text{String: displayName, Valid: displayName != ""},
AvatarUrl: pgtype.Text{String: avatarURL, Valid: avatarURL != ""},
IsActive: isActive,
@@ -431,33 +416,20 @@ func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) erro
return nil
}
func normalizeBotType(raw string) (string, error) {
normalized := strings.ToLower(strings.TrimSpace(raw))
if normalized == "" {
return BotTypePersonal, nil
}
switch normalized {
case BotTypePersonal, BotTypePublic:
return normalized, nil
default:
return "", fmt.Errorf("invalid bot type: %s", raw)
}
}
func asSQLCBot(v any) sqlc.Bot {
switch r := v.(type) {
case sqlc.Bot:
return r
case sqlc.CreateBotRow:
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, Type: r.Type, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
case sqlc.GetBotByIDRow:
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, Type: r.Type, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
case sqlc.ListBotsByOwnerRow:
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, Type: r.Type, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
case sqlc.UpdateBotProfileRow:
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, Type: r.Type, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
case sqlc.UpdateBotOwnerRow:
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, Type: r.Type, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
return sqlc.Bot{ID: r.ID, OwnerUserID: r.OwnerUserID, DisplayName: r.DisplayName, AvatarUrl: r.AvatarUrl, IsActive: r.IsActive, Status: r.Status, MaxContextLoadTime: r.MaxContextLoadTime, MaxContextTokens: r.MaxContextTokens, MaxInboxItems: r.MaxInboxItems, Language: r.Language, ReasoningEnabled: r.ReasoningEnabled, ReasoningEffort: r.ReasoningEffort, ChatModelID: r.ChatModelID, SearchProviderID: r.SearchProviderID, MemoryProviderID: r.MemoryProviderID, HeartbeatEnabled: r.HeartbeatEnabled, HeartbeatInterval: r.HeartbeatInterval, HeartbeatPrompt: r.HeartbeatPrompt, Metadata: r.Metadata, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt}
default:
return sqlc.Bot{}
}
@@ -487,7 +459,6 @@ func toBot(row sqlc.Bot) (Bot, error) {
return Bot{
ID: row.ID.String(),
OwnerUserID: row.OwnerUserID.String(),
Type: row.Type,
DisplayName: displayName,
AvatarURL: avatarURL,
IsActive: row.IsActive,
+25 -51
View File
@@ -40,40 +40,38 @@ func (d *fakeDBTX) QueryRow(ctx context.Context, sql string, args ...any) pgx.Ro
return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
}
// makeBotRow creates a fakeRow that populates a sqlc.Bot via Scan.
// Column order: id, owner_user_id, type, display_name, avatar_url, is_active, status,
// makeBotRow creates a fakeRow that populates a sqlc.GetBotByIDRow via Scan.
// Column order: id, owner_user_id, display_name, avatar_url, is_active, status,
// max_context_load_time, max_context_tokens, max_inbox_items, language,
// reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id,
// heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at.
func makeBotRow(botID, ownerUserID pgtype.UUID, botType string, allowGuest bool) *fakeRow {
func makeBotRow(botID, ownerUserID pgtype.UUID) *fakeRow {
return &fakeRow{
scanFunc: func(dest ...any) error {
if len(dest) < 22 {
if len(dest) < 21 {
return pgx.ErrNoRows
}
*dest[0].(*pgtype.UUID) = botID
*dest[1].(*pgtype.UUID) = ownerUserID
*dest[2].(*string) = botType
*dest[3].(*pgtype.Text) = pgtype.Text{String: "test-bot", Valid: true}
*dest[4].(*pgtype.Text) = pgtype.Text{}
*dest[5].(*bool) = true
*dest[6].(*string) = BotStatusReady
*dest[7].(*int32) = 30 // MaxContextLoadTime
*dest[8].(*int32) = 4096 // MaxContextTokens
*dest[9].(*int32) = 10 // MaxInboxItems
*dest[10].(*string) = "en"
_ = allowGuest
*dest[11].(*bool) = false // ReasoningEnabled
*dest[12].(*string) = "medium" // ReasoningEffort
*dest[13].(*pgtype.UUID) = pgtype.UUID{} // ChatModelID
*dest[14].(*pgtype.UUID) = pgtype.UUID{} // SearchProviderID
*dest[15].(*pgtype.UUID) = pgtype.UUID{} // MemoryProviderID
*dest[16].(*bool) = false // HeartbeatEnabled
*dest[17].(*int32) = 30 // HeartbeatInterval
*dest[18].(*string) = "" // HeartbeatPrompt
*dest[19].(*[]byte) = []byte(`{}`)
*dest[2].(*pgtype.Text) = pgtype.Text{String: "test-bot", Valid: true}
*dest[3].(*pgtype.Text) = pgtype.Text{}
*dest[4].(*bool) = true
*dest[5].(*string) = BotStatusReady
*dest[6].(*int32) = 30 // MaxContextLoadTime
*dest[7].(*int32) = 4096 // MaxContextTokens
*dest[8].(*int32) = 10 // MaxInboxItems
*dest[9].(*string) = "en"
*dest[10].(*bool) = false // ReasoningEnabled
*dest[11].(*string) = "medium" // ReasoningEffort
*dest[12].(*pgtype.UUID) = pgtype.UUID{} // ChatModelID
*dest[13].(*pgtype.UUID) = pgtype.UUID{} // SearchProviderID
*dest[14].(*pgtype.UUID) = pgtype.UUID{} // MemoryProviderID
*dest[15].(*bool) = false // HeartbeatEnabled
*dest[16].(*int32) = 30 // HeartbeatInterval
*dest[17].(*string) = "" // HeartbeatPrompt
*dest[18].(*[]byte) = []byte(`{}`)
*dest[19].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
*dest[20].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
*dest[21].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
return nil
},
}
@@ -97,47 +95,23 @@ func TestAuthorizeAccess(t *testing.T) {
name string
userID string
isAdmin bool
policy AccessPolicy
botType string
allowGst bool
wantErr bool
wantErrIs error
}{
{
name: "owner always allowed",
userID: ownerID,
policy: AccessPolicy{},
botType: BotTypePublic,
wantErr: false,
},
{
name: "admin always allowed",
userID: strangerID,
isAdmin: true,
policy: AccessPolicy{},
botType: BotTypePublic,
wantErr: false,
},
{
name: "stranger denied without guest on public bot",
name: "stranger denied",
userID: strangerID,
policy: AccessPolicy{AllowGuest: false},
botType: BotTypePublic,
wantErr: true,
wantErrIs: ErrBotAccessDenied,
},
{
name: "stranger allowed when policy allows guest",
userID: strangerID,
policy: AccessPolicy{AllowGuest: true},
botType: BotTypePublic,
wantErr: false,
},
{
name: "guest not allowed on personal bot",
userID: strangerID,
policy: AccessPolicy{AllowGuest: true},
botType: BotTypePersonal,
wantErr: true,
wantErrIs: ErrBotAccessDenied,
},
@@ -148,12 +122,12 @@ func TestAuthorizeAccess(t *testing.T) {
db := &fakeDBTX{
queryRowFunc: func(_ context.Context, _ string, args ...any) pgx.Row {
_ = args
return makeBotRow(botUUID, ownerUUID, tt.botType, tt.allowGst)
return makeBotRow(botUUID, ownerUUID)
},
}
svc := NewService(nil, sqlc.New(db))
_, err := svc.AuthorizeAccess(context.Background(), tt.userID, botID, tt.isAdmin, tt.policy)
_, err := svc.AuthorizeAccess(context.Background(), tt.userID, botID, tt.isAdmin)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
-6
View File
@@ -9,7 +9,6 @@ import (
type Bot struct {
ID string `json:"id"`
OwnerUserID string `json:"owner_user_id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
AvatarURL string `json:"avatar_url,omitempty"`
IsActive bool `json:"is_active"`
@@ -35,7 +34,6 @@ type BotCheck struct {
// CreateBotRequest is the input for creating a bot.
type CreateBotRequest struct {
Type string `json:"type"`
DisplayName string `json:"display_name,omitempty"`
AvatarURL string `json:"avatar_url,omitempty"`
IsActive *bool `json:"is_active,omitempty"`
@@ -77,10 +75,6 @@ type RuntimeChecker interface {
ListChecks(ctx context.Context, botID string) []BotCheck
}
const (
BotTypePersonal = "personal"
BotTypePublic = "public"
)
const (
BotStatusCreating = "creating"
+25 -25
View File
@@ -337,7 +337,7 @@ func (f *fakeChatService) Persist(_ context.Context, input messagepkg.PersistInp
func TestChannelInboundProcessorWithIdentity(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}}
policySvc := &fakePolicyService{allow: false}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-1", RouteID: "route-1"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -383,11 +383,14 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) {
}
}
func TestChannelInboundProcessorDenied(t *testing.T) {
func TestChannelInboundProcessorDeniedByACL(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}}
chatSvc := &fakeChatService{}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-denied", RouteID: "route-denied"}}
gateway := &fakeChatGateway{}
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, nil, nil, "", 0)
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
aclSvc := &fakeChatACL{allowed: false}
processor.SetACLService(aclSvc)
sender := &fakeReplySender{}
cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}
@@ -407,9 +410,6 @@ func TestChannelInboundProcessorDenied(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(sender.sent) != 1 || !strings.Contains(sender.sent[0].Message.PlainText(), "denied") {
t.Fatalf("expected access denied reply, got: %+v", sender.sent)
}
if gateway.gotReq.Query != "" {
t.Error("denied user should not trigger chat call")
}
@@ -417,7 +417,7 @@ func TestChannelInboundProcessorDenied(t *testing.T) {
func TestChannelInboundProcessorACLGuestDeniedDowngradesToNotify(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-acl-deny"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-acl", RouteID: "route-acl"}}
gateway := &fakeChatGateway{}
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
@@ -462,7 +462,7 @@ func TestChannelInboundProcessorACLGuestDeniedDowngradesToNotify(t *testing.T) {
func TestChannelInboundProcessorACLReceivesThreadScope(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-thread-scope"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-thread", RouteID: "route-thread"}}
gateway := &fakeChatGateway{}
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
@@ -501,7 +501,7 @@ func TestChannelInboundProcessorACLReceivesThreadScope(t *testing.T) {
func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}}
policySvc := &fakePolicyService{allow: false}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{}
gateway := &fakeChatGateway{}
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
@@ -570,7 +570,7 @@ func TestBuildInboundQueryAttachmentFallbackWithContainerRefs(t *testing.T) {
func TestChannelInboundProcessorAttachmentOnlyUsesFallbackQuery(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-fallback"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-fallback", RouteID: "route-fallback"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -614,7 +614,7 @@ func TestChannelInboundProcessorAttachmentOnlyUsesFallbackQuery(t *testing.T) {
func TestChannelInboundProcessorSilentReply(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-4", RouteID: "route-4"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -650,7 +650,7 @@ func TestChannelInboundProcessorSilentReply(t *testing.T) {
func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-5"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-5", RouteID: "route-5"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -692,7 +692,7 @@ func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) {
func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-6"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-6", RouteID: "route-6"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -752,7 +752,7 @@ func (s *failingOpenStreamSender) OpenStream(_ context.Context, _ string, _ chan
func TestChannelInboundProcessorDoesNotPersistBeforeOpenStream(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-openstream"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-openstream", RouteID: "route-openstream"}}
gateway := &fakeChatGateway{}
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
@@ -785,7 +785,7 @@ func TestChannelInboundProcessorDoesNotPersistBeforeOpenStream(t *testing.T) {
func TestChannelInboundProcessorPersistsAttachmentAssetRefs(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-asset"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-asset", RouteID: "route-asset"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -838,7 +838,7 @@ func TestChannelInboundProcessorPersistsAttachmentAssetRefs(t *testing.T) {
func TestChannelInboundProcessorIngestsPlatformKeyWithResolver(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-resolver"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-resolver", RouteID: "route-resolver"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -895,7 +895,7 @@ func TestChannelInboundProcessorIngestsPlatformKeyWithResolver(t *testing.T) {
func TestChannelInboundProcessorIngestsBase64Attachment(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-base64"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-base64", RouteID: "route-base64"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -967,7 +967,7 @@ func TestChannelInboundProcessorIngestsBase64Attachment(t *testing.T) {
func TestChannelInboundProcessorIngestsQQFileAttachmentKeepsOriginalExtWhenMimeGeneric(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-qq-file"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-qq-file", RouteID: "route-qq-file"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -1032,7 +1032,7 @@ func TestChannelInboundProcessorIngestsQQFileAttachmentKeepsOriginalExtWhenMimeG
func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-member"}}
policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"}
policySvc := &fakePolicyService{ownerUserID: "channelIdentity-owner"}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-personal-1", RouteID: "route-personal-1"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -1074,7 +1074,7 @@ func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) {
func TestChannelInboundProcessorPersonalGroupOwnerWithoutMentionUsesPassivePersistence(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner"}}
policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"}
policySvc := &fakePolicyService{ownerUserID: "channelIdentity-owner"}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-personal-2", RouteID: "route-personal-2"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -1121,7 +1121,7 @@ func TestChannelInboundProcessorProcessingStatusSuccessLifecycle(t *testing.T) {
registry := channel.NewRegistry()
registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier})
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-1", RouteID: "route-1"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -1178,7 +1178,7 @@ func TestChannelInboundProcessorProcessingStatusFailureLifecycle(t *testing.T) {
registry := channel.NewRegistry()
registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier})
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-2", RouteID: "route-2"}}
chatErr := errors.New("chat gateway unavailable")
gateway := &fakeChatGateway{err: chatErr}
@@ -1223,7 +1223,7 @@ func TestChannelInboundProcessorProcessingStatusErrorsAreBestEffort(t *testing.T
registry := channel.NewRegistry()
registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier})
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-3"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-3", RouteID: "route-3"}}
gateway := &fakeChatGateway{
resp: conversation.ChatResponse{
@@ -1270,7 +1270,7 @@ func TestChannelInboundProcessorProcessingFailedNotifyErrorDoesNotOverrideChatEr
registry := channel.NewRegistry()
registry.MustRegister(&fakeProcessingStatusAdapter{notifier: notifier})
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-4"}}
policySvc := &fakePolicyService{botType: "public"}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-4", RouteID: "route-4"}}
chatErr := errors.New("chat failed")
gateway := &fakeChatGateway{err: chatErr}
+2 -43
View File
@@ -28,7 +28,6 @@ type InboundIdentity struct {
UserID string
DisplayName string
AvatarURL string
BotType string
ForceReply bool
}
@@ -68,7 +67,6 @@ type ChannelIdentityService interface {
// PolicyService resolves access policy for a bot.
type PolicyService interface {
BotType(ctx context.Context, botID string) (string, error)
BotOwnerUserID(ctx context.Context, botID string) (string, error)
}
@@ -182,57 +180,18 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi
return state, err
}
// Personal bots are owner-only and must not depend on member/guest/preauth bypass.
if r.policy != nil {
botType, err := r.policy.BotType(ctx, botID)
if err != nil {
return state, err
}
state.Identity.BotType = botType
if strings.EqualFold(strings.TrimSpace(botType), "personal") {
ownerUserID, err := r.policy.BotOwnerUserID(ctx, botID)
if err != nil {
return state, err
}
isOwner := strings.TrimSpace(state.Identity.UserID) != "" &&
strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID)
if !isOwner {
// Ignore all non-owner messages for personal bots.
state.Decision = &IdentityDecision{Stop: true}
return state, nil
}
// Owner is authorized, but group trigger policy is still decided by
// shouldTriggerAssistantResponse in channel routing.
return state, nil
}
}
// Owner bypass — owner messages always pass identity resolution.
if r.policy != nil && strings.TrimSpace(state.Identity.UserID) != "" {
ownerUserID, err := r.policy.BotOwnerUserID(ctx, botID)
if err != nil {
return state, err
}
// Bot owner should not depend on bot_members linkage.
if strings.TrimSpace(ownerUserID) == strings.TrimSpace(state.Identity.UserID) {
return state, nil
}
}
if strings.EqualFold(strings.TrimSpace(state.Identity.BotType), "public") {
return state, nil
}
// In group conversations, silently drop unauthorized messages to avoid spamming
// the channel with "access denied" replies (same behavior as personal bot non-owner).
if isGroupConversationType(msg.Conversation.Type) {
state.Decision = &IdentityDecision{Stop: true}
return state, nil
}
state.Decision = &IdentityDecision{
Stop: true,
Reply: channel.Message{Text: r.unboundReply},
}
// Non-owner messages pass identity resolution; downstream ACL decides allow/deny.
return state, nil
}
+24 -55
View File
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"log/slog"
"strings"
"testing"
"time"
@@ -69,29 +68,10 @@ func (f *fakeChannelIdentityService) LinkChannelIdentityToUser(_ context.Context
}
type fakePolicyService struct {
allow bool
botType string
ownerUserID string
err error
}
func (f *fakePolicyService) AllowGuest(_ context.Context, _ string) (bool, error) {
if f.err != nil {
return false, f.err
}
return f.allow, nil
}
func (f *fakePolicyService) BotType(_ context.Context, _ string) (string, error) {
if f.err != nil {
return "", f.err
}
if strings.TrimSpace(f.botType) == "" {
return "public", nil
}
return f.botType, nil
}
func (f *fakePolicyService) BotOwnerUserID(_ context.Context, _ string) (string, error) {
if f.err != nil {
return "", f.err
@@ -163,7 +143,7 @@ func (f *fakeDirectoryAdapter) ResolveEntry(ctx context.Context, cfg channel.Cha
func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-1"}}
policySvc := &fakePolicyService{allow: true, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -207,7 +187,7 @@ func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) {
}
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-directory"}}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), registry, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -247,7 +227,7 @@ func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testin
}
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-directory-fail"}}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), registry, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -282,7 +262,7 @@ func TestIdentityResolverFeishuUsesOpenIDAsCanonicalSubject(t *testing.T) {
"u-userid": {ID: "channelIdentity-userid"},
},
}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -327,7 +307,7 @@ func TestIdentityResolverDirectoryAvatarURLPropagated(t *testing.T) {
}
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-avatar"}}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), registry, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -360,7 +340,7 @@ func TestIdentityResolverDirectoryAvatarURLPropagated(t *testing.T) {
func TestIdentityResolverExistingMemberPasses(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-2"}}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -381,7 +361,7 @@ func TestIdentityResolverExistingMemberPasses(t *testing.T) {
func TestIdentityResolverPublicBotGuestPasses(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-5"}}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "Access denied.")
msg := channel.InboundMessage{
@@ -400,9 +380,9 @@ func TestIdentityResolverPublicBotGuestPasses(t *testing.T) {
}
}
func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) {
func TestIdentityResolverNonOwnerGroupMessagePassesToACL(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-group"}}
policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"}
policySvc := &fakePolicyService{ownerUserID: "channelIdentity-owner"}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -420,20 +400,14 @@ func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if state.Decision == nil || !state.Decision.Stop {
t.Fatal("personal bot should reject group messages")
}
if channelIdentitySvc.calls != 1 {
t.Fatalf("expected channelIdentity resolution once before owner check, got %d", channelIdentitySvc.calls)
}
if !state.Decision.Reply.IsEmpty() {
t.Fatal("non-owner group message should be silently ignored")
if state.Decision != nil {
t.Fatal("non-owner group message should pass identity resolution (ACL decides later)")
}
}
func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner"}}
policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner"}
policySvc := &fakePolicyService{ownerUserID: "channelIdentity-owner"}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -461,7 +435,7 @@ func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) {
func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-owner-direct"}}
policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "channelIdentity-owner-direct"}
policySvc := &fakePolicyService{ownerUserID: "channelIdentity-owner-direct"}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -487,7 +461,7 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin
}
}
func TestIdentityResolverPersonalBotDoesNotFallbackToFeishuUserID(t *testing.T) {
func TestIdentityResolverFeishuUnlinkedOpenIDPassesToACL(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{
bySubject: map[string]identities.ChannelIdentity{
"ou-open-owner": {ID: "channelIdentity-open-owner"},
@@ -497,7 +471,7 @@ func TestIdentityResolverPersonalBotDoesNotFallbackToFeishuUserID(t *testing.T)
"channelIdentity-user-owner": "owner-user-1",
},
}
policySvc := &fakePolicyService{allow: false, botType: "personal", ownerUserID: "owner-user-1"}
policySvc := &fakePolicyService{ownerUserID: "owner-user-1"}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "")
msg := channel.InboundMessage{
@@ -521,11 +495,9 @@ func TestIdentityResolverPersonalBotDoesNotFallbackToFeishuUserID(t *testing.T)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if state.Decision == nil || !state.Decision.Stop {
t.Fatal("personal bot should deny when only feishu user_id is linked")
}
if state.Identity.UserID != "" {
t.Fatalf("expected no linked owner user via user_id fallback, got: %s", state.Identity.UserID)
// Without linked user, non-owner messages pass identity resolution; ACL decides later.
if state.Decision != nil {
t.Fatal("unlinked user should pass identity resolution (ACL decides later)")
}
if state.Identity.ChannelIdentityID != "channelIdentity-open-owner" {
t.Fatalf("expected open_id identity, got: %s", state.Identity.ChannelIdentityID)
@@ -535,9 +507,9 @@ func TestIdentityResolverPersonalBotDoesNotFallbackToFeishuUserID(t *testing.T)
}
}
func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing.T) {
func TestIdentityResolverNonOwnerDirectPassesToACL(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-non-owner"}}
policySvc := &fakePolicyService{allow: true, botType: "personal", ownerUserID: "channelIdentity-owner"}
policySvc := &fakePolicyService{ownerUserID: "channelIdentity-owner"}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "Access denied.")
msg := channel.InboundMessage{
@@ -555,11 +527,8 @@ func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if state.Decision == nil || !state.Decision.Stop {
t.Fatal("non-owner direct message should be rejected for personal bot")
}
if !state.Decision.Reply.IsEmpty() {
t.Fatal("non-owner direct message should be silently ignored")
if state.Decision != nil {
t.Fatal("non-owner direct message should pass identity resolution (ACL decides later)")
}
}
@@ -685,7 +654,7 @@ func TestIdentityResolverBindCodeNotScopedToCurrentBot(t *testing.T) {
func TestIdentityResolverPublicBotGroupGuestPasses(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-group-denied"}}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "Access denied.")
msg := channel.InboundMessage{
@@ -710,7 +679,7 @@ func TestIdentityResolverPublicBotGroupGuestPasses(t *testing.T) {
func TestIdentityResolverPublicBotDirectGuestPasses(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-direct-denied"}}
policySvc := &fakePolicyService{allow: false, botType: "public"}
policySvc := &fakePolicyService{}
resolver := NewIdentityResolver(slog.Default(), nil, channelIdentitySvc, policySvc, nil, "Access denied.")
msg := channel.InboundMessage{
+7 -19
View File
@@ -12,14 +12,13 @@ import (
)
const createBot = `-- name: CreateBot :one
INSERT INTO bots (owner_user_id, type, display_name, avatar_url, is_active, metadata, status)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
INSERT INTO bots (owner_user_id, display_name, avatar_url, is_active, metadata, status)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, owner_user_id, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
`
type CreateBotParams struct {
OwnerUserID pgtype.UUID `json:"owner_user_id"`
Type string `json:"type"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
IsActive bool `json:"is_active"`
@@ -30,7 +29,6 @@ type CreateBotParams struct {
type CreateBotRow struct {
ID pgtype.UUID `json:"id"`
OwnerUserID pgtype.UUID `json:"owner_user_id"`
Type string `json:"type"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
IsActive bool `json:"is_active"`
@@ -55,7 +53,6 @@ type CreateBotRow struct {
func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (CreateBotRow, error) {
row := q.db.QueryRow(ctx, createBot,
arg.OwnerUserID,
arg.Type,
arg.DisplayName,
arg.AvatarUrl,
arg.IsActive,
@@ -66,7 +63,6 @@ func (q *Queries) CreateBot(ctx context.Context, arg CreateBotParams) (CreateBot
err := row.Scan(
&i.ID,
&i.OwnerUserID,
&i.Type,
&i.DisplayName,
&i.AvatarUrl,
&i.IsActive,
@@ -100,7 +96,7 @@ func (q *Queries) DeleteBotByID(ctx context.Context, id pgtype.UUID) error {
}
const getBotByID = `-- name: GetBotByID :one
SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
SELECT id, owner_user_id, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
FROM bots
WHERE id = $1
`
@@ -108,7 +104,6 @@ WHERE id = $1
type GetBotByIDRow struct {
ID pgtype.UUID `json:"id"`
OwnerUserID pgtype.UUID `json:"owner_user_id"`
Type string `json:"type"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
IsActive bool `json:"is_active"`
@@ -136,7 +131,6 @@ func (q *Queries) GetBotByID(ctx context.Context, id pgtype.UUID) (GetBotByIDRow
err := row.Scan(
&i.ID,
&i.OwnerUserID,
&i.Type,
&i.DisplayName,
&i.AvatarUrl,
&i.IsActive,
@@ -161,7 +155,7 @@ func (q *Queries) GetBotByID(ctx context.Context, id pgtype.UUID) (GetBotByIDRow
}
const listBotsByOwner = `-- name: ListBotsByOwner :many
SELECT id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
SELECT id, owner_user_id, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
FROM bots
WHERE owner_user_id = $1
ORDER BY created_at DESC
@@ -170,7 +164,6 @@ ORDER BY created_at DESC
type ListBotsByOwnerRow struct {
ID pgtype.UUID `json:"id"`
OwnerUserID pgtype.UUID `json:"owner_user_id"`
Type string `json:"type"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
IsActive bool `json:"is_active"`
@@ -204,7 +197,6 @@ func (q *Queries) ListBotsByOwner(ctx context.Context, ownerUserID pgtype.UUID)
if err := rows.Scan(
&i.ID,
&i.OwnerUserID,
&i.Type,
&i.DisplayName,
&i.AvatarUrl,
&i.IsActive,
@@ -280,7 +272,7 @@ UPDATE bots
SET owner_user_id = $2,
updated_at = now()
WHERE id = $1
RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
RETURNING id, owner_user_id, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
`
type UpdateBotOwnerParams struct {
@@ -291,7 +283,6 @@ type UpdateBotOwnerParams struct {
type UpdateBotOwnerRow struct {
ID pgtype.UUID `json:"id"`
OwnerUserID pgtype.UUID `json:"owner_user_id"`
Type string `json:"type"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
IsActive bool `json:"is_active"`
@@ -319,7 +310,6 @@ func (q *Queries) UpdateBotOwner(ctx context.Context, arg UpdateBotOwnerParams)
err := row.Scan(
&i.ID,
&i.OwnerUserID,
&i.Type,
&i.DisplayName,
&i.AvatarUrl,
&i.IsActive,
@@ -351,7 +341,7 @@ SET display_name = $2,
metadata = $5,
updated_at = now()
WHERE id = $1
RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
RETURNING id, owner_user_id, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, max_inbox_items, language, reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at
`
type UpdateBotProfileParams struct {
@@ -365,7 +355,6 @@ type UpdateBotProfileParams struct {
type UpdateBotProfileRow struct {
ID pgtype.UUID `json:"id"`
OwnerUserID pgtype.UUID `json:"owner_user_id"`
Type string `json:"type"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
IsActive bool `json:"is_active"`
@@ -399,7 +388,6 @@ func (q *Queries) UpdateBotProfile(ctx context.Context, arg UpdateBotProfilePara
err := row.Scan(
&i.ID,
&i.OwnerUserID,
&i.Type,
&i.DisplayName,
&i.AvatarUrl,
&i.IsActive,
+8 -8
View File
@@ -15,7 +15,7 @@ const createChat = `-- name: CreateChat :one
SELECT
b.id AS id,
b.id AS bot_id,
(COALESCE(NULLIF($1::text, ''), CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END))::text AS kind,
(COALESCE(NULLIF($1::text, ''), 'direct'))::text AS kind,
CASE WHEN $1 = 'thread' THEN $2::uuid ELSE NULL::uuid END AS parent_chat_id,
COALESCE(NULLIF($3::text, ''), b.display_name) AS title,
COALESCE($4::uuid, b.owner_user_id) AS created_by_user_id,
@@ -94,7 +94,7 @@ const getChatByID = `-- name: GetChatByID :one
SELECT
b.id AS id,
b.id AS bot_id,
CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
'direct'::text AS kind,
NULL::uuid AS parent_chat_id,
b.display_name AS title,
b.owner_user_id AS created_by_user_id,
@@ -264,7 +264,7 @@ const listChatsByBotAndUser = `-- name: ListChatsByBotAndUser :many
SELECT
b.id AS id,
b.id AS bot_id,
CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
'direct'::text AS kind,
NULL::uuid AS parent_chat_id,
b.display_name AS title,
b.owner_user_id AS created_by_user_id,
@@ -332,7 +332,7 @@ const listThreadsByParent = `-- name: ListThreadsByParent :many
SELECT
b.id AS id,
b.id AS bot_id,
CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
'direct'::text AS kind,
NULL::uuid AS parent_chat_id,
b.display_name AS title,
b.owner_user_id AS created_by_user_id,
@@ -394,7 +394,7 @@ const listVisibleChatsByBotAndUser = `-- name: ListVisibleChatsByBotAndUser :man
SELECT
b.id AS id,
b.id AS bot_id,
CASE WHEN b.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
'direct'::text AS kind,
NULL::uuid AS parent_chat_id,
b.display_name AS title,
b.owner_user_id AS created_by_user_id,
@@ -405,7 +405,7 @@ SELECT
'participant'::text AS access_mode,
(CASE
WHEN b.owner_user_id = $1 THEN 'owner'
ELSE COALESCE(bm.role, ''::text)
ELSE ''::text
END)::text AS participant_role,
NULL::timestamptz AS last_observed_at
FROM bots b
@@ -507,12 +507,12 @@ WITH updated AS (
SET display_name = $1,
updated_at = now()
WHERE bots.id = $2
RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, language, reasoning_enabled, reasoning_effort, max_inbox_items, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, heartbeat_model_id, tts_model_id, browser_context_id, metadata, created_at, updated_at
RETURNING id, owner_user_id, display_name, avatar_url, is_active, status, max_context_load_time, max_context_tokens, language, reasoning_enabled, reasoning_effort, max_inbox_items, chat_model_id, search_provider_id, memory_provider_id, heartbeat_enabled, heartbeat_interval, heartbeat_prompt, heartbeat_model_id, tts_model_id, browser_context_id, metadata, created_at, updated_at
)
SELECT
updated.id AS id,
updated.id AS bot_id,
CASE WHEN updated.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
'direct'::text AS kind,
NULL::uuid AS parent_chat_id,
updated.display_name AS title,
updated.owner_user_id AS created_by_user_id,
-1
View File
@@ -11,7 +11,6 @@ import (
type Bot struct {
ID pgtype.UUID `json:"id"`
OwnerUserID pgtype.UUID `json:"owner_user_id"`
Type string `json:"type"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
IsActive bool `json:"is_active"`
+1 -1
View File
@@ -304,7 +304,7 @@ func (h *ACLHandler) requireManageAccess(c echo.Context) (string, string, error)
if botID == "" {
return "", "", echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, actorID, botID, bots.AccessPolicy{}); err != nil {
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, actorID, botID); err != nil {
return "", "", err
}
return botID, actorID, nil
+3 -4
View File
@@ -849,11 +849,11 @@ func (*ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, erro
}
func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
// requireBotAccessWithGuest is like requireBotAccess but also allows guest access
// for public bots when the caller explicitly opts into guest-compatible access.
// via ACL when the caller explicitly opts into guest-compatible access.
func (h *ContainerdHandler) requireBotAccessWithGuest(c echo.Context) (string, error) {
channelIdentityID, err := h.requireChannelIdentityID(c)
if err != nil {
@@ -863,8 +863,7 @@ func (h *ContainerdHandler) requireBotAccessWithGuest(c echo.Context) (string, e
if botID == "" {
return "", echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
policy := bots.AccessPolicy{AllowGuest: true}
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, channelIdentityID, botID, policy); err != nil {
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, channelIdentityID, botID); err != nil {
return "", err
}
return botID, nil
+3 -3
View File
@@ -25,8 +25,8 @@ func RequireChannelIdentityID(c echo.Context) (string, error) {
return channelIdentityID, nil
}
// AuthorizeBotAccess validates that the given identity has access to the specified bot.
func AuthorizeBotAccess(ctx context.Context, botService *bots.Service, accountService *accounts.Service, channelIdentityID, botID string, policy bots.AccessPolicy) (bots.Bot, error) {
// AuthorizeBotAccess validates that the given identity has owner/admin access to the specified bot.
func AuthorizeBotAccess(ctx context.Context, botService *bots.Service, accountService *accounts.Service, channelIdentityID, botID string) (bots.Bot, error) {
if botService == nil || accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
@@ -34,7 +34,7 @@ func AuthorizeBotAccess(ctx context.Context, botService *bots.Service, accountSe
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, policy)
bot, err := botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin)
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
+1 -1
View File
@@ -115,5 +115,5 @@ func (*HeartbeatHandler) requireUserID(c echo.Context) (string, error) {
}
func (h *HeartbeatHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID)
}
+1 -1
View File
@@ -252,7 +252,7 @@ func (h *InboxHandler) Count(c echo.Context) error {
}
func (h *InboxHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
func parseIntOr(s string, fallback int) int {
+1 -1
View File
@@ -453,7 +453,7 @@ func (*LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, er
}
func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowGuest: true})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
// ---------------------------------------------------------------------------
+1 -1
View File
@@ -410,5 +410,5 @@ func (*MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
}
func (h *MCPHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
+1 -1
View File
@@ -245,5 +245,5 @@ func (*MCPOAuthHandler) requireChannelIdentityID(c echo.Context) (string, error)
}
func (h *MCPOAuthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
+1 -1
View File
@@ -664,7 +664,7 @@ func (h *MemoryHandler) requireBotAccess(c echo.Context) (string, error) {
if err != nil {
return "", err
}
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{}); err != nil {
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, channelIdentityID, botID); err != nil {
return "", err
}
return botID, nil
+2 -2
View File
@@ -354,11 +354,11 @@ func (*MessageHandler) requireChannelIdentityID(c echo.Context) (string, error)
}
func (h *MessageHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowGuest: true})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
func (h *MessageHandler) requireReadable(ctx context.Context, conversationID, channelIdentityID string) error {
+1 -1
View File
@@ -220,5 +220,5 @@ func (*ScheduleHandler) requireUserID(c echo.Context) (string, error) {
}
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID)
}
+1 -4
View File
@@ -96,9 +96,6 @@ func (h *SettingsHandler) Upsert(c echo.Context) error {
}
resp, err := h.service.UpsertBot(c.Request().Context(), botID, req)
if err != nil {
if errors.Is(err, settings.ErrPersonalBotGuestAccessUnsupported) {
return echo.NewHTTPError(http.StatusBadRequest, "personal bot does not support guest access")
}
if errors.Is(err, settings.ErrInvalidModelRef) {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
@@ -148,5 +145,5 @@ func (*SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error)
}
func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
+1 -1
View File
@@ -434,5 +434,5 @@ func (*SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error)
}
func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID)
}
+1 -1
View File
@@ -85,7 +85,7 @@ func (h *TokenUsageHandler) GetTokenUsage(c echo.Context) error {
if botID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
}
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, userID, botID, bots.AccessPolicy{}); err != nil {
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, userID, botID); err != nil {
return err
}
+1 -1
View File
@@ -938,7 +938,7 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error {
}
func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
return AuthorizeBotAccess(ctx, h.botService, h.service, channelIdentityID, botID, bots.AccessPolicy{})
return AuthorizeBotAccess(ctx, h.botService, h.service, channelIdentityID, botID)
}
func (*UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
+3 -18
View File
@@ -10,8 +10,7 @@ import (
)
type Decision struct {
BotID string
BotType string
BotID string
}
type Service struct {
@@ -38,24 +37,10 @@ func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) {
if botID == "" {
return Decision{}, errors.New("bot id is required")
}
bot, err := s.bots.Get(ctx, botID)
if err != nil {
if _, err := s.bots.Get(ctx, botID); err != nil {
return Decision{}, err
}
decision := Decision{
BotID: botID,
BotType: strings.TrimSpace(bot.Type),
}
return decision, nil
}
// BotType returns the normalized bot type. Implements router.PolicyService.
func (s *Service) BotType(ctx context.Context, botID string) (string, error) {
decision, err := s.Resolve(ctx, botID)
if err != nil {
return "", err
}
return decision.BotType, nil
return Decision{BotID: botID}, nil
}
// BotOwnerUserID returns bot owner's user id. Implements router.PolicyService.
+2 -10
View File
@@ -24,8 +24,7 @@ type Service struct {
}
var (
ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access")
ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers")
ErrInvalidModelRef = errors.New("invalid model reference")
)
@@ -67,8 +66,6 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest
if err != nil {
return Settings{}, err
}
isPersonalBot := strings.EqualFold(strings.TrimSpace(botRow.Type), "personal")
allowGuest, err := s.allowGuestEnabled(ctx, botID)
if err != nil {
return Settings{}, err
@@ -86,12 +83,7 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest
if strings.TrimSpace(req.Language) != "" {
current.Language = strings.TrimSpace(req.Language)
}
if isPersonalBot {
if req.AllowGuest != nil && *req.AllowGuest {
return Settings{}, ErrPersonalBotGuestAccessUnsupported
}
current.AllowGuest = false
} else if req.AllowGuest != nil {
if req.AllowGuest != nil {
current.AllowGuest = *req.AllowGuest
}
if req.ReasoningEnabled != nil {