Files
Memoh/internal/conversation/flow/resolver.go
T
BBQ faaadf14c5 feat(channel): pass conversation type through to agent gateway and persist in route
Propagate conversation type (direct/group/thread) from channel adapters
all the way to the agent prompt. Store conversation_type on bot_channel_routes
so the bot knows whether a message originates from a p2p chat, group, or thread.

Schema changes are folded into the 0001 init migration (destructive update).
2026-02-13 06:22:24 +08:00

1170 lines
35 KiB
Go

package flow
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"sort"
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/memory"
messagepkg "github.com/memohai/memoh/internal/message"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/schedule"
"github.com/memohai/memoh/internal/settings"
)
const (
defaultMaxContextMinutes = 24 * 60
memoryContextLimitPerScope = 4
memoryContextMaxItems = 8
memoryContextItemMaxChars = 220
sharedMemoryNamespace = "bot"
)
// SkillEntry represents a skill loaded from the container.
type SkillEntry struct {
Name string
Description string
Content string
Metadata map[string]any
}
// SkillLoader loads skills for a given bot from its container.
type SkillLoader interface {
LoadSkills(ctx context.Context, botID string) ([]SkillEntry, error)
}
// ConversationSettingsReader defines settings lookup behavior needed by flow resolution.
type ConversationSettingsReader interface {
GetSettings(ctx context.Context, conversationID string) (conversation.Settings, error)
}
// Resolver orchestrates chat with the agent gateway.
type Resolver struct {
modelsService *models.Service
queries *sqlc.Queries
memoryService *memory.Service
conversationSvc ConversationSettingsReader
messageService messagepkg.Service
settingsService *settings.Service
skillLoader SkillLoader
gatewayBaseURL string
timeout time.Duration
logger *slog.Logger
httpClient *http.Client
streamingClient *http.Client
}
// NewResolver creates a Resolver that communicates with the agent gateway.
func NewResolver(
log *slog.Logger,
modelsService *models.Service,
queries *sqlc.Queries,
memoryService *memory.Service,
conversationSvc ConversationSettingsReader,
messageService messagepkg.Service,
settingsService *settings.Service,
gatewayBaseURL string,
timeout time.Duration,
) *Resolver {
if strings.TrimSpace(gatewayBaseURL) == "" {
gatewayBaseURL = "http://127.0.0.1:8081"
}
gatewayBaseURL = strings.TrimRight(gatewayBaseURL, "/")
if timeout <= 0 {
timeout = 60 * time.Second
}
return &Resolver{
modelsService: modelsService,
queries: queries,
memoryService: memoryService,
conversationSvc: conversationSvc,
messageService: messageService,
settingsService: settingsService,
gatewayBaseURL: gatewayBaseURL,
timeout: timeout,
logger: log.With(slog.String("service", "conversation_resolver")),
httpClient: &http.Client{Timeout: timeout},
streamingClient: &http.Client{},
}
}
// SetSkillLoader sets the skill loader used to populate usable skills in gateway requests.
func (r *Resolver) SetSkillLoader(sl SkillLoader) {
r.skillLoader = sl
}
// --- gateway payload ---
type gatewayModelConfig struct {
ModelID string `json:"modelId"`
ClientType string `json:"clientType"`
Input []string `json:"input"`
APIKey string `json:"apiKey"`
BaseURL string `json:"baseUrl"`
}
type gatewayIdentity struct {
BotID string `json:"botId"`
ContainerID string `json:"containerId"`
ChannelIdentityID string `json:"channelIdentityId"`
DisplayName string `json:"displayName"`
CurrentPlatform string `json:"currentPlatform,omitempty"`
ConversationType string `json:"conversationType,omitempty"`
SessionToken string `json:"sessionToken,omitempty"`
}
type gatewaySkill struct {
Name string `json:"name"`
Description string `json:"description"`
Content string `json:"content"`
Metadata map[string]any `json:"metadata,omitempty"`
}
type gatewayRequest struct {
Model gatewayModelConfig `json:"model"`
ActiveContextTime int `json:"activeContextTime"`
Channels []string `json:"channels"`
CurrentChannel string `json:"currentChannel"`
AllowedActions []string `json:"allowedActions,omitempty"`
Messages []conversation.ModelMessage `json:"messages"`
Skills []string `json:"skills"`
UsableSkills []gatewaySkill `json:"usableSkills"`
Query string `json:"query,omitempty"`
Identity gatewayIdentity `json:"identity"`
Attachments []any `json:"attachments"`
}
type gatewayResponse struct {
Messages []conversation.ModelMessage `json:"messages"`
Skills []string `json:"skills"`
}
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
type gatewaySchedule struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Pattern string `json:"pattern"`
MaxCalls *int `json:"maxCalls,omitempty"`
Command string `json:"command"`
}
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
type triggerScheduleRequest struct {
gatewayRequest
Schedule gatewaySchedule `json:"schedule"`
}
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
type resolvedContext struct {
payload gatewayRequest
model models.GetResponse
provider sqlc.LlmProvider
}
func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) {
if strings.TrimSpace(req.Query) == "" {
return resolvedContext{}, fmt.Errorf("query is required")
}
if strings.TrimSpace(req.BotID) == "" {
return resolvedContext{}, fmt.Errorf("bot id is required")
}
if strings.TrimSpace(req.ChatID) == "" {
return resolvedContext{}, fmt.Errorf("chat id is required")
}
skipHistory := req.MaxContextLoadTime < 0
botSettings, err := r.loadBotSettings(ctx, req.BotID)
if err != nil {
return resolvedContext{}, err
}
// Check chat-level model override.
var chatSettings conversation.Settings
if r.conversationSvc != nil {
chatSettings, err = r.conversationSvc.GetSettings(ctx, req.ChatID)
if err != nil {
return resolvedContext{}, err
}
}
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, chatSettings)
if err != nil {
return resolvedContext{}, err
}
clientType, err := normalizeClientType(provider.ClientType)
if err != nil {
return resolvedContext{}, err
}
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
var messages []conversation.ModelMessage
if !skipHistory && r.conversationSvc != nil {
messages, err = r.loadMessages(ctx, req.ChatID, maxCtx)
if err != nil {
return resolvedContext{}, err
}
}
if memoryMsg := r.loadMemoryContextMessage(ctx, req); memoryMsg != nil {
messages = append(messages, *memoryMsg)
}
messages = append(messages, req.Messages...)
messages = sanitizeMessages(messages)
skills := dedup(req.Skills)
containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID)
var usableSkills []gatewaySkill
if r.skillLoader != nil {
entries, err := r.skillLoader.LoadSkills(ctx, req.BotID)
if err != nil {
r.logger.Warn("failed to load usable skills", slog.String("bot_id", req.BotID), slog.Any("error", err))
} else {
usableSkills = make([]gatewaySkill, 0, len(entries))
for _, e := range entries {
usableSkills = append(usableSkills, gatewaySkill{
Name: e.Name,
Description: e.Description,
Content: e.Content,
Metadata: e.Metadata,
})
}
}
}
if usableSkills == nil {
usableSkills = []gatewaySkill{}
}
payload := gatewayRequest{
Model: gatewayModelConfig{
ModelID: chatModel.ModelID,
ClientType: clientType,
Input: chatModel.Input,
APIKey: provider.ApiKey,
BaseURL: provider.BaseUrl,
},
ActiveContextTime: maxCtx,
Channels: nonNilStrings(req.Channels),
CurrentChannel: req.CurrentChannel,
AllowedActions: req.AllowedActions,
Messages: nonNilModelMessages(messages),
Skills: nonNilStrings(skills),
UsableSkills: usableSkills,
Query: req.Query,
Identity: gatewayIdentity{
BotID: req.BotID,
ContainerID: containerID,
ChannelIdentityID: strings.TrimSpace(req.SourceChannelIdentityID),
DisplayName: r.resolveDisplayName(ctx, req),
CurrentPlatform: req.CurrentChannel,
ConversationType: strings.TrimSpace(req.ConversationType),
SessionToken: req.ChatToken,
},
Attachments: []any{},
}
return resolvedContext{payload: payload, model: chatModel, provider: provider}, nil
}
// --- Chat ---
// Chat sends a synchronous chat request to the agent gateway and stores the result.
func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) {
rc, err := r.resolve(ctx, req)
if err != nil {
return conversation.ChatResponse{}, err
}
resp, err := r.postChat(ctx, rc.payload, req.Token)
if err != nil {
return conversation.ChatResponse{}, err
}
if err := r.storeRound(ctx, req, resp.Messages); err != nil {
return conversation.ChatResponse{}, err
}
return conversation.ChatResponse{
Messages: resp.Messages,
Skills: resp.Skills,
Model: rc.model.ModelID,
Provider: rc.provider.ClientType,
}, nil
}
// --- TriggerSchedule ---
// TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint.
func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error {
if strings.TrimSpace(botID) == "" {
return fmt.Errorf("bot id is required")
}
if strings.TrimSpace(payload.Command) == "" {
return fmt.Errorf("schedule command is required")
}
chatID := "schedule-" + payload.ID
req := conversation.ChatRequest{
BotID: botID,
ChatID: chatID,
Query: payload.Command,
UserID: payload.OwnerUserID,
Token: token,
}
rc, err := r.resolve(ctx, req)
if err != nil {
return err
}
schedulePayload := rc.payload
schedulePayload.Identity.ChannelIdentityID = strings.TrimSpace(payload.OwnerUserID)
schedulePayload.Identity.DisplayName = "Scheduler"
triggerReq := triggerScheduleRequest{
gatewayRequest: schedulePayload,
Schedule: gatewaySchedule{
ID: payload.ID,
Name: payload.Name,
Description: payload.Description,
Pattern: payload.Pattern,
MaxCalls: payload.MaxCalls,
Command: payload.Command,
},
}
resp, err := r.postTriggerSchedule(ctx, triggerReq, token)
if err != nil {
return err
}
return r.storeRound(ctx, req, resp.Messages)
}
// --- StreamChat ---
// StreamChat sends a streaming chat request to the agent gateway.
func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) {
chunkCh := make(chan conversation.StreamChunk)
errCh := make(chan error, 1)
r.logger.Info("gateway stream start",
slog.String("bot_id", req.BotID),
slog.String("chat_id", req.ChatID),
)
go func() {
defer close(chunkCh)
defer close(errCh)
streamReq := req
rc, err := r.resolve(ctx, streamReq)
if err != nil {
r.logger.Error("gateway stream resolve failed",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.Any("error", err),
)
errCh <- err
return
}
if !streamReq.UserMessagePersisted {
if err := r.persistUserMessage(ctx, streamReq); err != nil {
r.logger.Error("gateway stream persist user message failed",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.Any("error", err),
)
errCh <- err
return
}
streamReq.UserMessagePersisted = true
}
if err := r.streamChat(ctx, rc.payload, streamReq, chunkCh); err != nil {
r.logger.Error("gateway stream request failed",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.Any("error", err),
)
errCh <- err
}
}()
return chunkCh, errCh
}
// --- HTTP helpers ---
func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) {
body, err := json.Marshal(payload)
if err != nil {
return gatewayResponse{}, err
}
url := r.gatewayBaseURL + "/chat/"
r.logger.Info("gateway request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200)))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return gatewayResponse{}, err
}
httpReq.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(token) != "" {
httpReq.Header.Set("Authorization", token)
}
resp, err := r.httpClient.Do(httpReq)
if err != nil {
return gatewayResponse{}, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return gatewayResponse{}, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
r.logger.Error("gateway error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
}
var parsed gatewayResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
r.logger.Error("gateway response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
}
return parsed, nil
}
// postTriggerSchedule sends a trigger-schedule request to the agent gateway.
func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerScheduleRequest, token string) (gatewayResponse, error) {
body, err := json.Marshal(payload)
if err != nil {
return gatewayResponse{}, err
}
url := r.gatewayBaseURL + "/chat/trigger-schedule"
r.logger.Info("gateway trigger-schedule request", slog.String("url", url), slog.String("schedule_id", payload.Schedule.ID))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return gatewayResponse{}, err
}
httpReq.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(token) != "" {
httpReq.Header.Set("Authorization", token)
}
resp, err := r.httpClient.Do(httpReq)
if err != nil {
return gatewayResponse{}, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return gatewayResponse{}, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
r.logger.Error("gateway trigger-schedule error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300)))
return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody)))
}
var parsed gatewayResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
r.logger.Error("gateway trigger-schedule response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err))
return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err)
}
return parsed, nil
}
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req conversation.ChatRequest, chunkCh chan<- conversation.StreamChunk) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
url := r.gatewayBaseURL + "/chat/stream"
r.logger.Info("gateway stream request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200)))
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
if strings.TrimSpace(req.Token) != "" {
httpReq.Header.Set("Authorization", req.Token)
}
resp, err := r.streamingClient.Do(httpReq)
if err != nil {
r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err))
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
errBody, _ := io.ReadAll(resp.Body)
r.logger.Error("gateway stream error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(errBody), 300)))
return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
currentEvent := ""
stored := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
if strings.HasPrefix(line, "event:") {
currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
continue
}
chunkCh <- conversation.StreamChunk([]byte(data))
if stored {
continue
}
if handled, storeErr := r.tryStoreStream(ctx, req, currentEvent, data); storeErr != nil {
return storeErr
} else if handled {
stored = true
}
}
return scanner.Err()
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, eventType, data string) (bool, error) {
// event: done + data: {messages: [...]}
if eventType == "done" {
var resp gatewayResponse
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
return true, r.storeRound(ctx, req, resp.Messages)
}
}
// data: {"type":"text_delta"|"agent_end"|"done", ...}
var envelope struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
Messages []conversation.ModelMessage `json:"messages"`
Skills []string `json:"skills"`
}
if err := json.Unmarshal([]byte(data), &envelope); err == nil {
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
return true, r.storeRound(ctx, req, envelope.Messages)
}
if envelope.Type == "done" && len(envelope.Data) > 0 {
var resp gatewayResponse
if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 {
return true, r.storeRound(ctx, req, resp.Messages)
}
}
}
// fallback: data: {messages: [...]}
var resp gatewayResponse
if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 {
return true, r.storeRound(ctx, req, resp.Messages)
}
return false, nil
}
// --- container resolution ---
func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string {
if strings.TrimSpace(explicit) != "" {
return explicit
}
if r.queries != nil {
pgBotID, err := parseResolverUUID(botID)
if err == nil {
row, err := r.queries.GetContainerByBotID(ctx, pgBotID)
if err == nil && strings.TrimSpace(row.ContainerID) != "" {
return row.ContainerID
}
}
}
r.logger.Warn("no container found for bot, using fallback", slog.String("bot_id", botID))
return "mcp-" + botID
}
// --- message loading ---
func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]conversation.ModelMessage, error) {
if r.messageService == nil {
return nil, nil
}
since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute)
msgs, err := r.messageService.ListSince(ctx, chatID, since)
if err != nil {
return nil, err
}
var result []conversation.ModelMessage
for _, m := range msgs {
var mm conversation.ModelMessage
if err := json.Unmarshal(m.Content, &mm); err != nil {
r.logger.Warn("loadMessages: content unmarshal failed, treating as raw text",
slog.String("chat_id", chatID), slog.Any("error", err))
mm = conversation.ModelMessage{Role: m.Role, Content: m.Content}
} else {
mm.Role = m.Role
}
result = append(result, mm)
}
return result, nil
}
type memoryContextItem struct {
Namespace string
Item memory.MemoryItem
}
func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversation.ChatRequest) *conversation.ModelMessage {
if r.memoryService == nil {
return nil
}
if strings.TrimSpace(req.Query) == "" || strings.TrimSpace(req.BotID) == "" || strings.TrimSpace(req.ChatID) == "" {
return nil
}
results := make([]memoryContextItem, 0, memoryContextLimitPerScope)
seen := map[string]struct{}{}
resp, err := r.memoryService.Search(ctx, memory.SearchRequest{
Query: req.Query,
BotID: req.BotID,
Limit: memoryContextLimitPerScope,
Filters: map[string]any{
"namespace": sharedMemoryNamespace,
"scopeId": req.BotID,
"bot_id": req.BotID,
},
})
if err != nil {
r.logger.Warn("memory search for context failed",
slog.String("namespace", sharedMemoryNamespace),
slog.Any("error", err),
)
return nil
}
for _, item := range resp.Results {
key := strings.TrimSpace(item.ID)
if key == "" {
key = sharedMemoryNamespace + ":" + strings.TrimSpace(item.Memory)
}
if key == "" {
continue
}
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
results = append(results, memoryContextItem{Namespace: sharedMemoryNamespace, Item: item})
}
if len(results) == 0 {
return nil
}
sort.Slice(results, func(i, j int) bool {
return results[i].Item.Score > results[j].Item.Score
})
if len(results) > memoryContextMaxItems {
results = results[:memoryContextMaxItems]
}
var sb strings.Builder
sb.WriteString("Relevant memory context (use when helpful):\n")
for _, entry := range results {
text := strings.TrimSpace(entry.Item.Memory)
if text == "" {
continue
}
sb.WriteString("- [")
sb.WriteString(entry.Namespace)
sb.WriteString("] ")
sb.WriteString(truncateMemorySnippet(text, memoryContextItemMaxChars))
sb.WriteString("\n")
}
payload := strings.TrimSpace(sb.String())
if payload == "" {
return nil
}
msg := conversation.ModelMessage{
Role: "system",
Content: conversation.NewTextContent(payload),
}
return &msg
}
// --- store helpers ---
func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.ChatRequest) error {
if r.messageService == nil {
return nil
}
if strings.TrimSpace(req.BotID) == "" {
return fmt.Errorf("bot id is required for persistence")
}
text := strings.TrimSpace(req.Query)
if text == "" {
return nil
}
message := conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(text),
}
content, err := json.Marshal(message)
if err != nil {
return err
}
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
_, err = r.messageService.Persist(ctx, messagepkg.PersistInput{
BotID: req.BotID,
RouteID: req.RouteID,
SenderChannelIdentityID: senderChannelIdentityID,
SenderUserID: senderUserID,
Platform: req.CurrentChannel,
ExternalMessageID: req.ExternalMessageID,
Role: "user",
Content: content,
Metadata: buildRouteMetadata(req),
})
return err
}
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage) error {
// Add user query as the first message if not already present in the round.
// This ensures the user's prompt is persisted alongside the assistant's response.
fullRound := make([]conversation.ModelMessage, 0, len(messages)+1)
hasUserQuery := false
for _, m := range messages {
if m.Role == "user" && m.TextContent() == req.Query {
hasUserQuery = true
break
}
}
if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" {
fullRound = append(fullRound, conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(req.Query),
})
}
for _, m := range messages {
if req.UserMessagePersisted && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
// User message was already persisted before streaming; skip duplicate copy in round payload.
continue
}
fullRound = append(fullRound, m)
}
if len(fullRound) == 0 {
return nil
}
r.storeMessages(ctx, req, fullRound)
r.storeMemory(ctx, req.BotID, fullRound)
return nil
}
func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage) {
if r.messageService == nil {
return
}
if strings.TrimSpace(req.BotID) == "" {
return
}
meta := buildRouteMetadata(req)
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
for _, msg := range messages {
content, err := json.Marshal(msg)
if err != nil {
r.logger.Warn("storeMessages: marshal failed", slog.Any("error", err))
continue
}
messageSenderChannelIdentityID := ""
messageSenderUserID := ""
externalMessageID := ""
sourceReplyToMessageID := ""
if msg.Role == "user" {
messageSenderChannelIdentityID = senderChannelIdentityID
messageSenderUserID = senderUserID
externalMessageID = req.ExternalMessageID
} else if strings.TrimSpace(req.ExternalMessageID) != "" {
// Assistant/tool/system outputs are linked to the inbound source message for cross-channel reply threading.
sourceReplyToMessageID = req.ExternalMessageID
}
if _, err := r.messageService.Persist(ctx, messagepkg.PersistInput{
BotID: req.BotID,
RouteID: req.RouteID,
SenderChannelIdentityID: messageSenderChannelIdentityID,
SenderUserID: messageSenderUserID,
Platform: req.CurrentChannel,
ExternalMessageID: externalMessageID,
SourceReplyToMessageID: sourceReplyToMessageID,
Role: msg.Role,
Content: content,
Metadata: meta,
}); err != nil {
r.logger.Warn("persist message failed", slog.Any("error", err))
}
}
}
func buildRouteMetadata(req conversation.ChatRequest) map[string]any {
if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" {
return nil
}
meta := map[string]any{}
if strings.TrimSpace(req.RouteID) != "" {
meta["route_id"] = req.RouteID
}
if strings.TrimSpace(req.CurrentChannel) != "" {
meta["platform"] = req.CurrentChannel
}
return meta
}
func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req conversation.ChatRequest) (string, string) {
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
userID := strings.TrimSpace(req.UserID)
senderChannelIdentityID := ""
if r.isExistingChannelIdentityID(ctx, channelIdentityID) {
senderChannelIdentityID = channelIdentityID
}
senderUserID := ""
if r.isExistingUserID(ctx, userID) {
senderUserID = userID
}
if senderUserID == "" && senderChannelIdentityID != "" {
if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" {
senderUserID = linked
}
}
return senderChannelIdentityID, senderUserID
}
func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool {
if r.queries == nil {
return false
}
pgID, err := parseResolverUUID(id)
if err != nil {
return false
}
_, err = r.queries.GetChannelIdentityByID(ctx, pgID)
return err == nil
}
func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool {
if r.queries == nil {
return false
}
pgID, err := parseResolverUUID(id)
if err != nil {
return false
}
_, err = r.queries.GetUserByID(ctx, pgID)
return err == nil
}
func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string {
if r.queries == nil {
return ""
}
pgID, err := parseResolverUUID(channelIdentityID)
if err != nil {
return ""
}
row, err := r.queries.GetChannelIdentityByID(ctx, pgID)
if err != nil || !row.UserID.Valid {
return ""
}
return row.UserID.String()
}
// resolveDisplayName returns the best available display name for the request identity:
// req.DisplayName if set, else channel identity's display_name, else linked user's display_name, else "User".
func (r *Resolver) resolveDisplayName(ctx context.Context, req conversation.ChatRequest) string {
if name := strings.TrimSpace(req.DisplayName); name != "" {
return name
}
if r.queries == nil {
return "User"
}
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
if channelIdentityID == "" {
return "User"
}
pgID, err := parseResolverUUID(channelIdentityID)
if err != nil {
return "User"
}
ci, err := r.queries.GetChannelIdentityByID(ctx, pgID)
if err == nil && ci.DisplayName.Valid {
if name := strings.TrimSpace(ci.DisplayName.String); name != "" {
return name
}
}
linkedUserID := r.linkedUserIDFromChannelIdentity(ctx, channelIdentityID)
if linkedUserID == "" {
return "User"
}
userPgID, err := parseResolverUUID(linkedUserID)
if err != nil {
return "User"
}
u, err := r.queries.GetUserByID(ctx, userPgID)
if err != nil || !u.DisplayName.Valid {
return "User"
}
if name := strings.TrimSpace(u.DisplayName.String); name != "" {
return name
}
return "User"
}
func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []conversation.ModelMessage) {
if r.memoryService == nil {
return
}
if strings.TrimSpace(botID) == "" {
return
}
memMsgs := make([]memory.Message, 0, len(messages))
for _, msg := range messages {
text := strings.TrimSpace(msg.TextContent())
if text == "" {
continue
}
role := msg.Role
if strings.TrimSpace(role) == "" {
role = "assistant"
}
memMsgs = append(memMsgs, memory.Message{Role: role, Content: text})
}
if len(memMsgs) == 0 {
return
}
r.addMemory(ctx, botID, memMsgs, sharedMemoryNamespace, botID)
}
func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Message, namespace, scopeID string) {
filters := map[string]any{
"namespace": namespace,
"scopeId": scopeID,
"bot_id": botID,
}
if _, err := r.memoryService.Add(ctx, memory.AddRequest{
Messages: msgs,
BotID: botID,
Filters: filters,
}); err != nil {
r.logger.Warn("store memory failed",
slog.String("namespace", namespace),
slog.String("scope_id", scopeID),
slog.Any("error", err),
)
}
}
// --- model selection ---
func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.LlmProvider, error) {
if r.modelsService == nil {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
}
modelID := strings.TrimSpace(req.Model)
providerFilter := strings.TrimSpace(req.Provider)
// Priority: request model > chat settings > bot settings.
if modelID == "" && providerFilter == "" {
if value := strings.TrimSpace(cs.ModelID); value != "" {
modelID = value
} else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" {
modelID = value
}
}
if modelID == "" {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings")
}
if providerFilter == "" {
return r.fetchChatModel(ctx, modelID)
}
candidates, err := r.listCandidates(ctx, providerFilter)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
for _, m := range candidates {
if m.ModelID == modelID {
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return m, prov, nil
}
}
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter)
}
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
model, err := r.modelsService.GetByModelID(ctx, modelID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
if model.Type != models.ModelTypeChat {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model")
}
prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return model, prov, nil
}
func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) {
var all []models.GetResponse
var err error
if providerFilter != "" {
all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter))
} else {
all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat)
}
if err != nil {
return nil, err
}
filtered := make([]models.GetResponse, 0, len(all))
for _, m := range all {
if m.Type == models.ModelTypeChat {
filtered = append(filtered, m)
}
}
return filtered, nil
}
// --- settings ---
func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) {
if r.settingsService == nil {
return settings.Settings{}, fmt.Errorf("settings service not configured")
}
return r.settingsService.GetBot(ctx, botID)
}
// --- utility ---
func normalizeClientType(clientType string) (string, error) {
ct := strings.ToLower(strings.TrimSpace(clientType))
switch ct {
case "openai", "openai-compat", "anthropic", "google",
"azure", "bedrock", "mistral", "xai", "ollama", "dashscope":
return ct, nil
default:
return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType)
}
}
func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage {
cleaned := make([]conversation.ModelMessage, 0, len(messages))
for _, msg := range messages {
if strings.TrimSpace(msg.Role) == "" {
continue
}
if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" {
continue
}
cleaned = append(cleaned, msg)
}
return cleaned
}
func dedup(items []string) []string {
seen := make(map[string]struct{}, len(items))
result := make([]string, 0, len(items))
for _, s := range items {
trimmed := strings.TrimSpace(s)
if trimmed == "" {
continue
}
if _, ok := seen[trimmed]; ok {
continue
}
seen[trimmed] = struct{}{}
result = append(result, trimmed)
}
return result
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}
func coalescePositiveInt(values ...int) int {
for _, v := range values {
if v > 0 {
return v
}
}
return defaultMaxContextMinutes
}
func nonNilStrings(s []string) []string {
if s == nil {
return []string{}
}
return s
}
func nonNilModelMessages(m []conversation.ModelMessage) []conversation.ModelMessage {
if m == nil {
return []conversation.ModelMessage{}
}
return m
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
func truncateMemorySnippet(s string, n int) string {
trimmed := strings.TrimSpace(s)
if len(trimmed) <= n {
return trimmed
}
return strings.TrimSpace(trimmed[:n]) + "..."
}
func parseResolverUUID(id string) (pgtype.UUID, error) {
if strings.TrimSpace(id) == "" {
return pgtype.UUID{}, fmt.Errorf("empty id")
}
return db.ParseUUID(id)
}