diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 4e4ecf18..253d88e5 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -11,6 +11,7 @@ import ( sdk "github.com/memohai/twilight-ai/sdk" "github.com/memohai/memoh/internal/agent/tools" + "github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/workspace/bridge" ) @@ -371,9 +372,9 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep if prepareStep != nil { opts = append(opts, sdk.WithPrepareStep(prepareStep)) } - opts = append(opts, BuildReasoningOptions(ModelConfig{ - ClientType: resolveClientType(cfg.Model), - ReasoningConfig: &ReasoningConfig{ + opts = append(opts, models.BuildReasoningOptions(models.SDKModelConfig{ + ClientType: models.ResolveClientType(cfg.Model), + ReasoningConfig: &models.ReasoningConfig{ Enabled: cfg.ReasoningEffort != "", Effort: cfg.ReasoningEffort, }, @@ -381,23 +382,6 @@ func (*Agent) buildGenerateOptions(cfg RunConfig, tools []sdk.Tool, prepareStep return opts } -func resolveClientType(model *sdk.Model) string { - if model == nil || model.Provider == nil { - return ClientTypeOpenAICompletions - } - name := model.Provider.Name() - switch { - case strings.Contains(name, "anthropic"): - return ClientTypeAnthropicMessages - case strings.Contains(name, "google"): - return ClientTypeGoogleGenerativeAI - case strings.Contains(name, "responses"): - return ClientTypeOpenAIResponses - default: - return ClientTypeOpenAICompletions - } -} - // assembleTools collects tools from all registered ToolProviders. func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig) ([]sdk.Tool, error) { if len(a.toolProviders) == 0 { diff --git a/internal/agent/read_media.go b/internal/agent/read_media.go index 4a5269a6..c16425b9 100644 --- a/internal/agent/read_media.go +++ b/internal/agent/read_media.go @@ -7,6 +7,7 @@ import ( sdk "github.com/memohai/twilight-ai/sdk" agenttools "github.com/memohai/memoh/internal/agent/tools" + "github.com/memohai/memoh/internal/models" ) func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) { @@ -14,7 +15,7 @@ func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *re return tools, nil } - clientType := resolveClientType(model) + clientType := models.ResolveClientType(model) state := &readMediaDecorationState{ pendingImages: make(map[string]sdk.ImagePart), } @@ -164,7 +165,7 @@ func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.Imag } image := imageBase64 - if clientType != ClientTypeAnthropicMessages { + if clientType != string(models.ClientTypeAnthropicMessages) { image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64) } return sdk.ImagePart{ diff --git a/internal/agent/spawn_adapter.go b/internal/agent/spawn_adapter.go index e411bacf..03ff3ab4 100644 --- a/internal/agent/spawn_adapter.go +++ b/internal/agent/spawn_adapter.go @@ -6,6 +6,7 @@ import ( sdk "github.com/memohai/twilight-ai/sdk" "github.com/memohai/memoh/internal/agent/tools" + "github.com/memohai/memoh/internal/models" ) // SpawnAdapter wraps *Agent to satisfy tools.SpawnAgent without creating @@ -68,10 +69,10 @@ func SpawnSystemPrompt(sessionType string) string { }) } -// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to agent.CreateModel. +// SpawnModelCreatorFunc returns a tools.ModelCreator that delegates to models.NewSDKChatModel. func SpawnModelCreatorFunc() tools.ModelCreator { return func(modelID, clientType, apiKey, baseURL string) *sdk.Model { - return CreateModel(ModelConfig{ + return models.NewSDKChatModel(models.SDKModelConfig{ ModelID: modelID, ClientType: clientType, APIKey: apiKey, diff --git a/internal/agent/types.go b/internal/agent/types.go index 798cf84d..14d0f2d9 100644 --- a/internal/agent/types.go +++ b/internal/agent/types.go @@ -95,21 +95,6 @@ type SystemFile struct { Content string } -// ModelConfig holds provider and model information resolved from DB. -type ModelConfig struct { - ModelID string - ClientType string - APIKey string //nolint:gosec // carries provider credential material at runtime - BaseURL string - ReasoningConfig *ReasoningConfig -} - -// ReasoningConfig controls extended thinking/reasoning behavior. -type ReasoningConfig struct { - Enabled bool - Effort string -} - func mustMarshal(v any) json.RawMessage { data, err := json.Marshal(v) if err != nil { diff --git a/internal/compaction/service.go b/internal/compaction/service.go index 170c8f39..df16c1c6 100644 --- a/internal/compaction/service.go +++ b/internal/compaction/service.go @@ -11,9 +11,9 @@ import ( "github.com/jackc/pgx/v5/pgtype" sdk "github.com/memohai/twilight-ai/sdk" - "github.com/memohai/memoh/internal/agent" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" + "github.com/memohai/memoh/internal/models" ) // Service manages context compaction for bot conversations. @@ -103,7 +103,7 @@ func (s *Service) doCompaction(ctx context.Context, logID pgtype.UUID, sessionUU userPrompt := buildUserPrompt(priorSummaries, entries) - model := agent.CreateModel(agent.ModelConfig{ + model := models.NewSDKChatModel(models.SDKModelConfig{ ClientType: cfg.ClientType, BaseURL: cfg.BaseURL, APIKey: cfg.APIKey, diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index c669274a..d9b9b69d 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -270,15 +270,15 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r reasoningEffort = botSettings.ReasoningEffort } - var reasoningConfig *agentpkg.ReasoningConfig + var reasoningConfig *models.ReasoningConfig if reasoningEffort != "" { - reasoningConfig = &agentpkg.ReasoningConfig{ + reasoningConfig = &models.ReasoningConfig{ Enabled: true, Effort: reasoningEffort, } } - modelCfg := agentpkg.ModelConfig{ + modelCfg := models.SDKModelConfig{ ModelID: chatModel.ModelID, ClientType: clientType, APIKey: provider.ApiKey, @@ -286,7 +286,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r ReasoningConfig: reasoningConfig, } - sdkModel := agentpkg.CreateModel(modelCfg) + sdkModel := models.NewSDKChatModel(modelCfg) sdkMessages := modelMessagesToSDKMessages(nonNilModelMessages(messages)) runCfg := agentpkg.RunConfig{ diff --git a/internal/conversation/flow/resolver_title.go b/internal/conversation/flow/resolver_title.go index 0d634243..f1453d40 100644 --- a/internal/conversation/flow/resolver_title.go +++ b/internal/conversation/flow/resolver_title.go @@ -9,7 +9,6 @@ import ( sdk "github.com/memohai/twilight-ai/sdk" - agentpkg "github.com/memohai/memoh/internal/agent" "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/db/sqlc" messageevent "github.com/memohai/memoh/internal/message/event" @@ -105,13 +104,13 @@ func (r *Resolver) generateTitle(ctx context.Context, model models.GetResponse, "Return ONLY the title text, nothing else.\n\n" + "User: " + userSnippet - modelCfg := agentpkg.ModelConfig{ + modelCfg := models.SDKModelConfig{ ModelID: model.ModelID, ClientType: provider.ClientType, APIKey: provider.ApiKey, BaseURL: provider.BaseUrl, } - sdkModel := agentpkg.CreateModel(modelCfg) + sdkModel := models.NewSDKChatModel(modelCfg) genCtx, cancel := context.WithTimeout(ctx, titleGenerateTimeout) defer cancel() diff --git a/internal/memory/adapters/builtin/dense_runtime.go b/internal/memory/adapters/builtin/dense_runtime.go index a43551d8..d0006d0c 100644 --- a/internal/memory/adapters/builtin/dense_runtime.go +++ b/internal/memory/adapters/builtin/dense_runtime.go @@ -35,6 +35,7 @@ type denseRuntime struct { type denseModelSpec struct { modelID string + clientType string baseURL string apiKey string dimensions int @@ -74,7 +75,7 @@ func newDenseRuntime(providerConfig map[string]any, queries *dbsqlc.Queries, cfg return nil, fmt.Errorf("dense runtime: %w", err) } - embedModel := models.NewSDKEmbeddingModel(spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout) + embedModel := models.NewSDKEmbeddingModel(spec.clientType, spec.baseURL, spec.apiKey, spec.modelID, denseEmbedTimeout) return &denseRuntime{ qdrant: qClient, @@ -565,6 +566,7 @@ func resolveDenseEmbeddingModel(ctx context.Context, queries *dbsqlc.Queries, mo } return denseModelSpec{ modelID: strings.TrimSpace(row.ModelID), + clientType: strings.TrimSpace(provider.ClientType), baseURL: strings.TrimSpace(provider.BaseUrl), apiKey: strings.TrimSpace(provider.ApiKey), dimensions: *cfg.Dimensions, diff --git a/internal/memory/memllm/client.go b/internal/memory/memllm/client.go index 32121cc4..b88e82e3 100644 --- a/internal/memory/memllm/client.go +++ b/internal/memory/memllm/client.go @@ -11,6 +11,7 @@ import ( "github.com/memohai/memoh/internal/agent" adapters "github.com/memohai/memoh/internal/memory/adapters" + "github.com/memohai/memoh/internal/models" ) const ( @@ -42,7 +43,7 @@ func New(cfg Config) *Client { } func (c *Client) model() *sdk.Model { - return agent.CreateModel(agent.ModelConfig{ + return models.NewSDKChatModel(models.SDKModelConfig{ ModelID: c.cfg.ModelID, ClientType: c.cfg.ClientType, APIKey: c.cfg.APIKey, diff --git a/internal/models/embedding.go b/internal/models/embedding.go index 70ffb823..6ff8fd0b 100644 --- a/internal/models/embedding.go +++ b/internal/models/embedding.go @@ -4,30 +4,42 @@ import ( "net/http" "time" + googleembedding "github.com/memohai/twilight-ai/provider/google/embedding" openaiembedding "github.com/memohai/twilight-ai/provider/openai/embedding" sdk "github.com/memohai/twilight-ai/sdk" ) // NewSDKEmbeddingModel creates a Twilight AI SDK EmbeddingModel for the given -// provider configuration. Currently all embedding providers use the -// OpenAI-compatible /embeddings endpoint (including Google-hosted models that -// expose the same wire format), so we route everything through the OpenAI -// embedding provider. If a future provider requires a different wire protocol, -// add a branch here. -func NewSDKEmbeddingModel(baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel { +// provider configuration. It dispatches to the native Google embedding provider +// when clientType is "google-generative-ai", and falls back to the +// OpenAI-compatible /embeddings endpoint for all other provider types. +func NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID string, timeout time.Duration) *sdk.EmbeddingModel { if timeout <= 0 { timeout = 30 * time.Second } httpClient := &http.Client{Timeout: timeout} - opts := []openaiembedding.Option{ - openaiembedding.WithAPIKey(apiKey), - openaiembedding.WithHTTPClient(httpClient), - } - if baseURL != "" { - opts = append(opts, openaiembedding.WithBaseURL(baseURL)) - } + switch ClientType(clientType) { + case ClientTypeGoogleGenerativeAI: + opts := []googleembedding.Option{ + googleembedding.WithAPIKey(apiKey), + googleembedding.WithHTTPClient(httpClient), + } + if baseURL != "" { + opts = append(opts, googleembedding.WithBaseURL(baseURL)) + } + p := googleembedding.New(opts...) + return p.EmbeddingModel(modelID) - p := openaiembedding.New(opts...) - return p.EmbeddingModel(modelID) + default: + opts := []openaiembedding.Option{ + openaiembedding.WithAPIKey(apiKey), + openaiembedding.WithHTTPClient(httpClient), + } + if baseURL != "" { + opts = append(opts, openaiembedding.WithBaseURL(baseURL)) + } + p := openaiembedding.New(opts...) + return p.EmbeddingModel(modelID) + } } diff --git a/internal/models/probe.go b/internal/models/probe.go index eb6edd86..c37dc8fe 100644 --- a/internal/models/probe.go +++ b/internal/models/probe.go @@ -43,7 +43,7 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { // Embedding models don't have a chat Provider in the SDK — probe // the /embeddings endpoint directly. if model.Type == string(ModelTypeEmbedding) { - return s.testEmbeddingModel(ctx, baseURL, apiKey, model.ModelID) + return s.testEmbeddingModel(ctx, string(clientType), baseURL, apiKey, model.ModelID) } sdkProvider := NewSDKProvider(baseURL, apiKey, clientType, probeTimeout) @@ -100,11 +100,11 @@ func (s *Service) Test(ctx context.Context, id string) (TestResponse, error) { // testEmbeddingModel probes an embedding model by performing a minimal // embedding request via the Twilight SDK, verifying that the model is // reachable and functional rather than merely checking HTTP connectivity. -func (*Service) testEmbeddingModel(ctx context.Context, baseURL, apiKey, modelID string) (TestResponse, error) { +func (*Service) testEmbeddingModel(ctx context.Context, clientType, baseURL, apiKey, modelID string) (TestResponse, error) { ctx, cancel := context.WithTimeout(ctx, probeTimeout) defer cancel() - model := NewSDKEmbeddingModel(baseURL, apiKey, modelID, probeTimeout) + model := NewSDKEmbeddingModel(clientType, baseURL, apiKey, modelID, probeTimeout) client := sdk.NewClient() start := time.Now() diff --git a/internal/agent/model.go b/internal/models/sdk.go similarity index 68% rename from internal/agent/model.go rename to internal/models/sdk.go index a663df38..d5357c6a 100644 --- a/internal/agent/model.go +++ b/internal/models/sdk.go @@ -1,6 +1,8 @@ -package agent +package models import ( + "strings" + anthropicmessages "github.com/memohai/twilight-ai/provider/anthropic/messages" googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai" openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions" @@ -8,23 +10,30 @@ import ( sdk "github.com/memohai/twilight-ai/sdk" ) -// ClientType constants matching the database model configuration. -const ( - ClientTypeOpenAICompletions = "openai-completions" - ClientTypeOpenAIResponses = "openai-responses" - ClientTypeAnthropicMessages = "anthropic-messages" - ClientTypeGoogleGenerativeAI = "google-generative-ai" -) +// SDKModelConfig holds provider and model information resolved from DB, +// used to construct a Twilight AI SDK Model instance. +type SDKModelConfig struct { + ModelID string + ClientType string + APIKey string //nolint:gosec // carries provider credential material at runtime + BaseURL string + ReasoningConfig *ReasoningConfig +} + +// ReasoningConfig controls extended thinking/reasoning behavior. +type ReasoningConfig struct { + Enabled bool + Effort string +} -// Reasoning budget maps per client type. var ( anthropicBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000} googleBudget = map[string]int{"low": 5000, "medium": 16000, "high": 50000} ) -// CreateModel builds a Twilight AI SDK Model from the resolved model config. -func CreateModel(cfg ModelConfig) *sdk.Model { - switch cfg.ClientType { +// NewSDKChatModel builds a Twilight AI SDK Model from the resolved model config. +func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model { + switch ClientType(cfg.ClientType) { case ClientTypeOpenAICompletions: opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), @@ -53,7 +62,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model { opts = append(opts, anthropicmessages.WithBaseURL(cfg.BaseURL)) } if cfg.ReasoningConfig != nil && cfg.ReasoningConfig.Enabled { - budget := ReasoningBudgetTokens(ClientTypeAnthropicMessages, cfg.ReasoningConfig.Effort) + budget := ReasoningBudgetTokens(cfg.ClientType, cfg.ReasoningConfig.Effort) opts = append(opts, anthropicmessages.WithThinking(anthropicmessages.ThinkingConfig{ Type: "enabled", BudgetTokens: budget, @@ -73,7 +82,6 @@ func CreateModel(cfg ModelConfig) *sdk.Model { return p.ChatModel(cfg.ModelID) default: - // OpenAI-compatible fallback opts := []openaicompletions.Option{ openaicompletions.WithAPIKey(cfg.APIKey), } @@ -86,7 +94,7 @@ func CreateModel(cfg ModelConfig) *sdk.Model { } // BuildReasoningOptions returns SDK generation options for reasoning/thinking. -func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption { +func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption { if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled { return nil } @@ -95,9 +103,8 @@ func BuildReasoningOptions(cfg ModelConfig) []sdk.GenerateOption { effort = "medium" } - switch cfg.ClientType { + switch ClientType(cfg.ClientType) { case ClientTypeAnthropicMessages: - // Anthropic uses thinking budget — no SDK option, handled by provider return nil case ClientTypeOpenAIResponses, ClientTypeOpenAICompletions: return []sdk.GenerateOption{sdk.WithReasoningEffort(effort)} @@ -113,7 +120,7 @@ func ReasoningBudgetTokens(clientType, effort string) int { if effort == "" { effort = "medium" } - switch clientType { + switch ClientType(clientType) { case ClientTypeAnthropicMessages: if b, ok := anthropicBudget[effort]; ok { return b @@ -128,3 +135,21 @@ func ReasoningBudgetTokens(clientType, effort string) int { return 0 } } + +// ResolveClientType infers the client type string from an SDK Model's provider name. +func ResolveClientType(model *sdk.Model) string { + if model == nil || model.Provider == nil { + return string(ClientTypeOpenAICompletions) + } + name := model.Provider.Name() + switch { + case strings.Contains(name, "anthropic"): + return string(ClientTypeAnthropicMessages) + case strings.Contains(name, "google"): + return string(ClientTypeGoogleGenerativeAI) + case strings.Contains(name, "responses"): + return string(ClientTypeOpenAIResponses) + default: + return string(ClientTypeOpenAICompletions) + } +}