From 8e1ed3683fd97ffe99b30779df0b898e516bd5b7 Mon Sep 17 00:00:00 2001 From: Fodesu <75713465+Fodesu@users.noreply.github.com> Date: Wed, 15 Apr 2026 00:07:41 +0800 Subject: [PATCH] feat(agent): relax provider http timeouts (#348) --- cmd/agent/app.go | 2 +- .../healthcheck/checkers/model/checker.go | 2 +- .../memory/adapters/builtin/dense_runtime.go | 2 +- internal/memory/memllm/client.go | 2 +- internal/models/embedding.go | 4 +- internal/models/http_client.go | 39 +++++++++++++++++++ internal/models/http_client_test.go | 36 +++++++++++++++++ internal/models/probe.go | 4 +- internal/models/sdk.go | 28 +++++-------- internal/providers/service.go | 4 +- 10 files changed, 95 insertions(+), 28 deletions(-) create mode 100644 internal/models/http_client.go create mode 100644 internal/models/http_client_test.go diff --git a/cmd/agent/app.go b/cmd/agent/app.go index 7efc79ad..91c3c543 100644 --- a/cmd/agent/app.go +++ b/cmd/agent/app.go @@ -146,7 +146,7 @@ func provideMemoryLLM(modelsService *models.Service, settingsService *settings.S modelsService: modelsService, settingsService: settingsService, queries: queries, - timeout: 30 * time.Second, + timeout: models.DefaultProviderRequestTimeout, logger: log, } } diff --git a/internal/healthcheck/checkers/model/checker.go b/internal/healthcheck/checkers/model/checker.go index 02848b98..0b3c243e 100644 --- a/internal/healthcheck/checkers/model/checker.go +++ b/internal/healthcheck/checkers/model/checker.go @@ -16,7 +16,7 @@ import ( const ( checkTypeModelConnection = "model.connection" titleKeyModelConnection = "bots.checks.titles.modelConnection" - defaultTimeout = 30 * time.Second + defaultTimeout = models.DefaultProviderProbeTimeout ) // BotModelLookup fetches model IDs configured for a bot. diff --git a/internal/memory/adapters/builtin/dense_runtime.go b/internal/memory/adapters/builtin/dense_runtime.go index 11940087..257f3f66 100644 --- a/internal/memory/adapters/builtin/dense_runtime.go +++ b/internal/memory/adapters/builtin/dense_runtime.go @@ -23,7 +23,7 @@ import ( "github.com/memohai/memoh/internal/models" ) -const denseEmbedTimeout = 30 * time.Second +const denseEmbedTimeout = models.DefaultProviderRequestTimeout type denseRuntime struct { qdrant *qdrantclient.Client diff --git a/internal/memory/memllm/client.go b/internal/memory/memllm/client.go index 417f6651..fbc9c885 100644 --- a/internal/memory/memllm/client.go +++ b/internal/memory/memllm/client.go @@ -15,7 +15,7 @@ import ( ) const ( - defaultTimeout = 30 * time.Second + defaultTimeout = models.DefaultProviderRequestTimeout maxExtractFacts = 10 maxDecideActions = 20 ) diff --git a/internal/models/embedding.go b/internal/models/embedding.go index 831b9e44..523a2584 100644 --- a/internal/models/embedding.go +++ b/internal/models/embedding.go @@ -15,10 +15,10 @@ import ( // OpenAI-compatible /embeddings endpoint for all other provider types. func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration, httpClient *http.Client) *sdk.EmbeddingModel { if timeout <= 0 { - timeout = 30 * time.Second + timeout = DefaultProviderRequestTimeout } if httpClient == nil { - httpClient = &http.Client{Timeout: timeout} + httpClient = NewProviderHTTPClient(timeout) } switch ClientType(clientType) { diff --git a/internal/models/http_client.go b/internal/models/http_client.go new file mode 100644 index 00000000..d5285374 --- /dev/null +++ b/internal/models/http_client.go @@ -0,0 +1,39 @@ +package models + +import ( + "net/http" + "time" +) + +const ( + DefaultProviderRequestTimeout = 2 * time.Minute + DefaultProviderProbeTimeout = 60 * time.Second + DefaultProviderTLSHandshakeTimeout = 30 * time.Second +) + +var defaultProviderTransport = newDefaultProviderTransport() + +// NewProviderHTTPClient returns an HTTP client for model/provider traffic. +// When timeout is zero or negative, the caller is expected to enforce limits +// via context deadlines, which keeps streaming responses unbounded by the +// client's global timeout while still using the relaxed TLS handshake window. +func NewProviderHTTPClient(timeout time.Duration) *http.Client { + client := &http.Client{Transport: defaultProviderTransport} + if timeout > 0 { + client.Timeout = timeout + } + return client +} + +func newDefaultProviderTransport() *http.Transport { + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return &http.Transport{TLSHandshakeTimeout: DefaultProviderTLSHandshakeTimeout} + } + + transport := base.Clone() + if transport.TLSHandshakeTimeout < DefaultProviderTLSHandshakeTimeout { + transport.TLSHandshakeTimeout = DefaultProviderTLSHandshakeTimeout + } + return transport +} diff --git a/internal/models/http_client_test.go b/internal/models/http_client_test.go new file mode 100644 index 00000000..b9ac1db7 --- /dev/null +++ b/internal/models/http_client_test.go @@ -0,0 +1,36 @@ +package models + +import ( + "net/http" + "testing" + "time" +) + +func TestNewProviderHTTPClientWithoutTimeoutKeepsStreamingFriendlyBehavior(t *testing.T) { + client := NewProviderHTTPClient(0) + if client == nil { + t.Fatal("expected client") + } + if client.Timeout != 0 { + t.Fatalf("expected no client timeout, got %s", client.Timeout) + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", client.Transport) + } + if transport.TLSHandshakeTimeout < DefaultProviderTLSHandshakeTimeout { + t.Fatalf("expected TLS handshake timeout >= %s, got %s", DefaultProviderTLSHandshakeTimeout, transport.TLSHandshakeTimeout) + } +} + +func TestNewProviderHTTPClientWithTimeout(t *testing.T) { + timeout := 45 * time.Second + client := NewProviderHTTPClient(timeout) + if client == nil { + t.Fatal("expected client") + } + if client.Timeout != timeout { + t.Fatalf("expected timeout %s, got %s", timeout, client.Timeout) + } +} diff --git a/internal/models/probe.go b/internal/models/probe.go index 8f7fb018..e717f609 100644 --- a/internal/models/probe.go +++ b/internal/models/probe.go @@ -23,7 +23,7 @@ import ( "github.com/memohai/memoh/internal/oauthctx" ) -const probeTimeout = 15 * time.Second +const probeTimeout = DefaultProviderProbeTimeout // Test probes a model's provider endpoint using the Twilight AI SDK // to verify connectivity, authentication, and model availability. @@ -140,7 +140,7 @@ func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID // It is exported so that other packages (e.g. providers) can reuse it for testing. func NewSDKProvider(baseURL, apiKey, codexAccountID string, clientType ClientType, timeout time.Duration, httpClient *http.Client) sdk.Provider { if httpClient == nil { - httpClient = &http.Client{Timeout: timeout} + httpClient = NewProviderHTTPClient(timeout) } switch clientType { diff --git a/internal/models/sdk.go b/internal/models/sdk.go index 7d370dca..4c8a52f6 100644 --- a/internal/models/sdk.go +++ b/internal/models/sdk.go @@ -39,14 +39,16 @@ var ( // NewSDKChatModel builds a Twilight AI SDK Model from the resolved model config. func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { + if cfg.HTTPClient == nil { + cfg.HTTPClient = NewProviderHTTPClient(0) + } + switch ClientType(cfg.ClientType) { case ClientTypeOpenAICompletions: opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), } - if cfg.HTTPClient != nil { - opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) - } + opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) if cfg.BaseURL != "" { opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL)) } @@ -57,9 +59,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openairesponses.Option{ openairesponses.WithAPIKey(cfg.APIKey), } - if cfg.HTTPClient != nil { - opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient)) - } + opts = append(opts, openairesponses.WithHTTPClient(cfg.HTTPClient)) if cfg.BaseURL != "" { opts = append(opts, openairesponses.WithBaseURL(cfg.BaseURL)) } @@ -70,9 +70,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openaicodex.Option{ openaicodex.WithAccessToken(cfg.APIKey), } - if cfg.HTTPClient != nil { - opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient)) - } + opts = append(opts, openaicodex.WithHTTPClient(cfg.HTTPClient)) if cfg.CodexAccountID != "" { opts = append(opts, openaicodex.WithAccountID(cfg.CodexAccountID)) } @@ -85,9 +83,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []anthropicmessages.Option{ anthropicmessages.WithAPIKey(cfg.APIKey), } - if cfg.HTTPClient != nil { - opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient)) - } + opts = append(opts, anthropicmessages.WithHTTPClient(cfg.HTTPClient)) if cfg.BaseURL != "" { opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL)) } @@ -105,9 +101,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []googlegenerative.Option{ googlegenerative.WithAPIKey(cfg.APIKey), } - if cfg.HTTPClient != nil { - opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient)) - } + opts = append(opts, googlegenerative.WithHTTPClient(cfg.HTTPClient)) if cfg.BaseURL != "" { opts = append(opts, googlegenerative.WithBaseURL(cfg.BaseURL)) } @@ -118,9 +112,7 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), } - if cfg.HTTPClient != nil { - opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) - } + opts = append(opts, openaicompletions.WithHTTPClient(cfg.HTTPClient)) if cfg.BaseURL != "" { opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL)) } diff --git a/internal/providers/service.go b/internal/providers/service.go index fc4dd656..f459aa39 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -209,7 +209,7 @@ func (s *Service) Count(ctx context.Context) (int64, error) { return count, nil } -const probeTimeout = 5 * time.Second +const probeTimeout = models.DefaultProviderProbeTimeout // Test probes the provider using the Twilight AI SDK to check // reachability and authentication. @@ -326,7 +326,7 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) } - resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured provider base URL + resp, err := models.NewProviderHTTPClient(probeTimeout).Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL if err != nil { return nil, fmt.Errorf("execute request: %w", err) }