mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
refactor: unify SDK model factories into internal/models
Move CreateModel, BuildReasoningOptions, ReasoningBudgetTokens and related types from internal/agent to internal/models as NewSDKChatModel, SDKModelConfig, etc. This eliminates duplicate ClientType constants and centralises all Twilight AI SDK instance creation in a single package. NewSDKEmbeddingModel now accepts a clientType parameter and dispatches to the native Google embedding provider for google-generative-ai, instead of always using the OpenAI-compatible endpoint.
This commit is contained in:
+4
-20
@@ -11,6 +11,7 @@ import (
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/agent/tools"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
)
|
||||
|
||||
@@ -371,9 +372,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep
|
||||
if prepareStep != nil {
|
||||
opts = append(opts, sdk.WithPrepareStep(prepareStep))
|
||||
}
|
||||
opts = append(opts, BuildReasoningOptions(ModelConfig{
|
||||
ClientType: resolveClientType(cfg.Model),
|
||||
ReasoningConfig: &ReasoningConfig{
|
||||
opts = append(opts, models.BuildReasoningOptions(models.SDKModelConfig{
|
||||
ClientType: models.ResolveClientType(cfg.Model),
|
||||
ReasoningConfig: &models.ReasoningConfig{
|
||||
Enabled: cfg.ReasoningEffort != "",
|
||||
Effort: cfg.ReasoningEffort,
|
||||
},
|
||||
@@ -381,23 +382,6 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep
|
||||
return opts
|
||||
}
|
||||
|
||||
func resolveClientType(model *sdk.Model) string {
|
||||
if model == nil || model.Provider == nil {
|
||||
return ClientTypeOpenAICompletions
|
||||
}
|
||||
name := model.Provider.Name()
|
||||
switch {
|
||||
case strings.Contains(name, "anthropic"):
|
||||
return ClientTypeAnthropicMessages
|
||||
case strings.Contains(name, "google"):
|
||||
return ClientTypeGoogleGenerativeAI
|
||||
case strings.Contains(name, "responses"):
|
||||
return ClientTypeOpenAIResponses
|
||||
default:
|
||||
return ClientTypeOpenAICompletions
|
||||
}
|
||||
}
|
||||
|
||||
// assembleTools collects tools from all registered ToolProviders.
|
||||
func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, error) {
|
||||
if len(a.toolProviders) == 0 {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
|
||||
@@ -14,7 +15,7 @@ func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *re
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
clientType := resolveClientType(model)
|
||||
clientType := models.ResolveClientType(model)
|
||||
state := &readMediaDecorationState{
|
||||
pendingImages: make(map[string]sdk.ImagePart),
|
||||
}
|
||||
@@ -164,7 +165,7 @@ func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.Imag
|
||||
}
|
||||
|
||||
image := imageBase64
|
||||
if clientType != ClientTypeAnthropicMessages {
|
||||
if clientType != string(models.ClientTypeAnthropicMessages) {
|
||||
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
|
||||
}
|
||||
return sdk.ImagePart{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/agent/tools"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
// SpawnAdapter wraps *Agent to satisfy tools.SpawnAgent without creating
|
||||
@@ -68,10 +69,10 @@ func SpawnSystemPrompt(sessionType string) string {
|
||||
})
|
||||
}
|
||||
|
||||
// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to agent.CreateModel.
|
||||
// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to models.NewSDKChatModel.
|
||||
func SpawnModelCreatorFunc() tools.ModelCreator {
|
||||
return func(modelID, clientType, apiKey, baseURL string) *sdk.Model {
|
||||
return CreateModel(ModelConfig{
|
||||
return models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ModelID: modelID,
|
||||
ClientType: clientType,
|
||||
APIKey: apiKey,
|
||||
|
||||
@@ -95,21 +95,6 @@ type SystemFile struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
// ModelConfig holds provider and model information resolved from DB.
|
||||
type ModelConfig struct {
|
||||
ModelID string
|
||||
ClientType string
|
||||
APIKey string //nolint:gosec // carries provider credential material at runtime
|
||||
BaseURL string
|
||||
ReasoningConfig *ReasoningConfig
|
||||
}
|
||||
|
||||
// ReasoningConfig controls extended thinking/reasoning behavior.
|
||||
type ReasoningConfig struct {
|
||||
Enabled bool
|
||||
Effort string
|
||||
}
|
||||
|
||||
func mustMarshal(v any) json.RawMessage {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
"github.com/memohai/memoh/internal/agent"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
// Service manages context compaction for bot conversations.
|
||||
@@ -103,7 +103,7 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU
|
||||
|
||||
userPrompt := buildUserPrompt(priorSummaries, entries)
|
||||
|
||||
model := agent.CreateModel(agent.ModelConfig{
|
||||
model := models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ClientType: cfg.ClientType,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
|
||||
@@ -270,15 +270,15 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
reasoningEffort = botSettings.ReasoningEffort
|
||||
}
|
||||
|
||||
var reasoningConfig *agentpkg.ReasoningConfig
|
||||
var reasoningConfig *models.ReasoningConfig
|
||||
if reasoningEffort != "" {
|
||||
reasoningConfig = &agentpkg.ReasoningConfig{
|
||||
reasoningConfig = &models.ReasoningConfig{
|
||||
Enabled: true,
|
||||
Effort: reasoningEffort,
|
||||
}
|
||||
}
|
||||
|
||||
modelCfg := agentpkg.ModelConfig{
|
||||
modelCfg := models.SDKModelConfig{
|
||||
ModelID: chatModel.ModelID,
|
||||
ClientType: clientType,
|
||||
APIKey: provider.ApiKey,
|
||||
@@ -286,7 +286,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r
|
||||
ReasoningConfig: reasoningConfig,
|
||||
}
|
||||
|
||||
sdkModel := agentpkg.CreateModel(modelCfg)
|
||||
sdkModel := models.NewSDKChatModel(modelCfg)
|
||||
sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages))
|
||||
|
||||
runCfg := agentpkg.RunConfig{
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
agentpkg "github.com/memohai/memoh/internal/agent"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
@@ -105,13 +104,13 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse,
|
||||
"Return ONLY the title text, nothing else.\n\n" +
|
||||
"User: " + userSnippet
|
||||
|
||||
modelCfg := agentpkg.ModelConfig{
|
||||
modelCfg := models.SDKModelConfig{
|
||||
ModelID: model.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: provider.ApiKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
}
|
||||
sdkModel := agentpkg.CreateModel(modelCfg)
|
||||
sdkModel := models.NewSDKChatModel(modelCfg)
|
||||
|
||||
genCtx, cancel := context.WithTimeout(ctx, titleGenerateTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -35,6 +35,7 @@ type denseRuntime struct {
|
||||
|
||||
type denseModelSpec struct {
|
||||
modelID string
|
||||
clientType string
|
||||
baseURL string
|
||||
apiKey string
|
||||
dimensions int
|
||||
@@ -74,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg
|
||||
return nil, fmt.Errorf("dense runtime: %w", err)
|
||||
}
|
||||
|
||||
embedModel := models.NewSDKEmbeddingModel(spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
|
||||
embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout)
|
||||
|
||||
return &denseRuntime{
|
||||
qdrant: qClient,
|
||||
@@ -565,6 +566,7 @@ func resolveDenseEmbeddingModel(ctx context.Context, queries *dbsqlc.Queries, mo
|
||||
}
|
||||
return denseModelSpec{
|
||||
modelID: strings.TrimSpace(row.ModelID),
|
||||
clientType: strings.TrimSpace(provider.ClientType),
|
||||
baseURL: strings.TrimSpace(provider.BaseUrl),
|
||||
apiKey: strings.TrimSpace(provider.ApiKey),
|
||||
dimensions: *cfg.Dimensions,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/memohai/memoh/internal/agent"
|
||||
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -42,7 +43,7 @@ func New(cfg Config) *Client {
|
||||
}
|
||||
|
||||
func (c *Client) model() *sdk.Model {
|
||||
return agent.CreateModel(agent.ModelConfig{
|
||||
return models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ModelID: c.cfg.ModelID,
|
||||
ClientType: c.cfg.ClientType,
|
||||
APIKey: c.cfg.APIKey,
|
||||
|
||||
@@ -4,30 +4,42 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
googleembedding "github.com/memohai/twilight-ai/provider/google/embedding"
|
||||
openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
)
|
||||
|
||||
// NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given
|
||||
// provider configuration. Currently all embedding providers use the
|
||||
// OpenAI-compatible /embeddings endpoint (including Google-hosted models that
|
||||
// expose the same wire format), so we route everything through the OpenAI
|
||||
// embedding provider. If a future provider requires a different wire protocol,
|
||||
// add a branch here.
|
||||
func NewSDKEmbeddingModel(baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
|
||||
// provider configuration. It dispatches to the native Google embedding provider
|
||||
// when clientType is "google-generative-ai", and falls back to the
|
||||
// OpenAI-compatible /embeddings endpoint for all other provider types.
|
||||
func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
opts := []openaiembedding.Option{
|
||||
openaiembedding.WithAPIKey(apiKey),
|
||||
openaiembedding.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, openaiembedding.WithBaseURL(baseURL))
|
||||
}
|
||||
switch ClientType(clientType) {
|
||||
case ClientTypeGoogleGenerativeAI:
|
||||
opts := []googleembedding.Option{
|
||||
googleembedding.WithAPIKey(apiKey),
|
||||
googleembedding.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, googleembedding.WithBaseURL(baseURL))
|
||||
}
|
||||
p := googleembedding.New(opts...)
|
||||
return p.EmbeddingModel(modelID)
|
||||
|
||||
p := openaiembedding.New(opts...)
|
||||
return p.EmbeddingModel(modelID)
|
||||
default:
|
||||
opts := []openaiembedding.Option{
|
||||
openaiembedding.WithAPIKey(apiKey),
|
||||
openaiembedding.WithHTTPClient(httpClient),
|
||||
}
|
||||
if baseURL != "" {
|
||||
opts = append(opts, openaiembedding.WithBaseURL(baseURL))
|
||||
}
|
||||
p := openaiembedding.New(opts...)
|
||||
return p.EmbeddingModel(modelID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
// Embedding models don't have a chat Provider in the SDK — probe
|
||||
// the /embeddings endpoint directly.
|
||||
if model.Type == string(ModelTypeEmbedding) {
|
||||
return s.testEmbeddingModel(ctx, baseURL, apiKey, model.ModelID)
|
||||
return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID)
|
||||
}
|
||||
|
||||
sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout)
|
||||
@@ -100,11 +100,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) {
|
||||
// testEmbeddingModel probes an embedding model by performing a minimal
|
||||
// embedding request via the Twilight SDK, verifying that the model is
|
||||
// reachable and functional rather than merely checking HTTP connectivity.
|
||||
func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) {
|
||||
func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
defer cancel()
|
||||
|
||||
model := NewSDKEmbeddingModel(baseURL, apiKey, modelID, probeTimeout)
|
||||
model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout)
|
||||
client := sdk.NewClient()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package agent
|
||||
package models
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages"
|
||||
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
@@ -8,23 +10,30 @@ import (
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
)
|
||||
|
||||
// ClientType constants matching the database model configuration.
|
||||
const (
|
||||
ClientTypeOpenAICompletions = "openai-completions"
|
||||
ClientTypeOpenAIResponses = "openai-responses"
|
||||
ClientTypeAnthropicMessages = "anthropic-messages"
|
||||
ClientTypeGoogleGenerativeAI = "google-generative-ai"
|
||||
)
|
||||
// SDKModelConfig holds provider and model information resolved from DB,
|
||||
// used to construct a Twilight AI SDK Model instance.
|
||||
type SDKModelConfig struct {
|
||||
ModelID string
|
||||
ClientType string
|
||||
APIKey string //nolint:gosec // carries provider credential material at runtime
|
||||
BaseURL string
|
||||
ReasoningConfig *ReasoningConfig
|
||||
}
|
||||
|
||||
// ReasoningConfig controls extended thinking/reasoning behavior.
|
||||
type ReasoningConfig struct {
|
||||
Enabled bool
|
||||
Effort string
|
||||
}
|
||||
|
||||
// Reasoning budget maps per client type.
|
||||
var (
|
||||
anthropicBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
|
||||
googleBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000}
|
||||
)
|
||||
|
||||
// CreateModel builds a Twilight AI SDK Model from the resolved model config.
|
||||
func CreateModel(cfg ModelConfig) *sdk.Model {
|
||||
switch cfg.ClientType {
|
||||
// NewSDKChatModel builds a Twilight AI SDK Model from the resolved model config.
|
||||
func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
switch ClientType(cfg.ClientType) {
|
||||
case ClientTypeOpenAICompletions:
|
||||
opts := []openaicompletions.Option{
|
||||
openaicompletions.WithAPIKey(cfg.APIKey),
|
||||
@@ -53,7 +62,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
|
||||
opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
if cfg.ReasoningConfig != nil && cfg.ReasoningConfig.Enabled {
|
||||
budget := ReasoningBudgetTokens(ClientTypeAnthropicMessages, cfg.ReasoningConfig.Effort)
|
||||
budget := ReasoningBudgetTokens(cfg.ClientType, cfg.ReasoningConfig.Effort)
|
||||
opts = append(opts, anthropicmessages.WithThinking(anthropicmessages.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: budget,
|
||||
@@ -73,7 +82,6 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
|
||||
return p.ChatModel(cfg.ModelID)
|
||||
|
||||
default:
|
||||
// OpenAI-compatible fallback
|
||||
opts := []openaicompletions.Option{
|
||||
openaicompletions.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
@@ -86,7 +94,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model {
|
||||
}
|
||||
|
||||
// BuildReasoningOptions returns SDK generation options for reasoning/thinking.
|
||||
func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption {
|
||||
func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
|
||||
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
|
||||
return nil
|
||||
}
|
||||
@@ -95,9 +103,8 @@ func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption {
|
||||
effort = "medium"
|
||||
}
|
||||
|
||||
switch cfg.ClientType {
|
||||
switch ClientType(cfg.ClientType) {
|
||||
case ClientTypeAnthropicMessages:
|
||||
// Anthropic uses thinking budget — no SDK option, handled by provider
|
||||
return nil
|
||||
case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions:
|
||||
return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)}
|
||||
@@ -113,7 +120,7 @@ func ReasoningBudgetTokens(clientType, effort string) int {
|
||||
if effort == "" {
|
||||
effort = "medium"
|
||||
}
|
||||
switch clientType {
|
||||
switch ClientType(clientType) {
|
||||
case ClientTypeAnthropicMessages:
|
||||
if b, ok := anthropicBudget[effort]; ok {
|
||||
return b
|
||||
@@ -128,3 +135,21 @@ func ReasoningBudgetTokens(clientType, effort string) int {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveClientType infers the client type string from an SDK Model's provider name.
|
||||
func ResolveClientType(model *sdk.Model) string {
|
||||
if model == nil || model.Provider == nil {
|
||||
return string(ClientTypeOpenAICompletions)
|
||||
}
|
||||
name := model.Provider.Name()
|
||||
switch {
|
||||
case strings.Contains(name, "anthropic"):
|
||||
return string(ClientTypeAnthropicMessages)
|
||||
case strings.Contains(name, "google"):
|
||||
return string(ClientTypeGoogleGenerativeAI)
|
||||
case strings.Contains(name, "responses"):
|
||||
return string(ClientTypeOpenAIResponses)
|
||||
default:
|
||||
return string(ClientTypeOpenAICompletions)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user