mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: add get models by provider
This commit is contained in:
@@ -3,21 +3,25 @@ package handlers
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
type ProvidersHandler struct {
|
||||
service *providers.Service
|
||||
logger *slog.Logger
|
||||
service *providers.Service
|
||||
modelsService *models.Service
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewProvidersHandler(log *slog.Logger, service *providers.Service) *ProvidersHandler {
|
||||
func NewProvidersHandler(log *slog.Logger, service *providers.Service, modelsService *models.Service) *ProvidersHandler {
|
||||
return &ProvidersHandler{
|
||||
service: service,
|
||||
logger: log.With(slog.String("handler", "providers")),
|
||||
service: service,
|
||||
modelsService: modelsService,
|
||||
logger: log.With(slog.String("handler", "providers")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +30,7 @@ func (h *ProvidersHandler) Register(e *echo.Echo) {
|
||||
group.POST("", h.Create)
|
||||
group.GET("", h.List)
|
||||
group.GET("/:id", h.Get)
|
||||
group.GET("/:id/models", h.ListModelsByProvider)
|
||||
group.GET("/name/:name", h.GetByName)
|
||||
group.PUT("/:id", h.Update)
|
||||
group.DELETE("/:id", h.Delete)
|
||||
@@ -124,6 +129,44 @@ func (h *ProvidersHandler) Get(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// ListModelsByProvider godoc
|
||||
// @Summary List provider models
|
||||
// @Description Get models for a provider by id, optionally filtered by type
|
||||
// @Tags providers
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Param type query string false "Model type (chat, embedding)"
|
||||
// @Success 200 {array} models.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /providers/{id}/models [get]
|
||||
func (h *ProvidersHandler) ListModelsByProvider(c echo.Context) error {
|
||||
if h.modelsService == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "models service not configured")
|
||||
}
|
||||
id := c.Param("id")
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
modelType := strings.TrimSpace(c.QueryParam("type"))
|
||||
var (
|
||||
resp []models.GetResponse
|
||||
err error
|
||||
)
|
||||
if modelType == "" {
|
||||
resp, err = h.modelsService.ListByProviderID(c.Request().Context(), id)
|
||||
} else {
|
||||
resp, err = h.modelsService.ListByProviderIDAndType(c.Request().Context(), id, models.ModelType(modelType))
|
||||
}
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// GetByName godoc
|
||||
// @Summary Get provider by name
|
||||
// @Description Get a provider configuration by its name
|
||||
@@ -236,4 +279,3 @@ func (h *ProvidersHandler) Count(c echo.Context) error {
|
||||
|
||||
return c.JSON(http.StatusOK, providers.CountResponse{Count: count})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user