mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: add auth to memory api
This commit is contained in:
+88
-36
@@ -17,6 +17,42 @@ type MemoryHandler struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
type memoryAddPayload struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Messages []memory.Message `json:"messages,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Infer *bool `json:"infer,omitempty"`
|
||||
}
|
||||
|
||||
type memorySearchPayload struct {
|
||||
Query string `json:"query"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
Sources []string `json:"sources,omitempty"`
|
||||
}
|
||||
|
||||
type memoryEmbedUpsertPayload struct {
|
||||
Type string `json:"type"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Input memory.EmbedInput `json:"input"`
|
||||
Source string `json:"source,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
}
|
||||
|
||||
type memoryDeleteAllPayload struct {
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
RunID string `json:"run_id,omitempty"`
|
||||
}
|
||||
|
||||
func NewMemoryHandler(log *slog.Logger, service *memory.Service) *MemoryHandler {
|
||||
return &MemoryHandler{
|
||||
service: service,
|
||||
@@ -45,9 +81,9 @@ func (h *MemoryHandler) checkService() error {
|
||||
|
||||
// EmbedUpsert godoc
|
||||
// @Summary Embed and upsert memory
|
||||
// @Description Embed text or multimodal input and upsert into memory store
|
||||
// @Description Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param payload body memory.EmbedUpsertRequest true "Embed upsert request"
|
||||
// @Param payload body memoryEmbedUpsertPayload true "Embed upsert request"
|
||||
// @Success 200 {object} memory.EmbedUpsertResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
@@ -62,14 +98,22 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var req memory.EmbedUpsertRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
var payload memoryEmbedUpsertPayload
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if req.UserID != "" && req.UserID != userID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "user mismatch")
|
||||
req := memory.EmbedUpsertRequest{
|
||||
Type: payload.Type,
|
||||
Provider: payload.Provider,
|
||||
Model: payload.Model,
|
||||
Input: payload.Input,
|
||||
Source: payload.Source,
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
Metadata: payload.Metadata,
|
||||
Filters: payload.Filters,
|
||||
}
|
||||
req.UserID = userID
|
||||
|
||||
resp, err := h.service.EmbedUpsert(c.Request().Context(), req)
|
||||
if err != nil {
|
||||
@@ -80,9 +124,9 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error {
|
||||
|
||||
// Add godoc
|
||||
// @Summary Add memory
|
||||
// @Description Add memory for a user via memory
|
||||
// @Description Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param payload body memory.AddRequest true "Add request"
|
||||
// @Param payload body memoryAddPayload true "Add request"
|
||||
// @Success 200 {object} memory.SearchResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
@@ -97,14 +141,20 @@ func (h *MemoryHandler) Add(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var req memory.AddRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
var payload memoryAddPayload
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if req.UserID != "" && req.UserID != userID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "user mismatch")
|
||||
req := memory.AddRequest{
|
||||
Message: payload.Message,
|
||||
Messages: payload.Messages,
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
Metadata: payload.Metadata,
|
||||
Filters: payload.Filters,
|
||||
Infer: payload.Infer,
|
||||
}
|
||||
req.UserID = userID
|
||||
|
||||
resp, err := h.service.Add(c.Request().Context(), req)
|
||||
if err != nil {
|
||||
@@ -115,9 +165,9 @@ func (h *MemoryHandler) Add(c echo.Context) error {
|
||||
|
||||
// Search godoc
|
||||
// @Summary Search memories
|
||||
// @Description Search memories for a user via memory
|
||||
// @Description Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param payload body memory.SearchRequest true "Search request"
|
||||
// @Param payload body memorySearchPayload true "Search request"
|
||||
// @Success 200 {object} memory.SearchResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
@@ -132,14 +182,19 @@ func (h *MemoryHandler) Search(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var req memory.SearchRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
var payload memorySearchPayload
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if req.UserID != "" && req.UserID != userID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "user mismatch")
|
||||
req := memory.SearchRequest{
|
||||
Query: payload.Query,
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
Limit: payload.Limit,
|
||||
Filters: payload.Filters,
|
||||
Sources: payload.Sources,
|
||||
}
|
||||
req.UserID = userID
|
||||
|
||||
resp, err := h.service.Search(c.Request().Context(), req)
|
||||
if err != nil {
|
||||
@@ -150,7 +205,7 @@ func (h *MemoryHandler) Search(c echo.Context) error {
|
||||
|
||||
// Update godoc
|
||||
// @Summary Update memory
|
||||
// @Description Update a memory by ID via memory
|
||||
// @Description Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param payload body memory.UpdateRequest true "Update request"
|
||||
// @Success 200 {object} memory.MemoryItem
|
||||
@@ -190,7 +245,7 @@ func (h *MemoryHandler) Update(c echo.Context) error {
|
||||
|
||||
// Get godoc
|
||||
// @Summary Get memory
|
||||
// @Description Get a memory by ID via memory
|
||||
// @Description Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param memoryId path string true "Memory ID"
|
||||
// @Success 200 {object} memory.MemoryItem
|
||||
@@ -224,9 +279,8 @@ func (h *MemoryHandler) Get(c echo.Context) error {
|
||||
|
||||
// GetAll godoc
|
||||
// @Summary List memories
|
||||
// @Description List memories for a user via memory
|
||||
// @Description List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param user_id query string false "User ID"
|
||||
// @Param agent_id query string false "Agent ID"
|
||||
// @Param run_id query string false "Run ID"
|
||||
// @Param limit query int false "Limit"
|
||||
@@ -244,9 +298,6 @@ func (h *MemoryHandler) GetAll(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if queryUserID := c.QueryParam("user_id"); queryUserID != "" && queryUserID != userID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "user mismatch")
|
||||
}
|
||||
req := memory.GetAllRequest{
|
||||
UserID: userID,
|
||||
AgentID: c.QueryParam("agent_id"),
|
||||
@@ -268,7 +319,7 @@ func (h *MemoryHandler) GetAll(c echo.Context) error {
|
||||
|
||||
// Delete godoc
|
||||
// @Summary Delete memory
|
||||
// @Description Delete a memory by ID via memory
|
||||
// @Description Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param memoryId path string true "Memory ID"
|
||||
// @Success 200 {object} memory.DeleteResponse
|
||||
@@ -307,9 +358,9 @@ func (h *MemoryHandler) Delete(c echo.Context) error {
|
||||
|
||||
// DeleteAll godoc
|
||||
// @Summary Delete memories
|
||||
// @Description Delete all memories for a user via memory
|
||||
// @Description Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id).
|
||||
// @Tags memory
|
||||
// @Param payload body memory.DeleteAllRequest true "Delete all request"
|
||||
// @Param payload body memoryDeleteAllPayload true "Delete all request"
|
||||
// @Success 200 {object} memory.DeleteResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
@@ -324,14 +375,15 @@ func (h *MemoryHandler) DeleteAll(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var req memory.DeleteAllRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
var payload memoryDeleteAllPayload
|
||||
if err := c.Bind(&payload); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if req.UserID != "" && req.UserID != userID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "user mismatch")
|
||||
req := memory.DeleteAllRequest{
|
||||
UserID: userID,
|
||||
AgentID: payload.AgentID,
|
||||
RunID: payload.RunID,
|
||||
}
|
||||
req.UserID = userID
|
||||
|
||||
resp, err := h.service.DeleteAll(c.Request().Context(), req)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user