From d05ba8956abe9ce0ac2238c8d416231a0a10d50f Mon Sep 17 00:00:00 2001 From: Acbox Date: Thu, 23 Apr 2026 17:56:09 +0800 Subject: [PATCH] fix: stop escaped history context blowup Avoid feeding structured tool payloads back into pipeline context as doubly encoded JSON, and return readable history summaries instead of raw message blobs. --- internal/agent/tools/history.go | 167 ++++++++++++++++-- internal/agent/tools/history_test.go | 60 +++++++ .../conversation/flow/resolver_history.go | 23 +-- .../conversation/flow/resolver_trim_test.go | 31 ++++ internal/pipeline/driver.go | 18 +- internal/pipeline/turn_response.go | 36 ++++ internal/pipeline/turn_response_test.go | 70 ++++++++ 7 files changed, 359 insertions(+), 46 deletions(-) create mode 100644 internal/agent/tools/history_test.go create mode 100644 internal/pipeline/turn_response.go create mode 100644 internal/pipeline/turn_response_test.go diff --git a/internal/agent/tools/history.go b/internal/agent/tools/history.go index f4354f11..74e2eeab 100644 --- a/internal/agent/tools/history.go +++ b/internal/agent/tools/history.go @@ -295,31 +295,168 @@ func extractTextContent(raw []byte) string { return "" } - if text := msg.TextContent(); text != "" { + if text := extractVisibleHistoryText(msg.Content); text != "" { return text } - // assistant tool_calls: show tool names - if len(msg.ToolCalls) > 0 { - names := make([]string, 0, len(msg.ToolCalls)) - for _, tc := range msg.ToolCalls { - if tc.Function.Name != "" { - names = append(names, tc.Function.Name) - } - } - if len(names) > 0 { - return "[tool_call: " + strings.Join(names, ", ") + "]" - } + if names := extractHistoryToolCallNames(msg); len(names) > 0 { + return "[tool_call: " + strings.Join(names, ", ") + "]" } - // tool result: content may be a JSON object; stringify it - if len(msg.Content) > 0 { - return string(msg.Content) + if names := extractHistoryToolResultNames(msg.Content); len(names) > 0 { + return "[tool_result: " + strings.Join(names, ", ") + "]" } return "" } +type historyContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + Emoji string `json:"emoji,omitempty"` + ToolName string `json:"toolName,omitempty"` + Content json.RawMessage `json:"content,omitempty"` +} + +func extractVisibleHistoryText(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + + var text string + if err := json.Unmarshal(raw, &text); err == nil { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return "" + } + if strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "{") { + if nested := extractVisibleHistoryText(json.RawMessage(trimmed)); nested != "" { + return nested + } + } + return trimmed + } + + parts := extractHistoryContentParts(raw) + if len(parts) > 0 { + lines := make([]string, 0, len(parts)) + for _, part := range parts { + partType := strings.ToLower(strings.TrimSpace(part.Type)) + switch { + case partType == "reasoning", partType == "tool-call", partType == "tool-result": + continue + case partType == "text" && strings.TrimSpace(part.Text) != "": + lines = append(lines, strings.TrimSpace(part.Text)) + case partType == "link" && strings.TrimSpace(part.URL) != "": + lines = append(lines, strings.TrimSpace(part.URL)) + case partType == "emoji" && strings.TrimSpace(part.Emoji) != "": + lines = append(lines, strings.TrimSpace(part.Emoji)) + case strings.TrimSpace(part.Text) != "": + lines = append(lines, strings.TrimSpace(part.Text)) + } + } + return strings.TrimSpace(strings.Join(lines, "\n")) + } + + var object map[string]any + if err := json.Unmarshal(raw, &object); err == nil { + if value, ok := object["text"].(string); ok { + return strings.TrimSpace(value) + } + } + + return "" +} + +func extractHistoryToolCallNames(msg conversation.ModelMessage) []string { + names := make([]string, 0, len(msg.ToolCalls)) + for _, part := range extractHistoryContentParts(msg.Content) { + if strings.ToLower(strings.TrimSpace(part.Type)) != "tool-call" { + continue + } + if name := strings.TrimSpace(part.ToolName); name != "" { + names = append(names, name) + } + } + if len(names) > 0 { + return dedupeHistoryNames(names) + } + + for _, tc := range msg.ToolCalls { + if name := strings.TrimSpace(tc.Function.Name); name != "" { + names = append(names, name) + } + } + return dedupeHistoryNames(names) +} + +func extractHistoryToolResultNames(raw json.RawMessage) []string { + parts := extractHistoryContentParts(raw) + if len(parts) == 0 { + return nil + } + + names := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.ToLower(strings.TrimSpace(part.Type)) != "tool-result" { + continue + } + if name := strings.TrimSpace(part.ToolName); name != "" { + names = append(names, name) + } + } + return dedupeHistoryNames(names) +} + +func extractHistoryContentParts(raw json.RawMessage) []historyContentPart { + if len(raw) == 0 { + return nil + } + + var parts []historyContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + return parts + } + + var encoded string + if err := json.Unmarshal(raw, &encoded); err == nil { + trimmed := strings.TrimSpace(encoded) + if strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "{") { + return extractHistoryContentParts(json.RawMessage(trimmed)) + } + } + + var object struct { + Content json.RawMessage `json:"content"` + } + if err := json.Unmarshal(raw, &object); err == nil && len(object.Content) > 0 { + return extractHistoryContentParts(object.Content) + } + + return nil +} + +func dedupeHistoryNames(names []string) []string { + if len(names) == 0 { + return nil + } + seen := make(map[string]struct{}, len(names)) + out := make([]string, 0, len(names)) + for _, name := range names { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + out = append(out, trimmed) + } + return out +} + var timeFormats = []string{ time.RFC3339, "2006-01-02T15:04:05", diff --git a/internal/agent/tools/history_test.go b/internal/agent/tools/history_test.go new file mode 100644 index 00000000..dda79892 --- /dev/null +++ b/internal/agent/tools/history_test.go @@ -0,0 +1,60 @@ +package tools + +import ( + "encoding/json" + "testing" + + "github.com/memohai/memoh/internal/conversation" +) + +func TestExtractTextContentSummarizesAssistantToolCalls(t *testing.T) { + t.Parallel() + + content, err := json.Marshal([]map[string]any{ + {"type": "reasoning", "text": "thinking"}, + {"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}}, + {"type": "tool-call", "toolName": "edit", "toolCallId": "call-2", "input": map[string]any{"path": "/tmp/a.txt"}}, + }) + if err != nil { + t.Fatalf("marshal content: %v", err) + } + + raw, err := json.Marshal(conversation.ModelMessage{ + Role: "assistant", + Content: content, + }) + if err != nil { + t.Fatalf("marshal model message: %v", err) + } + + got := extractTextContent(raw) + want := "[tool_call: read, edit]" + if got != want { + t.Fatalf("extractTextContent() = %q, want %q", got, want) + } +} + +func TestExtractTextContentSummarizesToolResults(t *testing.T) { + t.Parallel() + + content, err := json.Marshal([]map[string]any{ + {"type": "tool-result", "toolName": "search_messages", "toolCallId": "call-1", "result": map[string]any{"count": 3}}, + }) + if err != nil { + t.Fatalf("marshal content: %v", err) + } + + raw, err := json.Marshal(conversation.ModelMessage{ + Role: "tool", + Content: content, + }) + if err != nil { + t.Fatalf("marshal model message: %v", err) + } + + got := extractTextContent(raw) + want := "[tool_result: search_messages]" + if got != want { + t.Fatalf("extractTextContent() = %q, want %q", got, want) + } +} diff --git a/internal/conversation/flow/resolver_history.go b/internal/conversation/flow/resolver_history.go index 56ca8f8d..5efa8e90 100644 --- a/internal/conversation/flow/resolver_history.go +++ b/internal/conversation/flow/resolver_history.go @@ -325,22 +325,11 @@ func (r *Resolver) loadTurnResponses(ctx context.Context, sessionID string) []pi } var trs []pipelinepkg.TurnResponseEntry for _, m := range msgs { - if m.Role != "assistant" && m.Role != "tool" { + entry, ok := pipelinepkg.DecodeTurnResponseEntry(m) + if !ok { continue } - var mm conversation.ModelMessage - if err := json.Unmarshal(m.Content, &mm); err != nil { - continue - } - contentStr := "" - if mm.Content != nil { - contentStr = string(mm.Content) - } - trs = append(trs, pipelinepkg.TurnResponseEntry{ - RequestedAtMs: m.CreatedAt.UnixMilli(), - Role: m.Role, - Content: contentStr, - }) + trs = append(trs, entry) } return trs } @@ -356,8 +345,10 @@ func stripToolMessages(messages []conversation.ModelMessage) []conversation.Mode if strings.EqualFold(role, "tool") { continue } - // Remove assistant messages that contain tool calls (without text content). - if strings.EqualFold(role, "assistant") && len(m.ToolCalls) > 0 { + // Remove assistant messages that only contain tool calls / reasoning with + // no visible text. Tool-call metadata may live either in ToolCalls or in + // structured content parts. + if strings.EqualFold(role, "assistant") && hasToolCallContent(m) { text := m.TextContent() if strings.TrimSpace(text) == "" { continue diff --git a/internal/conversation/flow/resolver_trim_test.go b/internal/conversation/flow/resolver_trim_test.go index 41016326..480687a2 100644 --- a/internal/conversation/flow/resolver_trim_test.go +++ b/internal/conversation/flow/resolver_trim_test.go @@ -1,6 +1,7 @@ package flow import ( + "encoding/json" "testing" "github.com/memohai/memoh/internal/conversation" @@ -170,3 +171,33 @@ func TestTrimMessagesByTokens_EstimatesFallback(t *testing.T) { t.Fatalf("expected [system notice, assistant message], got %d messages: %+v", len(trimmed), trimmed) } } + +func TestStripToolMessages_RemovesAssistantToolCallContentParts(t *testing.T) { + t.Parallel() + + content, err := json.Marshal([]map[string]any{ + {"type": "reasoning", "text": "thinking"}, + {"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}}, + }) + if err != nil { + t.Fatalf("marshal content: %v", err) + } + + filtered := stripToolMessages([]conversation.ModelMessage{ + { + Role: "assistant", + Content: content, + }, + { + Role: "assistant", + Content: conversation.NewTextContent("保留这条消息"), + }, + }) + + if len(filtered) != 1 { + t.Fatalf("expected 1 message after filtering, got %d", len(filtered)) + } + if filtered[0].TextContent() != "保留这条消息" { + t.Fatalf("unexpected remaining message: %+v", filtered[0]) + } +} diff --git a/internal/pipeline/driver.go b/internal/pipeline/driver.go index 276b0b6a..9c15a81b 100644 --- a/internal/pipeline/driver.go +++ b/internal/pipeline/driver.go @@ -12,7 +12,6 @@ import ( agentpkg "github.com/memohai/memoh/internal/agent" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/conversation" messagepkg "github.com/memohai/memoh/internal/message" sessionpkg "github.com/memohai/memoh/internal/session" ) @@ -367,22 +366,11 @@ func (d *DiscussDriver) loadTurnResponses(ctx context.Context, sessionID string) var trs []TurnResponseEntry for _, m := range msgs { - if m.Role != "assistant" && m.Role != "tool" { + entry, ok := DecodeTurnResponseEntry(m) + if !ok { continue } - var mm conversation.ModelMessage - if err := json.Unmarshal(m.Content, &mm); err != nil { - continue - } - contentStr := mm.TextContent() - if contentStr == "" { - continue - } - trs = append(trs, TurnResponseEntry{ - RequestedAtMs: m.CreatedAt.UnixMilli(), - Role: m.Role, - Content: contentStr, - }) + trs = append(trs, entry) } return trs } diff --git a/internal/pipeline/turn_response.go b/internal/pipeline/turn_response.go new file mode 100644 index 00000000..eac3f502 --- /dev/null +++ b/internal/pipeline/turn_response.go @@ -0,0 +1,36 @@ +package pipeline + +import ( + "encoding/json" + "strings" + + "github.com/memohai/memoh/internal/conversation" + messagepkg "github.com/memohai/memoh/internal/message" +) + +// DecodeTurnResponseEntry converts a persisted bot message into a TR entry for +// pipeline context composition. Only visible text is preserved; structured +// tool-call / tool-result payloads are skipped to avoid re-injecting raw JSON +// into later prompts. +func DecodeTurnResponseEntry(msg messagepkg.Message) (TurnResponseEntry, bool) { + role := strings.TrimSpace(msg.Role) + if role != "assistant" && role != "tool" { + return TurnResponseEntry{}, false + } + + var modelMsg conversation.ModelMessage + if err := json.Unmarshal(msg.Content, &modelMsg); err != nil { + return TurnResponseEntry{}, false + } + + content := strings.TrimSpace(modelMsg.TextContent()) + if content == "" { + return TurnResponseEntry{}, false + } + + return TurnResponseEntry{ + RequestedAtMs: msg.CreatedAt.UnixMilli(), + Role: msg.Role, + Content: content, + }, true +} diff --git a/internal/pipeline/turn_response_test.go b/internal/pipeline/turn_response_test.go new file mode 100644 index 00000000..8643aa9b --- /dev/null +++ b/internal/pipeline/turn_response_test.go @@ -0,0 +1,70 @@ +package pipeline + +import ( + "encoding/json" + "testing" + "time" + + "github.com/memohai/memoh/internal/conversation" + messagepkg "github.com/memohai/memoh/internal/message" +) + +func TestDecodeTurnResponseEntryUsesVisibleText(t *testing.T) { + t.Parallel() + + content, err := json.Marshal([]map[string]any{ + {"type": "reasoning", "text": "thinking"}, + {"type": "text", "text": "任务完成"}, + }) + if err != nil { + t.Fatalf("marshal content: %v", err) + } + + modelMessage, err := json.Marshal(conversation.ModelMessage{ + Role: "assistant", + Content: content, + }) + if err != nil { + t.Fatalf("marshal model message: %v", err) + } + + entry, ok := DecodeTurnResponseEntry(messagepkg.Message{ + Role: "assistant", + Content: modelMessage, + CreatedAt: time.Unix(1710000000, 0).UTC(), + }) + if !ok { + t.Fatal("expected turn response entry") + } + if entry.Content != "任务完成" { + t.Fatalf("content = %q, want %q", entry.Content, "任务完成") + } +} + +func TestDecodeTurnResponseEntrySkipsToolCallOnlyPayload(t *testing.T) { + t.Parallel() + + content, err := json.Marshal([]map[string]any{ + {"type": "reasoning", "text": "thinking"}, + {"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}}, + }) + if err != nil { + t.Fatalf("marshal content: %v", err) + } + + modelMessage, err := json.Marshal(conversation.ModelMessage{ + Role: "assistant", + Content: content, + }) + if err != nil { + t.Fatalf("marshal model message: %v", err) + } + + if _, ok := DecodeTurnResponseEntry(messagepkg.Message{ + Role: "assistant", + Content: modelMessage, + CreatedAt: time.Unix(1710000000, 0).UTC(), + }); ok { + t.Fatal("expected tool-call-only payload to be skipped") + } +}