mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(access): add guest chat ACL (#235)
This commit is contained in:
@@ -0,0 +1,322 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/acl"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/channel/identities"
|
||||
identitypkg "github.com/memohai/memoh/internal/identity"
|
||||
)
|
||||
|
||||
type ACLHandler struct {
|
||||
service *acl.Service
|
||||
botService *bots.Service
|
||||
accountService *accounts.Service
|
||||
identityService *identities.Service
|
||||
}
|
||||
|
||||
func NewACLHandler(service *acl.Service, botService *bots.Service, accountService *accounts.Service, identityService *identities.Service) *ACLHandler {
|
||||
return &ACLHandler{
|
||||
service: service,
|
||||
botService: botService,
|
||||
accountService: accountService,
|
||||
identityService: identityService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ACLHandler) Register(e *echo.Echo) {
|
||||
group := e.Group("/bots/:bot_id")
|
||||
group.GET("/whitelist", h.ListWhitelist)
|
||||
group.PUT("/whitelist", h.UpsertWhitelist)
|
||||
group.DELETE("/whitelist/:rule_id", h.DeleteWhitelist)
|
||||
group.GET("/blacklist", h.ListBlacklist)
|
||||
group.PUT("/blacklist", h.UpsertBlacklist)
|
||||
group.DELETE("/blacklist/:rule_id", h.DeleteBlacklist)
|
||||
group.GET("/access/users", h.SearchUsers)
|
||||
group.GET("/access/channel_identities", h.SearchChannelIdentities)
|
||||
group.GET("/access/channel_identities/:channel_identity_id/conversations", h.ListObservedConversationsByChannelIdentity)
|
||||
}
|
||||
|
||||
// ListWhitelist godoc
|
||||
// @Summary List bot whitelist
|
||||
// @Description List guest allow rules for chat trigger
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Success 200 {object} acl.ListRulesResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/whitelist [get].
|
||||
func (h *ACLHandler) ListWhitelist(c echo.Context) error {
|
||||
botID, _, err := h.requireManageAccess(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := h.service.ListWhitelist(c.Request().Context(), botID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, acl.ListRulesResponse{Items: items})
|
||||
}
|
||||
|
||||
// UpsertWhitelist godoc
|
||||
// @Summary Upsert bot whitelist entry
|
||||
// @Description Add a guest allow rule for chat trigger
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body acl.UpsertRuleRequest true "Whitelist payload"
|
||||
// @Success 200 {object} acl.Rule
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/whitelist [put].
|
||||
func (h *ACLHandler) UpsertWhitelist(c echo.Context) error {
|
||||
botID, actorID, err := h.requireManageAccess(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req acl.UpsertRuleRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
item, err := h.service.AddWhitelistEntry(c.Request().Context(), botID, actorID, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, acl.ErrInvalidRuleSubject) || errors.Is(err, acl.ErrInvalidSourceScope) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, item)
|
||||
}
|
||||
|
||||
// DeleteWhitelist godoc
|
||||
// @Summary Delete bot whitelist entry
|
||||
// @Description Delete a guest allow rule by rule ID
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param rule_id path string true "Rule ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/whitelist/{rule_id} [delete].
|
||||
func (h *ACLHandler) DeleteWhitelist(c echo.Context) error {
|
||||
if _, _, err := h.requireManageAccess(c); err != nil {
|
||||
return err
|
||||
}
|
||||
ruleID := strings.TrimSpace(c.Param("rule_id"))
|
||||
if ruleID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "rule id is required")
|
||||
}
|
||||
if err := h.service.DeleteRule(c.Request().Context(), ruleID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ListBlacklist godoc
|
||||
// @Summary List bot blacklist
|
||||
// @Description List guest deny rules for chat trigger
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Success 200 {object} acl.ListRulesResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/blacklist [get].
|
||||
func (h *ACLHandler) ListBlacklist(c echo.Context) error {
|
||||
botID, _, err := h.requireManageAccess(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := h.service.ListBlacklist(c.Request().Context(), botID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, acl.ListRulesResponse{Items: items})
|
||||
}
|
||||
|
||||
// UpsertBlacklist godoc
|
||||
// @Summary Upsert bot blacklist entry
|
||||
// @Description Add a guest deny rule for chat trigger
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body acl.UpsertRuleRequest true "Blacklist payload"
|
||||
// @Success 200 {object} acl.Rule
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/blacklist [put].
|
||||
func (h *ACLHandler) UpsertBlacklist(c echo.Context) error {
|
||||
botID, actorID, err := h.requireManageAccess(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req acl.UpsertRuleRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
item, err := h.service.AddBlacklistEntry(c.Request().Context(), botID, actorID, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, acl.ErrInvalidRuleSubject) || errors.Is(err, acl.ErrInvalidSourceScope) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, item)
|
||||
}
|
||||
|
||||
// DeleteBlacklist godoc
|
||||
// @Summary Delete bot blacklist entry
|
||||
// @Description Delete a guest deny rule by rule ID
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param rule_id path string true "Rule ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/blacklist/{rule_id} [delete].
|
||||
func (h *ACLHandler) DeleteBlacklist(c echo.Context) error {
|
||||
if _, _, err := h.requireManageAccess(c); err != nil {
|
||||
return err
|
||||
}
|
||||
ruleID := strings.TrimSpace(c.Param("rule_id"))
|
||||
if ruleID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "rule id is required")
|
||||
}
|
||||
if err := h.service.DeleteRule(c.Request().Context(), ruleID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// SearchUsers godoc
|
||||
// @Summary Search access users
|
||||
// @Description Search user candidates for bot access control
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param q query string false "Search query"
|
||||
// @Param limit query int false "Max results"
|
||||
// @Success 200 {object} acl.UserCandidateListResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/access/users [get].
|
||||
func (h *ACLHandler) SearchUsers(c echo.Context) error {
|
||||
if _, _, err := h.requireManageAccess(c); err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := h.accountService.SearchAccounts(c.Request().Context(), strings.TrimSpace(c.QueryParam("q")), parseLimit(c.QueryParam("limit")))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
result := make([]acl.UserCandidate, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, acl.UserCandidate{
|
||||
ID: item.ID,
|
||||
Username: item.Username,
|
||||
DisplayName: item.DisplayName,
|
||||
AvatarURL: item.AvatarURL,
|
||||
Email: item.Email,
|
||||
})
|
||||
}
|
||||
return c.JSON(http.StatusOK, acl.UserCandidateListResponse{Items: result})
|
||||
}
|
||||
|
||||
// SearchChannelIdentities godoc
|
||||
// @Summary Search access channel identities
|
||||
// @Description Search locally observed channel identity candidates for bot access control
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param q query string false "Search query"
|
||||
// @Param limit query int false "Max results"
|
||||
// @Success 200 {object} acl.ChannelIdentityCandidateListResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/access/channel_identities [get].
|
||||
func (h *ACLHandler) SearchChannelIdentities(c echo.Context) error {
|
||||
if _, _, err := h.requireManageAccess(c); err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := h.identityService.Search(c.Request().Context(), strings.TrimSpace(c.QueryParam("q")), parseLimit(c.QueryParam("limit")))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
result := make([]acl.ChannelIdentityCandidate, 0, len(items))
|
||||
for _, item := range items {
|
||||
result = append(result, acl.ChannelIdentityCandidate{
|
||||
ID: item.ID,
|
||||
UserID: item.UserID,
|
||||
Channel: item.Channel,
|
||||
ChannelSubjectID: item.ChannelSubjectID,
|
||||
DisplayName: item.DisplayName,
|
||||
AvatarURL: item.AvatarURL,
|
||||
LinkedUsername: item.LinkedUsername,
|
||||
LinkedDisplayName: item.LinkedDisplayName,
|
||||
LinkedAvatarURL: item.LinkedAvatarURL,
|
||||
})
|
||||
}
|
||||
return c.JSON(http.StatusOK, acl.ChannelIdentityCandidateListResponse{Items: result})
|
||||
}
|
||||
|
||||
// ListObservedConversationsByChannelIdentity godoc
|
||||
// @Summary List observed conversations for a channel identity
|
||||
// @Description List previously observed conversation candidates for a channel identity under a bot
|
||||
// @Tags bots
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param channel_identity_id path string true "Channel Identity ID"
|
||||
// @Success 200 {object} acl.ObservedConversationCandidateListResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/access/channel_identities/{channel_identity_id}/conversations [get].
|
||||
func (h *ACLHandler) ListObservedConversationsByChannelIdentity(c echo.Context) error {
|
||||
botID, _, err := h.requireManageAccess(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channelIdentityID := strings.TrimSpace(c.Param("channel_identity_id"))
|
||||
if err := identitypkg.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
items, err := h.service.ListObservedConversationsByChannelIdentity(c.Request().Context(), botID, channelIdentityID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, acl.ObservedConversationCandidateListResponse{Items: items})
|
||||
}
|
||||
|
||||
func (h *ACLHandler) requireManageAccess(c echo.Context) (string, string, error) {
|
||||
actorID, err := RequireChannelIdentityID(c)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
botID := strings.TrimSpace(c.Param("bot_id"))
|
||||
if botID == "" {
|
||||
return "", "", echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, actorID, botID, bots.AccessPolicy{}); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return botID, actorID, nil
|
||||
}
|
||||
|
||||
func parseLimit(raw string) int {
|
||||
value, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || value <= 0 {
|
||||
return 50
|
||||
}
|
||||
if value > 200 {
|
||||
return 200
|
||||
}
|
||||
return value
|
||||
}
|
||||
@@ -849,11 +849,11 @@ func (*ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, erro
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
// requireBotAccessWithGuest is like requireBotAccess but also allows guest access
|
||||
// for public bots that have the allow_guest setting enabled.
|
||||
// for public bots when the caller explicitly opts into guest-compatible access.
|
||||
func (h *ContainerdHandler) requireBotAccessWithGuest(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
if err != nil {
|
||||
@@ -863,7 +863,7 @@ func (h *ContainerdHandler) requireBotAccessWithGuest(c echo.Context) (string, e
|
||||
if botID == "" {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
policy := bots.AccessPolicy{AllowPublicMember: true, AllowGuest: true}
|
||||
policy := bots.AccessPolicy{AllowGuest: true}
|
||||
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, channelIdentityID, botID, policy); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -115,5 +115,5 @@ func (*HeartbeatHandler) requireUserID(c echo.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (h *HeartbeatHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ func (h *InboxHandler) Count(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *InboxHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
func parseIntOr(s string, fallback int) int {
|
||||
|
||||
@@ -226,7 +226,7 @@ func (h *LocalChannelHandler) PostMessage(c echo.Context) error {
|
||||
},
|
||||
Conversation: channel.Conversation{
|
||||
ID: routeKey,
|
||||
Type: "p2p",
|
||||
Type: channel.ConversationTypePrivate,
|
||||
},
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
Source: "local",
|
||||
@@ -404,7 +404,7 @@ func (h *LocalChannelHandler) HandleWebSocket(c echo.Context) error {
|
||||
Token: bearerToken,
|
||||
UserID: channelIdentityID,
|
||||
SourceChannelIdentityID: channelIdentityID,
|
||||
ConversationType: "p2p",
|
||||
ConversationType: channel.ConversationTypePrivate,
|
||||
Query: text,
|
||||
CurrentChannel: h.channelType.String(),
|
||||
Channels: []string{h.channelType.String()},
|
||||
@@ -453,7 +453,7 @@ func (*LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, er
|
||||
}
|
||||
|
||||
func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowGuest: true})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -410,5 +410,5 @@ func (*MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (h *MCPHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
@@ -245,5 +245,5 @@ func (*MCPOAuthHandler) requireChannelIdentityID(c echo.Context) (string, error)
|
||||
}
|
||||
|
||||
func (h *MCPOAuthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
@@ -664,7 +664,7 @@ func (h *MemoryHandler) requireBotAccess(c echo.Context) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false}); err != nil {
|
||||
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return botID, nil
|
||||
|
||||
@@ -354,11 +354,11 @@ func (*MessageHandler) requireChannelIdentityID(c echo.Context) (string, error)
|
||||
}
|
||||
|
||||
func (h *MessageHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowGuest: true})
|
||||
}
|
||||
|
||||
func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
func (h *MessageHandler) requireReadable(ctx context.Context, conversationID, channelIdentityID string) error {
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/preauth"
|
||||
)
|
||||
|
||||
type PreauthHandler struct {
|
||||
service *preauth.Service
|
||||
botService *bots.Service
|
||||
accountService *accounts.Service
|
||||
}
|
||||
|
||||
func NewPreauthHandler(service *preauth.Service, botService *bots.Service, accountService *accounts.Service) *PreauthHandler {
|
||||
return &PreauthHandler{
|
||||
service: service,
|
||||
botService: botService,
|
||||
accountService: accountService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) Register(e *echo.Echo) {
|
||||
group := e.Group("/bots/:bot_id/preauth_keys")
|
||||
group.POST("", h.Issue)
|
||||
}
|
||||
|
||||
type preauthIssueRequest struct {
|
||||
TTLSeconds int `json:"ttl_seconds"`
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) Issue(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
botID := strings.TrimSpace(c.Param("bot_id"))
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), userID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
var req preauthIssueRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
ttl := 24 * time.Hour
|
||||
if req.TTLSeconds > 0 {
|
||||
ttl = time.Duration(req.TTLSeconds) * time.Second
|
||||
}
|
||||
key, err := h.service.Issue(c.Request().Context(), botID, userID, ttl)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, key)
|
||||
}
|
||||
|
||||
func (*PreauthHandler) requireUserID(c echo.Context) (string, error) {
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
@@ -220,5 +220,5 @@ func (*ScheduleHandler) requireUserID(c echo.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
@@ -148,5 +148,5 @@ func (*SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error)
|
||||
}
|
||||
|
||||
func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
@@ -434,5 +434,5 @@ func (*SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error)
|
||||
}
|
||||
|
||||
func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func (h *TokenUsageHandler) GetTokenUsage(c echo.Context) error {
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false}); err != nil {
|
||||
if _, err := AuthorizeBotAccess(c.Request().Context(), h.botService, h.accountService, userID, botID, bots.AccessPolicy{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
+1
-107
@@ -75,9 +75,6 @@ func (h *UsersHandler) Register(e *echo.Echo) {
|
||||
botGroup.PUT("/:id", h.UpdateBot)
|
||||
botGroup.PUT("/:id/owner", h.TransferBotOwner)
|
||||
botGroup.DELETE("/:id", h.DeleteBot)
|
||||
botGroup.GET("/:id/members", h.ListBotMembers)
|
||||
botGroup.PUT("/:id/members", h.UpsertBotMember)
|
||||
botGroup.DELETE("/:id/members/:user_id", h.DeleteBotMember)
|
||||
botGroup.GET("/:id/channel/:platform", h.GetBotChannelConfig)
|
||||
botGroup.PUT("/:id/channel/:platform", h.UpsertBotChannelConfig)
|
||||
botGroup.PATCH("/:id/channel/:platform/status", h.UpdateBotChannelStatus)
|
||||
@@ -662,109 +659,6 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
// ListBotMembers godoc
|
||||
// @Summary List bot members
|
||||
// @Description List members for a bot
|
||||
// @Tags bots
|
||||
// @Param id path string true "Bot ID"
|
||||
// @Success 200 {object} bots.ListMembersResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{id}/members [get].
|
||||
func (h *UsersHandler) ListBotMembers(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
botID := strings.TrimSpace(c.Param("id"))
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := h.botService.ListMembers(c.Request().Context(), botID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, bots.ListMembersResponse{Items: items})
|
||||
}
|
||||
|
||||
// UpsertBotMember godoc
|
||||
// @Summary Upsert bot member
|
||||
// @Description Add or update bot member role
|
||||
// @Tags bots
|
||||
// @Param id path string true "Bot ID"
|
||||
// @Param payload body bots.UpsertMemberRequest true "Member payload"
|
||||
// @Success 200 {object} bots.BotMember
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{id}/members [put].
|
||||
func (h *UsersHandler) UpsertBotMember(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
botID := strings.TrimSpace(c.Param("id"))
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
var req bots.UpsertMemberRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
req.UserID = strings.TrimSpace(req.UserID)
|
||||
if req.UserID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "user_id is required")
|
||||
}
|
||||
resp, err := h.botService.UpsertMember(c.Request().Context(), botID, req)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// DeleteBotMember godoc
|
||||
// @Summary Delete bot member
|
||||
// @Description Remove a member from a bot
|
||||
// @Tags bots
|
||||
// @Param id path string true "Bot ID"
|
||||
// @Param user_id path string true "User ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 403 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{id}/members/{user_id} [delete].
|
||||
func (h *UsersHandler) DeleteBotMember(c echo.Context) error {
|
||||
channelIdentityID, err := h.requireChannelIdentityID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
botID := strings.TrimSpace(c.Param("id"))
|
||||
if botID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "bot id is required")
|
||||
}
|
||||
memberUserID := strings.TrimSpace(c.Param("user_id"))
|
||||
if memberUserID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "user id is required")
|
||||
}
|
||||
if _, err := h.authorizeBotAccess(c.Request().Context(), channelIdentityID, botID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.botService.DeleteMember(c.Request().Context(), botID, memberUserID); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetBotChannelConfig godoc
|
||||
// @Summary Get bot channel config
|
||||
// @Description Get bot channel configuration
|
||||
@@ -1044,7 +938,7 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.service, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.service, channelIdentityID, botID, bots.AccessPolicy{})
|
||||
}
|
||||
|
||||
func (*UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
|
||||
Reference in New Issue
Block a user