mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
f9f968f13f
* feat(models): add per-model probe testing and auto-detect in UI Move health probes from provider level to model level for precise testing with real model_id and client_type. Provider test is now a simple reachability check. Backend: - Add POST /models/:id/test endpoint that probes the model's provider using its actual model_id and client_type - Add model healthcheck checker for bot health checks (chat/memory/embedding) - Simplify provider test to reachability-only Frontend: - Auto-probe models on mount with status indicator (green/yellow/red dot + latency) - Auto-probe provider reachability on load and on provider switch - Fix missing faBolt icon registration - Manual re-probe via refresh button Closes #117 * fix(models): increase probe timeout to 15s for slow providers Some providers (e.g. DashScope) exceed the 5s probe timeout, causing false-negative "context deadline exceeded" errors. Increase per-probe timeout to 15s and healthcheck overall timeout to 30s. * fix(sdk): regenerate exports after merge conflict Resolve duplicate SDK exports introduced by merge conflict resolution so the web build can compile again while preserving new model probe endpoints.
253 lines
6.3 KiB
Go
253 lines
6.3 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
)
|
|
|
|
// Service handles provider operations
|
|
type Service struct {
|
|
queries *sqlc.Queries
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewService creates a new provider service
|
|
func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
|
return &Service{
|
|
queries: queries,
|
|
logger: log.With(slog.String("service", "providers")),
|
|
}
|
|
}
|
|
|
|
// Create creates a new LLM provider
|
|
func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) {
|
|
// Marshal metadata
|
|
metadataJSON, err := json.Marshal(req.Metadata)
|
|
if err != nil {
|
|
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
|
|
}
|
|
|
|
// Create provider
|
|
provider, err := s.queries.CreateLlmProvider(ctx, sqlc.CreateLlmProviderParams{
|
|
Name: req.Name,
|
|
BaseUrl: req.BaseURL,
|
|
ApiKey: req.APIKey,
|
|
Metadata: metadataJSON,
|
|
})
|
|
if err != nil {
|
|
return GetResponse{}, fmt.Errorf("create provider: %w", err)
|
|
}
|
|
|
|
return s.toGetResponse(provider), nil
|
|
}
|
|
|
|
// Get retrieves a provider by ID
|
|
func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) {
|
|
providerID, err := db.ParseUUID(id)
|
|
if err != nil {
|
|
return GetResponse{}, err
|
|
}
|
|
|
|
provider, err := s.queries.GetLlmProviderByID(ctx, providerID)
|
|
if err != nil {
|
|
return GetResponse{}, fmt.Errorf("get provider: %w", err)
|
|
}
|
|
|
|
return s.toGetResponse(provider), nil
|
|
}
|
|
|
|
// GetByName retrieves a provider by name
|
|
func (s *Service) GetByName(ctx context.Context, name string) (GetResponse, error) {
|
|
provider, err := s.queries.GetLlmProviderByName(ctx, name)
|
|
if err != nil {
|
|
return GetResponse{}, fmt.Errorf("get provider by name: %w", err)
|
|
}
|
|
|
|
return s.toGetResponse(provider), nil
|
|
}
|
|
|
|
// List retrieves all providers
|
|
func (s *Service) List(ctx context.Context) ([]GetResponse, error) {
|
|
providers, err := s.queries.ListLlmProviders(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list providers: %w", err)
|
|
}
|
|
|
|
results := make([]GetResponse, 0, len(providers))
|
|
for _, p := range providers {
|
|
results = append(results, s.toGetResponse(p))
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
// Update updates an existing provider
|
|
func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) {
|
|
providerID, err := db.ParseUUID(id)
|
|
if err != nil {
|
|
return GetResponse{}, err
|
|
}
|
|
|
|
// Get existing provider
|
|
existing, err := s.queries.GetLlmProviderByID(ctx, providerID)
|
|
if err != nil {
|
|
return GetResponse{}, fmt.Errorf("get provider: %w", err)
|
|
}
|
|
|
|
// Apply updates
|
|
name := existing.Name
|
|
if req.Name != nil {
|
|
name = *req.Name
|
|
}
|
|
|
|
baseURL := existing.BaseUrl
|
|
if req.BaseURL != nil {
|
|
baseURL = *req.BaseURL
|
|
}
|
|
|
|
apiKey := resolveUpdatedAPIKey(existing.ApiKey, req.APIKey)
|
|
|
|
metadata := existing.Metadata
|
|
if req.Metadata != nil {
|
|
metadataJSON, err := json.Marshal(req.Metadata)
|
|
if err != nil {
|
|
return GetResponse{}, fmt.Errorf("marshal metadata: %w", err)
|
|
}
|
|
metadata = metadataJSON
|
|
}
|
|
|
|
// Update provider
|
|
updated, err := s.queries.UpdateLlmProvider(ctx, sqlc.UpdateLlmProviderParams{
|
|
ID: providerID,
|
|
Name: name,
|
|
BaseUrl: baseURL,
|
|
ApiKey: apiKey,
|
|
Metadata: metadata,
|
|
})
|
|
if err != nil {
|
|
return GetResponse{}, fmt.Errorf("update provider: %w", err)
|
|
}
|
|
|
|
return s.toGetResponse(updated), nil
|
|
}
|
|
|
|
// Delete deletes a provider by ID
|
|
func (s *Service) Delete(ctx context.Context, id string) error {
|
|
providerID, err := db.ParseUUID(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.queries.DeleteLlmProvider(ctx, providerID); err != nil {
|
|
return fmt.Errorf("delete provider: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Count returns the total count of providers
|
|
func (s *Service) Count(ctx context.Context) (int64, error) {
|
|
count, err := s.queries.CountLlmProviders(ctx)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("count providers: %w", err)
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
const probeTimeout = 5 * time.Second
|
|
|
|
// Test probes the provider's base URL to check reachability.
|
|
func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
|
providerID, err := db.ParseUUID(id)
|
|
if err != nil {
|
|
return TestResponse{}, err
|
|
}
|
|
|
|
provider, err := s.queries.GetLlmProviderByID(ctx, providerID)
|
|
if err != nil {
|
|
return TestResponse{}, fmt.Errorf("get provider: %w", err)
|
|
}
|
|
|
|
baseURL := strings.TrimRight(provider.BaseUrl, "/")
|
|
|
|
start := time.Now()
|
|
reachable, msg := probeReachable(ctx, baseURL)
|
|
latency := time.Since(start).Milliseconds()
|
|
|
|
return TestResponse{
|
|
Reachable: reachable,
|
|
LatencyMs: latency,
|
|
Message: msg,
|
|
}, nil
|
|
}
|
|
|
|
func probeReachable(ctx context.Context, baseURL string) (bool, string) {
|
|
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
|
if err != nil {
|
|
return false, err.Error()
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return false, err.Error()
|
|
}
|
|
io.Copy(io.Discard, resp.Body)
|
|
resp.Body.Close()
|
|
return true, ""
|
|
}
|
|
|
|
// toGetResponse converts a database provider to a response
|
|
func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
|
|
var metadata map[string]any
|
|
if len(provider.Metadata) > 0 {
|
|
if err := json.Unmarshal(provider.Metadata, &metadata); err != nil {
|
|
slog.Warn("provider metadata unmarshal failed", slog.String("id", provider.ID.String()), slog.Any("error", err))
|
|
}
|
|
}
|
|
|
|
// Mask API key (show only first 8 characters)
|
|
maskedAPIKey := maskAPIKey(provider.ApiKey)
|
|
|
|
return GetResponse{
|
|
ID: provider.ID.String(),
|
|
Name: provider.Name,
|
|
BaseURL: provider.BaseUrl,
|
|
APIKey: maskedAPIKey,
|
|
Metadata: metadata,
|
|
CreatedAt: provider.CreatedAt.Time,
|
|
UpdatedAt: provider.UpdatedAt.Time,
|
|
}
|
|
}
|
|
|
|
// maskAPIKey masks an API key for security
|
|
func maskAPIKey(apiKey string) string {
|
|
if apiKey == "" {
|
|
return ""
|
|
}
|
|
if len(apiKey) <= 8 {
|
|
return strings.Repeat("*", len(apiKey))
|
|
}
|
|
return apiKey[:8] + strings.Repeat("*", len(apiKey)-8)
|
|
}
|
|
|
|
// resolveUpdatedAPIKey keeps the original key when the request value matches the masked version.
|
|
// This prevents masked placeholder values from overwriting the real stored credential.
|
|
func resolveUpdatedAPIKey(existing string, updated *string) string {
|
|
if updated == nil {
|
|
return existing
|
|
}
|
|
if *updated == maskAPIKey(existing) {
|
|
return existing
|
|
}
|
|
return *updated
|
|
}
|