refactor: move client_type key from provider to model

This commit is contained in:
Acbox
2026-02-18 18:30:27 +08:00
parent 77e9f585a1
commit d6c47472b2
43 changed files with 552 additions and 1015 deletions
+4 -7
View File
@@ -213,10 +213,8 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
if err != nil {
return resolvedContext{}, err
}
clientType, err := normalizeClientType(provider.ClientType)
if err != nil {
return resolvedContext{}, err
}
clientType := string(chatModel.ClientType)
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
maxTokens := botSettings.MaxContextTokens
@@ -306,7 +304,7 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv
Messages: resp.Messages,
Skills: resp.Skills,
Model: rc.model.ModelID,
Provider: rc.provider.ClientType,
Provider: string(rc.model.ClientType),
}, nil
}
@@ -1235,8 +1233,7 @@ func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.
func normalizeClientType(clientType string) (string, error) {
ct := strings.ToLower(strings.TrimSpace(clientType))
switch ct {
case "openai", "openai-compat", "anthropic", "google",
"azure", "bedrock", "mistral", "xai", "ollama", "dashscope":
case "openai-responses", "openai-completions", "anthropic-messages", "google-generative-ai":
return ct, nil
default:
return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType)
+8 -8
View File
@@ -166,14 +166,13 @@ type LifecycleEvent struct {
}
type LlmProvider struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
ClientType string `json:"client_type"`
BaseUrl string `json:"base_url"`
ApiKey string `json:"api_key"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseUrl string `json:"base_url"`
ApiKey string `json:"api_key"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type McpConnection struct {
@@ -209,6 +208,7 @@ type Model struct {
ModelID string `json:"model_id"`
Name pgtype.Text `json:"name"`
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
ClientType pgtype.Text `json:"client_type"`
Dimensions pgtype.Int4 `json:"dimensions"`
InputModalities []string `json:"input_modalities"`
Type string `json:"type"`
+62 -101
View File
@@ -22,17 +22,6 @@ func (q *Queries) CountLlmProviders(ctx context.Context) (int64, error) {
return count, err
}
const countLlmProvidersByClientType = `-- name: CountLlmProvidersByClientType :one
SELECT COUNT(*) FROM llm_providers WHERE client_type = $1
`
func (q *Queries) CountLlmProvidersByClientType(ctx context.Context, clientType string) (int64, error) {
row := q.db.QueryRow(ctx, countLlmProvidersByClientType, clientType)
var count int64
err := row.Scan(&count)
return count, err
}
const countModels = `-- name: CountModels :one
SELECT COUNT(*) FROM models
`
@@ -56,29 +45,26 @@ func (q *Queries) CountModelsByType(ctx context.Context, type_ string) (int64, e
}
const createLlmProvider = `-- name: CreateLlmProvider :one
INSERT INTO llm_providers (name, client_type, base_url, api_key, metadata)
INSERT INTO llm_providers (name, base_url, api_key, metadata)
VALUES (
$1,
$2,
$3,
$4,
$5
$4
)
RETURNING id, name, client_type, base_url, api_key, metadata, created_at, updated_at
RETURNING id, name, base_url, api_key, metadata, created_at, updated_at
`
type CreateLlmProviderParams struct {
Name string `json:"name"`
ClientType string `json:"client_type"`
BaseUrl string `json:"base_url"`
ApiKey string `json:"api_key"`
Metadata []byte `json:"metadata"`
Name string `json:"name"`
BaseUrl string `json:"base_url"`
ApiKey string `json:"api_key"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderParams) (LlmProvider, error) {
row := q.db.QueryRow(ctx, createLlmProvider,
arg.Name,
arg.ClientType,
arg.BaseUrl,
arg.ApiKey,
arg.Metadata,
@@ -87,7 +73,6 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa
err := row.Scan(
&i.ID,
&i.Name,
&i.ClientType,
&i.BaseUrl,
&i.ApiKey,
&i.Metadata,
@@ -98,22 +83,24 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa
}
const createModel = `-- name: CreateModel :one
INSERT INTO models (model_id, name, llm_provider_id, dimensions, input_modalities, type)
INSERT INTO models (model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type)
VALUES (
$1,
$2,
$3,
$4,
$5,
$6
$6,
$7
)
RETURNING id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
`
type CreateModelParams struct {
ModelID string `json:"model_id"`
Name pgtype.Text `json:"name"`
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
ClientType pgtype.Text `json:"client_type"`
Dimensions pgtype.Int4 `json:"dimensions"`
InputModalities []string `json:"input_modalities"`
Type string `json:"type"`
@@ -124,6 +111,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
arg.ModelID,
arg.Name,
arg.LlmProviderID,
arg.ClientType,
arg.Dimensions,
arg.InputModalities,
arg.Type,
@@ -134,6 +122,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -209,7 +198,7 @@ func (q *Queries) DeleteModelByModelID(ctx context.Context, modelID string) erro
}
const getLlmProviderByID = `-- name: GetLlmProviderByID :one
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE id = $1
SELECT id, name, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE id = $1
`
func (q *Queries) GetLlmProviderByID(ctx context.Context, id pgtype.UUID) (LlmProvider, error) {
@@ -218,7 +207,6 @@ func (q *Queries) GetLlmProviderByID(ctx context.Context, id pgtype.UUID) (LlmPr
err := row.Scan(
&i.ID,
&i.Name,
&i.ClientType,
&i.BaseUrl,
&i.ApiKey,
&i.Metadata,
@@ -229,7 +217,7 @@ func (q *Queries) GetLlmProviderByID(ctx context.Context, id pgtype.UUID) (LlmPr
}
const getLlmProviderByName = `-- name: GetLlmProviderByName :one
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE name = $1
SELECT id, name, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE name = $1
`
func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmProvider, error) {
@@ -238,7 +226,6 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro
err := row.Scan(
&i.ID,
&i.Name,
&i.ClientType,
&i.BaseUrl,
&i.ApiKey,
&i.Metadata,
@@ -249,7 +236,7 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro
}
const getModelByID = `-- name: GetModelByID :one
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE id = $1
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE id = $1
`
func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) {
@@ -260,6 +247,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -270,7 +258,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
}
const getModelByModelID = `-- name: GetModelByModelID :one
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE model_id = $1
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE model_id = $1
`
func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) {
@@ -281,6 +269,7 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -291,7 +280,7 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
}
const listLlmProviders = `-- name: ListLlmProviders :many
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
SELECT id, name, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
ORDER BY created_at DESC
`
@@ -307,42 +296,6 @@ func (q *Queries) ListLlmProviders(ctx context.Context) ([]LlmProvider, error) {
if err := rows.Scan(
&i.ID,
&i.Name,
&i.ClientType,
&i.BaseUrl,
&i.ApiKey,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listLlmProvidersByClientType = `-- name: ListLlmProvidersByClientType :many
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
WHERE client_type = $1
ORDER BY created_at DESC
`
func (q *Queries) ListLlmProvidersByClientType(ctx context.Context, clientType string) ([]LlmProvider, error) {
rows, err := q.db.Query(ctx, listLlmProvidersByClientType, clientType)
if err != nil {
return nil, err
}
defer rows.Close()
var items []LlmProvider
for rows.Next() {
var i LlmProvider
if err := rows.Scan(
&i.ID,
&i.Name,
&i.ClientType,
&i.BaseUrl,
&i.ApiKey,
&i.Metadata,
@@ -394,7 +347,7 @@ func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pg
}
const listModels = `-- name: ListModels :many
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
ORDER BY created_at DESC
`
@@ -412,6 +365,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -429,13 +383,12 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
}
const listModelsByClientType = `-- name: ListModelsByClientType :many
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.input_modalities, m.type, m.created_at, m.updated_at FROM models AS m
JOIN llm_providers AS p ON p.id = m.llm_provider_id
WHERE p.client_type = $1
ORDER BY m.created_at DESC
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
WHERE client_type = $1
ORDER BY created_at DESC
`
func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) ([]Model, error) {
func (q *Queries) ListModelsByClientType(ctx context.Context, clientType pgtype.Text) ([]Model, error) {
rows, err := q.db.Query(ctx, listModelsByClientType, clientType)
if err != nil {
return nil, err
@@ -449,6 +402,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -466,7 +420,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
}
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
WHERE llm_provider_id = $1
ORDER BY created_at DESC
`
@@ -485,6 +439,7 @@ func (q *Queries) ListModelsByProviderID(ctx context.Context, llmProviderID pgty
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -502,7 +457,7 @@ func (q *Queries) ListModelsByProviderID(ctx context.Context, llmProviderID pgty
}
const listModelsByProviderIDAndType = `-- name: ListModelsByProviderIDAndType :many
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
WHERE llm_provider_id = $1
AND type = $2
ORDER BY created_at DESC
@@ -527,6 +482,7 @@ func (q *Queries) ListModelsByProviderIDAndType(ctx context.Context, arg ListMod
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -544,7 +500,7 @@ func (q *Queries) ListModelsByProviderIDAndType(ctx context.Context, arg ListMod
}
const listModelsByType = `-- name: ListModelsByType :many
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
WHERE type = $1
ORDER BY created_at DESC
`
@@ -563,6 +519,7 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model,
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -583,28 +540,25 @@ const updateLlmProvider = `-- name: UpdateLlmProvider :one
UPDATE llm_providers
SET
name = $1,
client_type = $2,
base_url = $3,
api_key = $4,
metadata = $5,
base_url = $2,
api_key = $3,
metadata = $4,
updated_at = now()
WHERE id = $6
RETURNING id, name, client_type, base_url, api_key, metadata, created_at, updated_at
WHERE id = $5
RETURNING id, name, base_url, api_key, metadata, created_at, updated_at
`
type UpdateLlmProviderParams struct {
Name string `json:"name"`
ClientType string `json:"client_type"`
BaseUrl string `json:"base_url"`
ApiKey string `json:"api_key"`
Metadata []byte `json:"metadata"`
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseUrl string `json:"base_url"`
ApiKey string `json:"api_key"`
Metadata []byte `json:"metadata"`
ID pgtype.UUID `json:"id"`
}
func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderParams) (LlmProvider, error) {
row := q.db.QueryRow(ctx, updateLlmProvider,
arg.Name,
arg.ClientType,
arg.BaseUrl,
arg.ApiKey,
arg.Metadata,
@@ -614,7 +568,6 @@ func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderPa
err := row.Scan(
&i.ID,
&i.Name,
&i.ClientType,
&i.BaseUrl,
&i.ApiKey,
&i.Metadata,
@@ -629,17 +582,19 @@ UPDATE models
SET
name = $1,
llm_provider_id = $2,
dimensions = $3,
input_modalities = $4,
type = $5,
client_type = $3,
dimensions = $4,
input_modalities = $5,
type = $6,
updated_at = now()
WHERE id = $6
RETURNING id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at
WHERE id = $7
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
`
type UpdateModelParams struct {
Name pgtype.Text `json:"name"`
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
ClientType pgtype.Text `json:"client_type"`
Dimensions pgtype.Int4 `json:"dimensions"`
InputModalities []string `json:"input_modalities"`
Type string `json:"type"`
@@ -650,6 +605,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
row := q.db.QueryRow(ctx, updateModel,
arg.Name,
arg.LlmProviderID,
arg.ClientType,
arg.Dimensions,
arg.InputModalities,
arg.Type,
@@ -661,6 +617,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
@@ -676,18 +633,20 @@ SET
model_id = $1,
name = $2,
llm_provider_id = $3,
dimensions = $4,
input_modalities = $5,
type = $6,
client_type = $4,
dimensions = $5,
input_modalities = $6,
type = $7,
updated_at = now()
WHERE model_id = $7
RETURNING id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at
WHERE model_id = $8
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
`
type UpdateModelByModelIDParams struct {
NewModelID string `json:"new_model_id"`
Name pgtype.Text `json:"name"`
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
ClientType pgtype.Text `json:"client_type"`
Dimensions pgtype.Int4 `json:"dimensions"`
InputModalities []string `json:"input_modalities"`
Type string `json:"type"`
@@ -699,6 +658,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
arg.NewModelID,
arg.Name,
arg.LlmProviderID,
arg.ClientType,
arg.Dimensions,
arg.InputModalities,
arg.Type,
@@ -710,6 +670,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
&i.ModelID,
&i.Name,
&i.LlmProviderID,
&i.ClientType,
&i.Dimensions,
&i.InputModalities,
&i.Type,
+5 -37
View File
@@ -17,10 +17,6 @@ import (
const (
TypeText = "text"
TypeMultimodal = "multimodal"
ProviderOpenAI = "openai"
ProviderBedrock = "bedrock"
ProviderDashScope = "dashscope"
)
type Request struct {
@@ -81,16 +77,10 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
}
switch req.Type {
case TypeText:
if req.Provider != "" && req.Provider != ProviderOpenAI {
return Result{}, errors.New("invalid provider for text embeddings")
}
if req.Input.Text == "" {
return Result{}, errors.New("text input is required")
}
case TypeMultimodal:
if req.Provider != "" && req.Provider != ProviderBedrock && req.Provider != ProviderDashScope {
return Result{}, errors.New("invalid provider for multimodal embeddings")
}
if req.Input.Text == "" && req.Input.ImageURL == "" && req.Input.VideoURL == "" {
return Result{}, errors.New("multimodal input is required")
}
@@ -109,7 +99,9 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
req.Model = selected.ModelID
req.Dimensions = selected.Dimensions
req.Provider = strings.ToLower(strings.TrimSpace(provider.ClientType))
if selected.ClientType != "" {
req.Provider = string(selected.ClientType)
}
if req.Model == "" {
return Result{}, errors.New("embedding model id not configured")
}
@@ -122,11 +114,9 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
timeout = 10 * time.Second
}
// OpenAI-compatible embeddings work for both openai-responses and openai-completions
switch req.Type {
case TypeText:
if req.Provider != ProviderOpenAI {
return Result{}, errors.New("provider not implemented")
}
embedder, err := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout)
if err != nil {
return Result{}, err
@@ -143,29 +133,7 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
Embedding: vector,
}, nil
case TypeMultimodal:
if req.Provider == ProviderDashScope {
if strings.TrimSpace(provider.ApiKey) == "" {
return Result{}, errors.New("dashscope api key is required")
}
dashscope := NewDashScopeEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, timeout)
vector, usage, err := dashscope.Embed(ctx, req.Input.Text, req.Input.ImageURL, req.Input.VideoURL)
if err != nil {
return Result{}, err
}
return Result{
Type: req.Type,
Provider: req.Provider,
Model: req.Model,
Dimensions: req.Dimensions,
Embedding: vector,
Usage: Usage{
InputTokens: usage.InputTokens,
ImageTokens: usage.ImageTokens,
Duration: usage.Duration,
},
}, nil
}
return Result{}, errors.New("provider not implemented")
return Result{}, errors.New("multimodal embeddings not supported for current provider types")
default:
return Result{}, errors.New("invalid embeddings type")
}
+1 -1
View File
@@ -62,7 +62,7 @@ func (h *ModelsHandler) Create(c echo.Context) error {
// @Description Get a list of all configured models, optionally filtered by type or client type
// @Tags models
// @Param type query string false "Model type (chat, embedding)"
// @Param client_type query string false "Client type (openai, openai-compat, anthropic, google, azure, bedrock, mistral, xai, ollama, dashscope)"
// @Param client_type query string false "Client type (openai-responses, openai-completions, anthropic-messages, google-generative-ai)"
// @Success 200 {array} models.GetResponse
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
+4 -31
View File
@@ -58,9 +58,6 @@ func (h *ProvidersHandler) Create(c echo.Context) error {
if req.Name == "" {
return echo.NewHTTPError(http.StatusBadRequest, "name is required")
}
if req.ClientType == "" {
return echo.NewHTTPError(http.StatusBadRequest, "client_type is required")
}
if req.BaseURL == "" {
return echo.NewHTTPError(http.StatusBadRequest, "base_url is required")
}
@@ -75,27 +72,15 @@ func (h *ProvidersHandler) Create(c echo.Context) error {
// List godoc
// @Summary List all LLM providers
// @Description Get a list of all configured LLM providers, optionally filtered by client type
// @Description Get a list of all configured LLM providers
// @Tags providers
// @Accept json
// @Produce json
// @Param client_type query string false "Client type filter (openai, openai-compat, anthropic, google, azure, bedrock, mistral, xai, ollama, dashscope)"
// @Success 200 {array} providers.GetResponse
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /providers [get]
func (h *ProvidersHandler) List(c echo.Context) error {
clientType := c.QueryParam("client_type")
var resp []providers.GetResponse
var err error
if clientType != "" {
resp, err = h.service.ListByClientType(c.Request().Context(), providers.ClientType(clientType))
} else {
resp, err = h.service.List(c.Request().Context())
}
resp, err := h.service.List(c.Request().Context())
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
@@ -252,27 +237,15 @@ func (h *ProvidersHandler) Delete(c echo.Context) error {
// Count godoc
// @Summary Count providers
// @Description Get the total count of providers, optionally filtered by client type
// @Description Get the total count of providers
// @Tags providers
// @Accept json
// @Produce json
// @Param client_type query string false "Client type filter (openai, openai-compat, anthropic, google, azure, bedrock, mistral, xai, ollama, dashscope)"
// @Success 200 {object} providers.CountResponse
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /providers/count [get]
func (h *ProvidersHandler) Count(c echo.Context) error {
clientType := c.QueryParam("client_type")
var count int64
var err error
if clientType != "" {
count, err = h.service.CountByClientType(c.Request().Context(), providers.ClientType(clientType))
} else {
count, err = h.service.Count(c.Request().Context())
}
count, err := h.service.Count(c.Request().Context())
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
+17 -11
View File
@@ -49,6 +49,9 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
InputModalities: inputMod,
Type: string(model.Type),
}
if model.ClientType != "" {
params.ClientType = pgtype.Text{String: string(model.ClientType), Valid: true}
}
// Handle optional name field
if model.Name != "" {
@@ -140,7 +143,7 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) (
return nil, fmt.Errorf("invalid client type: %s", clientType)
}
dbModels, err := s.queries.ListModelsByClientType(ctx, string(clientType))
dbModels, err := s.queries.ListModelsByClientType(ctx, pgtype.Text{String: string(clientType), Valid: true})
if err != nil {
return nil, fmt.Errorf("failed to list models by client type: %w", err)
}
@@ -207,6 +210,9 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
InputModalities: inputMod,
Type: string(model.Type),
}
if model.ClientType != "" {
params.ClientType = pgtype.Text{String: string(model.ClientType), Valid: true}
}
llmProviderID, err := db.ParseUUID(model.LlmProviderID)
if err != nil {
@@ -251,6 +257,9 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
InputModalities: inputMod,
Type: string(model.Type),
}
if model.ClientType != "" {
params.ClientType = pgtype.Text{String: string(model.ClientType), Valid: true}
}
llmProviderID, err := db.ParseUUID(model.LlmProviderID)
if err != nil {
@@ -333,6 +342,9 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
Type: ModelType(dbModel.Type),
},
}
if dbModel.ClientType.Valid {
resp.Model.ClientType = ClientType(dbModel.ClientType.String)
}
if resp.Model.Type == ModelTypeChat {
resp.Model.InputModalities = normalizeModalities(dbModel.InputModalities, []string{ModelInputText})
}
@@ -370,16 +382,10 @@ func normalizeModalities(modalities []string, fallback []string) []string {
func isValidClientType(clientType ClientType) bool {
switch clientType {
case ClientTypeOpenAI,
ClientTypeOpenAICompat,
ClientTypeAnthropic,
ClientTypeGoogle,
ClientTypeAzure,
ClientTypeBedrock,
ClientTypeMistral,
ClientTypeXAI,
ClientTypeOllama,
ClientTypeDashscope:
case ClientTypeOpenAIResponses,
ClientTypeOpenAICompletions,
ClientTypeAnthropicMessages,
ClientTypeGoogleGenerativeAI:
return true
default:
return false
+28 -10
View File
@@ -110,6 +110,7 @@ func TestModel_Validate(t *testing.T) {
ModelID: "gpt-4",
Name: "GPT-4",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
ClientType: models.ClientTypeOpenAIResponses,
Type: models.ModelTypeChat,
},
wantErr: false,
@@ -120,11 +121,33 @@ func TestModel_Validate(t *testing.T) {
ModelID: "gpt-4o",
Name: "GPT-4o",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
ClientType: models.ClientTypeOpenAIResponses,
InputModalities: []string{"text", "image", "audio"},
Type: models.ModelTypeChat,
},
wantErr: false,
},
{
name: "chat model missing client_type",
model: models.Model{
ModelID: "gpt-4",
Name: "GPT-4",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "embedding model without client_type is valid",
model: models.Model{
ModelID: "text-embedding-3-small",
Name: "Embedding",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
Type: models.ModelTypeEmbedding,
Dimensions: 1536,
},
wantErr: false,
},
{
name: "valid embedding model",
model: models.Model{
@@ -185,6 +208,7 @@ func TestModel_Validate(t *testing.T) {
model: models.Model{
ModelID: "gpt-4",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
ClientType: models.ClientTypeOpenAIResponses,
Type: models.ModelTypeChat,
InputModalities: []string{"text", "smell"},
},
@@ -262,15 +286,9 @@ func TestModelTypes(t *testing.T) {
})
t.Run("ClientType constants", func(t *testing.T) {
assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI)
assert.Equal(t, models.ClientType("openai-compat"), models.ClientTypeOpenAICompat)
assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic)
assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle)
assert.Equal(t, models.ClientType("azure"), models.ClientTypeAzure)
assert.Equal(t, models.ClientType("bedrock"), models.ClientTypeBedrock)
assert.Equal(t, models.ClientType("mistral"), models.ClientTypeMistral)
assert.Equal(t, models.ClientType("xai"), models.ClientTypeXAI)
assert.Equal(t, models.ClientType("ollama"), models.ClientTypeOllama)
assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope)
assert.Equal(t, models.ClientType("openai-responses"), models.ClientTypeOpenAIResponses)
assert.Equal(t, models.ClientType("openai-completions"), models.ClientTypeOpenAICompletions)
assert.Equal(t, models.ClientType("anthropic-messages"), models.ClientTypeAnthropicMessages)
assert.Equal(t, models.ClientType("google-generative-ai"), models.ClientTypeGoogleGenerativeAI)
})
}
+19 -17
View File
@@ -25,25 +25,20 @@ const (
type ClientType string
const (
ClientTypeOpenAI ClientType = "openai"
ClientTypeOpenAICompat ClientType = "openai-compat"
ClientTypeAnthropic ClientType = "anthropic"
ClientTypeGoogle ClientType = "google"
ClientTypeAzure ClientType = "azure"
ClientTypeBedrock ClientType = "bedrock"
ClientTypeMistral ClientType = "mistral"
ClientTypeXAI ClientType = "xai"
ClientTypeOllama ClientType = "ollama"
ClientTypeDashscope ClientType = "dashscope"
ClientTypeOpenAIResponses ClientType = "openai-responses"
ClientTypeOpenAICompletions ClientType = "openai-completions"
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
)
type Model struct {
ModelID string `json:"model_id"`
Name string `json:"name"`
LlmProviderID string `json:"llm_provider_id"`
InputModalities []string `json:"input_modalities,omitempty"`
Type ModelType `json:"type"`
Dimensions int `json:"dimensions"`
ModelID string `json:"model_id"`
Name string `json:"name"`
LlmProviderID string `json:"llm_provider_id"`
ClientType ClientType `json:"client_type,omitempty"`
InputModalities []string `json:"input_modalities,omitempty"`
Type ModelType `json:"type"`
Dimensions int `json:"dimensions"`
}
// validInputModalities is the set of recognised input modality tokens.
@@ -65,10 +60,17 @@ func (m *Model) Validate() error {
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding {
return errors.New("invalid model type")
}
if m.Type == ModelTypeChat {
if m.ClientType == "" {
return errors.New("client_type is required for chat models")
}
if !isValidClientType(m.ClientType) {
return fmt.Errorf("invalid client_type: %s", m.ClientType)
}
}
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
return errors.New("dimensions must be greater than 0")
}
// Input modalities only apply to chat models.
if m.Type == ModelTypeChat {
for _, mod := range m.InputModalities {
if _, ok := validInputModalities[mod]; !ok {
+16 -75
View File
@@ -27,11 +27,6 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
// Create creates a new LLM provider
func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) {
// Validate client type
if !isValidClientType(req.ClientType) {
return GetResponse{}, fmt.Errorf("invalid client_type: %s", req.ClientType)
}
// Marshal metadata
metadataJSON, err := json.Marshal(req.Metadata)
if err != nil {
@@ -40,11 +35,10 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
// Create provider
provider, err := s.queries.CreateLlmProvider(ctx, sqlc.CreateLlmProviderParams{
Name: req.Name,
ClientType: string(req.ClientType),
BaseUrl: req.BaseURL,
ApiKey: req.APIKey,
Metadata: metadataJSON,
Name: req.Name,
BaseUrl: req.BaseURL,
ApiKey: req.APIKey,
Metadata: metadataJSON,
})
if err != nil {
return GetResponse{}, fmt.Errorf("create provider: %w", err)
@@ -92,24 +86,6 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) {
return results, nil
}
// ListByClientType retrieves providers by client type
func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) {
if !isValidClientType(clientType) {
return nil, fmt.Errorf("invalid client_type: %s", clientType)
}
providers, err := s.queries.ListLlmProvidersByClientType(ctx, string(clientType))
if err != nil {
return nil, fmt.Errorf("list providers by client type: %w", err)
}
results := make([]GetResponse, 0, len(providers))
for _, p := range providers {
results = append(results, s.toGetResponse(p))
}
return results, nil
}
// Update updates an existing provider
func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) {
providerID, err := db.ParseUUID(id)
@@ -129,14 +105,6 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
name = *req.Name
}
clientType := existing.ClientType
if req.ClientType != nil {
if !isValidClientType(*req.ClientType) {
return GetResponse{}, fmt.Errorf("invalid client_type: %s", *req.ClientType)
}
clientType = string(*req.ClientType)
}
baseURL := existing.BaseUrl
if req.BaseURL != nil {
baseURL = *req.BaseURL
@@ -155,12 +123,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
// Update provider
updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{
ID: providerID,
Name: name,
ClientType: clientType,
BaseUrl: baseURL,
ApiKey: apiKey,
Metadata: metadata,
ID: providerID,
Name: name,
BaseUrl: baseURL,
ApiKey: apiKey,
Metadata: metadata,
})
if err != nil {
return GetResponse{}, fmt.Errorf("update provider: %w", err)
@@ -191,19 +158,6 @@ func (s *Service) Count(ctx context.Context) (int64, error) {
return count, nil
}
// CountByClientType returns the count of providers by client type
func (s *Service) CountByClientType(ctx context.Context, clientType ClientType) (int64, error) {
if !isValidClientType(clientType) {
return 0, fmt.Errorf("invalid client_type: %s", clientType)
}
count, err := s.queries.CountLlmProvidersByClientType(ctx, string(clientType))
if err != nil {
return 0, fmt.Errorf("count providers by client type: %w", err)
}
return count, nil
}
// toGetResponse converts a database provider to a response
func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
var metadata map[string]any
@@ -217,26 +171,13 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
maskedAPIKey := maskAPIKey(provider.ApiKey)
return GetResponse{
ID: provider.ID.String(),
Name: provider.Name,
ClientType: provider.ClientType,
BaseURL: provider.BaseUrl,
APIKey: maskedAPIKey,
Metadata: metadata,
CreatedAt: provider.CreatedAt.Time,
UpdatedAt: provider.UpdatedAt.Time,
}
}
// isValidClientType checks if a client type is valid
func isValidClientType(clientType ClientType) bool {
switch clientType {
case ClientTypeOpenAI, ClientTypeOpenAICompat, ClientTypeAnthropic, ClientTypeGoogle,
ClientTypeAzure, ClientTypeBedrock, ClientTypeMistral, ClientTypeXAI,
ClientTypeOllama, ClientTypeDashscope:
return true
default:
return false
ID: provider.ID.String(),
Name: provider.Name,
BaseURL: provider.BaseUrl,
APIKey: maskedAPIKey,
Metadata: metadata,
CreatedAt: provider.CreatedAt.Time,
UpdatedAt: provider.UpdatedAt.Time,
}
}
+18 -38
View File
@@ -2,50 +2,31 @@ package providers
import "time"
// ClientType represents the type of LLM provider client
type ClientType string
const (
ClientTypeOpenAI ClientType = "openai"
ClientTypeOpenAICompat ClientType = "openai-compat"
ClientTypeAnthropic ClientType = "anthropic"
ClientTypeGoogle ClientType = "google"
ClientTypeAzure ClientType = "azure"
ClientTypeBedrock ClientType = "bedrock"
ClientTypeMistral ClientType = "mistral"
ClientTypeXAI ClientType = "xai"
ClientTypeOllama ClientType = "ollama"
ClientTypeDashscope ClientType = "dashscope"
)
// CreateRequest represents a request to create a new LLM provider
type CreateRequest struct {
Name string `json:"name" validate:"required"`
ClientType ClientType `json:"client_type" validate:"required"`
BaseURL string `json:"base_url" validate:"required,url"`
APIKey string `json:"api_key"`
Metadata map[string]any `json:"metadata,omitempty"`
Name string `json:"name" validate:"required"`
BaseURL string `json:"base_url" validate:"required,url"`
APIKey string `json:"api_key"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// UpdateRequest represents a request to update an existing LLM provider
type UpdateRequest struct {
Name *string `json:"name,omitempty"`
ClientType *ClientType `json:"client_type,omitempty"`
BaseURL *string `json:"base_url,omitempty"`
APIKey *string `json:"api_key,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Name *string `json:"name,omitempty"`
BaseURL *string `json:"base_url,omitempty"`
APIKey *string `json:"api_key,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// GetResponse represents the response for getting a provider
type GetResponse struct {
ID string `json:"id"`
Name string `json:"name"`
ClientType string `json:"client_type"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key,omitempty"` // masked in response
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID string `json:"id"`
Name string `json:"name"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key,omitempty"` // masked in response
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ListResponse represents the response for listing providers
@@ -61,10 +42,9 @@ type CountResponse struct {
// TestRequest represents a request to test provider connection
type TestRequest struct {
ClientType ClientType `json:"client_type" validate:"required"`
BaseURL string `json:"base_url" validate:"required,url"`
APIKey string `json:"api_key"`
Model string `json:"model"` // optional test model
BaseURL string `json:"base_url" validate:"required,url"`
APIKey string `json:"api_key"`
Model string `json:"model"` // optional test model
}
// TestResponse represents the result of testing a provider