mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: add media asset system, channel lifecycle refactor, and chat attachments (#54)
This commit is contained in:
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -15,6 +17,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/channel/route"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/conversation/flow"
|
||||
"github.com/memohai/memoh/internal/media"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
)
|
||||
|
||||
@@ -33,11 +36,19 @@ type RouteResolver interface {
|
||||
ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error)
|
||||
}
|
||||
|
||||
type mediaIngestor interface {
|
||||
Ingest(ctx context.Context, input media.IngestInput) (media.Asset, error)
|
||||
// AccessPath returns a consumer-accessible reference for a persisted asset.
|
||||
// The format depends on the storage backend (e.g. container path, URL).
|
||||
AccessPath(asset media.Asset) string
|
||||
}
|
||||
|
||||
// ChannelInboundProcessor routes channel inbound messages to the chat gateway.
|
||||
type ChannelInboundProcessor struct {
|
||||
runner flow.Runner
|
||||
routeResolver RouteResolver
|
||||
message messagepkg.Writer
|
||||
mediaService mediaIngestor
|
||||
registry *channel.Registry
|
||||
logger *slog.Logger
|
||||
jwtSecret string
|
||||
@@ -87,6 +98,14 @@ func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware {
|
||||
return p.identity.Middleware()
|
||||
}
|
||||
|
||||
// SetMediaService configures media ingestion support for inbound attachments.
|
||||
func (p *ChannelInboundProcessor) SetMediaService(mediaService mediaIngestor) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
p.mediaService = mediaService
|
||||
}
|
||||
|
||||
// HandleInbound processes an inbound channel message through identity resolution and chat gateway.
|
||||
func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error {
|
||||
if p.runner == nil {
|
||||
@@ -96,7 +115,20 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
return fmt.Errorf("reply sender not configured")
|
||||
}
|
||||
text := buildInboundQuery(msg.Message)
|
||||
if strings.TrimSpace(text) == "" {
|
||||
if p.logger != nil {
|
||||
p.logger.Debug("inbound handle start",
|
||||
slog.String("channel", msg.Channel.String()),
|
||||
slog.String("message_id", strings.TrimSpace(msg.Message.ID)),
|
||||
slog.String("query", strings.TrimSpace(text)),
|
||||
slog.Int("attachments", len(msg.Message.Attachments)),
|
||||
slog.String("conversation_type", strings.TrimSpace(msg.Conversation.Type)),
|
||||
slog.String("conversation_id", strings.TrimSpace(msg.Conversation.ID)),
|
||||
)
|
||||
}
|
||||
if strings.TrimSpace(text) == "" && len(msg.Message.Attachments) == 0 {
|
||||
if p.logger != nil {
|
||||
p.logger.Debug("inbound dropped empty", slog.String("channel", msg.Channel.String()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
state, err := p.requireIdentity(ctx, cfg, msg)
|
||||
@@ -123,6 +155,8 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
}
|
||||
|
||||
identity := state.Identity
|
||||
resolvedAttachments := p.ingestInboundAttachments(ctx, cfg, msg, strings.TrimSpace(identity.BotID), msg.Message.Attachments)
|
||||
attachments := mapChannelAttachments(resolvedAttachments)
|
||||
|
||||
// Resolve or create the route via channel_routes.
|
||||
if p.routeResolver == nil {
|
||||
@@ -157,12 +191,14 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
slog.Bool("is_mentioned", metadataBool(msg.Metadata, "is_mentioned")),
|
||||
slog.Bool("is_reply_to_bot", metadataBool(msg.Metadata, "is_reply_to_bot")),
|
||||
slog.String("conversation_type", strings.TrimSpace(msg.Conversation.Type)),
|
||||
slog.String("query", strings.TrimSpace(text)),
|
||||
slog.Int("attachments", len(attachments)),
|
||||
)
|
||||
}
|
||||
p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, "passive_sync")
|
||||
p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, attachments, "passive_sync")
|
||||
return nil
|
||||
}
|
||||
userMessagePersisted := p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, "active_chat")
|
||||
userMessagePersisted := p.persistInboundUser(ctx, resolved.RouteID, identity, msg, text, attachments, "active_chat")
|
||||
|
||||
// Issue chat token for reply routing.
|
||||
chatToken := ""
|
||||
@@ -284,6 +320,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
CurrentChannel: msg.Channel.String(),
|
||||
Channels: []string{msg.Channel.String()},
|
||||
UserMessagePersisted: userMessagePersisted,
|
||||
Attachments: attachments,
|
||||
})
|
||||
|
||||
var (
|
||||
@@ -507,7 +544,15 @@ func metadataBool(metadata map[string]any, key string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ChannelInboundProcessor) persistInboundUser(ctx context.Context, routeID string, identity InboundIdentity, msg channel.InboundMessage, query string, triggerMode string) bool {
|
||||
func (p *ChannelInboundProcessor) persistInboundUser(
|
||||
ctx context.Context,
|
||||
routeID string,
|
||||
identity InboundIdentity,
|
||||
msg channel.InboundMessage,
|
||||
query string,
|
||||
attachments []conversation.ChatAttachment,
|
||||
triggerMode string,
|
||||
) bool {
|
||||
if p.message == nil {
|
||||
return false
|
||||
}
|
||||
@@ -540,6 +585,7 @@ func (p *ChannelInboundProcessor) persistInboundUser(ctx context.Context, routeI
|
||||
Role: "user",
|
||||
Content: payload,
|
||||
Metadata: meta,
|
||||
Assets: chatAttachmentsToAssetRefs(attachments),
|
||||
}); err != nil && p.logger != nil {
|
||||
p.logger.Warn("persist inbound user message failed", slog.Any("error", err))
|
||||
return false
|
||||
@@ -651,8 +697,15 @@ type gatewayStreamEnvelope struct {
|
||||
Delta string `json:"delta"`
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Image string `json:"image"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Messages []conversation.ModelMessage `json:"messages"`
|
||||
|
||||
ToolName string `json:"toolName"`
|
||||
ToolCallID string `json:"toolCallId"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
Attachments json.RawMessage `json:"attachments"`
|
||||
}
|
||||
|
||||
type gatewayStreamDoneData struct {
|
||||
@@ -685,6 +738,7 @@ func mapStreamChunkToChannelEvents(chunk conversation.StreamChunk) ([]channel.St
|
||||
{
|
||||
Type: channel.StreamEventDelta,
|
||||
Delta: envelope.Delta,
|
||||
Phase: channel.StreamPhaseText,
|
||||
},
|
||||
}, finalMessages, nil
|
||||
case "reasoning_delta":
|
||||
@@ -695,11 +749,95 @@ func mapStreamChunkToChannelEvents(chunk conversation.StreamChunk) ([]channel.St
|
||||
{
|
||||
Type: channel.StreamEventDelta,
|
||||
Delta: envelope.Delta,
|
||||
Metadata: map[string]any{
|
||||
"phase": "reasoning",
|
||||
Phase: channel.StreamPhaseReasoning,
|
||||
},
|
||||
}, finalMessages, nil
|
||||
case "tool_call_start":
|
||||
return []channel.StreamEvent{
|
||||
{
|
||||
Type: channel.StreamEventToolCallStart,
|
||||
ToolCall: &channel.StreamToolCall{
|
||||
Name: strings.TrimSpace(envelope.ToolName),
|
||||
CallID: strings.TrimSpace(envelope.ToolCallID),
|
||||
Input: parseRawJSON(envelope.Input),
|
||||
},
|
||||
},
|
||||
}, finalMessages, nil
|
||||
case "tool_call_end":
|
||||
return []channel.StreamEvent{
|
||||
{
|
||||
Type: channel.StreamEventToolCallEnd,
|
||||
ToolCall: &channel.StreamToolCall{
|
||||
Name: strings.TrimSpace(envelope.ToolName),
|
||||
CallID: strings.TrimSpace(envelope.ToolCallID),
|
||||
Input: parseRawJSON(envelope.Input),
|
||||
Result: parseRawJSON(envelope.Result),
|
||||
},
|
||||
},
|
||||
}, finalMessages, nil
|
||||
case "reasoning_start":
|
||||
return []channel.StreamEvent{
|
||||
{Type: channel.StreamEventPhaseStart, Phase: channel.StreamPhaseReasoning},
|
||||
}, finalMessages, nil
|
||||
case "reasoning_end":
|
||||
return []channel.StreamEvent{
|
||||
{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseReasoning},
|
||||
}, finalMessages, nil
|
||||
case "text_start":
|
||||
return []channel.StreamEvent{
|
||||
{Type: channel.StreamEventPhaseStart, Phase: channel.StreamPhaseText},
|
||||
}, finalMessages, nil
|
||||
case "text_end":
|
||||
return []channel.StreamEvent{
|
||||
{Type: channel.StreamEventPhaseEnd, Phase: channel.StreamPhaseText},
|
||||
}, finalMessages, nil
|
||||
case "attachment_delta":
|
||||
attachments := parseAttachmentDelta(envelope.Attachments)
|
||||
if len(attachments) == 0 {
|
||||
return nil, finalMessages, nil
|
||||
}
|
||||
return []channel.StreamEvent{
|
||||
{Type: channel.StreamEventAttachment, Attachments: attachments},
|
||||
}, finalMessages, nil
|
||||
case "agent_start":
|
||||
return []channel.StreamEvent{
|
||||
{
|
||||
Type: channel.StreamEventAgentStart,
|
||||
Metadata: map[string]any{
|
||||
"input": parseRawJSON(envelope.Input),
|
||||
"data": parseRawJSON(envelope.Data),
|
||||
},
|
||||
},
|
||||
}, finalMessages, nil
|
||||
case "agent_end":
|
||||
return []channel.StreamEvent{
|
||||
{
|
||||
Type: channel.StreamEventAgentEnd,
|
||||
Metadata: map[string]any{
|
||||
"result": parseRawJSON(envelope.Result),
|
||||
"data": parseRawJSON(envelope.Data),
|
||||
},
|
||||
},
|
||||
}, finalMessages, nil
|
||||
case "processing_started":
|
||||
return []channel.StreamEvent{
|
||||
{Type: channel.StreamEventProcessingStarted},
|
||||
}, finalMessages, nil
|
||||
case "processing_completed":
|
||||
return []channel.StreamEvent{
|
||||
{Type: channel.StreamEventProcessingCompleted},
|
||||
}, finalMessages, nil
|
||||
case "processing_failed":
|
||||
streamError := strings.TrimSpace(envelope.Error)
|
||||
if streamError == "" {
|
||||
streamError = strings.TrimSpace(envelope.Message)
|
||||
}
|
||||
return []channel.StreamEvent{
|
||||
{
|
||||
Type: channel.StreamEventProcessingFailed,
|
||||
Error: streamError,
|
||||
},
|
||||
}, finalMessages, nil
|
||||
case "error":
|
||||
streamError := strings.TrimSpace(envelope.Error)
|
||||
if streamError == "" {
|
||||
@@ -720,25 +858,7 @@ func mapStreamChunkToChannelEvents(chunk conversation.StreamChunk) ([]channel.St
|
||||
}
|
||||
|
||||
func buildInboundQuery(message channel.Message) string {
|
||||
text := strings.TrimSpace(message.PlainText())
|
||||
if len(message.Attachments) == 0 {
|
||||
return text
|
||||
}
|
||||
lines := make([]string, 0, len(message.Attachments)+1)
|
||||
if text != "" {
|
||||
lines = append(lines, text)
|
||||
}
|
||||
for _, att := range message.Attachments {
|
||||
label := strings.TrimSpace(att.Name)
|
||||
if label == "" {
|
||||
label = strings.TrimSpace(att.Reference())
|
||||
}
|
||||
if label == "" {
|
||||
label = "unknown"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("[attachment:%s] %s", att.Type, label))
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
return strings.TrimSpace(message.PlainText())
|
||||
}
|
||||
|
||||
func normalizeContentPartType(raw string) channel.MessagePartType {
|
||||
@@ -1043,3 +1163,298 @@ func (p *ChannelInboundProcessor) logProcessingStatusError(
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
|
||||
// parseRawJSON converts raw JSON bytes to a typed value for StreamToolCall fields.
|
||||
func parseRawJSON(raw json.RawMessage) any {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
var v any
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return string(raw)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// mapChannelAttachments converts channel.Attachment slice to conversation.ChatAttachment slice.
|
||||
// When an attachment has been ingested (AssetID is set), the URL field contains
|
||||
// the container-internal path; it is mapped to Path for downstream consumers.
|
||||
func mapChannelAttachments(attachments []channel.Attachment) []conversation.ChatAttachment {
|
||||
if len(attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]conversation.ChatAttachment, 0, len(attachments))
|
||||
for _, att := range attachments {
|
||||
ca := conversation.ChatAttachment{
|
||||
Type: string(att.Type),
|
||||
PlatformKey: att.PlatformKey,
|
||||
AssetID: att.AssetID,
|
||||
Name: att.Name,
|
||||
Mime: att.Mime,
|
||||
Size: att.Size,
|
||||
Metadata: att.Metadata,
|
||||
}
|
||||
if strings.TrimSpace(att.AssetID) != "" {
|
||||
ca.Path = att.URL
|
||||
ca.Base64 = att.Base64
|
||||
} else {
|
||||
ca.URL = att.URL
|
||||
}
|
||||
result = append(result, ca)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *ChannelInboundProcessor) ingestInboundAttachments(
|
||||
ctx context.Context,
|
||||
cfg channel.ChannelConfig,
|
||||
msg channel.InboundMessage,
|
||||
botID string,
|
||||
attachments []channel.Attachment,
|
||||
) []channel.Attachment {
|
||||
if len(attachments) == 0 || p == nil || p.mediaService == nil || strings.TrimSpace(botID) == "" {
|
||||
return attachments
|
||||
}
|
||||
result := make([]channel.Attachment, 0, len(attachments))
|
||||
for _, att := range attachments {
|
||||
item := att
|
||||
if strings.TrimSpace(item.AssetID) != "" {
|
||||
result = append(result, item)
|
||||
continue
|
||||
}
|
||||
payload, err := p.loadInboundAttachmentPayload(ctx, cfg, msg, item)
|
||||
if err != nil {
|
||||
if p.logger != nil {
|
||||
p.logger.Warn(
|
||||
"inbound attachment ingest skipped",
|
||||
slog.Any("error", err),
|
||||
slog.String("attachment_type", strings.TrimSpace(string(item.Type))),
|
||||
slog.String("attachment_url", strings.TrimSpace(item.URL)),
|
||||
slog.String("platform_key", strings.TrimSpace(item.PlatformKey)),
|
||||
)
|
||||
}
|
||||
result = append(result, item)
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(item.Mime) == "" {
|
||||
item.Mime = strings.TrimSpace(payload.mime)
|
||||
}
|
||||
if strings.TrimSpace(item.Name) == "" {
|
||||
item.Name = strings.TrimSpace(payload.name)
|
||||
}
|
||||
if item.Size == 0 && payload.size > 0 {
|
||||
item.Size = payload.size
|
||||
}
|
||||
maxBytes := media.MaxAssetBytes
|
||||
asset, err := p.mediaService.Ingest(ctx, media.IngestInput{
|
||||
BotID: botID,
|
||||
MediaType: mapInboundAttachmentMediaType(string(item.Type)),
|
||||
Mime: strings.TrimSpace(item.Mime),
|
||||
OriginalName: strings.TrimSpace(item.Name),
|
||||
Metadata: item.Metadata,
|
||||
Reader: payload.reader,
|
||||
MaxBytes: maxBytes,
|
||||
})
|
||||
if payload.reader != nil {
|
||||
_ = payload.reader.Close()
|
||||
}
|
||||
if err != nil {
|
||||
if p.logger != nil {
|
||||
p.logger.Warn(
|
||||
"inbound attachment ingest failed",
|
||||
slog.Any("error", err),
|
||||
slog.String("attachment_type", strings.TrimSpace(string(item.Type))),
|
||||
slog.String("attachment_url", strings.TrimSpace(item.URL)),
|
||||
slog.String("platform_key", strings.TrimSpace(item.PlatformKey)),
|
||||
)
|
||||
}
|
||||
result = append(result, item)
|
||||
continue
|
||||
}
|
||||
item.AssetID = asset.ID
|
||||
item.URL = p.mediaService.AccessPath(asset)
|
||||
item.PlatformKey = ""
|
||||
if strings.TrimSpace(item.Mime) == "" {
|
||||
item.Mime = strings.TrimSpace(asset.Mime)
|
||||
}
|
||||
if item.Size == 0 && asset.SizeBytes > 0 {
|
||||
item.Size = asset.SizeBytes
|
||||
}
|
||||
result = append(result, item)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type inboundAttachmentPayload struct {
|
||||
reader io.ReadCloser
|
||||
mime string
|
||||
name string
|
||||
size int64
|
||||
}
|
||||
|
||||
func (p *ChannelInboundProcessor) loadInboundAttachmentPayload(
|
||||
ctx context.Context,
|
||||
cfg channel.ChannelConfig,
|
||||
msg channel.InboundMessage,
|
||||
att channel.Attachment,
|
||||
) (inboundAttachmentPayload, error) {
|
||||
rawURL := strings.TrimSpace(att.URL)
|
||||
if rawURL != "" {
|
||||
payload, err := openInboundAttachmentURL(ctx, rawURL)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(att.Mime) != "" {
|
||||
payload.mime = strings.TrimSpace(att.Mime)
|
||||
}
|
||||
if strings.TrimSpace(payload.name) == "" {
|
||||
payload.name = strings.TrimSpace(att.Name)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
// When URL download fails and platform_key exists, attempt resolver fallback.
|
||||
if strings.TrimSpace(att.PlatformKey) == "" {
|
||||
return inboundAttachmentPayload{}, err
|
||||
}
|
||||
}
|
||||
platformKey := strings.TrimSpace(att.PlatformKey)
|
||||
if platformKey == "" {
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("attachment has no ingestible payload")
|
||||
}
|
||||
resolver := p.resolveAttachmentResolver(msg.Channel)
|
||||
if resolver == nil {
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("attachment resolver not supported for channel: %s", msg.Channel.String())
|
||||
}
|
||||
resolved, err := resolver.ResolveAttachment(ctx, cfg, att)
|
||||
if err != nil {
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("resolve attachment by platform key: %w", err)
|
||||
}
|
||||
if resolved.Reader == nil {
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("resolved attachment reader is nil")
|
||||
}
|
||||
mime := strings.TrimSpace(att.Mime)
|
||||
if mime == "" {
|
||||
mime = strings.TrimSpace(resolved.Mime)
|
||||
}
|
||||
name := strings.TrimSpace(att.Name)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(resolved.Name)
|
||||
}
|
||||
return inboundAttachmentPayload{
|
||||
reader: resolved.Reader,
|
||||
mime: mime,
|
||||
name: name,
|
||||
size: resolved.Size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func openInboundAttachmentURL(ctx context.Context, rawURL string) (inboundAttachmentPayload, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
client := &http.Client{Timeout: 20 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("download attachment: %w", err)
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
_ = resp.Body.Close()
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("download attachment status: %d", resp.StatusCode)
|
||||
}
|
||||
maxBytes := media.MaxAssetBytes
|
||||
if resp.ContentLength > maxBytes {
|
||||
_ = resp.Body.Close()
|
||||
return inboundAttachmentPayload{}, fmt.Errorf("%w: max %d bytes", media.ErrAssetTooLarge, maxBytes)
|
||||
}
|
||||
mime := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
if idx := strings.Index(mime, ";"); idx >= 0 {
|
||||
mime = strings.TrimSpace(mime[:idx])
|
||||
}
|
||||
return inboundAttachmentPayload{
|
||||
reader: resp.Body,
|
||||
mime: mime,
|
||||
size: resp.ContentLength,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *ChannelInboundProcessor) resolveAttachmentResolver(channelType channel.ChannelType) channel.AttachmentResolver {
|
||||
if p == nil || p.registry == nil {
|
||||
return nil
|
||||
}
|
||||
resolver, ok := p.registry.GetAttachmentResolver(channelType)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return resolver
|
||||
}
|
||||
|
||||
func mapInboundAttachmentMediaType(t string) media.MediaType {
|
||||
switch strings.ToLower(strings.TrimSpace(t)) {
|
||||
case "image", "gif":
|
||||
return media.MediaTypeImage
|
||||
case "audio", "voice":
|
||||
return media.MediaTypeAudio
|
||||
case "video":
|
||||
return media.MediaTypeVideo
|
||||
default:
|
||||
return media.MediaTypeFile
|
||||
}
|
||||
}
|
||||
|
||||
func chatAttachmentsToAssetRefs(attachments []conversation.ChatAttachment) []messagepkg.AssetRef {
|
||||
if len(attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
refs := make([]messagepkg.AssetRef, 0, len(attachments))
|
||||
for idx, att := range attachments {
|
||||
assetID := strings.TrimSpace(att.AssetID)
|
||||
if assetID == "" {
|
||||
continue
|
||||
}
|
||||
refs = append(refs, messagepkg.AssetRef{
|
||||
AssetID: assetID,
|
||||
Role: "attachment",
|
||||
Ordinal: idx,
|
||||
})
|
||||
}
|
||||
if len(refs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return refs
|
||||
}
|
||||
|
||||
// parseAttachmentDelta converts raw JSON attachment data to channel Attachments.
|
||||
func parseAttachmentDelta(raw json.RawMessage) []channel.Attachment {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
var items []struct {
|
||||
Type string `json:"type"`
|
||||
URL string `json:"url"`
|
||||
Path string `json:"path"`
|
||||
PlatformKey string `json:"platform_key"`
|
||||
AssetID string `json:"asset_id"`
|
||||
Name string `json:"name"`
|
||||
Mime string `json:"mime"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &items); err != nil {
|
||||
return nil
|
||||
}
|
||||
attachments := make([]channel.Attachment, 0, len(items))
|
||||
for _, item := range items {
|
||||
url := strings.TrimSpace(item.URL)
|
||||
if url == "" {
|
||||
url = strings.TrimSpace(item.Path)
|
||||
}
|
||||
attachments = append(attachments, channel.Attachment{
|
||||
Type: channel.AttachmentType(strings.TrimSpace(item.Type)),
|
||||
URL: url,
|
||||
PlatformKey: strings.TrimSpace(item.PlatformKey),
|
||||
AssetID: strings.TrimSpace(item.AssetID),
|
||||
Name: strings.TrimSpace(item.Name),
|
||||
Mime: strings.TrimSpace(item.Mime),
|
||||
Size: item.Size,
|
||||
})
|
||||
}
|
||||
return attachments
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -12,6 +15,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/channel/identities"
|
||||
"github.com/memohai/memoh/internal/channel/route"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/media"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
)
|
||||
@@ -168,6 +172,73 @@ type fakeChatService struct {
|
||||
resolveResult route.ResolveConversationResult
|
||||
resolveErr error
|
||||
persisted []messagepkg.Message
|
||||
persistedIn []messagepkg.PersistInput
|
||||
}
|
||||
|
||||
type fakeMediaIngestor struct {
|
||||
nextID string
|
||||
nextMime string
|
||||
ingestErr error
|
||||
calls int
|
||||
inputs []media.IngestInput
|
||||
}
|
||||
|
||||
func (f *fakeMediaIngestor) Ingest(ctx context.Context, input media.IngestInput) (media.Asset, error) {
|
||||
f.calls++
|
||||
f.inputs = append(f.inputs, input)
|
||||
if input.Reader != nil {
|
||||
_, _ = io.ReadAll(input.Reader)
|
||||
}
|
||||
if f.ingestErr != nil {
|
||||
return media.Asset{}, f.ingestErr
|
||||
}
|
||||
id := strings.TrimSpace(f.nextID)
|
||||
if id == "" {
|
||||
id = "asset-test-id"
|
||||
}
|
||||
mime := strings.TrimSpace(f.nextMime)
|
||||
if mime == "" {
|
||||
mime = strings.TrimSpace(input.Mime)
|
||||
}
|
||||
return media.Asset{
|
||||
ID: id,
|
||||
Mime: mime,
|
||||
StorageKey: input.BotID + "/" + string(input.MediaType) + "/test/" + id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeMediaIngestor) AccessPath(asset media.Asset) string {
|
||||
sub := asset.StorageKey
|
||||
if idx := strings.IndexByte(sub, '/'); idx >= 0 {
|
||||
sub = sub[idx+1:]
|
||||
}
|
||||
return "/data/media/" + sub
|
||||
}
|
||||
|
||||
type fakeAttachmentResolverAdapter struct{}
|
||||
|
||||
func (a *fakeAttachmentResolverAdapter) Type() channel.ChannelType {
|
||||
return channel.ChannelType("resolver-test")
|
||||
}
|
||||
|
||||
func (a *fakeAttachmentResolverAdapter) Descriptor() channel.Descriptor {
|
||||
return channel.Descriptor{
|
||||
Type: channel.ChannelType("resolver-test"),
|
||||
DisplayName: "ResolverTest",
|
||||
Capabilities: channel.ChannelCapabilities{
|
||||
Text: true,
|
||||
Attachments: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *fakeAttachmentResolverAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) {
|
||||
return channel.AttachmentPayload{
|
||||
Reader: io.NopCloser(strings.NewReader("resolver-bytes")),
|
||||
Mime: "application/octet-stream",
|
||||
Name: "resolver.bin",
|
||||
Size: int64(len("resolver-bytes")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatService) ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error) {
|
||||
@@ -178,6 +249,7 @@ func (f *fakeChatService) ResolveConversation(ctx context.Context, input route.R
|
||||
}
|
||||
|
||||
func (f *fakeChatService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) {
|
||||
f.persistedIn = append(f.persistedIn, input)
|
||||
msg := messagepkg.Message{
|
||||
BotID: input.BotID,
|
||||
RouteID: input.RouteID,
|
||||
@@ -432,6 +504,125 @@ func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelInboundProcessorPersistsAttachmentAssetRefs(t *testing.T) {
|
||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-asset"}}
|
||||
memberSvc := &fakeMemberService{isMember: true}
|
||||
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-asset", RouteID: "route-asset"}}
|
||||
gateway := &fakeChatGateway{
|
||||
resp: conversation.ChatResponse{
|
||||
Messages: []conversation.ModelMessage{
|
||||
{Role: "assistant", Content: conversation.NewTextContent("ok")},
|
||||
},
|
||||
},
|
||||
}
|
||||
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0)
|
||||
sender := &fakeReplySender{}
|
||||
|
||||
cfg := channel.ChannelConfig{ID: "cfg-asset", BotID: "bot-1"}
|
||||
msg := channel.InboundMessage{
|
||||
BotID: "bot-1",
|
||||
Channel: channel.ChannelType("feishu"),
|
||||
Message: channel.Message{
|
||||
ID: "msg-asset-1",
|
||||
Text: "attachment test",
|
||||
Attachments: []channel.Attachment{
|
||||
{
|
||||
Type: channel.AttachmentImage,
|
||||
URL: "https://example.com/img.png",
|
||||
AssetID: "asset-1",
|
||||
Name: "img.png",
|
||||
Mime: "image/png",
|
||||
},
|
||||
},
|
||||
},
|
||||
ReplyTarget: "chat_id:oc_asset",
|
||||
Sender: channel.Identity{SubjectID: "ext-asset"},
|
||||
Conversation: channel.Conversation{
|
||||
ID: "oc_asset",
|
||||
Type: "p2p",
|
||||
},
|
||||
}
|
||||
|
||||
if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(chatSvc.persistedIn) != 1 {
|
||||
t.Fatalf("expected one persisted input, got %d", len(chatSvc.persistedIn))
|
||||
}
|
||||
if len(chatSvc.persistedIn[0].Assets) != 1 {
|
||||
t.Fatalf("expected one persisted asset ref, got %d", len(chatSvc.persistedIn[0].Assets))
|
||||
}
|
||||
if got := chatSvc.persistedIn[0].Assets[0].AssetID; got != "asset-1" {
|
||||
t.Fatalf("expected persisted asset id asset-1, got %q", got)
|
||||
}
|
||||
if len(gateway.gotReq.Attachments) != 1 {
|
||||
t.Fatalf("expected one gateway attachment, got %d", len(gateway.gotReq.Attachments))
|
||||
}
|
||||
if got := gateway.gotReq.Attachments[0].AssetID; got != "asset-1" {
|
||||
t.Fatalf("expected gateway attachment asset_id asset-1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelInboundProcessorIngestsPlatformKeyWithResolver(t *testing.T) {
|
||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-resolver"}}
|
||||
memberSvc := &fakeMemberService{isMember: true}
|
||||
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-resolver", RouteID: "route-resolver"}}
|
||||
gateway := &fakeChatGateway{
|
||||
resp: conversation.ChatResponse{
|
||||
Messages: []conversation.ModelMessage{
|
||||
{Role: "assistant", Content: conversation.NewTextContent("ok")},
|
||||
},
|
||||
},
|
||||
}
|
||||
registry := channel.NewRegistry()
|
||||
registry.MustRegister(&fakeAttachmentResolverAdapter{})
|
||||
processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0)
|
||||
mediaSvc := &fakeMediaIngestor{nextID: "asset-resolved-1", nextMime: "application/octet-stream"}
|
||||
processor.SetMediaService(mediaSvc)
|
||||
sender := &fakeReplySender{}
|
||||
|
||||
cfg := channel.ChannelConfig{ID: "cfg-resolver", BotID: "bot-1", ChannelType: channel.ChannelType("resolver-test")}
|
||||
msg := channel.InboundMessage{
|
||||
BotID: "bot-1",
|
||||
Channel: channel.ChannelType("resolver-test"),
|
||||
Message: channel.Message{
|
||||
ID: "msg-resolver-1",
|
||||
Text: "attachment resolver test",
|
||||
Attachments: []channel.Attachment{
|
||||
{
|
||||
Type: channel.AttachmentFile,
|
||||
PlatformKey: "platform-file-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
ReplyTarget: "resolver-target",
|
||||
Sender: channel.Identity{SubjectID: "resolver-user"},
|
||||
Conversation: channel.Conversation{
|
||||
ID: "resolver-conv",
|
||||
Type: "p2p",
|
||||
},
|
||||
}
|
||||
|
||||
if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if mediaSvc.calls != 1 {
|
||||
t.Fatalf("expected media ingest to be called once, got %d", mediaSvc.calls)
|
||||
}
|
||||
if len(gateway.gotReq.Attachments) != 1 {
|
||||
t.Fatalf("expected one gateway attachment, got %d", len(gateway.gotReq.Attachments))
|
||||
}
|
||||
if got := gateway.gotReq.Attachments[0].AssetID; got != "asset-resolved-1" {
|
||||
t.Fatalf("expected resolved asset id, got %q", got)
|
||||
}
|
||||
if len(chatSvc.persistedIn) != 1 || len(chatSvc.persistedIn[0].Assets) != 1 {
|
||||
t.Fatalf("expected one persisted asset ref, got %+v", chatSvc.persistedIn)
|
||||
}
|
||||
if got := chatSvc.persistedIn[0].Assets[0].AssetID; got != "asset-resolved-1" {
|
||||
t.Fatalf("expected persisted asset id asset-resolved-1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) {
|
||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-member"}}
|
||||
memberSvc := &fakeMemberService{isMember: true}
|
||||
@@ -704,3 +895,242 @@ func TestChannelInboundProcessorProcessingFailedNotifyErrorDoesNotOverrideChatEr
|
||||
t.Fatalf("unexpected processing status lifecycle: %+v", notifier.events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadInboundAttachmentURLTooLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Length", "999999999")
|
||||
_, _ = w.Write([]byte("x"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := openInboundAttachmentURL(context.Background(), server.URL)
|
||||
if err == nil {
|
||||
t.Fatalf("expected too-large error")
|
||||
}
|
||||
if !errors.Is(err, media.ErrAssetTooLarge) {
|
||||
t.Fatalf("expected ErrAssetTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStreamChunkToChannelEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chunk string
|
||||
wantType channel.StreamEventType
|
||||
wantDelta string
|
||||
wantPhase channel.StreamPhase
|
||||
wantToolName string
|
||||
wantAttCount int
|
||||
wantError string
|
||||
wantNilEvents bool
|
||||
}{
|
||||
{
|
||||
name: "text_delta",
|
||||
chunk: `{"type":"text_delta","delta":"hello"}`,
|
||||
wantType: channel.StreamEventDelta,
|
||||
wantDelta: "hello",
|
||||
wantPhase: channel.StreamPhaseText,
|
||||
},
|
||||
{
|
||||
name: "text_delta empty",
|
||||
chunk: `{"type":"text_delta","delta":""}`,
|
||||
wantNilEvents: true,
|
||||
},
|
||||
{
|
||||
name: "reasoning_delta",
|
||||
chunk: `{"type":"reasoning_delta","delta":"thinking"}`,
|
||||
wantType: channel.StreamEventDelta,
|
||||
wantDelta: "thinking",
|
||||
wantPhase: channel.StreamPhaseReasoning,
|
||||
},
|
||||
{
|
||||
name: "reasoning_delta empty",
|
||||
chunk: `{"type":"reasoning_delta","delta":""}`,
|
||||
wantNilEvents: true,
|
||||
},
|
||||
{
|
||||
name: "reasoning_start",
|
||||
chunk: `{"type":"reasoning_start"}`,
|
||||
wantType: channel.StreamEventPhaseStart,
|
||||
wantPhase: channel.StreamPhaseReasoning,
|
||||
},
|
||||
{
|
||||
name: "reasoning_end",
|
||||
chunk: `{"type":"reasoning_end"}`,
|
||||
wantType: channel.StreamEventPhaseEnd,
|
||||
wantPhase: channel.StreamPhaseReasoning,
|
||||
},
|
||||
{
|
||||
name: "text_start",
|
||||
chunk: `{"type":"text_start"}`,
|
||||
wantType: channel.StreamEventPhaseStart,
|
||||
wantPhase: channel.StreamPhaseText,
|
||||
},
|
||||
{
|
||||
name: "text_end",
|
||||
chunk: `{"type":"text_end"}`,
|
||||
wantType: channel.StreamEventPhaseEnd,
|
||||
wantPhase: channel.StreamPhaseText,
|
||||
},
|
||||
{
|
||||
name: "tool_call_start",
|
||||
chunk: `{"type":"tool_call_start","toolName":"search_web","toolCallId":"tc_1","input":{"query":"test"}}`,
|
||||
wantType: channel.StreamEventToolCallStart,
|
||||
wantToolName: "search_web",
|
||||
},
|
||||
{
|
||||
name: "tool_call_end",
|
||||
chunk: `{"type":"tool_call_end","toolName":"search_web","toolCallId":"tc_1","input":{"query":"test"},"result":{"ok":true}}`,
|
||||
wantType: channel.StreamEventToolCallEnd,
|
||||
wantToolName: "search_web",
|
||||
},
|
||||
{
|
||||
name: "attachment_delta",
|
||||
chunk: `{"type":"attachment_delta","attachments":[{"type":"image","url":"https://example.com/img.png"}]}`,
|
||||
wantType: channel.StreamEventAttachment,
|
||||
wantAttCount: 1,
|
||||
},
|
||||
{
|
||||
name: "attachment_delta empty",
|
||||
chunk: `{"type":"attachment_delta","attachments":[]}`,
|
||||
wantNilEvents: true,
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
chunk: `{"type":"error","error":"something failed"}`,
|
||||
wantType: channel.StreamEventError,
|
||||
wantError: "something failed",
|
||||
},
|
||||
{
|
||||
name: "error fallback to message",
|
||||
chunk: `{"type":"error","message":"fallback msg"}`,
|
||||
wantType: channel.StreamEventError,
|
||||
wantError: "fallback msg",
|
||||
},
|
||||
{
|
||||
name: "agent_start",
|
||||
chunk: `{"type":"agent_start","input":{"agent":"planner"}}`,
|
||||
wantType: channel.StreamEventAgentStart,
|
||||
},
|
||||
{
|
||||
name: "agent_end",
|
||||
chunk: `{"type":"agent_end","result":{"ok":true}}`,
|
||||
wantType: channel.StreamEventAgentEnd,
|
||||
},
|
||||
{
|
||||
name: "processing_started",
|
||||
chunk: `{"type":"processing_started"}`,
|
||||
wantType: channel.StreamEventProcessingStarted,
|
||||
},
|
||||
{
|
||||
name: "processing_completed",
|
||||
chunk: `{"type":"processing_completed"}`,
|
||||
wantType: channel.StreamEventProcessingCompleted,
|
||||
},
|
||||
{
|
||||
name: "processing_failed",
|
||||
chunk: `{"type":"processing_failed","error":"failed"}`,
|
||||
wantType: channel.StreamEventProcessingFailed,
|
||||
wantError: "failed",
|
||||
},
|
||||
{
|
||||
name: "empty chunk",
|
||||
chunk: ``,
|
||||
wantNilEvents: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
events, _, err := mapStreamChunkToChannelEvents(conversation.StreamChunk([]byte(tt.chunk)))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if tt.wantNilEvents {
|
||||
if len(events) > 0 {
|
||||
t.Fatalf("expected nil/empty events, got %d", len(events))
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 event, got %d", len(events))
|
||||
}
|
||||
ev := events[0]
|
||||
if ev.Type != tt.wantType {
|
||||
t.Fatalf("expected type %q, got %q", tt.wantType, ev.Type)
|
||||
}
|
||||
if tt.wantDelta != "" && ev.Delta != tt.wantDelta {
|
||||
t.Fatalf("expected delta %q, got %q", tt.wantDelta, ev.Delta)
|
||||
}
|
||||
if tt.wantPhase != "" && ev.Phase != tt.wantPhase {
|
||||
t.Fatalf("expected phase %q, got %q", tt.wantPhase, ev.Phase)
|
||||
}
|
||||
if tt.wantToolName != "" {
|
||||
if ev.ToolCall == nil {
|
||||
t.Fatal("expected non-nil ToolCall")
|
||||
}
|
||||
if ev.ToolCall.Name != tt.wantToolName {
|
||||
t.Fatalf("expected tool name %q, got %q", tt.wantToolName, ev.ToolCall.Name)
|
||||
}
|
||||
}
|
||||
if tt.wantAttCount > 0 && len(ev.Attachments) != tt.wantAttCount {
|
||||
t.Fatalf("expected %d attachments, got %d", tt.wantAttCount, len(ev.Attachments))
|
||||
}
|
||||
if tt.wantError != "" && ev.Error != tt.wantError {
|
||||
t.Fatalf("expected error %q, got %q", tt.wantError, ev.Error)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStreamChunkToChannelEvents_ToolCallFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chunk := `{"type":"tool_call_end","toolName":"calc","toolCallId":"c1","input":{"x":1},"result":{"sum":2}}`
|
||||
events, _, err := mapStreamChunkToChannelEvents(conversation.StreamChunk([]byte(chunk)))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 event, got %d", len(events))
|
||||
}
|
||||
tc := events[0].ToolCall
|
||||
if tc == nil {
|
||||
t.Fatal("expected non-nil ToolCall")
|
||||
}
|
||||
if tc.Name != "calc" || tc.CallID != "c1" {
|
||||
t.Fatalf("unexpected name/callID: %q / %q", tc.Name, tc.CallID)
|
||||
}
|
||||
if tc.Input == nil || tc.Result == nil {
|
||||
t.Fatal("expected non-nil Input and Result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStreamChunkToChannelEvents_FinalMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chunk := `{"type":"agent_end","messages":[{"role":"assistant","content":"done"}]}`
|
||||
events, messages, err := mapStreamChunkToChannelEvents(conversation.StreamChunk([]byte(chunk)))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 event, got %d", len(events))
|
||||
}
|
||||
if events[0].Type != channel.StreamEventAgentEnd {
|
||||
t.Fatalf("expected event type %q, got %q", channel.StreamEventAgentEnd, events[0].Type)
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 final message, got %d", len(messages))
|
||||
}
|
||||
if messages[0].Role != "assistant" {
|
||||
t.Fatalf("expected role assistant, got %q", messages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user