mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
180 lines
4.9 KiB
Go
180 lines
4.9 KiB
Go
package provider
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
)
|
|
|
|
type Service struct {
|
|
queries *sqlc.Queries
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
|
return &Service{
|
|
queries: queries,
|
|
logger: log.With(slog.String("service", "memory_providers")),
|
|
}
|
|
}
|
|
|
|
func (*Service) ListMeta(_ context.Context) []ProviderMeta {
|
|
return []ProviderMeta{
|
|
{
|
|
Provider: string(ProviderBuiltin),
|
|
DisplayName: "Built-in",
|
|
ConfigSchema: ProviderConfigSchema{
|
|
Fields: map[string]ProviderFieldSchema{
|
|
"memory_model_id": {
|
|
Type: "model_select",
|
|
Title: "Memory Model",
|
|
Description: "LLM model used for memory extraction and decision",
|
|
Required: false,
|
|
},
|
|
"embedding_model_id": {
|
|
Type: "model_select",
|
|
Title: "Embedding Model",
|
|
Description: "Embedding model for dense vector search",
|
|
Required: false,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *Service) Create(ctx context.Context, req ProviderCreateRequest) (ProviderGetResponse, error) {
|
|
if !isValidProviderType(req.Provider) {
|
|
return ProviderGetResponse{}, fmt.Errorf("invalid provider type: %s", req.Provider)
|
|
}
|
|
configJSON, err := json.Marshal(req.Config)
|
|
if err != nil {
|
|
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", err)
|
|
}
|
|
row, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
|
|
Name: strings.TrimSpace(req.Name),
|
|
Provider: string(req.Provider),
|
|
Config: configJSON,
|
|
IsDefault: false,
|
|
})
|
|
if err != nil {
|
|
return ProviderGetResponse{}, fmt.Errorf("create memory provider: %w", err)
|
|
}
|
|
return s.toGetResponse(row), nil
|
|
}
|
|
|
|
func (s *Service) Get(ctx context.Context, id string) (ProviderGetResponse, error) {
|
|
pgID, err := db.ParseUUID(id)
|
|
if err != nil {
|
|
return ProviderGetResponse{}, err
|
|
}
|
|
row, err := s.queries.GetMemoryProviderByID(ctx, pgID)
|
|
if err != nil {
|
|
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
|
|
}
|
|
return s.toGetResponse(row), nil
|
|
}
|
|
|
|
func (s *Service) List(ctx context.Context) ([]ProviderGetResponse, error) {
|
|
rows, err := s.queries.ListMemoryProviders(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list memory providers: %w", err)
|
|
}
|
|
items := make([]ProviderGetResponse, 0, len(rows))
|
|
for _, row := range rows {
|
|
items = append(items, s.toGetResponse(row))
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Service) Update(ctx context.Context, id string, req ProviderUpdateRequest) (ProviderGetResponse, error) {
|
|
pgID, err := db.ParseUUID(id)
|
|
if err != nil {
|
|
return ProviderGetResponse{}, err
|
|
}
|
|
current, err := s.queries.GetMemoryProviderByID(ctx, pgID)
|
|
if err != nil {
|
|
return ProviderGetResponse{}, fmt.Errorf("get memory provider: %w", err)
|
|
}
|
|
name := current.Name
|
|
if req.Name != nil {
|
|
name = strings.TrimSpace(*req.Name)
|
|
}
|
|
config := current.Config
|
|
if req.Config != nil {
|
|
configJSON, marshalErr := json.Marshal(req.Config)
|
|
if marshalErr != nil {
|
|
return ProviderGetResponse{}, fmt.Errorf("marshal config: %w", marshalErr)
|
|
}
|
|
config = configJSON
|
|
}
|
|
updated, err := s.queries.UpdateMemoryProvider(ctx, sqlc.UpdateMemoryProviderParams{
|
|
ID: pgID,
|
|
Name: name,
|
|
Config: config,
|
|
})
|
|
if err != nil {
|
|
return ProviderGetResponse{}, fmt.Errorf("update memory provider: %w", err)
|
|
}
|
|
return s.toGetResponse(updated), nil
|
|
}
|
|
|
|
func (s *Service) Delete(ctx context.Context, id string) error {
|
|
pgID, err := db.ParseUUID(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.queries.DeleteMemoryProvider(ctx, pgID)
|
|
}
|
|
|
|
// EnsureDefault creates a default builtin provider if none exists.
|
|
func (s *Service) EnsureDefault(ctx context.Context) (ProviderGetResponse, error) {
|
|
row, err := s.queries.GetDefaultMemoryProvider(ctx)
|
|
if err == nil {
|
|
return s.toGetResponse(row), nil
|
|
}
|
|
configJSON, _ := json.Marshal(map[string]any{})
|
|
created, err := s.queries.CreateMemoryProvider(ctx, sqlc.CreateMemoryProviderParams{
|
|
Name: "Built-in Memory",
|
|
Provider: string(ProviderBuiltin),
|
|
Config: configJSON,
|
|
IsDefault: true,
|
|
})
|
|
if err != nil {
|
|
return ProviderGetResponse{}, fmt.Errorf("create default memory provider: %w", err)
|
|
}
|
|
return s.toGetResponse(created), nil
|
|
}
|
|
|
|
func (s *Service) toGetResponse(row sqlc.MemoryProvider) ProviderGetResponse {
|
|
var cfg map[string]any
|
|
if len(row.Config) > 0 {
|
|
if err := json.Unmarshal(row.Config, &cfg); err != nil {
|
|
s.logger.Warn("memory provider config unmarshal failed", slog.String("id", row.ID.String()), slog.Any("error", err))
|
|
}
|
|
}
|
|
return ProviderGetResponse{
|
|
ID: row.ID.String(),
|
|
Name: row.Name,
|
|
Provider: row.Provider,
|
|
Config: cfg,
|
|
IsDefault: row.IsDefault,
|
|
CreatedAt: row.CreatedAt.Time,
|
|
UpdatedAt: row.UpdatedAt.Time,
|
|
}
|
|
}
|
|
|
|
func isValidProviderType(t ProviderType) bool {
|
|
switch t {
|
|
case ProviderBuiltin:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|