feat: ui message (#357)

This commit is contained in:
Acbox Liu
2026-04-11 13:29:41 +08:00
committed by GitHub
parent f376a2abe3
commit 7a21fd5f07
19 changed files with 2141 additions and 774 deletions
+96
View File
@@ -0,0 +1,96 @@
package conversation
import (
"strings"
"time"
)
// UIMessageType identifies the frontend-friendly message block type.
type UIMessageType string
const (
UIMessageText UIMessageType = "text"
UIMessageReasoning UIMessageType = "reasoning"
UIMessageTool UIMessageType = "tool"
UIMessageAttachments UIMessageType = "attachments"
)
// UIAttachment is the normalized attachment shape used by the web frontend.
type UIAttachment struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Path string `json:"path,omitempty"`
URL string `json:"url,omitempty"`
Base64 string `json:"base64,omitempty"`
Name string `json:"name,omitempty"`
ContentHash string `json:"content_hash,omitempty"`
BotID string `json:"bot_id,omitempty"`
Mime string `json:"mime,omitempty"`
Size int64 `json:"size,omitempty"`
StorageKey string `json:"storage_key,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// UIMessage is the normalized assistant output block used by the web frontend.
type UIMessage struct {
ID int `json:"id"`
Type UIMessageType `json:"type"`
Content string `json:"content,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Output any `json:"output,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Running *bool `json:"running,omitempty"`
Progress []any `json:"progress,omitempty"`
Attachments []UIAttachment `json:"attachments,omitempty"`
}
// UITurn is the normalized chat turn used by the web frontend.
type UITurn struct {
Role string `json:"role"`
Messages []UIMessage `json:"messages,omitempty"`
Text string `json:"text,omitempty"`
Attachments []UIAttachment `json:"attachments,omitempty"`
Timestamp time.Time `json:"timestamp"`
Platform string `json:"platform,omitempty"`
SenderDisplayName string `json:"sender_display_name,omitempty"`
SenderAvatarURL string `json:"sender_avatar_url,omitempty"`
SenderUserID string `json:"sender_user_id,omitempty"`
ID string `json:"id,omitempty"`
}
// UIMessageStreamEvent is the generic event shape accepted by the UI stream converter.
// The handler layer adapts agent/channel events to this struct to avoid package cycles.
type UIMessageStreamEvent struct {
Type string
Delta string
ToolName string
ToolCallID string
Input any
Output any
Progress any
Attachments []UIAttachment
Error string
}
func uiBoolPtr(v bool) *bool {
return &v
}
func normalizeUIAttachmentType(kind, mime string) string {
if trimmed := strings.ToLower(strings.TrimSpace(kind)); trimmed != "" {
return trimmed
}
normalizedMime := strings.ToLower(strings.TrimSpace(mime))
switch {
case strings.HasPrefix(normalizedMime, "image/"):
return "image"
case strings.HasPrefix(normalizedMime, "audio/"):
return "audio"
case strings.HasPrefix(normalizedMime, "video/"):
return "video"
default:
return "file"
}
}
+609
View File
@@ -0,0 +1,609 @@
package conversation
import (
"encoding/json"
"regexp"
"strings"
messagepkg "github.com/memohai/memoh/internal/message"
)
var (
uiMessageYAMLHeaderRe = regexp.MustCompile(`(?s)\A---\n.*?\n---\n?`)
uiMessageAgentTagsRe = regexp.MustCompile(`(?s)<attachments>.*?</attachments>|<reactions>.*?</reactions>|<speech>.*?</speech>`)
uiMessageCollapsedNewlinesRe = regexp.MustCompile(`\n{3,}`)
)
type uiContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
URL string `json:"url,omitempty"`
Emoji string `json:"emoji,omitempty"`
ToolCallID string `json:"toolCallId,omitempty"`
ToolName string `json:"toolName,omitempty"`
Input any `json:"input,omitempty"`
Output any `json:"output,omitempty"`
Result any `json:"result,omitempty"`
}
type uiExtractedToolCall struct {
ID string
Name string
Input any
}
type uiExtractedToolResult struct {
ToolCallID string
Output any
}
type uiPendingAssistantTurn struct {
Turn UITurn
NextID int
ToolIndexes map[string]int
}
// ConvertRawModelMessagesToUIAssistantMessages converts terminal stream payload
// messages into frontend-friendly assistant UI messages.
func ConvertRawModelMessagesToUIAssistantMessages(raw json.RawMessage) []UIMessage {
if len(raw) == 0 {
return nil
}
var messages []ModelMessage
if err := json.Unmarshal(raw, &messages); err != nil {
return nil
}
return ConvertModelMessagesToUIAssistantMessages(messages)
}
// ConvertModelMessagesToUIAssistantMessages converts assistant/tool output
// messages into frontend-friendly UI message blocks.
func ConvertModelMessagesToUIAssistantMessages(messages []ModelMessage) []UIMessage {
pending := &uiPendingAssistantTurn{
ToolIndexes: map[string]int{},
}
for _, modelMessage := range messages {
switch strings.ToLower(strings.TrimSpace(modelMessage.Role)) {
case "assistant":
for _, reasoning := range extractPersistedReasoning(modelMessage) {
appendPendingAssistantMessage(pending, UIMessage{
Type: UIMessageReasoning,
Content: reasoning,
})
}
if text := extractAssistantStreamMessageText(modelMessage); text != "" {
appendPendingAssistantMessage(pending, UIMessage{
Type: UIMessageText,
Content: text,
})
}
for _, call := range extractPersistedToolCalls(modelMessage) {
appendPendingAssistantMessage(pending, UIMessage{
Type: UIMessageTool,
Name: call.Name,
Input: call.Input,
ToolCallID: call.ID,
Running: uiBoolPtr(true),
})
if call.ID != "" {
pending.ToolIndexes[call.ID] = len(pending.Turn.Messages) - 1
}
}
case "tool":
for _, toolResult := range extractPersistedToolResults(modelMessage) {
idx, ok := pending.ToolIndexes[toolResult.ToolCallID]
if !ok || idx < 0 || idx >= len(pending.Turn.Messages) {
continue
}
if isHiddenCurrentConversationToolOutput(toolResult.Output) {
removePendingAssistantMessage(pending, idx)
delete(pending.ToolIndexes, toolResult.ToolCallID)
continue
}
pending.Turn.Messages[idx].Output = toolResult.Output
pending.Turn.Messages[idx].Running = uiBoolPtr(false)
}
}
}
for _, idx := range pending.ToolIndexes {
if idx >= 0 && idx < len(pending.Turn.Messages) {
pending.Turn.Messages[idx].Running = uiBoolPtr(false)
}
}
return pending.Turn.Messages
}
// ConvertMessagesToUITurns converts persisted message rows into frontend-friendly turns.
func ConvertMessagesToUITurns(messages []messagepkg.Message) []UITurn {
result := make([]UITurn, 0, len(messages))
var pending *uiPendingAssistantTurn
flushPending := func() {
if pending == nil {
return
}
for _, idx := range pending.ToolIndexes {
if idx < 0 || idx >= len(pending.Turn.Messages) {
continue
}
pending.Turn.Messages[idx].Running = uiBoolPtr(false)
}
if len(pending.Turn.Messages) > 0 {
result = append(result, pending.Turn)
}
pending = nil
}
for _, raw := range messages {
modelMessage := decodePersistedModelMessage(raw)
switch strings.ToLower(strings.TrimSpace(raw.Role)) {
case "user":
flushPending()
text := extractPersistedMessageText(raw, modelMessage)
attachments := uiAttachmentsFromMessageAssets(raw)
if text == "" && len(attachments) == 0 {
continue
}
turn := UITurn{
Role: "user",
Text: text,
Attachments: attachments,
Timestamp: raw.CreatedAt,
Platform: resolveUIPersistencePlatform(raw),
ID: strings.TrimSpace(raw.ID),
}
if turn.Platform != "" {
turn.SenderDisplayName = strings.TrimSpace(raw.SenderDisplayName)
turn.SenderAvatarURL = strings.TrimSpace(raw.SenderAvatarURL)
turn.SenderUserID = strings.TrimSpace(raw.SenderUserID)
}
result = append(result, turn)
case "assistant":
toolCalls := extractPersistedToolCalls(modelMessage)
text := extractPersistedMessageText(raw, modelMessage)
reasonings := extractPersistedReasoning(modelMessage)
attachments := uiAttachmentsFromMessageAssets(raw)
if len(toolCalls) > 0 {
if pending == nil {
pending = newPendingAssistantTurn(raw)
}
for _, reasoning := range reasonings {
appendPendingAssistantMessage(pending, UIMessage{
ID: pending.NextID,
Type: UIMessageReasoning,
Content: reasoning,
})
}
if text != "" {
appendPendingAssistantMessage(pending, UIMessage{
ID: pending.NextID,
Type: UIMessageText,
Content: text,
})
}
for _, call := range toolCalls {
block := UIMessage{
ID: pending.NextID,
Type: UIMessageTool,
Name: call.Name,
Input: call.Input,
ToolCallID: call.ID,
Running: uiBoolPtr(true),
}
appendPendingAssistantMessage(pending, block)
if call.ID != "" {
pending.ToolIndexes[call.ID] = len(pending.Turn.Messages) - 1
}
}
if len(attachments) > 0 {
appendPendingAssistantMessage(pending, UIMessage{
ID: pending.NextID,
Type: UIMessageAttachments,
Attachments: attachments,
})
}
continue
}
if pending != nil && (text != "" || len(reasonings) > 0 || len(attachments) > 0) {
for _, reasoning := range reasonings {
appendPendingAssistantMessage(pending, UIMessage{
ID: pending.NextID,
Type: UIMessageReasoning,
Content: reasoning,
})
}
if text != "" {
appendPendingAssistantMessage(pending, UIMessage{
ID: pending.NextID,
Type: UIMessageText,
Content: text,
})
}
if len(attachments) > 0 {
appendPendingAssistantMessage(pending, UIMessage{
ID: pending.NextID,
Type: UIMessageAttachments,
Attachments: attachments,
})
}
flushPending()
continue
}
flushPending()
assistantMessages := buildStandaloneAssistantMessages(text, reasonings, attachments)
if len(assistantMessages) == 0 {
continue
}
result = append(result, UITurn{
Role: "assistant",
Messages: assistantMessages,
Timestamp: raw.CreatedAt,
Platform: resolveUIPersistencePlatform(raw),
ID: strings.TrimSpace(raw.ID),
})
case "tool":
if pending == nil {
continue
}
for _, toolResult := range extractPersistedToolResults(modelMessage) {
idx, ok := pending.ToolIndexes[toolResult.ToolCallID]
if !ok || idx < 0 || idx >= len(pending.Turn.Messages) {
continue
}
if isHiddenCurrentConversationToolOutput(toolResult.Output) {
removePendingAssistantMessage(pending, idx)
delete(pending.ToolIndexes, toolResult.ToolCallID)
continue
}
pending.Turn.Messages[idx].Output = toolResult.Output
pending.Turn.Messages[idx].Running = uiBoolPtr(false)
}
}
}
flushPending()
return result
}
func newPendingAssistantTurn(raw messagepkg.Message) *uiPendingAssistantTurn {
return &uiPendingAssistantTurn{
Turn: UITurn{
Role: "assistant",
Timestamp: raw.CreatedAt,
Platform: resolveUIPersistencePlatform(raw),
ID: strings.TrimSpace(raw.ID),
},
ToolIndexes: map[string]int{},
}
}
func appendPendingAssistantMessage(pending *uiPendingAssistantTurn, message UIMessage) {
if pending == nil {
return
}
message.ID = pending.NextID
pending.NextID++
pending.Turn.Messages = append(pending.Turn.Messages, message)
}
func removePendingAssistantMessage(pending *uiPendingAssistantTurn, idx int) {
if pending == nil || idx < 0 || idx >= len(pending.Turn.Messages) {
return
}
pending.Turn.Messages = append(pending.Turn.Messages[:idx], pending.Turn.Messages[idx+1:]...)
for callID, currentIdx := range pending.ToolIndexes {
switch {
case currentIdx == idx:
delete(pending.ToolIndexes, callID)
case currentIdx > idx:
pending.ToolIndexes[callID] = currentIdx - 1
}
}
}
func buildStandaloneAssistantMessages(text string, reasonings []string, attachments []UIAttachment) []UIMessage {
messages := make([]UIMessage, 0, len(reasonings)+2)
nextID := 0
for _, reasoning := range reasonings {
messages = append(messages, UIMessage{
ID: nextID,
Type: UIMessageReasoning,
Content: reasoning,
})
nextID++
}
if text != "" {
messages = append(messages, UIMessage{
ID: nextID,
Type: UIMessageText,
Content: text,
})
nextID++
}
if len(attachments) > 0 {
messages = append(messages, UIMessage{
ID: nextID,
Type: UIMessageAttachments,
Attachments: attachments,
})
}
return messages
}
func decodePersistedModelMessage(raw messagepkg.Message) ModelMessage {
var message ModelMessage
if err := json.Unmarshal(raw.Content, &message); err != nil {
return ModelMessage{
Role: raw.Role,
Content: raw.Content,
}
}
message.Role = raw.Role
return message
}
func extractPersistedMessageText(raw messagepkg.Message, message ModelMessage) string {
if strings.EqualFold(raw.Role, "user") {
if text := strings.TrimSpace(raw.DisplayContent); text != "" {
return text
}
}
text := strings.TrimSpace(extractTextFromPersistedContent(message.Content))
if text == "" {
return ""
}
if strings.EqualFold(raw.Role, "user") {
return strings.TrimSpace(stripPersistedYAMLHeader(text))
}
return strings.TrimSpace(stripPersistedAgentTags(text))
}
func extractAssistantStreamMessageText(message ModelMessage) string {
return strings.TrimSpace(stripPersistedAgentTags(extractTextFromPersistedContent(message.Content)))
}
func extractTextFromPersistedContent(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var text string
if err := json.Unmarshal(raw, &text); err == nil {
return strings.TrimSpace(text)
}
parts := extractPersistedContentParts(raw)
if len(parts) > 0 {
lines := make([]string, 0, len(parts))
for _, part := range parts {
partType := strings.ToLower(strings.TrimSpace(part.Type))
if partType == "reasoning" {
continue
}
switch {
case partType == "text" && strings.TrimSpace(part.Text) != "":
lines = append(lines, strings.TrimSpace(part.Text))
case partType == "link" && strings.TrimSpace(part.URL) != "":
lines = append(lines, strings.TrimSpace(part.URL))
case partType == "emoji" && strings.TrimSpace(part.Emoji) != "":
lines = append(lines, strings.TrimSpace(part.Emoji))
case strings.TrimSpace(part.Text) != "":
lines = append(lines, strings.TrimSpace(part.Text))
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
var object map[string]any
if err := json.Unmarshal(raw, &object); err == nil {
if value, ok := object["text"].(string); ok {
return strings.TrimSpace(value)
}
}
return ""
}
func extractPersistedReasoning(message ModelMessage) []string {
parts := extractPersistedContentParts(message.Content)
if len(parts) == 0 {
return nil
}
reasonings := make([]string, 0, len(parts))
for _, part := range parts {
if strings.ToLower(strings.TrimSpace(part.Type)) != "reasoning" {
continue
}
if text := strings.TrimSpace(part.Text); text != "" {
reasonings = append(reasonings, text)
}
}
return reasonings
}
func extractPersistedToolCalls(message ModelMessage) []uiExtractedToolCall {
parts := extractPersistedContentParts(message.Content)
calls := make([]uiExtractedToolCall, 0, len(parts)+len(message.ToolCalls))
for _, part := range parts {
if strings.ToLower(strings.TrimSpace(part.Type)) != "tool-call" {
continue
}
calls = append(calls, uiExtractedToolCall{
ID: strings.TrimSpace(part.ToolCallID),
Name: strings.TrimSpace(part.ToolName),
Input: part.Input,
})
}
if len(calls) > 0 {
return calls
}
for _, toolCall := range message.ToolCalls {
input := any(nil)
if rawArgs := strings.TrimSpace(toolCall.Function.Arguments); rawArgs != "" {
if err := json.Unmarshal([]byte(rawArgs), &input); err != nil {
input = rawArgs
}
}
calls = append(calls, uiExtractedToolCall{
ID: strings.TrimSpace(toolCall.ID),
Name: strings.TrimSpace(toolCall.Function.Name),
Input: input,
})
}
return calls
}
func extractPersistedToolResults(message ModelMessage) []uiExtractedToolResult {
parts := extractPersistedContentParts(message.Content)
results := make([]uiExtractedToolResult, 0, len(parts))
for _, part := range parts {
if strings.ToLower(strings.TrimSpace(part.Type)) != "tool-result" {
continue
}
output := part.Output
if output == nil {
output = part.Result
}
results = append(results, uiExtractedToolResult{
ToolCallID: strings.TrimSpace(part.ToolCallID),
Output: output,
})
}
if len(results) > 0 {
return results
}
if strings.TrimSpace(message.ToolCallID) == "" {
return nil
}
var output any
if err := json.Unmarshal(message.Content, &output); err != nil {
output = strings.TrimSpace(string(message.Content))
}
return []uiExtractedToolResult{{
ToolCallID: strings.TrimSpace(message.ToolCallID),
Output: output,
}}
}
func extractPersistedContentParts(raw json.RawMessage) []uiContentPart {
if len(raw) == 0 {
return nil
}
var parts []uiContentPart
if err := json.Unmarshal(raw, &parts); err == nil {
return parts
}
var encoded string
if err := json.Unmarshal(raw, &encoded); err == nil {
trimmed := strings.TrimSpace(encoded)
if strings.HasPrefix(trimmed, "[") && json.Unmarshal([]byte(trimmed), &parts) == nil {
return parts
}
}
var object struct {
Content json.RawMessage `json:"content"`
}
if err := json.Unmarshal(raw, &object); err == nil && len(object.Content) > 0 {
return extractPersistedContentParts(object.Content)
}
return nil
}
func uiAttachmentsFromMessageAssets(raw messagepkg.Message) []UIAttachment {
if len(raw.Assets) == 0 {
return nil
}
attachments := make([]UIAttachment, 0, len(raw.Assets))
for _, asset := range raw.Assets {
attachments = append(attachments, UIAttachment{
ID: strings.TrimSpace(asset.ContentHash),
Type: normalizeUIAttachmentType("", asset.Mime),
Name: strings.TrimSpace(asset.Name),
ContentHash: strings.TrimSpace(asset.ContentHash),
BotID: strings.TrimSpace(raw.BotID),
Mime: strings.TrimSpace(asset.Mime),
Size: asset.SizeBytes,
StorageKey: strings.TrimSpace(asset.StorageKey),
Metadata: asset.Metadata,
})
}
return attachments
}
func resolveUIPersistencePlatform(raw messagepkg.Message) string {
direct := strings.ToLower(strings.TrimSpace(raw.Platform))
if direct == "local" {
return ""
}
if direct != "" {
return direct
}
if raw.Metadata != nil {
if platform, ok := raw.Metadata["platform"].(string); ok {
trimmed := strings.ToLower(strings.TrimSpace(platform))
if trimmed == "local" {
return ""
}
return trimmed
}
}
return ""
}
func stripPersistedYAMLHeader(text string) string {
return strings.TrimSpace(uiMessageYAMLHeaderRe.ReplaceAllString(text, ""))
}
func stripPersistedAgentTags(text string) string {
stripped := uiMessageAgentTagsRe.ReplaceAllString(text, "")
return strings.TrimSpace(uiMessageCollapsedNewlinesRe.ReplaceAllString(stripped, "\n\n"))
}
func isHiddenCurrentConversationToolOutput(output any) bool {
typed, ok := output.(map[string]any)
if !ok {
return false
}
delivered, _ := typed["delivered"].(string)
return strings.EqualFold(strings.TrimSpace(delivered), "current_conversation")
}
+176
View File
@@ -0,0 +1,176 @@
package conversation
import "strings"
type uiTextStreamState struct {
ID int
Content string
}
type uiToolStreamState struct {
Message UIMessage
}
// UIMessageStreamConverter converts low-level stream events into complete UI messages.
type UIMessageStreamConverter struct {
nextID int
text *uiTextStreamState
reasoning *uiTextStreamState
tools map[string]*uiToolStreamState
}
// NewUIMessageStreamConverter creates a new UI stream converter.
func NewUIMessageStreamConverter() *UIMessageStreamConverter {
return &UIMessageStreamConverter{
tools: map[string]*uiToolStreamState{},
}
}
// HandleEvent updates converter state and returns zero or one complete UI messages.
func (c *UIMessageStreamConverter) HandleEvent(event UIMessageStreamEvent) []UIMessage {
switch strings.ToLower(strings.TrimSpace(event.Type)) {
case "text_start":
c.text = &uiTextStreamState{ID: c.nextMessageID()}
return nil
case "text_delta":
if c.text == nil {
c.text = &uiTextStreamState{ID: c.nextMessageID()}
}
c.text.Content += event.Delta
return []UIMessage{{
ID: c.text.ID,
Type: UIMessageText,
Content: c.text.Content,
}}
case "text_end":
c.text = nil
return nil
case "reasoning_start":
c.reasoning = &uiTextStreamState{ID: c.nextMessageID()}
return nil
case "reasoning_delta":
if c.reasoning == nil {
c.reasoning = &uiTextStreamState{ID: c.nextMessageID()}
}
c.reasoning.Content += event.Delta
return []UIMessage{{
ID: c.reasoning.ID,
Type: UIMessageReasoning,
Content: c.reasoning.Content,
}}
case "reasoning_end":
c.reasoning = nil
return nil
case "tool_call_start":
state := &uiToolStreamState{
Message: UIMessage{
ID: c.nextMessageID(),
Type: UIMessageTool,
Name: strings.TrimSpace(event.ToolName),
Input: event.Input,
ToolCallID: strings.TrimSpace(event.ToolCallID),
Running: uiBoolPtr(true),
},
}
if state.Message.ToolCallID != "" {
c.tools[state.Message.ToolCallID] = state
}
c.text = nil
return []UIMessage{state.Message}
case "tool_call_progress":
state := c.findToolState(event.ToolCallID, event.ToolName)
if state == nil {
state = &uiToolStreamState{
Message: UIMessage{
ID: c.nextMessageID(),
Type: UIMessageTool,
Name: strings.TrimSpace(event.ToolName),
Input: event.Input,
ToolCallID: strings.TrimSpace(event.ToolCallID),
Running: uiBoolPtr(true),
},
}
if state.Message.ToolCallID != "" {
c.tools[state.Message.ToolCallID] = state
}
}
state.Message.Progress = append(state.Message.Progress, event.Progress)
if event.Input != nil {
state.Message.Input = event.Input
}
return []UIMessage{cloneToolStreamMessage(state.Message)}
case "tool_call_end":
state := c.findToolState(event.ToolCallID, event.ToolName)
if state == nil {
state = &uiToolStreamState{
Message: UIMessage{
ID: c.nextMessageID(),
Type: UIMessageTool,
Name: strings.TrimSpace(event.ToolName),
Input: event.Input,
ToolCallID: strings.TrimSpace(event.ToolCallID),
},
}
}
if event.Input != nil {
state.Message.Input = event.Input
}
state.Message.Output = event.Output
state.Message.Running = uiBoolPtr(false)
if state.Message.ToolCallID != "" {
delete(c.tools, state.Message.ToolCallID)
}
return []UIMessage{cloneToolStreamMessage(state.Message)}
case "attachment_delta":
if len(event.Attachments) == 0 {
return nil
}
return []UIMessage{{
ID: c.nextMessageID(),
Type: UIMessageAttachments,
Attachments: append([]UIAttachment(nil), event.Attachments...),
}}
default:
return nil
}
}
func (c *UIMessageStreamConverter) nextMessageID() int {
id := c.nextID
c.nextID++
return id
}
func (c *UIMessageStreamConverter) findToolState(toolCallID, toolName string) *uiToolStreamState {
if trimmed := strings.TrimSpace(toolCallID); trimmed != "" {
if state, ok := c.tools[trimmed]; ok {
return state
}
}
normalizedName := strings.TrimSpace(toolName)
for _, state := range c.tools {
if strings.TrimSpace(state.Message.Name) == normalizedName {
return state
}
}
return nil
}
func cloneToolStreamMessage(message UIMessage) UIMessage {
clone := message
if len(message.Progress) > 0 {
clone.Progress = append([]any(nil), message.Progress...)
}
return clone
}
+277
View File
@@ -0,0 +1,277 @@
package conversation
import (
"encoding/json"
"testing"
"time"
messagepkg "github.com/memohai/memoh/internal/message"
)
func TestConvertMessagesToUITurnsGroupsAssistantToolAndFiltersCurrentConversationDelivery(t *testing.T) {
baseTime := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
messages := []messagepkg.Message{
{
ID: "user-1",
BotID: "bot-1",
SessionID: "session-1",
Role: "user",
DisplayContent: "hello",
Content: mustUIMessageJSON(t, ModelMessage{
Role: "user",
Content: mustUIRawJSON(t, "hello"),
}),
CreatedAt: baseTime,
},
{
ID: "assistant-1",
BotID: "bot-1",
SessionID: "session-1",
Role: "assistant",
Content: mustUIMessageJSON(t, ModelMessage{
Role: "assistant",
Content: mustUIRawJSON(t, []map[string]any{
{"type": "reasoning", "text": "thinking"},
{"type": "tool-call", "toolCallId": "call-1", "toolName": "read", "input": map[string]any{"path": "/tmp/a.txt"}},
{"type": "tool-call", "toolCallId": "call-2", "toolName": "send", "input": map[string]any{"message": "hi"}},
}),
}),
Assets: []messagepkg.MessageAsset{{
ContentHash: "hash-1",
Mime: "image/png",
StorageKey: "media/hash-1",
Name: "image.png",
}},
CreatedAt: baseTime.Add(1 * time.Minute),
},
{
ID: "tool-1",
BotID: "bot-1",
SessionID: "session-1",
Role: "tool",
Content: mustUIMessageJSON(t, ModelMessage{
Role: "tool",
Content: mustUIRawJSON(t, []map[string]any{
{"type": "tool-result", "toolCallId": "call-1", "toolName": "read", "result": map[string]any{"structuredContent": map[string]any{"stdout": "hello"}}},
}),
}),
CreatedAt: baseTime.Add(2 * time.Minute),
},
{
ID: "tool-2",
BotID: "bot-1",
SessionID: "session-1",
Role: "tool",
Content: mustUIMessageJSON(t, ModelMessage{
Role: "tool",
Content: mustUIRawJSON(t, []map[string]any{
{"type": "tool-result", "toolCallId": "call-2", "toolName": "send", "result": map[string]any{"delivered": "current_conversation"}},
}),
}),
CreatedAt: baseTime.Add(3 * time.Minute),
},
{
ID: "assistant-2",
BotID: "bot-1",
SessionID: "session-1",
Role: "assistant",
Content: mustUIMessageJSON(t, ModelMessage{
Role: "assistant",
Content: mustUIRawJSON(t, []map[string]any{{"type": "text", "text": "done"}}),
}),
CreatedAt: baseTime.Add(4 * time.Minute),
},
}
turns := ConvertMessagesToUITurns(messages)
if len(turns) != 2 {
t.Fatalf("expected 2 turns, got %d", len(turns))
}
userTurn := turns[0]
if userTurn.Role != "user" || userTurn.Text != "hello" {
t.Fatalf("unexpected user turn: %#v", userTurn)
}
assistantTurn := turns[1]
if assistantTurn.Role != "assistant" {
t.Fatalf("expected assistant turn, got %#v", assistantTurn)
}
if len(assistantTurn.Messages) != 4 {
t.Fatalf("expected 4 assistant messages, got %d", len(assistantTurn.Messages))
}
if assistantTurn.Messages[0].Type != UIMessageReasoning || assistantTurn.Messages[0].Content != "thinking" {
t.Fatalf("unexpected reasoning block: %#v", assistantTurn.Messages[0])
}
if assistantTurn.Messages[1].Type != UIMessageTool || assistantTurn.Messages[1].Name != "read" {
t.Fatalf("unexpected tool block: %#v", assistantTurn.Messages[1])
}
if assistantTurn.Messages[1].Running == nil || *assistantTurn.Messages[1].Running {
t.Fatalf("expected tool block to be completed: %#v", assistantTurn.Messages[1])
}
if assistantTurn.Messages[2].Type != UIMessageAttachments || len(assistantTurn.Messages[2].Attachments) != 1 {
t.Fatalf("unexpected attachment block: %#v", assistantTurn.Messages[2])
}
if assistantTurn.Messages[2].Attachments[0].Type != "image" || assistantTurn.Messages[2].Attachments[0].BotID != "bot-1" {
t.Fatalf("unexpected attachment payload: %#v", assistantTurn.Messages[2].Attachments[0])
}
if assistantTurn.Messages[3].Type != UIMessageText || assistantTurn.Messages[3].Content != "done" {
t.Fatalf("unexpected trailing text block: %#v", assistantTurn.Messages[3])
}
for _, block := range assistantTurn.Messages {
if block.Type == UIMessageTool && block.Name == "send" {
t.Fatalf("expected current conversation delivery tool to be filtered out")
}
}
}
func TestConvertMessagesToUITurnsStripsUserYAMLHeaderFallback(t *testing.T) {
now := time.Now().UTC()
turns := ConvertMessagesToUITurns([]messagepkg.Message{{
ID: "user-1",
BotID: "bot-1",
SessionID: "session-1",
Role: "user",
Content: mustUIMessageJSON(t, ModelMessage{
Role: "user",
Content: mustUIRawJSON(t, "---\nmessage-id: 1\nchannel: telegram\n---\nhello"),
}),
CreatedAt: now,
}})
if len(turns) != 1 {
t.Fatalf("expected 1 turn, got %d", len(turns))
}
if turns[0].Text != "hello" {
t.Fatalf("expected YAML header to be stripped, got %q", turns[0].Text)
}
}
func TestUIMessageStreamConverterAccumulatesToolProgress(t *testing.T) {
converter := NewUIMessageStreamConverter()
start := converter.HandleEvent(UIMessageStreamEvent{
Type: "tool_call_start",
ToolName: "exec",
ToolCallID: "call-1",
Input: map[string]any{"command": "ls"},
})
if len(start) != 1 || start[0].Type != UIMessageTool || start[0].Name != "exec" {
t.Fatalf("unexpected tool start event: %#v", start)
}
if start[0].Running == nil || !*start[0].Running {
t.Fatalf("expected running tool start, got %#v", start[0])
}
progressOne := converter.HandleEvent(UIMessageStreamEvent{
Type: "tool_call_progress",
ToolName: "exec",
ToolCallID: "call-1",
Progress: "line 1",
})
progressTwo := converter.HandleEvent(UIMessageStreamEvent{
Type: "tool_call_progress",
ToolName: "exec",
ToolCallID: "call-1",
Progress: map[string]any{"line": 2},
})
if len(progressOne) != 1 || len(progressOne[0].Progress) != 1 {
t.Fatalf("unexpected first progress snapshot: %#v", progressOne)
}
if len(progressTwo) != 1 || len(progressTwo[0].Progress) != 2 {
t.Fatalf("unexpected second progress snapshot: %#v", progressTwo)
}
if progressTwo[0].ID != start[0].ID {
t.Fatalf("expected progress snapshots to reuse tool message id")
}
end := converter.HandleEvent(UIMessageStreamEvent{
Type: "tool_call_end",
ToolName: "exec",
ToolCallID: "call-1",
Output: map[string]any{"structuredContent": map[string]any{"stdout": "done"}},
})
if len(end) != 1 || end[0].Running == nil || *end[0].Running {
t.Fatalf("expected completed tool snapshot, got %#v", end)
}
if end[0].ID != start[0].ID || len(end[0].Progress) != 2 {
t.Fatalf("expected final snapshot to keep id and progress, got %#v", end[0])
}
}
func TestUIMessageStreamConverterStartsNewTextBlockAfterTool(t *testing.T) {
converter := NewUIMessageStreamConverter()
first := converter.HandleEvent(UIMessageStreamEvent{Type: "text_delta", Delta: "hello"})
converter.HandleEvent(UIMessageStreamEvent{Type: "tool_call_start", ToolName: "read", ToolCallID: "call-1"})
converter.HandleEvent(UIMessageStreamEvent{Type: "tool_call_end", ToolName: "read", ToolCallID: "call-1"})
second := converter.HandleEvent(UIMessageStreamEvent{Type: "text_delta", Delta: "world"})
if len(first) != 1 || len(second) != 1 {
t.Fatalf("expected text snapshots, got first=%#v second=%#v", first, second)
}
if first[0].ID == second[0].ID {
t.Fatalf("expected new text block after tool call, got same id %d", first[0].ID)
}
}
func TestConvertRawModelMessagesToUIAssistantMessagesBuildsTerminalSnapshots(t *testing.T) {
raw := mustUIRawJSON(t, []ModelMessage{
{
Role: "assistant",
Content: mustUIRawJSON(t, []map[string]any{
{"type": "reasoning", "text": "thinking"},
{"type": "tool-call", "toolCallId": "call-1", "toolName": "read", "input": map[string]any{"path": "/tmp/a.txt"}},
}),
},
{
Role: "tool",
Content: mustUIRawJSON(t, []map[string]any{
{"type": "tool-result", "toolCallId": "call-1", "toolName": "read", "result": map[string]any{"structuredContent": map[string]any{"stdout": "ok"}}},
}),
},
{
Role: "assistant",
Content: mustUIRawJSON(t, []map[string]any{
{"type": "text", "text": "final answer"},
}),
},
})
messages := ConvertRawModelMessagesToUIAssistantMessages(raw)
if len(messages) != 3 {
t.Fatalf("expected 3 ui messages, got %d", len(messages))
}
if messages[0].ID != 0 || messages[0].Type != UIMessageReasoning {
t.Fatalf("unexpected first ui message: %#v", messages[0])
}
if messages[1].ID != 1 || messages[1].Type != UIMessageTool {
t.Fatalf("unexpected second ui message: %#v", messages[1])
}
if messages[1].Running == nil || *messages[1].Running {
t.Fatalf("expected terminal tool message to be completed: %#v", messages[1])
}
if messages[2].ID != 2 || messages[2].Type != UIMessageText || messages[2].Content != "final answer" {
t.Fatalf("unexpected final ui message: %#v", messages[2])
}
}
func mustUIRawJSON(t *testing.T, value any) json.RawMessage {
t.Helper()
data, err := json.Marshal(value)
if err != nil {
t.Fatalf("marshal raw json: %v", err)
}
return data
}
func mustUIMessageJSON(t *testing.T, message ModelMessage) json.RawMessage {
t.Helper()
data, err := json.Marshal(message)
if err != nil {
t.Fatalf("marshal message: %v", err)
}
return data
}