mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refactor: move client_type key from provider to model
This commit is contained in:
@@ -213,10 +213,8 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
clientType, err := normalizeClientType(provider.ClientType)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
clientType := string(chatModel.ClientType)
|
||||
|
||||
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
|
||||
maxTokens := botSettings.MaxContextTokens
|
||||
|
||||
@@ -306,7 +304,7 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv
|
||||
Messages: resp.Messages,
|
||||
Skills: resp.Skills,
|
||||
Model: rc.model.ModelID,
|
||||
Provider: rc.provider.ClientType,
|
||||
Provider: string(rc.model.ClientType),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1235,8 +1233,7 @@ func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.
|
||||
func normalizeClientType(clientType string) (string, error) {
|
||||
ct := strings.ToLower(strings.TrimSpace(clientType))
|
||||
switch ct {
|
||||
case "openai", "openai-compat", "anthropic", "google",
|
||||
"azure", "bedrock", "mistral", "xai", "ollama", "dashscope":
|
||||
case "openai-responses", "openai-completions", "anthropic-messages", "google-generative-ai":
|
||||
return ct, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType)
|
||||
|
||||
@@ -166,14 +166,13 @@ type LifecycleEvent struct {
|
||||
}
|
||||
|
||||
type LlmProvider struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ClientType string `json:"client_type"`
|
||||
BaseUrl string `json:"base_url"`
|
||||
ApiKey string `json:"api_key"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseUrl string `json:"base_url"`
|
||||
ApiKey string `json:"api_key"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type McpConnection struct {
|
||||
@@ -209,6 +208,7 @@ type Model struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Name pgtype.Text `json:"name"`
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
ClientType pgtype.Text `json:"client_type"`
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
InputModalities []string `json:"input_modalities"`
|
||||
Type string `json:"type"`
|
||||
|
||||
+62
-101
@@ -22,17 +22,6 @@ func (q *Queries) CountLlmProviders(ctx context.Context) (int64, error) {
|
||||
return count, err
|
||||
}
|
||||
|
||||
const countLlmProvidersByClientType = `-- name: CountLlmProvidersByClientType :one
|
||||
SELECT COUNT(*) FROM llm_providers WHERE client_type = $1
|
||||
`
|
||||
|
||||
func (q *Queries) CountLlmProvidersByClientType(ctx context.Context, clientType string) (int64, error) {
|
||||
row := q.db.QueryRow(ctx, countLlmProvidersByClientType, clientType)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
const countModels = `-- name: CountModels :one
|
||||
SELECT COUNT(*) FROM models
|
||||
`
|
||||
@@ -56,29 +45,26 @@ func (q *Queries) CountModelsByType(ctx context.Context, type_ string) (int64, e
|
||||
}
|
||||
|
||||
const createLlmProvider = `-- name: CreateLlmProvider :one
|
||||
INSERT INTO llm_providers (name, client_type, base_url, api_key, metadata)
|
||||
INSERT INTO llm_providers (name, base_url, api_key, metadata)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5
|
||||
$4
|
||||
)
|
||||
RETURNING id, name, client_type, base_url, api_key, metadata, created_at, updated_at
|
||||
RETURNING id, name, base_url, api_key, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateLlmProviderParams struct {
|
||||
Name string `json:"name"`
|
||||
ClientType string `json:"client_type"`
|
||||
BaseUrl string `json:"base_url"`
|
||||
ApiKey string `json:"api_key"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
Name string `json:"name"`
|
||||
BaseUrl string `json:"base_url"`
|
||||
ApiKey string `json:"api_key"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderParams) (LlmProvider, error) {
|
||||
row := q.db.QueryRow(ctx, createLlmProvider,
|
||||
arg.Name,
|
||||
arg.ClientType,
|
||||
arg.BaseUrl,
|
||||
arg.ApiKey,
|
||||
arg.Metadata,
|
||||
@@ -87,7 +73,6 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.ClientType,
|
||||
&i.BaseUrl,
|
||||
&i.ApiKey,
|
||||
&i.Metadata,
|
||||
@@ -98,22 +83,24 @@ func (q *Queries) CreateLlmProvider(ctx context.Context, arg CreateLlmProviderPa
|
||||
}
|
||||
|
||||
const createModel = `-- name: CreateModel :one
|
||||
INSERT INTO models (model_id, name, llm_provider_id, dimensions, input_modalities, type)
|
||||
INSERT INTO models (model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6
|
||||
$6,
|
||||
$7
|
||||
)
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at
|
||||
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateModelParams struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Name pgtype.Text `json:"name"`
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
ClientType pgtype.Text `json:"client_type"`
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
InputModalities []string `json:"input_modalities"`
|
||||
Type string `json:"type"`
|
||||
@@ -124,6 +111,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
|
||||
arg.ModelID,
|
||||
arg.Name,
|
||||
arg.LlmProviderID,
|
||||
arg.ClientType,
|
||||
arg.Dimensions,
|
||||
arg.InputModalities,
|
||||
arg.Type,
|
||||
@@ -134,6 +122,7 @@ func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -209,7 +198,7 @@ func (q *Queries) DeleteModelByModelID(ctx context.Context, modelID string) erro
|
||||
}
|
||||
|
||||
const getLlmProviderByID = `-- name: GetLlmProviderByID :one
|
||||
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE id = $1
|
||||
SELECT id, name, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetLlmProviderByID(ctx context.Context, id pgtype.UUID) (LlmProvider, error) {
|
||||
@@ -218,7 +207,6 @@ func (q *Queries) GetLlmProviderByID(ctx context.Context, id pgtype.UUID) (LlmPr
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.ClientType,
|
||||
&i.BaseUrl,
|
||||
&i.ApiKey,
|
||||
&i.Metadata,
|
||||
@@ -229,7 +217,7 @@ func (q *Queries) GetLlmProviderByID(ctx context.Context, id pgtype.UUID) (LlmPr
|
||||
}
|
||||
|
||||
const getLlmProviderByName = `-- name: GetLlmProviderByName :one
|
||||
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE name = $1
|
||||
SELECT id, name, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE name = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmProvider, error) {
|
||||
@@ -238,7 +226,6 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.ClientType,
|
||||
&i.BaseUrl,
|
||||
&i.ApiKey,
|
||||
&i.Metadata,
|
||||
@@ -249,7 +236,7 @@ func (q *Queries) GetLlmProviderByName(ctx context.Context, name string) (LlmPro
|
||||
}
|
||||
|
||||
const getModelByID = `-- name: GetModelByID :one
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE id = $1
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) {
|
||||
@@ -260,6 +247,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -270,7 +258,7 @@ func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, erro
|
||||
}
|
||||
|
||||
const getModelByModelID = `-- name: GetModelByModelID :one
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE model_id = $1
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models WHERE model_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) {
|
||||
@@ -281,6 +269,7 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -291,7 +280,7 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
|
||||
}
|
||||
|
||||
const listLlmProviders = `-- name: ListLlmProviders :many
|
||||
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
|
||||
SELECT id, name, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
@@ -307,42 +296,6 @@ func (q *Queries) ListLlmProviders(ctx context.Context) ([]LlmProvider, error) {
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.ClientType,
|
||||
&i.BaseUrl,
|
||||
&i.ApiKey,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listLlmProvidersByClientType = `-- name: ListLlmProvidersByClientType :many
|
||||
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
|
||||
WHERE client_type = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListLlmProvidersByClientType(ctx context.Context, clientType string) ([]LlmProvider, error) {
|
||||
rows, err := q.db.Query(ctx, listLlmProvidersByClientType, clientType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []LlmProvider
|
||||
for rows.Next() {
|
||||
var i LlmProvider
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.ClientType,
|
||||
&i.BaseUrl,
|
||||
&i.ApiKey,
|
||||
&i.Metadata,
|
||||
@@ -394,7 +347,7 @@ func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pg
|
||||
}
|
||||
|
||||
const listModels = `-- name: ListModels :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
@@ -412,6 +365,7 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -429,13 +383,12 @@ func (q *Queries) ListModels(ctx context.Context) ([]Model, error) {
|
||||
}
|
||||
|
||||
const listModelsByClientType = `-- name: ListModelsByClientType :many
|
||||
SELECT m.id, m.model_id, m.name, m.llm_provider_id, m.dimensions, m.input_modalities, m.type, m.created_at, m.updated_at FROM models AS m
|
||||
JOIN llm_providers AS p ON p.id = m.llm_provider_id
|
||||
WHERE p.client_type = $1
|
||||
ORDER BY m.created_at DESC
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
WHERE client_type = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) ([]Model, error) {
|
||||
func (q *Queries) ListModelsByClientType(ctx context.Context, clientType pgtype.Text) ([]Model, error) {
|
||||
rows, err := q.db.Query(ctx, listModelsByClientType, clientType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -449,6 +402,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -466,7 +420,7 @@ func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string)
|
||||
}
|
||||
|
||||
const listModelsByProviderID = `-- name: ListModelsByProviderID :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
WHERE llm_provider_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
@@ -485,6 +439,7 @@ func (q *Queries) ListModelsByProviderID(ctx context.Context, llmProviderID pgty
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -502,7 +457,7 @@ func (q *Queries) ListModelsByProviderID(ctx context.Context, llmProviderID pgty
|
||||
}
|
||||
|
||||
const listModelsByProviderIDAndType = `-- name: ListModelsByProviderIDAndType :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
WHERE llm_provider_id = $1
|
||||
AND type = $2
|
||||
ORDER BY created_at DESC
|
||||
@@ -527,6 +482,7 @@ func (q *Queries) ListModelsByProviderIDAndType(ctx context.Context, arg ListMod
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -544,7 +500,7 @@ func (q *Queries) ListModelsByProviderIDAndType(ctx context.Context, arg ListMod
|
||||
}
|
||||
|
||||
const listModelsByType = `-- name: ListModelsByType :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
SELECT id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at FROM models
|
||||
WHERE type = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
@@ -563,6 +519,7 @@ func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model,
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -583,28 +540,25 @@ const updateLlmProvider = `-- name: UpdateLlmProvider :one
|
||||
UPDATE llm_providers
|
||||
SET
|
||||
name = $1,
|
||||
client_type = $2,
|
||||
base_url = $3,
|
||||
api_key = $4,
|
||||
metadata = $5,
|
||||
base_url = $2,
|
||||
api_key = $3,
|
||||
metadata = $4,
|
||||
updated_at = now()
|
||||
WHERE id = $6
|
||||
RETURNING id, name, client_type, base_url, api_key, metadata, created_at, updated_at
|
||||
WHERE id = $5
|
||||
RETURNING id, name, base_url, api_key, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateLlmProviderParams struct {
|
||||
Name string `json:"name"`
|
||||
ClientType string `json:"client_type"`
|
||||
BaseUrl string `json:"base_url"`
|
||||
ApiKey string `json:"api_key"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseUrl string `json:"base_url"`
|
||||
ApiKey string `json:"api_key"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderParams) (LlmProvider, error) {
|
||||
row := q.db.QueryRow(ctx, updateLlmProvider,
|
||||
arg.Name,
|
||||
arg.ClientType,
|
||||
arg.BaseUrl,
|
||||
arg.ApiKey,
|
||||
arg.Metadata,
|
||||
@@ -614,7 +568,6 @@ func (q *Queries) UpdateLlmProvider(ctx context.Context, arg UpdateLlmProviderPa
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.ClientType,
|
||||
&i.BaseUrl,
|
||||
&i.ApiKey,
|
||||
&i.Metadata,
|
||||
@@ -629,17 +582,19 @@ UPDATE models
|
||||
SET
|
||||
name = $1,
|
||||
llm_provider_id = $2,
|
||||
dimensions = $3,
|
||||
input_modalities = $4,
|
||||
type = $5,
|
||||
client_type = $3,
|
||||
dimensions = $4,
|
||||
input_modalities = $5,
|
||||
type = $6,
|
||||
updated_at = now()
|
||||
WHERE id = $6
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at
|
||||
WHERE id = $7
|
||||
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateModelParams struct {
|
||||
Name pgtype.Text `json:"name"`
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
ClientType pgtype.Text `json:"client_type"`
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
InputModalities []string `json:"input_modalities"`
|
||||
Type string `json:"type"`
|
||||
@@ -650,6 +605,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
|
||||
row := q.db.QueryRow(ctx, updateModel,
|
||||
arg.Name,
|
||||
arg.LlmProviderID,
|
||||
arg.ClientType,
|
||||
arg.Dimensions,
|
||||
arg.InputModalities,
|
||||
arg.Type,
|
||||
@@ -661,6 +617,7 @@ func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
@@ -676,18 +633,20 @@ SET
|
||||
model_id = $1,
|
||||
name = $2,
|
||||
llm_provider_id = $3,
|
||||
dimensions = $4,
|
||||
input_modalities = $5,
|
||||
type = $6,
|
||||
client_type = $4,
|
||||
dimensions = $5,
|
||||
input_modalities = $6,
|
||||
type = $7,
|
||||
updated_at = now()
|
||||
WHERE model_id = $7
|
||||
RETURNING id, model_id, name, llm_provider_id, dimensions, input_modalities, type, created_at, updated_at
|
||||
WHERE model_id = $8
|
||||
RETURNING id, model_id, name, llm_provider_id, client_type, dimensions, input_modalities, type, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateModelByModelIDParams struct {
|
||||
NewModelID string `json:"new_model_id"`
|
||||
Name pgtype.Text `json:"name"`
|
||||
LlmProviderID pgtype.UUID `json:"llm_provider_id"`
|
||||
ClientType pgtype.Text `json:"client_type"`
|
||||
Dimensions pgtype.Int4 `json:"dimensions"`
|
||||
InputModalities []string `json:"input_modalities"`
|
||||
Type string `json:"type"`
|
||||
@@ -699,6 +658,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
|
||||
arg.NewModelID,
|
||||
arg.Name,
|
||||
arg.LlmProviderID,
|
||||
arg.ClientType,
|
||||
arg.Dimensions,
|
||||
arg.InputModalities,
|
||||
arg.Type,
|
||||
@@ -710,6 +670,7 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
|
||||
&i.ModelID,
|
||||
&i.Name,
|
||||
&i.LlmProviderID,
|
||||
&i.ClientType,
|
||||
&i.Dimensions,
|
||||
&i.InputModalities,
|
||||
&i.Type,
|
||||
|
||||
@@ -17,10 +17,6 @@ import (
|
||||
const (
|
||||
TypeText = "text"
|
||||
TypeMultimodal = "multimodal"
|
||||
|
||||
ProviderOpenAI = "openai"
|
||||
ProviderBedrock = "bedrock"
|
||||
ProviderDashScope = "dashscope"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
@@ -81,16 +77,10 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
|
||||
}
|
||||
switch req.Type {
|
||||
case TypeText:
|
||||
if req.Provider != "" && req.Provider != ProviderOpenAI {
|
||||
return Result{}, errors.New("invalid provider for text embeddings")
|
||||
}
|
||||
if req.Input.Text == "" {
|
||||
return Result{}, errors.New("text input is required")
|
||||
}
|
||||
case TypeMultimodal:
|
||||
if req.Provider != "" && req.Provider != ProviderBedrock && req.Provider != ProviderDashScope {
|
||||
return Result{}, errors.New("invalid provider for multimodal embeddings")
|
||||
}
|
||||
if req.Input.Text == "" && req.Input.ImageURL == "" && req.Input.VideoURL == "" {
|
||||
return Result{}, errors.New("multimodal input is required")
|
||||
}
|
||||
@@ -109,7 +99,9 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
|
||||
|
||||
req.Model = selected.ModelID
|
||||
req.Dimensions = selected.Dimensions
|
||||
req.Provider = strings.ToLower(strings.TrimSpace(provider.ClientType))
|
||||
if selected.ClientType != "" {
|
||||
req.Provider = string(selected.ClientType)
|
||||
}
|
||||
if req.Model == "" {
|
||||
return Result{}, errors.New("embedding model id not configured")
|
||||
}
|
||||
@@ -122,11 +114,9 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
// OpenAI-compatible embeddings work for both openai-responses and openai-completions
|
||||
switch req.Type {
|
||||
case TypeText:
|
||||
if req.Provider != ProviderOpenAI {
|
||||
return Result{}, errors.New("provider not implemented")
|
||||
}
|
||||
embedder, err := NewOpenAIEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, req.Dimensions, timeout)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
@@ -143,29 +133,7 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
|
||||
Embedding: vector,
|
||||
}, nil
|
||||
case TypeMultimodal:
|
||||
if req.Provider == ProviderDashScope {
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return Result{}, errors.New("dashscope api key is required")
|
||||
}
|
||||
dashscope := NewDashScopeEmbedder(r.logger, provider.ApiKey, provider.BaseUrl, req.Model, timeout)
|
||||
vector, usage, err := dashscope.Embed(ctx, req.Input.Text, req.Input.ImageURL, req.Input.VideoURL)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
return Result{
|
||||
Type: req.Type,
|
||||
Provider: req.Provider,
|
||||
Model: req.Model,
|
||||
Dimensions: req.Dimensions,
|
||||
Embedding: vector,
|
||||
Usage: Usage{
|
||||
InputTokens: usage.InputTokens,
|
||||
ImageTokens: usage.ImageTokens,
|
||||
Duration: usage.Duration,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return Result{}, errors.New("provider not implemented")
|
||||
return Result{}, errors.New("multimodal embeddings not supported for current provider types")
|
||||
default:
|
||||
return Result{}, errors.New("invalid embeddings type")
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (h *ModelsHandler) Create(c echo.Context) error {
|
||||
// @Description Get a list of all configured models, optionally filtered by type or client type
|
||||
// @Tags models
|
||||
// @Param type query string false "Model type (chat, embedding)"
|
||||
// @Param client_type query string false "Client type (openai, openai-compat, anthropic, google, azure, bedrock, mistral, xai, ollama, dashscope)"
|
||||
// @Param client_type query string false "Client type (openai-responses, openai-completions, anthropic-messages, google-generative-ai)"
|
||||
// @Success 200 {array} models.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
|
||||
@@ -58,9 +58,6 @@ func (h *ProvidersHandler) Create(c echo.Context) error {
|
||||
if req.Name == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "name is required")
|
||||
}
|
||||
if req.ClientType == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "client_type is required")
|
||||
}
|
||||
if req.BaseURL == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "base_url is required")
|
||||
}
|
||||
@@ -75,27 +72,15 @@ func (h *ProvidersHandler) Create(c echo.Context) error {
|
||||
|
||||
// List godoc
|
||||
// @Summary List all LLM providers
|
||||
// @Description Get a list of all configured LLM providers, optionally filtered by client type
|
||||
// @Description Get a list of all configured LLM providers
|
||||
// @Tags providers
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param client_type query string false "Client type filter (openai, openai-compat, anthropic, google, azure, bedrock, mistral, xai, ollama, dashscope)"
|
||||
// @Success 200 {array} providers.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /providers [get]
|
||||
func (h *ProvidersHandler) List(c echo.Context) error {
|
||||
clientType := c.QueryParam("client_type")
|
||||
|
||||
var resp []providers.GetResponse
|
||||
var err error
|
||||
|
||||
if clientType != "" {
|
||||
resp, err = h.service.ListByClientType(c.Request().Context(), providers.ClientType(clientType))
|
||||
} else {
|
||||
resp, err = h.service.List(c.Request().Context())
|
||||
}
|
||||
|
||||
resp, err := h.service.List(c.Request().Context())
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
@@ -252,27 +237,15 @@ func (h *ProvidersHandler) Delete(c echo.Context) error {
|
||||
|
||||
// Count godoc
|
||||
// @Summary Count providers
|
||||
// @Description Get the total count of providers, optionally filtered by client type
|
||||
// @Description Get the total count of providers
|
||||
// @Tags providers
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param client_type query string false "Client type filter (openai, openai-compat, anthropic, google, azure, bedrock, mistral, xai, ollama, dashscope)"
|
||||
// @Success 200 {object} providers.CountResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /providers/count [get]
|
||||
func (h *ProvidersHandler) Count(c echo.Context) error {
|
||||
clientType := c.QueryParam("client_type")
|
||||
|
||||
var count int64
|
||||
var err error
|
||||
|
||||
if clientType != "" {
|
||||
count, err = h.service.CountByClientType(c.Request().Context(), providers.ClientType(clientType))
|
||||
} else {
|
||||
count, err = h.service.Count(c.Request().Context())
|
||||
}
|
||||
|
||||
count, err := h.service.Count(c.Request().Context())
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
+17
-11
@@ -49,6 +49,9 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
InputModalities: inputMod,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
if model.ClientType != "" {
|
||||
params.ClientType = pgtype.Text{String: string(model.ClientType), Valid: true}
|
||||
}
|
||||
|
||||
// Handle optional name field
|
||||
if model.Name != "" {
|
||||
@@ -140,7 +143,7 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) (
|
||||
return nil, fmt.Errorf("invalid client type: %s", clientType)
|
||||
}
|
||||
|
||||
dbModels, err := s.queries.ListModelsByClientType(ctx, string(clientType))
|
||||
dbModels, err := s.queries.ListModelsByClientType(ctx, pgtype.Text{String: string(clientType), Valid: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list models by client type: %w", err)
|
||||
}
|
||||
@@ -207,6 +210,9 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
InputModalities: inputMod,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
if model.ClientType != "" {
|
||||
params.ClientType = pgtype.Text{String: string(model.ClientType), Valid: true}
|
||||
}
|
||||
|
||||
llmProviderID, err := db.ParseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
@@ -251,6 +257,9 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
InputModalities: inputMod,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
if model.ClientType != "" {
|
||||
params.ClientType = pgtype.Text{String: string(model.ClientType), Valid: true}
|
||||
}
|
||||
|
||||
llmProviderID, err := db.ParseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
@@ -333,6 +342,9 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||
Type: ModelType(dbModel.Type),
|
||||
},
|
||||
}
|
||||
if dbModel.ClientType.Valid {
|
||||
resp.Model.ClientType = ClientType(dbModel.ClientType.String)
|
||||
}
|
||||
if resp.Model.Type == ModelTypeChat {
|
||||
resp.Model.InputModalities = normalizeModalities(dbModel.InputModalities, []string{ModelInputText})
|
||||
}
|
||||
@@ -370,16 +382,10 @@ func normalizeModalities(modalities []string, fallback []string) []string {
|
||||
|
||||
func isValidClientType(clientType ClientType) bool {
|
||||
switch clientType {
|
||||
case ClientTypeOpenAI,
|
||||
ClientTypeOpenAICompat,
|
||||
ClientTypeAnthropic,
|
||||
ClientTypeGoogle,
|
||||
ClientTypeAzure,
|
||||
ClientTypeBedrock,
|
||||
ClientTypeMistral,
|
||||
ClientTypeXAI,
|
||||
ClientTypeOllama,
|
||||
ClientTypeDashscope:
|
||||
case ClientTypeOpenAIResponses,
|
||||
ClientTypeOpenAICompletions,
|
||||
ClientTypeAnthropicMessages,
|
||||
ClientTypeGoogleGenerativeAI:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -110,6 +110,7 @@ func TestModel_Validate(t *testing.T) {
|
||||
ModelID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
ClientType: models.ClientTypeOpenAIResponses,
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -120,11 +121,33 @@ func TestModel_Validate(t *testing.T) {
|
||||
ModelID: "gpt-4o",
|
||||
Name: "GPT-4o",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
ClientType: models.ClientTypeOpenAIResponses,
|
||||
InputModalities: []string{"text", "image", "audio"},
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "chat model missing client_type",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "embedding model without client_type is valid",
|
||||
model: models.Model{
|
||||
ModelID: "text-embedding-3-small",
|
||||
Name: "Embedding",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeEmbedding,
|
||||
Dimensions: 1536,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid embedding model",
|
||||
model: models.Model{
|
||||
@@ -185,6 +208,7 @@ func TestModel_Validate(t *testing.T) {
|
||||
model: models.Model{
|
||||
ModelID: "gpt-4",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
ClientType: models.ClientTypeOpenAIResponses,
|
||||
Type: models.ModelTypeChat,
|
||||
InputModalities: []string{"text", "smell"},
|
||||
},
|
||||
@@ -262,15 +286,9 @@ func TestModelTypes(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ClientType constants", func(t *testing.T) {
|
||||
assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI)
|
||||
assert.Equal(t, models.ClientType("openai-compat"), models.ClientTypeOpenAICompat)
|
||||
assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic)
|
||||
assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle)
|
||||
assert.Equal(t, models.ClientType("azure"), models.ClientTypeAzure)
|
||||
assert.Equal(t, models.ClientType("bedrock"), models.ClientTypeBedrock)
|
||||
assert.Equal(t, models.ClientType("mistral"), models.ClientTypeMistral)
|
||||
assert.Equal(t, models.ClientType("xai"), models.ClientTypeXAI)
|
||||
assert.Equal(t, models.ClientType("ollama"), models.ClientTypeOllama)
|
||||
assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope)
|
||||
assert.Equal(t, models.ClientType("openai-responses"), models.ClientTypeOpenAIResponses)
|
||||
assert.Equal(t, models.ClientType("openai-completions"), models.ClientTypeOpenAICompletions)
|
||||
assert.Equal(t, models.ClientType("anthropic-messages"), models.ClientTypeAnthropicMessages)
|
||||
assert.Equal(t, models.ClientType("google-generative-ai"), models.ClientTypeGoogleGenerativeAI)
|
||||
})
|
||||
}
|
||||
|
||||
+19
-17
@@ -25,25 +25,20 @@ const (
|
||||
type ClientType string
|
||||
|
||||
const (
|
||||
ClientTypeOpenAI ClientType = "openai"
|
||||
ClientTypeOpenAICompat ClientType = "openai-compat"
|
||||
ClientTypeAnthropic ClientType = "anthropic"
|
||||
ClientTypeGoogle ClientType = "google"
|
||||
ClientTypeAzure ClientType = "azure"
|
||||
ClientTypeBedrock ClientType = "bedrock"
|
||||
ClientTypeMistral ClientType = "mistral"
|
||||
ClientTypeXAI ClientType = "xai"
|
||||
ClientTypeOllama ClientType = "ollama"
|
||||
ClientTypeDashscope ClientType = "dashscope"
|
||||
ClientTypeOpenAIResponses ClientType = "openai-responses"
|
||||
ClientTypeOpenAICompletions ClientType = "openai-completions"
|
||||
ClientTypeAnthropicMessages ClientType = "anthropic-messages"
|
||||
ClientTypeGoogleGenerativeAI ClientType = "google-generative-ai"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Name string `json:"name"`
|
||||
LlmProviderID string `json:"llm_provider_id"`
|
||||
InputModalities []string `json:"input_modalities,omitempty"`
|
||||
Type ModelType `json:"type"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
ModelID string `json:"model_id"`
|
||||
Name string `json:"name"`
|
||||
LlmProviderID string `json:"llm_provider_id"`
|
||||
ClientType ClientType `json:"client_type,omitempty"`
|
||||
InputModalities []string `json:"input_modalities,omitempty"`
|
||||
Type ModelType `json:"type"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
}
|
||||
|
||||
// validInputModalities is the set of recognised input modality tokens.
|
||||
@@ -65,10 +60,17 @@ func (m *Model) Validate() error {
|
||||
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding {
|
||||
return errors.New("invalid model type")
|
||||
}
|
||||
if m.Type == ModelTypeChat {
|
||||
if m.ClientType == "" {
|
||||
return errors.New("client_type is required for chat models")
|
||||
}
|
||||
if !isValidClientType(m.ClientType) {
|
||||
return fmt.Errorf("invalid client_type: %s", m.ClientType)
|
||||
}
|
||||
}
|
||||
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
|
||||
return errors.New("dimensions must be greater than 0")
|
||||
}
|
||||
// Input modalities only apply to chat models.
|
||||
if m.Type == ModelTypeChat {
|
||||
for _, mod := range m.InputModalities {
|
||||
if _, ok := validInputModalities[mod]; !ok {
|
||||
|
||||
@@ -27,11 +27,6 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||
|
||||
// Create creates a new LLM provider
|
||||
func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) {
|
||||
// Validate client type
|
||||
if !isValidClientType(req.ClientType) {
|
||||
return GetResponse{}, fmt.Errorf("invalid client_type: %s", req.ClientType)
|
||||
}
|
||||
|
||||
// Marshal metadata
|
||||
metadataJSON, err := json.Marshal(req.Metadata)
|
||||
if err != nil {
|
||||
@@ -40,11 +35,10 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e
|
||||
|
||||
// Create provider
|
||||
provider, err := s.queries.CreateLlmProvider(ctx, sqlc.CreateLlmProviderParams{
|
||||
Name: req.Name,
|
||||
ClientType: string(req.ClientType),
|
||||
BaseUrl: req.BaseURL,
|
||||
ApiKey: req.APIKey,
|
||||
Metadata: metadataJSON,
|
||||
Name: req.Name,
|
||||
BaseUrl: req.BaseURL,
|
||||
ApiKey: req.APIKey,
|
||||
Metadata: metadataJSON,
|
||||
})
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("create provider: %w", err)
|
||||
@@ -92,24 +86,6 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ListByClientType retrieves providers by client type
|
||||
func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) {
|
||||
if !isValidClientType(clientType) {
|
||||
return nil, fmt.Errorf("invalid client_type: %s", clientType)
|
||||
}
|
||||
|
||||
providers, err := s.queries.ListLlmProvidersByClientType(ctx, string(clientType))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list providers by client type: %w", err)
|
||||
}
|
||||
|
||||
results := make([]GetResponse, 0, len(providers))
|
||||
for _, p := range providers {
|
||||
results = append(results, s.toGetResponse(p))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Update updates an existing provider
|
||||
func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) {
|
||||
providerID, err := db.ParseUUID(id)
|
||||
@@ -129,14 +105,6 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
clientType := existing.ClientType
|
||||
if req.ClientType != nil {
|
||||
if !isValidClientType(*req.ClientType) {
|
||||
return GetResponse{}, fmt.Errorf("invalid client_type: %s", *req.ClientType)
|
||||
}
|
||||
clientType = string(*req.ClientType)
|
||||
}
|
||||
|
||||
baseURL := existing.BaseUrl
|
||||
if req.BaseURL != nil {
|
||||
baseURL = *req.BaseURL
|
||||
@@ -155,12 +123,11 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get
|
||||
|
||||
// Update provider
|
||||
updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{
|
||||
ID: providerID,
|
||||
Name: name,
|
||||
ClientType: clientType,
|
||||
BaseUrl: baseURL,
|
||||
ApiKey: apiKey,
|
||||
Metadata: metadata,
|
||||
ID: providerID,
|
||||
Name: name,
|
||||
BaseUrl: baseURL,
|
||||
ApiKey: apiKey,
|
||||
Metadata: metadata,
|
||||
})
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("update provider: %w", err)
|
||||
@@ -191,19 +158,6 @@ func (s *Service) Count(ctx context.Context) (int64, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByClientType returns the count of providers by client type
|
||||
func (s *Service) CountByClientType(ctx context.Context, clientType ClientType) (int64, error) {
|
||||
if !isValidClientType(clientType) {
|
||||
return 0, fmt.Errorf("invalid client_type: %s", clientType)
|
||||
}
|
||||
|
||||
count, err := s.queries.CountLlmProvidersByClientType(ctx, string(clientType))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count providers by client type: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// toGetResponse converts a database provider to a response
|
||||
func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
|
||||
var metadata map[string]any
|
||||
@@ -217,26 +171,13 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
|
||||
maskedAPIKey := maskAPIKey(provider.ApiKey)
|
||||
|
||||
return GetResponse{
|
||||
ID: provider.ID.String(),
|
||||
Name: provider.Name,
|
||||
ClientType: provider.ClientType,
|
||||
BaseURL: provider.BaseUrl,
|
||||
APIKey: maskedAPIKey,
|
||||
Metadata: metadata,
|
||||
CreatedAt: provider.CreatedAt.Time,
|
||||
UpdatedAt: provider.UpdatedAt.Time,
|
||||
}
|
||||
}
|
||||
|
||||
// isValidClientType checks if a client type is valid
|
||||
func isValidClientType(clientType ClientType) bool {
|
||||
switch clientType {
|
||||
case ClientTypeOpenAI, ClientTypeOpenAICompat, ClientTypeAnthropic, ClientTypeGoogle,
|
||||
ClientTypeAzure, ClientTypeBedrock, ClientTypeMistral, ClientTypeXAI,
|
||||
ClientTypeOllama, ClientTypeDashscope:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
ID: provider.ID.String(),
|
||||
Name: provider.Name,
|
||||
BaseURL: provider.BaseUrl,
|
||||
APIKey: maskedAPIKey,
|
||||
Metadata: metadata,
|
||||
CreatedAt: provider.CreatedAt.Time,
|
||||
UpdatedAt: provider.UpdatedAt.Time,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+18
-38
@@ -2,50 +2,31 @@ package providers
|
||||
|
||||
import "time"
|
||||
|
||||
// ClientType represents the type of LLM provider client
|
||||
type ClientType string
|
||||
|
||||
const (
|
||||
ClientTypeOpenAI ClientType = "openai"
|
||||
ClientTypeOpenAICompat ClientType = "openai-compat"
|
||||
ClientTypeAnthropic ClientType = "anthropic"
|
||||
ClientTypeGoogle ClientType = "google"
|
||||
ClientTypeAzure ClientType = "azure"
|
||||
ClientTypeBedrock ClientType = "bedrock"
|
||||
ClientTypeMistral ClientType = "mistral"
|
||||
ClientTypeXAI ClientType = "xai"
|
||||
ClientTypeOllama ClientType = "ollama"
|
||||
ClientTypeDashscope ClientType = "dashscope"
|
||||
)
|
||||
|
||||
// CreateRequest represents a request to create a new LLM provider
|
||||
type CreateRequest struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
ClientType ClientType `json:"client_type" validate:"required"`
|
||||
BaseURL string `json:"base_url" validate:"required,url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
BaseURL string `json:"base_url" validate:"required,url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateRequest represents a request to update an existing LLM provider
|
||||
type UpdateRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
ClientType *ClientType `json:"client_type,omitempty"`
|
||||
BaseURL *string `json:"base_url,omitempty"`
|
||||
APIKey *string `json:"api_key,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
BaseURL *string `json:"base_url,omitempty"`
|
||||
APIKey *string `json:"api_key,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// GetResponse represents the response for getting a provider
|
||||
type GetResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ClientType string `json:"client_type"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key,omitempty"` // masked in response
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key,omitempty"` // masked in response
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ListResponse represents the response for listing providers
|
||||
@@ -61,10 +42,9 @@ type CountResponse struct {
|
||||
|
||||
// TestRequest represents a request to test provider connection
|
||||
type TestRequest struct {
|
||||
ClientType ClientType `json:"client_type" validate:"required"`
|
||||
BaseURL string `json:"base_url" validate:"required,url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model string `json:"model"` // optional test model
|
||||
BaseURL string `json:"base_url" validate:"required,url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model string `json:"model"` // optional test model
|
||||
}
|
||||
|
||||
// TestResponse represents the result of testing a provider
|
||||
|
||||
Reference in New Issue
Block a user