refactor: use Twilight AI SDK for model and provider connectivity tes… (#273)

* refactor: use Twilight AI SDK for model and provider connectivity testing

Replace hand-rolled HTTP probing with the Twilight AI SDK's built-in
Provider.Test() and TestModel() methods.

- Model test now runs Provider.Test() (connectivity + auth) followed by
  TestModel() (model availability) via the SDK
- Provider test auto-detects client_type from associated models and
  creates the correct SDK provider for accurate auth header handling
- Embedding models use a dedicated /embeddings endpoint probe since
  the SDK's chat Provider doesn't cover embedding APIs
- Latency measurement now covers the full test lifecycle
- Add TestStatusModelNotSupported for models not found by the provider
- Upgrade twilight-ai to v0.3.3-0.20260321100646-43c789b701dd which
  includes fallback probing for providers without GET /models/{id}

* fix: lint
This commit is contained in:
Acbox Liu
2026-03-21 19:14:50 +08:00
committed by GitHub
parent 80b36f79f3
commit a7a36df705
8 changed files with 186 additions and 137 deletions
@@ -63,11 +63,11 @@
class="mt-2"
:show-group-headers="false"
>
<template #trigger="{ open, displayLabel }">
<template #trigger="{ open: isOpen, displayLabel }">
<Button
variant="outline"
role="combobox"
:aria-expanded="open"
:aria-expanded="isOpen"
class="w-full justify-between font-normal mt-2"
>
<span class="truncate">
+1 -1
View File
@@ -9,7 +9,7 @@
import { computed } from 'vue'
const props = withDefaults(defineProps<{
status: 'success' | 'error' | 'warning' | 'idle'
status?: 'success' | 'error' | 'warning' | 'idle'
}>(), {
status: 'idle',
})
+10
View File
@@ -43,6 +43,7 @@ export default [
quotes: ['error', 'single'],
semi: ['error', 'never'],
'vue/multi-word-component-names': 'off',
'vue/require-default-prop': 'off',
'@typescript-eslint/no-unused-vars': ['error', {
argsIgnorePattern: '^_',
varsIgnorePattern: '^_',
@@ -50,4 +51,13 @@ export default [
}],
},
},
{
files: [
'apps/web/src/pages/chat/components/tool-call-edit.vue',
'apps/web/src/pages/chat/components/tool-call-write.vue',
],
rules: {
'vue/no-v-html': 'off',
},
},
]
+1 -1
View File
@@ -26,7 +26,7 @@ require (
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mailgun/mailgun-go/v5 v5.14.0
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7
github.com/memohai/twilight-ai v0.3.2
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd
github.com/modelcontextprotocol/go-sdk v1.4.1
github.com/opencontainers/image-spec v1.1.1
github.com/opencontainers/runtime-spec v1.3.0
+2 -2
View File
@@ -228,8 +228,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 h1:beehwOQperqGWj4m4EhcPhnSZKtDiuHK/7ZMoTPaQjw=
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7/go.mod h1:OvmxM7JmnXBmwJWWVqtreL3HSHSKuzPbtbhlg5MvBg0=
github.com/memohai/twilight-ai v0.3.2 h1:0C5U9W8s/6I0T5YWoCgKLPZ/hVmM2Q65IHifNqK444Y=
github.com/memohai/twilight-ai v0.3.2/go.mod h1:vHNoRb6/quMacMAgIp838aoiNhsZbE0bFCnRRNyRwNc=
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd h1:uV7xsqYHYpEmT6xKvkOs5mHT5oEKnwV1F93ialqi78k=
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd/go.mod h1:vHNoRb6/quMacMAgIp838aoiNhsZbE0bFCnRRNyRwNc=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
+140 -107
View File
@@ -3,19 +3,26 @@ package models
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/db"
)
const probeTimeout = 15 * time.Second
// Test probes a model's provider endpoint using the model's real model_id
// and client_type to verify that the configuration is valid.
// Test probes a model's provider endpoint using the Twilight AI SDK
// to verify connectivity, authentication, and model availability.
func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
modelID, err := db.ParseUUID(id)
if err != nil {
@@ -34,139 +41,165 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
baseURL := strings.TrimRight(provider.BaseUrl, "/")
apiKey := provider.ApiKey
clientType := ClientType(model.ClientType.String)
// Reachability check
reachable, reachMsg := probeReachable(ctx, baseURL)
if !reachable {
// Embedding models don't have a chat Provider in the SDK — probe
// the /embeddings endpoint directly.
if model.Type == string(ModelTypeEmbedding) {
return s.testEmbeddingModel(ctx, baseURL, apiKey, model.ModelID)
}
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
start := time.Now()
providerResult := sdkProvider.Test(ctx)
switch providerResult.Status {
case sdk.ProviderStatusUnreachable:
return TestResponse{
Status: TestStatusError,
Message: reachMsg,
Status: TestStatusError,
Reachable: false,
LatencyMs: time.Since(start).Milliseconds(),
Message: providerResult.Message,
}, nil
case sdk.ProviderStatusUnhealthy:
return TestResponse{
Status: TestStatusAuthError,
Reachable: true,
LatencyMs: time.Since(start).Milliseconds(),
Message: providerResult.Message,
}, nil
}
// Select probe by client type (chat) or model type (embedding)
var result probeResult
if model.Type == string(ModelTypeEmbedding) {
result = probeEmbedding(ctx, baseURL, apiKey, model.ModelID)
} else {
result = probeChatModel(ctx, baseURL, apiKey, model.ModelID, ClientType(model.ClientType.String))
modelResult, err := sdkProvider.TestModel(ctx, model.ModelID)
latency := time.Since(start).Milliseconds()
if err != nil {
return TestResponse{
Status: TestStatusError,
Reachable: true,
LatencyMs: latency,
Message: err.Error(),
}, nil
}
if !modelResult.Supported {
return TestResponse{
Status: TestStatusModelNotSupported,
Reachable: true,
LatencyMs: latency,
Message: modelResult.Message,
}, nil
}
return TestResponse{
Status: classifyProbe(result.statusCode),
Status: TestStatusOK,
Reachable: true,
LatencyMs: result.latencyMs,
Message: result.message,
LatencyMs: latency,
Message: modelResult.Message,
}, nil
}
type probeResult struct {
statusCode int
latencyMs int64
message string
}
// testEmbeddingModel probes an embedding model by sending a minimal
// request to the /embeddings endpoint.
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) {
body, _ := json.Marshal(map[string]any{"model": modelID, "input": "hello"})
func probeChatModel(ctx context.Context, baseURL, apiKey, modelID string, clientType ClientType) probeResult {
switch clientType {
case ClientTypeOpenAIResponses:
body := fmt.Sprintf(`{"model":%q,"input":"hi","max_output_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/responses", openAIHeaders(apiKey), body)
case ClientTypeOpenAICompletions:
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/chat/completions", openAIHeaders(apiKey), body)
case ClientTypeAnthropicMessages:
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/messages", map[string]string{
"x-api-key": apiKey,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json",
}, body)
case ClientTypeGoogleGenerativeAI:
body := `{"contents":[{"parts":[{"text":"hi"}]}],"generationConfig":{"maxOutputTokens":1}}`
url := fmt.Sprintf("%s/models/%s:generateContent", baseURL, modelID)
return doProbe(ctx, http.MethodPost, url, map[string]string{
"x-goog-api-key": apiKey,
"Content-Type": "application/json",
}, body)
default:
// Fallback: treat as OpenAI completions compatible
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/chat/completions", openAIHeaders(apiKey), body)
}
}
func probeEmbedding(ctx context.Context, baseURL, apiKey, modelID string) probeResult {
body := fmt.Sprintf(`{"model":%q,"input":"hello"}`, modelID)
return doProbe(ctx, http.MethodPost, baseURL+"/embeddings", openAIHeaders(apiKey), body)
}
func openAIHeaders(apiKey string) map[string]string {
return map[string]string{
"Authorization": "Bearer " + apiKey,
"Content-Type": "application/json",
}
}
func probeReachable(ctx context.Context, baseURL string) (bool, string) {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
baseURL+"/embeddings", bytes.NewReader(body))
if err != nil {
return false, err.Error()
}
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
if err != nil {
return false, err.Error()
}
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
return true, ""
}
func doProbe(ctx context.Context, method, url string, headers map[string]string, body string) probeResult {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
var bodyReader io.Reader
if body != "" {
bodyReader = bytes.NewBufferString(body)
}
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return probeResult{message: err.Error()}
}
for k, v := range headers {
req.Header.Set(k, v)
return TestResponse{Status: TestStatusError, Message: err.Error()}, nil
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
start := time.Now()
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
httpClient := &http.Client{Timeout: probeTimeout}
// #nosec G704 -- baseURL comes from the configured provider endpoint that this health probe is expected to test.
resp, err := httpClient.Do(req)
latency := time.Since(start).Milliseconds()
if err != nil {
return probeResult{latencyMs: latency, message: err.Error()}
return TestResponse{
Status: TestStatusError,
Reachable: false,
LatencyMs: latency,
Message: err.Error(),
}, nil
}
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
return probeResult{statusCode: resp.StatusCode, latencyMs: latency}
result, classifyErr := sdk.ClassifyProbeStatus(resp.StatusCode)
if classifyErr != nil {
return TestResponse{
Status: TestStatusError,
Reachable: true,
LatencyMs: latency,
Message: classifyErr.Error(),
}, nil
}
tr := TestResponse{
Reachable: true,
LatencyMs: latency,
Message: result.Message,
}
if result.Supported {
tr.Status = TestStatusOK
} else {
tr.Status = TestStatusModelNotSupported
}
return tr, nil
}
func classifyProbe(statusCode int) TestStatus {
switch {
case statusCode >= 200 && statusCode <= 299:
return TestStatusOK
case statusCode == 400 || statusCode == 422 || statusCode == 429:
// 400/422 = endpoint works but request rejected; 429 = rate limited (model exists)
return TestStatusOK
case statusCode == 401 || statusCode == 403:
return TestStatusAuthError
// NewSDKProvider creates a Twilight AI SDK Provider for the given client type.
// It is exported so that other packages (e.g. providers) can reuse it for testing.
func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.Duration) sdk.Provider {
httpClient := &http.Client{Timeout: timeout}
switch clientType {
case ClientTypeOpenAIResponses:
opts := []openairesponses.Option{
openairesponses.WithAPIKey(apiKey),
openairesponses.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, openairesponses.WithBaseURL(baseURL))
}
return openairesponses.New(opts...)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(apiKey),
anthropicmessages.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, anthropicmessages.WithBaseURL(baseURL))
}
return anthropicmessages.New(opts...)
case ClientTypeGoogleGenerativeAI:
opts := []googlegenerative.Option{
googlegenerative.WithAPIKey(apiKey),
googlegenerative.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, googlegenerative.WithBaseURL(baseURL))
}
return googlegenerative.New(opts...)
default:
return TestStatusError
opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(apiKey),
openaicompletions.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, openaicompletions.WithBaseURL(baseURL))
}
return openaicompletions.New(opts...)
}
}
+4 -3
View File
@@ -143,9 +143,10 @@ type CountResponse struct {
type TestStatus string
const (
TestStatusOK TestStatus = "ok"
TestStatusAuthError TestStatus = "auth_error"
TestStatusError TestStatus = "error"
TestStatusOK TestStatus = "ok"
TestStatusAuthError TestStatus = "auth_error"
TestStatusModelNotSupported TestStatus = "model_not_supported"
TestStatusError TestStatus = "error"
)
// TestResponse is returned by POST /models/:id/test.
+26 -21
View File
@@ -10,8 +10,12 @@ import (
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
// Service handles provider operations.
@@ -163,7 +167,8 @@ func (s *Service) Count(ctx context.Context) (int64, error) {
const probeTimeout = 5 * time.Second
// Test probes the provider's base URL to check reachability.
// Test probes the provider using the Twilight AI SDK to check
// reachability and authentication.
func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
providerID, err := db.ParseUUID(id)
if err != nil {
@@ -177,17 +182,34 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
baseURL := strings.TrimRight(provider.BaseUrl, "/")
// Determine client type from the first model using this provider.
clientType := resolveProviderClientType(ctx, s.queries, providerID)
sdkProvider := models.NewSDKProvider(baseURL, provider.ApiKey, clientType, probeTimeout)
start := time.Now()
reachable, msg := probeReachable(ctx, baseURL)
result := sdkProvider.Test(ctx)
latency := time.Since(start).Milliseconds()
return TestResponse{
Reachable: reachable,
Reachable: result.Status != sdk.ProviderStatusUnreachable,
LatencyMs: latency,
Message: msg,
Message: result.Message,
}, nil
}
// resolveProviderClientType looks up models associated with the provider
// and returns the first model's client_type. Falls back to openai-completions.
func resolveProviderClientType(ctx context.Context, q *sqlc.Queries, providerID pgtype.UUID) models.ClientType {
rows, err := q.ListModelsByProviderID(ctx, providerID)
if err == nil && len(rows) > 0 {
if ct := rows[0].ClientType; ct.Valid && ct.String != "" {
return models.ClientType(ct.String)
}
}
return models.ClientTypeOpenAICompletions
}
// FetchRemoteModels fetches models from the provider's /v1/models endpoint.
func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteModel, error) {
providerID, err := db.ParseUUID(id)
@@ -234,23 +256,6 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
return fetchResp.Data, nil
}
func probeReachable(ctx context.Context, baseURL string) (bool, string) {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
if err != nil {
return false, err.Error()
}
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
if err != nil {
return false, err.Error()
}
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
return true, ""
}
// toGetResponse converts a database provider to a response.
func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
var metadata map[string]any