mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: embedding router
This commit is contained in:
+62
-21
@@ -29,12 +29,16 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
}
|
||||
|
||||
// Convert to sqlc params
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
return AddResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
|
||||
}
|
||||
|
||||
params := sqlc.CreateModelParams{
|
||||
ModelID: model.ModelID,
|
||||
BaseUrl: model.BaseURL,
|
||||
ApiKey: model.APIKey,
|
||||
ClientType: string(model.ClientType),
|
||||
Type: string(model.Type),
|
||||
ModelID: model.ModelID,
|
||||
LlmProviderID: llmProviderID,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
|
||||
// Handle optional name field
|
||||
@@ -123,7 +127,7 @@ func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetRes
|
||||
|
||||
// ListByClientType returns models filtered by client type
|
||||
func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) {
|
||||
if clientType != ClientTypeOpenAI && clientType != ClientTypeAnthropic && clientType != ClientTypeGoogle {
|
||||
if !isValidClientType(clientType) {
|
||||
return nil, fmt.Errorf("invalid client type: %s", clientType)
|
||||
}
|
||||
|
||||
@@ -148,13 +152,17 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelParams{
|
||||
ID: uuid,
|
||||
BaseUrl: model.BaseURL,
|
||||
ApiKey: model.APIKey,
|
||||
ClientType: string(model.ClientType),
|
||||
Type: string(model.Type),
|
||||
ID: uuid,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
|
||||
}
|
||||
params.LlmProviderID = llmProviderID
|
||||
|
||||
if model.Name != "" {
|
||||
params.Name = pgtype.Text{String: model.Name, Valid: true}
|
||||
}
|
||||
@@ -183,13 +191,17 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelByModelIDParams{
|
||||
ModelID: modelID,
|
||||
BaseUrl: model.BaseURL,
|
||||
ApiKey: model.APIKey,
|
||||
ClientType: string(model.ClientType),
|
||||
Type: string(model.Type),
|
||||
ModelID: modelID,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
Type: string(model.Type),
|
||||
}
|
||||
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("invalid llm provider ID: %w", err)
|
||||
}
|
||||
params.LlmProviderID = llmProviderID
|
||||
|
||||
if model.Name != "" {
|
||||
params.Name = pgtype.Text{String: model.Name, Valid: true}
|
||||
}
|
||||
@@ -274,14 +286,16 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||
resp := GetResponse{
|
||||
ModelId: dbModel.ModelID,
|
||||
Model: Model{
|
||||
ModelID: dbModel.ModelID,
|
||||
BaseURL: dbModel.BaseUrl,
|
||||
APIKey: dbModel.ApiKey,
|
||||
ClientType: ClientType(dbModel.ClientType),
|
||||
Type: ModelType(dbModel.Type),
|
||||
ModelID: dbModel.ModelID,
|
||||
IsMultimodal: dbModel.IsMultimodal,
|
||||
Type: ModelType(dbModel.Type),
|
||||
},
|
||||
}
|
||||
|
||||
if llmProviderID, ok := uuidStringFromPgUUID(dbModel.LlmProviderID); ok {
|
||||
resp.Model.LlmProviderID = llmProviderID
|
||||
}
|
||||
|
||||
if dbModel.Name.Valid {
|
||||
resp.Model.Name = dbModel.Name.String
|
||||
}
|
||||
@@ -300,3 +314,30 @@ func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse {
|
||||
}
|
||||
return responses
|
||||
}
|
||||
|
||||
func isValidClientType(clientType ClientType) bool {
|
||||
switch clientType {
|
||||
case ClientTypeOpenAI,
|
||||
ClientTypeAnthropic,
|
||||
ClientTypeGoogle,
|
||||
ClientTypeBedrock,
|
||||
ClientTypeOllama,
|
||||
ClientTypeAzure,
|
||||
ClientTypeDashscope,
|
||||
ClientTypeOther:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) {
|
||||
if !value.Valid {
|
||||
return "", false
|
||||
}
|
||||
id, err := uuid.FromBytes(value.Bytes[:])
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return id.String(), true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user