Skip to content

Commit 0edd895

Browse files
mahadzaryab1bitfield
andauthoredSep 2, 2024
fix data race on p.stderr (#209)
Co-authored-by: John Arundel <john@bitfieldconsulting.com>
1 parent 0daf4b2 commit 0edd895

File tree

3 files changed

+38
-14
lines changed

3 files changed

+38
-14
lines changed
 

‎.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
with:
1414
go-version: ${{ matrix.go-version }}
1515
- uses: actions/checkout@v3
16-
- run: go test ./...
16+
- run: go test -race ./...
1717

1818
gocritic:
1919
runs-on: ubuntu-latest

‎script.go

+29-13
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@ import (
2828
// Pipe represents a pipe object with an associated [ReadAutoCloser].
2929
type Pipe struct {
3030
// Reader is the underlying reader.
31-
Reader ReadAutoCloser
32-
stdout, stderr io.Writer
33-
httpClient *http.Client
31+
Reader ReadAutoCloser
32+
stdout io.Writer
33+
httpClient *http.Client
3434

35-
// because pipe stages are concurrent, protect 'err'
36-
mu *sync.Mutex
37-
err error
35+
// because pipe stages are concurrent, protect 'err' and 'stderr'
36+
mu *sync.Mutex
37+
err error
38+
stderr io.Writer
3839
}
3940

4041
// Args creates a pipe containing the program's command-line arguments from
@@ -414,8 +415,9 @@ func (p *Pipe) Exec(cmdLine string) *Pipe {
414415
cmd.Stdin = r
415416
cmd.Stdout = w
416417
cmd.Stderr = w
417-
if p.stderr != nil {
418-
cmd.Stderr = p.stderr
418+
pipeStderr := p.stdErr()
419+
if pipeStderr != nil {
420+
cmd.Stderr = pipeStderr
419421
}
420422
err = cmd.Start()
421423
if err != nil {
@@ -454,8 +456,9 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe {
454456
cmd := exec.Command(args[0], args[1:]...)
455457
cmd.Stdout = w
456458
cmd.Stderr = w
457-
if p.stderr != nil {
458-
cmd.Stderr = p.stderr
459+
pipeStderr := p.stdErr()
460+
if pipeStderr != nil {
461+
cmd.Stderr = pipeStderr
459462
}
460463
err = cmd.Start()
461464
if err != nil {
@@ -839,6 +842,18 @@ func (p *Pipe) Slice() ([]string, error) {
839842
return result, p.Error()
840843
}
841844

845+
// stdErr returns the pipe's configured standard error writer for commands run
846+
// via [Pipe.Exec] and [Pipe.ExecForEach]. The default is nil, which means that
847+
// error output will go to the pipe.
848+
func (p *Pipe) stdErr() io.Writer {
849+
if p.mu == nil { // uninitialised pipe
850+
return nil
851+
}
852+
p.mu.Lock()
853+
defer p.mu.Unlock()
854+
return p.stderr
855+
}
856+
842857
// Stdout copies the pipe's contents to its configured standard output (using
843858
// [Pipe.WithStdout]), or to [os.Stdout] otherwise, and returns the number of
844859
// bytes successfully written, together with any error.
@@ -913,10 +928,11 @@ func (p *Pipe) WithReader(r io.Reader) *Pipe {
913928
return p
914929
}
915930

916-
// WithStderr redirects the standard error output for commands run via
917-
// [Pipe.Exec] or [Pipe.ExecForEach] to the writer w, instead of going to the
918-
// pipe as it normally would.
931+
// WithStderr sets the standard error output for [Pipe.Exec] or
932+
// [Pipe.ExecForEach] commands to w, instead of the pipe.
919933
func (p *Pipe) WithStderr(w io.Writer) *Pipe {
934+
p.mu.Lock()
935+
defer p.mu.Unlock()
920936
p.stderr = w
921937
return p
922938
}

‎script_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -1971,6 +1971,14 @@ func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) {
19711971
}
19721972
}
19731973

1974+
func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) {
1975+
t.Parallel()
1976+
err := script.Exec("echo").WithStderr(nil).Wait()
1977+
if err != nil {
1978+
t.Fatal(err)
1979+
}
1980+
}
1981+
19741982
func ExampleArgs() {
19751983
script.Args().Stdout()
19761984
// prints command-line arguments

0 commit comments

Comments
 (0)
Please sign in to comment.