mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
1657 lines
49 KiB
Go
1657 lines
49 KiB
Go
package flow
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
|
|
attachmentpkg "github.com/memohai/memoh/internal/attachment"
|
|
"github.com/memohai/memoh/internal/conversation"
|
|
"github.com/memohai/memoh/internal/db"
|
|
"github.com/memohai/memoh/internal/db/sqlc"
|
|
"github.com/memohai/memoh/internal/memory"
|
|
messagepkg "github.com/memohai/memoh/internal/message"
|
|
"github.com/memohai/memoh/internal/models"
|
|
"github.com/memohai/memoh/internal/schedule"
|
|
"github.com/memohai/memoh/internal/settings"
|
|
)
|
|
|
|
const (
|
|
defaultMaxContextMinutes = 24 * 60
|
|
memoryContextLimitPerScope = 4
|
|
memoryContextMaxItems = 8
|
|
memoryContextItemMaxChars = 220
|
|
sharedMemoryNamespace = "bot"
|
|
// Keep gateway payload bounded when inlining binary attachments as data URLs.
|
|
gatewayInlineAttachmentMaxBytes int64 = 20 * 1024 * 1024
|
|
)
|
|
|
|
// SkillEntry represents a skill loaded from the container.
|
|
type SkillEntry struct {
|
|
Name string
|
|
Description string
|
|
Content string
|
|
Metadata map[string]any
|
|
}
|
|
|
|
// SkillLoader loads skills for a given bot from its container.
|
|
type SkillLoader interface {
|
|
LoadSkills(ctx context.Context, botID string) ([]SkillEntry, error)
|
|
}
|
|
|
|
// ConversationSettingsReader defines settings lookup behavior needed by flow resolution.
|
|
type ConversationSettingsReader interface {
|
|
GetSettings(ctx context.Context, conversationID string) (conversation.Settings, error)
|
|
}
|
|
|
|
// gatewayAssetLoader resolves content_hash references to binary payloads for gateway dispatch.
|
|
type gatewayAssetLoader interface {
|
|
OpenForGateway(ctx context.Context, botID, contentHash string) (reader io.ReadCloser, mime string, err error)
|
|
}
|
|
|
|
// Resolver orchestrates chat with the agent gateway.
|
|
type Resolver struct {
|
|
modelsService *models.Service
|
|
queries *sqlc.Queries
|
|
memoryService *memory.Service
|
|
conversationSvc ConversationSettingsReader
|
|
messageService messagepkg.Service
|
|
settingsService *settings.Service
|
|
skillLoader SkillLoader
|
|
assetLoader gatewayAssetLoader
|
|
gatewayBaseURL string
|
|
timeout time.Duration
|
|
logger *slog.Logger
|
|
httpClient *http.Client
|
|
streamingClient *http.Client
|
|
}
|
|
|
|
// NewResolver creates a Resolver that communicates with the agent gateway.
|
|
func NewResolver(
|
|
log *slog.Logger,
|
|
modelsService *models.Service,
|
|
queries *sqlc.Queries,
|
|
memoryService *memory.Service,
|
|
conversationSvc ConversationSettingsReader,
|
|
messageService messagepkg.Service,
|
|
settingsService *settings.Service,
|
|
gatewayBaseURL string,
|
|
timeout time.Duration,
|
|
) *Resolver {
|
|
if strings.TrimSpace(gatewayBaseURL) == "" {
|
|
gatewayBaseURL = "http://127.0.0.1:8081"
|
|
}
|
|
gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/")
|
|
if timeout <= 0 {
|
|
timeout = 60 * time.Second
|
|
}
|
|
return &Resolver{
|
|
modelsService: modelsService,
|
|
queries: queries,
|
|
memoryService: memoryService,
|
|
conversationSvc: conversationSvc,
|
|
messageService: messageService,
|
|
settingsService: settingsService,
|
|
gatewayBaseURL: gatewayBaseURL,
|
|
timeout: timeout,
|
|
logger: log.With(slog.String("service", "conversation_resolver")),
|
|
httpClient: &http.Client{Timeout: timeout},
|
|
streamingClient: &http.Client{},
|
|
}
|
|
}
|
|
|
|
// SetSkillLoader sets the skill loader used to populate usable skills in gateway requests.
|
|
func (r *Resolver) SetSkillLoader(sl SkillLoader) {
|
|
r.skillLoader = sl
|
|
}
|
|
|
|
// SetGatewayAssetLoader configures optional asset loading used to inline
|
|
// attachments before calling the agent gateway.
|
|
func (r *Resolver) SetGatewayAssetLoader(loader gatewayAssetLoader) {
|
|
r.assetLoader = loader
|
|
}
|
|
|
|
// --- gateway payload ---
|
|
|
|
type gatewayModelConfig struct {
|
|
ModelID string `json:"modelId"`
|
|
ClientType string `json:"clientType"`
|
|
Input []string `json:"input"`
|
|
APIKey string `json:"apiKey"`
|
|
BaseURL string `json:"baseUrl"`
|
|
}
|
|
|
|
type gatewayIdentity struct {
|
|
BotID string `json:"botId"`
|
|
ContainerID string `json:"containerId"`
|
|
ChannelIdentityID string `json:"channelIdentityId"`
|
|
DisplayName string `json:"displayName"`
|
|
CurrentPlatform string `json:"currentPlatform,omitempty"`
|
|
ConversationType string `json:"conversationType,omitempty"`
|
|
SessionToken string `json:"sessionToken,omitempty"`
|
|
}
|
|
|
|
type gatewaySkill struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Content string `json:"content"`
|
|
Metadata map[string]any `json:"metadata,omitempty"`
|
|
}
|
|
|
|
type gatewayRequest struct {
|
|
Model gatewayModelConfig `json:"model"`
|
|
ActiveContextTime int `json:"activeContextTime"`
|
|
Channels []string `json:"channels"`
|
|
CurrentChannel string `json:"currentChannel"`
|
|
AllowedActions []string `json:"allowedActions,omitempty"`
|
|
Messages []conversation.ModelMessage `json:"messages"`
|
|
Skills []string `json:"skills"`
|
|
UsableSkills []gatewaySkill `json:"usableSkills"`
|
|
Query string `json:"query"`
|
|
Identity gatewayIdentity `json:"identity"`
|
|
Attachments []any `json:"attachments"`
|
|
}
|
|
|
|
type gatewayResponse struct {
|
|
Messages []conversation.ModelMessage `json:"messages"`
|
|
Skills []string `json:"skills"`
|
|
Usage json.RawMessage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type gatewayUsage struct {
|
|
InputTokens *int `json:"inputTokens"`
|
|
OutputTokens *int `json:"outputTokens"`
|
|
}
|
|
|
|
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
|
|
type gatewaySchedule struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Pattern string `json:"pattern"`
|
|
MaxCalls *int `json:"maxCalls,omitempty"`
|
|
Command string `json:"command"`
|
|
}
|
|
|
|
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
|
|
// It omits "query" from JSON so the trigger-schedule endpoint does not receive it.
|
|
type triggerScheduleRequest struct {
|
|
gatewayRequest
|
|
Schedule gatewaySchedule `json:"schedule"`
|
|
}
|
|
|
|
// MarshalJSON marshals the request without the "query" field for trigger-schedule.
|
|
func (t triggerScheduleRequest) MarshalJSON() ([]byte, error) {
|
|
type alias struct {
|
|
gatewayRequest
|
|
Schedule gatewaySchedule `json:"schedule"`
|
|
}
|
|
raw, err := json.Marshal(alias(t))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var m map[string]json.RawMessage
|
|
if err := json.Unmarshal(raw, &m); err != nil {
|
|
return nil, err
|
|
}
|
|
delete(m, "query")
|
|
return json.Marshal(m)
|
|
}
|
|
|
|
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
|
|
|
|
type resolvedContext struct {
|
|
payload gatewayRequest
|
|
model models.GetResponse
|
|
provider sqlc.LlmProvider
|
|
}
|
|
|
|
func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) {
|
|
if strings.TrimSpace(req.Query) == "" && len(req.Attachments) == 0 {
|
|
return resolvedContext{}, fmt.Errorf("query or attachments is required")
|
|
}
|
|
if strings.TrimSpace(req.BotID) == "" {
|
|
return resolvedContext{}, fmt.Errorf("bot id is required")
|
|
}
|
|
if strings.TrimSpace(req.ChatID) == "" {
|
|
return resolvedContext{}, fmt.Errorf("chat id is required")
|
|
}
|
|
|
|
skipHistory := req.MaxContextLoadTime < 0
|
|
|
|
botSettings, err := r.loadBotSettings(ctx, req.BotID)
|
|
if err != nil {
|
|
return resolvedContext{}, err
|
|
}
|
|
|
|
// Check chat-level model override.
|
|
var chatSettings conversation.Settings
|
|
if r.conversationSvc != nil {
|
|
chatSettings, err = r.conversationSvc.GetSettings(ctx, req.ChatID)
|
|
if err != nil {
|
|
return resolvedContext{}, err
|
|
}
|
|
}
|
|
|
|
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, chatSettings)
|
|
if err != nil {
|
|
return resolvedContext{}, err
|
|
}
|
|
clientType := string(chatModel.ClientType)
|
|
|
|
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
|
|
maxTokens := botSettings.MaxContextTokens
|
|
|
|
// Build non-history parts first so we can reserve their token cost before
|
|
// trimming history messages.
|
|
memoryMsg := r.loadMemoryContextMessage(ctx, req)
|
|
var overhead int
|
|
if memoryMsg != nil {
|
|
overhead += estimateMessageTokens(*memoryMsg)
|
|
}
|
|
for _, m := range req.Messages {
|
|
overhead += estimateMessageTokens(m)
|
|
}
|
|
// Reserve space for the system prompt built by the agent gateway
|
|
// (IDENTITY.md, SOUL.md, TOOLS.md, skills, boilerplate, user prompt, etc.).
|
|
const systemPromptReserve = 4096
|
|
overhead += systemPromptReserve
|
|
|
|
historyBudget := maxTokens - overhead
|
|
if historyBudget < 0 {
|
|
historyBudget = 0
|
|
}
|
|
|
|
var messages []conversation.ModelMessage
|
|
if !skipHistory && r.conversationSvc != nil {
|
|
loaded, loadErr := r.loadMessages(ctx, req.ChatID, maxCtx)
|
|
if loadErr != nil {
|
|
return resolvedContext{}, loadErr
|
|
}
|
|
messages = trimMessagesByTokens(loaded, historyBudget)
|
|
}
|
|
if memoryMsg != nil {
|
|
messages = append(messages, *memoryMsg)
|
|
}
|
|
messages = append(messages, req.Messages...)
|
|
messages = sanitizeMessages(messages)
|
|
skills := dedup(req.Skills)
|
|
containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID)
|
|
|
|
var usableSkills []gatewaySkill
|
|
if r.skillLoader != nil {
|
|
entries, err := r.skillLoader.LoadSkills(ctx, req.BotID)
|
|
if err != nil {
|
|
r.logger.Warn("failed to load usable skills", slog.String("bot_id", req.BotID), slog.Any("error", err))
|
|
} else {
|
|
usableSkills = make([]gatewaySkill, 0, len(entries))
|
|
for _, e := range entries {
|
|
skill, ok := normalizeGatewaySkill(e)
|
|
if !ok {
|
|
continue
|
|
}
|
|
usableSkills = append(usableSkills, skill)
|
|
}
|
|
}
|
|
}
|
|
if usableSkills == nil {
|
|
usableSkills = []gatewaySkill{}
|
|
}
|
|
|
|
payload := gatewayRequest{
|
|
Model: gatewayModelConfig{
|
|
ModelID: chatModel.ModelID,
|
|
ClientType: clientType,
|
|
Input: chatModel.InputModalities,
|
|
APIKey: provider.ApiKey,
|
|
BaseURL: provider.BaseUrl,
|
|
},
|
|
ActiveContextTime: maxCtx,
|
|
Channels: nonNilStrings(req.Channels),
|
|
CurrentChannel: req.CurrentChannel,
|
|
AllowedActions: req.AllowedActions,
|
|
Messages: nonNilModelMessages(messages),
|
|
Skills: nonNilStrings(skills),
|
|
UsableSkills: usableSkills,
|
|
Query: req.Query,
|
|
Identity: gatewayIdentity{
|
|
BotID: req.BotID,
|
|
ContainerID: containerID,
|
|
ChannelIdentityID: strings.TrimSpace(req.SourceChannelIdentityID),
|
|
DisplayName: r.resolveDisplayName(ctx, req),
|
|
CurrentPlatform: req.CurrentChannel,
|
|
ConversationType: strings.TrimSpace(req.ConversationType),
|
|
SessionToken: req.ChatToken,
|
|
},
|
|
Attachments: r.routeAndMergeAttachments(ctx, chatModel, req),
|
|
}
|
|
|
|
return resolvedContext{payload: payload, model: chatModel, provider: provider}, nil
|
|
}
|
|
|
|
// --- Chat ---
|
|
|
|
// Chat sends a synchronous chat request to the agent gateway and stores the result.
|
|
func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) {
|
|
rc, err := r.resolve(ctx, req)
|
|
if err != nil {
|
|
return conversation.ChatResponse{}, err
|
|
}
|
|
resp, err := r.postChat(ctx, rc.payload, req.Token)
|
|
if err != nil {
|
|
return conversation.ChatResponse{}, err
|
|
}
|
|
if err := r.storeRound(ctx, req, resp.Messages, resp.Usage); err != nil {
|
|
return conversation.ChatResponse{}, err
|
|
}
|
|
return conversation.ChatResponse{
|
|
Messages: resp.Messages,
|
|
Skills: resp.Skills,
|
|
Model: rc.model.ModelID,
|
|
Provider: string(rc.model.ClientType),
|
|
}, nil
|
|
}
|
|
|
|
// --- TriggerSchedule ---
|
|
|
|
// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint.
|
|
func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error {
|
|
if strings.TrimSpace(botID) == "" {
|
|
return fmt.Errorf("bot id is required")
|
|
}
|
|
if strings.TrimSpace(payload.Command) == "" {
|
|
return fmt.Errorf("schedule command is required")
|
|
}
|
|
|
|
req := conversation.ChatRequest{
|
|
BotID: botID,
|
|
ChatID: botID,
|
|
Query: payload.Command,
|
|
UserID: payload.OwnerUserID,
|
|
Token: token,
|
|
}
|
|
rc, err := r.resolve(ctx, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
schedulePayload := rc.payload
|
|
schedulePayload.Identity.ChannelIdentityID = strings.TrimSpace(payload.OwnerUserID)
|
|
schedulePayload.Identity.DisplayName = "Scheduler"
|
|
|
|
triggerReq := triggerScheduleRequest{
|
|
gatewayRequest: schedulePayload,
|
|
Schedule: gatewaySchedule{
|
|
ID: payload.ID,
|
|
Name: payload.Name,
|
|
Description: payload.Description,
|
|
Pattern: payload.Pattern,
|
|
MaxCalls: payload.MaxCalls,
|
|
Command: payload.Command,
|
|
},
|
|
}
|
|
|
|
resp, err := r.postTriggerSchedule(ctx, triggerReq, token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return r.storeRound(ctx, req, resp.Messages, resp.Usage)
|
|
}
|
|
|
|
// --- StreamChat ---
|
|
|
|
// StreamChat sends a streaming chat request to the agent gateway.
|
|
func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) {
|
|
chunkCh := make(chan conversation.StreamChunk)
|
|
errCh := make(chan error, 1)
|
|
r.logger.Info("gateway stream start",
|
|
slog.String("bot_id", req.BotID),
|
|
slog.String("chat_id", req.ChatID),
|
|
)
|
|
|
|
go func() {
|
|
defer close(chunkCh)
|
|
defer close(errCh)
|
|
|
|
streamReq := req
|
|
rc, err := r.resolve(ctx, streamReq)
|
|
if err != nil {
|
|
r.logger.Error("gateway stream resolve failed",
|
|
slog.String("bot_id", streamReq.BotID),
|
|
slog.String("chat_id", streamReq.ChatID),
|
|
slog.Any("error", err),
|
|
)
|
|
errCh <- err
|
|
return
|
|
}
|
|
if !streamReq.UserMessagePersisted {
|
|
if err := r.persistUserMessage(ctx, streamReq); err != nil {
|
|
r.logger.Error("gateway stream persist user message failed",
|
|
slog.String("bot_id", streamReq.BotID),
|
|
slog.String("chat_id", streamReq.ChatID),
|
|
slog.Any("error", err),
|
|
)
|
|
errCh <- err
|
|
return
|
|
}
|
|
streamReq.UserMessagePersisted = true
|
|
}
|
|
if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil {
|
|
r.logger.Error("gateway stream request failed",
|
|
slog.String("bot_id", streamReq.BotID),
|
|
slog.String("chat_id", streamReq.ChatID),
|
|
slog.Any("error", err),
|
|
)
|
|
errCh <- err
|
|
}
|
|
}()
|
|
return chunkCh, errCh
|
|
}
|
|
|
|
// --- HTTP helpers ---
|
|
|
|
func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) {
|
|
url := r.gatewayBaseURL + "/chat/"
|
|
r.logger.Info(
|
|
"gateway request",
|
|
slog.String("url", url),
|
|
slog.Int("messages", len(payload.Messages)),
|
|
slog.Int("attachments", len(payload.Attachments)),
|
|
)
|
|
|
|
httpReq, err := newJSONRequestWithContext(ctx, http.MethodPost, url, payload)
|
|
if err != nil {
|
|
return gatewayResponse{}, err
|
|
}
|
|
if strings.TrimSpace(token) != "" {
|
|
httpReq.Header.Set("Authorization", token)
|
|
}
|
|
|
|
resp, err := r.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return gatewayResponse{}, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return gatewayResponse{}, err
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
r.logger.Error("gateway error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
|
|
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
|
|
}
|
|
|
|
var parsed gatewayResponse
|
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
|
r.logger.Error("gateway response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
|
|
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
|
|
}
|
|
return parsed, nil
|
|
}
|
|
|
|
// postTriggerSchedule sends a trigger-schedule request to the agent gateway.
|
|
func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) {
|
|
url := r.gatewayBaseURL + "/chat/trigger-schedule"
|
|
r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID))
|
|
|
|
httpReq, err := newJSONRequestWithContext(ctx, http.MethodPost, url, payload)
|
|
if err != nil {
|
|
return gatewayResponse{}, err
|
|
}
|
|
if strings.TrimSpace(token) != "" {
|
|
httpReq.Header.Set("Authorization", token)
|
|
}
|
|
|
|
resp, err := r.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return gatewayResponse{}, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return gatewayResponse{}, err
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
|
|
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
|
|
}
|
|
|
|
var parsed gatewayResponse
|
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
|
r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
|
|
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
|
|
}
|
|
return parsed, nil
|
|
}
|
|
|
|
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req conversation.ChatRequest, chunkCh chan<- conversation.StreamChunk) error {
|
|
url := r.gatewayBaseURL + "/chat/stream"
|
|
r.logger.Info(
|
|
"gateway stream request",
|
|
slog.String("url", url),
|
|
slog.Int("messages", len(payload.Messages)),
|
|
slog.Int("attachments", len(payload.Attachments)),
|
|
)
|
|
httpReq, err := newJSONRequestWithContext(ctx, http.MethodPost, url, payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
httpReq.Header.Set("Accept", "text/event-stream")
|
|
if strings.TrimSpace(req.Token) != "" {
|
|
httpReq.Header.Set("Authorization", req.Token)
|
|
}
|
|
|
|
resp, err := r.streamingClient.Do(httpReq)
|
|
if err != nil {
|
|
r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err))
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
errBody, _ := io.ReadAll(resp.Body)
|
|
r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300)))
|
|
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody)))
|
|
}
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
|
|
|
currentEvent := ""
|
|
stored := false
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(line, "event:") {
|
|
currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
|
continue
|
|
}
|
|
if !strings.HasPrefix(line, "data:") {
|
|
continue
|
|
}
|
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
if data == "" || data == "[DONE]" {
|
|
continue
|
|
}
|
|
chunkCh <- conversation.StreamChunk([]byte(data))
|
|
|
|
if stored {
|
|
continue
|
|
}
|
|
if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil {
|
|
return storeErr
|
|
} else if handled {
|
|
stored = true
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
func newJSONRequestWithContext(ctx context.Context, method, url string, payload any) (*http.Request, error) {
|
|
pr, pw := io.Pipe()
|
|
go func() {
|
|
enc := json.NewEncoder(pw)
|
|
_ = pw.CloseWithError(enc.Encode(payload))
|
|
}()
|
|
req, err := http.NewRequestWithContext(ctx, method, url, pr)
|
|
if err != nil {
|
|
_ = pr.Close()
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
return req, nil
|
|
}
|
|
|
|
// tryStoreStream attempts to extract final messages from a stream event and persist them.
|
|
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, eventType, data string) (bool, error) {
|
|
// event: done + data: {messages: [...]}
|
|
if eventType == "done" {
|
|
var resp gatewayResponse
|
|
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
|
|
return true, r.storeRound(ctx, req, resp.Messages, resp.Usage)
|
|
}
|
|
}
|
|
|
|
// data: {"type":"text_delta"|"agent_end"|"done", ...}
|
|
var envelope struct {
|
|
Type string `json:"type"`
|
|
Data json.RawMessage `json:"data"`
|
|
Messages []conversation.ModelMessage `json:"messages"`
|
|
Skills []string `json:"skills"`
|
|
Usage json.RawMessage `json:"usage,omitempty"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &envelope); err == nil {
|
|
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
|
|
return true, r.storeRound(ctx, req, envelope.Messages, envelope.Usage)
|
|
}
|
|
if envelope.Type == "done" && len(envelope.Data) > 0 {
|
|
var resp gatewayResponse
|
|
if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 {
|
|
return true, r.storeRound(ctx, req, resp.Messages, resp.Usage)
|
|
}
|
|
}
|
|
}
|
|
|
|
// fallback: data: {messages: [...]}
|
|
var resp gatewayResponse
|
|
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
|
|
return true, r.storeRound(ctx, req, resp.Messages, resp.Usage)
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// routeAndMergeAttachments applies CapabilityFallbackPolicy to split
|
|
// request attachments by model input modalities, then merges the results
|
|
// into a single []any for the gateway request.
|
|
func (r *Resolver) routeAndMergeAttachments(ctx context.Context, model models.GetResponse, req conversation.ChatRequest) []any {
|
|
if len(req.Attachments) == 0 {
|
|
return []any{}
|
|
}
|
|
typed := r.prepareGatewayAttachments(ctx, req)
|
|
routed := routeAttachmentsByCapability(model.InputModalities, typed)
|
|
// Convert unsupported attachments to tool file references.
|
|
for i := range routed.Fallback {
|
|
fallbackPath := strings.TrimSpace(routed.Fallback[i].FallbackPath)
|
|
if fallbackPath == "" {
|
|
// Cannot downgrade non-file payloads to tool file references.
|
|
// Drop them explicitly to keep gateway contract deterministic.
|
|
if r != nil && r.logger != nil {
|
|
r.logger.Warn(
|
|
"drop attachment without fallback path",
|
|
slog.String("type", strings.TrimSpace(routed.Fallback[i].Type)),
|
|
slog.String("transport", strings.TrimSpace(routed.Fallback[i].Transport)),
|
|
slog.String("content_hash", strings.TrimSpace(routed.Fallback[i].ContentHash)),
|
|
slog.Bool("has_payload", strings.TrimSpace(routed.Fallback[i].Payload) != ""),
|
|
)
|
|
}
|
|
routed.Fallback[i] = gatewayAttachment{}
|
|
continue
|
|
}
|
|
routed.Fallback[i].Type = "file"
|
|
routed.Fallback[i].Transport = gatewayTransportToolFileRef
|
|
routed.Fallback[i].Payload = fallbackPath
|
|
}
|
|
merged := make([]any, 0, len(routed.Native)+len(routed.Fallback))
|
|
merged = append(merged, attachmentsToAny(routed.Native)...)
|
|
for _, fb := range routed.Fallback {
|
|
if fb.Type == "" || strings.TrimSpace(fb.Transport) == "" || strings.TrimSpace(fb.Payload) == "" {
|
|
continue
|
|
}
|
|
merged = append(merged, fb)
|
|
}
|
|
if len(merged) == 0 {
|
|
return []any{}
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func (r *Resolver) prepareGatewayAttachments(ctx context.Context, req conversation.ChatRequest) []gatewayAttachment {
|
|
if len(req.Attachments) == 0 {
|
|
return nil
|
|
}
|
|
prepared := make([]gatewayAttachment, 0, len(req.Attachments))
|
|
for _, raw := range req.Attachments {
|
|
attachmentType := strings.ToLower(strings.TrimSpace(raw.Type))
|
|
payload := strings.TrimSpace(raw.Base64)
|
|
transport := ""
|
|
fallbackPath := strings.TrimSpace(raw.Path)
|
|
if payload != "" {
|
|
transport = gatewayTransportInlineDataURL
|
|
} else {
|
|
rawURL := strings.TrimSpace(raw.URL)
|
|
if isDataURL(rawURL) {
|
|
payload = rawURL
|
|
transport = gatewayTransportInlineDataURL
|
|
} else if isLikelyPublicURL(rawURL) {
|
|
payload = rawURL
|
|
transport = gatewayTransportPublicURL
|
|
} else if rawURL != "" && fallbackPath == "" {
|
|
fallbackPath = rawURL
|
|
}
|
|
}
|
|
item := gatewayAttachment{
|
|
ContentHash: strings.TrimSpace(raw.ContentHash),
|
|
Type: attachmentType,
|
|
Mime: strings.TrimSpace(raw.Mime),
|
|
Size: raw.Size,
|
|
Name: strings.TrimSpace(raw.Name),
|
|
Transport: transport,
|
|
Payload: payload,
|
|
Metadata: raw.Metadata,
|
|
FallbackPath: fallbackPath,
|
|
}
|
|
item = normalizeGatewayAttachmentPayload(item)
|
|
item = r.inlineImageAttachmentAssetIfNeeded(ctx, strings.TrimSpace(req.BotID), item)
|
|
prepared = append(prepared, item)
|
|
}
|
|
return prepared
|
|
}
|
|
|
|
func normalizeGatewayAttachmentPayload(item gatewayAttachment) gatewayAttachment {
|
|
if item.Transport != gatewayTransportInlineDataURL {
|
|
return item
|
|
}
|
|
payload := strings.TrimSpace(item.Payload)
|
|
if payload == "" {
|
|
return item
|
|
}
|
|
lower := strings.ToLower(payload)
|
|
if strings.HasPrefix(lower, "data:") {
|
|
if strings.TrimSpace(item.Mime) == "" || strings.EqualFold(strings.TrimSpace(item.Mime), "application/octet-stream") {
|
|
if start := strings.Index(payload, ":"); start >= 0 {
|
|
rest := payload[start+1:]
|
|
if end := strings.Index(rest, ";"); end > 0 {
|
|
mime := strings.TrimSpace(rest[:end])
|
|
if mime != "" {
|
|
item.Mime = mime
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return item
|
|
}
|
|
mime := strings.TrimSpace(item.Mime)
|
|
if mime == "" {
|
|
mime = "application/octet-stream"
|
|
}
|
|
item.Payload = attachmentpkg.NormalizeBase64DataURL(payload, mime)
|
|
return item
|
|
}
|
|
|
|
func isLikelyPublicURL(raw string) bool {
|
|
trimmed := strings.ToLower(strings.TrimSpace(raw))
|
|
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
|
|
}
|
|
|
|
func isDataURL(raw string) bool {
|
|
trimmed := strings.ToLower(strings.TrimSpace(raw))
|
|
return strings.HasPrefix(trimmed, "data:")
|
|
}
|
|
|
|
func (r *Resolver) inlineImageAttachmentAssetIfNeeded(ctx context.Context, botID string, item gatewayAttachment) gatewayAttachment {
|
|
if item.Type != "image" {
|
|
return item
|
|
}
|
|
if strings.TrimSpace(item.Payload) != "" &&
|
|
(item.Transport == gatewayTransportInlineDataURL || item.Transport == gatewayTransportPublicURL) {
|
|
return item
|
|
}
|
|
contentHash := strings.TrimSpace(item.ContentHash)
|
|
if contentHash == "" {
|
|
return item
|
|
}
|
|
dataURL, mime, err := r.inlineAssetAsDataURL(ctx, botID, contentHash, item.Type, item.Mime)
|
|
if err != nil {
|
|
if r != nil && r.logger != nil {
|
|
r.logger.Warn(
|
|
"inline gateway image attachment failed",
|
|
slog.Any("error", err),
|
|
slog.String("bot_id", botID),
|
|
slog.String("content_hash", contentHash),
|
|
)
|
|
}
|
|
return item
|
|
}
|
|
item.Transport = gatewayTransportInlineDataURL
|
|
item.Payload = dataURL
|
|
if strings.TrimSpace(item.Mime) == "" {
|
|
item.Mime = mime
|
|
}
|
|
return item
|
|
}
|
|
|
|
func (r *Resolver) inlineAssetAsDataURL(ctx context.Context, botID, contentHash, attachmentType, fallbackMime string) (string, string, error) {
|
|
if r == nil || r.assetLoader == nil {
|
|
return "", "", fmt.Errorf("gateway asset loader not configured")
|
|
}
|
|
reader, assetMime, err := r.assetLoader.OpenForGateway(ctx, botID, contentHash)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("open asset: %w", err)
|
|
}
|
|
defer func() {
|
|
_ = reader.Close()
|
|
}()
|
|
mime := strings.TrimSpace(fallbackMime)
|
|
if mime == "" {
|
|
mime = strings.TrimSpace(assetMime)
|
|
}
|
|
dataURL, resolvedMime, err := encodeReaderAsDataURL(reader, gatewayInlineAttachmentMaxBytes, attachmentType, mime)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
return dataURL, resolvedMime, nil
|
|
}
|
|
|
|
func encodeReaderAsDataURL(reader io.Reader, maxBytes int64, attachmentType, fallbackMime string) (string, string, error) {
|
|
if reader == nil {
|
|
return "", "", fmt.Errorf("reader is required")
|
|
}
|
|
if maxBytes <= 0 {
|
|
return "", "", fmt.Errorf("max bytes must be greater than 0")
|
|
}
|
|
limited := &io.LimitedReader{R: reader, N: maxBytes + 1}
|
|
head := make([]byte, 512)
|
|
n, err := limited.Read(head)
|
|
if err != nil && err != io.EOF {
|
|
return "", "", fmt.Errorf("read asset: %w", err)
|
|
}
|
|
head = head[:n]
|
|
|
|
mime := strings.TrimSpace(fallbackMime)
|
|
if strings.EqualFold(strings.TrimSpace(attachmentType), "image") &&
|
|
(strings.TrimSpace(mime) == "" || strings.EqualFold(strings.TrimSpace(mime), "application/octet-stream")) {
|
|
detected := strings.TrimSpace(http.DetectContentType(head))
|
|
if strings.HasPrefix(strings.ToLower(detected), "image/") {
|
|
mime = detected
|
|
}
|
|
}
|
|
if mime == "" {
|
|
mime = "application/octet-stream"
|
|
}
|
|
|
|
var encoded strings.Builder
|
|
encoded.Grow(len("data:") + len(mime) + len(";base64,"))
|
|
encoded.WriteString("data:")
|
|
encoded.WriteString(mime)
|
|
encoded.WriteString(";base64,")
|
|
|
|
encoder := base64.NewEncoder(base64.StdEncoding, &encoded)
|
|
if len(head) > 0 {
|
|
if _, err := encoder.Write(head); err != nil {
|
|
_ = encoder.Close()
|
|
return "", "", fmt.Errorf("encode asset head: %w", err)
|
|
}
|
|
}
|
|
copied, err := io.Copy(encoder, limited)
|
|
if err != nil {
|
|
_ = encoder.Close()
|
|
return "", "", fmt.Errorf("encode asset body: %w", err)
|
|
}
|
|
if err := encoder.Close(); err != nil {
|
|
return "", "", fmt.Errorf("finalize asset encoding: %w", err)
|
|
}
|
|
|
|
total := int64(len(head)) + copied
|
|
if total > maxBytes {
|
|
return "", "", fmt.Errorf(
|
|
"asset too large to inline: %d > %d",
|
|
total,
|
|
maxBytes,
|
|
)
|
|
}
|
|
return encoded.String(), mime, nil
|
|
}
|
|
|
|
// --- container resolution ---
|
|
|
|
func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string {
|
|
if strings.TrimSpace(explicit) != "" {
|
|
return explicit
|
|
}
|
|
if r.queries != nil {
|
|
pgBotID, err := parseResolverUUID(botID)
|
|
if err == nil {
|
|
row, err := r.queries.GetContainerByBotID(ctx, pgBotID)
|
|
if err == nil && strings.TrimSpace(row.ContainerID) != "" {
|
|
return row.ContainerID
|
|
}
|
|
}
|
|
}
|
|
r.logger.Warn("no container found for bot, using fallback", slog.String("bot_id", botID))
|
|
return "mcp-" + botID
|
|
}
|
|
|
|
// --- message loading ---
|
|
|
|
type messageWithUsage struct {
|
|
Message conversation.ModelMessage
|
|
UsageInputTokens *int
|
|
}
|
|
|
|
func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]messageWithUsage, error) {
|
|
if r.messageService == nil {
|
|
return nil, nil
|
|
}
|
|
since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute)
|
|
msgs, err := r.messageService.ListSince(ctx, chatID, since)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var result []messageWithUsage
|
|
for _, m := range msgs {
|
|
var mm conversation.ModelMessage
|
|
if err := json.Unmarshal(m.Content, &mm); err != nil {
|
|
r.logger.Warn("loadMessages: content unmarshal failed, treating as raw text",
|
|
slog.String("chat_id", chatID), slog.Any("error", err))
|
|
mm = conversation.ModelMessage{Role: m.Role, Content: m.Content}
|
|
} else {
|
|
mm.Role = m.Role
|
|
}
|
|
var inputTokens *int
|
|
if len(m.Usage) > 0 {
|
|
var u gatewayUsage
|
|
if json.Unmarshal(m.Usage, &u) == nil {
|
|
inputTokens = u.InputTokens
|
|
}
|
|
}
|
|
result = append(result, messageWithUsage{Message: mm, UsageInputTokens: inputTokens})
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func estimateMessageTokens(msg conversation.ModelMessage) int {
|
|
text := msg.TextContent()
|
|
if len(text) == 0 {
|
|
data, _ := json.Marshal(msg.Content)
|
|
return len(data) / 4
|
|
}
|
|
return len(text) / 4
|
|
}
|
|
|
|
func trimMessagesByTokens(messages []messageWithUsage, maxTokens int) []conversation.ModelMessage {
|
|
if maxTokens <= 0 || len(messages) == 0 {
|
|
result := make([]conversation.ModelMessage, len(messages))
|
|
for i, m := range messages {
|
|
result[i] = m.Message
|
|
}
|
|
return result
|
|
}
|
|
|
|
// Scan backwards. When a message with UsageInputTokens is found, that value
|
|
// represents the cumulative input tokens for all messages up to and including
|
|
// that message. Messages after it are estimated with chars/4.
|
|
totalTokens := 0
|
|
anchorFound := false
|
|
cutoff := 0
|
|
|
|
tailEstimate := 0
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if !anchorFound && messages[i].UsageInputTokens != nil {
|
|
anchorFound = true
|
|
totalTokens = *messages[i].UsageInputTokens + tailEstimate
|
|
if totalTokens > maxTokens {
|
|
cutoff = i + 1
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
est := estimateMessageTokens(messages[i].Message)
|
|
if anchorFound {
|
|
totalTokens += est
|
|
if totalTokens > maxTokens {
|
|
cutoff = i + 1
|
|
break
|
|
}
|
|
} else {
|
|
tailEstimate += est
|
|
}
|
|
}
|
|
|
|
if !anchorFound {
|
|
totalTokens = 0
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
totalTokens += estimateMessageTokens(messages[i].Message)
|
|
if totalTokens > maxTokens {
|
|
cutoff = i + 1
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Keep provider-valid message order: a "tool" message must follow a preceding
|
|
// assistant tool call. When history is head-trimmed, a leading tool message
|
|
// may become orphaned and cause provider 400 errors.
|
|
for cutoff < len(messages) && strings.EqualFold(strings.TrimSpace(messages[cutoff].Message.Role), "tool") {
|
|
cutoff++
|
|
}
|
|
|
|
result := make([]conversation.ModelMessage, 0, len(messages)-cutoff)
|
|
for _, m := range messages[cutoff:] {
|
|
result = append(result, m.Message)
|
|
}
|
|
return result
|
|
}
|
|
|
|
type memoryContextItem struct {
|
|
Namespace string
|
|
Item memory.MemoryItem
|
|
}
|
|
|
|
func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversation.ChatRequest) *conversation.ModelMessage {
|
|
if r.memoryService == nil {
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" {
|
|
return nil
|
|
}
|
|
|
|
results := make([]memoryContextItem, 0, memoryContextLimitPerScope)
|
|
seen := map[string]struct{}{}
|
|
resp, err := r.memoryService.Search(ctx, memory.SearchRequest{
|
|
Query: req.Query,
|
|
BotID: req.BotID,
|
|
Limit: memoryContextLimitPerScope,
|
|
Filters: map[string]any{
|
|
"namespace": sharedMemoryNamespace,
|
|
"scopeId": req.BotID,
|
|
"bot_id": req.BotID,
|
|
},
|
|
NoStats: true,
|
|
})
|
|
if err != nil {
|
|
r.logger.Warn("memory search for context failed",
|
|
slog.String("namespace", sharedMemoryNamespace),
|
|
slog.Any("error", err),
|
|
)
|
|
return nil
|
|
}
|
|
for _, item := range resp.Results {
|
|
key := strings.TrimSpace(item.ID)
|
|
if key == "" {
|
|
key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory)
|
|
}
|
|
if key == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[key]; ok {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item})
|
|
}
|
|
if len(results) == 0 {
|
|
return nil
|
|
}
|
|
|
|
sort.Slice(results, func(i, j int) bool {
|
|
return results[i].Item.Score > results[j].Item.Score
|
|
})
|
|
if len(results) > memoryContextMaxItems {
|
|
results = results[:memoryContextMaxItems]
|
|
}
|
|
|
|
var sb strings.Builder
|
|
sb.WriteString("Relevant memory context (use when helpful):\n")
|
|
for _, entry := range results {
|
|
text := strings.TrimSpace(entry.Item.Memory)
|
|
if text == "" {
|
|
continue
|
|
}
|
|
sb.WriteString("- [")
|
|
sb.WriteString(entry.Namespace)
|
|
sb.WriteString("] ")
|
|
sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars))
|
|
sb.WriteString("\n")
|
|
}
|
|
payload := strings.TrimSpace(sb.String())
|
|
if payload == "" {
|
|
return nil
|
|
}
|
|
msg := conversation.ModelMessage{
|
|
Role: "user",
|
|
Content: conversation.NewTextContent(payload),
|
|
}
|
|
return &msg
|
|
}
|
|
|
|
// --- store helpers ---
|
|
|
|
func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.ChatRequest) error {
|
|
if r.messageService == nil {
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(req.BotID) == "" {
|
|
return fmt.Errorf("bot id is required for persistence")
|
|
}
|
|
text := strings.TrimSpace(req.Query)
|
|
if text == "" && len(req.Attachments) == 0 {
|
|
return nil
|
|
}
|
|
|
|
message := conversation.ModelMessage{
|
|
Role: "user",
|
|
Content: conversation.NewTextContent(text),
|
|
}
|
|
content, err := json.Marshal(message)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
|
|
_, err = r.messageService.Persist(ctx, messagepkg.PersistInput{
|
|
BotID: req.BotID,
|
|
RouteID: req.RouteID,
|
|
SenderChannelIdentityID: senderChannelIdentityID,
|
|
SenderUserID: senderUserID,
|
|
Platform: req.CurrentChannel,
|
|
ExternalMessageID: req.ExternalMessageID,
|
|
Role: "user",
|
|
Content: content,
|
|
Metadata: buildRouteMetadata(req),
|
|
Assets: chatAttachmentsToAssetRefs(req.Attachments),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage, usage json.RawMessage) error {
|
|
fullRound := make([]conversation.ModelMessage, 0, len(messages))
|
|
for _, m := range messages {
|
|
if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
|
|
continue
|
|
}
|
|
fullRound = append(fullRound, m)
|
|
}
|
|
if len(fullRound) == 0 {
|
|
return nil
|
|
}
|
|
|
|
r.storeMessages(ctx, req, fullRound, usage)
|
|
go r.storeMemory(context.WithoutCancel(ctx), req.BotID, fullRound)
|
|
return nil
|
|
}
|
|
|
|
func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage, usage json.RawMessage) {
|
|
if r.messageService == nil {
|
|
return
|
|
}
|
|
if strings.TrimSpace(req.BotID) == "" {
|
|
return
|
|
}
|
|
meta := buildRouteMetadata(req)
|
|
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
|
|
|
|
// Determine the last assistant message index for outbound asset attachment.
|
|
lastAssistantIdx := -1
|
|
if req.OutboundAssetCollector != nil {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if messages[i].Role == "assistant" {
|
|
lastAssistantIdx = i
|
|
break
|
|
}
|
|
}
|
|
}
|
|
var outboundAssets []messagepkg.AssetRef
|
|
if lastAssistantIdx >= 0 {
|
|
outboundAssets = outboundAssetRefsToMessageRefs(req.OutboundAssetCollector())
|
|
}
|
|
|
|
for i, msg := range messages {
|
|
content, err := json.Marshal(msg)
|
|
if err != nil {
|
|
r.logger.Warn("storeMessages: marshal failed", slog.Any("error", err))
|
|
continue
|
|
}
|
|
messageSenderChannelIdentityID := ""
|
|
messageSenderUserID := ""
|
|
externalMessageID := ""
|
|
sourceReplyToMessageID := ""
|
|
assets := []messagepkg.AssetRef(nil)
|
|
if msg.Role == "user" {
|
|
messageSenderChannelIdentityID = senderChannelIdentityID
|
|
messageSenderUserID = senderUserID
|
|
externalMessageID = req.ExternalMessageID
|
|
if strings.TrimSpace(msg.TextContent()) == strings.TrimSpace(req.Query) {
|
|
assets = chatAttachmentsToAssetRefs(req.Attachments)
|
|
}
|
|
} else if strings.TrimSpace(req.ExternalMessageID) != "" {
|
|
sourceReplyToMessageID = req.ExternalMessageID
|
|
}
|
|
if i == lastAssistantIdx && len(outboundAssets) > 0 {
|
|
assets = append(assets, outboundAssets...)
|
|
}
|
|
var msgUsage json.RawMessage
|
|
if i == len(messages)-1 && len(usage) > 0 {
|
|
msgUsage = usage
|
|
}
|
|
if _, err := r.messageService.Persist(ctx, messagepkg.PersistInput{
|
|
BotID: req.BotID,
|
|
RouteID: req.RouteID,
|
|
SenderChannelIdentityID: messageSenderChannelIdentityID,
|
|
SenderUserID: messageSenderUserID,
|
|
Platform: req.CurrentChannel,
|
|
ExternalMessageID: externalMessageID,
|
|
SourceReplyToMessageID: sourceReplyToMessageID,
|
|
Role: msg.Role,
|
|
Content: content,
|
|
Metadata: meta,
|
|
Usage: msgUsage,
|
|
Assets: assets,
|
|
}); err != nil {
|
|
r.logger.Warn("persist message failed", slog.Any("error", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
// outboundAssetRefsToMessageRefs converts outbound asset refs from the streaming
|
|
// collector into message-level asset refs for persistence.
|
|
func outboundAssetRefsToMessageRefs(refs []conversation.OutboundAssetRef) []messagepkg.AssetRef {
|
|
if len(refs) == 0 {
|
|
return nil
|
|
}
|
|
result := make([]messagepkg.AssetRef, 0, len(refs))
|
|
for _, ref := range refs {
|
|
contentHash := strings.TrimSpace(ref.ContentHash)
|
|
if contentHash == "" {
|
|
continue
|
|
}
|
|
role := ref.Role
|
|
if strings.TrimSpace(role) == "" {
|
|
role = "attachment"
|
|
}
|
|
result = append(result, messagepkg.AssetRef{
|
|
ContentHash: contentHash,
|
|
Role: role,
|
|
Ordinal: ref.Ordinal,
|
|
Mime: ref.Mime,
|
|
SizeBytes: ref.SizeBytes,
|
|
StorageKey: ref.StorageKey,
|
|
})
|
|
}
|
|
return result
|
|
}
|
|
|
|
// chatAttachmentsToAssetRefs converts ChatAttachment slice to message AssetRef slice.
|
|
// Only attachments that carry a content_hash are included.
|
|
func chatAttachmentsToAssetRefs(attachments []conversation.ChatAttachment) []messagepkg.AssetRef {
|
|
if len(attachments) == 0 {
|
|
return nil
|
|
}
|
|
refs := make([]messagepkg.AssetRef, 0, len(attachments))
|
|
for i, att := range attachments {
|
|
contentHash := strings.TrimSpace(att.ContentHash)
|
|
if contentHash == "" {
|
|
continue
|
|
}
|
|
ref := messagepkg.AssetRef{
|
|
ContentHash: contentHash,
|
|
Role: "attachment",
|
|
Ordinal: i,
|
|
Mime: strings.TrimSpace(att.Mime),
|
|
SizeBytes: att.Size,
|
|
}
|
|
if att.Metadata != nil {
|
|
if sk, ok := att.Metadata["storage_key"].(string); ok {
|
|
ref.StorageKey = sk
|
|
}
|
|
}
|
|
refs = append(refs, ref)
|
|
}
|
|
return refs
|
|
}
|
|
|
|
func buildRouteMetadata(req conversation.ChatRequest) map[string]any {
|
|
if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" {
|
|
return nil
|
|
}
|
|
meta := map[string]any{}
|
|
if strings.TrimSpace(req.RouteID) != "" {
|
|
meta["route_id"] = req.RouteID
|
|
}
|
|
if strings.TrimSpace(req.CurrentChannel) != "" {
|
|
meta["platform"] = req.CurrentChannel
|
|
}
|
|
return meta
|
|
}
|
|
|
|
func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req conversation.ChatRequest) (string, string) {
|
|
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
|
|
userID := strings.TrimSpace(req.UserID)
|
|
|
|
senderChannelIdentityID := ""
|
|
if r.isExistingChannelIdentityID(ctx, channelIdentityID) {
|
|
senderChannelIdentityID = channelIdentityID
|
|
}
|
|
|
|
senderUserID := ""
|
|
if r.isExistingUserID(ctx, userID) {
|
|
senderUserID = userID
|
|
}
|
|
if senderUserID == "" && senderChannelIdentityID != "" {
|
|
if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" {
|
|
senderUserID = linked
|
|
}
|
|
}
|
|
return senderChannelIdentityID, senderUserID
|
|
}
|
|
|
|
func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool {
|
|
if r.queries == nil {
|
|
return false
|
|
}
|
|
pgID, err := parseResolverUUID(id)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
_, err = r.queries.GetChannelIdentityByID(ctx, pgID)
|
|
return err == nil
|
|
}
|
|
|
|
func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool {
|
|
if r.queries == nil {
|
|
return false
|
|
}
|
|
pgID, err := parseResolverUUID(id)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
_, err = r.queries.GetUserByID(ctx, pgID)
|
|
return err == nil
|
|
}
|
|
|
|
func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string {
|
|
if r.queries == nil {
|
|
return ""
|
|
}
|
|
pgID, err := parseResolverUUID(channelIdentityID)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
row, err := r.queries.GetChannelIdentityByID(ctx, pgID)
|
|
if err != nil || !row.UserID.Valid {
|
|
return ""
|
|
}
|
|
return row.UserID.String()
|
|
}
|
|
|
|
// resolveDisplayName returns the best available display name for the request identity:
|
|
// req.DisplayName if set, else channel identity's display_name, else linked user's display_name, else "User".
|
|
func (r *Resolver) resolveDisplayName(ctx context.Context, req conversation.ChatRequest) string {
|
|
if name := strings.TrimSpace(req.DisplayName); name != "" {
|
|
return name
|
|
}
|
|
if r.queries == nil {
|
|
return "User"
|
|
}
|
|
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
|
|
if channelIdentityID == "" {
|
|
return "User"
|
|
}
|
|
pgID, err := parseResolverUUID(channelIdentityID)
|
|
if err != nil {
|
|
return "User"
|
|
}
|
|
ci, err := r.queries.GetChannelIdentityByID(ctx, pgID)
|
|
if err == nil && ci.DisplayName.Valid {
|
|
if name := strings.TrimSpace(ci.DisplayName.String); name != "" {
|
|
return name
|
|
}
|
|
}
|
|
linkedUserID := r.linkedUserIDFromChannelIdentity(ctx, channelIdentityID)
|
|
if linkedUserID == "" {
|
|
return "User"
|
|
}
|
|
userPgID, err := parseResolverUUID(linkedUserID)
|
|
if err != nil {
|
|
return "User"
|
|
}
|
|
u, err := r.queries.GetUserByID(ctx, userPgID)
|
|
if err != nil || !u.DisplayName.Valid {
|
|
return "User"
|
|
}
|
|
if name := strings.TrimSpace(u.DisplayName.String); name != "" {
|
|
return name
|
|
}
|
|
return "User"
|
|
}
|
|
|
|
func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []conversation.ModelMessage) {
|
|
if r.memoryService == nil {
|
|
return
|
|
}
|
|
if strings.TrimSpace(botID) == "" {
|
|
return
|
|
}
|
|
memMsgs := make([]memory.Message, 0, len(messages))
|
|
for _, msg := range messages {
|
|
text := strings.TrimSpace(msg.TextContent())
|
|
if text == "" {
|
|
continue
|
|
}
|
|
role := msg.Role
|
|
if strings.TrimSpace(role) == "" {
|
|
role = "assistant"
|
|
}
|
|
memMsgs = append(memMsgs, memory.Message{Role: role, Content: text})
|
|
}
|
|
if len(memMsgs) == 0 {
|
|
return
|
|
}
|
|
r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID)
|
|
}
|
|
|
|
func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) {
|
|
filters := map[string]any{
|
|
"namespace": namespace,
|
|
"scopeId": scopeID,
|
|
"bot_id": botID,
|
|
}
|
|
if _, err := r.memoryService.Add(ctx, memory.AddRequest{
|
|
Messages: msgs,
|
|
BotID: botID,
|
|
Filters: filters,
|
|
}); err != nil {
|
|
r.logger.Warn("store memory failed",
|
|
slog.String("namespace", namespace),
|
|
slog.String("scope_id", scopeID),
|
|
slog.Any("error", err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// --- model selection ---
|
|
|
|
func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.LlmProvider, error) {
|
|
if r.modelsService == nil {
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
|
|
}
|
|
modelID := strings.TrimSpace(req.Model)
|
|
providerFilter := strings.TrimSpace(req.Provider)
|
|
|
|
// Priority: request model > chat settings > bot settings.
|
|
if modelID == "" && providerFilter == "" {
|
|
if value := strings.TrimSpace(cs.ModelID); value != "" {
|
|
modelID = value
|
|
} else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" {
|
|
modelID = value
|
|
}
|
|
}
|
|
|
|
if modelID == "" {
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings")
|
|
}
|
|
|
|
if providerFilter == "" {
|
|
return r.fetchChatModel(ctx, modelID)
|
|
}
|
|
|
|
candidates, err := r.listCandidates(ctx, providerFilter)
|
|
if err != nil {
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
|
}
|
|
for _, m := range candidates {
|
|
if m.ModelID == modelID {
|
|
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
|
|
if err != nil {
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
|
}
|
|
return m, prov, nil
|
|
}
|
|
}
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter)
|
|
}
|
|
|
|
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
|
|
model, err := r.modelsService.GetByModelID(ctx, modelID)
|
|
if err != nil {
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
|
}
|
|
if model.Type != models.ModelTypeChat {
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
|
|
}
|
|
prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
|
|
if err != nil {
|
|
return models.GetResponse{}, sqlc.LlmProvider{}, err
|
|
}
|
|
return model, prov, nil
|
|
}
|
|
|
|
func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) {
|
|
var all []models.GetResponse
|
|
var err error
|
|
if providerFilter != "" {
|
|
all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter))
|
|
} else {
|
|
all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
filtered := make([]models.GetResponse, 0, len(all))
|
|
for _, m := range all {
|
|
if m.Type == models.ModelTypeChat {
|
|
filtered = append(filtered, m)
|
|
}
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
// --- settings ---
|
|
|
|
func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) {
|
|
if r.settingsService == nil {
|
|
return settings.Settings{}, fmt.Errorf("settings service not configured")
|
|
}
|
|
return r.settingsService.GetBot(ctx, botID)
|
|
}
|
|
|
|
// --- utility ---
|
|
|
|
func normalizeClientType(clientType string) (string, error) {
|
|
ct := strings.ToLower(strings.TrimSpace(clientType))
|
|
switch ct {
|
|
case "openai-responses", "openai-completions", "anthropic-messages", "google-generative-ai":
|
|
return ct, nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType)
|
|
}
|
|
}
|
|
|
|
func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage {
|
|
cleaned := make([]conversation.ModelMessage, 0, len(messages))
|
|
for _, msg := range messages {
|
|
if strings.TrimSpace(msg.Role) == "" {
|
|
continue
|
|
}
|
|
if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" {
|
|
continue
|
|
}
|
|
cleaned = append(cleaned, msg)
|
|
}
|
|
return cleaned
|
|
}
|
|
|
|
func normalizeGatewaySkill(entry SkillEntry) (gatewaySkill, bool) {
|
|
name := strings.TrimSpace(entry.Name)
|
|
if name == "" {
|
|
return gatewaySkill{}, false
|
|
}
|
|
description := strings.TrimSpace(entry.Description)
|
|
if description == "" {
|
|
description = name
|
|
}
|
|
content := strings.TrimSpace(entry.Content)
|
|
if content == "" {
|
|
content = description
|
|
}
|
|
return gatewaySkill{
|
|
Name: name,
|
|
Description: description,
|
|
Content: content,
|
|
Metadata: entry.Metadata,
|
|
}, true
|
|
}
|
|
|
|
func dedup(items []string) []string {
|
|
seen := make(map[string]struct{}, len(items))
|
|
result := make([]string, 0, len(items))
|
|
for _, s := range items {
|
|
trimmed := strings.TrimSpace(s)
|
|
if trimmed == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[trimmed]; ok {
|
|
continue
|
|
}
|
|
seen[trimmed] = struct{}{}
|
|
result = append(result, trimmed)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, v := range values {
|
|
if strings.TrimSpace(v) != "" {
|
|
return v
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func coalescePositiveInt(values ...int) int {
|
|
for _, v := range values {
|
|
if v > 0 {
|
|
return v
|
|
}
|
|
}
|
|
return defaultMaxContextMinutes
|
|
}
|
|
|
|
func nonNilStrings(s []string) []string {
|
|
if s == nil {
|
|
return []string{}
|
|
}
|
|
return s
|
|
}
|
|
|
|
func nonNilModelMessages(m []conversation.ModelMessage) []conversation.ModelMessage {
|
|
if m == nil {
|
|
return []conversation.ModelMessage{}
|
|
}
|
|
return m
|
|
}
|
|
|
|
func truncate(s string, n int) string {
|
|
if len(s) <= n {
|
|
return s
|
|
}
|
|
return s[:n] + "..."
|
|
}
|
|
|
|
func truncateMemorySnippet(s string, n int) string {
|
|
trimmed := strings.TrimSpace(s)
|
|
if len(trimmed) <= n {
|
|
return trimmed
|
|
}
|
|
return strings.TrimSpace(trimmed[:n]) + "..."
|
|
}
|
|
|
|
func parseResolverUUID(id string) (pgtype.UUID, error) {
|
|
if strings.TrimSpace(id) == "" {
|
|
return pgtype.UUID{}, fmt.Errorf("empty id")
|
|
}
|
|
return db.ParseUUID(id)
|
|
}
|