Skip to content

Commit 7f760af

Browse files
Add parameter to disable SAML URL check (#1128)
Added the parameter "disableSamlURLCheck" to disable SAML URL check.
1 parent c355711 commit 7f760af

File tree

6 files changed

+135
-25
lines changed

6 files changed

+135
-25
lines changed

auth.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
428428
sc.cfg.Application,
429429
sc.cfg.Account,
430430
sc.cfg.User,
431-
sc.cfg.Password)
431+
sc.cfg.Password,
432+
sc.cfg.DisableSamlURLCheck)
432433
if err != nil {
433434
return nil, err
434435
}

authokta.go

+16-13
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ func authenticateBySAML(
5959
account string,
6060
user string,
6161
password string,
62+
disableSamlURLCheck ConfigBool,
6263
) (samlResponse []byte, err error) {
6364
logger.WithContext(ctx).Info("step 1: query GS to obtain IDP token and SSO url")
6465
headers := make(map[string]string)
@@ -152,20 +153,22 @@ func authenticateBySAML(
152153
if err != nil {
153154
return nil, err
154155
}
155-
logger.WithContext(ctx).Info("step 5: validate post_back_url matches Snowflake URL")
156-
tgtURL, err := postBackURL(bd)
157-
if err != nil {
158-
return nil, err
159-
}
156+
if disableSamlURLCheck == ConfigBoolFalse {
157+
logger.WithContext(ctx).Info("step 5: validate post_back_url matches Snowflake URL")
158+
tgtURL, err := postBackURL(bd)
159+
if err != nil {
160+
return nil, err
161+
}
160162

161-
fullURL := sr.getURL()
162-
logger.WithContext(ctx).Infof("tgtURL: %v, origURL: %v", tgtURL, fullURL)
163-
if !isPrefixEqual(tgtURL, fullURL) {
164-
return nil, &SnowflakeError{
165-
Number: ErrCodeSSOURLNotMatch,
166-
SQLState: SQLStateConnectionRejected,
167-
Message: errMsgSSOURLNotMatch,
168-
MessageArgs: []interface{}{tgtURL, fullURL},
163+
fullURL := sr.getURL()
164+
logger.WithContext(ctx).Infof("tgtURL: %v, origURL: %v", tgtURL, fullURL)
165+
if !isPrefixEqual(tgtURL, fullURL) {
166+
return nil, &SnowflakeError{
167+
Number: ErrCodeSSOURLNotMatch,
168+
SQLState: SQLStateConnectionRejected,
169+
Message: errMsgSSOURLNotMatch,
170+
MessageArgs: []interface{}{tgtURL, fullURL},
171+
}
169172
}
170173
}
171174
return bd, nil

authokta_test.go

+42-11
Original file line numberDiff line numberDiff line change
@@ -233,64 +233,95 @@ func TestUnitAuthenticateBySAML(t *testing.T) {
233233
TokenAccessor: getSimpleTokenAccessor(),
234234
}
235235
var err error
236-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
236+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
237237
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
238238
assertEqualE(t, err.Error(), "failed to get SAML response")
239239

240240
sr.FuncPostAuthSAML = postAuthSAMLAuthFail
241-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
241+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
242242
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
243243
assertEqualE(t, err.Error(), "strconv.Atoi: parsing \"\": invalid syntax")
244244

245245
sr.FuncPostAuthSAML = postAuthSAMLAuthFailWithCode
246-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
246+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
247247
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
248248
driverErr, ok := err.(*SnowflakeError)
249249
assertTrueF(t, ok, "should be a SnowflakeError")
250250
assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError)
251251

252252
sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL
253-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
253+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
254254
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
255255
driverErr, ok = err.(*SnowflakeError)
256256
assertTrueF(t, ok, "should be a SnowflakeError")
257257
assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError)
258258

259259
sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL
260-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
260+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
261261
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
262262
assertEqualE(t, err.Error(), "failed to parse token URL. invalid!@url$%^")
263263

264264
sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidSSOURL
265-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
265+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
266266
assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
267267
assertEqualE(t, err.Error(), "failed to parse SSO URL. invalid!@url$%^")
268268

269269
sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess
270270
sr.FuncPostAuthOKTA = postAuthOKTAError
271-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
271+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
272272
assertNotNilF(t, err, "should have failed at FuncPostAuthOKTA.")
273273
assertEqualE(t, err.Error(), "failed to get SAML response")
274274

275275
sr.FuncPostAuthOKTA = postAuthOKTASuccess
276276
sr.FuncGetSSO = getSSOError
277-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
277+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
278278
assertNotNilF(t, err, "should have failed at FuncGetSSO.")
279279
assertEqualE(t, err.Error(), "failed to get SSO html")
280280

281281
sr.FuncGetSSO = getSSOSuccessButInvalidURL
282-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
282+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
283283
assertNotNilF(t, err, "should have failed at FuncGetSSO.")
284284
assertHasPrefixE(t, err.Error(), "failed to find action field in HTML response")
285285

286286
sr.FuncGetSSO = getSSOSuccess
287-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
287+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
288288
assertNilF(t, err, "should have succeeded at FuncGetSSO.")
289289

290290
sr.FuncGetSSO = getSSOSuccessButWrongPrefixURL
291-
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
291+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
292292
assertNotNilF(t, err, "should have failed at FuncGetSSO.")
293293
driverErr, ok = err.(*SnowflakeError)
294294
assertTrueF(t, ok, "should be a SnowflakeError")
295295
assertEqualE(t, driverErr.Number, ErrCodeSSOURLNotMatch)
296296
}
297+
298+
func TestDisableSamlURLCheck(t *testing.T) {
299+
authenticator := &url.URL{
300+
Scheme: "https",
301+
Host: "abc.com",
302+
}
303+
application := "testapp"
304+
account := "testaccount"
305+
user := "u"
306+
password := "p"
307+
sr := &snowflakeRestful{
308+
Protocol: "https",
309+
Host: "abc.com",
310+
Port: 443,
311+
FuncPostAuthSAML: postAuthSAMLAuthSuccess,
312+
FuncPostAuthOKTA: postAuthOKTASuccess,
313+
FuncGetSSO: getSSOSuccessButWrongPrefixURL,
314+
TokenAccessor: getSimpleTokenAccessor(),
315+
}
316+
var err error
317+
// Test for disabled SAML URL check
318+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolTrue)
319+
assertNilF(t, err, "SAML URL check should have disabled.")
320+
321+
// Test for enabled SAML URL check
322+
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password, ConfigBoolFalse)
323+
assertNotNilF(t, err, "should have failed at FuncGetSSO.")
324+
driverErr, ok := err.(*SnowflakeError)
325+
assertTrueF(t, ok, "should be a SnowflakeError")
326+
assertEqualE(t, driverErr.Number, ErrCodeSSOURLNotMatch)
327+
}

doc.go

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ The following connection parameters are supported:
130130
- clientConfigFile: specifies the location of the client configuration json file.
131131
In this file you can configure Easy Logging feature.
132132
133+
- disableSamlURLCheck: disables the SAML URL check. Default value is false.
134+
133135
All other parameters are interpreted as session parameters (https://docs.snowflake.com/en/sql-reference/parameters.html).
134136
For example, the TIMESTAMP_OUTPUT_FORMAT session parameter can be set by adding:
135137

dsn.go

+16
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ type Config struct {
107107
ClientConfigFile string // File path to the client configuration json file
108108

109109
DisableConsoleLogin ConfigBool // Indicates whether console login should be disabled
110+
111+
DisableSamlURLCheck ConfigBool // Indicates whether the SAML URL check should be disabled
110112
}
111113

112114
// Validate enables testing if config is correct.
@@ -267,6 +269,9 @@ func DSN(cfg *Config) (dsn string, err error) {
267269
if cfg.DisableConsoleLogin != configBoolNotSet {
268270
params.Add("disableConsoleLogin", strconv.FormatBool(cfg.DisableConsoleLogin != ConfigBoolFalse))
269271
}
272+
if cfg.DisableSamlURLCheck != configBoolNotSet {
273+
params.Add("disableSamlURLCheck", strconv.FormatBool(cfg.DisableSamlURLCheck != ConfigBoolFalse))
274+
}
270275

271276
dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port)
272277
if params.Encode() != "" {
@@ -770,6 +775,17 @@ func parseDSNParams(cfg *Config, params string) (err error) {
770775
} else {
771776
cfg.DisableConsoleLogin = ConfigBoolFalse
772777
}
778+
case "disableSamlURLCheck":
779+
var vv bool
780+
vv, err = strconv.ParseBool(value)
781+
if err != nil {
782+
return
783+
}
784+
if vv {
785+
cfg.DisableSamlURLCheck = ConfigBoolTrue
786+
} else {
787+
cfg.DisableSamlURLCheck = ConfigBoolFalse
788+
}
773789
default:
774790
if cfg.Params == nil {
775791
cfg.Params = make(map[string]*string)

dsn_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,40 @@ func TestParseDSN(t *testing.T) {
748748
ocspMode: ocspModeFailOpen,
749749
err: nil,
750750
},
751+
{
752+
dsn: "u:[email protected]:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableSamlURLCheck=true",
753+
config: &Config{
754+
Account: "a", User: "u", Password: "p",
755+
Authenticator: AuthTypeExternalBrowser,
756+
Protocol: "http", Host: "a.snowflake.local", Port: 9876,
757+
OCSPFailOpen: OCSPFailOpenTrue,
758+
ValidateDefaultParameters: ConfigBoolTrue,
759+
ClientTimeout: defaultClientTimeout,
760+
JWTClientTimeout: defaultJWTClientTimeout,
761+
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
762+
IncludeRetryReason: ConfigBoolTrue,
763+
DisableSamlURLCheck: ConfigBoolTrue,
764+
},
765+
ocspMode: ocspModeFailOpen,
766+
err: nil,
767+
},
768+
{
769+
dsn: "u:[email protected]:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableSamlURLCheck=false",
770+
config: &Config{
771+
Account: "a", User: "u", Password: "p",
772+
Authenticator: AuthTypeExternalBrowser,
773+
Protocol: "http", Host: "a.snowflake.local", Port: 9876,
774+
OCSPFailOpen: OCSPFailOpenTrue,
775+
ValidateDefaultParameters: ConfigBoolTrue,
776+
ClientTimeout: defaultClientTimeout,
777+
JWTClientTimeout: defaultJWTClientTimeout,
778+
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
779+
IncludeRetryReason: ConfigBoolTrue,
780+
DisableSamlURLCheck: ConfigBoolFalse,
781+
},
782+
ocspMode: ocspModeFailOpen,
783+
err: nil,
784+
},
751785
}
752786

753787
for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} {
@@ -910,6 +944,9 @@ func TestParseDSN(t *testing.T) {
910944
if test.config.DisableConsoleLogin != cfg.DisableConsoleLogin {
911945
t.Fatalf("%v: Failed to match DisableConsoleLogin. expected: %v, got: %v", i, test.config.DisableConsoleLogin, cfg.DisableConsoleLogin)
912946
}
947+
if test.config.DisableSamlURLCheck != cfg.DisableSamlURLCheck {
948+
t.Fatalf("%v: Failed to match DisableSamlURLCheck. expected: %v, got: %v", i, test.config.DisableSamlURLCheck, cfg.DisableSamlURLCheck)
949+
}
913950
assertEqualF(t, cfg.ClientConfigFile, test.config.ClientConfigFile, "client config file")
914951
case test.err != nil:
915952
driverErrE, okE := test.err.(*SnowflakeError)
@@ -1379,6 +1416,26 @@ func TestDSN(t *testing.T) {
13791416
},
13801417
dsn: "u:[email protected]:443?authenticator=externalbrowser&disableConsoleLogin=false&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
13811418
},
1419+
{
1420+
cfg: &Config{
1421+
User: "u",
1422+
Password: "p",
1423+
Account: "a.b.c",
1424+
Authenticator: AuthTypeExternalBrowser,
1425+
DisableSamlURLCheck: ConfigBoolTrue,
1426+
},
1427+
dsn: "u:[email protected]:443?authenticator=externalbrowser&disableSamlURLCheck=true&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
1428+
},
1429+
{
1430+
cfg: &Config{
1431+
User: "u",
1432+
Password: "p",
1433+
Account: "a.b.c",
1434+
Authenticator: AuthTypeExternalBrowser,
1435+
DisableSamlURLCheck: ConfigBoolFalse,
1436+
},
1437+
dsn: "u:[email protected]:443?authenticator=externalbrowser&disableSamlURLCheck=false&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
1438+
},
13821439
}
13831440
for _, test := range testcases {
13841441
t.Run(test.dsn, func(t *testing.T) {

0 commit comments

Comments
 (0)