refactor: unify SDK model factories into internal/models

Move CreateModel, BuildReasoningOptions, ReasoningBudgetTokens and
related types from internal/agent to internal/models as NewSDKChatModel,
SDKModelConfig, etc. This eliminates duplicate ClientType constants and
centralises all Twilight AI SDK instance creation in a single package.

NewSDKEmbeddingModel now accepts a clientType parameter and dispatches
to the native Google embedding provider for google-generative-ai,
instead of always using the OpenAI-compatible endpoint.
This commit is contained in:
Acbox
2026-03-26 20:08:35 +08:00
parent 03ba13e7e5
commit 65b2797626
12 changed files with 96 additions and 86 deletions
+4 -20
View File
@@ -11,6 +11,7 @@ import (
sdk "github.com/memohai/twilight-ai/sdk" sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/agent/tools" "github.com/memohai/memoh/internal/agent/tools"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/workspace/bridge" "github.com/memohai/memoh/internal/workspace/bridge"
) )
@@ -371,9 +372,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep
if prepareStep != nil { if prepareStep != nil {
opts = append(opts, sdk.WithPrepareStep(prepareStep)) opts = append(opts, sdk.WithPrepareStep(prepareStep))
} }
opts = append(opts, BuildReasoningOptions(ModelConfig{ opts = append(opts, models.BuildReasoningOptions(models.SDKModelConfig{
ClientType: resolveClientType(cfg.Model), ClientType: models.ResolveClientType(cfg.Model),
ReasoningConfig: &ReasoningConfig{ ReasoningConfig: &models.ReasoningConfig{
Enabled: cfg.ReasoningEffort != "", Enabled: cfg.ReasoningEffort != "",
Effort: cfg.ReasoningEffort, Effort: cfg.ReasoningEffort,
}, },
@@ -381,23 +382,6 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep
return opts return opts
} }
func resolveClientType(model *sdk.Model) string {
if model == nil || model.Provider == nil {
return ClientTypeOpenAICompletions
}
name := model.Provider.Name()
switch {
case strings.Contains(name, "anthropic"):
return ClientTypeAnthropicMessages
case strings.Contains(name, "google"):
return ClientTypeGoogleGenerativeAI
case strings.Contains(name, "responses"):
return ClientTypeOpenAIResponses
default:
return ClientTypeOpenAICompletions
}
}
// assembleTools collects tools from all registered ToolProviders. // assembleTools collects tools from all registered ToolProviders.
func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, error) { func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, error) {
if len(a.toolProviders) == 0 { if len(a.toolProviders) == 0 {
+3 -2
View File
@@ -7,6 +7,7 @@ import (
sdk "github.com/memohai/twilight-ai/sdk" sdk "github.com/memohai/twilight-ai/sdk"
agenttools "github.com/memohai/memoh/internal/agent/tools" agenttools "github.com/memohai/memoh/internal/agent/tools"
"github.com/memohai/memoh/internal/models"
) )
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) { func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
@@ -14,7 +15,7 @@ func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *re
return tools, nil return tools, nil
} }
clientType := resolveClientType(model) clientType := models.ResolveClientType(model)
state := &readMediaDecorationState{ state := &readMediaDecorationState{
pendingImages: make(map[string]sdk.ImagePart), pendingImages: make(map[string]sdk.ImagePart),
} }
@@ -164,7 +165,7 @@ func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.Imag
} }
image := imageBase64 image := imageBase64
if clientType != ClientTypeAnthropicMessages { if clientType != string(models.ClientTypeAnthropicMessages) {
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64) image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
} }
return sdk.ImagePart{ return sdk.ImagePart{
+3 -2
View File
@@ -6,6 +6,7 @@ import (
sdk "github.com/memohai/twilight-ai/sdk" sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/agent/tools" "github.com/memohai/memoh/internal/agent/tools"
"github.com/memohai/memoh/internal/models"
) )
// SpawnAdapter wraps *Agent to satisfy tools.SpawnAgent without creating // SpawnAdapter wraps *Agent to satisfy tools.SpawnAgent without creating
@@ -68,10 +69,10 @@ func SpawnSystemPrompt(sessionType string) string {
}) })
} }
// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to agent.CreateModel. // SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to models.NewSDKChatModel.
func SpawnModelCreatorFunc() tools.ModelCreator { func SpawnModelCreatorFunc() tools.ModelCreator {
return func(modelID, clientType, apiKey, baseURL string) *sdk.Model { return func(modelID, clientType, apiKey, baseURL string) *sdk.Model {
return CreateModel(ModelConfig{ return models.NewSDKChatModel(models.SDKModelConfig{
ModelID: modelID, ModelID: modelID,
ClientType: clientType, ClientType: clientType,
APIKey: apiKey, APIKey: apiKey,
-15
View File
@@ -95,21 +95,6 @@ type SystemFile struct {
Content string Content string
} }
// ModelConfig holds provider and model information resolved from DB.
type ModelConfig struct {
ModelID string
ClientType string
APIKey string //nolint:gosec // carries provider credential material at runtime
BaseURL string
ReasoningConfig *ReasoningConfig
}
// ReasoningConfig controls extended thinking/reasoning behavior.
type ReasoningConfig struct {
Enabled bool
Effort string
}
func mustMarshal(v any) json.RawMessage { func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v) data, err := json.Marshal(v)
if err != nil { if err != nil {
+2 -2
View File
@@ -11,9 +11,9 @@ import (
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
sdk "github.com/memohai/twilight-ai/sdk" sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
) )
// Service manages context compaction for bot conversations. // Service manages context compaction for bot conversations.
@@ -103,7 +103,7 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU
userPrompt := buildUserPrompt(priorSummaries, entries) userPrompt := buildUserPrompt(priorSummaries, entries)
model := agent.CreateModel(agent.ModelConfig{ model := models.NewSDKChatModel(models.SDKModelConfig{
ClientType: cfg.ClientType, ClientType: cfg.ClientType,
BaseURL: cfg.BaseURL, BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey, APIKey: cfg.APIKey,
+4 -4
View File
@@ -270,15 +270,15 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
reasoningEffort = botSettings.ReasoningEffort reasoningEffort = botSettings.ReasoningEffort
} }
var reasoningConfig *agentpkg.ReasoningConfig var reasoningConfig *models.ReasoningConfig
if reasoningEffort != "" { if reasoningEffort != "" {
reasoningConfig = &agentpkg.ReasoningConfig{ reasoningConfig = &models.ReasoningConfig{
Enabled: true, Enabled: true,
Effort: reasoningEffort, Effort: reasoningEffort,
} }
} }
modelCfg := agentpkg.ModelConfig{ modelCfg := models.SDKModelConfig{
ModelID: chatModel.ModelID, ModelID: chatModel.ModelID,
ClientType: clientType, ClientType: clientType,
APIKey: provider.ApiKey, APIKey: provider.ApiKey,
@@ -286,7 +286,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
ReasoningConfig: reasoningConfig, ReasoningConfig: reasoningConfig,
} }
sdkModel := agentpkg.CreateModel(modelCfg) sdkModel := models.NewSDKChatModel(modelCfg)
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages)) sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
runCfg := agentpkg.RunConfig{ runCfg := agentpkg.RunConfig{
+2 -3
View File
@@ -9,7 +9,6 @@ import (
sdk "github.com/memohai/twilight-ai/sdk" sdk "github.com/memohai/twilight-ai/sdk"
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/db/sqlc"
messageevent "github.com/memohai/memoh/internal/message/event" messageevent "github.com/memohai/memoh/internal/message/event"
@@ -105,13 +104,13 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
"Return ONLY the title text, nothing else.\n\n" + "Return ONLY the title text, nothing else.\n\n" +
"User: " + userSnippet "User: " + userSnippet
modelCfg := agentpkg.ModelConfig{ modelCfg := models.SDKModelConfig{
ModelID: model.ModelID, ModelID: model.ModelID,
ClientType: provider.ClientType, ClientType: provider.ClientType,
APIKey: provider.ApiKey, APIKey: provider.ApiKey,
BaseURL: provider.BaseUrl, BaseURL: provider.BaseUrl,
} }
sdkModel := agentpkg.CreateModel(modelCfg) sdkModel := models.NewSDKChatModel(modelCfg)
genCtx, cancel := context.WithTimeout(ctx, titleGenerateTimeout) genCtx, cancel := context.WithTimeout(ctx, titleGenerateTimeout)
defer cancel() defer cancel()
@@ -35,6 +35,7 @@ type denseRuntime struct {
type denseModelSpec struct { type denseModelSpec struct {
modelID string modelID string
clientType string
baseURL string baseURL string
apiKey string apiKey string
dimensions int dimensions int
@@ -74,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg
return nil, fmt.Errorf("dense runtime: %w", err) return nil, fmt.Errorf("dense runtime: %w", err)
} }
embedModel := models.NewSDKEmbeddingModel(spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout) embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
return &denseRuntime{ return &denseRuntime{
qdrant: qClient, qdrant: qClient,
@@ -565,6 +566,7 @@ func resolveDenseEmbeddingModel(ctx context.Context, queries *dbsqlc.Queries, mo
} }
return denseModelSpec{ return denseModelSpec{
modelID: strings.TrimSpace(row.ModelID), modelID: strings.TrimSpace(row.ModelID),
clientType: strings.TrimSpace(provider.ClientType),
baseURL: strings.TrimSpace(provider.BaseUrl), baseURL: strings.TrimSpace(provider.BaseUrl),
apiKey: strings.TrimSpace(provider.ApiKey), apiKey: strings.TrimSpace(provider.ApiKey),
dimensions: *cfg.Dimensions, dimensions: *cfg.Dimensions,
+2 -1
View File
@@ -11,6 +11,7 @@ import (
"github.com/memohai/memoh/internal/agent" "github.com/memohai/memoh/internal/agent"
adapters "github.com/memohai/memoh/internal/memory/adapters" adapters "github.com/memohai/memoh/internal/memory/adapters"
"github.com/memohai/memoh/internal/models"
) )
const ( const (
@@ -42,7 +43,7 @@ func New(cfg Config) *Client {
} }
func (c *Client) model() *sdk.Model { func (c *Client) model() *sdk.Model {
return agent.CreateModel(agent.ModelConfig{ return models.NewSDKChatModel(models.SDKModelConfig{
ModelID: c.cfg.ModelID, ModelID: c.cfg.ModelID,
ClientType: c.cfg.ClientType, ClientType: c.cfg.ClientType,
APIKey: c.cfg.APIKey, APIKey: c.cfg.APIKey,
+19 -7
View File
@@ -4,22 +4,34 @@ import (
"net/http" "net/http"
"time" "time"
googleembedding "github.com/memohai/twilight-ai/provider/google/embedding"
openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding" openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding"
sdk "github.com/memohai/twilight-ai/sdk" sdk "github.com/memohai/twilight-ai/sdk"
) )
// NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given // NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given
// provider configuration. Currently all embedding providers use the // provider configuration. It dispatches to the native Google embedding provider
// OpenAI-compatible /embeddings endpoint (including Google-hosted models that // when clientType is "google-generative-ai", and falls back to the
// expose the same wire format), so we route everything through the OpenAI // OpenAI-compatible /embeddings endpoint for all other provider types.
// embedding provider. If a future provider requires a different wire protocol, func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
// add a branch here.
func NewSDKEmbeddingModel(baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
if timeout <= 0 { if timeout <= 0 {
timeout = 30 * time.Second timeout = 30 * time.Second
} }
httpClient := &http.Client{Timeout: timeout} httpClient := &http.Client{Timeout: timeout}
switch ClientType(clientType) {
case ClientTypeGoogleGenerativeAI:
opts := []googleembedding.Option{
googleembedding.WithAPIKey(apiKey),
googleembedding.WithHTTPClient(httpClient),
}
if baseURL != "" {
opts = append(opts, googleembedding.WithBaseURL(baseURL))
}
p := googleembedding.New(opts...)
return p.EmbeddingModel(modelID)
default:
opts := []openaiembedding.Option{ opts := []openaiembedding.Option{
openaiembedding.WithAPIKey(apiKey), openaiembedding.WithAPIKey(apiKey),
openaiembedding.WithHTTPClient(httpClient), openaiembedding.WithHTTPClient(httpClient),
@@ -27,7 +39,7 @@ func NewSDKEmbeddingModel(baseURL, apiKey, modelID string, timeout time.Duration
if baseURL != "" { if baseURL != "" {
opts = append(opts, openaiembedding.WithBaseURL(baseURL)) opts = append(opts, openaiembedding.WithBaseURL(baseURL))
} }
p := openaiembedding.New(opts...) p := openaiembedding.New(opts...)
return p.EmbeddingModel(modelID) return p.EmbeddingModel(modelID)
} }
}
+3 -3
View File
@@ -43,7 +43,7 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
// Embedding models don't have a chat Provider in the SDK — probe // Embedding models don't have a chat Provider in the SDK — probe
// the /embeddings endpoint directly. // the /embeddings endpoint directly.
if model.Type == string(ModelTypeEmbedding) { if model.Type == string(ModelTypeEmbedding) {
return s.testEmbeddingModel(ctx, baseURL, apiKey, model.ModelID) return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID)
} }
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout) sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
@@ -100,11 +100,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
// testEmbeddingModel probes an embedding model by performing a minimal // testEmbeddingModel probes an embedding model by performing a minimal
// embedding request via the Twilight SDK, verifying that the model is // embedding request via the Twilight SDK, verifying that the model is
// reachable and functional rather than merely checking HTTP connectivity. // reachable and functional rather than merely checking HTTP connectivity.
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) { func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) {
ctx, cancel := context.WithTimeout(ctx, probeTimeout) ctx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel() defer cancel()
model := NewSDKEmbeddingModel(baseURL, apiKey, modelID, probeTimeout) model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout)
client := sdk.NewClient() client := sdk.NewClient()
start := time.Now() start := time.Now()
@@ -1,6 +1,8 @@
package agent package models
import ( import (
"strings"
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages" anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai" googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions" openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
@@ -8,23 +10,30 @@ import (
sdk "github.com/memohai/twilight-ai/sdk" sdk "github.com/memohai/twilight-ai/sdk"
) )
// ClientType constants matching the database model configuration. // SDKModelConfig holds provider and model information resolved from DB,
const ( // used to construct a Twilight AI SDK Model instance.
ClientTypeOpenAICompletions = "openai-completions" type SDKModelConfig struct {
ClientTypeOpenAIResponses = "openai-responses" ModelID string
ClientTypeAnthropicMessages = "anthropic-messages" ClientType string
ClientTypeGoogleGenerativeAI = "google-generative-ai" APIKey string //nolint:gosec // carries provider credential material at runtime
) BaseURL string
ReasoningConfig *ReasoningConfig
}
// ReasoningConfig controls extended thinking/reasoning behavior.
type ReasoningConfig struct {
Enabled bool
Effort string
}
// Reasoning budget maps per client type.
var ( var (
anthropicBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000} anthropicBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
googleBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000} googleBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
) )
// CreateModel builds a Twilight AI SDK Model from the resolved model config. // NewSDKChatModel builds a Twilight AI SDK Model from the resolved model config.
func CreateModel(cfg ModelConfig) *sdk.Model { func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
switch cfg.ClientType { switch ClientType(cfg.ClientType) {
case ClientTypeOpenAICompletions: case ClientTypeOpenAICompletions:
opts := []openaicompletions.Option{ opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey), openaicompletions.WithAPIKey(cfg.APIKey),
@@ -53,7 +62,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL)) opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
} }
if cfg.ReasoningConfig != nil && cfg.ReasoningConfig.Enabled { if cfg.ReasoningConfig != nil && cfg.ReasoningConfig.Enabled {
budget := ReasoningBudgetTokens(ClientTypeAnthropicMessages, cfg.ReasoningConfig.Effort) budget := ReasoningBudgetTokens(cfg.ClientType, cfg.ReasoningConfig.Effort)
opts = append(opts, anthropicmessages.WithThinking(anthropicmessages.ThinkingConfig{ opts = append(opts, anthropicmessages.WithThinking(anthropicmessages.ThinkingConfig{
Type: "enabled", Type: "enabled",
BudgetTokens: budget, BudgetTokens: budget,
@@ -73,7 +82,6 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
return p.ChatModel(cfg.ModelID) return p.ChatModel(cfg.ModelID)
default: default:
// OpenAI-compatible fallback
opts := []openaicompletions.Option{ opts := []openaicompletions.Option{
openaicompletions.WithAPIKey(cfg.APIKey), openaicompletions.WithAPIKey(cfg.APIKey),
} }
@@ -86,7 +94,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
} }
// BuildReasoningOptions returns SDK generation options for reasoning/thinking. // BuildReasoningOptions returns SDK generation options for reasoning/thinking.
func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption { func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled { if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
return nil return nil
} }
@@ -95,9 +103,8 @@ func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption {
effort = "medium" effort = "medium"
} }
switch cfg.ClientType { switch ClientType(cfg.ClientType) {
case ClientTypeAnthropicMessages: case ClientTypeAnthropicMessages:
// Anthropic uses thinking budget — no SDK option, handled by provider
return nil return nil
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions: case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions:
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)} return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
@@ -113,7 +120,7 @@ func ReasoningBudgetTokens(clientType, effort string) int {
if effort == "" { if effort == "" {
effort = "medium" effort = "medium"
} }
switch clientType { switch ClientType(clientType) {
case ClientTypeAnthropicMessages: case ClientTypeAnthropicMessages:
if b, ok := anthropicBudget[effort]; ok { if b, ok := anthropicBudget[effort]; ok {
return b return b
@@ -128,3 +135,21 @@ func ReasoningBudgetTokens(clientType, effort string) int {
return 0 return 0
} }
} }
// ResolveClientType infers the client type string from an SDK Model's provider name.
func ResolveClientType(model *sdk.Model) string {
if model == nil || model.Provider == nil {
return string(ClientTypeOpenAICompletions)
}
name := model.Provider.Name()
switch {
case strings.Contains(name, "anthropic"):
return string(ClientTypeAnthropicMessages)
case strings.Contains(name, "google"):
return string(ClientTypeGoogleGenerativeAI)
case strings.Contains(name, "responses"):
return string(ClientTypeOpenAIResponses)
default:
return string(ClientTypeOpenAICompletions)
}
}