feat: embedding router

This commit is contained in:
Ran
2026-01-26 05:10:53 +07:00
parent c332ce7749
commit 3ff0e2c4dd
22 changed files with 2572 additions and 392 deletions
+62 -21
View File
@@ -29,12 +29,16 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
}
// Convert to sqlc params
llmProviderID, err := parseUUID(model.LlmProviderID)
if err != nil {
return AddResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
}
params := sqlc.CreateModelParams{
ModelID: model.ModelID,
BaseUrl: model.BaseURL,
ApiKey: model.APIKey,
ClientType: string(model.ClientType),
Type: string(model.Type),
ModelID: model.ModelID,
LlmProviderID: llmProviderID,
IsMultimodal: model.IsMultimodal,
Type: string(model.Type),
}
// Handle optional name field
@@ -123,7 +127,7 @@ func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetRes
// ListByClientType returns models filtered by client type
func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) {
if clientType != ClientTypeOpenAI && clientType != ClientTypeAnthropic && clientType != ClientTypeGoogle {
if !isValidClientType(clientType) {
return nil, fmt.Errorf("invalid client type: %s", clientType)
}
@@ -148,13 +152,17 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
}
params := sqlc.UpdateModelParams{
ID: uuid,
BaseUrl: model.BaseURL,
ApiKey: model.APIKey,
ClientType: string(model.ClientType),
Type: string(model.Type),
ID: uuid,
IsMultimodal: model.IsMultimodal,
Type: string(model.Type),
}
llmProviderID, err := parseUUID(model.LlmProviderID)
if err != nil {
return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
}
params.LlmProviderID = llmProviderID
if model.Name != "" {
params.Name = pgtype.Text{String: model.Name, Valid: true}
}
@@ -183,13 +191,17 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
}
params := sqlc.UpdateModelByModelIDParams{
ModelID: modelID,
BaseUrl: model.BaseURL,
ApiKey: model.APIKey,
ClientType: string(model.ClientType),
Type: string(model.Type),
ModelID: modelID,
IsMultimodal: model.IsMultimodal,
Type: string(model.Type),
}
llmProviderID, err := parseUUID(model.LlmProviderID)
if err != nil {
return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
}
params.LlmProviderID = llmProviderID
if model.Name != "" {
params.Name = pgtype.Text{String: model.Name, Valid: true}
}
@@ -274,14 +286,16 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
resp := GetResponse{
ModelId: dbModel.ModelID,
Model: Model{
ModelID: dbModel.ModelID,
BaseURL: dbModel.BaseUrl,
APIKey: dbModel.ApiKey,
ClientType: ClientType(dbModel.ClientType),
Type: ModelType(dbModel.Type),
ModelID: dbModel.ModelID,
IsMultimodal: dbModel.IsMultimodal,
Type: ModelType(dbModel.Type),
},
}
if llmProviderID, ok := uuidStringFromPgUUID(dbModel.LlmProviderID); ok {
resp.Model.LlmProviderID = llmProviderID
}
if dbModel.Name.Valid {
resp.Model.Name = dbModel.Name.String
}
@@ -300,3 +314,30 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse {
}
return responses
}
func isValidClientType(clientType ClientType) bool {
switch clientType {
case ClientTypeOpenAI,
ClientTypeAnthropic,
ClientTypeGoogle,
ClientTypeBedrock,
ClientTypeOllama,
ClientTypeAzure,
ClientTypeDashscope,
ClientTypeOther:
return true
default:
return false
}
}
func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) {
if !value.Valid {
return "", false
}
id, err := uuid.FromBytes(value.Bytes[:])
if err != nil {
return "", false
}
return id.String(), true
}
+32 -55
View File
@@ -18,9 +18,7 @@ func ExampleService_Create() {
// req := models.AddRequest{
// ModelID: "gpt-4",
// Name: "GPT-4",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-...",
// ClientType: models.ClientTypeOpenAI,
// LlmProviderID: "11111111-1111-1111-1111-111111111111",
// Type: models.ModelTypeChat,
// }
@@ -77,9 +75,7 @@ func ExampleService_UpdateByModelID() {
// req := models.UpdateRequest{
// ModelID: "gpt-4",
// Name: "GPT-4 Turbo",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-...",
// ClientType: models.ClientTypeOpenAI,
// LlmProviderID: "11111111-1111-1111-1111-111111111111",
// Type: models.ModelTypeChat,
// }
@@ -111,89 +107,65 @@ func TestModel_Validate(t *testing.T) {
{
name: "valid chat model",
model: models.Model{
ModelID: "gpt-4",
Name: "GPT-4",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
ModelID: "gpt-4",
Name: "GPT-4",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
Type: models.ModelTypeChat,
},
wantErr: false,
},
{
name: "valid embedding model",
model: models.Model{
ModelID: "text-embedding-ada-002",
Name: "Ada Embeddings",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeEmbedding,
Dimensions: 1536,
ModelID: "text-embedding-ada-002",
Name: "Ada Embeddings",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
Type: models.ModelTypeEmbedding,
Dimensions: 1536,
},
wantErr: false,
},
{
name: "missing model_id",
model: models.Model{
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
LlmProviderID: "11111111-1111-1111-1111-111111111111",
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "missing base_url",
name: "missing llm_provider_id",
model: models.Model{
ModelID: "gpt-4",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
ModelID: "gpt-4",
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "missing api_key",
name: "invalid llm_provider_id",
model: models.Model{
ModelID: "gpt-4",
BaseURL: "https://api.openai.com/v1",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "invalid client type",
model: models.Model{
ModelID: "gpt-4",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: "invalid",
Type: models.ModelTypeChat,
ModelID: "gpt-4",
LlmProviderID: "not-a-uuid",
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "invalid model type",
model: models.Model{
ModelID: "gpt-4",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: "invalid",
ModelID: "gpt-4",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
Type: "invalid",
},
wantErr: true,
},
{
name: "embedding model missing dimensions",
model: models.Model{
ModelID: "text-embedding-ada-002",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeEmbedding,
Dimensions: 0,
ModelID: "text-embedding-ada-002",
LlmProviderID: "11111111-1111-1111-1111-111111111111",
Type: models.ModelTypeEmbedding,
Dimensions: 0,
},
wantErr: true,
},
@@ -221,6 +193,11 @@ func TestModelTypes(t *testing.T) {
assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI)
assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic)
assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle)
assert.Equal(t, models.ClientType("bedrock"), models.ClientTypeBedrock)
assert.Equal(t, models.ClientType("ollama"), models.ClientTypeOllama)
assert.Equal(t, models.ClientType("azure"), models.ClientTypeAzure)
assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope)
assert.Equal(t, models.ClientType("other"), models.ClientTypeOther)
})
}
+19 -19
View File
@@ -2,13 +2,15 @@ package models
import (
"errors"
"github.com/google/uuid"
)
type ModelType string
const (
ModelTypeChat = "chat"
ModelTypeEmbedding = "embedding"
ModelTypeChat ModelType = "chat"
ModelTypeEmbedding ModelType = "embedding"
)
type ClientType string
@@ -17,37 +19,35 @@ const (
ClientTypeOpenAI ClientType = "openai"
ClientTypeAnthropic ClientType = "anthropic"
ClientTypeGoogle ClientType = "google"
ClientTypeBedrock ClientType = "bedrock"
ClientTypeOllama ClientType = "ollama"
ClientTypeAzure ClientType = "azure"
ClientTypeDashscope ClientType = "dashscope"
ClientTypeOther ClientType = "other"
)
type Model struct {
ModelID string `json:"model_id"`
Name string `json:"name"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"`
ClientType ClientType `json:"client_type"`
Type ModelType `json:"type"`
Dimensions int `json:"dimensions"`
ModelID string `json:"model_id"`
Name string `json:"name"`
LlmProviderID string `json:"llm_provider_id"`
IsMultimodal bool `json:"is_multimodal"`
Type ModelType `json:"type"`
Dimensions int `json:"dimensions"`
}
func (m *Model) Validate() error {
if m.ModelID == "" {
return errors.New("model ID is required")
}
if m.BaseURL == "" {
return errors.New("base URL is required")
if m.LlmProviderID == "" {
return errors.New("llm provider ID is required")
}
if m.APIKey == "" {
return errors.New("API key is required")
}
if m.ClientType == "" {
return errors.New("client type is required")
if _, err := uuid.Parse(m.LlmProviderID); err != nil {
return errors.New("llm provider ID must be a valid UUID")
}
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding {
return errors.New("invalid model type")
}
if m.ClientType != ClientTypeOpenAI && m.ClientType != ClientTypeAnthropic && m.ClientType != ClientTypeGoogle {
return errors.New("invalid client type")
}
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
return errors.New("dimensions must be greater than 0")
}