Skip to content

Commit 8257f91

Browse files
SNOW-1789753: Support GCS region specific endpoint (#1280)
1 parent 7d34091 commit 8257f91

5 files changed

+239
-18
lines changed

gcs_storage_client.go

+21-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const (
2020
gcsMetadataMatdescKey = gcsMetadataPrefix + "matdesc"
2121
gcsMetadataEncryptionDataProp = gcsMetadataPrefix + "encryptiondata"
2222
gcsFileHeaderDigest = "gcs-file-header-digest"
23+
gcsRegionMeCentral2 = "me-central2"
2324
)
2425

2526
type snowflakeGcsClient struct {
@@ -52,7 +53,7 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin
5253
if meta.presignedURL != nil {
5354
meta.resStatus = notFoundFile
5455
} else {
55-
URL, err := util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(filename, "/"))
56+
URL, err := util.generateFileURL(meta.stageInfo, strings.TrimLeft(filename, "/"))
5657
if err != nil {
5758
return nil, err
5859
}
@@ -147,7 +148,7 @@ func (util *snowflakeGcsClient) uploadFile(
147148
var err error
148149

149150
if uploadURL == nil {
150-
uploadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.dstFileName, "/"))
151+
uploadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.dstFileName, "/"))
151152
if err != nil {
152153
return err
153154
}
@@ -279,7 +280,7 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
279280
gcsHeaders := make(map[string]string)
280281

281282
if downloadURL == nil || downloadURL.String() == "" {
282-
downloadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.srcFileName, "/"))
283+
downloadURL, err = util.generateFileURL(meta.stageInfo, strings.TrimLeft(meta.srcFileName, "/"))
283284
if err != nil {
284285
return err
285286
}
@@ -388,10 +389,11 @@ func (util *snowflakeGcsClient) extractBucketNameAndPath(location string) *gcsLo
388389
return &gcsLocation{containerName, path}
389390
}
390391

391-
func (util *snowflakeGcsClient) generateFileURL(stageLocation string, filename string) (*url.URL, error) {
392-
gcsLoc := util.extractBucketNameAndPath(stageLocation)
392+
func (util *snowflakeGcsClient) generateFileURL(stageInfo *execResponseStageInfo, filename string) (*url.URL, error) {
393+
gcsLoc := util.extractBucketNameAndPath(stageInfo.Location)
393394
fullFilePath := gcsLoc.path + filename
394-
URL, err := url.Parse("https://storage.googleapis.com/" + gcsLoc.bucketName + "/" + url.QueryEscape(fullFilePath))
395+
endPoint := getGcsCustomEndpoint(stageInfo)
396+
URL, err := url.Parse(endPoint + "/" + gcsLoc.bucketName + "/" + url.QueryEscape(fullFilePath))
395397
if err != nil {
396398
return nil, err
397399
}
@@ -407,3 +409,16 @@ func newGcsClient() gcsAPI {
407409
Transport: SnowflakeTransport,
408410
}
409411
}
412+
413+
func getGcsCustomEndpoint(info *execResponseStageInfo) string {
414+
endpoint := "https://storage.googleapis.com"
415+
416+
// TODO: SNOW-1789759 hardcoded region will be replaced in the future
417+
isRegionalURLEnabled := (strings.ToLower(info.Region) == gcsRegionMeCentral2) || info.UseRegionalURL
418+
if info.EndPoint != "" {
419+
endpoint = fmt.Sprintf("https://%s", info.EndPoint)
420+
} else if info.Region != "" && isRegionalURLEnabled {
421+
endpoint = fmt.Sprintf("https://storage.%s.rep.googleapis.com", strings.ToLower(info.Region))
422+
}
423+
return endpoint
424+
}

gcs_storage_client_test.go

+93-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ func TestGenerateFileURL(t *testing.T) {
105105
}
106106
for _, test := range testcases {
107107
t.Run(test.location, func(t *testing.T) {
108-
gcsURL, err := gcsUtil.generateFileURL(test.location, test.fname)
108+
stageInfo := &execResponseStageInfo{}
109+
stageInfo.Location = test.location
110+
gcsURL, err := gcsUtil.generateFileURL(stageInfo, test.fname)
109111
if err != nil {
110112
t.Error(err)
111113
}
@@ -1126,3 +1128,93 @@ func Test_snowflakeGcsClient_nativeDownloadFile(t *testing.T) {
11261128
t.Error("should have raised an error")
11271129
}
11281130
}
1131+
1132+
func TestGetGcsCustomEndpoint(t *testing.T) {
1133+
testcases := []struct {
1134+
desc string
1135+
in execResponseStageInfo
1136+
out string
1137+
}{
1138+
{
1139+
desc: "when the endPoint is not specified and UseRegionalURL is false",
1140+
in: execResponseStageInfo{
1141+
UseRegionalURL: false,
1142+
EndPoint: "",
1143+
Region: "WEST-1",
1144+
},
1145+
out: "https://storage.googleapis.com",
1146+
},
1147+
{
1148+
desc: "when the useRegionalURL is only enabled",
1149+
in: execResponseStageInfo{
1150+
UseRegionalURL: true,
1151+
EndPoint: "",
1152+
Region: "mockLocation",
1153+
},
1154+
out: "https://storage.mocklocation.rep.googleapis.com",
1155+
},
1156+
{
1157+
desc: "when the region is me-central2",
1158+
in: execResponseStageInfo{
1159+
UseRegionalURL: false,
1160+
EndPoint: "",
1161+
Region: "me-central2",
1162+
},
1163+
out: "https://storage.me-central2.rep.googleapis.com",
1164+
},
1165+
{
1166+
desc: "when the region is me-central2 (mixed case)",
1167+
in: execResponseStageInfo{
1168+
UseRegionalURL: false,
1169+
EndPoint: "",
1170+
Region: "ME-cEntRal2",
1171+
},
1172+
out: "https://storage.me-central2.rep.googleapis.com",
1173+
},
1174+
{
1175+
desc: "when the region is me-central2 (uppercase)",
1176+
in: execResponseStageInfo{
1177+
UseRegionalURL: false,
1178+
EndPoint: "",
1179+
Region: "ME-CENTRAL2",
1180+
},
1181+
out: "https://storage.me-central2.rep.googleapis.com",
1182+
},
1183+
{
1184+
desc: "when the endPoint is specified",
1185+
in: execResponseStageInfo{
1186+
UseRegionalURL: false,
1187+
EndPoint: "storage.specialEndPoint.rep.googleapis.com",
1188+
Region: "ME-cEntRal1",
1189+
},
1190+
out: "https://storage.specialEndPoint.rep.googleapis.com",
1191+
},
1192+
{
1193+
desc: "when both the endPoint and the useRegionalUrl are specified",
1194+
in: execResponseStageInfo{
1195+
UseRegionalURL: true,
1196+
EndPoint: "storage.specialEndPoint.rep.googleapis.com",
1197+
Region: "ME-cEntRal1",
1198+
},
1199+
out: "https://storage.specialEndPoint.rep.googleapis.com",
1200+
},
1201+
{
1202+
desc: "when both the endPoint is specified and the region is me-central2",
1203+
in: execResponseStageInfo{
1204+
UseRegionalURL: true,
1205+
EndPoint: "storage.specialEndPoint.rep.googleapis.com",
1206+
Region: "ME-CENTRAL2",
1207+
},
1208+
out: "https://storage.specialEndPoint.rep.googleapis.com",
1209+
},
1210+
}
1211+
1212+
for _, test := range testcases {
1213+
t.Run(test.desc, func(t *testing.T) {
1214+
endpoint := getGcsCustomEndpoint(&test.in)
1215+
if endpoint != test.out {
1216+
t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, endpoint)
1217+
}
1218+
})
1219+
}
1220+
}

query.go

+2
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ type execResponseStageInfo struct {
108108
Creds execResponseCredentials `json:"creds,omitempty"`
109109
PresignedURL string `json:"presignedUrl,omitempty"`
110110
EndPoint string `json:"endPoint,omitempty"`
111+
UseS3RegionalURL bool `json:"useS3RegionalUrl,omitempty"`
112+
UseRegionalURL bool `json:"useRegionalUrl,omitempty"`
111113
}
112114

113115
// make all data field optional

s3_storage_client.go

+24-11
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@ import (
77
"context"
88
"errors"
99
"fmt"
10+
"io"
11+
"net/http"
12+
"os"
13+
"strings"
14+
1015
"github.com/aws/aws-sdk-go-v2/aws"
1116
"github.com/aws/aws-sdk-go-v2/credentials"
1217
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
1318
"github.com/aws/aws-sdk-go-v2/service/s3"
1419
"github.com/aws/smithy-go"
1520
"github.com/aws/smithy-go/logging"
16-
"io"
17-
"net/http"
18-
"os"
19-
"strings"
2021
)
2122

2223
const (
@@ -47,20 +48,15 @@ var S3LoggingMode aws.ClientLogMode
4748
func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
4849
stageCredentials := info.Creds
4950
s3Logger := logging.LoggerFunc(s3LoggingFunc)
50-
51-
var endpoint *string
52-
if info.EndPoint != "" {
53-
tmp := "https://" + info.EndPoint
54-
endpoint = &tmp
55-
}
51+
endPoint := getS3CustomEndpoint(info)
5652

5753
return s3.New(s3.Options{
5854
Region: info.Region,
5955
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(
6056
stageCredentials.AwsKeyID,
6157
stageCredentials.AwsSecretKey,
6258
stageCredentials.AwsToken)),
63-
BaseEndpoint: endpoint,
59+
BaseEndpoint: endPoint,
6460
UseAccelerate: useAccelerateEndpoint,
6561
HTTPClient: &http.Client{
6662
Transport: SnowflakeTransport,
@@ -70,6 +66,23 @@ func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAcce
7066
}), nil
7167
}
7268

69+
func getS3CustomEndpoint(info *execResponseStageInfo) *string {
70+
var endPoint *string
71+
isRegionalURLEnabled := info.UseRegionalURL || info.UseS3RegionalURL
72+
if info.EndPoint != "" {
73+
tmp := fmt.Sprintf("https://%s", info.EndPoint)
74+
endPoint = &tmp
75+
} else if info.Region != "" && isRegionalURLEnabled {
76+
domainSuffixForRegionalURL := "amazonaws.com"
77+
if strings.HasPrefix(strings.ToLower(info.Region), "cn-") {
78+
domainSuffixForRegionalURL = "amazonaws.com.cn"
79+
}
80+
tmp := fmt.Sprintf("https://s3.%s.%s", info.Region, domainSuffixForRegionalURL)
81+
endPoint = &tmp
82+
}
83+
return endPoint
84+
}
85+
7386
func s3LoggingFunc(classification logging.Classification, format string, v ...interface{}) {
7487
switch classification {
7588
case logging.Debug:

s3_storage_client_test.go

+99
Original file line numberDiff line numberDiff line change
@@ -793,3 +793,102 @@ func TestConvertContentLength(t *testing.T) {
793793
})
794794
}
795795
}
796+
797+
func TestGetS3Endpoint(t *testing.T) {
798+
testcases := []struct {
799+
desc string
800+
in execResponseStageInfo
801+
out string
802+
}{
803+
804+
{
805+
desc: "when UseRegionalURL is valid and the region does not start with cn-",
806+
in: execResponseStageInfo{
807+
UseS3RegionalURL: false,
808+
UseRegionalURL: true,
809+
EndPoint: "",
810+
Region: "WEST-1",
811+
},
812+
out: "https://s3.WEST-1.amazonaws.com",
813+
},
814+
{
815+
desc: "when UseS3RegionalURL is valid and the region does not start with cn-",
816+
in: execResponseStageInfo{
817+
UseS3RegionalURL: true,
818+
UseRegionalURL: false,
819+
EndPoint: "",
820+
Region: "WEST-1",
821+
},
822+
out: "https://s3.WEST-1.amazonaws.com",
823+
},
824+
{
825+
desc: "when endPoint is enabled and the region does not start with cn-",
826+
in: execResponseStageInfo{
827+
UseS3RegionalURL: false,
828+
UseRegionalURL: false,
829+
EndPoint: "s3.endpoint",
830+
Region: "mockLocation",
831+
},
832+
out: "https://s3.endpoint",
833+
},
834+
{
835+
desc: "when endPoint is enabled and the region starts with cn-",
836+
in: execResponseStageInfo{
837+
UseS3RegionalURL: false,
838+
UseRegionalURL: false,
839+
EndPoint: "s3.endpoint",
840+
Region: "cn-mockLocation",
841+
},
842+
out: "https://s3.endpoint",
843+
},
844+
{
845+
desc: "when useS3RegionalURL is valid and domain starts with cn",
846+
in: execResponseStageInfo{
847+
UseS3RegionalURL: true,
848+
UseRegionalURL: false,
849+
EndPoint: "",
850+
Region: "cn-mockLocation",
851+
},
852+
out: "https://s3.cn-mockLocation.amazonaws.com.cn",
853+
},
854+
{
855+
desc: "when useRegionalURL is valid and domain starts with cn",
856+
in: execResponseStageInfo{
857+
UseS3RegionalURL: true,
858+
UseRegionalURL: false,
859+
EndPoint: "",
860+
Region: "cn-mockLocation",
861+
},
862+
out: "https://s3.cn-mockLocation.amazonaws.com.cn",
863+
},
864+
{
865+
desc: "when useRegionalURL is valid and domain starts with cn",
866+
in: execResponseStageInfo{
867+
UseS3RegionalURL: true,
868+
UseRegionalURL: false,
869+
EndPoint: "",
870+
Region: "cn-mockLocation",
871+
},
872+
out: "https://s3.cn-mockLocation.amazonaws.com.cn",
873+
},
874+
{
875+
desc: "when endPoint is specified, both UseRegionalURL and useS3PRegionalUrl are valid, and the region starts with cn",
876+
in: execResponseStageInfo{
877+
UseS3RegionalURL: true,
878+
UseRegionalURL: true,
879+
EndPoint: "s3.endpoint",
880+
Region: "cn-mockLocation",
881+
},
882+
out: "https://s3.endpoint",
883+
},
884+
}
885+
886+
for _, test := range testcases {
887+
t.Run(test.desc, func(t *testing.T) {
888+
endpoint := getS3CustomEndpoint(&test.in)
889+
if *endpoint != test.out {
890+
t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.out, *endpoint)
891+
}
892+
})
893+
}
894+
}

0 commit comments

Comments
 (0)