fix(pipeline): preserve tool calls and anchor driver cursor correctly

The discuss driver's RC+TR composition had three compounding bugs that
caused old tasks to be re-answered after idle timeouts and made the LLM
blind to its own prior tool usage:

- DecodeTurnResponseEntry only kept visible text via TextContent(), so
  assistant steps carrying only tool_call parts (the first half of every
  tool round) were dropped entirely. Rewritten to render tool_call and
  tool_result parts as <tool_call>/<tool_result> tags, covering both
  Vercel-style content parts and legacy OpenAI ToolCalls/role=tool
  envelopes. Reasoning parts remain stripped to avoid re-injection.

- loadTurnResponses hard-capped TRs at 24h while RC is replayed in full
  from the events table, producing asymmetric context (user messages
  from day 1 visible, matching bot replies missing). The cap is removed;
  any size-bound trimming belongs in compaction, not here.

- lastProcessedMs lived only in memory and was set to time.Now() at turn
  end. After the 10-minute idle timeout, the goroutine exited and the
  next turn started with cursor=0, treating the entire history as new
  traffic. Now initialised from the latest TR's requested_at on cold
  start, and advanced to max(consumed RC.ReceivedAtMs) per turn so that
  messages arriving mid-generation trigger a follow-up round instead of
  being wrongly marked processed.
This commit is contained in:
Acbox
2026-04-23 20:05:35 +08:00
parent defddc2257
commit e94b4b58ed
4 changed files with 551 additions and 17 deletions
+61 -5
View File
@@ -230,6 +230,25 @@ func (d *DiscussDriver) handleReplyWithAgent(ctx context.Context, sess *discussS
trs := d.loadTurnResponses(ctx, cfg.SessionID)
// Cold-start / post-idle initialisation: if we haven't processed anything
// in this goroutine's lifetime yet, anchor `lastProcessedMs` to the most
// recent TR's requested_at. Any RC segment strictly older than that has
// already been "seen" by a prior LLM call (whose response is in the TR
// stream), so it should not retrigger a reply. Without this anchor, every
// idle-timeout restart would treat the entire session history as brand
// new external traffic and re-answer it.
if sess.lastProcessedMs == 0 {
sess.lastProcessedMs = anchorFromTRs(trs)
}
// Re-evaluate the trigger condition now that lastProcessedMs is anchored.
// The outer loop used lastProcessedMs=0 to allow first-time dispatch into
// this function; after initialisation, we must verify there's actually a
// new external event past the anchor before spending an LLM call.
if LatestExternalEventMs(rc, sess.lastProcessedMs) == 0 {
return
}
composed := ComposeContext(rc, trs, "")
if composed == nil {
return
@@ -285,8 +304,6 @@ func (d *DiscussDriver) handleReplyWithAgent(ctx context.Context, sess *discussS
}
}
now := time.Now()
if d.deps.Resolver != nil && len(finalMessages) > 0 {
var sdkMsgs []sdk.Message
if json.Unmarshal(finalMessages, &sdkMsgs) == nil && len(sdkMsgs) > 0 {
@@ -299,7 +316,39 @@ func (d *DiscussDriver) handleReplyWithAgent(ctx context.Context, sess *discussS
}
}
sess.lastProcessedMs = now.UnixMilli()
// Advance the cursor to the latest RC segment actually consumed in this
// turn (not wall-clock time). Messages that arrive DURING LLM generation
// will land in a newer RC with ReceivedAtMs > this cursor and correctly
// trigger another round; wall-clock would wrongly mark them processed.
consumedMs := latestRCReceivedAtMs(rc)
if consumedMs > sess.lastProcessedMs {
sess.lastProcessedMs = consumedMs
}
}
// latestRCReceivedAtMs returns the maximum ReceivedAtMs across all segments
// in the given RC, or 0 if the RC is empty.
func latestRCReceivedAtMs(rc RenderedContext) int64 {
var latest int64
for _, seg := range rc {
if seg.ReceivedAtMs > latest {
latest = seg.ReceivedAtMs
}
}
return latest
}
// anchorFromTRs returns the maximum RequestedAtMs across a TR slice. Used
// by the cold-start initialisation of `lastProcessedMs` so that RC segments
// older than the latest persisted bot response are not re-answered.
func anchorFromTRs(trs []TurnResponseEntry) int64 {
var latest int64
for _, tr := range trs {
if tr.RequestedAtMs > latest {
latest = tr.RequestedAtMs
}
}
return latest
}
// broadcastDiscussEvent forwards an agent stream event to the RouteHub so the
@@ -352,13 +401,20 @@ func agentEventToChannelEvent(e agentpkg.StreamEvent) (channel.StreamEvent, bool
}
}
// loadTurnResponses loads every assistant/tool message ever persisted for
// this session and decodes them into TR entries. There is no time-based cut
// off on purpose: truncating TRs while RC is replayed in full from the events
// table creates an asymmetric context (user messages visible, the bot's own
// earlier replies missing) that confuses both the LLM and loop-detection.
// Any size-bound trimming should happen later via compaction, not here.
func (d *DiscussDriver) loadTurnResponses(ctx context.Context, sessionID string) []TurnResponseEntry {
if d.deps.MessageService == nil {
return nil
}
since := time.Now().UTC().Add(-24 * time.Hour)
msgs, err := d.deps.MessageService.ListActiveSinceBySession(ctx, sessionID, since)
// time.Unix(0, 0) is the Unix epoch; the underlying query uses
// `created_at >= $1`, so this effectively loads every session message.
msgs, err := d.deps.MessageService.ListActiveSinceBySession(ctx, sessionID, time.Unix(0, 0).UTC())
if err != nil {
d.logger.Warn("load TRs failed", slog.String("session_id", sessionID), slog.Any("error", err))
return nil
+103
View File
@@ -208,6 +208,109 @@ func TestHandleReplyWithAgent_NoInlineWhenNoVision(t *testing.T) {
}
}
func TestAnchorFromTRs(t *testing.T) {
t.Parallel()
if got := anchorFromTRs(nil); got != 0 {
t.Fatalf("empty TRs anchor = %d, want 0", got)
}
got := anchorFromTRs([]TurnResponseEntry{
{RequestedAtMs: 100},
{RequestedAtMs: 500},
{RequestedAtMs: 300},
})
if got != 500 {
t.Fatalf("anchor = %d, want 500", got)
}
}
func TestLatestRCReceivedAtMs(t *testing.T) {
t.Parallel()
if got := latestRCReceivedAtMs(nil); got != 0 {
t.Fatalf("empty RC = %d, want 0", got)
}
got := latestRCReceivedAtMs(RenderedContext{
{ReceivedAtMs: 100},
{ReceivedAtMs: 900},
{ReceivedAtMs: 500, IsMyself: true},
})
if got != 900 {
t.Fatalf("latest = %d, want 900", got)
}
}
// TestHandleReplyWithAgent_ColdStartAnchoredByTR simulates idle-timeout
// restart: the session's in-memory lastProcessedMs is 0, but RC replay has
// brought back old user messages that were already answered in prior
// LLM rounds (represented by TRs). The driver MUST NOT re-answer them.
func TestHandleReplyWithAgent_ColdStartAnchoredByTR(t *testing.T) {
rc := RenderedContext{
{
ReceivedAtMs: 100,
Content: []RenderedContentPiece{{Type: "text", Text: `<message id="old">task 1</message>`}},
},
}
fakeAgent := &fakeDiscussStreamer{}
resolver := &fakeRunConfigResolver{}
driver := NewDiscussDriver(DiscussDriverDeps{
Pipeline: NewPipeline(RenderParams{}),
Resolver: resolver,
MessageService: nil,
})
sess := &discussSession{
config: DiscussSessionConfig{BotID: "b", SessionID: "s"},
lastProcessedMs: 0,
}
// Simulate a previously answered round by pre-stuffing a TR newer than
// the RC segment's ReceivedAtMs. Since we cannot inject MessageService
// easily, we instead pre-set lastProcessedMs as the anchor would.
sess.lastProcessedMs = 200 // mimic anchorFromTRs result
driver.handleReplyWithAgent(context.Background(), sess, rc, driver.logger, fakeAgent)
if fakeAgent.lastConfig != nil {
t.Fatal("agent must not be invoked when all RC segments predate lastProcessedMs")
}
}
// TestHandleReplyWithAgent_CursorAdvancesToRCNotWallClock ensures that after
// a turn we set lastProcessedMs to the max ReceivedAtMs actually consumed in
// the RC snapshot, not time.Now(). This matters for messages that arrive
// mid-turn: they end up in a fresher RC with ReceivedAtMs > cursor, which
// correctly triggers the next round.
func TestHandleReplyWithAgent_CursorAdvancesToRCNotWallClock(t *testing.T) {
rc := RenderedContext{
{
ReceivedAtMs: 777,
Content: []RenderedContentPiece{{Type: "text", Text: `<message id="x">hello</message>`}},
},
}
fakeAgent := &fakeDiscussStreamer{}
resolver := &fakeRunConfigResolver{}
driver := NewDiscussDriver(DiscussDriverDeps{
Pipeline: NewPipeline(RenderParams{}),
Resolver: resolver,
})
sess := &discussSession{
config: DiscussSessionConfig{BotID: "b", SessionID: "s"},
lastProcessedMs: 0,
}
driver.handleReplyWithAgent(context.Background(), sess, rc, driver.logger, fakeAgent)
if fakeAgent.lastConfig == nil {
t.Fatal("expected agent to be invoked")
}
if sess.lastProcessedMs != 777 {
t.Fatalf("lastProcessedMs = %d, want 777 (max RC ReceivedAtMs)", sess.lastProcessedMs)
}
}
// --- Test helpers ---
type fakeDiscussStreamer struct {
+190 -7
View File
@@ -2,6 +2,7 @@ package pipeline
import (
"encoding/json"
"fmt"
"strings"
"github.com/memohai/memoh/internal/conversation"
@@ -9,9 +10,17 @@ import (
)
// DecodeTurnResponseEntry converts a persisted bot message into a TR entry for
// pipeline context composition. Only visible text is preserved; structured
// tool-call / tool-result payloads are skipped to avoid re-injecting raw JSON
// into later prompts.
// pipeline context composition.
//
// Unlike the old implementation (which only kept plain text and dropped all
// tool-call / tool-result payloads), this version renders the full turn —
// including tool calls and their results — into a single structured string
// so the LLM can observe its own prior tool usage when the conversation is
// later replayed or summarised.
//
// The rendering is intentionally compact and XML-flavoured so it survives
// round-trips through the merge/compose pipeline without being confused with
// the user-facing XML used by Rendering.
func DecodeTurnResponseEntry(msg messagepkg.Message) (TurnResponseEntry, bool) {
role := strings.TrimSpace(msg.Role)
if role != "assistant" && role != "tool" {
@@ -23,14 +32,188 @@ func DecodeTurnResponseEntry(msg messagepkg.Message) (TurnResponseEntry, bool) {
return TurnResponseEntry{}, false
}
content := strings.TrimSpace(modelMsg.TextContent())
if content == "" {
var rendered string
switch role {
case "tool":
rendered = renderToolRoleMessage(modelMsg)
default:
rendered = renderAssistantMessage(modelMsg)
}
if strings.TrimSpace(rendered) == "" {
return TurnResponseEntry{}, false
}
return TurnResponseEntry{
RequestedAtMs: msg.CreatedAt.UnixMilli(),
Role: msg.Role,
Content: content,
Role: role,
Content: rendered,
}, true
}
// turnResponsePart is a permissive view of a persisted content part. It
// purposefully uses json.RawMessage for tool input/output to avoid losing
// structure while keeping the type declaration local to this package.
type turnResponsePart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ToolCallID string `json:"toolCallId,omitempty"`
ToolName string `json:"toolName,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
Output json.RawMessage `json:"output,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
}
func renderAssistantMessage(msg conversation.ModelMessage) string {
var b strings.Builder
// 1) Plain-string content (legacy format).
if len(msg.Content) > 0 {
var plain string
if err := json.Unmarshal(msg.Content, &plain); err == nil {
plain = strings.TrimSpace(plain)
if plain != "" {
b.WriteString(plain)
}
}
}
// 2) Array-of-parts content (Vercel AI SDK uiMessage format).
var parts []turnResponsePart
if len(msg.Content) > 0 {
_ = json.Unmarshal(msg.Content, &parts)
}
for _, p := range parts {
switch strings.ToLower(strings.TrimSpace(p.Type)) {
case "text":
text := strings.TrimSpace(p.Text)
if text == "" {
continue
}
if b.Len() > 0 {
b.WriteByte('\n')
}
b.WriteString(text)
case "reasoning":
// Intentionally omitted: reasoning is model-internal and must not
// leak back into subsequent prompts verbatim.
continue
case "tool-call":
writeToolCallTag(&b, p.ToolCallID, p.ToolName, p.Input)
case "tool-result":
payload := p.Output
if len(payload) == 0 {
payload = p.Result
}
writeToolResultTag(&b, p.ToolCallID, p.ToolName, payload)
}
}
// 3) Top-level ToolCalls field (older OpenAI-style wire format).
for _, call := range msg.ToolCalls {
id := strings.TrimSpace(call.ID)
name := strings.TrimSpace(call.Function.Name)
args := strings.TrimSpace(call.Function.Arguments)
var input json.RawMessage
if args != "" {
// Arguments is a string containing JSON; try to keep it raw so
// the downstream renderer doesn't double-escape.
if json.Valid([]byte(args)) {
input = json.RawMessage(args)
} else {
encoded, _ := json.Marshal(args)
input = encoded
}
}
writeToolCallTag(&b, id, name, input)
}
return b.String()
}
func renderToolRoleMessage(msg conversation.ModelMessage) string {
// Two possible persistence shapes:
// a) Content is a JSON array of parts with type="tool-result".
// b) Content is the tool result itself, and ToolCallID is set on the
// ModelMessage envelope (older OpenAI-style format).
var b strings.Builder
var parts []turnResponsePart
if len(msg.Content) > 0 {
_ = json.Unmarshal(msg.Content, &parts)
}
for _, p := range parts {
if strings.ToLower(strings.TrimSpace(p.Type)) != "tool-result" {
continue
}
payload := p.Output
if len(payload) == 0 {
payload = p.Result
}
writeToolResultTag(&b, p.ToolCallID, p.ToolName, payload)
}
if b.Len() > 0 {
return b.String()
}
if strings.TrimSpace(msg.ToolCallID) != "" {
writeToolResultTag(&b, msg.ToolCallID, msg.Name, msg.Content)
}
return b.String()
}
func writeToolCallTag(b *strings.Builder, id, name string, input json.RawMessage) {
if b.Len() > 0 {
b.WriteByte('\n')
}
fmt.Fprintf(b, `<tool_call id=%q name=%q>`, escapeXMLAttrValue(strings.TrimSpace(id)), escapeXMLAttrValue(strings.TrimSpace(name)))
if payload := formatToolPayload(input); payload != "" {
b.WriteString(payload)
}
b.WriteString("</tool_call>")
}
func writeToolResultTag(b *strings.Builder, id, name string, payload json.RawMessage) {
if b.Len() > 0 {
b.WriteByte('\n')
}
fmt.Fprintf(b, `<tool_result id=%q name=%q>`, escapeXMLAttrValue(strings.TrimSpace(id)), escapeXMLAttrValue(strings.TrimSpace(name)))
if rendered := formatToolPayload(payload); rendered != "" {
b.WriteString(rendered)
}
b.WriteString("</tool_result>")
}
// formatToolPayload returns a compact textual representation of a tool
// input/output payload safe to embed inside a tag body.
func formatToolPayload(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" || trimmed == "null" {
return ""
}
// If the payload is a JSON string, unquote it so the body reads naturally.
var asString string
if err := json.Unmarshal(raw, &asString); err == nil {
s := strings.TrimSpace(asString)
if s == "" {
return ""
}
return escapeXMLText(s)
}
// Otherwise, re-encode as compact JSON so whitespace is normalised and
// any nested structured content round-trips losslessly.
var v any
if err := json.Unmarshal(raw, &v); err == nil {
encoded, err := json.Marshal(v)
if err == nil {
return escapeXMLText(string(encoded))
}
}
return escapeXMLText(trimmed)
}
+197 -5
View File
@@ -2,6 +2,7 @@ package pipeline
import (
"encoding/json"
"strings"
"testing"
"time"
@@ -39,14 +40,23 @@ func TestDecodeTurnResponseEntryUsesVisibleText(t *testing.T) {
if entry.Content != "任务完成" {
t.Fatalf("content = %q, want %q", entry.Content, "任务完成")
}
// Reasoning must never leak into TRs to avoid re-injection into prompts.
if strings.Contains(entry.Content, "thinking") {
t.Fatalf("reasoning leaked into TR: %q", entry.Content)
}
}
func TestDecodeTurnResponseEntrySkipsToolCallOnlyPayload(t *testing.T) {
func TestDecodeTurnResponseEntryPreservesToolCallOnlyPayload(t *testing.T) {
t.Parallel()
content, err := json.Marshal([]map[string]any{
{"type": "reasoning", "text": "thinking"},
{"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}},
{
"type": "tool-call",
"toolName": "read",
"toolCallId": "call-1",
"input": map[string]any{"path": "/tmp/a.txt"},
},
})
if err != nil {
t.Fatalf("marshal content: %v", err)
@@ -60,11 +70,193 @@ func TestDecodeTurnResponseEntrySkipsToolCallOnlyPayload(t *testing.T) {
t.Fatalf("marshal model message: %v", err)
}
if _, ok := DecodeTurnResponseEntry(messagepkg.Message{
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
Role: "assistant",
Content: modelMessage,
CreatedAt: time.Unix(1710000000, 0).UTC(),
}); ok {
t.Fatal("expected tool-call-only payload to be skipped")
})
if !ok {
t.Fatal("expected tool-call-only payload to be preserved as TR")
}
if !strings.Contains(entry.Content, `<tool_call id="call-1" name="read">`) {
t.Fatalf("missing tool_call tag: %q", entry.Content)
}
if !strings.Contains(entry.Content, `"path":"/tmp/a.txt"`) {
t.Fatalf("tool input missing: %q", entry.Content)
}
if strings.Contains(entry.Content, "thinking") {
t.Fatalf("reasoning leaked: %q", entry.Content)
}
}
func TestDecodeTurnResponseEntryRendersTextAndToolCall(t *testing.T) {
t.Parallel()
content, err := json.Marshal([]map[string]any{
{"type": "text", "text": "Let me check."},
{
"type": "tool-call",
"toolName": "web_search",
"toolCallId": "call-42",
"input": map[string]any{"query": "today news"},
},
})
if err != nil {
t.Fatalf("marshal content: %v", err)
}
modelMessage, err := json.Marshal(conversation.ModelMessage{
Role: "assistant",
Content: content,
})
if err != nil {
t.Fatalf("marshal model message: %v", err)
}
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
Role: "assistant",
Content: modelMessage,
})
if !ok {
t.Fatal("expected entry")
}
if !strings.Contains(entry.Content, "Let me check.") {
t.Fatalf("missing text portion: %q", entry.Content)
}
if !strings.Contains(entry.Content, `<tool_call id="call-42" name="web_search">`) {
t.Fatalf("missing tool_call tag: %q", entry.Content)
}
}
func TestDecodeTurnResponseEntryToolRoleWithPartsResult(t *testing.T) {
t.Parallel()
content, err := json.Marshal([]map[string]any{
{
"type": "tool-result",
"toolCallId": "call-1",
"toolName": "web_search",
"output": map[string]any{
"count": 3,
"summary": "ok",
},
},
})
if err != nil {
t.Fatalf("marshal content: %v", err)
}
modelMessage, err := json.Marshal(conversation.ModelMessage{
Role: "tool",
Content: content,
})
if err != nil {
t.Fatalf("marshal model message: %v", err)
}
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
Role: "tool",
Content: modelMessage,
})
if !ok {
t.Fatal("expected tool role entry")
}
if !strings.Contains(entry.Content, `<tool_result id="call-1" name="web_search">`) {
t.Fatalf("missing tool_result tag: %q", entry.Content)
}
if !strings.Contains(entry.Content, `"count":3`) || !strings.Contains(entry.Content, `"summary":"ok"`) {
t.Fatalf("structured tool output not preserved: %q", entry.Content)
}
}
func TestDecodeTurnResponseEntryToolRoleLegacyEnvelope(t *testing.T) {
t.Parallel()
// Old OpenAI-style: role=tool + ToolCallID on the envelope, Content is
// a JSON string carrying the result directly.
resultBody := json.RawMessage(`{"status":"ok"}`)
modelMessage, err := json.Marshal(conversation.ModelMessage{
Role: "tool",
ToolCallID: "call-99",
Name: "ping",
Content: resultBody,
})
if err != nil {
t.Fatalf("marshal model message: %v", err)
}
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
Role: "tool",
Content: modelMessage,
})
if !ok {
t.Fatal("expected entry for legacy tool envelope")
}
if !strings.Contains(entry.Content, `<tool_result id="call-99" name="ping">`) {
t.Fatalf("missing tool_result tag: %q", entry.Content)
}
if !strings.Contains(entry.Content, `"status":"ok"`) {
t.Fatalf("legacy tool body missing: %q", entry.Content)
}
}
func TestDecodeTurnResponseEntrySkipsEmpty(t *testing.T) {
t.Parallel()
// Only reasoning → nothing to expose to future prompts → skip.
content, err := json.Marshal([]map[string]any{
{"type": "reasoning", "text": "thinking out loud"},
})
if err != nil {
t.Fatalf("marshal content: %v", err)
}
modelMessage, err := json.Marshal(conversation.ModelMessage{
Role: "assistant",
Content: content,
})
if err != nil {
t.Fatalf("marshal model message: %v", err)
}
if _, ok := DecodeTurnResponseEntry(messagepkg.Message{
Role: "assistant",
Content: modelMessage,
}); ok {
t.Fatal("expected reasoning-only message to be skipped")
}
}
func TestDecodeTurnResponseEntryLegacyToolCallsField(t *testing.T) {
t.Parallel()
// Older OpenAI envelope: Content is empty string, ToolCalls carries
// the function-call structure.
modelMessage, err := json.Marshal(conversation.ModelMessage{
Role: "assistant",
Content: json.RawMessage(`""`),
ToolCalls: []conversation.ToolCall{
{
ID: "call-legacy",
Type: "function",
Function: conversation.ToolCallFunction{
Name: "send",
Arguments: `{"text":"hi"}`,
},
},
},
})
if err != nil {
t.Fatalf("marshal model message: %v", err)
}
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
Role: "assistant",
Content: modelMessage,
})
if !ok {
t.Fatal("expected legacy tool-calls envelope to decode")
}
if !strings.Contains(entry.Content, `<tool_call id="call-legacy" name="send">`) {
t.Fatalf("missing tool_call tag: %q", entry.Content)
}
if !strings.Contains(entry.Content, `"text":"hi"`) {
t.Fatalf("arguments missing: %q", entry.Content)
}
}