mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
81 lines
2.1 KiB
Go
81 lines
2.1 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
|
)
|
|
|
|
type cancelOnStdoutExecStream struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
outputs []*pb.ExecOutput
|
|
canceled bool
|
|
}
|
|
|
|
func newCancelOnStdoutExecStream() *cancelOnStdoutExecStream {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &cancelOnStdoutExecStream{ctx: ctx, cancel: cancel}
|
|
}
|
|
|
|
func (s *cancelOnStdoutExecStream) Send(msg *pb.ExecOutput) error {
|
|
clone := proto.Clone(msg).(*pb.ExecOutput)
|
|
if len(msg.GetData()) > 0 {
|
|
clone.Data = append([]byte(nil), msg.GetData()...)
|
|
}
|
|
s.outputs = append(s.outputs, clone)
|
|
if !s.canceled && msg.GetStream() == pb.ExecOutput_STDOUT && len(msg.GetData()) > 0 {
|
|
s.canceled = true
|
|
s.cancel()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *cancelOnStdoutExecStream) Recv() (*pb.ExecInput, error) {
|
|
<-s.ctx.Done()
|
|
return nil, s.ctx.Err()
|
|
}
|
|
|
|
func (s *cancelOnStdoutExecStream) Context() context.Context { return s.ctx }
|
|
func (*cancelOnStdoutExecStream) SetHeader(metadata.MD) error { return nil }
|
|
func (*cancelOnStdoutExecStream) SendHeader(metadata.MD) error { return nil }
|
|
func (*cancelOnStdoutExecStream) SetTrailer(metadata.MD) {}
|
|
func (*cancelOnStdoutExecStream) SendMsg(any) error { return nil }
|
|
func (*cancelOnStdoutExecStream) RecvMsg(any) error { return nil }
|
|
|
|
func TestExecPipePreservesExitCodeAcrossStreamCancellation(t *testing.T) {
|
|
stream := newCancelOnStdoutExecStream()
|
|
|
|
err := execPipe(stream, &pb.ExecInput{
|
|
Command: "printf ok; sleep 0.2",
|
|
WorkDir: "/tmp",
|
|
TimeoutSeconds: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("execPipe returned error: %v", err)
|
|
}
|
|
|
|
var stdout string
|
|
var exitCode int32 = -999
|
|
for _, output := range stream.outputs {
|
|
switch output.GetStream() {
|
|
case pb.ExecOutput_STDOUT:
|
|
stdout += string(output.GetData())
|
|
case pb.ExecOutput_EXIT:
|
|
exitCode = output.GetExitCode()
|
|
}
|
|
}
|
|
|
|
if stdout != "ok" {
|
|
t.Fatalf("stdout = %q, want %q", stdout, "ok")
|
|
}
|
|
if exitCode != 0 {
|
|
t.Fatalf("exit code = %d, want 0", exitCode)
|
|
}
|
|
}
|