mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(pipeline): preserve tool calls and anchor driver cursor correctly
The discuss driver's RC+TR composition had three compounding bugs that caused old tasks to be re-answered after idle timeouts and made the LLM blind to its own prior tool usage: - DecodeTurnResponseEntry only kept visible text via TextContent(), so assistant steps carrying only tool_call parts (the first half of every tool round) were dropped entirely. Rewritten to render tool_call and tool_result parts as <tool_call>/<tool_result> tags, covering both Vercel-style content parts and legacy OpenAI ToolCalls/role=tool envelopes. Reasoning parts remain stripped to avoid re-injection. - loadTurnResponses hard-capped TRs at 24h while RC is replayed in full from the events table, producing asymmetric context (user messages from day 1 visible, matching bot replies missing). The cap is removed; any size-bound trimming belongs in compaction, not here. - lastProcessedMs lived only in memory and was set to time.Now() at turn end. After the 10-minute idle timeout, the goroutine exited and the next turn started with cursor=0, treating the entire history as new traffic. Now initialised from the latest TR's requested_at on cold start, and advanced to max(consumed RC.ReceivedAtMs) per turn so that messages arriving mid-generation trigger a follow-up round instead of being wrongly marked processed.
This commit is contained in:
@@ -230,6 +230,25 @@ func (d *DiscussDriver) handleReplyWithAgent(ctx context.Context, sess *discussS
|
||||
|
||||
trs := d.loadTurnResponses(ctx, cfg.SessionID)
|
||||
|
||||
// Cold-start / post-idle initialisation: if we haven't processed anything
|
||||
// in this goroutine's lifetime yet, anchor `lastProcessedMs` to the most
|
||||
// recent TR's requested_at. Any RC segment strictly older than that has
|
||||
// already been "seen" by a prior LLM call (whose response is in the TR
|
||||
// stream), so it should not retrigger a reply. Without this anchor, every
|
||||
// idle-timeout restart would treat the entire session history as brand
|
||||
// new external traffic and re-answer it.
|
||||
if sess.lastProcessedMs == 0 {
|
||||
sess.lastProcessedMs = anchorFromTRs(trs)
|
||||
}
|
||||
|
||||
// Re-evaluate the trigger condition now that lastProcessedMs is anchored.
|
||||
// The outer loop used lastProcessedMs=0 to allow first-time dispatch into
|
||||
// this function; after initialisation, we must verify there's actually a
|
||||
// new external event past the anchor before spending an LLM call.
|
||||
if LatestExternalEventMs(rc, sess.lastProcessedMs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
composed := ComposeContext(rc, trs, "")
|
||||
if composed == nil {
|
||||
return
|
||||
@@ -285,8 +304,6 @@ func (d *DiscussDriver) handleReplyWithAgent(ctx context.Context, sess *discussS
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
if d.deps.Resolver != nil && len(finalMessages) > 0 {
|
||||
var sdkMsgs []sdk.Message
|
||||
if json.Unmarshal(finalMessages, &sdkMsgs) == nil && len(sdkMsgs) > 0 {
|
||||
@@ -299,7 +316,39 @@ func (d *DiscussDriver) handleReplyWithAgent(ctx context.Context, sess *discussS
|
||||
}
|
||||
}
|
||||
|
||||
sess.lastProcessedMs = now.UnixMilli()
|
||||
// Advance the cursor to the latest RC segment actually consumed in this
|
||||
// turn (not wall-clock time). Messages that arrive DURING LLM generation
|
||||
// will land in a newer RC with ReceivedAtMs > this cursor and correctly
|
||||
// trigger another round; wall-clock would wrongly mark them processed.
|
||||
consumedMs := latestRCReceivedAtMs(rc)
|
||||
if consumedMs > sess.lastProcessedMs {
|
||||
sess.lastProcessedMs = consumedMs
|
||||
}
|
||||
}
|
||||
|
||||
// latestRCReceivedAtMs returns the maximum ReceivedAtMs across all segments
|
||||
// in the given RC, or 0 if the RC is empty.
|
||||
func latestRCReceivedAtMs(rc RenderedContext) int64 {
|
||||
var latest int64
|
||||
for _, seg := range rc {
|
||||
if seg.ReceivedAtMs > latest {
|
||||
latest = seg.ReceivedAtMs
|
||||
}
|
||||
}
|
||||
return latest
|
||||
}
|
||||
|
||||
// anchorFromTRs returns the maximum RequestedAtMs across a TR slice. Used
|
||||
// by the cold-start initialisation of `lastProcessedMs` so that RC segments
|
||||
// older than the latest persisted bot response are not re-answered.
|
||||
func anchorFromTRs(trs []TurnResponseEntry) int64 {
|
||||
var latest int64
|
||||
for _, tr := range trs {
|
||||
if tr.RequestedAtMs > latest {
|
||||
latest = tr.RequestedAtMs
|
||||
}
|
||||
}
|
||||
return latest
|
||||
}
|
||||
|
||||
// broadcastDiscussEvent forwards an agent stream event to the RouteHub so the
|
||||
@@ -352,13 +401,20 @@ func agentEventToChannelEvent(e agentpkg.StreamEvent) (channel.StreamEvent, bool
|
||||
}
|
||||
}
|
||||
|
||||
// loadTurnResponses loads every assistant/tool message ever persisted for
|
||||
// this session and decodes them into TR entries. There is no time-based cut
|
||||
// off on purpose: truncating TRs while RC is replayed in full from the events
|
||||
// table creates an asymmetric context (user messages visible, the bot's own
|
||||
// earlier replies missing) that confuses both the LLM and loop-detection.
|
||||
// Any size-bound trimming should happen later via compaction, not here.
|
||||
func (d *DiscussDriver) loadTurnResponses(ctx context.Context, sessionID string) []TurnResponseEntry {
|
||||
if d.deps.MessageService == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
since := time.Now().UTC().Add(-24 * time.Hour)
|
||||
msgs, err := d.deps.MessageService.ListActiveSinceBySession(ctx, sessionID, since)
|
||||
// time.Unix(0, 0) is the Unix epoch; the underlying query uses
|
||||
// `created_at >= $1`, so this effectively loads every session message.
|
||||
msgs, err := d.deps.MessageService.ListActiveSinceBySession(ctx, sessionID, time.Unix(0, 0).UTC())
|
||||
if err != nil {
|
||||
d.logger.Warn("load TRs failed", slog.String("session_id", sessionID), slog.Any("error", err))
|
||||
return nil
|
||||
|
||||
@@ -208,6 +208,109 @@ func TestHandleReplyWithAgent_NoInlineWhenNoVision(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnchorFromTRs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := anchorFromTRs(nil); got != 0 {
|
||||
t.Fatalf("empty TRs anchor = %d, want 0", got)
|
||||
}
|
||||
got := anchorFromTRs([]TurnResponseEntry{
|
||||
{RequestedAtMs: 100},
|
||||
{RequestedAtMs: 500},
|
||||
{RequestedAtMs: 300},
|
||||
})
|
||||
if got != 500 {
|
||||
t.Fatalf("anchor = %d, want 500", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLatestRCReceivedAtMs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := latestRCReceivedAtMs(nil); got != 0 {
|
||||
t.Fatalf("empty RC = %d, want 0", got)
|
||||
}
|
||||
got := latestRCReceivedAtMs(RenderedContext{
|
||||
{ReceivedAtMs: 100},
|
||||
{ReceivedAtMs: 900},
|
||||
{ReceivedAtMs: 500, IsMyself: true},
|
||||
})
|
||||
if got != 900 {
|
||||
t.Fatalf("latest = %d, want 900", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleReplyWithAgent_ColdStartAnchoredByTR simulates idle-timeout
|
||||
// restart: the session's in-memory lastProcessedMs is 0, but RC replay has
|
||||
// brought back old user messages that were already answered in prior
|
||||
// LLM rounds (represented by TRs). The driver MUST NOT re-answer them.
|
||||
func TestHandleReplyWithAgent_ColdStartAnchoredByTR(t *testing.T) {
|
||||
rc := RenderedContext{
|
||||
{
|
||||
ReceivedAtMs: 100,
|
||||
Content: []RenderedContentPiece{{Type: "text", Text: `<message id="old">task 1</message>`}},
|
||||
},
|
||||
}
|
||||
|
||||
fakeAgent := &fakeDiscussStreamer{}
|
||||
resolver := &fakeRunConfigResolver{}
|
||||
|
||||
driver := NewDiscussDriver(DiscussDriverDeps{
|
||||
Pipeline: NewPipeline(RenderParams{}),
|
||||
Resolver: resolver,
|
||||
MessageService: nil,
|
||||
})
|
||||
|
||||
sess := &discussSession{
|
||||
config: DiscussSessionConfig{BotID: "b", SessionID: "s"},
|
||||
lastProcessedMs: 0,
|
||||
}
|
||||
|
||||
// Simulate a previously answered round by pre-stuffing a TR newer than
|
||||
// the RC segment's ReceivedAtMs. Since we cannot inject MessageService
|
||||
// easily, we instead pre-set lastProcessedMs as the anchor would.
|
||||
sess.lastProcessedMs = 200 // mimic anchorFromTRs result
|
||||
|
||||
driver.handleReplyWithAgent(context.Background(), sess, rc, driver.logger, fakeAgent)
|
||||
|
||||
if fakeAgent.lastConfig != nil {
|
||||
t.Fatal("agent must not be invoked when all RC segments predate lastProcessedMs")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleReplyWithAgent_CursorAdvancesToRCNotWallClock ensures that after
|
||||
// a turn we set lastProcessedMs to the max ReceivedAtMs actually consumed in
|
||||
// the RC snapshot, not time.Now(). This matters for messages that arrive
|
||||
// mid-turn: they end up in a fresher RC with ReceivedAtMs > cursor, which
|
||||
// correctly triggers the next round.
|
||||
func TestHandleReplyWithAgent_CursorAdvancesToRCNotWallClock(t *testing.T) {
|
||||
rc := RenderedContext{
|
||||
{
|
||||
ReceivedAtMs: 777,
|
||||
Content: []RenderedContentPiece{{Type: "text", Text: `<message id="x">hello</message>`}},
|
||||
},
|
||||
}
|
||||
fakeAgent := &fakeDiscussStreamer{}
|
||||
resolver := &fakeRunConfigResolver{}
|
||||
driver := NewDiscussDriver(DiscussDriverDeps{
|
||||
Pipeline: NewPipeline(RenderParams{}),
|
||||
Resolver: resolver,
|
||||
})
|
||||
sess := &discussSession{
|
||||
config: DiscussSessionConfig{BotID: "b", SessionID: "s"},
|
||||
lastProcessedMs: 0,
|
||||
}
|
||||
|
||||
driver.handleReplyWithAgent(context.Background(), sess, rc, driver.logger, fakeAgent)
|
||||
|
||||
if fakeAgent.lastConfig == nil {
|
||||
t.Fatal("expected agent to be invoked")
|
||||
}
|
||||
if sess.lastProcessedMs != 777 {
|
||||
t.Fatalf("lastProcessedMs = %d, want 777 (max RC ReceivedAtMs)", sess.lastProcessedMs)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test helpers ---
|
||||
|
||||
type fakeDiscussStreamer struct {
|
||||
|
||||
@@ -2,6 +2,7 @@ package pipeline
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
@@ -9,9 +10,17 @@ import (
|
||||
)
|
||||
|
||||
// DecodeTurnResponseEntry converts a persisted bot message into a TR entry for
|
||||
// pipeline context composition. Only visible text is preserved; structured
|
||||
// tool-call / tool-result payloads are skipped to avoid re-injecting raw JSON
|
||||
// into later prompts.
|
||||
// pipeline context composition.
|
||||
//
|
||||
// Unlike the old implementation (which only kept plain text and dropped all
|
||||
// tool-call / tool-result payloads), this version renders the full turn —
|
||||
// including tool calls and their results — into a single structured string
|
||||
// so the LLM can observe its own prior tool usage when the conversation is
|
||||
// later replayed or summarised.
|
||||
//
|
||||
// The rendering is intentionally compact and XML-flavoured so it survives
|
||||
// round-trips through the merge/compose pipeline without being confused with
|
||||
// the user-facing XML used by Rendering.
|
||||
func DecodeTurnResponseEntry(msg messagepkg.Message) (TurnResponseEntry, bool) {
|
||||
role := strings.TrimSpace(msg.Role)
|
||||
if role != "assistant" && role != "tool" {
|
||||
@@ -23,14 +32,188 @@ func DecodeTurnResponseEntry(msg messagepkg.Message) (TurnResponseEntry, bool) {
|
||||
return TurnResponseEntry{}, false
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(modelMsg.TextContent())
|
||||
if content == "" {
|
||||
var rendered string
|
||||
switch role {
|
||||
case "tool":
|
||||
rendered = renderToolRoleMessage(modelMsg)
|
||||
default:
|
||||
rendered = renderAssistantMessage(modelMsg)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(rendered) == "" {
|
||||
return TurnResponseEntry{}, false
|
||||
}
|
||||
|
||||
return TurnResponseEntry{
|
||||
RequestedAtMs: msg.CreatedAt.UnixMilli(),
|
||||
Role: msg.Role,
|
||||
Content: content,
|
||||
Role: role,
|
||||
Content: rendered,
|
||||
}, true
|
||||
}
|
||||
|
||||
// turnResponsePart is a permissive view of a persisted content part. It
|
||||
// purposefully uses json.RawMessage for tool input/output to avoid losing
|
||||
// structure while keeping the type declaration local to this package.
|
||||
type turnResponsePart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ToolCallID string `json:"toolCallId,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Output json.RawMessage `json:"output,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func renderAssistantMessage(msg conversation.ModelMessage) string {
|
||||
var b strings.Builder
|
||||
|
||||
// 1) Plain-string content (legacy format).
|
||||
if len(msg.Content) > 0 {
|
||||
var plain string
|
||||
if err := json.Unmarshal(msg.Content, &plain); err == nil {
|
||||
plain = strings.TrimSpace(plain)
|
||||
if plain != "" {
|
||||
b.WriteString(plain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Array-of-parts content (Vercel AI SDK uiMessage format).
|
||||
var parts []turnResponsePart
|
||||
if len(msg.Content) > 0 {
|
||||
_ = json.Unmarshal(msg.Content, &parts)
|
||||
}
|
||||
|
||||
for _, p := range parts {
|
||||
switch strings.ToLower(strings.TrimSpace(p.Type)) {
|
||||
case "text":
|
||||
text := strings.TrimSpace(p.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(text)
|
||||
case "reasoning":
|
||||
// Intentionally omitted: reasoning is model-internal and must not
|
||||
// leak back into subsequent prompts verbatim.
|
||||
continue
|
||||
case "tool-call":
|
||||
writeToolCallTag(&b, p.ToolCallID, p.ToolName, p.Input)
|
||||
case "tool-result":
|
||||
payload := p.Output
|
||||
if len(payload) == 0 {
|
||||
payload = p.Result
|
||||
}
|
||||
writeToolResultTag(&b, p.ToolCallID, p.ToolName, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Top-level ToolCalls field (older OpenAI-style wire format).
|
||||
for _, call := range msg.ToolCalls {
|
||||
id := strings.TrimSpace(call.ID)
|
||||
name := strings.TrimSpace(call.Function.Name)
|
||||
args := strings.TrimSpace(call.Function.Arguments)
|
||||
var input json.RawMessage
|
||||
if args != "" {
|
||||
// Arguments is a string containing JSON; try to keep it raw so
|
||||
// the downstream renderer doesn't double-escape.
|
||||
if json.Valid([]byte(args)) {
|
||||
input = json.RawMessage(args)
|
||||
} else {
|
||||
encoded, _ := json.Marshal(args)
|
||||
input = encoded
|
||||
}
|
||||
}
|
||||
writeToolCallTag(&b, id, name, input)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func renderToolRoleMessage(msg conversation.ModelMessage) string {
|
||||
// Two possible persistence shapes:
|
||||
// a) Content is a JSON array of parts with type="tool-result".
|
||||
// b) Content is the tool result itself, and ToolCallID is set on the
|
||||
// ModelMessage envelope (older OpenAI-style format).
|
||||
var b strings.Builder
|
||||
|
||||
var parts []turnResponsePart
|
||||
if len(msg.Content) > 0 {
|
||||
_ = json.Unmarshal(msg.Content, &parts)
|
||||
}
|
||||
for _, p := range parts {
|
||||
if strings.ToLower(strings.TrimSpace(p.Type)) != "tool-result" {
|
||||
continue
|
||||
}
|
||||
payload := p.Output
|
||||
if len(payload) == 0 {
|
||||
payload = p.Result
|
||||
}
|
||||
writeToolResultTag(&b, p.ToolCallID, p.ToolName, payload)
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
if strings.TrimSpace(msg.ToolCallID) != "" {
|
||||
writeToolResultTag(&b, msg.ToolCallID, msg.Name, msg.Content)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func writeToolCallTag(b *strings.Builder, id, name string, input json.RawMessage) {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
fmt.Fprintf(b, `<tool_call id=%q name=%q>`, escapeXMLAttrValue(strings.TrimSpace(id)), escapeXMLAttrValue(strings.TrimSpace(name)))
|
||||
if payload := formatToolPayload(input); payload != "" {
|
||||
b.WriteString(payload)
|
||||
}
|
||||
b.WriteString("</tool_call>")
|
||||
}
|
||||
|
||||
func writeToolResultTag(b *strings.Builder, id, name string, payload json.RawMessage) {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
fmt.Fprintf(b, `<tool_result id=%q name=%q>`, escapeXMLAttrValue(strings.TrimSpace(id)), escapeXMLAttrValue(strings.TrimSpace(name)))
|
||||
if rendered := formatToolPayload(payload); rendered != "" {
|
||||
b.WriteString(rendered)
|
||||
}
|
||||
b.WriteString("</tool_result>")
|
||||
}
|
||||
|
||||
// formatToolPayload returns a compact textual representation of a tool
|
||||
// input/output payload safe to embed inside a tag body.
|
||||
func formatToolPayload(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
trimmed := strings.TrimSpace(string(raw))
|
||||
if trimmed == "" || trimmed == "null" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// If the payload is a JSON string, unquote it so the body reads naturally.
|
||||
var asString string
|
||||
if err := json.Unmarshal(raw, &asString); err == nil {
|
||||
s := strings.TrimSpace(asString)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
return escapeXMLText(s)
|
||||
}
|
||||
|
||||
// Otherwise, re-encode as compact JSON so whitespace is normalised and
|
||||
// any nested structured content round-trips losslessly.
|
||||
var v any
|
||||
if err := json.Unmarshal(raw, &v); err == nil {
|
||||
encoded, err := json.Marshal(v)
|
||||
if err == nil {
|
||||
return escapeXMLText(string(encoded))
|
||||
}
|
||||
}
|
||||
return escapeXMLText(trimmed)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package pipeline
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -39,14 +40,23 @@ func TestDecodeTurnResponseEntryUsesVisibleText(t *testing.T) {
|
||||
if entry.Content != "任务完成" {
|
||||
t.Fatalf("content = %q, want %q", entry.Content, "任务完成")
|
||||
}
|
||||
// Reasoning must never leak into TRs to avoid re-injection into prompts.
|
||||
if strings.Contains(entry.Content, "thinking") {
|
||||
t.Fatalf("reasoning leaked into TR: %q", entry.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTurnResponseEntrySkipsToolCallOnlyPayload(t *testing.T) {
|
||||
func TestDecodeTurnResponseEntryPreservesToolCallOnlyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
{"type": "tool-call", "toolName": "read", "toolCallId": "call-1", "input": map[string]any{"path": "/tmp/a.txt"}},
|
||||
{
|
||||
"type": "tool-call",
|
||||
"toolName": "read",
|
||||
"toolCallId": "call-1",
|
||||
"input": map[string]any{"path": "/tmp/a.txt"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
@@ -60,11 +70,193 @@ func TestDecodeTurnResponseEntrySkipsToolCallOnlyPayload(t *testing.T) {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "assistant",
|
||||
Content: modelMessage,
|
||||
CreatedAt: time.Unix(1710000000, 0).UTC(),
|
||||
}); ok {
|
||||
t.Fatal("expected tool-call-only payload to be skipped")
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("expected tool-call-only payload to be preserved as TR")
|
||||
}
|
||||
if !strings.Contains(entry.Content, `<tool_call id="call-1" name="read">`) {
|
||||
t.Fatalf("missing tool_call tag: %q", entry.Content)
|
||||
}
|
||||
if !strings.Contains(entry.Content, `"path":"/tmp/a.txt"`) {
|
||||
t.Fatalf("tool input missing: %q", entry.Content)
|
||||
}
|
||||
if strings.Contains(entry.Content, "thinking") {
|
||||
t.Fatalf("reasoning leaked: %q", entry.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTurnResponseEntryRendersTextAndToolCall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{
|
||||
"type": "tool-call",
|
||||
"toolName": "web_search",
|
||||
"toolCallId": "call-42",
|
||||
"input": map[string]any{"query": "today news"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
modelMessage, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "assistant",
|
||||
Content: modelMessage,
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("expected entry")
|
||||
}
|
||||
if !strings.Contains(entry.Content, "Let me check.") {
|
||||
t.Fatalf("missing text portion: %q", entry.Content)
|
||||
}
|
||||
if !strings.Contains(entry.Content, `<tool_call id="call-42" name="web_search">`) {
|
||||
t.Fatalf("missing tool_call tag: %q", entry.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTurnResponseEntryToolRoleWithPartsResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{
|
||||
"type": "tool-result",
|
||||
"toolCallId": "call-1",
|
||||
"toolName": "web_search",
|
||||
"output": map[string]any{
|
||||
"count": 3,
|
||||
"summary": "ok",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
|
||||
modelMessage, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "tool",
|
||||
Content: modelMessage,
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("expected tool role entry")
|
||||
}
|
||||
if !strings.Contains(entry.Content, `<tool_result id="call-1" name="web_search">`) {
|
||||
t.Fatalf("missing tool_result tag: %q", entry.Content)
|
||||
}
|
||||
if !strings.Contains(entry.Content, `"count":3`) || !strings.Contains(entry.Content, `"summary":"ok"`) {
|
||||
t.Fatalf("structured tool output not preserved: %q", entry.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTurnResponseEntryToolRoleLegacyEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Old OpenAI-style: role=tool + ToolCallID on the envelope, Content is
|
||||
// a JSON string carrying the result directly.
|
||||
resultBody := json.RawMessage(`{"status":"ok"}`)
|
||||
modelMessage, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: "call-99",
|
||||
Name: "ping",
|
||||
Content: resultBody,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
|
||||
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "tool",
|
||||
Content: modelMessage,
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("expected entry for legacy tool envelope")
|
||||
}
|
||||
if !strings.Contains(entry.Content, `<tool_result id="call-99" name="ping">`) {
|
||||
t.Fatalf("missing tool_result tag: %q", entry.Content)
|
||||
}
|
||||
if !strings.Contains(entry.Content, `"status":"ok"`) {
|
||||
t.Fatalf("legacy tool body missing: %q", entry.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTurnResponseEntrySkipsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Only reasoning → nothing to expose to future prompts → skip.
|
||||
content, err := json.Marshal([]map[string]any{
|
||||
{"type": "reasoning", "text": "thinking out loud"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal content: %v", err)
|
||||
}
|
||||
modelMessage, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
if _, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "assistant",
|
||||
Content: modelMessage,
|
||||
}); ok {
|
||||
t.Fatal("expected reasoning-only message to be skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeTurnResponseEntryLegacyToolCallsField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Older OpenAI envelope: Content is empty string, ToolCalls carries
|
||||
// the function-call structure.
|
||||
modelMessage, err := json.Marshal(conversation.ModelMessage{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`""`),
|
||||
ToolCalls: []conversation.ToolCall{
|
||||
{
|
||||
ID: "call-legacy",
|
||||
Type: "function",
|
||||
Function: conversation.ToolCallFunction{
|
||||
Name: "send",
|
||||
Arguments: `{"text":"hi"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal model message: %v", err)
|
||||
}
|
||||
entry, ok := DecodeTurnResponseEntry(messagepkg.Message{
|
||||
Role: "assistant",
|
||||
Content: modelMessage,
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("expected legacy tool-calls envelope to decode")
|
||||
}
|
||||
if !strings.Contains(entry.Content, `<tool_call id="call-legacy" name="send">`) {
|
||||
t.Fatalf("missing tool_call tag: %q", entry.Content)
|
||||
}
|
||||
if !strings.Contains(entry.Content, `"text":"hi"`) {
|
||||
t.Fatalf("arguments missing: %q", entry.Content)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user