mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: add get models by provider
This commit is contained in:
+10
-9
@@ -3,13 +3,14 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// "github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/chat"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
@@ -176,18 +177,18 @@ func main() {
|
||||
|
||||
// Initialize providers and models handlers
|
||||
providersService := providers.NewService(logger.L, queries)
|
||||
providersHandler := handlers.NewProvidersHandler(logger.L, providersService)
|
||||
providersHandler := handlers.NewProvidersHandler(logger.L, providersService, modelsService)
|
||||
settingsService := settings.NewService(logger.L, queries)
|
||||
settingsHandler := handlers.NewSettingsHandler(logger.L, settingsService)
|
||||
modelsHandler := handlers.NewModelsHandler(logger.L, modelsService, settingsService)
|
||||
historyService := history.NewService(logger.L, queries)
|
||||
historyHandler := handlers.NewHistoryHandler(logger.L, historyService)
|
||||
channelService := channel.NewService(queries)
|
||||
channelManager := channel.NewManager(channelService, chatResolver)
|
||||
channelManager.RegisterAdapter(channel.NewTelegramAdapter())
|
||||
channelManager.RegisterAdapter(channel.NewFeishuAdapter())
|
||||
channelManager.Start(ctx)
|
||||
channelHandler := handlers.NewChannelHandler(channelService, channelManager)
|
||||
// channelService := channel.NewService(queries)
|
||||
// channelManager := channel.NewManager(channelService, chatResolver)
|
||||
// channelManager.RegisterAdapter(channel.NewTelegramAdapter())
|
||||
// channelManager.RegisterAdapter(channel.NewFeishuAdapter())
|
||||
// channelManager.Start(ctx)
|
||||
// channelHandler := handlers.NewChannelHandler(channelService, channelManager)
|
||||
scheduleService := schedule.NewService(logger.L, queries, chatResolver, cfg.Auth.JWTSecret)
|
||||
if err := scheduleService.Bootstrap(ctx); err != nil {
|
||||
logger.Error("schedule bootstrap", slog.Any("error", err))
|
||||
@@ -196,7 +197,7 @@ func main() {
|
||||
scheduleHandler := handlers.NewScheduleHandler(logger.L, scheduleService)
|
||||
subagentService := subagent.NewService(logger.L, queries)
|
||||
subagentHandler := handlers.NewSubagentHandler(logger.L, subagentService)
|
||||
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler, channelHandler)
|
||||
srv := server.NewServer(logger.L, addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, chatHandler, swaggerHandler, providersHandler, modelsHandler, settingsHandler, historyHandler, scheduleHandler, subagentHandler, containerdHandler, /*channelHandler*/)
|
||||
|
||||
if err := srv.Start(); err != nil {
|
||||
logger.Error("server failed", slog.Any("error", err))
|
||||
|
||||
@@ -78,6 +78,17 @@ JOIN llm_providers AS p ON p.id = m.llm_provider_id
|
||||
WHERE p.client_type = sqlc.arg(client_type)
|
||||
ORDER BY m.created_at DESC;
|
||||
|
||||
-- name: ListModelsByProviderID :many
|
||||
SELECT * FROM models
|
||||
WHERE llm_provider_id = sqlc.arg(llm_provider_id)
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: ListModelsByProviderIDAndType :many
|
||||
SELECT * FROM models
|
||||
WHERE llm_provider_id = sqlc.arg(llm_provider_id)
|
||||
AND type = sqlc.arg(type)
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: UpdateModel :one
|
||||
UPDATE models
|
||||
SET
|
||||
|
||||
@@ -1828,6 +1828,59 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/providers/{id}/models": {
|
||||
"get": {
|
||||
"description": "Get models for a provider by id, optionally filtered by type",
|
||||
"tags": [
|
||||
"providers"
|
||||
],
|
||||
"summary": "List provider models",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Provider ID (UUID)",
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Model type (chat, embedding)",
|
||||
"name": "type",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/models.GetResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/schedule": {
|
||||
"get": {
|
||||
"description": "List schedules for current user",
|
||||
|
||||
@@ -1819,6 +1819,59 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/providers/{id}/models": {
|
||||
"get": {
|
||||
"description": "Get models for a provider by id, optionally filtered by type",
|
||||
"tags": [
|
||||
"providers"
|
||||
],
|
||||
"summary": "List provider models",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Provider ID (UUID)",
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Model type (chat, embedding)",
|
||||
"name": "type",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/models.GetResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/schedule": {
|
||||
"get": {
|
||||
"description": "List schedules for current user",
|
||||
|
||||
@@ -1939,6 +1939,41 @@ paths:
|
||||
summary: Update provider
|
||||
tags:
|
||||
- providers
|
||||
/providers/{id}/models:
|
||||
get:
|
||||
description: Get models for a provider by id, optionally filtered by type
|
||||
parameters:
|
||||
- description: Provider ID (UUID)
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
- description: Model type (chat, embedding)
|
||||
in: query
|
||||
name: type
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
items:
|
||||
$ref: '#/definitions/models.GetResponse'
|
||||
type: array
|
||||
"400":
|
||||
description: Bad Request
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
"404":
|
||||
description: Not Found
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
summary: List provider models
|
||||
tags:
|
||||
- providers
|
||||
/providers/count:
|
||||
get:
|
||||
consumes:
|
||||
|
||||
@@ -8,25 +8,6 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type ChannelConfig struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
Config []byte `json:"config"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
IsGlobal bool `json:"is_global"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ChannelUserConfig struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
Config []byte `json:"config"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Container struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
|
||||
@@ -527,6 +527,84 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
|
||||
WHERE llm_provider_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListModelsByProviderID(ctx context.Context, llmProviderID pgtype.UUID) ([]Model, error) {
|
||||
rows, err := q.db.Query(ctx, listModelsByProviderID, llmProviderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Model
|
||||
for rows.Next() {
|
||||
var i Model
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listModelsByProviderIDAndType = `-- name: ListModelsByProviderIDAndType :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
|
||||
WHERE llm_provider_id = $1
|
||||
AND type = $2
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
type ListModelsByProviderIDAndTypeParams struct {
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListModelsByProviderIDAndType(ctx context.Context, arg ListModelsByProviderIDAndTypeParams) ([]Model, error) {
|
||||
rows, err := q.db.Query(ctx, listModelsByProviderIDAndType, arg.LlmProviderID, arg.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Model
|
||||
for rows.Next() {
|
||||
var i Model
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.Dimensions,
|
||||
&i.IsMultimodal,
|
||||
&i.Type,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listModelsByType = `-- name: ListModelsByType :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
|
||||
WHERE type = $1
|
||||
|
||||
@@ -3,21 +3,25 @@ package handlers
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
)
|
||||
|
||||
type ProvidersHandler struct {
|
||||
service *providers.Service
|
||||
logger *slog.Logger
|
||||
service *providers.Service
|
||||
modelsService *models.Service
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewProvidersHandler(log *slog.Logger, service *providers.Service) *ProvidersHandler {
|
||||
func NewProvidersHandler(log *slog.Logger, service *providers.Service, modelsService *models.Service) *ProvidersHandler {
|
||||
return &ProvidersHandler{
|
||||
service: service,
|
||||
logger: log.With(slog.String("handler", "providers")),
|
||||
service: service,
|
||||
modelsService: modelsService,
|
||||
logger: log.With(slog.String("handler", "providers")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +30,7 @@ func (h *ProvidersHandler) Register(e *echo.Echo) {
|
||||
group.POST("", h.Create)
|
||||
group.GET("", h.List)
|
||||
group.GET("/:id", h.Get)
|
||||
group.GET("/:id/models", h.ListModelsByProvider)
|
||||
group.GET("/name/:name", h.GetByName)
|
||||
group.PUT("/:id", h.Update)
|
||||
group.DELETE("/:id", h.Delete)
|
||||
@@ -124,6 +129,44 @@ func (h *ProvidersHandler) Get(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// ListModelsByProvider godoc
|
||||
// @Summary List provider models
|
||||
// @Description Get models for a provider by id, optionally filtered by type
|
||||
// @Tags providers
|
||||
// @Param id path string true "Provider ID (UUID)"
|
||||
// @Param type query string false "Model type (chat, embedding)"
|
||||
// @Success 200 {array} models.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /providers/{id}/models [get]
|
||||
func (h *ProvidersHandler) ListModelsByProvider(c echo.Context) error {
|
||||
if h.modelsService == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "models service not configured")
|
||||
}
|
||||
id := c.Param("id")
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "id is required")
|
||||
}
|
||||
modelType := strings.TrimSpace(c.QueryParam("type"))
|
||||
var (
|
||||
resp []models.GetResponse
|
||||
err error
|
||||
)
|
||||
if modelType == "" {
|
||||
resp, err = h.modelsService.ListByProviderID(c.Request().Context(), id)
|
||||
} else {
|
||||
resp, err = h.modelsService.ListByProviderIDAndType(c.Request().Context(), id, models.ModelType(modelType))
|
||||
}
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// GetByName godoc
|
||||
// @Summary Get provider by name
|
||||
// @Description Get a provider configuration by its name
|
||||
@@ -236,4 +279,3 @@ func (h *ProvidersHandler) Count(c echo.Context) error {
|
||||
|
||||
return c.JSON(http.StatusOK, providers.CountResponse{Count: count})
|
||||
}
|
||||
|
||||
|
||||
@@ -143,6 +143,44 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) (
|
||||
return convertToGetResponseList(dbModels), nil
|
||||
}
|
||||
|
||||
// ListByProviderID returns models filtered by provider ID.
|
||||
func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]GetResponse, error) {
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
return nil, fmt.Errorf("provider id is required")
|
||||
}
|
||||
uuid, err := parseUUID(providerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid provider id: %w", err)
|
||||
}
|
||||
dbModels, err := s.queries.ListModelsByProviderID(ctx, uuid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models by provider: %w", err)
|
||||
}
|
||||
return convertToGetResponseList(dbModels), nil
|
||||
}
|
||||
|
||||
// ListByProviderIDAndType returns models filtered by provider ID and type.
|
||||
func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string, modelType ModelType) ([]GetResponse, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding {
|
||||
return nil, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
return nil, fmt.Errorf("provider id is required")
|
||||
}
|
||||
uuid, err := parseUUID(providerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid provider id: %w", err)
|
||||
}
|
||||
dbModels, err := s.queries.ListModelsByProviderIDAndType(ctx, sqlc.ListModelsByProviderIDAndTypeParams{
|
||||
LlmProviderID: uuid,
|
||||
Type: string(modelType),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models by provider and type: %w", err)
|
||||
}
|
||||
return convertToGetResponseList(dbModels), nil
|
||||
}
|
||||
|
||||
// UpdateByID updates a model by its internal UUID
|
||||
func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) {
|
||||
uuid, err := parseUUID(id)
|
||||
|
||||
@@ -17,7 +17,7 @@ type Server struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler) *Server {
|
||||
func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, chatHandler *handlers.ChatHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, historyHandler *handlers.HistoryHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, /* channelHandler handlers.ChannelHandler*/) *Server {
|
||||
if addr == "" {
|
||||
addr = ":8080"
|
||||
}
|
||||
@@ -90,9 +90,9 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
|
||||
if containerdHandler != nil {
|
||||
containerdHandler.Register(e)
|
||||
}
|
||||
if channelHandler != nil {
|
||||
channelHandler.Register(e)
|
||||
}
|
||||
// if channelHandler != nil {
|
||||
// channelHandler.Register(e)
|
||||
// }
|
||||
|
||||
return &Server{
|
||||
echo: e,
|
||||
|
||||
Reference in New Issue
Block a user