mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: embedding router
This commit is contained in:
+62
-21
@@ -29,12 +29,16 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
}
|
||||
|
||||
// Convert to sqlc params
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
return AddResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
|
||||
}
|
||||
|
||||
params := sqlc.CreateModelParams{
|
||||
ModelID: model.ModelID,
|
||||
BaseUrl: model.BaseURL,
|
||||
ApiKey: model.APIKey,
|
||||
ClientType: string(model.ClientType),
|
||||
Type: string(model.Type),
|
||||
ModelID: model.ModelID,
|
||||
LlmProviderID: llmProviderID,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
|
||||
// Handle optional name field
|
||||
@@ -123,7 +127,7 @@ func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetRes
|
||||
|
||||
// ListByClientType returns models filtered by client type
|
||||
func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) {
|
||||
if clientType != ClientTypeOpenAI && clientType != ClientTypeAnthropic && clientType != ClientTypeGoogle {
|
||||
if !isValidClientType(clientType) {
|
||||
return nil, fmt.Errorf("invalid client type: %s", clientType)
|
||||
}
|
||||
|
||||
@@ -148,13 +152,17 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelParams{
|
||||
ID: uuid,
|
||||
BaseUrl: model.BaseURL,
|
||||
ApiKey: model.APIKey,
|
||||
ClientType: string(model.ClientType),
|
||||
Type: string(model.Type),
|
||||
ID: uuid,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
|
||||
}
|
||||
params.LlmProviderID = llmProviderID
|
||||
|
||||
if model.Name != "" {
|
||||
params.Name = pgtype.Text{String: model.Name, Valid: true}
|
||||
}
|
||||
@@ -183,13 +191,17 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelByModelIDParams{
|
||||
ModelID: modelID,
|
||||
BaseUrl: model.BaseURL,
|
||||
ApiKey: model.APIKey,
|
||||
ClientType: string(model.ClientType),
|
||||
Type: string(model.Type),
|
||||
ModelID: modelID,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
|
||||
}
|
||||
params.LlmProviderID = llmProviderID
|
||||
|
||||
if model.Name != "" {
|
||||
params.Name = pgtype.Text{String: model.Name, Valid: true}
|
||||
}
|
||||
@@ -274,14 +286,16 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||
resp := GetResponse{
|
||||
ModelId: dbModel.ModelID,
|
||||
Model: Model{
|
||||
ModelID: dbModel.ModelID,
|
||||
BaseURL: dbModel.BaseUrl,
|
||||
APIKey: dbModel.ApiKey,
|
||||
ClientType: ClientType(dbModel.ClientType),
|
||||
Type: ModelType(dbModel.Type),
|
||||
ModelID: dbModel.ModelID,
|
||||
IsMultimodal: dbModel.IsMultimodal,
|
||||
Type: ModelType(dbModel.Type),
|
||||
},
|
||||
}
|
||||
|
||||
if llmProviderID, ok := uuidStringFromPgUUID(dbModel.LlmProviderID); ok {
|
||||
resp.Model.LlmProviderID = llmProviderID
|
||||
}
|
||||
|
||||
if dbModel.Name.Valid {
|
||||
resp.Model.Name = dbModel.Name.String
|
||||
}
|
||||
@@ -300,3 +314,30 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse {
|
||||
}
|
||||
return responses
|
||||
}
|
||||
|
||||
func isValidClientType(clientType ClientType) bool {
|
||||
switch clientType {
|
||||
case ClientTypeOpenAI,
|
||||
ClientTypeAnthropic,
|
||||
ClientTypeGoogle,
|
||||
ClientTypeBedrock,
|
||||
ClientTypeOllama,
|
||||
ClientTypeAzure,
|
||||
ClientTypeDashscope,
|
||||
ClientTypeOther:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) {
|
||||
if !value.Valid {
|
||||
return "", false
|
||||
}
|
||||
id, err := uuid.FromBytes(value.Bytes[:])
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return id.String(), true
|
||||
}
|
||||
|
||||
@@ -18,9 +18,7 @@ func ExampleService_Create() {
|
||||
// req := models.AddRequest{
|
||||
// ModelID: "gpt-4",
|
||||
// Name: "GPT-4",
|
||||
// BaseURL: "https://api.openai.com/v1",
|
||||
// APIKey: "sk-...",
|
||||
// ClientType: models.ClientTypeOpenAI,
|
||||
// LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
// Type: models.ModelTypeChat,
|
||||
// }
|
||||
|
||||
@@ -77,9 +75,7 @@ func ExampleService_UpdateByModelID() {
|
||||
// req := models.UpdateRequest{
|
||||
// ModelID: "gpt-4",
|
||||
// Name: "GPT-4 Turbo",
|
||||
// BaseURL: "https://api.openai.com/v1",
|
||||
// APIKey: "sk-...",
|
||||
// ClientType: models.ClientTypeOpenAI,
|
||||
// LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
// Type: models.ModelTypeChat,
|
||||
// }
|
||||
|
||||
@@ -111,89 +107,65 @@ func TestModel_Validate(t *testing.T) {
|
||||
{
|
||||
name: "valid chat model",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
APIKey: "sk-test",
|
||||
ClientType: models.ClientTypeOpenAI,
|
||||
Type: models.ModelTypeChat,
|
||||
ModelID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid embedding model",
|
||||
model: models.Model{
|
||||
ModelID: "text-embedding-ada-002",
|
||||
Name: "Ada Embeddings",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
APIKey: "sk-test",
|
||||
ClientType: models.ClientTypeOpenAI,
|
||||
Type: models.ModelTypeEmbedding,
|
||||
Dimensions: 1536,
|
||||
ModelID: "text-embedding-ada-002",
|
||||
Name: "Ada Embeddings",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeEmbedding,
|
||||
Dimensions: 1536,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model_id",
|
||||
model: models.Model{
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
APIKey: "sk-test",
|
||||
ClientType: models.ClientTypeOpenAI,
|
||||
Type: models.ModelTypeChat,
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing base_url",
|
||||
name: "missing llm_provider_id",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-4",
|
||||
APIKey: "sk-test",
|
||||
ClientType: models.ClientTypeOpenAI,
|
||||
Type: models.ModelTypeChat,
|
||||
ModelID: "gpt-4",
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing api_key",
|
||||
name: "invalid llm_provider_id",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-4",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
ClientType: models.ClientTypeOpenAI,
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid client type",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-4",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
APIKey: "sk-test",
|
||||
ClientType: "invalid",
|
||||
Type: models.ModelTypeChat,
|
||||
ModelID: "gpt-4",
|
||||
LlmProviderID: "not-a-uuid",
|
||||
Type: models.ModelTypeChat,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid model type",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-4",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
APIKey: "sk-test",
|
||||
ClientType: models.ClientTypeOpenAI,
|
||||
Type: "invalid",
|
||||
ModelID: "gpt-4",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "embedding model missing dimensions",
|
||||
model: models.Model{
|
||||
ModelID: "text-embedding-ada-002",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
APIKey: "sk-test",
|
||||
ClientType: models.ClientTypeOpenAI,
|
||||
Type: models.ModelTypeEmbedding,
|
||||
Dimensions: 0,
|
||||
ModelID: "text-embedding-ada-002",
|
||||
LlmProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeEmbedding,
|
||||
Dimensions: 0,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -221,6 +193,11 @@ func TestModelTypes(t *testing.T) {
|
||||
assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI)
|
||||
assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic)
|
||||
assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle)
|
||||
assert.Equal(t, models.ClientType("bedrock"), models.ClientTypeBedrock)
|
||||
assert.Equal(t, models.ClientType("ollama"), models.ClientTypeOllama)
|
||||
assert.Equal(t, models.ClientType("azure"), models.ClientTypeAzure)
|
||||
assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope)
|
||||
assert.Equal(t, models.ClientType("other"), models.ClientTypeOther)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+19
-19
@@ -2,13 +2,15 @@ package models
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ModelType string
|
||||
|
||||
const (
|
||||
ModelTypeChat = "chat"
|
||||
ModelTypeEmbedding = "embedding"
|
||||
ModelTypeChat ModelType = "chat"
|
||||
ModelTypeEmbedding ModelType = "embedding"
|
||||
)
|
||||
|
||||
type ClientType string
|
||||
@@ -17,37 +19,35 @@ const (
|
||||
ClientTypeOpenAI ClientType = "openai"
|
||||
ClientTypeAnthropic ClientType = "anthropic"
|
||||
ClientTypeGoogle ClientType = "google"
|
||||
ClientTypeBedrock ClientType = "bedrock"
|
||||
ClientTypeOllama ClientType = "ollama"
|
||||
ClientTypeAzure ClientType = "azure"
|
||||
ClientTypeDashscope ClientType = "dashscope"
|
||||
ClientTypeOther ClientType = "other"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
ClientType ClientType `json:"client_type"`
|
||||
Type ModelType `json:"type"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
ModelID string `json:"model_id"`
|
||||
Name string `json:"name"`
|
||||
LlmProviderID string `json:"llm_provider_id"`
|
||||
IsMultimodal bool `json:"is_multimodal"`
|
||||
Type ModelType `json:"type"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.ModelID == "" {
|
||||
return errors.New("model ID is required")
|
||||
}
|
||||
if m.BaseURL == "" {
|
||||
return errors.New("base URL is required")
|
||||
if m.LlmProviderID == "" {
|
||||
return errors.New("llm provider ID is required")
|
||||
}
|
||||
if m.APIKey == "" {
|
||||
return errors.New("API key is required")
|
||||
}
|
||||
if m.ClientType == "" {
|
||||
return errors.New("client type is required")
|
||||
if _, err := uuid.Parse(m.LlmProviderID); err != nil {
|
||||
return errors.New("llm provider ID must be a valid UUID")
|
||||
}
|
||||
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding {
|
||||
return errors.New("invalid model type")
|
||||
}
|
||||
if m.ClientType != ClientTypeOpenAI && m.ClientType != ClientTypeAnthropic && m.ClientType != ClientTypeGoogle {
|
||||
return errors.New("invalid client type")
|
||||
}
|
||||
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
|
||||
return errors.New("dimensions must be greater than 0")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user