Skip to content

Commit 72a121f

Browse files
SNOW-1859664 use correct transport for calling cloud providers (#1288)
1 parent 7f77aea commit 72a121f

11 files changed

+106
-28
lines changed

azure_storage_client.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bo
5252
RetryDelay: 2 * time.Second,
5353
},
5454
Transport: &http.Client{
55-
Transport: SnowflakeTransport,
55+
Transport: getTransport(util.cfg),
5656
},
5757
},
5858
})
@@ -74,7 +74,7 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
7474
return nil, err
7575
}
7676
path := azureLoc.path + strings.TrimLeft(filename, "/")
77-
containerClient, err := createContainerClient(client.URL())
77+
containerClient, err := createContainerClient(client.URL(), util.cfg)
7878
if err != nil {
7979
return nil, &SnowflakeError{
8080
Message: "failed to create container client",
@@ -188,7 +188,7 @@ func (util *snowflakeAzureClient) uploadFile(
188188
Message: "failed to cast to azure client",
189189
}
190190
}
191-
containerClient, err := createContainerClient(client.URL())
191+
containerClient, err := createContainerClient(client.URL(), util.cfg)
192192

193193
if err != nil {
194194
return &SnowflakeError{
@@ -273,7 +273,7 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
273273
Message: "failed to cast to azure client",
274274
}
275275
}
276-
containerClient, err := createContainerClient(client.URL())
276+
containerClient, err := createContainerClient(client.URL(), util.cfg)
277277
if err != nil {
278278
return &SnowflakeError{
279279
Message: "failed to create container client",
@@ -348,10 +348,10 @@ func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Respons
348348
strings.Contains(errStr, "Server failed to authenticate the request")
349349
}
350350

351-
func createContainerClient(clientURL string) (*container.Client, error) {
351+
func createContainerClient(clientURL string, cfg *Config) (*container.Client, error) {
352352
return container.NewClientWithNoCredential(clientURL, &container.ClientOptions{ClientOptions: azcore.ClientOptions{
353353
Transport: &http.Client{
354-
Transport: SnowflakeTransport,
354+
Transport: getTransport(cfg),
355355
},
356356
}})
357357
}

client_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func (t *DummyTransport) RoundTrip(r *http.Request) (*http.Response, error) {
2323
}
2424
return &http.Response{StatusCode: 200}, nil
2525
}
26-
return snowflakeInsecureTransport.RoundTrip(r)
26+
return snowflakeNoOcspTransport.RoundTrip(r)
2727
}
2828

2929
func TestInternalClient(t *testing.T) {

connection.go

+19-1
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err
790790
if sc.cfg.Transporter == nil {
791791
if sc.cfg.DisableOCSPChecks || sc.cfg.InsecureMode {
792792
// no revocation check with OCSP. Think twice when you want to enable this option.
793-
st = snowflakeInsecureTransport
793+
st = snowflakeNoOcspTransport
794794
} else {
795795
// set OCSP fail open mode
796796
ocspResponseCacheLock.Lock()
@@ -856,3 +856,21 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err
856856

857857
return sc, nil
858858
}
859+
860+
func getTransport(cfg *Config) http.RoundTripper {
861+
if cfg == nil {
862+
logger.Debug("getTransport: got nil Config, will perform OCSP validation for cloud storage")
863+
return SnowflakeTransport
864+
}
865+
// if user configured a custom Transporter, prioritize that
866+
if cfg.Transporter != nil {
867+
logger.Debug("getTransport: using Transporter configured by the user")
868+
return cfg.Transporter
869+
}
870+
if cfg.DisableOCSPChecks || cfg.InsecureMode {
871+
logger.Debug("getTransport: skipping OCSP validation for cloud storage")
872+
return snowflakeNoOcspTransport
873+
}
874+
logger.Debug("getTransport: will perform OCSP validation for cloud storage")
875+
return SnowflakeTransport
876+
}

connection_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,56 @@ func TestBeginCreatesTransaction(t *testing.T) {
826826
}
827827
})
828828
}
829+
830+
type EmptyTransporter struct{}
831+
832+
func (t EmptyTransporter) RoundTrip(*http.Request) (*http.Response, error) {
833+
return nil, nil
834+
}
835+
836+
func TestGetTransport(t *testing.T) {
837+
testcases := []struct {
838+
name string
839+
cfg *Config
840+
transport http.RoundTripper
841+
}{
842+
{
843+
name: "DisableOCSPChecks and InsecureMode false",
844+
cfg: &Config{Account: "one", DisableOCSPChecks: false, InsecureMode: false},
845+
transport: SnowflakeTransport,
846+
},
847+
{
848+
name: "DisableOCSPChecks true and InsecureMode false",
849+
cfg: &Config{Account: "two", DisableOCSPChecks: true, InsecureMode: false},
850+
transport: snowflakeNoOcspTransport,
851+
},
852+
{
853+
name: "DisableOCSPChecks false and InsecureMode true",
854+
cfg: &Config{Account: "three", DisableOCSPChecks: false, InsecureMode: true},
855+
transport: snowflakeNoOcspTransport,
856+
},
857+
{
858+
name: "DisableOCSPChecks and InsecureMode missing from Config",
859+
cfg: &Config{Account: "four"},
860+
transport: SnowflakeTransport,
861+
},
862+
{
863+
name: "whole Config is missing",
864+
cfg: nil,
865+
transport: SnowflakeTransport,
866+
},
867+
{
868+
name: "Using custom Transporter",
869+
cfg: &Config{Account: "five", DisableOCSPChecks: true, InsecureMode: false, Transporter: EmptyTransporter{}},
870+
transport: EmptyTransporter{},
871+
},
872+
}
873+
for _, test := range testcases {
874+
t.Run(test.name, func(t *testing.T) {
875+
result := getTransport(test.cfg)
876+
if test.transport != result {
877+
t.Errorf("Failed to return the correct transport, input :%#v, expected: %v, got: %v", test.cfg, test.transport, result)
878+
}
879+
})
880+
}
881+
}

driver_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1993,7 +1993,7 @@ type CountingTransport struct {
19931993

19941994
func (t *CountingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
19951995
t.requests++
1996-
return snowflakeInsecureTransport.RoundTrip(r)
1996+
return snowflakeNoOcspTransport.RoundTrip(r)
19971997
}
19981998

19991999
func TestOpenWithTransport(t *testing.T) {

file_transfer_agent.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -597,15 +597,15 @@ type s3BucketAccelerateConfigGetter interface {
597597

598598
type s3ClientCreator interface {
599599
extractBucketNameAndPath(location string) (*s3Location, error)
600-
createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
600+
createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error)
601601
}
602602

603603
func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s3ClientCreator) error {
604604
s3Loc, err := s3Util.extractBucketNameAndPath(sfa.stageInfo.Location)
605605
if err != nil {
606606
return err
607607
}
608-
s3Cli, err := s3Util.createClient(sfa.stageInfo, false)
608+
s3Cli, err := s3Util.createClientWithConfig(sfa.stageInfo, false, sfa.sc.cfg)
609609
if err != nil {
610610
return err
611611
}

file_transfer_agent_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) {
5353

5454
type s3ClientCreatorMock struct {
5555
extract func(string) (*s3Location, error)
56-
create func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
56+
create func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error)
5757
}
5858

5959
func (mock *s3ClientCreatorMock) extractBucketNameAndPath(location string) (*s3Location, error) {
6060
return mock.extract(location)
6161
}
6262

63-
func (mock *s3ClientCreatorMock) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
64-
return mock.create(info, useAccelerateEndpoint)
63+
func (mock *s3ClientCreatorMock) createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
64+
return mock.create(info, useAccelerateEndpoint, cfg)
6565
}
6666

6767
type s3BucketAccelerateConfigGetterMock struct {
@@ -96,7 +96,7 @@ func TestGetBucketAccelerateConfigurationTooManyRetries(t *testing.T) {
9696
extract: func(s string) (*s3Location, error) {
9797
return &s3Location{bucketName: "test", s3Path: "test"}, nil
9898
},
99-
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
99+
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
100100
return &s3BucketAccelerateConfigGetterMock{err: errors.New("testing")}, nil
101101
},
102102
})
@@ -146,7 +146,7 @@ func TestGetBucketAccelerateConfigurationFailedCreateClient(t *testing.T) {
146146
extract: func(s string) (*s3Location, error) {
147147
return &s3Location{bucketName: "test", s3Path: "test"}, nil
148148
},
149-
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
149+
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
150150
return nil, errors.New("failed creation")
151151
},
152152
})
@@ -172,7 +172,7 @@ func TestGetBucketAccelerateConfigurationInvalidClient(t *testing.T) {
172172
extract: func(s string) (*s3Location, error) {
173173
return &s3Location{bucketName: "test", s3Path: "test"}, nil
174174
},
175-
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
175+
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
176176
return 1, nil
177177
},
178178
})

gcs_storage_client.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin
7373
for k, v := range gcsHeaders {
7474
req.Header.Add(k, v)
7575
}
76-
client := newGcsClient()
76+
client := newGcsClient(util.cfg)
7777
// for testing only
7878
if meta.mockGcsClient != nil {
7979
client = meta.mockGcsClient
@@ -226,7 +226,7 @@ func (util *snowflakeGcsClient) uploadFile(
226226
for k, v := range gcsHeaders {
227227
req.Header.Add(k, v)
228228
}
229-
client := newGcsClient()
229+
client := newGcsClient(util.cfg)
230230
// for testing only
231231
if meta.mockGcsClient != nil {
232232
client = meta.mockGcsClient
@@ -307,7 +307,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
307307
for k, v := range gcsHeaders {
308308
req.Header.Add(k, v)
309309
}
310-
client := newGcsClient()
310+
client := newGcsClient(util.cfg)
311311
// for testing only
312312
if meta.mockGcsClient != nil {
313313
client = meta.mockGcsClient
@@ -409,9 +409,9 @@ func (util *snowflakeGcsClient) isTokenExpired(resp *http.Response) bool {
409409
return resp.StatusCode == 401
410410
}
411411

412-
func newGcsClient() gcsAPI {
412+
func newGcsClient(cfg *Config) gcsAPI {
413413
return &http.Client{
414-
Transport: SnowflakeTransport,
414+
Transport: getTransport(cfg),
415415
}
416416
}
417417

ocsp.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ func getRevocationStatus(ctx context.Context, subject, issuer *x509.Certificate)
637637
}
638638
ocspClient := &http.Client{
639639
Timeout: timeout,
640-
Transport: snowflakeInsecureTransport,
640+
Transport: snowflakeNoOcspTransport,
641641
}
642642
ocspRes, ocspResBytes, ocspS := retryOCSP(
643643
ctx, ocspClient, http.NewRequest, u, headers, ocspReq, issuer, timeout)
@@ -786,7 +786,7 @@ func downloadOCSPCacheServer() {
786786
}
787787
ocspClient := &http.Client{
788788
Timeout: timeout,
789-
Transport: snowflakeInsecureTransport,
789+
Transport: snowflakeNoOcspTransport,
790790
}
791791
ret, ocspStatus := checkOCSPCacheServer(context.Background(), ocspClient, http.NewRequest, u, timeout)
792792
if ocspStatus.code != ocspSuccess {
@@ -1075,8 +1075,8 @@ func init() {
10751075
initOCSPCache()
10761076
}
10771077

1078-
// snowflakeInsecureTransport is the transport object that doesn't do certificate revocation check.
1079-
var snowflakeInsecureTransport = &http.Transport{
1078+
// snowflakeNoOcspTransport is the transport object that doesn't do certificate revocation check with OCSP.
1079+
var snowflakeNoOcspTransport = &http.Transport{
10801080
MaxIdleConns: 10,
10811081
IdleConnTimeout: 30 * time.Minute,
10821082
Proxy: http.ProxyFromEnvironment,

ocsp_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestOCSP(t *testing.T) {
3434
}
3535

3636
transports := []*http.Transport{
37-
snowflakeInsecureTransport,
37+
snowflakeNoOcspTransport,
3838
SnowflakeTransport,
3939
}
4040

s3_storage_client.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,20 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
5959
BaseEndpoint: endPoint,
6060
UseAccelerate: useAccelerateEndpoint,
6161
HTTPClient: &http.Client{
62-
Transport: SnowflakeTransport,
62+
Transport: getTransport(util.cfg),
6363
},
6464
ClientLogMode: S3LoggingMode,
6565
Logger: s3Logger,
6666
}), nil
6767
}
6868

69+
// to be used with S3 transferAccelerateConfigWithUtil
70+
func (util *snowflakeS3Client) createClientWithConfig(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) {
71+
// copy snowflakeFileTransferAgent's config onto the cloud client so we could decide which Transport to use
72+
util.cfg = cfg
73+
return util.createClient(info, useAccelerateEndpoint)
74+
}
75+
6976
func getS3CustomEndpoint(info *execResponseStageInfo) *string {
7077
var endPoint *string
7178
isRegionalURLEnabled := info.UseRegionalURL || info.UseS3RegionalURL

0 commit comments

Comments
 (0)