mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: move default model into user settings
This commit is contained in:
@@ -1,57 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
// SelectMemoryModel selects a chat model for memory operations.
|
||||
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) {
|
||||
// First try to get the memory-enabled model.
|
||||
memoryModel, err := modelsService.GetByEnableAs(ctx, EnableAsMemory)
|
||||
if err == nil {
|
||||
provider, err := FetchProviderByID(ctx, queries, memoryModel.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return memoryModel, provider, nil
|
||||
}
|
||||
|
||||
// Fallback to chat model.
|
||||
chatModel, err := modelsService.GetByEnableAs(ctx, EnableAsChat)
|
||||
if err == nil {
|
||||
provider, err := FetchProviderByID(ctx, queries, chatModel.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return chatModel, provider, nil
|
||||
}
|
||||
|
||||
// If no enabled models, try to find any chat model.
|
||||
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
|
||||
if err != nil || len(candidates) == 0 {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
|
||||
}
|
||||
|
||||
selected := candidates[0]
|
||||
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return selected, provider, nil
|
||||
}
|
||||
|
||||
// FetchProviderByID fetches a provider by ID.
|
||||
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
return sqlc.LlmProvider{}, fmt.Errorf("provider id missing")
|
||||
}
|
||||
parsed, err := parseUUID(providerID)
|
||||
if err != nil {
|
||||
return sqlc.LlmProvider{}, err
|
||||
}
|
||||
return queries.GetLlmProviderByID(ctx, parsed)
|
||||
}
|
||||
+30
-55
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
@@ -31,13 +32,6 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
return AddResponse{}, fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// If enable_as is set, clear any existing model with the same enable_as
|
||||
if model.EnableAs != nil {
|
||||
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
|
||||
return AddResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to sqlc params
|
||||
llmProviderID, err := parseUUID(model.LlmProviderID)
|
||||
if err != nil {
|
||||
@@ -61,11 +55,6 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro
|
||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||
}
|
||||
|
||||
// Handle optional enable_as field
|
||||
if model.EnableAs != nil {
|
||||
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
|
||||
}
|
||||
|
||||
created, err := s.queries.CreateModel(ctx, params)
|
||||
if err != nil {
|
||||
return AddResponse{}, fmt.Errorf("failed to create model: %w", err)
|
||||
@@ -166,13 +155,6 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// If enable_as is being set, clear any existing model with the same enable_as
|
||||
if model.EnableAs != nil {
|
||||
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelParams{
|
||||
ID: uuid,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
@@ -193,11 +175,6 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest)
|
||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||
}
|
||||
|
||||
// Handle optional enable_as field
|
||||
if model.EnableAs != nil {
|
||||
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
|
||||
}
|
||||
|
||||
updated, err := s.queries.UpdateModel(ctx, params)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||
@@ -217,13 +194,6 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
return GetResponse{}, fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
// If enable_as is being set, clear any existing model with the same enable_as
|
||||
if model.EnableAs != nil {
|
||||
if err := s.queries.ClearEnableAs(ctx, pgtype.Text{String: string(*model.EnableAs), Valid: true}); err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to clear existing enable_as: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := sqlc.UpdateModelByModelIDParams{
|
||||
ModelID: modelID,
|
||||
IsMultimodal: model.IsMultimodal,
|
||||
@@ -244,11 +214,6 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat
|
||||
params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true}
|
||||
}
|
||||
|
||||
// Handle optional enable_as field
|
||||
if model.EnableAs != nil {
|
||||
params.EnableAs = pgtype.Text{String: string(*model.EnableAs), Valid: true}
|
||||
}
|
||||
|
||||
updated, err := s.queries.UpdateModelByModelID(ctx, params)
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to update model: %w", err)
|
||||
@@ -306,20 +271,6 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetByEnableAs retrieves the model that has the specified enable_as value
|
||||
func (s *Service) GetByEnableAs(ctx context.Context, enableAs EnableAs) (GetResponse, error) {
|
||||
if enableAs != EnableAsChat && enableAs != EnableAsMemory && enableAs != EnableAsEmbedding {
|
||||
return GetResponse{}, fmt.Errorf("invalid enable_as value: %s", enableAs)
|
||||
}
|
||||
|
||||
dbModel, err := s.queries.GetModelByEnableAs(ctx, pgtype.Text{String: string(enableAs), Valid: true})
|
||||
if err != nil {
|
||||
return GetResponse{}, fmt.Errorf("failed to get model by enable_as: %w", err)
|
||||
}
|
||||
|
||||
return convertToGetResponse(dbModel), nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func parseUUID(id string) (pgtype.UUID, error) {
|
||||
@@ -357,11 +308,6 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||
resp.Model.Dimensions = int(dbModel.Dimensions.Int32)
|
||||
}
|
||||
|
||||
if dbModel.EnableAs.Valid {
|
||||
enableAs := EnableAs(dbModel.EnableAs.String)
|
||||
resp.Model.EnableAs = &enableAs
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -399,3 +345,32 @@ func uuidStringFromPgUUID(value pgtype.UUID) (string, bool) {
|
||||
}
|
||||
return id.String(), true
|
||||
}
|
||||
|
||||
// SelectMemoryModel selects a chat model for memory operations.
|
||||
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) {
|
||||
if modelsService == nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
|
||||
}
|
||||
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
|
||||
if err != nil || len(candidates) == 0 {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
|
||||
}
|
||||
selected := candidates[0]
|
||||
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return selected, provider, nil
|
||||
}
|
||||
|
||||
// FetchProviderByID fetches a provider by ID.
|
||||
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
return sqlc.LlmProvider{}, fmt.Errorf("provider id missing")
|
||||
}
|
||||
parsed, err := parseUUID(providerID)
|
||||
if err != nil {
|
||||
return sqlc.LlmProvider{}, err
|
||||
}
|
||||
return queries.GetLlmProviderByID(ctx, parsed)
|
||||
}
|
||||
|
||||
@@ -13,14 +13,6 @@ const (
|
||||
ModelTypeEmbedding ModelType = "embedding"
|
||||
)
|
||||
|
||||
type EnableAs string
|
||||
|
||||
const (
|
||||
EnableAsChat EnableAs = "chat"
|
||||
EnableAsMemory EnableAs = "memory"
|
||||
EnableAsEmbedding EnableAs = "embedding"
|
||||
)
|
||||
|
||||
type ClientType string
|
||||
|
||||
const (
|
||||
@@ -41,7 +33,6 @@ type Model struct {
|
||||
IsMultimodal bool `json:"is_multimodal"`
|
||||
Type ModelType `json:"type"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
EnableAs *EnableAs `json:"enable_as,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
@@ -60,21 +51,7 @@ func (m *Model) Validate() error {
|
||||
if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 {
|
||||
return errors.New("dimensions must be greater than 0")
|
||||
}
|
||||
|
||||
// Validate enable_as based on type
|
||||
if m.EnableAs != nil {
|
||||
switch m.Type {
|
||||
case ModelTypeEmbedding:
|
||||
if *m.EnableAs != EnableAsEmbedding {
|
||||
return errors.New("embedding models can only have enable_as set to 'embedding'")
|
||||
}
|
||||
case ModelTypeChat:
|
||||
if *m.EnableAs != EnableAsChat && *m.EnableAs != EnableAsMemory {
|
||||
return errors.New("chat models can only have enable_as set to 'chat' or 'memory'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user