refactor: unify SDK model factories into internal/models

Move CreateModel, BuildReasoningOptions, ReasoningBudgetTokens and
related types from internal/agent to internal/models as NewSDKChatModel,
SDKModelConfig, etc. This eliminates duplicate ClientType constants and
centralises all Twilight AI SDK instance creation in a single package.

NewSDKEmbeddingModel now accepts a clientType parameter and dispatches
to the native Google embedding provider for google-generative-ai,
instead of always using the OpenAI-compatible endpoint.
This commit is contained in:
Acbox
2026-03-26 20:08:35 +08:00
parent 03ba13e7e5
commit 65b2797626
12 changed files with 96 additions and 86 deletions
+27 -15
View File
@@ -4,30 +4,42 @@ import (
"net/http"
"time"
googleembedding "github.com/memohai/twilight-ai/provider/google/embedding"
openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding"
sdk "github.com/memohai/twilight-ai/sdk"
)
// NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given
// provider configuration. Currently all embedding providers use the
// OpenAI-compatible /embeddings endpoint (including Google-hosted models that
// expose the same wire format), so we route everything through the OpenAI
// embedding provider. If a future provider requires a different wire protocol,
// add a branch here.
func NewSDKEmbeddingModel(baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
// provider configuration. It dispatches to the native Google embedding provider
// when clientType is "google-generative-ai", and falls back to the
// OpenAI-compatible /embeddings endpoint for all other provider types.
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
if timeout <= 0 {
timeout = 30 * time.Second
}
httpClient := &http.Client{Timeout: timeout}
opts := []openaiembedding.Option{
openaiembedding.WithAPIKey(apiKey),
openaiembedding.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, openaiembedding.WithBaseURL(baseURL))
}
switch ClientType(clientType) {
case ClientTypeGoogleGenerativeAI:
opts := []googleembedding.Option{
googleembedding.WithAPIKey(apiKey),
googleembedding.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, googleembedding.WithBaseURL(baseURL))
}
p := googleembedding.New(opts...)
return p.EmbeddingModel(modelID)
p := openaiembedding.New(opts...)
return p.EmbeddingModel(modelID)
default:
opts := []openaiembedding.Option{
openaiembedding.WithAPIKey(apiKey),
openaiembedding.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, openaiembedding.WithBaseURL(baseURL))
}
p := openaiembedding.New(opts...)
return p.EmbeddingModel(modelID)
}
}
+3 -3
View File
@@ -43,7 +43,7 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
// Embedding models don't have a chat Provider in the SDK — probe
// the /embeddings endpoint directly.
if model.Type == string(ModelTypeEmbedding) {
return s.testEmbeddingModel(ctx, baseURL, apiKey, model.ModelID)
return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID)
}
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
@@ -100,11 +100,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
// testEmbeddingModel probes an embedding model by performing a minimal
// embedding request via the Twilight SDK, verifying that the model is
// reachable and functional rather than merely checking HTTP connectivity.
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) {
func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) {
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
model := NewSDKEmbeddingModel(baseURL, apiKey, modelID, probeTimeout)
model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout)
client := sdk.NewClient()
start := time.Now()
+155
View File
@@ -0,0 +1,155 @@
package models
import (
"strings"
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
)
// SDKModelConfig holds provider and model information resolved from DB,
// used to construct a Twilight AI SDK Model instance.
type SDKModelConfig struct {
ModelID string
ClientType string
APIKey string //nolint:gosec // carries provider credential material at runtime
BaseURL string
ReasoningConfig *ReasoningConfig
}
// ReasoningConfig controls extended thinking/reasoning behavior.
type ReasoningConfig struct {
Enabled bool
Effort string
}
var (
anthropicBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
googleBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
)
// NewSDKChatModel builds a Twilight AI SDK Model from the resolved model config.
func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
switch ClientType(cfg.ClientType) {
case ClientTypeOpenAICompletions:
opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey),
}
if cfg.BaseURL != "" {
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
}
p := openaicompletions.New(opts...)
return p.ChatModel(cfg.ModelID)
case ClientTypeOpenAIResponses:
opts := []openairesponses.Option{
openairesponses.WithAPIKey(cfg.APIKey),
}
if cfg.BaseURL != "" {
opts = append(opts, openairesponses.WithBaseURL(cfg.BaseURL))
}
p := openairesponses.New(opts...)
return p.ChatModel(cfg.ModelID)
case ClientTypeAnthropicMessages:
opts := []anthropicmessages.Option{
anthropicmessages.WithAPIKey(cfg.APIKey),
}
if cfg.BaseURL != "" {
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
}
if cfg.ReasoningConfig != nil && cfg.ReasoningConfig.Enabled {
budget := ReasoningBudgetTokens(cfg.ClientType, cfg.ReasoningConfig.Effort)
opts = append(opts, anthropicmessages.WithThinking(anthropicmessages.ThinkingConfig{
Type: "enabled",
BudgetTokens: budget,
}))
}
p := anthropicmessages.New(opts...)
return p.ChatModel(cfg.ModelID)
case ClientTypeGoogleGenerativeAI:
opts := []googlegenerative.Option{
googlegenerative.WithAPIKey(cfg.APIKey),
}
if cfg.BaseURL != "" {
opts = append(opts, googlegenerative.WithBaseURL(cfg.BaseURL))
}
p := googlegenerative.New(opts...)
return p.ChatModel(cfg.ModelID)
default:
opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey),
}
if cfg.BaseURL != "" {
opts = append(opts, openaicompletions.WithBaseURL(cfg.BaseURL))
}
p := openaicompletions.New(opts...)
return p.ChatModel(cfg.ModelID)
}
}
// BuildReasoningOptions returns SDK generation options for reasoning/thinking.
func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
return nil
}
effort := cfg.ReasoningConfig.Effort
if effort == "" {
effort = "medium"
}
switch ClientType(cfg.ClientType) {
case ClientTypeAnthropicMessages:
return nil
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions:
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
case ClientTypeGoogleGenerativeAI:
return nil
default:
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
}
}
// ReasoningBudgetTokens returns the token budget for extended thinking based on client type and effort.
func ReasoningBudgetTokens(clientType, effort string) int {
if effort == "" {
effort = "medium"
}
switch ClientType(clientType) {
case ClientTypeAnthropicMessages:
if b, ok := anthropicBudget[effort]; ok {
return b
}
return anthropicBudget["medium"]
case ClientTypeGoogleGenerativeAI:
if b, ok := googleBudget[effort]; ok {
return b
}
return googleBudget["medium"]
default:
return 0
}
}
// ResolveClientType infers the client type string from an SDK Model's provider name.
func ResolveClientType(model *sdk.Model) string {
if model == nil || model.Provider == nil {
return string(ClientTypeOpenAICompletions)
}
name := model.Provider.Name()
switch {
case strings.Contains(name, "anthropic"):
return string(ClientTypeAnthropicMessages)
case strings.Contains(name, "google"):
return string(ClientTypeGoogleGenerativeAI)
case strings.Contains(name, "responses"):
return string(ClientTypeOpenAIResponses)
default:
return string(ClientTypeOpenAICompletions)
}
}