Skip to content

Commit c55b379

Browse files
authored
SNOW-1825476 Implement programmatic access token (PAT) (#1298)
1 parent 11bb91f commit c55b379

File tree

14 files changed

+358
-28
lines changed

14 files changed

+358
-28
lines changed

.github/workflows/build-test.yml

+3-5
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ concurrency:
2626
jobs:
2727
lint:
2828
runs-on: ubuntu-latest
29-
strategy:
30-
fail-fast: false
3129
name: Check linter
3230
steps:
3331
- uses: actions/checkout@v4
@@ -54,7 +52,7 @@ jobs:
5452
- uses: actions/checkout@v4
5553
- uses: actions/setup-java@v4 # for wiremock
5654
with:
57-
java-version: 11
55+
java-version: 17
5856
distribution: 'temurin'
5957
- name: Setup go
6058
uses: actions/setup-go@v5
@@ -85,7 +83,7 @@ jobs:
8583
- uses: actions/checkout@v4
8684
- uses: actions/setup-java@v4 # for wiremock
8785
with:
88-
java-version: 11
86+
java-version: 17
8987
distribution: 'temurin'
9088
- name: Setup go
9189
uses: actions/setup-go@v5
@@ -115,7 +113,7 @@ jobs:
115113
- uses: actions/checkout@v4
116114
- uses: actions/setup-java@v4 # for wiremock
117115
with:
118-
java-version: 11
116+
java-version: 17
119117
distribution: 'temurin'
120118
- name: Setup go
121119
uses: actions/setup-go@v5

auth.go

+24
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"io"
1414
"net/http"
1515
"net/url"
16+
"os"
1617
"runtime"
1718
"strconv"
1819
"strings"
@@ -49,6 +50,8 @@ const (
4950
AuthTypeTokenAccessor
5051
// AuthTypeUsernamePasswordMFA is to use username and password with mfa
5152
AuthTypeUsernamePasswordMFA
53+
// AuthTypePat is to use programmatic access token
54+
AuthTypePat
5255
)
5356

5457
func determineAuthenticatorType(cfg *Config, value string) error {
@@ -72,6 +75,9 @@ func determineAuthenticatorType(cfg *Config, value string) error {
7275
} else if upperCaseValue == AuthTypeTokenAccessor.String() {
7376
cfg.Authenticator = AuthTypeTokenAccessor
7477
return nil
78+
} else if upperCaseValue == AuthTypePat.String() && experimentalAuthEnabled() {
79+
cfg.Authenticator = AuthTypePat
80+
return nil
7581
} else {
7682
// possibly Okta case
7783
oktaURLString, err := url.QueryUnescape(lowerCaseValue)
@@ -121,6 +127,8 @@ func (authType AuthType) String() string {
121127
return "TOKENACCESSOR"
122128
case AuthTypeUsernamePasswordMFA:
123129
return "USERNAME_PASSWORD_MFA"
130+
case AuthTypePat:
131+
return "PROGRAMMATIC_ACCESS_TOKEN"
124132
default:
125133
return "UNKNOWN"
126134
}
@@ -440,6 +448,17 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
440448
return nil, err
441449
}
442450
requestMain.Token = jwtTokenString
451+
case AuthTypePat:
452+
if !experimentalAuthEnabled() {
453+
return nil, errors.New("programmatic access tokens are not ready to use")
454+
}
455+
logger.WithContext(sc.ctx).Info("Programmatic access token")
456+
requestMain.Authenticator = AuthTypePat.String()
457+
requestMain.LoginName = sc.cfg.User
458+
requestMain.Token = sc.cfg.Token
459+
if sc.cfg.Password != "" && sc.cfg.Token == "" {
460+
requestMain.Token = sc.cfg.Password
461+
}
443462
case AuthTypeSnowflake:
444463
logger.WithContext(sc.ctx).Info("Username and password")
445464
requestMain.LoginName = sc.cfg.User
@@ -574,3 +593,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
574593
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
575594
return nil
576595
}
596+
597+
func experimentalAuthEnabled() bool {
598+
val, ok := os.LookupEnv("ENABLE_EXPERIMENTAL_AUTHENTICATION")
599+
return ok && strings.EqualFold(val, "true")
600+
}

auth_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -1003,3 +1003,60 @@ func TestContextPropagatedToAuthWhenUsingOpenDB(t *testing.T) {
10031003
assertStringContainsE(t, err.Error(), "context deadline exceeded")
10041004
cancel()
10051005
}
1006+
1007+
func TestPatSuccessfulFlow(t *testing.T) {
1008+
cfg := wiremock.connectionConfig()
1009+
cfg.Authenticator = AuthTypePat
1010+
cfg.Token = "some PAT"
1011+
testPatSuccessfulFlow(t, cfg)
1012+
}
1013+
1014+
func testPatSuccessfulFlow(t *testing.T, cfg *Config) {
1015+
skipOnJenkins(t, "wiremock is not enabled")
1016+
enableExperimentalAuth(t)
1017+
wiremock.registerMappings(t,
1018+
wiremockMapping{filePath: "auth/pat/successful_flow.json"},
1019+
wiremockMapping{filePath: "select1.json", params: map[string]string{
1020+
"%AUTHORIZATION_HEADER%": "Snowflake Token=\\\"session token\\\""},
1021+
},
1022+
)
1023+
connector := NewConnector(SnowflakeDriver{}, *cfg)
1024+
db := sql.OpenDB(connector)
1025+
rows, err := db.Query("SELECT 1")
1026+
assertNilF(t, err)
1027+
var v int
1028+
assertTrueE(t, rows.Next())
1029+
assertNilF(t, rows.Scan(&v))
1030+
assertEqualE(t, v, 1)
1031+
}
1032+
1033+
func enableExperimentalAuth(t *testing.T) {
1034+
err := os.Setenv("ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
1035+
assertNilF(t, err)
1036+
}
1037+
1038+
func TestPatSuccessfulFlowWithPatAsPasswordWithPatAuthenticator(t *testing.T) {
1039+
cfg := wiremock.connectionConfig()
1040+
cfg.Authenticator = AuthTypePat
1041+
cfg.Password = "some PAT"
1042+
testPatSuccessfulFlow(t, cfg)
1043+
}
1044+
1045+
func TestPatInvalidToken(t *testing.T) {
1046+
skipOnJenkins(t, "wiremock is not enabled")
1047+
enableExperimentalAuth(t)
1048+
wiremock.registerMappings(t,
1049+
wiremockMapping{filePath: "auth/pat/invalid_token.json"},
1050+
)
1051+
cfg := wiremock.connectionConfig()
1052+
cfg.Authenticator = AuthTypePat
1053+
cfg.Token = "some PAT"
1054+
connector := NewConnector(SnowflakeDriver{}, *cfg)
1055+
db := sql.OpenDB(connector)
1056+
_, err := db.Query("SELECT 1")
1057+
assertNotNilF(t, err)
1058+
var se *SnowflakeError
1059+
assertTrueF(t, errors.As(err, &se))
1060+
assertEqualE(t, se.Number, 394400)
1061+
assertEqualE(t, se.Message, "Programmatic access token is invalid.")
1062+
}
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pat
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
include ../../gosnowflake.mak
2+
CMD_TARGET=pat
3+
4+
## Install
5+
install: cinstall
6+
7+
## Run
8+
run: crun
9+
10+
## Lint
11+
lint: clint
12+
13+
## Format source codes
14+
fmt: cfmt
15+
16+
.PHONY: install run lint fmt

cmd/programmatic_access_token/pat.go

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// you have to configure PAT on your user
2+
3+
package main
4+
5+
import (
6+
"database/sql"
7+
"flag"
8+
"fmt"
9+
sf "github.com/snowflakedb/gosnowflake"
10+
"log"
11+
)
12+
13+
func main() {
14+
if !flag.Parsed() {
15+
flag.Parse()
16+
}
17+
18+
cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{
19+
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
20+
{Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true},
21+
{Name: "Token", EnvName: "SNOWFLAKE_TEST_PAT", FailOnMissing: true},
22+
{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false},
23+
{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false},
24+
{Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false},
25+
})
26+
cfg.Authenticator = sf.AuthTypePat
27+
if err != nil {
28+
log.Fatalf("cannot build config. %v", err)
29+
}
30+
31+
connector := sf.NewConnector(sf.SnowflakeDriver{}, *cfg)
32+
db := sql.OpenDB(connector)
33+
defer db.Close()
34+
35+
query := "SELECT 1"
36+
rows, err := db.Query(query)
37+
if err != nil {
38+
log.Fatalf("failed to run a query. %v, err: %v", query, err)
39+
}
40+
defer rows.Close()
41+
var v int
42+
if !rows.Next() {
43+
log.Fatalf("no rows returned")
44+
}
45+
if err = rows.Scan(&v); err != nil {
46+
log.Fatalf("failed to scan rows. %v", err)
47+
}
48+
if v != 1 {
49+
log.Fatalf("unexpected result, expected 1, got %v", v)
50+
}
51+
fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!\n", query)
52+
}

driver.go

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (dri
4747
if err := config.Validate(); err != nil {
4848
return nil, err
4949
}
50+
if config.Params == nil {
51+
config.Params = make(map[string]*string)
52+
}
5053
if config.Tracing != "" {
5154
if err := logger.SetLogLevel(config.Tracing); err != nil {
5255
return nil, err

dsn.go

+17-3
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,11 @@ func fillMissingConfigParameters(cfg *Config) error {
460460
if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" {
461461
return errEmptyPassword()
462462
}
463+
464+
if authRequiresEitherPasswordOrToken(cfg) && strings.TrimSpace(cfg.Password) == "" && strings.TrimSpace(cfg.Token) == "" {
465+
return errEmptyPasswordAndToken()
466+
}
467+
463468
if strings.Trim(cfg.Protocol, " ") == "" {
464469
cfg.Protocol = "https"
465470
}
@@ -576,14 +581,20 @@ func buildHostFromAccountAndRegion(account, region string) string {
576581
func authRequiresUser(cfg *Config) bool {
577582
return cfg.Authenticator != AuthTypeOAuth &&
578583
cfg.Authenticator != AuthTypeTokenAccessor &&
579-
cfg.Authenticator != AuthTypeExternalBrowser
584+
cfg.Authenticator != AuthTypeExternalBrowser &&
585+
cfg.Authenticator != AuthTypePat
580586
}
581587

582588
func authRequiresPassword(cfg *Config) bool {
583589
return cfg.Authenticator != AuthTypeOAuth &&
584590
cfg.Authenticator != AuthTypeTokenAccessor &&
585591
cfg.Authenticator != AuthTypeExternalBrowser &&
586-
cfg.Authenticator != AuthTypeJwt
592+
cfg.Authenticator != AuthTypeJwt &&
593+
cfg.Authenticator != AuthTypePat
594+
}
595+
596+
func authRequiresEitherPasswordOrToken(cfg *Config) bool {
597+
return cfg.Authenticator == AuthTypePat
587598
}
588599

589600
// transformAccountToHost transforms account to host
@@ -905,7 +916,7 @@ type ConfigParam struct {
905916

906917
// GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config
907918
func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
908-
var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string
919+
var account, user, password, token, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string
909920
var privateKey *rsa.PrivateKey
910921
var err error
911922
if len(properties) == 0 || properties == nil {
@@ -923,6 +934,8 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
923934
user = value
924935
case "Password":
925936
password = value
937+
case "Token":
938+
token = value
926939
case "Role":
927940
role = value
928941
case "Host":
@@ -963,6 +976,7 @@ func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) {
963976
Account: account,
964977
User: user,
965978
Password: password,
979+
Token: token,
966980
Role: role,
967981
Host: host,
968982
Port: port,

0 commit comments

Comments
 (0)