feat(agent): relax provider http timeouts (#348)

This commit is contained in:
Fodesu
2026-04-15 00:07:41 +08:00
committed by GitHub
parent 38ac907361
commit 8e1ed3683f
10 changed files with 95 additions and 28 deletions
+2 -2
View File
@@ -15,10 +15,10 @@ import (
// OpenAI-compatible /embeddings endpoint for all other provider types.
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration, httpClient *http.Client) *sdk.EmbeddingModel {
if timeout <= 0 {
timeout = 30 * time.Second
timeout = DefaultProviderRequestTimeout
}
if httpClient == nil {
httpClient = &http.Client{Timeout: timeout}
httpClient = NewProviderHTTPClient(timeout)
}
switch ClientType(clientType) {
+39
View File
@@ -0,0 +1,39 @@
package models
import (
"net/http"
"time"
)
const (
DefaultProviderRequestTimeout = 2 * time.Minute
DefaultProviderProbeTimeout = 60 * time.Second
DefaultProviderTLSHandshakeTimeout = 30 * time.Second
)
var defaultProviderTransport = newDefaultProviderTransport()
// NewProviderHTTPClient returns an HTTP client for model/provider traffic.
// When timeout is zero or negative, the caller is expected to enforce limits
// via context deadlines, which keeps streaming responses unbounded by the
// client's global timeout while still using the relaxed TLS handshake window.
func NewProviderHTTPClient(timeout time.Duration) *http.Client {
client := &http.Client{Transport: defaultProviderTransport}
if timeout > 0 {
client.Timeout = timeout
}
return client
}
func newDefaultProviderTransport() *http.Transport {
base, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return &http.Transport{TLSHandshakeTimeout: DefaultProviderTLSHandshakeTimeout}
}
transport := base.Clone()
if transport.TLSHandshakeTimeout < DefaultProviderTLSHandshakeTimeout {
transport.TLSHandshakeTimeout = DefaultProviderTLSHandshakeTimeout
}
return transport
}
+36
View File
@@ -0,0 +1,36 @@
package models
import (
"net/http"
"testing"
"time"
)
func TestNewProviderHTTPClientWithoutTimeoutKeepsStreamingFriendlyBehavior(t *testing.T) {
client := NewProviderHTTPClient(0)
if client == nil {
t.Fatal("expected client")
}
if client.Timeout != 0 {
t.Fatalf("expected no client timeout, got %s", client.Timeout)
}
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected *http.Transport, got %T", client.Transport)
}
if transport.TLSHandshakeTimeout < DefaultProviderTLSHandshakeTimeout {
t.Fatalf("expected TLS handshake timeout >= %s, got %s", DefaultProviderTLSHandshakeTimeout, transport.TLSHandshakeTimeout)
}
}
func TestNewProviderHTTPClientWithTimeout(t *testing.T) {
timeout := 45 * time.Second
client := NewProviderHTTPClient(timeout)
if client == nil {
t.Fatal("expected client")
}
if client.Timeout != timeout {
t.Fatalf("expected timeout %s, got %s", timeout, client.Timeout)
}
}
+2 -2
View File
@@ -23,7 +23,7 @@ import (
"github.com/memohai/memoh/internal/oauthctx"
)
const probeTimeout = 15 * time.Second
const probeTimeout = DefaultProviderProbeTimeout
// Test probes a model's provider endpoint using the Twilight AI SDK
// to verify connectivity, authentication, and model availability.
@@ -140,7 +140,7 @@ func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID
// It is exported so that other packages (e.g. providers) can reuse it for testing.
func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientType, timeout time.Duration, httpClient *http.Client) sdk.Provider {
if httpClient == nil {
httpClient = &http.Client{Timeout: timeout}
httpClient = NewProviderHTTPClient(timeout)
}
switch clientType {
+10 -18
View File
@@ -39,14 +39,16 @@ var (
// NewSDKChatModel builds a Twilight AI SDK Model from the resolved model config.
func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
if cfg.HTTPClient == nil {
cfg.HTTPClient = NewProviderHTTPClient(0)
}
switch ClientType(cfg.ClientType) {
case ClientTypeOpenAICompletions:
opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
}
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
if cfg.BaseURL != "" {
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
}
@@ -57,9 +59,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []openairesponses.Option{
openairesponses.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient))
}
opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient))
if cfg.BaseURL != "" {
opts = append(opts, openairesponses.WithBaseURL(cfg.BaseURL))
}
@@ -70,9 +70,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []openaicodex.Option{
openaicodex.WithAccessToken(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient))
}
opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient))
if cfg.CodexAccountID != "" {
opts = append(opts, openaicodex.WithAccountID(cfg.CodexAccountID))
}
@@ -85,9 +83,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient))
}
opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient))
if cfg.BaseURL != "" {
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
}
@@ -105,9 +101,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []googlegenerative.Option{
googlegenerative.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient))
}
opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient))
if cfg.BaseURL != "" {
opts = append(opts, googlegenerative.WithBaseURL(cfg.BaseURL))
}
@@ -118,9 +112,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
}
opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient))
if cfg.BaseURL != "" {
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
}