diff --git a/cmd/bridge/server.go b/cmd/bridge/server.go index 164d9048..abeb1c6c 100644 --- a/cmd/bridge/server.go +++ b/cmd/bridge/server.go @@ -361,7 +361,10 @@ func execPipe(stream pb.ContainerService_ExecServer, firstMsg *pb.ExecInput) err timeout = defaultTimeout } - ctx, cancel := context.WithTimeout(stream.Context(), time.Duration(timeout)*time.Second) + // Keep non-PTY execs alive across transport cancellation so a dropped + // stream does not rewrite a successful command into exit -1. The timeout + // still bounds command lifetime, and stream shutdown still closes stdin. + ctx, cancel := context.WithTimeout(context.WithoutCancel(stream.Context()), time.Duration(timeout)*time.Second) defer cancel() cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) //nolint:gosec // G204: MCP exec tool intentionally executes agent-issued shell commands inside the container diff --git a/cmd/bridge/server_test.go b/cmd/bridge/server_test.go new file mode 100644 index 00000000..595c2dc0 --- /dev/null +++ b/cmd/bridge/server_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "context" + "testing" + + "google.golang.org/grpc/metadata" + + pb "github.com/memohai/memoh/internal/workspace/bridgepb" +) + +type cancelOnStdoutExecStream struct { + ctx context.Context + cancel context.CancelFunc + + outputs []*pb.ExecOutput + canceled bool +} + +func newCancelOnStdoutExecStream() *cancelOnStdoutExecStream { + ctx, cancel := context.WithCancel(context.Background()) + return &cancelOnStdoutExecStream{ctx: ctx, cancel: cancel} +} + +func (s *cancelOnStdoutExecStream) Send(msg *pb.ExecOutput) error { + clone := *msg + if len(msg.GetData()) > 0 { + clone.Data = append([]byte(nil), msg.GetData()...) + } + s.outputs = append(s.outputs, &clone) + if !s.canceled && msg.GetStream() == pb.ExecOutput_STDOUT && len(msg.GetData()) > 0 { + s.canceled = true + s.cancel() + } + return nil +} + +func (s *cancelOnStdoutExecStream) Recv() (*pb.ExecInput, error) { + <-s.ctx.Done() + return nil, s.ctx.Err() +} + +func (s *cancelOnStdoutExecStream) Context() context.Context { return s.ctx } +func (*cancelOnStdoutExecStream) SetHeader(metadata.MD) error { return nil } +func (*cancelOnStdoutExecStream) SendHeader(metadata.MD) error { return nil } +func (*cancelOnStdoutExecStream) SetTrailer(metadata.MD) {} +func (*cancelOnStdoutExecStream) SendMsg(any) error { return nil } +func (*cancelOnStdoutExecStream) RecvMsg(any) error { return nil } + +func TestExecPipePreservesExitCodeAcrossStreamCancellation(t *testing.T) { + stream := newCancelOnStdoutExecStream() + + err := execPipe(stream, &pb.ExecInput{ + Command: "printf ok; sleep 0.2", + WorkDir: "/tmp", + TimeoutSeconds: 5, + }) + if err != nil { + t.Fatalf("execPipe returned error: %v", err) + } + + var stdout string + var exitCode int32 = -999 + for _, output := range stream.outputs { + switch output.GetStream() { + case pb.ExecOutput_STDOUT: + stdout += string(output.GetData()) + case pb.ExecOutput_EXIT: + exitCode = output.GetExitCode() + } + } + + if stdout != "ok" { + t.Fatalf("stdout = %q, want %q", stdout, "ok") + } + if exitCode != 0 { + t.Fatalf("exit code = %d, want 0", exitCode) + } +}