mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
386 lines
11 KiB
Go
386 lines
11 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
|
|
|
mcptools "github.com/memohai/memoh/internal/mcp"
|
|
)
|
|
|
|
// fakeMCPConnection implements sdkmcp.Connection for testing.
|
|
// onWrite is called synchronously when Write is called; if it returns a
|
|
// non-nil Response the response is queued to be returned by Read.
|
|
type fakeMCPConnection struct {
|
|
mu sync.Mutex
|
|
writes []*sdkjsonrpc.Request
|
|
readCh chan sdkjsonrpc.Message
|
|
closed chan struct{}
|
|
closeMu sync.Once
|
|
onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error)
|
|
}
|
|
|
|
func newFakeMCPConnection(onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error)) *fakeMCPConnection {
|
|
return &fakeMCPConnection{
|
|
writes: make([]*sdkjsonrpc.Request, 0, 16),
|
|
readCh: make(chan sdkjsonrpc.Message, 32),
|
|
closed: make(chan struct{}),
|
|
onWrite: onWrite,
|
|
}
|
|
}
|
|
|
|
func (c *fakeMCPConnection) Read(ctx context.Context) (sdkjsonrpc.Message, error) {
|
|
select {
|
|
case <-c.closed:
|
|
return nil, io.EOF
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case msg, ok := <-c.readCh:
|
|
if !ok {
|
|
return nil, io.EOF
|
|
}
|
|
return msg, nil
|
|
}
|
|
}
|
|
|
|
func (c *fakeMCPConnection) Write(ctx context.Context, msg sdkjsonrpc.Message) error {
|
|
req, ok := msg.(*sdkjsonrpc.Request)
|
|
if !ok {
|
|
return fmt.Errorf("unsupported message type: %T", msg)
|
|
}
|
|
cloned := cloneJSONRPCRequest(req)
|
|
c.mu.Lock()
|
|
c.writes = append(c.writes, cloned)
|
|
c.mu.Unlock()
|
|
|
|
if c.onWrite == nil {
|
|
return nil
|
|
}
|
|
resp, err := c.onWrite(cloned)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp == nil {
|
|
return nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-c.closed:
|
|
return io.EOF
|
|
case c.readCh <- resp:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (c *fakeMCPConnection) Close() error {
|
|
c.closeMu.Do(func() {
|
|
close(c.closed)
|
|
close(c.readCh)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (*fakeMCPConnection) SessionID() string { return "test-session" }
|
|
|
|
func cloneJSONRPCRequest(req *sdkjsonrpc.Request) *sdkjsonrpc.Request {
|
|
if req == nil {
|
|
return nil
|
|
}
|
|
params := append([]byte(nil), req.Params...)
|
|
return &sdkjsonrpc.Request{
|
|
ID: req.ID,
|
|
Method: req.Method,
|
|
Params: params,
|
|
Extra: req.Extra,
|
|
}
|
|
}
|
|
|
|
func jsonRPCSuccessResponse(id sdkjsonrpc.ID, payload map[string]any) *sdkjsonrpc.Response {
|
|
body, _ := json.Marshal(payload)
|
|
return &sdkjsonrpc.Response{ID: id, Result: body}
|
|
}
|
|
|
|
func newTestMCPSession(conn *fakeMCPConnection) *mcpSession {
|
|
readCtx, cancelRead := context.WithCancel(context.Background()) //nolint:gosec // G118: cancelRead is stored in mcpSession.cancelRead
|
|
return &mcpSession{
|
|
pending: map[string]chan *sdkjsonrpc.Response{},
|
|
conn: conn,
|
|
closed: make(chan struct{}),
|
|
readCtx: readCtx,
|
|
cancelRead: cancelRead,
|
|
}
|
|
}
|
|
|
|
// --- Tests ---
|
|
|
|
// TestMCPSession_CallRaw_ResponseEnvelope verifies that callRaw returns a
|
|
// standard JSON-RPC envelope {"jsonrpc","id","result"}.
|
|
func TestMCPSession_CallRaw_ResponseEnvelope(t *testing.T) {
|
|
conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) {
|
|
return jsonRPCSuccessResponse(req.ID, map[string]any{"tools": []any{}}), nil
|
|
})
|
|
sess := newTestMCPSession(conn)
|
|
sess.initState = mcpSessionInitStateReady
|
|
go sess.readLoop()
|
|
defer sess.closeWithError(io.EOF)
|
|
|
|
payload, err := sess.call(context.Background(), mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("1"),
|
|
Method: "tools/list",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("call failed: %v", err)
|
|
}
|
|
|
|
// Verify standard JSON-RPC envelope.
|
|
if payload["jsonrpc"] != "2.0" {
|
|
t.Errorf("expected jsonrpc=2.0, got %v", payload["jsonrpc"])
|
|
}
|
|
if _, ok := payload["id"]; !ok {
|
|
t.Errorf("expected 'id' field in envelope, got %v", payload)
|
|
}
|
|
if _, ok := payload["result"]; !ok {
|
|
t.Errorf("expected 'result' field in envelope, got %v", payload)
|
|
}
|
|
if _, ok := payload["error"]; ok {
|
|
t.Errorf("unexpected 'error' field in success envelope")
|
|
}
|
|
}
|
|
|
|
// TestMCPSession_CallRaw_ErrorEnvelope verifies that server-side errors are
|
|
// returned as {"jsonrpc","id","error"} envelope, not a Go error.
|
|
func TestMCPSession_CallRaw_ErrorEnvelope(t *testing.T) {
|
|
conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) {
|
|
return &sdkjsonrpc.Response{
|
|
ID: req.ID,
|
|
Error: &sdkjsonrpc.Error{Code: -32601, Message: "Method not found"},
|
|
}, nil
|
|
})
|
|
sess := newTestMCPSession(conn)
|
|
sess.initState = mcpSessionInitStateReady
|
|
go sess.readLoop()
|
|
defer sess.closeWithError(io.EOF)
|
|
|
|
payload, err := sess.call(context.Background(), mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("2"),
|
|
Method: "unknown/method",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected Go error (server errors should be in envelope): %v", err)
|
|
}
|
|
errField, ok := payload["error"].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("expected 'error' field in envelope, got %v", payload)
|
|
}
|
|
if errField["code"] != int64(-32601) {
|
|
t.Errorf("unexpected error code: %v", errField["code"])
|
|
}
|
|
if _, ok := payload["result"]; ok {
|
|
t.Errorf("unexpected 'result' field in error envelope")
|
|
}
|
|
}
|
|
|
|
// TestMCPSession_InitializeRetryAfterFailure tests that the session retries
|
|
// the initialize handshake after the first attempt fails.
|
|
func TestMCPSession_InitializeRetryAfterFailure(t *testing.T) {
|
|
initCalls := 0
|
|
conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) {
|
|
switch req.Method {
|
|
case "initialize":
|
|
initCalls++
|
|
if initCalls == 1 {
|
|
return &sdkjsonrpc.Response{
|
|
ID: req.ID,
|
|
Error: &sdkjsonrpc.Error{Code: -32603, Message: "temporary init failure"},
|
|
}, nil
|
|
}
|
|
return jsonRPCSuccessResponse(req.ID, map[string]any{"protocolVersion": "2025-06-18"}), nil
|
|
case "tools/list":
|
|
return jsonRPCSuccessResponse(req.ID, map[string]any{"tools": []any{}}), nil
|
|
default:
|
|
return nil, nil
|
|
}
|
|
})
|
|
sess := newTestMCPSession(conn)
|
|
go sess.readLoop()
|
|
defer sess.closeWithError(io.EOF)
|
|
|
|
_, firstErr := sess.call(context.Background(), mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("1"),
|
|
Method: "tools/list",
|
|
})
|
|
if firstErr == nil {
|
|
t.Fatal("first call should fail when initialize fails")
|
|
}
|
|
|
|
secondPayload, secondErr := sess.call(context.Background(), mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("2"),
|
|
Method: "tools/list",
|
|
})
|
|
if secondErr != nil {
|
|
t.Fatalf("second call should recover by retrying initialize: %v", secondErr)
|
|
}
|
|
if initCalls != 2 {
|
|
t.Fatalf("initialize should be retried once, got calls: %d", initCalls)
|
|
}
|
|
result, ok := secondPayload["result"].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("missing tools/list result in envelope: %#v", secondPayload)
|
|
}
|
|
if _, ok := result["tools"].([]any); !ok {
|
|
t.Fatalf("missing tools field: %#v", result)
|
|
}
|
|
}
|
|
|
|
// TestMCPSession_ExplicitInitializeNoDoubling tests that sending an explicit
|
|
// "initialize" call does not cause the session to auto-initialize again.
|
|
func TestMCPSession_ExplicitInitializeNoDoubling(t *testing.T) {
|
|
initializeCalls := 0
|
|
initializedNotifications := 0
|
|
conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) {
|
|
switch req.Method {
|
|
case "initialize":
|
|
initializeCalls++
|
|
return jsonRPCSuccessResponse(req.ID, map[string]any{"protocolVersion": "2025-06-18"}), nil
|
|
case "notifications/initialized":
|
|
initializedNotifications++
|
|
return nil, nil
|
|
case "tools/list":
|
|
return jsonRPCSuccessResponse(req.ID, map[string]any{"tools": []any{}}), nil
|
|
default:
|
|
return nil, nil
|
|
}
|
|
})
|
|
sess := newTestMCPSession(conn)
|
|
go sess.readLoop()
|
|
defer sess.closeWithError(io.EOF)
|
|
|
|
_, initErr := sess.call(context.Background(), mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("100"),
|
|
Method: "initialize",
|
|
Params: json.RawMessage(`{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"v1"}}`),
|
|
})
|
|
if initErr != nil {
|
|
t.Fatalf("explicit initialize should succeed: %v", initErr)
|
|
}
|
|
|
|
_, listErr := sess.call(context.Background(), mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("101"),
|
|
Method: "tools/list",
|
|
})
|
|
if listErr != nil {
|
|
t.Fatalf("tools/list after initialize should succeed: %v", listErr)
|
|
}
|
|
if initializeCalls != 1 {
|
|
t.Fatalf("initialize should not be duplicated, got: %d", initializeCalls)
|
|
}
|
|
if initializedNotifications != 1 {
|
|
t.Fatalf("should send exactly one notifications/initialized, got: %d", initializedNotifications)
|
|
}
|
|
}
|
|
|
|
// TestMCPSession_PendingCleanupOnContextCancel tests that cancelling a request
|
|
// context removes it from the pending map.
|
|
func TestMCPSession_PendingCleanupOnContextCancel(t *testing.T) {
|
|
conn := newFakeMCPConnection(func(_ *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) {
|
|
// Never reply — caller should time out.
|
|
return nil, nil
|
|
})
|
|
sess := newTestMCPSession(conn)
|
|
sess.initState = mcpSessionInitStateReady
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
|
defer cancel()
|
|
_, err := sess.call(ctx, mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("200"),
|
|
Method: "tools/list",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("call should fail on context timeout")
|
|
}
|
|
|
|
sess.pendingMu.Lock()
|
|
pendingCount := len(sess.pending)
|
|
sess.pendingMu.Unlock()
|
|
if pendingCount != 0 {
|
|
t.Fatalf("pending map should be empty after cancellation, got: %d", pendingCount)
|
|
}
|
|
}
|
|
|
|
// TestMCPSession_PendingCleanupOnClose tests that closing the session drains
|
|
// all pending channels (callers unblock).
|
|
func TestMCPSession_PendingCleanupOnClose(t *testing.T) {
|
|
conn := newFakeMCPConnection(func(_ *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) {
|
|
return nil, nil // never reply
|
|
})
|
|
sess := newTestMCPSession(conn)
|
|
sess.initState = mcpSessionInitStateReady
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := sess.call(context.Background(), mcptools.JSONRPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: mcptools.RawStringID("300"),
|
|
Method: "tools/list",
|
|
})
|
|
errCh <- err
|
|
}()
|
|
|
|
// Give goroutine time to register in pending.
|
|
time.Sleep(10 * time.Millisecond)
|
|
sess.closeWithError(io.EOF)
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
if err == nil {
|
|
t.Error("expected error after session close, got nil")
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("call did not unblock after session close")
|
|
}
|
|
|
|
sess.pendingMu.Lock()
|
|
pendingCount := len(sess.pending)
|
|
sess.pendingMu.Unlock()
|
|
if pendingCount != 0 {
|
|
t.Fatalf("pending map should be empty after close, got: %d", pendingCount)
|
|
}
|
|
}
|
|
|
|
// TestMCPSession_ReadLoopCancelOnClose tests that closing the session
|
|
// (which cancels readCtx) causes readLoop to exit.
|
|
func TestMCPSession_ReadLoopCancelOnClose(t *testing.T) {
|
|
conn := newFakeMCPConnection(nil)
|
|
sess := newTestMCPSession(conn)
|
|
|
|
loopDone := make(chan struct{})
|
|
go func() {
|
|
sess.readLoop()
|
|
close(loopDone)
|
|
}()
|
|
|
|
// Close the session; this should cancel readCtx and unblock readLoop.
|
|
sess.closeWithError(io.EOF)
|
|
|
|
select {
|
|
case <-loopDone:
|
|
// readLoop exited as expected.
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("readLoop did not exit after session close")
|
|
}
|
|
}
|