diff --git a/cmd/bridge/server_test.go b/cmd/bridge/server_test.go new file mode 100644 index 00000000..834afac4 --- /dev/null +++ b/cmd/bridge/server_test.go @@ -0,0 +1,80 @@ +package main + +import ( + "context" + "testing" + + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + + pb "github.com/memohai/memoh/internal/workspace/bridgepb" +) + +type cancelOnStdoutExecStream struct { + ctx context.Context + cancel context.CancelFunc + + outputs []*pb.ExecOutput + canceled bool +} + +func newCancelOnStdoutExecStream() *cancelOnStdoutExecStream { + ctx, cancel := context.WithCancel(context.Background()) + return &cancelOnStdoutExecStream{ctx: ctx, cancel: cancel} +} + +func (s *cancelOnStdoutExecStream) Send(msg *pb.ExecOutput) error { + clone := proto.Clone(msg).(*pb.ExecOutput) + if len(msg.GetData()) > 0 { + clone.Data = append([]byte(nil), msg.GetData()...) + } + s.outputs = append(s.outputs, clone) + if !s.canceled && msg.GetStream() == pb.ExecOutput_STDOUT && len(msg.GetData()) > 0 { + s.canceled = true + s.cancel() + } + return nil +} + +func (s *cancelOnStdoutExecStream) Recv() (*pb.ExecInput, error) { + <-s.ctx.Done() + return nil, s.ctx.Err() +} + +func (s *cancelOnStdoutExecStream) Context() context.Context { return s.ctx } +func (*cancelOnStdoutExecStream) SetHeader(metadata.MD) error { return nil } +func (*cancelOnStdoutExecStream) SendHeader(metadata.MD) error { return nil } +func (*cancelOnStdoutExecStream) SetTrailer(metadata.MD) {} +func (*cancelOnStdoutExecStream) SendMsg(any) error { return nil } +func (*cancelOnStdoutExecStream) RecvMsg(any) error { return nil } + +func TestExecPipePreservesExitCodeAcrossStreamCancellation(t *testing.T) { + stream := newCancelOnStdoutExecStream() + + err := execPipe(stream, &pb.ExecInput{ + Command: "printf ok; sleep 0.2", + WorkDir: "/tmp", + TimeoutSeconds: 5, + }) + if err != nil { + t.Fatalf("execPipe returned error: %v", err) + } + + var stdout string + var exitCode int32 = -999 + for _, output := range stream.outputs { + switch output.GetStream() { + case pb.ExecOutput_STDOUT: + stdout += string(output.GetData()) + case pb.ExecOutput_EXIT: + exitCode = output.GetExitCode() + } + } + + if stdout != "ok" { + t.Fatalf("stdout = %q, want %q", stdout, "ok") + } + if exitCode != 0 { + t.Fatalf("exit code = %d, want 0", exitCode) + } +}