feat: models

This commit is contained in:
Acbox
2026-01-23 18:53:20 +08:00
parent 0edaba4e74
commit c332ce7749
17 changed files with 2765 additions and 39 deletions
+302
View File
@@ -0,0 +1,302 @@
package models
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db/sqlc"
)
// Service provides CRUD operations for models
type Service struct {
queries *sqlc.Queries
}
// NewService creates a new models service
func NewService(queries *sqlc.Queries) *Service {
return &Service{
queries: queries,
}
}
// Create adds a new model to the database
func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, error) {
model := Model(req)
if err := model.Validate(); err != nil {
return AddResponse{}, fmt.Errorf("validation failed: %w", err)
}
// Convert to sqlc params
params := sqlc.CreateModelParams{
ModelID: model.ModelID,
BaseUrl: model.BaseURL,
ApiKey: model.APIKey,
ClientType: string(model.ClientType),
Type: string(model.Type),
}
// Handle optional name field
if model.Name != "" {
params.Name = pgtype.Text{String: model.Name, Valid: true}
}
// Handle optional dimensions field (only for embedding models)
if model.Type == ModelTypeEmbedding && model.Dimensions > 0 {
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
created, err := s.queries.CreateModel(ctx, params)
if err != nil {
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
}
// Convert pgtype.UUID to string
var idStr string
if created.ID.Valid {
id, err := uuid.FromBytes(created.ID.Bytes[:])
if err != nil {
return AddResponse{}, fmt.Errorf("failed to convert UUID: %w", err)
}
idStr = id.String()
}
return AddResponse{
ID: idStr,
ModelID: created.ModelID,
}, nil
}
// GetByID retrieves a model by its internal UUID
func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) {
uuid, err := parseUUID(id)
if err != nil {
return GetResponse{}, fmt.Errorf("invalid ID: %w", err)
}
dbModel, err := s.queries.GetModelByID(ctx, uuid)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to get model: %w", err)
}
return convertToGetResponse(dbModel), nil
}
// GetByModelID retrieves a model by its model_id field
func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse, error) {
if modelID == "" {
return GetResponse{}, fmt.Errorf("model_id is required")
}
dbModel, err := s.queries.GetModelByModelID(ctx, modelID)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to get model: %w", err)
}
return convertToGetResponse(dbModel), nil
}
// List returns all models
func (s *Service) List(ctx context.Context) ([]GetResponse, error) {
dbModels, err := s.queries.ListModels(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list models: %w", err)
}
return convertToGetResponseList(dbModels), nil
}
// ListByType returns models filtered by type (chat or embedding)
func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) {
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding {
return nil, fmt.Errorf("invalid model type: %s", modelType)
}
dbModels, err := s.queries.ListModelsByType(ctx, string(modelType))
if err != nil {
return nil, fmt.Errorf("failed to list models by type: %w", err)
}
return convertToGetResponseList(dbModels), nil
}
// ListByClientType returns models filtered by client type
func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) {
if clientType != ClientTypeOpenAI && clientType != ClientTypeAnthropic && clientType != ClientTypeGoogle {
return nil, fmt.Errorf("invalid client type: %s", clientType)
}
dbModels, err := s.queries.ListModelsByClientType(ctx, string(clientType))
if err != nil {
return nil, fmt.Errorf("failed to list models by client type: %w", err)
}
return convertToGetResponseList(dbModels), nil
}
// UpdateByID updates a model by its internal UUID
func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) {
uuid, err := parseUUID(id)
if err != nil {
return GetResponse{}, fmt.Errorf("invalid ID: %w", err)
}
model := Model(req)
if err := model.Validate(); err != nil {
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
}
params := sqlc.UpdateModelParams{
ID: uuid,
BaseUrl: model.BaseURL,
ApiKey: model.APIKey,
ClientType: string(model.ClientType),
Type: string(model.Type),
}
if model.Name != "" {
params.Name = pgtype.Text{String: model.Name, Valid: true}
}
if model.Type == ModelTypeEmbedding && model.Dimensions > 0 {
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
updated, err := s.queries.UpdateModel(ctx, params)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
}
return convertToGetResponse(updated), nil
}
// UpdateByModelID updates a model by its model_id field
func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req UpdateRequest) (GetResponse, error) {
if modelID == "" {
return GetResponse{}, fmt.Errorf("model_id is required")
}
model := Model(req)
if err := model.Validate(); err != nil {
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
}
params := sqlc.UpdateModelByModelIDParams{
ModelID: modelID,
BaseUrl: model.BaseURL,
ApiKey: model.APIKey,
ClientType: string(model.ClientType),
Type: string(model.Type),
}
if model.Name != "" {
params.Name = pgtype.Text{String: model.Name, Valid: true}
}
if model.Type == ModelTypeEmbedding && model.Dimensions > 0 {
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
}
updated, err := s.queries.UpdateModelByModelID(ctx, params)
if err != nil {
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
}
return convertToGetResponse(updated), nil
}
// DeleteByID deletes a model by its internal UUID
func (s *Service) DeleteByID(ctx context.Context, id string) error {
uuid, err := parseUUID(id)
if err != nil {
return fmt.Errorf("invalid ID: %w", err)
}
if err := s.queries.DeleteModel(ctx, uuid); err != nil {
return fmt.Errorf("failed to delete model: %w", err)
}
return nil
}
// DeleteByModelID deletes a model by its model_id field
func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error {
if modelID == "" {
return fmt.Errorf("model_id is required")
}
if err := s.queries.DeleteModelByModelID(ctx, modelID); err != nil {
return fmt.Errorf("failed to delete model: %w", err)
}
return nil
}
// Count returns the total number of models
func (s *Service) Count(ctx context.Context) (int64, error) {
count, err := s.queries.CountModels(ctx)
if err != nil {
return 0, fmt.Errorf("failed to count models: %w", err)
}
return count, nil
}
// CountByType returns the number of models of a specific type
func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, error) {
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding {
return 0, fmt.Errorf("invalid model type: %s", modelType)
}
count, err := s.queries.CountModelsByType(ctx, string(modelType))
if err != nil {
return 0, fmt.Errorf("failed to count models by type: %w", err)
}
return count, nil
}
// Helper functions
func parseUUID(id string) (pgtype.UUID, error) {
parsed, err := uuid.Parse(id)
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid UUID format: %w", err)
}
var pgUUID pgtype.UUID
copy(pgUUID.Bytes[:], parsed[:])
pgUUID.Valid = true
return pgUUID, nil
}
func convertToGetResponse(dbModel sqlc.Model) GetResponse {
resp := GetResponse{
ModelId: dbModel.ModelID,
Model: Model{
ModelID: dbModel.ModelID,
BaseURL: dbModel.BaseUrl,
APIKey: dbModel.ApiKey,
ClientType: ClientType(dbModel.ClientType),
Type: ModelType(dbModel.Type),
},
}
if dbModel.Name.Valid {
resp.Model.Name = dbModel.Name.String
}
if dbModel.Dimensions.Valid {
resp.Model.Dimensions = int(dbModel.Dimensions.Int32)
}
return resp
}
func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse {
responses := make([]GetResponse, 0, len(dbModels))
for _, dbModel := range dbModels {
responses = append(responses, convertToGetResponse(dbModel))
}
return responses
}
+297
View File
@@ -0,0 +1,297 @@
package models_test
import (
"testing"
"github.com/memohai/memoh/internal/models"
"github.com/stretchr/testify/assert"
)
// This is an example test file demonstrating how to use the models service
// Actual tests would require database setup and mocking
func ExampleService_Create() {
// Example usage - in real code, you would initialize with actual database connection
// service := models.NewService(queries)
// ctx := context.Background()
// req := models.AddRequest{
// ModelID: "gpt-4",
// Name: "GPT-4",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-...",
// ClientType: models.ClientTypeOpenAI,
// Type: models.ModelTypeChat,
// }
// resp, err := service.Create(ctx, req)
// if err != nil {
// // handle error
// }
// fmt.Printf("Created model with ID: %s\n", resp.ID)
}
func ExampleService_GetByModelID() {
// Example usage
// service := models.NewService(queries)
// ctx := context.Background()
// resp, err := service.GetByModelID(ctx, "gpt-4")
// if err != nil {
// // handle error
// }
// fmt.Printf("Model: %+v\n", resp.Model)
}
func ExampleService_List() {
// Example usage
// service := models.NewService(queries)
// ctx := context.Background()
// models, err := service.List(ctx)
// if err != nil {
// // handle error
// }
// for _, model := range models {
// fmt.Printf("Model ID: %s, Type: %s\n", model.ModelID, model.Type)
// }
}
func ExampleService_ListByType() {
// Example usage
// service := models.NewService(queries)
// ctx := context.Background()
// chatModels, err := service.ListByType(ctx, models.ModelTypeChat)
// if err != nil {
// // handle error
// }
// fmt.Printf("Found %d chat models\n", len(chatModels))
}
func ExampleService_UpdateByModelID() {
// Example usage
// service := models.NewService(queries)
// ctx := context.Background()
// req := models.UpdateRequest{
// ModelID: "gpt-4",
// Name: "GPT-4 Turbo",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-...",
// ClientType: models.ClientTypeOpenAI,
// Type: models.ModelTypeChat,
// }
// resp, err := service.UpdateByModelID(ctx, "gpt-4", req)
// if err != nil {
// // handle error
// }
// fmt.Printf("Updated model: %s\n", resp.ModelId)
}
func ExampleService_DeleteByModelID() {
// Example usage
// service := models.NewService(queries)
// ctx := context.Background()
// err := service.DeleteByModelID(ctx, "gpt-4")
// if err != nil {
// // handle error
// }
// fmt.Println("Model deleted successfully")
}
func TestModel_Validate(t *testing.T) {
tests := []struct {
name string
model models.Model
wantErr bool
}{
{
name: "valid chat model",
model: models.Model{
ModelID: "gpt-4",
Name: "GPT-4",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
},
wantErr: false,
},
{
name: "valid embedding model",
model: models.Model{
ModelID: "text-embedding-ada-002",
Name: "Ada Embeddings",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeEmbedding,
Dimensions: 1536,
},
wantErr: false,
},
{
name: "missing model_id",
model: models.Model{
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "missing base_url",
model: models.Model{
ModelID: "gpt-4",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "missing api_key",
model: models.Model{
ModelID: "gpt-4",
BaseURL: "https://api.openai.com/v1",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "invalid client type",
model: models.Model{
ModelID: "gpt-4",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: "invalid",
Type: models.ModelTypeChat,
},
wantErr: true,
},
{
name: "invalid model type",
model: models.Model{
ModelID: "gpt-4",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: "invalid",
},
wantErr: true,
},
{
name: "embedding model missing dimensions",
model: models.Model{
ModelID: "text-embedding-ada-002",
BaseURL: "https://api.openai.com/v1",
APIKey: "sk-test",
ClientType: models.ClientTypeOpenAI,
Type: models.ModelTypeEmbedding,
Dimensions: 0,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.model.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestModelTypes(t *testing.T) {
t.Run("ModelType constants", func(t *testing.T) {
assert.Equal(t, models.ModelType("chat"), models.ModelTypeChat)
assert.Equal(t, models.ModelType("embedding"), models.ModelTypeEmbedding)
})
t.Run("ClientType constants", func(t *testing.T) {
assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI)
assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic)
assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle)
})
}
// Integration test example (requires actual database)
// func TestService_Integration(t *testing.T) {
// if testing.Short() {
// t.Skip("Skipping integration test")
// }
//
// ctx := context.Background()
//
// // Setup database connection
// pool, err := db.Open(ctx, config.PostgresConfig{
// Host: "localhost",
// Port: 5432,
// User: "test",
// Password: "test",
// Database: "test_db",
// SSLMode: "disable",
// })
// require.NoError(t, err)
// defer pool.Close()
//
// queries := sqlc.New(pool)
// service := models.NewService(queries)
//
// // Test Create
// createReq := models.AddRequest{
// ModelID: "test-gpt-4",
// Name: "Test GPT-4",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-test",
// ClientType: models.ClientTypeOpenAI,
// Type: models.ModelTypeChat,
// }
// createResp, err := service.Create(ctx, createReq)
// require.NoError(t, err)
// assert.NotEmpty(t, createResp.ID)
// assert.Equal(t, "test-gpt-4", createResp.ModelID)
//
// // Test GetByModelID
// getResp, err := service.GetByModelID(ctx, "test-gpt-4")
// require.NoError(t, err)
// assert.Equal(t, "test-gpt-4", getResp.ModelID)
// assert.Equal(t, "Test GPT-4", getResp.Name)
//
// // Test List
// models, err := service.List(ctx)
// require.NoError(t, err)
// assert.NotEmpty(t, models)
//
// // Test Update
// updateReq := models.UpdateRequest{
// ModelID: "test-gpt-4",
// Name: "Updated GPT-4",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-test-updated",
// ClientType: models.ClientTypeOpenAI,
// Type: models.ModelTypeChat,
// }
// updateResp, err := service.UpdateByModelID(ctx, "test-gpt-4", updateReq)
// require.NoError(t, err)
// assert.Equal(t, "Updated GPT-4", updateResp.Name)
//
// // Test Count
// count, err := service.Count(ctx)
// require.NoError(t, err)
// assert.Greater(t, count, int64(0))
//
// // Test Delete
// err = service.DeleteByModelID(ctx, "test-gpt-4")
// require.NoError(t, err)
// }
+91
View File
@@ -0,0 +1,91 @@
package models
import (
"errors"
)
type ModelType string
const (
ModelTypeChat = "chat"
ModelTypeEmbedding = "embedding"
)
type ClientType string
const (
ClientTypeOpenAI ClientType = "openai"
ClientTypeAnthropic ClientType = "anthropic"
ClientTypeGoogle ClientType = "google"
)
type Model struct {
ModelID string `json:"model_id"`
Name string `json:"name"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"`
ClientType ClientType `json:"client_type"`
Type ModelType `json:"type"`
Dimensions int `json:"dimensions"`
}
func (m *Model) Validate() error {
if m.ModelID == "" {
return errors.New("model ID is required")
}
if m.BaseURL == "" {
return errors.New("base URL is required")
}
if m.APIKey == "" {
return errors.New("API key is required")
}
if m.ClientType == "" {
return errors.New("client type is required")
}
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding {
return errors.New("invalid model type")
}
if m.ClientType != ClientTypeOpenAI && m.ClientType != ClientTypeAnthropic && m.ClientType != ClientTypeGoogle {
return errors.New("invalid client type")
}
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
return errors.New("dimensions must be greater than 0")
}
return nil
}
type AddRequest Model
type AddResponse struct {
ID string `json:"id"`
ModelID string `json:"model_id"`
}
type GetRequest struct {
ID string `json:"id"`
}
type GetResponse struct {
ModelId string `json:"model_id"`
Model
}
type UpdateRequest Model
type ListRequest struct {
Type ModelType `json:"type,omitempty"`
ClientType ClientType `json:"client_type,omitempty"`
}
type DeleteRequest struct {
ID string `json:"id,omitempty"`
ModelID string `json:"model_id,omitempty"`
}
type DeleteResponse struct {
Message string `json:"message"`
}
type CountResponse struct {
Count int64 `json:"count"`
}