mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
refactor: use Twilight AI SDK for model and provider connectivity tes… (#273)
* refactor: use Twilight AI SDK for model and provider connectivity testing
Replace hand-rolled HTTP probing with the Twilight AI SDK's built-in
Provider.Test() and TestModel() methods.
- Model test now runs Provider.Test() (connectivity + auth) followed by
TestModel() (model availability) via the SDK
- Provider test auto-detects client_type from associated models and
creates the correct SDK provider for accurate auth header handling
- Embedding models use a dedicated /embeddings endpoint probe since
the SDK's chat Provider doesn't cover embedding APIs
- Latency measurement now covers the full test lifecycle
- Add TestStatusModelNotSupported for models not found by the provider
- Upgrade twilight-ai to v0.3.3-0.20260321100646-43c789b701dd which
includes fallback probing for providers without GET /models/{id}
* fix: lint
This commit is contained in:
@@ -63,11 +63,11 @@
|
||||
class="mt-2"
|
||||
:show-group-headers="false"
|
||||
>
|
||||
<template #trigger="{ open, displayLabel }">
|
||||
<template #trigger="{ open: isOpen, displayLabel }">
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
:aria-expanded="open"
|
||||
:aria-expanded="isOpen"
|
||||
class="w-full justify-between font-normal mt-2"
|
||||
>
|
||||
<span class="truncate">
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
import { computed } from 'vue'
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
status: 'success' | 'error' | 'warning' | 'idle'
|
||||
status?: 'success' | 'error' | 'warning' | 'idle'
|
||||
}>(), {
|
||||
status: 'idle',
|
||||
})
|
||||
|
||||
@@ -43,6 +43,7 @@ export default [
|
||||
quotes: ['error', 'single'],
|
||||
semi: ['error', 'never'],
|
||||
'vue/multi-word-component-names': 'off',
|
||||
'vue/require-default-prop': 'off',
|
||||
'@typescript-eslint/no-unused-vars': ['error', {
|
||||
argsIgnorePattern: '^_',
|
||||
varsIgnorePattern: '^_',
|
||||
@@ -50,4 +51,13 @@ export default [
|
||||
}],
|
||||
},
|
||||
},
|
||||
{
|
||||
files: [
|
||||
'apps/web/src/pages/chat/components/tool-call-edit.vue',
|
||||
'apps/web/src/pages/chat/components/tool-call-write.vue',
|
||||
],
|
||||
rules: {
|
||||
'vue/no-v-html': 'off',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -26,7 +26,7 @@ require (
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mailgun/mailgun-go/v5 v5.14.0
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7
|
||||
github.com/memohai/twilight-ai v0.3.2
|
||||
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1
|
||||
github.com/opencontainers/image-spec v1.1.1
|
||||
github.com/opencontainers/runtime-spec v1.3.0
|
||||
|
||||
@@ -228,8 +228,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7 h1:beehwOQperqGWj4m4EhcPhnSZKtDiuHK/7ZMoTPaQjw=
|
||||
github.com/memohai/acgo v0.0.0-20260221232113-babac0d6acd7/go.mod h1:OvmxM7JmnXBmwJWWVqtreL3HSHSKuzPbtbhlg5MvBg0=
|
||||
github.com/memohai/twilight-ai v0.3.2 h1:0C5U9W8s/6I0T5YWoCgKLPZ/hVmM2Q65IHifNqK444Y=
|
||||
github.com/memohai/twilight-ai v0.3.2/go.mod h1:vHNoRb6/quMacMAgIp838aoiNhsZbE0bFCnRRNyRwNc=
|
||||
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd h1:uV7xsqYHYpEmT6xKvkOs5mHT5oEKnwV1F93ialqi78k=
|
||||
github.com/memohai/twilight-ai v0.3.3-0.20260321100646-43c789b701dd/go.mod h1:vHNoRb6/quMacMAgIp838aoiNhsZbE0bFCnRRNyRwNc=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
|
||||
|
||||
+140
-107
@@ -3,19 +3,26 @@ package models
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
|
||||
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
)
|
||||
|
||||
const probeTimeout = 15 * time.Second
|
||||
|
||||
// Test probes a model's provider endpoint using the model's real model_id
|
||||
// and client_type to verify that the configuration is valid.
|
||||
// Test probes a model's provider endpoint using the Twilight AI SDK
|
||||
// to verify connectivity, authentication, and model availability.
|
||||
func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
modelID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
@@ -34,139 +41,165 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
|
||||
baseURL := strings.TrimRight(provider.BaseUrl, "/")
|
||||
apiKey := provider.ApiKey
|
||||
clientType := ClientType(model.ClientType.String)
|
||||
|
||||
// Reachability check
|
||||
reachable, reachMsg := probeReachable(ctx, baseURL)
|
||||
if !reachable {
|
||||
// Embedding models don't have a chat Provider in the SDK — probe
|
||||
// the /embeddings endpoint directly.
|
||||
if model.Type == string(ModelTypeEmbedding) {
|
||||
return s.testEmbeddingModel(ctx, baseURL, apiKey, model.ModelID)
|
||||
}
|
||||
|
||||
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
providerResult := sdkProvider.Test(ctx)
|
||||
switch providerResult.Status {
|
||||
case sdk.ProviderStatusUnreachable:
|
||||
return TestResponse{
|
||||
Status: TestStatusError,
|
||||
Message: reachMsg,
|
||||
Status: TestStatusError,
|
||||
Reachable: false,
|
||||
LatencyMs: time.Since(start).Milliseconds(),
|
||||
Message: providerResult.Message,
|
||||
}, nil
|
||||
case sdk.ProviderStatusUnhealthy:
|
||||
return TestResponse{
|
||||
Status: TestStatusAuthError,
|
||||
Reachable: true,
|
||||
LatencyMs: time.Since(start).Milliseconds(),
|
||||
Message: providerResult.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Select probe by client type (chat) or model type (embedding)
|
||||
var result probeResult
|
||||
if model.Type == string(ModelTypeEmbedding) {
|
||||
result = probeEmbedding(ctx, baseURL, apiKey, model.ModelID)
|
||||
} else {
|
||||
result = probeChatModel(ctx, baseURL, apiKey, model.ModelID, ClientType(model.ClientType.String))
|
||||
modelResult, err := sdkProvider.TestModel(ctx, model.ModelID)
|
||||
latency := time.Since(start).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
return TestResponse{
|
||||
Status: TestStatusError,
|
||||
Reachable: true,
|
||||
LatencyMs: latency,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !modelResult.Supported {
|
||||
return TestResponse{
|
||||
Status: TestStatusModelNotSupported,
|
||||
Reachable: true,
|
||||
LatencyMs: latency,
|
||||
Message: modelResult.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return TestResponse{
|
||||
Status: classifyProbe(result.statusCode),
|
||||
Status: TestStatusOK,
|
||||
Reachable: true,
|
||||
LatencyMs: result.latencyMs,
|
||||
Message: result.message,
|
||||
LatencyMs: latency,
|
||||
Message: modelResult.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type probeResult struct {
|
||||
statusCode int
|
||||
latencyMs int64
|
||||
message string
|
||||
}
|
||||
// testEmbeddingModel probes an embedding model by sending a minimal
|
||||
// request to the /embeddings endpoint.
|
||||
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) {
|
||||
body, _ := json.Marshal(map[string]any{"model": modelID, "input": "hello"})
|
||||
|
||||
func probeChatModel(ctx context.Context, baseURL, apiKey, modelID string, clientType ClientType) probeResult {
|
||||
switch clientType {
|
||||
case ClientTypeOpenAIResponses:
|
||||
body := fmt.Sprintf(`{"model":%q,"input":"hi","max_output_tokens":1}`, modelID)
|
||||
return doProbe(ctx, http.MethodPost, baseURL+"/responses", openAIHeaders(apiKey), body)
|
||||
|
||||
case ClientTypeOpenAICompletions:
|
||||
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
|
||||
return doProbe(ctx, http.MethodPost, baseURL+"/chat/completions", openAIHeaders(apiKey), body)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
|
||||
return doProbe(ctx, http.MethodPost, baseURL+"/messages", map[string]string{
|
||||
"x-api-key": apiKey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
}, body)
|
||||
|
||||
case ClientTypeGoogleGenerativeAI:
|
||||
body := `{"contents":[{"parts":[{"text":"hi"}]}],"generationConfig":{"maxOutputTokens":1}}`
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent", baseURL, modelID)
|
||||
return doProbe(ctx, http.MethodPost, url, map[string]string{
|
||||
"x-goog-api-key": apiKey,
|
||||
"Content-Type": "application/json",
|
||||
}, body)
|
||||
|
||||
default:
|
||||
// Fallback: treat as OpenAI completions compatible
|
||||
body := fmt.Sprintf(`{"model":%q,"messages":[{"role":"user","content":"hi"}],"max_tokens":1}`, modelID)
|
||||
return doProbe(ctx, http.MethodPost, baseURL+"/chat/completions", openAIHeaders(apiKey), body)
|
||||
}
|
||||
}
|
||||
|
||||
func probeEmbedding(ctx context.Context, baseURL, apiKey, modelID string) probeResult {
|
||||
body := fmt.Sprintf(`{"model":%q,"input":"hello"}`, modelID)
|
||||
return doProbe(ctx, http.MethodPost, baseURL+"/embeddings", openAIHeaders(apiKey), body)
|
||||
}
|
||||
|
||||
func openAIHeaders(apiKey string) map[string]string {
|
||||
return map[string]string{
|
||||
"Authorization": "Bearer " + apiKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
func probeReachable(ctx context.Context, baseURL string) (bool, string) {
|
||||
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
baseURL+"/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
|
||||
if err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
return true, ""
|
||||
}
|
||||
|
||||
func doProbe(ctx context.Context, method, url string, headers map[string]string, body string) probeResult {
|
||||
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
var bodyReader io.Reader
|
||||
if body != "" {
|
||||
bodyReader = bytes.NewBufferString(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||
if err != nil {
|
||||
return probeResult{message: err.Error()}
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
return TestResponse{Status: TestStatusError, Message: err.Error()}, nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
start := time.Now()
|
||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
|
||||
httpClient := &http.Client{Timeout: probeTimeout}
|
||||
// #nosec G704 -- baseURL comes from the configured provider endpoint that this health probe is expected to test.
|
||||
resp, err := httpClient.Do(req)
|
||||
latency := time.Since(start).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
return probeResult{latencyMs: latency, message: err.Error()}
|
||||
return TestResponse{
|
||||
Status: TestStatusError,
|
||||
Reachable: false,
|
||||
LatencyMs: latency,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return probeResult{statusCode: resp.StatusCode, latencyMs: latency}
|
||||
result, classifyErr := sdk.ClassifyProbeStatus(resp.StatusCode)
|
||||
if classifyErr != nil {
|
||||
return TestResponse{
|
||||
Status: TestStatusError,
|
||||
Reachable: true,
|
||||
LatencyMs: latency,
|
||||
Message: classifyErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
tr := TestResponse{
|
||||
Reachable: true,
|
||||
LatencyMs: latency,
|
||||
Message: result.Message,
|
||||
}
|
||||
if result.Supported {
|
||||
tr.Status = TestStatusOK
|
||||
} else {
|
||||
tr.Status = TestStatusModelNotSupported
|
||||
}
|
||||
return tr, nil
|
||||
}
|
||||
|
||||
func classifyProbe(statusCode int) TestStatus {
|
||||
switch {
|
||||
case statusCode >= 200 && statusCode <= 299:
|
||||
return TestStatusOK
|
||||
case statusCode == 400 || statusCode == 422 || statusCode == 429:
|
||||
// 400/422 = endpoint works but request rejected; 429 = rate limited (model exists)
|
||||
return TestStatusOK
|
||||
case statusCode == 401 || statusCode == 403:
|
||||
return TestStatusAuthError
|
||||
// NewSDKProvider creates a Twilight AI SDK Provider for the given client type.
|
||||
// It is exported so that other packages (e.g. providers) can reuse it for testing.
|
||||
func NewSDKProvider(baseURL, apiKey string, clientType ClientType, timeout time.Duration) sdk.Provider {
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
switch clientType {
|
||||
case ClientTypeOpenAIResponses:
|
||||
opts := []openairesponses.Option{
|
||||
openairesponses.WithAPIKey(apiKey),
|
||||
openairesponses.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, openairesponses.WithBaseURL(baseURL))
|
||||
}
|
||||
return openairesponses.New(opts...)
|
||||
|
||||
case ClientTypeAnthropicMessages:
|
||||
opts := []anthropicmessages.Option{
|
||||
anthropicmessages.WithAPIKey(apiKey),
|
||||
anthropicmessages.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, anthropicmessages.WithBaseURL(baseURL))
|
||||
}
|
||||
return anthropicmessages.New(opts...)
|
||||
|
||||
case ClientTypeGoogleGenerativeAI:
|
||||
opts := []googlegenerative.Option{
|
||||
googlegenerative.WithAPIKey(apiKey),
|
||||
googlegenerative.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, googlegenerative.WithBaseURL(baseURL))
|
||||
}
|
||||
return googlegenerative.New(opts...)
|
||||
|
||||
default:
|
||||
return TestStatusError
|
||||
opts := []openaicompletions.Option{
|
||||
openaicompletions.WithAPIKey(apiKey),
|
||||
openaicompletions.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, openaicompletions.WithBaseURL(baseURL))
|
||||
}
|
||||
return openaicompletions.New(opts...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,9 +143,10 @@ type CountResponse struct {
|
||||
type TestStatus string
|
||||
|
||||
const (
|
||||
TestStatusOK TestStatus = "ok"
|
||||
TestStatusAuthError TestStatus = "auth_error"
|
||||
TestStatusError TestStatus = "error"
|
||||
TestStatusOK TestStatus = "ok"
|
||||
TestStatusAuthError TestStatus = "auth_error"
|
||||
TestStatusModelNotSupported TestStatus = "model_not_supported"
|
||||
TestStatusError TestStatus = "error"
|
||||
)
|
||||
|
||||
// TestResponse is returned by POST /models/:id/test.
|
||||
|
||||
@@ -10,8 +10,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
// Service handles provider operations.
|
||||
@@ -163,7 +167,8 @@ func (s *Service) Count(ctx context.Context) (int64, error) {
|
||||
|
||||
const probeTimeout = 5 * time.Second
|
||||
|
||||
// Test probes the provider's base URL to check reachability.
|
||||
// Test probes the provider using the Twilight AI SDK to check
|
||||
// reachability and authentication.
|
||||
func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
providerID, err := db.ParseUUID(id)
|
||||
if err != nil {
|
||||
@@ -177,17 +182,34 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
|
||||
baseURL := strings.TrimRight(provider.BaseUrl, "/")
|
||||
|
||||
// Determine client type from the first model using this provider.
|
||||
clientType := resolveProviderClientType(ctx, s.queries, providerID)
|
||||
|
||||
sdkProvider := models.NewSDKProvider(baseURL, provider.ApiKey, clientType, probeTimeout)
|
||||
|
||||
start := time.Now()
|
||||
reachable, msg := probeReachable(ctx, baseURL)
|
||||
result := sdkProvider.Test(ctx)
|
||||
latency := time.Since(start).Milliseconds()
|
||||
|
||||
return TestResponse{
|
||||
Reachable: reachable,
|
||||
Reachable: result.Status != sdk.ProviderStatusUnreachable,
|
||||
LatencyMs: latency,
|
||||
Message: msg,
|
||||
Message: result.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveProviderClientType looks up models associated with the provider
|
||||
// and returns the first model's client_type. Falls back to openai-completions.
|
||||
func resolveProviderClientType(ctx context.Context, q *sqlc.Queries, providerID pgtype.UUID) models.ClientType {
|
||||
rows, err := q.ListModelsByProviderID(ctx, providerID)
|
||||
if err == nil && len(rows) > 0 {
|
||||
if ct := rows[0].ClientType; ct.Valid && ct.String != "" {
|
||||
return models.ClientType(ct.String)
|
||||
}
|
||||
}
|
||||
return models.ClientTypeOpenAICompletions
|
||||
}
|
||||
|
||||
// FetchRemoteModels fetches models from the provider's /v1/models endpoint.
|
||||
func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteModel, error) {
|
||||
providerID, err := db.ParseUUID(id)
|
||||
@@ -234,23 +256,6 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod
|
||||
return fetchResp.Data, nil
|
||||
}
|
||||
|
||||
func probeReachable(ctx context.Context, baseURL string) (bool, string) {
|
||||
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
if err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL
|
||||
if err != nil {
|
||||
return false, err.Error()
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// toGetResponse converts a database provider to a response.
|
||||
func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
|
||||
var metadata map[string]any
|
||||
|
||||
Reference in New Issue
Block a user