Skip to content

fix(race-conditon): fix data race condition when accessing stderr on pipe #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 2, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
with:
go-version: ${{ matrix.go-version }}
- uses: actions/checkout@v3
- run: go test ./...
- run: go test -race ./...

gocritic:
runs-on: ubuntu-latest
Expand Down
42 changes: 29 additions & 13 deletions script.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ import (
// Pipe represents a pipe object with an associated [ReadAutoCloser].
type Pipe struct {
// Reader is the underlying reader.
Reader ReadAutoCloser
stdout, stderr io.Writer
httpClient *http.Client
Reader ReadAutoCloser
stdout io.Writer
httpClient *http.Client

// because pipe stages are concurrent, protect 'err'
mu *sync.Mutex
err error
// because pipe stages are concurrent, protect 'err' and 'stderr'
mu *sync.Mutex
err error
stderr io.Writer
}

// Args creates a pipe containing the program's command-line arguments from
Expand Down Expand Up @@ -414,8 +415,9 @@ func (p *Pipe) Exec(cmdLine string) *Pipe {
cmd.Stdin = r
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
pipeStderr := p.stdErr()
if pipeStderr != nil {
cmd.Stderr = pipeStderr
}
err = cmd.Start()
if err != nil {
Expand Down Expand Up @@ -454,8 +456,9 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
pipeStderr := p.stdErr()
if pipeStderr != nil {
cmd.Stderr = pipeStderr
}
err = cmd.Start()
if err != nil {
Expand Down Expand Up @@ -839,6 +842,18 @@ func (p *Pipe) Slice() ([]string, error) {
return result, p.Error()
}

// stdErr returns the pipe's configured standard error writer for commands run
// via [Pipe.Exec] and [Pipe.ExecForEach]. The default is nil, which means that
// error output will go to the pipe.
func (p *Pipe) stdErr() io.Writer {
if p.mu == nil { // uninitialised pipe
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
return p.stderr
}

// Stdout copies the pipe's contents to its configured standard output (using
// [Pipe.WithStdout]), or to [os.Stdout] otherwise, and returns the number of
// bytes successfully written, together with any error.
Expand Down Expand Up @@ -913,10 +928,11 @@ func (p *Pipe) WithReader(r io.Reader) *Pipe {
return p
}

// WithStderr redirects the standard error output for commands run via
// [Pipe.Exec] or [Pipe.ExecForEach] to the writer w, instead of going to the
// pipe as it normally would.
// WithStderr sets the standard error output for [Pipe.Exec] or
// [Pipe.ExecForEach] commands to w, instead of the pipe.
func (p *Pipe) WithStderr(w io.Writer) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()
p.stderr = w
return p
}
Expand Down
8 changes: 8 additions & 0 deletions script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,14 @@ func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) {
}
}

func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) {
t.Parallel()
err := script.Exec("echo").WithStderr(nil).Wait()
if err != nil {
t.Fatal(err)
}
}

func ExampleArgs() {
script.Args().Stdout()
// prints command-line arguments
Expand Down
Loading