refactor(agent): remove agent gateway instead of twilight sdk (#264)

* refactor(agent): replace TypeScript agent gateway with in-process Go agent using twilight-ai SDK

- Remove apps/agent (Bun/Elysia gateway), packages/agent (@memoh/agent),
  internal/bun runtime manager, and all embedded agent/bun assets
- Add internal/agent package powered by twilight-ai SDK for LLM calls,
  tool execution, streaming, sential logic, tag extraction, and prompts
- Integrate ToolGatewayService in-process for both built-in and user MCP
  tools, eliminating HTTP round-trips to the old gateway
- Update resolver to convert between sdk.Message and ModelMessage at the
  boundary (resolver_messages.go), keeping agent package free of
  persistence concerns
- Prepend user message before storeRound since SDK only returns output
  messages (assistant + tool)
- Clean up all Docker configs, TOML configs, nginx proxy, Dockerfile.agent,
  and Go config structs related to the removed agent gateway
- Update cmd/agent and cmd/memoh entry points with setter-based
  ToolGateway injection to avoid FX dependency cycles

* fix(web): move form declaration before computed properties that reference it

The `form` reactive object was declared after computed properties like
`selectedMemoryProvider` and `isSelectedMemoryProviderPersisted` that
reference it, causing a TDZ ReferenceError during setup.

* fix: prevent UTF-8 character corruption in streaming text output

StreamTagExtractor.Push() used byte-level string slicing to hold back
buffer tails for tag detection, which could split multi-byte UTF-8
characters. After json.Marshal replaced invalid bytes with U+FFFD,
the corruption became permanent — causing garbled CJK characters (�)
in agent responses.

Add safeUTF8SplitIndex() to back up split points to valid character
boundaries. Also fix byte-level truncation in command/formatter.go
and command/fs.go to use rune-aware slicing.

* fix: add agent error logging and fix Gemini tool schema validation

- Log agent stream errors in both SSE and WebSocket paths with bot/model context
- Fix send tool `attachments` parameter: empty `items` schema rejected by
  Google Gemini API (INVALID_ARGUMENT), now specifies `{"type": "string"}`
- Upgrade twilight-ai to d898f0b (includes raw body in API error messages)

* chore(ci): remove agent gateway from Docker build and release pipelines

Agent gateway has been replaced by in-process Go agent; remove the
obsolete Docker image matrix entry, Bun/UPX CI steps, and agent-binary
build logic from the release script.

* fix: preserve attachment filename, metadata, and container path through persistence

- Add `name` column to `bot_history_message_assets` (migration 0034) to
  persist original filenames across page refreshes.
- Add `metadata` JSONB column (migration 0035) to store source_path,
  source_url, and other context alongside each asset.
- Update SQL queries, sqlc-generated code, and all Go types (MessageAsset,
  AssetRef, OutboundAssetRef, FileAttachment) to carry name and metadata
  through the full lifecycle.
- Extract filenames from path/URL in AttachmentsResolver before clearing
  raw paths; enrich streaming event metadata with name, source_path, and
  source_url in both the WebSocket and channel inbound ingestion paths.
- Implement `LinkAssets` on message service and `LinkOutboundAssets` on
  flow resolver so WebSocket-streamed bot attachments are persisted to the
  correct assistant message after streaming completes.
- Frontend: update MessageAsset type with metadata field, pass metadata
  through to attachment items, and reorder attachment-block.vue template
  so container files (identified by metadata.source_path) open in the
  sidebar file manager instead of triggering a download.

* refactor(agent): decouple built-in tools from MCP, load via ToolProvider interface

Migrate all 13 built-in tool providers from internal/mcp/providers/ to
internal/agent/tools/ using the twilight-ai sdk.Tool structure. The agent
now loads tools through a ToolProvider interface instead of the MCP
ToolGatewayService, which is simplified to only manage external federation
sources. This enables selective tool loading and removes the coupling
between business tools and the MCP protocol layer.

* refactor(flow): split monolithic resolver.go into focused modules

Break the 1959-line resolver.go into 12 files organized by concern:
- resolver.go: core orchestration (Resolver struct, resolve, Chat, prepareRunConfig)
- resolver_stream.go: streaming (StreamChat, StreamChatWS, tryStoreStream)
- resolver_trigger.go: schedule/heartbeat triggers
- resolver_attachments.go: attachment routing, inlining, encoding
- resolver_history.go: message loading, deduplication, token trimming
- resolver_store.go: persistence (storeRound, storeMessages, asset linking)
- resolver_memory.go: memory provider integration
- resolver_model_selection.go: model selection and candidate matching
- resolver_identity.go: display name and channel identity resolution
- resolver_settings.go: bot settings, loop detection, inbox
- user_header.go: YAML front-matter formatting
- resolver_util.go: shared utilities (sanitize, normalize, dedup, UUID)

* fix(agent): enable Anthropic extended thinking by passing ReasoningConfig to provider

Anthropic's thinking requires WithThinking() at provider creation time,
unlike OpenAI which uses per-request ReasoningEffort. The config was
never wired through, so Claude models could not trigger thinking.

* refactor(agent): extract prompts into embedded markdown templates

Move inline prompt strings from prompt.go into separate .md files under
internal/agent/prompts/, using {{key}} placeholders and a simple render
engine. Remove obsolete SystemPromptParams fields (Language,
MaxContextLoadTime, Channels, CurrentChannel) and their call-site usage.

* fix: lint
This commit is contained in:
Acbox Liu
2026-03-19 13:31:54 +08:00
committed by GitHub
parent ef333ae516
commit 1680316c7f
169 changed files with 7988 additions and 14436 deletions
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,255 @@
package flow
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
attachmentpkg "github.com/memohai/memoh/internal/attachment"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/models"
)
const (
gatewayInlineAttachmentMaxBytes int64 = 20 * 1024 * 1024
)
// routeAndMergeAttachments applies CapabilityFallbackPolicy to split
// request attachments by model input modalities, then merges the results
// into a single []any for the gateway request.
func (r *Resolver) routeAndMergeAttachments(ctx context.Context, model models.GetResponse, req conversation.ChatRequest) []any {
if len(req.Attachments) == 0 {
return []any{}
}
typed := r.prepareGatewayAttachments(ctx, req)
routed := routeAttachmentsByCapability(model.InputModalities, typed)
for i := range routed.Fallback {
fallbackPath := strings.TrimSpace(routed.Fallback[i].FallbackPath)
if fallbackPath == "" {
if r != nil && r.logger != nil {
r.logger.Warn(
"drop attachment without fallback path",
slog.String("type", strings.TrimSpace(routed.Fallback[i].Type)),
slog.String("transport", strings.TrimSpace(routed.Fallback[i].Transport)),
slog.String("content_hash", strings.TrimSpace(routed.Fallback[i].ContentHash)),
slog.Bool("has_payload", strings.TrimSpace(routed.Fallback[i].Payload) != ""),
)
}
routed.Fallback[i] = gatewayAttachment{}
continue
}
routed.Fallback[i].Type = "file"
routed.Fallback[i].Transport = gatewayTransportToolFileRef
routed.Fallback[i].Payload = fallbackPath
}
merged := make([]any, 0, len(routed.Native)+len(routed.Fallback))
merged = append(merged, attachmentsToAny(routed.Native)...)
for _, fb := range routed.Fallback {
if fb.Type == "" || strings.TrimSpace(fb.Transport) == "" || strings.TrimSpace(fb.Payload) == "" {
continue
}
merged = append(merged, fb)
}
if len(merged) == 0 {
return []any{}
}
return merged
}
func (r *Resolver) prepareGatewayAttachments(ctx context.Context, req conversation.ChatRequest) []gatewayAttachment {
if len(req.Attachments) == 0 {
return nil
}
prepared := make([]gatewayAttachment, 0, len(req.Attachments))
for _, raw := range req.Attachments {
attachmentType := strings.ToLower(strings.TrimSpace(raw.Type))
payload := strings.TrimSpace(raw.Base64)
transport := ""
fallbackPath := strings.TrimSpace(raw.Path)
if payload != "" {
transport = gatewayTransportInlineDataURL
} else {
rawURL := strings.TrimSpace(raw.URL)
switch {
case isDataURL(rawURL):
payload = rawURL
transport = gatewayTransportInlineDataURL
case isLikelyPublicURL(rawURL):
payload = rawURL
transport = gatewayTransportPublicURL
case rawURL != "" && fallbackPath == "":
fallbackPath = rawURL
}
}
item := gatewayAttachment{
ContentHash: strings.TrimSpace(raw.ContentHash),
Type: attachmentType,
Mime: strings.TrimSpace(raw.Mime),
Size: raw.Size,
Name: strings.TrimSpace(raw.Name),
Transport: transport,
Payload: payload,
Metadata: raw.Metadata,
FallbackPath: fallbackPath,
}
item = normalizeGatewayAttachmentPayload(item)
item = r.inlineImageAttachmentAssetIfNeeded(ctx, strings.TrimSpace(req.BotID), item)
prepared = append(prepared, item)
}
return prepared
}
func normalizeGatewayAttachmentPayload(item gatewayAttachment) gatewayAttachment {
if item.Transport != gatewayTransportInlineDataURL {
return item
}
payload := strings.TrimSpace(item.Payload)
if payload == "" {
return item
}
if strings.HasPrefix(strings.ToLower(payload), "data:") {
mime := strings.TrimSpace(item.Mime)
if mime == "" || strings.EqualFold(mime, "application/octet-stream") {
if extracted := attachmentpkg.MimeFromDataURL(payload); extracted != "" {
item.Mime = extracted
}
}
item.Payload = payload
return item
}
mime := strings.TrimSpace(item.Mime)
if mime == "" {
mime = "application/octet-stream"
}
item.Payload = attachmentpkg.NormalizeBase64DataURL(payload, mime)
return item
}
func isLikelyPublicURL(raw string) bool {
trimmed := strings.ToLower(strings.TrimSpace(raw))
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
}
func isDataURL(raw string) bool {
trimmed := strings.ToLower(strings.TrimSpace(raw))
return strings.HasPrefix(trimmed, "data:")
}
func (r *Resolver) inlineImageAttachmentAssetIfNeeded(ctx context.Context, botID string, item gatewayAttachment) gatewayAttachment {
if item.Type != "image" {
return item
}
if strings.TrimSpace(item.Payload) != "" &&
(item.Transport == gatewayTransportInlineDataURL || item.Transport == gatewayTransportPublicURL) {
return item
}
contentHash := strings.TrimSpace(item.ContentHash)
if contentHash == "" {
return item
}
dataURL, mime, err := r.inlineAssetAsDataURL(ctx, botID, contentHash, item.Type, item.Mime)
if err != nil {
if r != nil && r.logger != nil {
r.logger.Warn(
"inline gateway image attachment failed",
slog.Any("error", err),
slog.String("bot_id", botID),
slog.String("content_hash", contentHash),
)
}
return item
}
item.Transport = gatewayTransportInlineDataURL
item.Payload = dataURL
if strings.TrimSpace(item.Mime) == "" {
item.Mime = mime
}
return item
}
func (r *Resolver) inlineAssetAsDataURL(ctx context.Context, botID, contentHash, attachmentType, fallbackMime string) (string, string, error) {
if r == nil || r.assetLoader == nil {
return "", "", errors.New("gateway asset loader not configured")
}
reader, assetMime, err := r.assetLoader.OpenForGateway(ctx, botID, contentHash)
if err != nil {
return "", "", fmt.Errorf("open asset: %w", err)
}
defer func() {
_ = reader.Close()
}()
mime := strings.TrimSpace(fallbackMime)
if mime == "" {
mime = strings.TrimSpace(assetMime)
}
dataURL, resolvedMime, err := encodeReaderAsDataURL(reader, gatewayInlineAttachmentMaxBytes, attachmentType, mime)
if err != nil {
return "", "", err
}
return dataURL, resolvedMime, nil
}
func encodeReaderAsDataURL(reader io.Reader, maxBytes int64, attachmentType, fallbackMime string) (string, string, error) {
if reader == nil {
return "", "", errors.New("reader is required")
}
if maxBytes <= 0 {
return "", "", errors.New("max bytes must be greater than 0")
}
limited := &io.LimitedReader{R: reader, N: maxBytes + 1}
head := make([]byte, 512)
n, err := limited.Read(head)
if err != nil && !errors.Is(err, io.EOF) {
return "", "", fmt.Errorf("read asset: %w", err)
}
head = head[:n]
mime := strings.TrimSpace(fallbackMime)
if strings.EqualFold(strings.TrimSpace(attachmentType), "image") &&
(strings.TrimSpace(mime) == "" || strings.EqualFold(strings.TrimSpace(mime), "application/octet-stream")) {
detected := strings.TrimSpace(http.DetectContentType(head))
if strings.HasPrefix(strings.ToLower(detected), "image/") {
mime = detected
}
}
if mime == "" {
mime = "application/octet-stream"
}
var encoded strings.Builder
encoded.Grow(len("data:") + len(mime) + len(";base64,"))
encoded.WriteString("data:")
encoded.WriteString(mime)
encoded.WriteString(";base64,")
encoder := base64.NewEncoder(base64.StdEncoding, &encoded)
if len(head) > 0 {
if _, err := encoder.Write(head); err != nil {
_ = encoder.Close()
return "", "", fmt.Errorf("encode asset head: %w", err)
}
}
copied, err := io.Copy(encoder, limited)
if err != nil {
_ = encoder.Close()
return "", "", fmt.Errorf("encode asset body: %w", err)
}
if err := encoder.Close(); err != nil {
return "", "", fmt.Errorf("finalize asset encoding: %w", err)
}
total := int64(len(head)) + copied
if total > maxBytes {
return "", "", fmt.Errorf(
"asset too large to inline: %d > %d",
total,
maxBytes,
)
}
return encoded.String(), mime, nil
}
@@ -0,0 +1,160 @@
package flow
import (
"context"
"encoding/json"
"log/slog"
"strings"
"time"
"github.com/memohai/memoh/internal/conversation"
)
type messageWithUsage struct {
Message conversation.ModelMessage
UsageInputTokens *int
UsageOutputTokens *int
RouteID string
ExternalMessageID string
Platform string
SenderChannelID string
}
func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]messageWithUsage, error) {
if r.messageService == nil {
return nil, nil
}
since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute)
msgs, err := r.messageService.ListActiveSince(ctx, chatID, since)
if err != nil {
return nil, err
}
var result []messageWithUsage
for _, m := range msgs {
var mm conversation.ModelMessage
if err := json.Unmarshal(m.Content, &mm); err != nil {
r.logger.Warn("loadMessages: content unmarshal failed, treating as raw text",
slog.String("chat_id", chatID), slog.Any("error", err))
mm = conversation.ModelMessage{Role: m.Role, Content: m.Content}
} else {
mm.Role = m.Role
}
var inputTokens *int
var outputTokens *int
if len(m.Usage) > 0 {
var u usageInfo
if json.Unmarshal(m.Usage, &u) == nil {
inputTokens = u.InputTokens
outputTokens = u.OutputTokens
}
}
result = append(result, messageWithUsage{
Message: mm,
UsageInputTokens: inputTokens,
UsageOutputTokens: outputTokens,
RouteID: strings.TrimSpace(m.RouteID),
ExternalMessageID: strings.TrimSpace(m.ExternalMessageID),
Platform: strings.TrimSpace(m.Platform),
SenderChannelID: strings.TrimSpace(m.SenderChannelIdentityID),
})
}
return result, nil
}
func dedupePersistedCurrentUserMessage(messages []messageWithUsage, req conversation.ChatRequest) []messageWithUsage {
if !req.UserMessagePersisted || len(messages) == 0 {
return messages
}
targetRouteID := strings.TrimSpace(req.RouteID)
targetExternalID := strings.TrimSpace(req.ExternalMessageID)
targetPlatform := strings.TrimSpace(req.CurrentChannel)
targetSenderChannelID := strings.TrimSpace(req.SourceChannelIdentityID)
if targetExternalID == "" {
return messages
}
for i := len(messages) - 1; i >= 0; i-- {
item := messages[i]
if !strings.EqualFold(strings.TrimSpace(item.Message.Role), "user") {
continue
}
if strings.TrimSpace(item.ExternalMessageID) != targetExternalID {
continue
}
if targetRouteID != "" && item.RouteID != "" && item.RouteID != targetRouteID {
continue
}
if targetPlatform != "" && item.Platform != "" && !strings.EqualFold(item.Platform, targetPlatform) {
continue
}
if targetSenderChannelID != "" && item.SenderChannelID != "" && item.SenderChannelID != targetSenderChannelID {
continue
}
return append(messages[:i], messages[i+1:]...)
}
return messages
}
func estimateMessageTokens(msg conversation.ModelMessage) int {
text := msg.TextContent()
if len(text) == 0 {
data, _ := json.Marshal(msg.Content)
return len(data) / 4
}
return len(text) / 4
}
func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxTokens int) []conversation.ModelMessage {
if maxTokens == 0 || len(messages) == 0 {
result := make([]conversation.ModelMessage, len(messages))
for i, m := range messages {
result[i] = m.Message
}
return result
}
// Scan from newest to oldest, accumulating per-message token costs.
// Messages with stored usage data use that value; others fall back to a
// character-based estimate so that user/tool messages are not free-passed.
totalTokens := 0
cutoff := 0
messagesWithUsage := 0
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].UsageOutputTokens != nil {
totalTokens += *messages[i].UsageOutputTokens
messagesWithUsage++
} else {
totalTokens += estimateMessageTokens(messages[i].Message)
}
if totalTokens > maxTokens {
cutoff = i + 1
break
}
}
// Keep provider-valid message order: a "tool" message must follow a preceding
// assistant tool call. When history is head-trimmed, a leading tool message
// may become orphaned and cause provider 400 errors.
for cutoff < len(messages) && strings.EqualFold(strings.TrimSpace(messages[cutoff].Message.Role), "tool") {
cutoff++
}
if log != nil {
log.Debug("trimMessagesByTokens",
slog.Int("total_messages", len(messages)),
slog.Int("messages_with_usage", messagesWithUsage),
slog.Int("accumulated_output_tokens", totalTokens),
slog.Int("max_tokens", maxTokens),
slog.Int("cutoff_index", cutoff),
slog.Int("kept_messages", len(messages)-cutoff),
)
}
result := make([]conversation.ModelMessage, 0, len(messages)-cutoff)
for _, m := range messages[cutoff:] {
result = append(result, m.Message)
}
return result
}
@@ -0,0 +1,88 @@
package flow
import (
"context"
"strings"
"github.com/memohai/memoh/internal/conversation"
)
// resolveDisplayName returns the best available display name for the request identity:
// req.DisplayName if set, else channel identity's display_name, else linked user's display_name, else "User".
func (r *Resolver) resolveDisplayName(ctx context.Context, req conversation.ChatRequest) string {
if name := strings.TrimSpace(req.DisplayName); name != "" {
return name
}
if r.queries == nil {
return "User"
}
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
if channelIdentityID == "" {
return "User"
}
pgID, err := parseResolverUUID(channelIdentityID)
if err != nil {
return "User"
}
ci, err := r.queries.GetChannelIdentityByID(ctx, pgID)
if err == nil && ci.DisplayName.Valid {
if name := strings.TrimSpace(ci.DisplayName.String); name != "" {
return name
}
}
linkedUserID := r.linkedUserIDFromChannelIdentity(ctx, channelIdentityID)
if linkedUserID == "" {
return "User"
}
userPgID, err := parseResolverUUID(linkedUserID)
if err != nil {
return "User"
}
u, err := r.queries.GetUserByID(ctx, userPgID)
if err != nil || !u.DisplayName.Valid {
return "User"
}
if name := strings.TrimSpace(u.DisplayName.String); name != "" {
return name
}
return "User"
}
func (r *Resolver) isExistingChannelIdentityID(ctx context.Context, id string) bool {
if r.queries == nil {
return false
}
pgID, err := parseResolverUUID(id)
if err != nil {
return false
}
_, err = r.queries.GetChannelIdentityByID(ctx, pgID)
return err == nil
}
func (r *Resolver) isExistingUserID(ctx context.Context, id string) bool {
if r.queries == nil {
return false
}
pgID, err := parseResolverUUID(id)
if err != nil {
return false
}
_, err = r.queries.GetUserByID(ctx, pgID)
return err == nil
}
func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelIdentityID string) string {
if r.queries == nil {
return ""
}
pgID, err := parseResolverUUID(channelIdentityID)
if err != nil {
return ""
}
row, err := r.queries.GetChannelIdentityByID(ctx, pgID)
if err != nil || !row.UserID.Valid {
return ""
}
return row.UserID.String()
}
@@ -0,0 +1,97 @@
package flow
import (
"context"
"log/slog"
"strings"
"github.com/memohai/memoh/internal/conversation"
memprovider "github.com/memohai/memoh/internal/memory/adapters"
)
func (r *Resolver) resolveMemoryProvider(ctx context.Context, botID string) memprovider.Provider {
if r.memoryRegistry == nil {
return nil
}
if r.settingsService == nil {
return nil
}
botSettings, err := r.settingsService.GetBot(ctx, botID)
if err != nil {
return nil
}
providerID := strings.TrimSpace(botSettings.MemoryProviderID)
if providerID == "" {
return nil
}
p, err := r.memoryRegistry.Get(providerID)
if err != nil {
r.logger.Warn("memory provider lookup failed", slog.String("provider_id", providerID), slog.Any("error", err))
return nil
}
return p
}
func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversation.ChatRequest) *conversation.ModelMessage {
p := r.resolveMemoryProvider(ctx, req.BotID)
if p == nil {
return nil
}
result, err := p.OnBeforeChat(ctx, memprovider.BeforeChatRequest{
Query: req.Query,
BotID: req.BotID,
ChatID: req.ChatID,
})
if err != nil {
r.logger.Warn("memory provider OnBeforeChat failed", slog.Any("error", err))
return nil
}
if result == nil || strings.TrimSpace(result.ContextText) == "" {
return nil
}
return &conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(result.ContextText),
}
}
func (r *Resolver) storeMemory(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage) {
botID := strings.TrimSpace(req.BotID)
if botID == "" {
return
}
memMsgs := toProviderMessages(messages)
if len(memMsgs) == 0 {
return
}
p := r.resolveMemoryProvider(ctx, botID)
if p == nil {
return
}
if err := p.OnAfterChat(ctx, memprovider.AfterChatRequest{
BotID: botID,
Messages: memMsgs,
UserID: strings.TrimSpace(req.UserID),
ChannelIdentityID: strings.TrimSpace(req.SourceChannelIdentityID),
DisplayName: r.resolveDisplayName(ctx, req),
}); err != nil {
r.logger.Warn("memory provider OnAfterChat failed", slog.String("bot_id", botID), slog.Any("error", err))
}
}
func toProviderMessages(messages []conversation.ModelMessage) []memprovider.Message {
out := make([]memprovider.Message, 0, len(messages))
for _, msg := range messages {
text := strings.TrimSpace(msg.TextContent())
if text == "" {
continue
}
role := strings.TrimSpace(msg.Role)
if role == "" {
role = "assistant"
}
out = append(out, memprovider.Message{Role: role, Content: text})
}
return out
}
@@ -0,0 +1,84 @@
package flow
import (
"encoding/json"
"strings"
sdk "github.com/memohai/twilight-ai/sdk"
"github.com/memohai/memoh/internal/conversation"
)
// sdkMessagesToModelMessages converts SDK messages to the persistence/API format
// at the resolver boundary. This is the only place where this conversion should happen.
func sdkMessagesToModelMessages(msgs []sdk.Message) []conversation.ModelMessage {
result := make([]conversation.ModelMessage, 0, len(msgs))
for _, msg := range msgs {
data, err := json.Marshal(msg)
if err != nil {
continue
}
var envelope struct {
Content json.RawMessage `json:"content"`
}
if err := json.Unmarshal(data, &envelope); err != nil {
continue
}
result = append(result, conversation.ModelMessage{
Role: string(msg.Role),
Content: envelope.Content,
})
}
return result
}
// modelMessageToSDKMessage converts a persistence format message to SDK message
// at the resolver boundary using sdk.Message's native JSON deserialization.
func modelMessageToSDKMessage(mm conversation.ModelMessage) sdk.Message {
var s string
if err := json.Unmarshal(mm.Content, &s); err == nil {
return sdk.Message{
Role: sdk.MessageRole(mm.Role),
Content: []sdk.MessagePart{sdk.TextPart{Text: s}},
}
}
// Try the full sdk.Message format (content is an array of typed parts)
envelope, _ := json.Marshal(struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}{
Role: mm.Role,
Content: mm.Content,
})
var msg sdk.Message
if err := json.Unmarshal(envelope, &msg); err == nil {
return msg
}
return sdk.Message{Role: sdk.MessageRole(mm.Role)}
}
// prependUserMessage prepends the user query as a ModelMessage to the output
// messages from the agent. The SDK only returns output messages (assistant + tool);
// user messages must be added back at the resolver boundary for persistence.
func prependUserMessage(query string, output []conversation.ModelMessage) []conversation.ModelMessage {
if strings.TrimSpace(query) == "" {
return output
}
round := make([]conversation.ModelMessage, 0, 1+len(output))
round = append(round, conversation.ModelMessage{
Role: "user",
Content: conversation.NewTextContent(query),
})
return append(round, output...)
}
// modelMessagesToSDKMessages converts a slice of persistence messages to SDK messages.
func modelMessagesToSDKMessages(msgs []conversation.ModelMessage) []sdk.Message {
result := make([]sdk.Message, 0, len(msgs))
for _, mm := range msgs {
result = append(result, modelMessageToSDKMessage(mm))
}
return result
}
@@ -0,0 +1,119 @@
package flow
import (
"context"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/settings"
)
func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.LlmProvider, error) {
if r.modelsService == nil {
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured")
}
modelID := strings.TrimSpace(req.Model)
providerFilter := strings.TrimSpace(req.Provider)
// Priority: request model > chat settings > bot settings.
if modelID == "" && providerFilter == "" {
if value := strings.TrimSpace(cs.ModelID); value != "" {
modelID = value
} else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" {
modelID = value
}
}
if modelID == "" {
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("chat model not configured: specify model in request or bot settings")
}
if providerFilter == "" {
return r.fetchChatModel(ctx, modelID)
}
candidates, err := r.listCandidates(ctx, providerFilter)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
for _, m := range candidates {
if matchesModelReference(m, modelID) {
prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return m, prov, nil
}
}
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model %q not found for provider %q", modelID, providerFilter)
}
func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) {
modelRef := strings.TrimSpace(modelID)
if modelRef == "" {
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("model id is required")
}
// Support both model UUID and model_id slug. UUID-formatted slugs still
// work because we fall back to GetByModelID when UUID lookup misses.
var model models.GetResponse
var err error
if _, parseErr := db.ParseUUID(modelRef); parseErr == nil {
model, err = r.modelsService.GetByID(ctx, modelRef)
if err == nil {
goto resolved
}
if !errors.Is(err, pgx.ErrNoRows) {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
}
model, err = r.modelsService.GetByModelID(ctx, modelRef)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
resolved:
if model.Type != models.ModelTypeChat {
return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("model is not a chat model")
}
prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID)
if err != nil {
return models.GetResponse{}, sqlc.LlmProvider{}, err
}
return model, prov, nil
}
func matchesModelReference(model models.GetResponse, modelRef string) bool {
ref := strings.TrimSpace(modelRef)
if ref == "" {
return false
}
return model.ID == ref || model.ModelID == ref
}
func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) {
var all []models.GetResponse
var err error
if providerFilter != "" {
all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter))
} else {
all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat)
}
if err != nil {
return nil, err
}
filtered := make([]models.GetResponse, 0, len(all))
for _, m := range all {
if m.Type == models.ModelTypeChat {
filtered = append(filtered, m)
}
}
return filtered, nil
}
@@ -0,0 +1,69 @@
package flow
import (
"context"
"encoding/json"
"errors"
"log/slog"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/settings"
)
func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) {
if r.settingsService == nil {
return settings.Settings{}, errors.New("settings service not configured")
}
return r.settingsService.GetBot(ctx, botID)
}
func (r *Resolver) loadBotLoopDetectionEnabled(ctx context.Context, botID string) bool {
if r.queries == nil {
return false
}
botUUID, err := db.ParseUUID(botID)
if err != nil {
return false
}
row, err := r.queries.GetBotByID(ctx, botUUID)
if err != nil {
r.logger.Debug("failed to load bot metadata for loop detection",
slog.String("bot_id", botID),
slog.Any("error", err),
)
return false
}
return parseLoopDetectionEnabledFromMetadata(row.Metadata)
}
func parseLoopDetectionEnabledFromMetadata(payload []byte) bool {
if len(payload) == 0 {
return false
}
var metadata map[string]any
if err := json.Unmarshal(payload, &metadata); err != nil || metadata == nil {
return false
}
features, ok := metadata["features"].(map[string]any)
if !ok {
return false
}
loopDetection, ok := features["loop_detection"].(map[string]any)
if !ok {
return false
}
enabled, ok := loopDetection["enabled"].(bool)
if !ok {
return false
}
return enabled
}
func (r *Resolver) markInboxRead(ctx context.Context, botID string, ids []string) {
if r.inboxService == nil || len(ids) == 0 {
return
}
if err := r.inboxService.MarkRead(ctx, botID, ids); err != nil {
r.logger.Warn("failed to mark inbox items as read", slog.String("bot_id", botID), slog.Any("error", err))
}
}
@@ -1,32 +0,0 @@
package flow
import "testing"
func TestNormalizeGatewaySkill_Fallbacks(t *testing.T) {
got, ok := normalizeGatewaySkill(SkillEntry{
Name: " demo-skill ",
})
if !ok {
t.Fatal("expected valid skill")
}
if got.Name != "demo-skill" {
t.Fatalf("expected trimmed name demo-skill, got %q", got.Name)
}
if got.Description != "demo-skill" {
t.Fatalf("expected description fallback to name, got %q", got.Description)
}
if got.Content != "demo-skill" {
t.Fatalf("expected content fallback to description, got %q", got.Content)
}
}
func TestNormalizeGatewaySkill_RejectsEmptyName(t *testing.T) {
_, ok := normalizeGatewaySkill(SkillEntry{
Name: " ",
Description: "desc",
Content: "content",
})
if ok {
t.Fatal("expected invalid skill when name is empty")
}
}
@@ -0,0 +1,238 @@
package flow
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"strings"
"github.com/memohai/memoh/internal/conversation"
messagepkg "github.com/memohai/memoh/internal/message"
)
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage, usage json.RawMessage, usages []json.RawMessage, modelID string) error {
fullRound := make([]conversation.ModelMessage, 0, len(messages))
roundUsages := make([]json.RawMessage, 0, len(usages))
// When the user message was already persisted by a channel adapter, skip
// the duplicate from the round. Otherwise keep it so that user + assistant
// messages are written atomically (deferred persistence).
skipUserQuery := req.UserMessagePersisted
for i, m := range messages {
if skipUserQuery && m.Role == "user" && strings.TrimSpace(m.TextContent()) == strings.TrimSpace(req.Query) {
skipUserQuery = false // only skip the first matching user message
continue
}
fullRound = append(fullRound, m)
if i < len(usages) {
roundUsages = append(roundUsages, usages[i])
}
}
if len(fullRound) == 0 {
return nil
}
r.storeMessages(ctx, req, fullRound, usage, roundUsages, modelID)
go r.storeMemory(context.WithoutCancel(ctx), req, fullRound)
return nil
}
func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage, usage json.RawMessage, usages []json.RawMessage, modelID string) {
if r.messageService == nil {
return
}
if strings.TrimSpace(req.BotID) == "" {
return
}
meta := buildRouteMetadata(req)
senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req)
// Determine the last assistant message index for outbound asset attachment.
lastAssistantIdx := -1
if req.OutboundAssetCollector != nil {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "assistant" {
lastAssistantIdx = i
break
}
}
}
var outboundAssets []messagepkg.AssetRef
if lastAssistantIdx >= 0 {
outboundAssets = outboundAssetRefsToMessageRefs(req.OutboundAssetCollector())
}
for i, msg := range messages {
content, err := json.Marshal(msg)
if err != nil {
r.logger.Warn("storeMessages: marshal failed", slog.Any("error", err))
continue
}
messageSenderChannelIdentityID := ""
messageSenderUserID := ""
externalMessageID := ""
sourceReplyToMessageID := ""
assets := []messagepkg.AssetRef(nil)
if msg.Role == "user" {
messageSenderChannelIdentityID = senderChannelIdentityID
messageSenderUserID = senderUserID
externalMessageID = req.ExternalMessageID
if strings.TrimSpace(msg.TextContent()) == strings.TrimSpace(req.Query) {
assets = chatAttachmentsToAssetRefs(req.Attachments)
}
} else if strings.TrimSpace(req.ExternalMessageID) != "" {
sourceReplyToMessageID = req.ExternalMessageID
}
if i == lastAssistantIdx && len(outboundAssets) > 0 {
assets = append(assets, outboundAssets...)
}
var msgUsage json.RawMessage
if i < len(usages) && len(usages[i]) > 0 && !isJSONNull(usages[i]) {
msgUsage = usages[i]
} else if i == len(messages)-1 && len(usage) > 0 {
msgUsage = usage
}
if _, err := r.messageService.Persist(ctx, messagepkg.PersistInput{
BotID: req.BotID,
RouteID: req.RouteID,
SenderChannelIdentityID: messageSenderChannelIdentityID,
SenderUserID: messageSenderUserID,
Platform: req.CurrentChannel,
ExternalMessageID: externalMessageID,
SourceReplyToMessageID: sourceReplyToMessageID,
Role: msg.Role,
Content: content,
Metadata: meta,
Usage: msgUsage,
Assets: assets,
ModelID: modelID,
}); err != nil {
r.logger.Warn("persist message failed", slog.Any("error", err))
}
}
}
func isJSONNull(data json.RawMessage) bool {
return len(data) == 0 || bytes.Equal(bytes.TrimSpace(data), []byte("null"))
}
// outboundAssetRefsToMessageRefs converts outbound asset refs from the streaming
// collector into message-level asset refs for persistence.
func outboundAssetRefsToMessageRefs(refs []conversation.OutboundAssetRef) []messagepkg.AssetRef {
if len(refs) == 0 {
return nil
}
result := make([]messagepkg.AssetRef, 0, len(refs))
for _, ref := range refs {
contentHash := strings.TrimSpace(ref.ContentHash)
if contentHash == "" {
continue
}
role := ref.Role
if strings.TrimSpace(role) == "" {
role = "attachment"
}
result = append(result, messagepkg.AssetRef{
ContentHash: contentHash,
Role: role,
Ordinal: ref.Ordinal,
Mime: ref.Mime,
SizeBytes: ref.SizeBytes,
StorageKey: ref.StorageKey,
Name: ref.Name,
Metadata: ref.Metadata,
})
}
return result
}
// chatAttachmentsToAssetRefs converts ChatAttachment slice to message AssetRef slice.
// Only attachments that carry a content_hash are included.
func chatAttachmentsToAssetRefs(attachments []conversation.ChatAttachment) []messagepkg.AssetRef {
if len(attachments) == 0 {
return nil
}
refs := make([]messagepkg.AssetRef, 0, len(attachments))
for i, att := range attachments {
contentHash := strings.TrimSpace(att.ContentHash)
if contentHash == "" {
continue
}
ref := messagepkg.AssetRef{
ContentHash: contentHash,
Role: "attachment",
Ordinal: i,
Mime: strings.TrimSpace(att.Mime),
SizeBytes: att.Size,
Name: strings.TrimSpace(att.Name),
Metadata: att.Metadata,
}
if att.Metadata != nil {
if sk, ok := att.Metadata["storage_key"].(string); ok {
ref.StorageKey = sk
}
}
refs = append(refs, ref)
}
return refs
}
func buildRouteMetadata(req conversation.ChatRequest) map[string]any {
if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" {
return nil
}
meta := map[string]any{}
if strings.TrimSpace(req.RouteID) != "" {
meta["route_id"] = req.RouteID
}
if strings.TrimSpace(req.CurrentChannel) != "" {
meta["platform"] = req.CurrentChannel
}
return meta
}
func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req conversation.ChatRequest) (string, string) {
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
userID := strings.TrimSpace(req.UserID)
senderChannelIdentityID := ""
if r.isExistingChannelIdentityID(ctx, channelIdentityID) {
senderChannelIdentityID = channelIdentityID
}
senderUserID := ""
if r.isExistingUserID(ctx, userID) {
senderUserID = userID
}
if senderUserID == "" && senderChannelIdentityID != "" {
if linked := r.linkedUserIDFromChannelIdentity(ctx, senderChannelIdentityID); linked != "" {
senderUserID = linked
}
}
return senderChannelIdentityID, senderUserID
}
// LinkOutboundAssets links bot-generated assets to the latest assistant
// message for the given bot. Used by the WebSocket path where attachment
// ingestion happens after message persistence.
func (r *Resolver) LinkOutboundAssets(ctx context.Context, botID string, assets []messagepkg.AssetRef) {
if r.messageService == nil || len(assets) == 0 || strings.TrimSpace(botID) == "" {
return
}
// ListLatest returns messages in DESC order (newest first).
msgs, err := r.messageService.ListLatest(ctx, botID, 5)
if err != nil {
r.logger.Warn("LinkOutboundAssets: list latest failed", slog.Any("error", err))
return
}
for _, msg := range msgs {
if msg.Role == "assistant" {
if linkErr := r.messageService.LinkAssets(ctx, msg.ID, assets); linkErr != nil {
r.logger.Warn("LinkOutboundAssets: link failed", slog.Any("error", linkErr))
}
return
}
}
r.logger.Warn("LinkOutboundAssets: no assistant message found", slog.String("bot_id", botID))
}
@@ -0,0 +1,170 @@
package flow
import (
"context"
"encoding/json"
"fmt"
"log/slog"
sdk "github.com/memohai/twilight-ai/sdk"
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/conversation"
)
// WSStreamEvent represents a raw JSON event forwarded from the agent.
type WSStreamEvent = json.RawMessage
// StreamChat runs a streaming chat via the internal agent.
func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) {
chunkCh := make(chan conversation.StreamChunk)
errCh := make(chan error, 1)
r.logger.Info("agent stream start",
slog.String("bot_id", req.BotID),
slog.String("chat_id", req.ChatID),
)
go func() {
defer close(chunkCh)
defer close(errCh)
streamReq := req
rc, err := r.resolve(ctx, streamReq)
if err != nil {
r.logger.Error("agent stream resolve failed",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.Any("error", err),
)
errCh <- err
return
}
streamReq.Query = rc.query
cfg := rc.runConfig
cfg = r.prepareRunConfig(ctx, cfg)
eventCh := r.agent.Stream(ctx, cfg)
stored := false
for event := range eventCh {
if event.Type == agentpkg.EventError {
r.logger.Error("agent stream error",
slog.String("bot_id", streamReq.BotID),
slog.String("chat_id", streamReq.ChatID),
slog.String("model_id", rc.model.ID),
slog.String("error", event.Error),
)
}
data, err := json.Marshal(event)
if err != nil {
continue
}
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
if _, storeErr := r.tryStoreStream(ctx, streamReq, data, rc.model.ID); storeErr != nil {
r.logger.Error("stream persist failed", slog.Any("error", storeErr))
} else {
stored = true
}
}
chunkCh <- conversation.StreamChunk(data)
}
r.markInboxRead(ctx, streamReq.BotID, rc.inboxItemIDs)
}()
return chunkCh, errCh
}
// StreamChatWS resolves the agent context and streams agent events.
// Events are sent on eventCh. When abortCh is closed, the context is cancelled.
func (r *Resolver) StreamChatWS(
ctx context.Context,
req conversation.ChatRequest,
eventCh chan<- WSStreamEvent,
abortCh <-chan struct{},
) error {
rc, err := r.resolve(ctx, req)
if err != nil {
return fmt.Errorf("resolve: %w", err)
}
req.Query = rc.query
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-abortCh:
cancel()
case <-streamCtx.Done():
}
}()
cfg := rc.runConfig
cfg = r.prepareRunConfig(streamCtx, cfg)
agentEventCh := r.agent.Stream(streamCtx, cfg)
modelID := rc.model.ID
stored := false
for event := range agentEventCh {
if event.Type == agentpkg.EventError {
r.logger.Error("agent stream error",
slog.String("bot_id", req.BotID),
slog.String("chat_id", req.ChatID),
slog.String("model_id", modelID),
slog.String("error", event.Error),
)
}
data, err := json.Marshal(event)
if err != nil {
continue
}
if !stored && event.IsTerminal() && len(event.Messages) > 0 {
if _, storeErr := r.tryStoreStream(ctx, req, data, modelID); storeErr != nil {
r.logger.Error("ws persist failed", slog.Any("error", storeErr))
} else {
stored = true
}
}
select {
case eventCh <- json.RawMessage(data):
case <-ctx.Done():
return ctx.Err()
}
}
r.markInboxRead(ctx, req.BotID, rc.inboxItemIDs)
return nil
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, data []byte, modelID string) (bool, error) {
var envelope struct {
Type string `json:"type"`
Messages json.RawMessage `json:"messages"`
Usage json.RawMessage `json:"usage,omitempty"`
Usages json.RawMessage `json:"usages,omitempty"`
}
if err := json.Unmarshal(data, &envelope); err != nil {
return false, nil
}
if len(envelope.Messages) == 0 {
return false, nil
}
var sdkMsgs []sdk.Message
if err := json.Unmarshal(envelope.Messages, &sdkMsgs); err != nil || len(sdkMsgs) == 0 {
return false, nil
}
outputMessages := sdkMessagesToModelMessages(sdkMsgs)
roundMessages := prependUserMessage(req.Query, outputMessages)
var usages []json.RawMessage
if len(envelope.Usages) > 0 {
_ = json.Unmarshal(envelope.Usages, &usages)
}
return true, r.storeRound(ctx, req, roundMessages, envelope.Usage, usages, modelID)
}
@@ -1,142 +0,0 @@
package flow
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/memohai/memoh/internal/conversation"
messagepkg "github.com/memohai/memoh/internal/message"
)
type blockingMessageService struct {
persistCalled chan struct{}
persistContinue chan struct{}
}
func (s *blockingMessageService) Persist(_ context.Context, _ messagepkg.PersistInput) (messagepkg.Message, error) {
select {
case <-s.persistCalled:
default:
close(s.persistCalled)
}
<-s.persistContinue
return messagepkg.Message{}, nil
}
func (*blockingMessageService) List(_ context.Context, _ string) ([]messagepkg.Message, error) {
return nil, nil
}
func (*blockingMessageService) ListSince(_ context.Context, _ string, _ time.Time) ([]messagepkg.Message, error) {
return nil, nil
}
func (*blockingMessageService) ListActiveSince(_ context.Context, _ string, _ time.Time) ([]messagepkg.Message, error) {
return nil, nil
}
func (*blockingMessageService) ListLatest(_ context.Context, _ string, _ int32) ([]messagepkg.Message, error) {
return nil, nil
}
func (*blockingMessageService) ListBefore(_ context.Context, _ string, _ time.Time, _ int32) ([]messagepkg.Message, error) {
return nil, nil
}
func (*blockingMessageService) DeleteByBot(_ context.Context, _ string) error {
return nil
}
func TestStreamChat_PersistsFinalMessagesBeforeForwardingDoneEvent(t *testing.T) {
t.Parallel()
msgSvc := &blockingMessageService{
persistCalled: make(chan struct{}),
persistContinue: make(chan struct{}),
}
doneResp := gatewayResponse{
Messages: []conversation.ModelMessage{
{Role: "assistant", Content: conversation.NewTextContent("ok")},
},
}
doneData, err := json.Marshal(doneResp)
if err != nil {
t.Fatalf("marshal done response: %v", err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/stream" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
_, _ = w.Write([]byte("event: done\n"))
_, _ = w.Write([]byte("data: "))
_, _ = w.Write(doneData)
_, _ = w.Write([]byte("\n\n"))
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
t.Cleanup(srv.Close)
r := &Resolver{
messageService: msgSvc,
gatewayBaseURL: srv.URL,
logger: slog.New(slog.DiscardHandler),
streamingClient: srv.Client(),
httpClient: srv.Client(),
}
chunkCh := make(chan conversation.StreamChunk, 10)
req := conversation.ChatRequest{BotID: "bot-test", ChatID: "chat-test"}
payload := gatewayRequest{}
streamDone := make(chan error, 1)
go func() {
streamDone <- r.streamChat(context.Background(), payload, req, chunkCh, "model-test")
close(chunkCh)
}()
select {
case <-msgSvc.persistCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Persist to be called")
}
select {
case got := <-chunkCh:
t.Fatalf("done event forwarded before persistence finished: %s", string(got))
default:
}
close(msgSvc.persistContinue)
select {
case err := <-streamDone:
if err != nil {
t.Fatalf("streamChat returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for streamChat to finish")
}
select {
case got := <-chunkCh:
if len(got) == 0 {
t.Fatal("expected forwarded done event data")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for forwarded done event data")
}
}
-259
View File
@@ -7,170 +7,13 @@ import (
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/models"
)
func TestPostTriggerSchedule_Endpoint(t *testing.T) {
var capturedPath string
var capturedBody []byte
var capturedAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAuth = r.Header.Get("Authorization")
capturedBody, _ = io.ReadAll(r.Body)
resp := gatewayResponse{
Messages: []conversation.ModelMessage{{Role: "assistant", Content: conversation.NewTextContent("ok")}},
}
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(resp))
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
maxCalls := 5
req := triggerScheduleRequest{
gatewayRequest: gatewayRequest{
Model: gatewayModelConfig{
ModelID: "gpt-4",
ClientType: "openai",
APIKey: "sk-test",
BaseURL: "https://api.openai.com",
},
ActiveContextTime: 1440,
Channels: []string{},
Messages: []conversation.ModelMessage{},
Skills: []string{},
Identity: gatewayIdentity{
BotID: "bot-123",
ChannelIdentityID: "owner-user-1",
DisplayName: "Scheduler",
},
Attachments: []any{},
},
Schedule: gatewaySchedule{
ID: "sched-1",
Name: "daily report",
Description: "generate daily report",
Pattern: "0 9 * * *",
MaxCalls: &maxCalls,
Command: "generate the daily report",
},
}
resp, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer test-token")
if err != nil {
t.Fatalf("postTriggerSchedule returned error: %v", err)
}
if capturedPath != "/chat/trigger-schedule" {
t.Errorf("expected path /chat/trigger-schedule, got %s", capturedPath)
}
if capturedAuth != "Bearer test-token" {
t.Errorf("expected Authorization header 'Bearer test-token', got %s", capturedAuth)
}
if len(resp.Messages) != 1 {
t.Errorf("expected 1 message, got %d", len(resp.Messages))
}
var body map[string]any
if err := json.Unmarshal(capturedBody, &body); err != nil {
t.Fatalf("failed to parse captured body: %v", err)
}
schedule, ok := body["schedule"].(map[string]any)
if !ok {
t.Fatal("expected 'schedule' field in request body")
}
if schedule["id"] != "sched-1" {
t.Errorf("expected schedule.id=sched-1, got %v", schedule["id"])
}
if schedule["command"] != "generate the daily report" {
t.Errorf("expected schedule.command, got %v", schedule["command"])
}
if _, hasQuery := body["query"]; hasQuery {
t.Error("trigger-schedule request should not contain 'query' field")
}
}
func TestPostTriggerSchedule_NoAuth(t *testing.T) {
var capturedAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
resp := gatewayResponse{Messages: []conversation.ModelMessage{}}
require.NoError(t, json.NewEncoder(w).Encode(resp))
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
req := triggerScheduleRequest{
gatewayRequest: gatewayRequest{
Channels: []string{},
Messages: []conversation.ModelMessage{},
Skills: []string{},
Attachments: []any{},
},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "")
if err != nil {
t.Fatalf("postTriggerSchedule returned error: %v", err)
}
if capturedAuth != "" {
t.Errorf("expected no Authorization header, got %s", capturedAuth)
}
}
func TestPostTriggerSchedule_GatewayError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, err := w.Write([]byte("internal error"))
require.NoError(t, err)
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
httpClient: &http.Client{Timeout: 5 * time.Second},
logger: slog.Default(),
}
req := triggerScheduleRequest{
gatewayRequest: gatewayRequest{
Channels: []string{},
Messages: []conversation.ModelMessage{},
Skills: []string{},
Attachments: []any{},
},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok")
if err == nil {
t.Fatal("expected error for 500 response")
}
}
type fakeGatewayAssetLoader struct {
openFn func(ctx context.Context, botID, contentHash string) (io.ReadCloser, string, error)
}
@@ -245,108 +88,6 @@ func TestPrepareGatewayAttachments_DataURLFromURLFieldIsNativeInline(t *testing.
}
}
func TestStreamChat_AllowsLargeSSEDataLines(t *testing.T) {
const overOldScannerLimit = 3 * 1024 * 1024
hugeDelta := strings.Repeat("a", overOldScannerLimit)
dataJSON, err := json.Marshal(map[string]any{
"type": "text_delta",
"delta": hugeDelta,
})
if err != nil {
t.Fatalf("failed to marshal test payload: %v", err)
}
dataStr := string(dataJSON)
parts := make([]string, 0, (len(dataStr)/8192)+1)
for i := 0; i < len(dataStr); i += 8192 {
end := i + 8192
if end > len(dataStr) {
end = len(dataStr)
}
parts = append(parts, dataStr[i:end])
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/stream" {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w, "event: message\n")
for _, part := range parts {
_, _ = io.WriteString(w, "data:")
_, _ = io.WriteString(w, part)
_, _ = io.WriteString(w, "\n")
}
_, _ = io.WriteString(w, "\n")
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
streamingClient: srv.Client(),
logger: slog.Default(),
}
chunkCh := make(chan conversation.StreamChunk, 1)
err = resolver.streamChat(
context.Background(),
gatewayRequest{},
conversation.ChatRequest{},
chunkCh,
"model-test",
)
if err != nil {
t.Fatalf("streamChat returned error: %v", err)
}
select {
case chunk := <-chunkCh:
if !bytes.Equal(chunk, dataJSON) {
t.Fatalf("unexpected reconstructed payload: got prefix %q", string(chunk[:minInt(len(chunk), 80)]))
}
default:
t.Fatalf("expected at least one streamed chunk")
}
}
func TestStreamChat_RejectsOverLimitSSELine(t *testing.T) {
tooLong := strings.Repeat("x", gatewaySSEMaxLineBytes+10)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/stream" {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w, "event: message\n")
_, _ = io.WriteString(w, "data:")
_, _ = io.WriteString(w, tooLong)
_, _ = io.WriteString(w, "\n\n")
}))
defer srv.Close()
resolver := &Resolver{
gatewayBaseURL: srv.URL,
streamingClient: srv.Client(),
logger: slog.Default(),
}
chunkCh := make(chan conversation.StreamChunk, 1)
err := resolver.streamChat(context.Background(), gatewayRequest{}, conversation.ChatRequest{}, chunkCh, "model-test")
if err == nil {
t.Fatalf("expected streamChat to error on oversized SSE line")
}
if !strings.Contains(err.Error(), "sse line too long") {
t.Fatalf("expected line-too-long error, got: %v", err)
}
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func TestPrepareGatewayAttachments_PublicURLFromURLFieldIsNativePublic(t *testing.T) {
resolver := &Resolver{logger: slog.Default()}
req := conversation.ChatRequest{
@@ -0,0 +1,126 @@
package flow
import (
"context"
"encoding/json"
"errors"
"strings"
sdk "github.com/memohai/twilight-ai/sdk"
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/heartbeat"
"github.com/memohai/memoh/internal/schedule"
)
// TriggerSchedule executes a scheduled command via the internal agent.
func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error {
if strings.TrimSpace(botID) == "" {
return errors.New("bot id is required")
}
if strings.TrimSpace(payload.Command) == "" {
return errors.New("schedule command is required")
}
req := conversation.ChatRequest{
BotID: botID,
ChatID: botID,
Query: payload.Command,
UserID: payload.OwnerUserID,
Token: token,
}
rc, err := r.resolve(ctx, req)
if err != nil {
return err
}
cfg := rc.runConfig
cfg.Identity.ChannelIdentityID = strings.TrimSpace(payload.OwnerUserID)
cfg.Identity.DisplayName = "Scheduler"
schedulePrompt := agentpkg.GenerateSchedulePrompt(agentpkg.Schedule{
ID: payload.ID,
Name: payload.Name,
Description: payload.Description,
Pattern: payload.Pattern,
MaxCalls: payload.MaxCalls,
Command: payload.Command,
})
cfg.Messages = append(cfg.Messages, sdk.UserMessage(schedulePrompt))
cfg = r.prepareRunConfig(ctx, cfg)
result, err := r.agent.Generate(ctx, cfg)
if err != nil {
return err
}
outputMessages := sdkMessagesToModelMessages(result.Messages)
roundMessages := prependUserMessage(req.Query, outputMessages)
usageJSON, _ := json.Marshal(result.Usage)
return r.storeRound(ctx, req, roundMessages, usageJSON, nil, rc.model.ID)
}
// TriggerHeartbeat executes a heartbeat check via the internal agent.
func (r *Resolver) TriggerHeartbeat(ctx context.Context, botID string, payload heartbeat.TriggerPayload, token string) (heartbeat.TriggerResult, error) {
if strings.TrimSpace(botID) == "" {
return heartbeat.TriggerResult{}, errors.New("bot id is required")
}
var heartbeatModel string
if botSettings, err := r.loadBotSettings(ctx, botID); err == nil {
heartbeatModel = strings.TrimSpace(botSettings.HeartbeatModelID)
}
req := conversation.ChatRequest{
BotID: botID,
ChatID: botID,
Query: "heartbeat",
UserID: payload.OwnerUserID,
Token: token,
Model: heartbeatModel,
}
rc, err := r.resolve(ctx, req)
if err != nil {
return heartbeat.TriggerResult{}, err
}
cfg := rc.runConfig
cfg.Identity.ChannelIdentityID = strings.TrimSpace(payload.OwnerUserID)
cfg.Identity.DisplayName = "Heartbeat"
var checklist string
if r.agent != nil {
fs := agentpkg.NewFSClient(nil, botID)
checklist = fs.ReadTextSafe(ctx, "/data/HEARTBEAT.md")
}
heartbeatPrompt := agentpkg.GenerateHeartbeatPrompt(payload.Interval, checklist)
cfg.Messages = append(cfg.Messages, sdk.UserMessage(heartbeatPrompt))
cfg = r.prepareRunConfig(ctx, cfg)
result, err := r.agent.Generate(ctx, cfg)
if err != nil {
return heartbeat.TriggerResult{}, err
}
status := "alert"
text := strings.TrimSpace(result.Text)
if isHeartbeatOK(text) {
status = "ok"
}
usageJSON, _ := json.Marshal(result.Usage)
return heartbeat.TriggerResult{
Status: status,
Text: text,
Usage: usageJSON,
UsageBytes: usageJSON,
ModelID: rc.model.ID,
}, nil
}
func isHeartbeatOK(text string) bool {
t := strings.TrimSpace(text)
return strings.HasPrefix(t, "HEARTBEAT_OK") || strings.HasSuffix(t, "HEARTBEAT_OK") || t == "HEARTBEAT_OK"
}
+200
View File
@@ -0,0 +1,200 @@
package flow
import (
"encoding/base64"
"encoding/json"
"errors"
"sort"
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db"
)
func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage {
cleaned := make([]conversation.ModelMessage, 0, len(messages))
for _, msg := range messages {
if normalized, ok := normalizeImagePartsToDataURL(msg); ok {
msg = normalized
}
if strings.TrimSpace(msg.Role) == "" {
continue
}
if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" {
continue
}
cleaned = append(cleaned, msg)
}
return cleaned
}
func normalizeImagePartsToDataURL(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
if len(msg.Content) == 0 {
return msg, false
}
var parts []map[string]json.RawMessage
if err := json.Unmarshal(msg.Content, &parts); err != nil || len(parts) == 0 {
return msg, false
}
changed := false
for i := range parts {
partTypeRaw, ok := parts[i]["type"]
if !ok {
continue
}
var partType string
if err := json.Unmarshal(partTypeRaw, &partType); err != nil || !strings.EqualFold(partType, "image") {
continue
}
imageRaw, ok := parts[i]["image"]
if !ok || len(imageRaw) == 0 {
continue
}
var tmp string
if json.Unmarshal(imageRaw, &tmp) == nil {
continue
}
var payload []byte
if b, ok := decodeIndexedByteObject(imageRaw); ok {
payload = b
} else if b, ok := decodeByteArray(imageRaw); ok {
payload = b
} else {
continue
}
if len(payload) == 0 {
continue
}
// action trigger to image only here.
mediaType := "application/octet-stream"
if mediaTypeRaw, ok := parts[i]["mediaType"]; ok {
var mt string
if err := json.Unmarshal(mediaTypeRaw, &mt); err == nil && strings.TrimSpace(mt) != "" {
mediaType = strings.TrimSpace(mt)
}
}
dataURL := "data:" + mediaType + ";base64," + base64.StdEncoding.EncodeToString(payload)
rebuilt, err := json.Marshal(dataURL)
if err != nil {
continue
}
parts[i]["image"] = rebuilt
changed = true
}
if !changed {
return msg, false
}
rebuiltContent, err := json.Marshal(parts)
if err != nil {
return msg, false
}
msg.Content = rebuiltContent
return msg, true
}
func decodeByteArray(raw json.RawMessage) ([]byte, bool) {
var arr []int
if err := json.Unmarshal(raw, &arr); err != nil {
return nil, false
}
if len(arr) == 0 {
return nil, false
}
out := make([]byte, len(arr))
for i, v := range arr {
if v < 0 || v > 255 {
return nil, false
}
out[i] = byte(v)
}
return out, true
}
func decodeIndexedByteObject(raw json.RawMessage) ([]byte, bool) {
var obj map[string]json.RawMessage
if err := json.Unmarshal(raw, &obj); err != nil || len(obj) == 0 {
return nil, false
}
type indexedByte struct {
idx int
val byte
}
items := make([]indexedByte, 0, len(obj))
for k, vRaw := range obj {
idx, err := strconv.Atoi(k)
if err != nil || idx < 0 {
return nil, false
}
var val int
if err := json.Unmarshal(vRaw, &val); err != nil || val < 0 || val > 255 {
return nil, false
}
items = append(items, indexedByte{idx: idx, val: byte(val)})
}
sort.Slice(items, func(i, j int) bool { return items[i].idx < items[j].idx })
for i := range items {
if items[i].idx != i {
return nil, false
}
}
out := make([]byte, len(items))
for i := range items {
out[i] = items[i].val
}
return out, true
}
func dedup(items []string) []string {
seen := make(map[string]struct{}, len(items))
result := make([]string, 0, len(items))
for _, s := range items {
trimmed := strings.TrimSpace(s)
if trimmed == "" {
continue
}
if _, ok := seen[trimmed]; ok {
continue
}
seen[trimmed] = struct{}{}
result = append(result, trimmed)
}
return result
}
func coalescePositiveInt(values ...int) int {
for _, v := range values {
if v > 0 {
return v
}
}
return defaultMaxContextMinutes
}
func nonNilStrings(s []string) []string {
if s == nil {
return []string{}
}
return s
}
func nonNilModelMessages(m []conversation.ModelMessage) []conversation.ModelMessage {
if m == nil {
return []conversation.ModelMessage{}
}
return m
}
func parseResolverUUID(id string) (pgtype.UUID, error) {
if strings.TrimSpace(id) == "" {
return pgtype.UUID{}, errors.New("empty id")
}
return db.ParseUUID(id)
}
+123
View File
@@ -0,0 +1,123 @@
package flow
import (
"strings"
"time"
)
// UserMessageMeta holds the structured metadata attached to every user
// message. It is the single source of truth shared by the YAML header
// (sent to the LLM) and the inbox content JSONB.
type UserMessageMeta struct {
MessageID string `json:"message-id,omitempty"`
ChannelIdentityID string `json:"channel-identity-id"`
DisplayName string `json:"display-name"`
Channel string `json:"channel"`
ConversationType string `json:"conversation-type"`
ConversationName string `json:"conversation-name,omitempty"`
Time string `json:"time"`
AttachmentPaths []string `json:"attachments"`
}
// BuildUserMessageMeta constructs a UserMessageMeta from the inbound
// parameters. Both FormatUserHeader and inbox content use this.
func BuildUserMessageMeta(messageID, channelIdentityID, displayName, channel, conversationType, conversationName string, attachmentPaths []string) UserMessageMeta {
if attachmentPaths == nil {
attachmentPaths = []string{}
}
return UserMessageMeta{
MessageID: messageID,
ChannelIdentityID: channelIdentityID,
DisplayName: displayName,
Channel: channel,
ConversationType: conversationType,
ConversationName: conversationName,
Time: time.Now().UTC().Format(time.RFC3339),
AttachmentPaths: attachmentPaths,
}
}
// ToMap returns the metadata as a map with the same keys used in the YAML
// header, suitable for storing as inbox content JSONB.
func (m UserMessageMeta) ToMap() map[string]any {
result := map[string]any{
"channel-identity-id": m.ChannelIdentityID,
"display-name": m.DisplayName,
"channel": m.Channel,
"conversation-type": m.ConversationType,
"time": m.Time,
"attachments": m.AttachmentPaths,
}
if m.MessageID != "" {
result["message-id"] = m.MessageID
}
if m.ConversationName != "" {
result["conversation-name"] = m.ConversationName
}
return result
}
// FormatUserHeader wraps a user query with YAML front-matter metadata so
// the LLM sees structured context (sender, channel, time, attachments)
// alongside the raw message. This must be the single source of truth for
// user-message formatting — the agent gateway must NOT add its own header.
func FormatUserHeader(messageID, channelIdentityID, displayName, channel, conversationType, conversationName string, attachmentPaths []string, query string) string {
meta := BuildUserMessageMeta(messageID, channelIdentityID, displayName, channel, conversationType, conversationName, attachmentPaths)
return FormatUserHeaderFromMeta(meta, query)
}
// FormatUserHeaderFromMeta formats a pre-built UserMessageMeta into the
// YAML front-matter string sent to the LLM.
func FormatUserHeaderFromMeta(meta UserMessageMeta, query string) string {
var sb strings.Builder
sb.WriteString("---\n")
if meta.MessageID != "" {
writeYAMLString(&sb, "message-id", meta.MessageID)
}
writeYAMLString(&sb, "channel-identity-id", meta.ChannelIdentityID)
writeYAMLString(&sb, "display-name", meta.DisplayName)
writeYAMLString(&sb, "channel", meta.Channel)
writeYAMLString(&sb, "conversation-type", meta.ConversationType)
if meta.ConversationName != "" {
writeYAMLString(&sb, "conversation-name", meta.ConversationName)
}
writeYAMLString(&sb, "time", meta.Time)
if len(meta.AttachmentPaths) > 0 {
sb.WriteString("attachments:\n")
for _, p := range meta.AttachmentPaths {
sb.WriteString(" - ")
sb.WriteString(p)
sb.WriteByte('\n')
}
} else {
sb.WriteString("attachments: []\n")
}
sb.WriteString("---\n")
sb.WriteString(query)
return sb.String()
}
func writeYAMLString(sb *strings.Builder, key, value string) {
sb.WriteString(key)
sb.WriteString(": ")
if value == "" || needsYAMLQuote(value) {
sb.WriteByte('"')
sb.WriteString(strings.ReplaceAll(value, `"`, `\"`))
sb.WriteByte('"')
} else {
sb.WriteString(value)
}
sb.WriteByte('\n')
}
func needsYAMLQuote(s string) bool {
if s == "" {
return true
}
for _, c := range s {
if c == ':' || c == '#' || c == '"' || c == '\'' || c == '{' || c == '}' || c == '[' || c == ']' || c == ',' || c == '\n' {
return true
}
}
return false
}