mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
refactor: unify SDK model factories into internal/models
Move CreateModel, BuildReasoningOptions, ReasoningBudgetTokens and related types from internal/agent to internal/models as NewSDKChatModel, SDKModelConfig, etc. This eliminates duplicate ClientType constants and centralises all Twilight AI SDK instance creation in a single package. NewSDKEmbeddingModel now accepts a clientType parameter and dispatches to the native Google embedding provider for google-generative-ai, instead of always using the OpenAI-compatible endpoint.
This commit is contained in:
+4
-20
@@ -11,6 +11,7 @@ import (
|
|||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/agent/tools"
|
"github.com/memohai/memoh/internal/agent/tools"
|
||||||
|
"github.com/memohai/memoh/internal/models"
|
||||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -371,9 +372,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep
|
|||||||
if prepareStep != nil {
|
if prepareStep != nil {
|
||||||
opts = append(opts, sdk.WithPrepareStep(prepareStep))
|
opts = append(opts, sdk.WithPrepareStep(prepareStep))
|
||||||
}
|
}
|
||||||
opts = append(opts, BuildReasoningOptions(ModelConfig{
|
opts = append(opts, models.BuildReasoningOptions(models.SDKModelConfig{
|
||||||
ClientType: resolveClientType(cfg.Model),
|
ClientType: models.ResolveClientType(cfg.Model),
|
||||||
ReasoningConfig: &ReasoningConfig{
|
ReasoningConfig: &models.ReasoningConfig{
|
||||||
Enabled: cfg.ReasoningEffort != "",
|
Enabled: cfg.ReasoningEffort != "",
|
||||||
Effort: cfg.ReasoningEffort,
|
Effort: cfg.ReasoningEffort,
|
||||||
},
|
},
|
||||||
@@ -381,23 +382,6 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep
|
|||||||
return opts
|
return opts
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveClientType(model *sdk.Model) string {
|
|
||||||
if model == nil || model.Provider == nil {
|
|
||||||
return ClientTypeOpenAICompletions
|
|
||||||
}
|
|
||||||
name := model.Provider.Name()
|
|
||||||
switch {
|
|
||||||
case strings.Contains(name, "anthropic"):
|
|
||||||
return ClientTypeAnthropicMessages
|
|
||||||
case strings.Contains(name, "google"):
|
|
||||||
return ClientTypeGoogleGenerativeAI
|
|
||||||
case strings.Contains(name, "responses"):
|
|
||||||
return ClientTypeOpenAIResponses
|
|
||||||
default:
|
|
||||||
return ClientTypeOpenAICompletions
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// assembleTools collects tools from all registered ToolProviders.
|
// assembleTools collects tools from all registered ToolProviders.
|
||||||
func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, error) {
|
func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, error) {
|
||||||
if len(a.toolProviders) == 0 {
|
if len(a.toolProviders) == 0 {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
||||||
|
"github.com/memohai/memoh/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
|
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
|
||||||
@@ -14,7 +15,7 @@ func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *re
|
|||||||
return tools, nil
|
return tools, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
clientType := resolveClientType(model)
|
clientType := models.ResolveClientType(model)
|
||||||
state := &readMediaDecorationState{
|
state := &readMediaDecorationState{
|
||||||
pendingImages: make(map[string]sdk.ImagePart),
|
pendingImages: make(map[string]sdk.ImagePart),
|
||||||
}
|
}
|
||||||
@@ -164,7 +165,7 @@ func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.Imag
|
|||||||
}
|
}
|
||||||
|
|
||||||
image := imageBase64
|
image := imageBase64
|
||||||
if clientType != ClientTypeAnthropicMessages {
|
if clientType != string(models.ClientTypeAnthropicMessages) {
|
||||||
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
|
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
|
||||||
}
|
}
|
||||||
return sdk.ImagePart{
|
return sdk.ImagePart{
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/agent/tools"
|
"github.com/memohai/memoh/internal/agent/tools"
|
||||||
|
"github.com/memohai/memoh/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SpawnAdapter wraps *Agent to satisfy tools.SpawnAgent without creating
|
// SpawnAdapter wraps *Agent to satisfy tools.SpawnAgent without creating
|
||||||
@@ -68,10 +69,10 @@ func SpawnSystemPrompt(sessionType string) string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to agent.CreateModel.
|
// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to models.NewSDKChatModel.
|
||||||
func SpawnModelCreatorFunc() tools.ModelCreator {
|
func SpawnModelCreatorFunc() tools.ModelCreator {
|
||||||
return func(modelID, clientType, apiKey, baseURL string) *sdk.Model {
|
return func(modelID, clientType, apiKey, baseURL string) *sdk.Model {
|
||||||
return CreateModel(ModelConfig{
|
return models.NewSDKChatModel(models.SDKModelConfig{
|
||||||
ModelID: modelID,
|
ModelID: modelID,
|
||||||
ClientType: clientType,
|
ClientType: clientType,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
|
|||||||
@@ -95,21 +95,6 @@ type SystemFile struct {
|
|||||||
Content string
|
Content string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelConfig holds provider and model information resolved from DB.
|
|
||||||
type ModelConfig struct {
|
|
||||||
ModelID string
|
|
||||||
ClientType string
|
|
||||||
APIKey string //nolint:gosec // carries provider credential material at runtime
|
|
||||||
BaseURL string
|
|
||||||
ReasoningConfig *ReasoningConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReasoningConfig controls extended thinking/reasoning behavior.
|
|
||||||
type ReasoningConfig struct {
|
|
||||||
Enabled bool
|
|
||||||
Effort string
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustMarshal(v any) json.RawMessage {
|
func mustMarshal(v any) json.RawMessage {
|
||||||
data, err := json.Marshal(v)
|
data, err := json.Marshal(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/agent"
|
|
||||||
"github.com/memohai/memoh/internal/db"
|
"github.com/memohai/memoh/internal/db"
|
||||||
"github.com/memohai/memoh/internal/db/sqlc"
|
"github.com/memohai/memoh/internal/db/sqlc"
|
||||||
|
"github.com/memohai/memoh/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Service manages context compaction for bot conversations.
|
// Service manages context compaction for bot conversations.
|
||||||
@@ -103,7 +103,7 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU
|
|||||||
|
|
||||||
userPrompt := buildUserPrompt(priorSummaries, entries)
|
userPrompt := buildUserPrompt(priorSummaries, entries)
|
||||||
|
|
||||||
model := agent.CreateModel(agent.ModelConfig{
|
model := models.NewSDKChatModel(models.SDKModelConfig{
|
||||||
ClientType: cfg.ClientType,
|
ClientType: cfg.ClientType,
|
||||||
BaseURL: cfg.BaseURL,
|
BaseURL: cfg.BaseURL,
|
||||||
APIKey: cfg.APIKey,
|
APIKey: cfg.APIKey,
|
||||||
|
|||||||
@@ -270,15 +270,15 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
|||||||
reasoningEffort = botSettings.ReasoningEffort
|
reasoningEffort = botSettings.ReasoningEffort
|
||||||
}
|
}
|
||||||
|
|
||||||
var reasoningConfig *agentpkg.ReasoningConfig
|
var reasoningConfig *models.ReasoningConfig
|
||||||
if reasoningEffort != "" {
|
if reasoningEffort != "" {
|
||||||
reasoningConfig = &agentpkg.ReasoningConfig{
|
reasoningConfig = &models.ReasoningConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Effort: reasoningEffort,
|
Effort: reasoningEffort,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelCfg := agentpkg.ModelConfig{
|
modelCfg := models.SDKModelConfig{
|
||||||
ModelID: chatModel.ModelID,
|
ModelID: chatModel.ModelID,
|
||||||
ClientType: clientType,
|
ClientType: clientType,
|
||||||
APIKey: provider.ApiKey,
|
APIKey: provider.ApiKey,
|
||||||
@@ -286,7 +286,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
|||||||
ReasoningConfig: reasoningConfig,
|
ReasoningConfig: reasoningConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
sdkModel := agentpkg.CreateModel(modelCfg)
|
sdkModel := models.NewSDKChatModel(modelCfg)
|
||||||
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
|
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
|
||||||
|
|
||||||
runCfg := agentpkg.RunConfig{
|
runCfg := agentpkg.RunConfig{
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
|
|
||||||
agentpkg "github.com/memohai/memoh/internal/agent"
|
|
||||||
"github.com/memohai/memoh/internal/conversation"
|
"github.com/memohai/memoh/internal/conversation"
|
||||||
"github.com/memohai/memoh/internal/db/sqlc"
|
"github.com/memohai/memoh/internal/db/sqlc"
|
||||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||||
@@ -105,13 +104,13 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
|
|||||||
"Return ONLY the title text, nothing else.\n\n" +
|
"Return ONLY the title text, nothing else.\n\n" +
|
||||||
"User: " + userSnippet
|
"User: " + userSnippet
|
||||||
|
|
||||||
modelCfg := agentpkg.ModelConfig{
|
modelCfg := models.SDKModelConfig{
|
||||||
ModelID: model.ModelID,
|
ModelID: model.ModelID,
|
||||||
ClientType: provider.ClientType,
|
ClientType: provider.ClientType,
|
||||||
APIKey: provider.ApiKey,
|
APIKey: provider.ApiKey,
|
||||||
BaseURL: provider.BaseUrl,
|
BaseURL: provider.BaseUrl,
|
||||||
}
|
}
|
||||||
sdkModel := agentpkg.CreateModel(modelCfg)
|
sdkModel := models.NewSDKChatModel(modelCfg)
|
||||||
|
|
||||||
genCtx, cancel := context.WithTimeout(ctx, titleGenerateTimeout)
|
genCtx, cancel := context.WithTimeout(ctx, titleGenerateTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type denseRuntime struct {
|
|||||||
|
|
||||||
type denseModelSpec struct {
|
type denseModelSpec struct {
|
||||||
modelID string
|
modelID string
|
||||||
|
clientType string
|
||||||
baseURL string
|
baseURL string
|
||||||
apiKey string
|
apiKey string
|
||||||
dimensions int
|
dimensions int
|
||||||
@@ -74,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg
|
|||||||
return nil, fmt.Errorf("dense runtime: %w", err)
|
return nil, fmt.Errorf("dense runtime: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
embedModel := models.NewSDKEmbeddingModel(spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
|
embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
|
||||||
|
|
||||||
return &denseRuntime{
|
return &denseRuntime{
|
||||||
qdrant: qClient,
|
qdrant: qClient,
|
||||||
@@ -565,6 +566,7 @@ func resolveDenseEmbeddingModel(ctx context.Context, queries *dbsqlc.Queries, mo
|
|||||||
}
|
}
|
||||||
return denseModelSpec{
|
return denseModelSpec{
|
||||||
modelID: strings.TrimSpace(row.ModelID),
|
modelID: strings.TrimSpace(row.ModelID),
|
||||||
|
clientType: strings.TrimSpace(provider.ClientType),
|
||||||
baseURL: strings.TrimSpace(provider.BaseUrl),
|
baseURL: strings.TrimSpace(provider.BaseUrl),
|
||||||
apiKey: strings.TrimSpace(provider.ApiKey),
|
apiKey: strings.TrimSpace(provider.ApiKey),
|
||||||
dimensions: *cfg.Dimensions,
|
dimensions: *cfg.Dimensions,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/memohai/memoh/internal/agent"
|
"github.com/memohai/memoh/internal/agent"
|
||||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||||
|
"github.com/memohai/memoh/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -42,7 +43,7 @@ func New(cfg Config) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) model() *sdk.Model {
|
func (c *Client) model() *sdk.Model {
|
||||||
return agent.CreateModel(agent.ModelConfig{
|
return models.NewSDKChatModel(models.SDKModelConfig{
|
||||||
ModelID: c.cfg.ModelID,
|
ModelID: c.cfg.ModelID,
|
||||||
ClientType: c.cfg.ClientType,
|
ClientType: c.cfg.ClientType,
|
||||||
APIKey: c.cfg.APIKey,
|
APIKey: c.cfg.APIKey,
|
||||||
|
|||||||
@@ -4,30 +4,42 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
googleembedding "github.com/memohai/twilight-ai/provider/google/embedding"
|
||||||
openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding"
|
openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding"
|
||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given
|
// NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given
|
||||||
// provider configuration. Currently all embedding providers use the
|
// provider configuration. It dispatches to the native Google embedding provider
|
||||||
// OpenAI-compatible /embeddings endpoint (including Google-hosted models that
|
// when clientType is "google-generative-ai", and falls back to the
|
||||||
// expose the same wire format), so we route everything through the OpenAI
|
// OpenAI-compatible /embeddings endpoint for all other provider types.
|
||||||
// embedding provider. If a future provider requires a different wire protocol,
|
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
|
||||||
// add a branch here.
|
|
||||||
func NewSDKEmbeddingModel(baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
|
|
||||||
if timeout <= 0 {
|
if timeout <= 0 {
|
||||||
timeout = 30 * time.Second
|
timeout = 30 * time.Second
|
||||||
}
|
}
|
||||||
httpClient := &http.Client{Timeout: timeout}
|
httpClient := &http.Client{Timeout: timeout}
|
||||||
|
|
||||||
opts := []openaiembedding.Option{
|
switch ClientType(clientType) {
|
||||||
openaiembedding.WithAPIKey(apiKey),
|
case ClientTypeGoogleGenerativeAI:
|
||||||
openaiembedding.WithHTTPClient(httpClient),
|
opts := []googleembedding.Option{
|
||||||
}
|
googleembedding.WithAPIKey(apiKey),
|
||||||
if baseURL != "" {
|
googleembedding.WithHTTPClient(httpClient),
|
||||||
opts = append(opts, openaiembedding.WithBaseURL(baseURL))
|
}
|
||||||
}
|
if baseURL != "" {
|
||||||
|
opts = append(opts, googleembedding.WithBaseURL(baseURL))
|
||||||
|
}
|
||||||
|
p := googleembedding.New(opts...)
|
||||||
|
return p.EmbeddingModel(modelID)
|
||||||
|
|
||||||
p := openaiembedding.New(opts...)
|
default:
|
||||||
return p.EmbeddingModel(modelID)
|
opts := []openaiembedding.Option{
|
||||||
|
openaiembedding.WithAPIKey(apiKey),
|
||||||
|
openaiembedding.WithHTTPClient(httpClient),
|
||||||
|
}
|
||||||
|
if baseURL != "" {
|
||||||
|
opts = append(opts, openaiembedding.WithBaseURL(baseURL))
|
||||||
|
}
|
||||||
|
p := openaiembedding.New(opts...)
|
||||||
|
return p.EmbeddingModel(modelID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
|||||||
// Embedding models don't have a chat Provider in the SDK — probe
|
// Embedding models don't have a chat Provider in the SDK — probe
|
||||||
// the /embeddings endpoint directly.
|
// the /embeddings endpoint directly.
|
||||||
if model.Type == string(ModelTypeEmbedding) {
|
if model.Type == string(ModelTypeEmbedding) {
|
||||||
return s.testEmbeddingModel(ctx, baseURL, apiKey, model.ModelID)
|
return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
|
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
|
||||||
@@ -100,11 +100,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
|||||||
// testEmbeddingModel probes an embedding model by performing a minimal
|
// testEmbeddingModel probes an embedding model by performing a minimal
|
||||||
// embedding request via the Twilight SDK, verifying that the model is
|
// embedding request via the Twilight SDK, verifying that the model is
|
||||||
// reachable and functional rather than merely checking HTTP connectivity.
|
// reachable and functional rather than merely checking HTTP connectivity.
|
||||||
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) {
|
func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
model := NewSDKEmbeddingModel(baseURL, apiKey, modelID, probeTimeout)
|
model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout)
|
||||||
client := sdk.NewClient()
|
client := sdk.NewClient()
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package agent
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
|
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
|
||||||
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
||||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||||
@@ -8,23 +10,30 @@ import (
|
|||||||
sdk "github.com/memohai/twilight-ai/sdk"
|
sdk "github.com/memohai/twilight-ai/sdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientType constants matching the database model configuration.
|
// SDKModelConfig holds provider and model information resolved from DB,
|
||||||
const (
|
// used to construct a Twilight AI SDK Model instance.
|
||||||
ClientTypeOpenAICompletions = "openai-completions"
|
type SDKModelConfig struct {
|
||||||
ClientTypeOpenAIResponses = "openai-responses"
|
ModelID string
|
||||||
ClientTypeAnthropicMessages = "anthropic-messages"
|
ClientType string
|
||||||
ClientTypeGoogleGenerativeAI = "google-generative-ai"
|
APIKey string //nolint:gosec // carries provider credential material at runtime
|
||||||
)
|
BaseURL string
|
||||||
|
ReasoningConfig *ReasoningConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReasoningConfig controls extended thinking/reasoning behavior.
|
||||||
|
type ReasoningConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
Effort string
|
||||||
|
}
|
||||||
|
|
||||||
// Reasoning budget maps per client type.
|
|
||||||
var (
|
var (
|
||||||
anthropicBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
|
anthropicBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
|
||||||
googleBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
|
googleBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateModel builds a Twilight AI SDK Model from the resolved model config.
|
// NewSDKChatModel builds a Twilight AI SDK Model from the resolved model config.
|
||||||
func CreateModel(cfg ModelConfig) *sdk.Model {
|
func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||||
switch cfg.ClientType {
|
switch ClientType(cfg.ClientType) {
|
||||||
case ClientTypeOpenAICompletions:
|
case ClientTypeOpenAICompletions:
|
||||||
opts := []openaicompletions.Option{
|
opts := []openaicompletions.Option{
|
||||||
openaicompletions.WithAPIKey(cfg.APIKey),
|
openaicompletions.WithAPIKey(cfg.APIKey),
|
||||||
@@ -53,7 +62,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
|
|||||||
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
|
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
|
||||||
}
|
}
|
||||||
if cfg.ReasoningConfig != nil && cfg.ReasoningConfig.Enabled {
|
if cfg.ReasoningConfig != nil && cfg.ReasoningConfig.Enabled {
|
||||||
budget := ReasoningBudgetTokens(ClientTypeAnthropicMessages, cfg.ReasoningConfig.Effort)
|
budget := ReasoningBudgetTokens(cfg.ClientType, cfg.ReasoningConfig.Effort)
|
||||||
opts = append(opts, anthropicmessages.WithThinking(anthropicmessages.ThinkingConfig{
|
opts = append(opts, anthropicmessages.WithThinking(anthropicmessages.ThinkingConfig{
|
||||||
Type: "enabled",
|
Type: "enabled",
|
||||||
BudgetTokens: budget,
|
BudgetTokens: budget,
|
||||||
@@ -73,7 +82,6 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
|
|||||||
return p.ChatModel(cfg.ModelID)
|
return p.ChatModel(cfg.ModelID)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// OpenAI-compatible fallback
|
|
||||||
opts := []openaicompletions.Option{
|
opts := []openaicompletions.Option{
|
||||||
openaicompletions.WithAPIKey(cfg.APIKey),
|
openaicompletions.WithAPIKey(cfg.APIKey),
|
||||||
}
|
}
|
||||||
@@ -86,7 +94,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BuildReasoningOptions returns SDK generation options for reasoning/thinking.
|
// BuildReasoningOptions returns SDK generation options for reasoning/thinking.
|
||||||
func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption {
|
func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
|
||||||
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
|
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -95,9 +103,8 @@ func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption {
|
|||||||
effort = "medium"
|
effort = "medium"
|
||||||
}
|
}
|
||||||
|
|
||||||
switch cfg.ClientType {
|
switch ClientType(cfg.ClientType) {
|
||||||
case ClientTypeAnthropicMessages:
|
case ClientTypeAnthropicMessages:
|
||||||
// Anthropic uses thinking budget — no SDK option, handled by provider
|
|
||||||
return nil
|
return nil
|
||||||
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions:
|
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions:
|
||||||
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
|
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
|
||||||
@@ -113,7 +120,7 @@ func ReasoningBudgetTokens(clientType, effort string) int {
|
|||||||
if effort == "" {
|
if effort == "" {
|
||||||
effort = "medium"
|
effort = "medium"
|
||||||
}
|
}
|
||||||
switch clientType {
|
switch ClientType(clientType) {
|
||||||
case ClientTypeAnthropicMessages:
|
case ClientTypeAnthropicMessages:
|
||||||
if b, ok := anthropicBudget[effort]; ok {
|
if b, ok := anthropicBudget[effort]; ok {
|
||||||
return b
|
return b
|
||||||
@@ -128,3 +135,21 @@ func ReasoningBudgetTokens(clientType, effort string) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveClientType infers the client type string from an SDK Model's provider name.
|
||||||
|
func ResolveClientType(model *sdk.Model) string {
|
||||||
|
if model == nil || model.Provider == nil {
|
||||||
|
return string(ClientTypeOpenAICompletions)
|
||||||
|
}
|
||||||
|
name := model.Provider.Name()
|
||||||
|
switch {
|
||||||
|
case strings.Contains(name, "anthropic"):
|
||||||
|
return string(ClientTypeAnthropicMessages)
|
||||||
|
case strings.Contains(name, "google"):
|
||||||
|
return string(ClientTypeGoogleGenerativeAI)
|
||||||
|
case strings.Contains(name, "responses"):
|
||||||
|
return string(ClientTypeOpenAIResponses)
|
||||||
|
default:
|
||||||
|
return string(ClientTypeOpenAICompletions)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user