Files
Memoh/internal/models/probe.go
T

173 lines
5.2 KiB
Go

package models
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/memohai/memoh/internal/db"
)
const probeTimeout = 15 * time.Second
// Test probes a model's provider endpoint using the model's real model_id
// and client_type to verify that the configuration is valid.
func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
modelID, err := db.ParseUUID(id)
if err != nil {
return TestResponse{}, fmt.Errorf("invalid model id: %w", err)
}
model, err := s.queries.GetModelByID(ctx, modelID)
if err != nil {
return TestResponse{}, fmt.Errorf("get model: %w", err)
}
provider, err := s.queries.GetLlmProviderByID(ctx, model.LlmProviderID)
if err != nil {
return TestResponse{}, fmt.Errorf("get provider: %w", err)
}
baseURL := strings.TrimRight(provider.BaseUrl, "/")
apiKey := provider.ApiKey
// Reachability check
reachable, reachMsg := probeReachable(ctx, baseURL)
if !reachable {
return TestResponse{
Status: TestStatusError,
Message: reachMsg,
}, nil
}
// Select probe by client type (chat) or model type (embedding)
var result probeResult
if model.Type == string(ModelTypeEmbedding) {
result = probeEmbedding(ctx, baseURL, apiKey, model.ModelID)
} else {
result = probeChatModel(ctx, baseURL, apiKey, model.ModelID, ClientType(model.ClientType.String))
}
return TestResponse{
Status: classifyProbe(result.statusCode),
Reachable: true,
LatencyMs: result.latencyMs,
Message: result.message,
}, nil
}
type probeResult struct {
statusCode int
latencyMs int64
message string
}
func probeChatModel(ctx context.Context, baseURL, apiKey, modelID string, clientType ClientType) probeResult {
switch clientType {
case ClientTypeOpenAIResponses:
body := fmt.Sprintf(`{"model":%q,"input":"hi","max_output_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/responses", openAIHeaders(apiKey), body)
case ClientTypeOpenAICompletions:
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/chat/completions", openAIHeaders(apiKey), body)
case ClientTypeAnthropicMessages:
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/messages", map[string]string{
"x-api-key": apiKey,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json",
}, body)
case ClientTypeGoogleGenerativeAI:
body := `{"contents":[{"parts":[{"text":"hi"}]}],"generationConfig":{"maxOutputTokens":1}}`
url := fmt.Sprintf("%s/models/%s:generateContent", baseURL, modelID)
return doProbe(ctx, http.MethodPost, url, map[string]string{
"x-goog-api-key": apiKey,
"Content-Type": "application/json",
}, body)
default:
// Fallback: treat as OpenAI completions compatible
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/chat/completions", openAIHeaders(apiKey), body)
}
}
func probeEmbedding(ctx context.Context, baseURL, apiKey, modelID string) probeResult {
body := fmt.Sprintf(`{"model":%q,"input":"hello"}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/embeddings", openAIHeaders(apiKey), body)
}
func openAIHeaders(apiKey string) map[string]string {
return map[string]string{
"Authorization": "Bearer " + apiKey,
"Content-Type": "application/json",
}
}
func probeReachable(ctx context.Context, baseURL string) (bool, string) {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
if err != nil {
return false, err.Error()
}
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
if err != nil {
return false, err.Error()
}
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
return true, ""
}
func doProbe(ctx context.Context, method, url string, headers map[string]string, body string) probeResult {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
var bodyReader io.Reader
if body != "" {
bodyReader = bytes.NewBufferString(body)
}
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return probeResult{message: err.Error()}
}
for k, v := range headers {
req.Header.Set(k, v)
}
start := time.Now()
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
latency := time.Since(start).Milliseconds()
if err != nil {
return probeResult{latencyMs: latency, message: err.Error()}
}
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
return probeResult{statusCode: resp.StatusCode, latencyMs: latency}
}
func classifyProbe(statusCode int) TestStatus {
switch {
case statusCode >= 200 && statusCode <= 299:
return TestStatusOK
case statusCode == 400 || statusCode == 422 || statusCode == 429:
// 400/422 = endpoint works but request rejected; 429 = rate limited (model exists)
return TestStatusOK
case statusCode == 401 || statusCode == 403:
return TestStatusAuthError
default:
return TestStatusError
}
}