mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix: stop escaped history context blowup
Avoid feeding structured tool payloads back into pipeline context as doubly encoded JSON, and return readable history summaries instead of raw message blobs.
This commit is contained in:
+152
-15
@@ -295,31 +295,168 @@ func extractTextContent(raw []byte) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
if text := msg.TextContent(); text != "" {
|
||||
if text := extractVisibleHistoryText(msg.Content); text != "" {
|
||||
return text
|
||||
}
|
||||
|
||||
// assistant tool_calls: show tool names
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
names := make([]string, 0, len(msg.ToolCalls))
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.Function.Name != "" {
|
||||
names = append(names, tc.Function.Name)
|
||||
}
|
||||
}
|
||||
if len(names) > 0 {
|
||||
return "[tool_call: " + strings.Join(names, ", ") + "]"
|
||||
}
|
||||
if names := extractHistoryToolCallNames(msg); len(names) > 0 {
|
||||
return "[tool_call: " + strings.Join(names, ", ") + "]"
|
||||
}
|
||||
|
||||
// tool result: content may be a JSON object; stringify it
|
||||
if len(msg.Content) > 0 {
|
||||
return string(msg.Content)
|
||||
if names := extractHistoryToolResultNames(msg.Content); len(names) > 0 {
|
||||
return "[tool_result: " + strings.Join(names, ", ") + "]"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
type historyContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Emoji string `json:"emoji,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
func extractVisibleHistoryText(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(raw, &text); err == nil {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "{") {
|
||||
if nested := extractVisibleHistoryText(json.RawMessage(trimmed)); nested != "" {
|
||||
return nested
|
||||
}
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
parts := extractHistoryContentParts(raw)
|
||||
if len(parts) > 0 {
|
||||
lines := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
partType := strings.ToLower(strings.TrimSpace(part.Type))
|
||||
switch {
|
||||
case partType == "reasoning", partType == "tool-call", partType == "tool-result":
|
||||
continue
|
||||
case partType == "text" && strings.TrimSpace(part.Text) != "":
|
||||
lines = append(lines, strings.TrimSpace(part.Text))
|
||||
case partType == "link" && strings.TrimSpace(part.URL) != "":
|
||||
lines = append(lines, strings.TrimSpace(part.URL))
|
||||
case partType == "emoji" && strings.TrimSpace(part.Emoji) != "":
|
||||
lines = append(lines, strings.TrimSpace(part.Emoji))
|
||||
case strings.TrimSpace(part.Text) != "":
|
||||
lines = append(lines, strings.TrimSpace(part.Text))
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
var object map[string]any
|
||||
if err := json.Unmarshal(raw, &object); err == nil {
|
||||
if value, ok := object["text"].(string); ok {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractHistoryToolCallNames(msg conversation.ModelMessage) []string {
|
||||
names := make([]string, 0, len(msg.ToolCalls))
|
||||
for _, part := range extractHistoryContentParts(msg.Content) {
|
||||
if strings.ToLower(strings.TrimSpace(part.Type)) != "tool-call" {
|
||||
continue
|
||||
}
|
||||
if name := strings.TrimSpace(part.ToolName); name != "" {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
if len(names) > 0 {
|
||||
return dedupeHistoryNames(names)
|
||||
}
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if name := strings.TrimSpace(tc.Function.Name); name != "" {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
return dedupeHistoryNames(names)
|
||||
}
|
||||
|
||||
func extractHistoryToolResultNames(raw json.RawMessage) []string {
|
||||
parts := extractHistoryContentParts(raw)
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if strings.ToLower(strings.TrimSpace(part.Type)) != "tool-result" {
|
||||
continue
|
||||
}
|
||||
if name := strings.TrimSpace(part.ToolName); name != "" {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
return dedupeHistoryNames(names)
|
||||
}
|
||||
|
||||
func extractHistoryContentParts(raw json.RawMessage) []historyContentPart {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []historyContentPart
|
||||
if err := json.Unmarshal(raw, &parts); err == nil {
|
||||
return parts
|
||||
}
|
||||
|
||||
var encoded string
|
||||
if err := json.Unmarshal(raw, &encoded); err == nil {
|
||||
trimmed := strings.TrimSpace(encoded)
|
||||
if strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "{") {
|
||||
return extractHistoryContentParts(json.RawMessage(trimmed))
|
||||
}
|
||||
}
|
||||
|
||||
var object struct {
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &object); err == nil && len(object.Content) > 0 {
|
||||
return extractHistoryContentParts(object.Content)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func dedupeHistoryNames(names []string) []string {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(names))
|
||||
out := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
var timeFormats = []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02T15:04:05",
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
func TestExtractTextContentSummarizesAssistantToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
{"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}},
|
||||
{"type": "tool-call", "toolName": "edit", "toolCallId": "call-2", "input": map[string]any{"path": "/tmp/a.txt"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
got := extractTextContent(raw)
|
||||
want := "[tool_call: read, edit]"
|
||||
if got != want {
|
||||
t.Fatalf("extractTextContent() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTextContentSummarizesToolResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "tool-result", "toolName": "search_messages", "toolCallId": "call-1", "result": map[string]any{"count": 3}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
got := extractTextContent(raw)
|
||||
want := "[tool_result: search_messages]"
|
||||
if got != want {
|
||||
t.Fatalf("extractTextContent() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -325,22 +325,11 @@ func (r *Resolver) loadTurnResponses(ctx context.Context, sessionID string) []pi
|
||||
}
|
||||
var trs []pipelinepkg.TurnResponseEntry
|
||||
for _, m := range msgs {
|
||||
if m.Role != "assistant" && m.Role != "tool" {
|
||||
entry, ok := pipelinepkg.DecodeTurnResponseEntry(m)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var mm conversation.ModelMessage
|
||||
if err := json.Unmarshal(m.Content, &mm); err != nil {
|
||||
continue
|
||||
}
|
||||
contentStr := ""
|
||||
if mm.Content != nil {
|
||||
contentStr = string(mm.Content)
|
||||
}
|
||||
trs = append(trs, pipelinepkg.TurnResponseEntry{
|
||||
RequestedAtMs: m.CreatedAt.UnixMilli(),
|
||||
Role: m.Role,
|
||||
Content: contentStr,
|
||||
})
|
||||
trs = append(trs, entry)
|
||||
}
|
||||
return trs
|
||||
}
|
||||
@@ -356,8 +345,10 @@ func stripToolMessages(messages []conversation.ModelMessage) []conversation.Mode
|
||||
if strings.EqualFold(role, "tool") {
|
||||
continue
|
||||
}
|
||||
// Remove assistant messages that contain tool calls (without text content).
|
||||
if strings.EqualFold(role, "assistant") && len(m.ToolCalls) > 0 {
|
||||
// Remove assistant messages that only contain tool calls / reasoning with
|
||||
// no visible text. Tool-call metadata may live either in ToolCalls or in
|
||||
// structured content parts.
|
||||
if strings.EqualFold(role, "assistant") && hasToolCallContent(m) {
|
||||
text := m.TextContent()
|
||||
if strings.TrimSpace(text) == "" {
|
||||
continue
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
@@ -170,3 +171,33 @@ func TestTrimMessagesByTokens_EstimatesFallback(t *testing.T) {
|
||||
t.Fatalf("expected [system notice, assistant message], got %d messages: %+v", len(trimmed), trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripToolMessages_RemovesAssistantToolCallContentParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
{"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
|
||||
filtered := stripToolMessages([]conversation.ModelMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: conversation.NewTextContent("保留这条消息"),
|
||||
},
|
||||
})
|
||||
|
||||
if len(filtered) != 1 {
|
||||
t.Fatalf("expected 1 message after filtering, got %d", len(filtered))
|
||||
}
|
||||
if filtered[0].TextContent() != "保留这条消息" {
|
||||
t.Fatalf("unexpected remaining message: %+v", filtered[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
agentpkg "github.com/memohai/memoh/internal/agent"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
sessionpkg "github.com/memohai/memoh/internal/session"
|
||||
)
|
||||
@@ -367,22 +366,11 @@ func (d *DiscussDriver) loadTurnResponses(ctx context.Context, sessionID string)
|
||||
|
||||
var trs []TurnResponseEntry
|
||||
for _, m := range msgs {
|
||||
if m.Role != "assistant" && m.Role != "tool" {
|
||||
entry, ok := DecodeTurnResponseEntry(m)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var mm conversation.ModelMessage
|
||||
if err := json.Unmarshal(m.Content, &mm); err != nil {
|
||||
continue
|
||||
}
|
||||
contentStr := mm.TextContent()
|
||||
if contentStr == "" {
|
||||
continue
|
||||
}
|
||||
trs = append(trs, TurnResponseEntry{
|
||||
RequestedAtMs: m.CreatedAt.UnixMilli(),
|
||||
Role: m.Role,
|
||||
Content: contentStr,
|
||||
})
|
||||
trs = append(trs, entry)
|
||||
}
|
||||
return trs
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
)
|
||||
|
||||
// DecodeTurnResponseEntry converts a persisted bot message into a TR entry for
|
||||
// pipeline context composition. Only visible text is preserved; structured
|
||||
// tool-call / tool-result payloads are skipped to avoid re-injecting raw JSON
|
||||
// into later prompts.
|
||||
func DecodeTurnResponseEntry(msg messagepkg.Message) (TurnResponseEntry, bool) {
|
||||
role := strings.TrimSpace(msg.Role)
|
||||
if role != "assistant" && role != "tool" {
|
||||
return TurnResponseEntry{}, false
|
||||
}
|
||||
|
||||
var modelMsg conversation.ModelMessage
|
||||
if err := json.Unmarshal(msg.Content, &modelMsg); err != nil {
|
||||
return TurnResponseEntry{}, false
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(modelMsg.TextContent())
|
||||
if content == "" {
|
||||
return TurnResponseEntry{}, false
|
||||
}
|
||||
|
||||
return TurnResponseEntry{
|
||||
RequestedAtMs: msg.CreatedAt.UnixMilli(),
|
||||
Role: msg.Role,
|
||||
Content: content,
|
||||
}, true
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
)
|
||||
|
||||
func TestDecodeTurnResponseEntryUsesVisibleText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
{"type": "text", "text": "任务完成"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
|
||||
modelMessage, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "assistant",
|
||||
Content: modelMessage,
|
||||
CreatedAt: time.Unix(1710000000, 0).UTC(),
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("expected turn response entry")
|
||||
}
|
||||
if entry.Content != "任务完成" {
|
||||
t.Fatalf("content = %q, want %q", entry.Content, "任务完成")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTurnResponseEntrySkipsToolCallOnlyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
{"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
|
||||
modelMessage, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "assistant",
|
||||
Content: modelMessage,
|
||||
CreatedAt: time.Unix(1710000000, 0).UTC(),
|
||||
}); ok {
|
||||
t.Fatal("expected tool-call-only payload to be skipped")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user