mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
178 lines
4.3 KiB
Go
178 lines
4.3 KiB
Go
package flow
|
|
|
|
import (
|
|
"strings"
|
|
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
|
|
"github.com/memohai/memoh/internal/conversation"
|
|
)
|
|
|
|
const syntheticToolClosureError = "tool execution interrupted before a response was recorded"
|
|
|
|
type pendingToolCall struct {
|
|
ID string
|
|
ToolName string
|
|
}
|
|
|
|
func repairToolCallClosures(messages []conversation.ModelMessage, reason string) []conversation.ModelMessage {
|
|
if len(messages) == 0 {
|
|
return messages
|
|
}
|
|
if strings.TrimSpace(reason) == "" {
|
|
reason = syntheticToolClosureError
|
|
}
|
|
|
|
repaired := make([]conversation.ModelMessage, 0, len(messages))
|
|
pending := make(map[string]pendingToolCall)
|
|
pendingOrder := make([]string, 0)
|
|
|
|
flushPending := func() {
|
|
if len(pendingOrder) == 0 {
|
|
return
|
|
}
|
|
for _, callID := range pendingOrder {
|
|
call, ok := pending[callID]
|
|
if !ok {
|
|
continue
|
|
}
|
|
repaired = append(repaired, syntheticToolResultMessage(call.ID, call.ToolName, reason))
|
|
delete(pending, callID)
|
|
}
|
|
pendingOrder = pendingOrder[:0]
|
|
}
|
|
|
|
for _, msg := range messages {
|
|
role := strings.ToLower(strings.TrimSpace(msg.Role))
|
|
switch role {
|
|
case "assistant":
|
|
if len(pending) > 0 {
|
|
flushPending()
|
|
}
|
|
repaired = append(repaired, msg)
|
|
for _, call := range extractAssistantToolCallParts(msg) {
|
|
callID := strings.TrimSpace(call.ToolCallID)
|
|
if callID == "" {
|
|
continue
|
|
}
|
|
if _, exists := pending[callID]; exists {
|
|
continue
|
|
}
|
|
pending[callID] = pendingToolCall{
|
|
ID: callID,
|
|
ToolName: strings.TrimSpace(call.ToolName),
|
|
}
|
|
pendingOrder = append(pendingOrder, callID)
|
|
}
|
|
|
|
case "tool":
|
|
filtered := filterToolMessageToPending(msg, pending)
|
|
if filtered == nil {
|
|
continue
|
|
}
|
|
repaired = append(repaired, *filtered)
|
|
for _, result := range extractToolResultParts(*filtered) {
|
|
delete(pending, strings.TrimSpace(result.ToolCallID))
|
|
}
|
|
if len(pending) == 0 && len(pendingOrder) > 0 {
|
|
pendingOrder = pendingOrder[:0]
|
|
continue
|
|
}
|
|
if len(pendingOrder) > 0 {
|
|
nextOrder := pendingOrder[:0]
|
|
for _, callID := range pendingOrder {
|
|
if _, ok := pending[callID]; ok {
|
|
nextOrder = append(nextOrder, callID)
|
|
}
|
|
}
|
|
pendingOrder = nextOrder
|
|
}
|
|
|
|
default:
|
|
if len(pending) > 0 {
|
|
flushPending()
|
|
}
|
|
repaired = append(repaired, msg)
|
|
}
|
|
}
|
|
|
|
if len(pending) > 0 {
|
|
flushPending()
|
|
}
|
|
return repaired
|
|
}
|
|
|
|
func extractAssistantToolCallParts(msg conversation.ModelMessage) []sdk.ToolCallPart {
|
|
sdkMsg := modelMessageToSDKMessage(msg)
|
|
if len(sdkMsg.Content) == 0 {
|
|
return nil
|
|
}
|
|
calls := make([]sdk.ToolCallPart, 0, len(sdkMsg.Content))
|
|
for _, part := range sdkMsg.Content {
|
|
call, ok := part.(sdk.ToolCallPart)
|
|
if !ok {
|
|
continue
|
|
}
|
|
calls = append(calls, call)
|
|
}
|
|
return calls
|
|
}
|
|
|
|
func extractToolResultParts(msg conversation.ModelMessage) []sdk.ToolResultPart {
|
|
sdkMsg := modelMessageToSDKMessage(msg)
|
|
if len(sdkMsg.Content) == 0 {
|
|
return nil
|
|
}
|
|
results := make([]sdk.ToolResultPart, 0, len(sdkMsg.Content))
|
|
for _, part := range sdkMsg.Content {
|
|
result, ok := part.(sdk.ToolResultPart)
|
|
if !ok {
|
|
continue
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
return results
|
|
}
|
|
|
|
func filterToolMessageToPending(msg conversation.ModelMessage, pending map[string]pendingToolCall) *conversation.ModelMessage {
|
|
results := extractToolResultParts(msg)
|
|
if len(results) == 0 {
|
|
return nil
|
|
}
|
|
|
|
filtered := make([]sdk.ToolResultPart, 0, len(results))
|
|
for _, result := range results {
|
|
if _, ok := pending[strings.TrimSpace(result.ToolCallID)]; !ok {
|
|
continue
|
|
}
|
|
filtered = append(filtered, result)
|
|
}
|
|
if len(filtered) == 0 {
|
|
return nil
|
|
}
|
|
|
|
converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(filtered...)})
|
|
if len(converted) == 0 {
|
|
return nil
|
|
}
|
|
filteredMsg := converted[0]
|
|
filteredMsg.Usage = msg.Usage
|
|
return &filteredMsg
|
|
}
|
|
|
|
func syntheticToolResultMessage(toolCallID, toolName, reason string) conversation.ModelMessage {
|
|
converted := sdkMessagesToModelMessages([]sdk.Message{sdk.ToolMessage(sdk.ToolResultPart{
|
|
ToolCallID: strings.TrimSpace(toolCallID),
|
|
ToolName: strings.TrimSpace(toolName),
|
|
Result: strings.TrimSpace(reason),
|
|
IsError: true,
|
|
})})
|
|
if len(converted) == 0 {
|
|
return conversation.ModelMessage{
|
|
Role: "tool",
|
|
Content: conversation.NewTextContent(strings.TrimSpace(reason)),
|
|
}
|
|
}
|
|
return converted[0]
|
|
}
|