Skip to content

Commit bea18fc

Browse files
Proposed goroutine wrapper solution (#1157)
1 parent 5d28db8 commit bea18fc

5 files changed

+137
-51
lines changed

async.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ func (sr *snowflakeRestful) processAsync(
3737
}
3838

3939
// spawn goroutine to retrieve asynchronous results
40-
go sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg)
40+
go GoroutineWrapper(
41+
ctx,
42+
func() {
43+
sr.getAsync(ctx, headers, sr.getFullURL(respd.Data.GetResultURL, nil), timeout, res, rows, cfg)
44+
},
45+
)
4146
return respd, nil
4247
}
4348

authexternalbrowser.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,12 @@ func authenticateByExternalBrowser(
211211
disableConsoleLogin ConfigBool,
212212
) ([]byte, []byte, error) {
213213
resultChan := make(chan authenticateByExternalBrowserResult, 1)
214-
go func() {
215-
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password, disableConsoleLogin)
216-
}()
214+
go GoroutineWrapper(
215+
ctx,
216+
func() {
217+
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password, disableConsoleLogin)
218+
},
219+
)
217220
select {
218221
case <-time.After(externalBrowserTimeout):
219222
return nil, nil, errors.New("authentication timed out")

chunk_downloader.go

+60-47
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,12 @@ func (scd *snowflakeChunkDownloader) schedule() {
148148
select {
149149
case nextIdx := <-scd.ChunksChan:
150150
logger.WithContext(scd.ctx).Infof("schedule chunk: %v", nextIdx+1)
151-
go scd.FuncDownload(scd.ctx, scd, nextIdx)
151+
go GoroutineWrapper(
152+
scd.ctx,
153+
func() {
154+
scd.FuncDownload(scd.ctx, scd, nextIdx)
155+
},
156+
)
152157
default:
153158
// no more download
154159
logger.WithContext(scd.ctx).Info("no more download")
@@ -162,7 +167,12 @@ func (scd *snowflakeChunkDownloader) checkErrorRetry() (err error) {
162167
errc.Error != context.Canceled &&
163168
errc.Error != context.DeadlineExceeded {
164169
// add the index to the chunks channel so that the download will be retried.
165-
go scd.FuncDownload(scd.ctx, scd, errc.Index)
170+
go GoroutineWrapper(
171+
scd.ctx,
172+
func() {
173+
scd.FuncDownload(scd.ctx, scd, errc.Index)
174+
},
175+
)
166176
scd.ChunksErrorCounter++
167177
logger.WithContext(scd.ctx).Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...",
168178
errc.Index, errc.Error, scd.ChunksErrorCounter, maxChunkDownloaderErrorCounter)
@@ -508,55 +518,58 @@ func (scd *streamChunkDownloader) nextResultSet() error {
508518
}
509519

510520
func (scd *streamChunkDownloader) start() error {
511-
go func() {
512-
readErr := io.EOF
513-
514-
logger.WithContext(scd.ctx).Infof(
515-
"start downloading. downloader id: %v, %v/%v rows, %v chunks",
516-
scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas))
517-
t := time.Now()
518-
519-
defer func() {
520-
if readErr == io.EOF {
521-
logger.WithContext(scd.ctx).Infof("downloading done. downloader id: %v", scd.id)
522-
} else {
523-
logger.WithContext(scd.ctx).Debugf("downloading error. downloader id: %v", scd.id)
524-
}
525-
scd.readErr = readErr
526-
close(scd.rowStream)
527-
528-
if r := recover(); r != nil {
529-
if err, ok := r.(error); ok {
530-
readErr = err
521+
go GoroutineWrapper(
522+
scd.ctx,
523+
func() {
524+
readErr := io.EOF
525+
526+
logger.WithContext(scd.ctx).Infof(
527+
"start downloading. downloader id: %v, %v/%v rows, %v chunks",
528+
scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas))
529+
t := time.Now()
530+
531+
defer func() {
532+
if readErr == io.EOF {
533+
logger.WithContext(scd.ctx).Infof("downloading done. downloader id: %v", scd.id)
531534
} else {
532-
readErr = fmt.Errorf("%v", r)
535+
logger.WithContext(scd.ctx).Debugf("downloading error. downloader id: %v", scd.id)
536+
}
537+
scd.readErr = readErr
538+
close(scd.rowStream)
539+
540+
if r := recover(); r != nil {
541+
if err, ok := r.(error); ok {
542+
readErr = err
543+
} else {
544+
readErr = fmt.Errorf("%v", r)
545+
}
533546
}
547+
}()
548+
549+
logger.WithContext(scd.ctx).Infof("sending initial set of rows in %vms", time.Since(t).Microseconds())
550+
t = time.Now()
551+
for _, row := range scd.RowSet.JSON {
552+
scd.rowStream <- row
534553
}
535-
}()
536-
537-
logger.WithContext(scd.ctx).Infof("sending initial set of rows in %vms", time.Since(t).Microseconds())
538-
t = time.Now()
539-
for _, row := range scd.RowSet.JSON {
540-
scd.rowStream <- row
541-
}
542-
scd.RowSet.JSON = nil
543-
544-
// Download and parse one chunk at a time. The fetcher will send each
545-
// parsed row to the row stream. When an error occurs, the fetcher will
546-
// stop writing to the row stream so we can stop processing immediately
547-
for i, chunk := range scd.ChunkMetas {
548-
logger.WithContext(scd.ctx).Infof("starting chunk fetch %d (%d rows)", i, chunk.RowCount)
549-
if err := scd.fetcher.fetch(chunk.URL, scd.rowStream); err != nil {
550-
logger.WithContext(scd.ctx).Debugf(
551-
"failed chunk fetch %d: %#v, downloader id: %v, %v/%v rows, %v chunks",
552-
i, err, scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas))
553-
readErr = fmt.Errorf("chunk fetch: %w", err)
554-
break
554+
scd.RowSet.JSON = nil
555+
556+
// Download and parse one chunk at a time. The fetcher will send each
557+
// parsed row to the row stream. When an error occurs, the fetcher will
558+
// stop writing to the row stream so we can stop processing immediately
559+
for i, chunk := range scd.ChunkMetas {
560+
logger.WithContext(scd.ctx).Infof("starting chunk fetch %d (%d rows)", i, chunk.RowCount)
561+
if err := scd.fetcher.fetch(chunk.URL, scd.rowStream); err != nil {
562+
logger.WithContext(scd.ctx).Debugf(
563+
"failed chunk fetch %d: %#v, downloader id: %v, %v/%v rows, %v chunks",
564+
i, err, scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas))
565+
readErr = fmt.Errorf("chunk fetch: %w", err)
566+
break
567+
}
568+
logger.WithContext(scd.ctx).Infof("fetched chunk %d (%d rows) in %vms", i, chunk.RowCount, time.Since(t).Microseconds())
569+
t = time.Now()
555570
}
556-
logger.WithContext(scd.ctx).Infof("fetched chunk %d (%d rows) in %vms", i, chunk.RowCount, time.Since(t).Microseconds())
557-
t = time.Now()
558-
}
559-
}()
571+
},
572+
)
560573
return nil
561574
}
562575

function_wrapper_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package gosnowflake
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
)
8+
9+
var (
10+
goWrapperCalled = false
11+
testGoRoutineWrapperLock sync.Mutex
12+
)
13+
14+
func setGoWrapperCalled(value bool) {
15+
testGoRoutineWrapperLock.Lock()
16+
defer testGoRoutineWrapperLock.Unlock()
17+
goWrapperCalled = value
18+
}
19+
func getGoWrapperCalled() bool {
20+
testGoRoutineWrapperLock.Lock()
21+
defer testGoRoutineWrapperLock.Unlock()
22+
return goWrapperCalled
23+
}
24+
25+
// this is the go wrapper function we are going to pass into GoroutineWrapper.
26+
// we will know that this has been called if the channel is closed
27+
var closeGoWrapperCalledChannel = func(ctx context.Context, f func()) {
28+
setGoWrapperCalled(true)
29+
f()
30+
}
31+
32+
func TestGoWrapper(t *testing.T) {
33+
runDBTest(t, func(dbt *DBTest) {
34+
oldGoroutineWrapper := GoroutineWrapper
35+
t.Cleanup(func() {
36+
GoroutineWrapper = oldGoroutineWrapper
37+
})
38+
39+
GoroutineWrapper = closeGoWrapperCalledChannel
40+
41+
ctx := WithAsyncMode(context.Background())
42+
rows := dbt.mustQueryContext(ctx, "SELECT 1")
43+
defer rows.Close()
44+
45+
assertTrueF(t, getGoWrapperCalled(), "channel should be closed, indicating our wrapper worked")
46+
})
47+
}

function_wrappers.go

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package gosnowflake
2+
3+
import "context"
4+
5+
// GoroutineWrapperFunc is used to wrap goroutines. This is useful if the caller wants
6+
// to recover panics, rather than letting panics cause a system crash. A suggestion would be to
7+
// use use the recover functionality, and log the panic as is most useful to you
8+
type GoroutineWrapperFunc func(ctx context.Context, f func())
9+
10+
// The default GoroutineWrapperFunc; this does nothing. With this default wrapper
11+
// panics will take down binary as expected
12+
var noopGoroutineWrapper = func(_ context.Context, f func()) {
13+
f()
14+
}
15+
16+
// GoroutineWrapper is used to hold the GoroutineWrapperFunc set by the client, or to
17+
// store the default goroutine wrapper which does nothing
18+
var GoroutineWrapper GoroutineWrapperFunc = noopGoroutineWrapper

0 commit comments

Comments
 (0)