feat: add get models by provider

This commit is contained in:
Acbox
2026-02-02 13:48:22 +08:00
parent c10fbfee23
commit 3f0a0f8499
10 changed files with 330 additions and 38 deletions
+10 -9
View File
@@ -3,13 +3,14 @@ package main
import (
"context"
"fmt"
"log"
"log/slog"
"os"
"strings"
"time"
// "github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/chat"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/config"
ctr "github.com/memohai/memoh/internal/containerd"
"github.com/memohai/memoh/internal/db"
@@ -176,18 +177,18 @@ func main() {
// Initialize providers and models handlers
providersService := providers.NewService(logger.L, queries)
providersHandler := handlers.NewProvidersHandler(logger.L, providersService)
providersHandler := handlers.NewProvidersHandler(logger.L, providersService, modelsService)
settingsService := settings.NewService(logger.L, queries)
settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService)
modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService)
historyService := history.NewService(logger.L, queries)
historyHandler := handlers.NewHistoryHandler(logger.L, historyService)
channelService := channel.NewService(queries)
channelManager := channel.NewManager(channelService, chatResolver)
channelManager.RegisterAdapter(channel.NewTelegramAdapter())
channelManager.RegisterAdapter(channel.NewFeishuAdapter())
channelManager.Start(ctx)
channelHandler := handlers.NewChannelHandler(channelService, channelManager)
// channelService := channel.NewService(queries)
// channelManager := channel.NewManager(channelService, chatResolver)
// channelManager.RegisterAdapter(channel.NewTelegramAdapter())
// channelManager.RegisterAdapter(channel.NewFeishuAdapter())
// channelManager.Start(ctx)
// channelHandler := handlers.NewChannelHandler(channelService, channelManager)
scheduleService := schedule.NewService(logger.L, queries, chatResolver, cfg.Auth.JWTSecret)
if err := scheduleService.Bootstrap(ctx); err != nil {
logger.Error("schedule bootstrap", slog.Any("error", err))
@@ -196,7 +197,7 @@ func main() {
scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService)
subagentService := subagent.NewService(logger.L, queries)
subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService)
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler)
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler, /*channelHandler*/)
if err := srv.Start(); err != nil {
logger.Error("server failed", slog.Any("error", err))
+11
View File
@@ -78,6 +78,17 @@ JOIN llm_providers AS p ON p.id = m.llm_provider_id
WHERE p.client_type = sqlc.arg(client_type)
ORDER BY m.created_at DESC;
-- name: ListModelsByProviderID :many
SELECT * FROM models
WHERE llm_provider_id = sqlc.arg(llm_provider_id)
ORDER BY created_at DESC;
-- name: ListModelsByProviderIDAndType :many
SELECT * FROM models
WHERE llm_provider_id = sqlc.arg(llm_provider_id)
AND type = sqlc.arg(type)
ORDER BY created_at DESC;
-- name: UpdateModel :one
UPDATE models
SET
+53
View File
@@ -1828,6 +1828,59 @@ const docTemplate = `{
}
}
},
"/providers/{id}/models": {
"get": {
"description": "Get models for a provider by id, optionally filtered by type",
"tags": [
"providers"
],
"summary": "List provider models",
"parameters": [
{
"type": "string",
"description": "Provider ID (UUID)",
"name": "id",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Model type (chat, embedding)",
"name": "type",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/models.GetResponse"
}
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/schedule": {
"get": {
"description": "List schedules for current user",
+53
View File
@@ -1819,6 +1819,59 @@
}
}
},
"/providers/{id}/models": {
"get": {
"description": "Get models for a provider by id, optionally filtered by type",
"tags": [
"providers"
],
"summary": "List provider models",
"parameters": [
{
"type": "string",
"description": "Provider ID (UUID)",
"name": "id",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Model type (chat, embedding)",
"name": "type",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/models.GetResponse"
}
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/schedule": {
"get": {
"description": "List schedules for current user",
+35
View File
@@ -1939,6 +1939,41 @@ paths:
summary: Update provider
tags:
- providers
/providers/{id}/models:
get:
description: Get models for a provider by id, optionally filtered by type
parameters:
- description: Provider ID (UUID)
in: path
name: id
required: true
type: string
- description: Model type (chat, embedding)
in: query
name: type
type: string
responses:
"200":
description: OK
schema:
items:
$ref: '#/definitions/models.GetResponse'
type: array
"400":
description: Bad Request
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"404":
description: Not Found
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/handlers.ErrorResponse'
summary: List provider models
tags:
- providers
/providers/count:
get:
consumes:
-19
View File
@@ -8,25 +8,6 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
type ChannelConfig struct {
ID pgtype.UUID `json:"id"`
ChannelType string `json:"channel_type"`
Config []byte `json:"config"`
UserID pgtype.UUID `json:"user_id"`
IsGlobal bool `json:"is_global"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type ChannelUserConfig struct {
ID pgtype.UUID `json:"id"`
ChannelType string `json:"channel_type"`
UserID pgtype.UUID `json:"user_id"`
Config []byte `json:"config"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type Container struct {
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
+78
View File
@@ -527,6 +527,84 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
return items, nil
}
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
WHERE llm_provider_id = $1
ORDER BY created_at DESC
`
func (q *Queries) ListModelsByProviderID(ctx context.Context, llmProviderID pgtype.UUID) ([]Model, error) {
rows, err := q.db.Query(ctx, listModelsByProviderID, llmProviderID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Model
for rows.Next() {
var i Model
if err := rows.Scan(
&i.ID,
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listModelsByProviderIDAndType = `-- name: ListModelsByProviderIDAndType :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
WHERE llm_provider_id = $1
AND type = $2
ORDER BY created_at DESC
`
type ListModelsByProviderIDAndTypeParams struct {
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
Type string `json:"type"`
}
func (q *Queries) ListModelsByProviderIDAndType(ctx context.Context, arg ListModelsByProviderIDAndTypeParams) ([]Model, error) {
rows, err := q.db.Query(ctx, listModelsByProviderIDAndType, arg.LlmProviderID, arg.Type)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Model
for rows.Next() {
var i Model
if err := rows.Scan(
&i.ID,
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.Dimensions,
&i.IsMultimodal,
&i.Type,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listModelsByType = `-- name: ListModelsByType :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
WHERE type = $1
+48 -6
View File
@@ -3,21 +3,25 @@ package handlers
import (
"log/slog"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/providers"
)
type ProvidersHandler struct {
service *providers.Service
logger *slog.Logger
service *providers.Service
modelsService *models.Service
logger *slog.Logger
}
func NewProvidersHandler(log *slog.Logger, service *providers.Service) *ProvidersHandler {
func NewProvidersHandler(log *slog.Logger, service *providers.Service, modelsService *models.Service) *ProvidersHandler {
return &ProvidersHandler{
service: service,
logger: log.With(slog.String("handler", "providers")),
service: service,
modelsService: modelsService,
logger: log.With(slog.String("handler", "providers")),
}
}
@@ -26,6 +30,7 @@ func (h *ProvidersHandler) Register(e *echo.Echo) {
group.POST("", h.Create)
group.GET("", h.List)
group.GET("/:id", h.Get)
group.GET("/:id/models", h.ListModelsByProvider)
group.GET("/name/:name", h.GetByName)
group.PUT("/:id", h.Update)
group.DELETE("/:id", h.Delete)
@@ -124,6 +129,44 @@ func (h *ProvidersHandler) Get(c echo.Context) error {
return c.JSON(http.StatusOK, resp)
}
// ListModelsByProvider godoc
// @Summary List provider models
// @Description Get models for a provider by id, optionally filtered by type
// @Tags providers
// @Param id path string true "Provider ID (UUID)"
// @Param type query string false "Model type (chat, embedding)"
// @Success 200 {array} models.GetResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /providers/{id}/models [get]
func (h *ProvidersHandler) ListModelsByProvider(c echo.Context) error {
if h.modelsService == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "models service not configured")
}
id := c.Param("id")
if strings.TrimSpace(id) == "" {
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
}
modelType := strings.TrimSpace(c.QueryParam("type"))
var (
resp []models.GetResponse
err error
)
if modelType == "" {
resp, err = h.modelsService.ListByProviderID(c.Request().Context(), id)
} else {
resp, err = h.modelsService.ListByProviderIDAndType(c.Request().Context(), id, models.ModelType(modelType))
}
if err != nil {
if strings.Contains(err.Error(), "invalid") {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
return c.JSON(http.StatusOK, resp)
}
// GetByName godoc
// @Summary Get provider by name
// @Description Get a provider configuration by its name
@@ -236,4 +279,3 @@ func (h *ProvidersHandler) Count(c echo.Context) error {
return c.JSON(http.StatusOK, providers.CountResponse{Count: count})
}
+38
View File
@@ -143,6 +143,44 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) (
return convertToGetResponseList(dbModels), nil
}
// ListByProviderID returns models filtered by provider ID.
func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]GetResponse, error) {
if strings.TrimSpace(providerID) == "" {
return nil, fmt.Errorf("provider id is required")
}
uuid, err := parseUUID(providerID)
if err != nil {
return nil, fmt.Errorf("invalid provider id: %w", err)
}
dbModels, err := s.queries.ListModelsByProviderID(ctx, uuid)
if err != nil {
return nil, fmt.Errorf("failed to list models by provider: %w", err)
}
return convertToGetResponseList(dbModels), nil
}
// ListByProviderIDAndType returns models filtered by provider ID and type.
func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string, modelType ModelType) ([]GetResponse, error) {
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding {
return nil, fmt.Errorf("invalid model type: %s", modelType)
}
if strings.TrimSpace(providerID) == "" {
return nil, fmt.Errorf("provider id is required")
}
uuid, err := parseUUID(providerID)
if err != nil {
return nil, fmt.Errorf("invalid provider id: %w", err)
}
dbModels, err := s.queries.ListModelsByProviderIDAndType(ctx, sqlc.ListModelsByProviderIDAndTypeParams{
LlmProviderID: uuid,
Type: string(modelType),
})
if err != nil {
return nil, fmt.Errorf("failed to list models by provider and type: %w", err)
}
return convertToGetResponseList(dbModels), nil
}
// UpdateByID updates a model by its internal UUID
func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) {
uuid, err := parseUUID(id)
+4 -4
View File
@@ -17,7 +17,7 @@ type Server struct {
logger *slog.Logger
}
func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler) *Server {
func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, /* channelHandler handlers.ChannelHandler*/) *Server {
if addr == "" {
addr = ":8080"
}
@@ -90,9 +90,9 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
if containerdHandler != nil {
containerdHandler.Register(e)
}
if channelHandler != nil {
channelHandler.Register(e)
}
// if channelHandler != nil {
// channelHandler.Register(e)
// }
return &Server{
echo: e,