diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index 001cb59f..acccfa72 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -17,6 +17,42 @@ type MemoryHandler struct { logger *slog.Logger } +type memoryAddPayload struct { + Message string `json:"message,omitempty"` + Messages []memory.Message `json:"messages,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` + Infer *bool `json:"infer,omitempty"` +} + +type memorySearchPayload struct { + Query string `json:"query"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Limit int `json:"limit,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` + Sources []string `json:"sources,omitempty"` +} + +type memoryEmbedUpsertPayload struct { + Type string `json:"type"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Input memory.EmbedInput `json:"input"` + Source string `json:"source,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Filters map[string]interface{} `json:"filters,omitempty"` +} + +type memoryDeleteAllPayload struct { + AgentID string `json:"agent_id,omitempty"` + RunID string `json:"run_id,omitempty"` +} + func NewMemoryHandler(log *slog.Logger, service *memory.Service) *MemoryHandler { return &MemoryHandler{ service: service, @@ -45,9 +81,9 @@ func (h *MemoryHandler) checkService() error { // EmbedUpsert godoc // @Summary Embed and upsert memory -// @Description Embed text or multimodal input and upsert into memory store +// @Description Embed text or multimodal input and upsert into memory store. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory -// @Param payload body memory.EmbedUpsertRequest true "Embed upsert request" +// @Param payload body memoryEmbedUpsertPayload true "Embed upsert request" // @Success 200 {object} memory.EmbedUpsertResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -62,14 +98,22 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { return err } - var req memory.EmbedUpsertRequest - if err := c.Bind(&req); err != nil { + var payload memoryEmbedUpsertPayload + if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if req.UserID != "" && req.UserID != userID { - return echo.NewHTTPError(http.StatusForbidden, "user mismatch") + req := memory.EmbedUpsertRequest{ + Type: payload.Type, + Provider: payload.Provider, + Model: payload.Model, + Input: payload.Input, + Source: payload.Source, + UserID: userID, + AgentID: payload.AgentID, + RunID: payload.RunID, + Metadata: payload.Metadata, + Filters: payload.Filters, } - req.UserID = userID resp, err := h.service.EmbedUpsert(c.Request().Context(), req) if err != nil { @@ -80,9 +124,9 @@ func (h *MemoryHandler) EmbedUpsert(c echo.Context) error { // Add godoc // @Summary Add memory -// @Description Add memory for a user via memory +// @Description Add memory for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory -// @Param payload body memory.AddRequest true "Add request" +// @Param payload body memoryAddPayload true "Add request" // @Success 200 {object} memory.SearchResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -97,14 +141,20 @@ func (h *MemoryHandler) Add(c echo.Context) error { return err } - var req memory.AddRequest - if err := c.Bind(&req); err != nil { + var payload memoryAddPayload + if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if req.UserID != "" && req.UserID != userID { - return echo.NewHTTPError(http.StatusForbidden, "user mismatch") + req := memory.AddRequest{ + Message: payload.Message, + Messages: payload.Messages, + UserID: userID, + AgentID: payload.AgentID, + RunID: payload.RunID, + Metadata: payload.Metadata, + Filters: payload.Filters, + Infer: payload.Infer, } - req.UserID = userID resp, err := h.service.Add(c.Request().Context(), req) if err != nil { @@ -115,9 +165,9 @@ func (h *MemoryHandler) Add(c echo.Context) error { // Search godoc // @Summary Search memories -// @Description Search memories for a user via memory +// @Description Search memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory -// @Param payload body memory.SearchRequest true "Search request" +// @Param payload body memorySearchPayload true "Search request" // @Success 200 {object} memory.SearchResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -132,14 +182,19 @@ func (h *MemoryHandler) Search(c echo.Context) error { return err } - var req memory.SearchRequest - if err := c.Bind(&req); err != nil { + var payload memorySearchPayload + if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if req.UserID != "" && req.UserID != userID { - return echo.NewHTTPError(http.StatusForbidden, "user mismatch") + req := memory.SearchRequest{ + Query: payload.Query, + UserID: userID, + AgentID: payload.AgentID, + RunID: payload.RunID, + Limit: payload.Limit, + Filters: payload.Filters, + Sources: payload.Sources, } - req.UserID = userID resp, err := h.service.Search(c.Request().Context(), req) if err != nil { @@ -150,7 +205,7 @@ func (h *MemoryHandler) Search(c echo.Context) error { // Update godoc // @Summary Update memory -// @Description Update a memory by ID via memory +// @Description Update a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory // @Param payload body memory.UpdateRequest true "Update request" // @Success 200 {object} memory.MemoryItem @@ -190,7 +245,7 @@ func (h *MemoryHandler) Update(c echo.Context) error { // Get godoc // @Summary Get memory -// @Description Get a memory by ID via memory +// @Description Get a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory // @Param memoryId path string true "Memory ID" // @Success 200 {object} memory.MemoryItem @@ -224,9 +279,8 @@ func (h *MemoryHandler) Get(c echo.Context) error { // GetAll godoc // @Summary List memories -// @Description List memories for a user via memory +// @Description List memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory -// @Param user_id query string false "User ID" // @Param agent_id query string false "Agent ID" // @Param run_id query string false "Run ID" // @Param limit query int false "Limit" @@ -244,9 +298,6 @@ func (h *MemoryHandler) GetAll(c echo.Context) error { return err } - if queryUserID := c.QueryParam("user_id"); queryUserID != "" && queryUserID != userID { - return echo.NewHTTPError(http.StatusForbidden, "user mismatch") - } req := memory.GetAllRequest{ UserID: userID, AgentID: c.QueryParam("agent_id"), @@ -268,7 +319,7 @@ func (h *MemoryHandler) GetAll(c echo.Context) error { // Delete godoc // @Summary Delete memory -// @Description Delete a memory by ID via memory +// @Description Delete a memory by ID via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory // @Param memoryId path string true "Memory ID" // @Success 200 {object} memory.DeleteResponse @@ -307,9 +358,9 @@ func (h *MemoryHandler) Delete(c echo.Context) error { // DeleteAll godoc // @Summary Delete memories -// @Description Delete all memories for a user via memory +// @Description Delete all memories for a user via memory. Auth: Bearer JWT determines user_id (sub or user_id). // @Tags memory -// @Param payload body memory.DeleteAllRequest true "Delete all request" +// @Param payload body memoryDeleteAllPayload true "Delete all request" // @Success 200 {object} memory.DeleteResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse @@ -324,14 +375,15 @@ func (h *MemoryHandler) DeleteAll(c echo.Context) error { return err } - var req memory.DeleteAllRequest - if err := c.Bind(&req); err != nil { + var payload memoryDeleteAllPayload + if err := c.Bind(&payload); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if req.UserID != "" && req.UserID != userID { - return echo.NewHTTPError(http.StatusForbidden, "user mismatch") + req := memory.DeleteAllRequest{ + UserID: userID, + AgentID: payload.AgentID, + RunID: payload.RunID, } - req.UserID = userID resp, err := h.service.DeleteAll(c.Request().Context(), req) if err != nil {