Skip to content

Commit d7988bd

Browse files
feat: add SCRAM-SHA-256 support
1 parent 1ef2f69 commit d7988bd

File tree

3 files changed

+294
-110
lines changed

3 files changed

+294
-110
lines changed

internal/protocol/auth.go

Lines changed: 216 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,138 +1,265 @@
11
package protocol
22

33
import (
4+
"bytes"
5+
"crypto/hmac"
46
"crypto/md5"
57
"crypto/rand"
8+
"crypto/sha256"
69
"encoding/base64"
710
"encoding/binary"
811
"encoding/hex"
912
"fmt"
13+
pbkdf2 "postgres-protocol-go"
1014
"postgres-protocol-go/internal/pool"
1115
"postgres-protocol-go/internal/protocol/messages"
1216
"postgres-protocol-go/pkg/utils"
17+
"strconv"
1318
"strings"
1419
)
1520

1621
func ProcessAuth(pgConnection PgConnection) error {
17-
answer, err := pgConnection.readMessage()
22+
var (
23+
saslMethod string
24+
clientNonce string
25+
expectedServerSig []byte
26+
)
1827

19-
if err != nil {
20-
return err
21-
}
28+
for {
29+
answer, err := pgConnection.readMessage()
30+
if err != nil {
31+
return err
32+
}
2233

23-
identifier := utils.ParseIdentifier(answer)
34+
identifier := utils.ParseIdentifier(answer)
35+
if identifier != string(messages.Auth) {
36+
return fmt.Errorf("expected auth message, got %s", identifier)
37+
}
2438

25-
if identifier != string(messages.Auth) {
26-
return fmt.Errorf("expected auth message, got %s", identifier)
27-
}
39+
authType := parseAuthType(answer)
2840

29-
authType := parseAuthType(answer)
41+
switch authType {
42+
case authenticationOk:
43+
if pgConnection.isVerbose() {
44+
fmt.Println("Authentication successful")
45+
fmt.Println("Waiting for ReadyForQuery message")
46+
}
47+
return waitForReady(pgConnection)
48+
case authenticationSASL:
49+
saslMethod = strings.Trim(string(answer[9:]), "\x00 \n\r")
50+
switch saslMethod {
51+
case "SCRAM-SHA-256":
52+
nonce, initialResponse, err := buildSCRAMInitialResponse(pgConnection.connConfig.User)
53+
if err != nil {
54+
return err
55+
}
56+
clientNonce = nonce
3057

31-
switch authType {
32-
case authenticationOk:
33-
if pgConnection.isVerbose() {
34-
fmt.Println("Authentication successful")
35-
fmt.Println("Waiting for ReadyForQuery message")
36-
}
58+
buff := pool.NewWriteBuffer(1024)
59+
buff.StartMessage(messages.SASLInitial)
60+
buff.WriteString(saslMethod)
61+
buff.WriteInt32(int32(len(initialResponse)))
62+
buff.Write(initialResponse)
63+
buff.FinishMessage()
3764

38-
for {
39-
message, err := pgConnection.readMessage()
40-
if err != nil {
41-
return err
65+
if err := pgConnection.sendMessage(buff); err != nil {
66+
return err
67+
}
68+
default:
69+
return fmt.Errorf("SASL authentication method %s is not supported", saslMethod)
4270
}
71+
case authenticationSASLContinue:
72+
switch saslMethod {
73+
case "SCRAM-SHA-256":
74+
if clientNonce == "" {
75+
return fmt.Errorf("client nonce not available")
76+
}
77+
if pgConnection.connConfig.Password == nil {
78+
return fmt.Errorf("password is required for SCRAM authentication")
79+
}
4380

44-
// there are other useful messages that can be processed here like client_enconding, DateStyle, BackendKeyData, etc.
45-
switch utils.ParseIdentifier(message) {
46-
case string(messages.ReadyForQuery):
47-
return nil
48-
default:
49-
if pgConnection.isVerbose() {
50-
fmt.Printf("Auth: Unknown message: %s\n", string(message))
81+
serverMessage := string(answer[9:])
82+
parts := strings.Split(serverMessage, ",")
83+
var serverNonce, saltB64 string
84+
var iterations int
85+
for _, part := range parts {
86+
kv := strings.SplitN(part, "=", 2)
87+
if len(kv) != 2 {
88+
return fmt.Errorf("invalid part in SASLContinue message: %s", part)
89+
}
90+
key, value := kv[0], kv[1]
91+
switch key {
92+
case "r":
93+
serverNonce = value
94+
case "s":
95+
saltB64 = value
96+
case "i":
97+
i, err := strconv.Atoi(value)
98+
if err != nil {
99+
return fmt.Errorf("invalid iteration count: %v", err)
100+
}
101+
iterations = i
102+
default:
103+
return fmt.Errorf("unexpected key in SASLContinue message: %s", key)
104+
}
105+
}
106+
107+
if !strings.HasPrefix(serverNonce, clientNonce) {
108+
return fmt.Errorf("server nonce does not start with client nonce")
109+
}
110+
111+
salt, err := base64.StdEncoding.DecodeString(saltB64)
112+
if err != nil {
113+
return fmt.Errorf("failed to decode salt: %v", err)
114+
}
115+
116+
password := []byte(*pgConnection.connConfig.Password)
117+
saltedPassword := pbkdf2.Key(password, salt, iterations, 32, sha256.New)
118+
119+
clientKey := hmac.New(sha256.New, saltedPassword)
120+
clientKey.Write([]byte("Client Key"))
121+
clientKeyBytes := clientKey.Sum(nil)
122+
123+
storedKey := sha256.Sum256(clientKeyBytes)
124+
125+
clientFirstBare := fmt.Sprintf("n=%s,r=%s", pgConnection.connConfig.User, clientNonce)
126+
serverFirst := fmt.Sprintf("r=%s,s=%s,i=%d", serverNonce, saltB64, iterations)
127+
clientFinalWithoutProof := fmt.Sprintf("c=biws,r=%s", serverNonce)
128+
authMessage := strings.Join([]string{clientFirstBare, serverFirst, clientFinalWithoutProof}, ",")
129+
130+
clientSignature := hmac.New(sha256.New, storedKey[:])
131+
clientSignature.Write([]byte(authMessage))
132+
clientSignatureBytes := clientSignature.Sum(nil)
133+
134+
clientProof := make([]byte, len(clientKeyBytes))
135+
for i := 0; i < len(clientKeyBytes); i++ {
136+
clientProof[i] = clientKeyBytes[i] ^ clientSignatureBytes[i]
137+
}
138+
clientProofB64 := base64.StdEncoding.EncodeToString(clientProof)
139+
140+
serverKey := hmac.New(sha256.New, saltedPassword)
141+
serverKey.Write([]byte("Server Key"))
142+
serverKeyBytes := serverKey.Sum(nil)
143+
144+
expectedServerSigHasher := hmac.New(sha256.New, serverKeyBytes)
145+
expectedServerSigHasher.Write([]byte(authMessage))
146+
expectedServerSig = expectedServerSigHasher.Sum(nil)
147+
148+
clientFinalMessage := fmt.Sprintf("c=biws,r=%s,p=%s", serverNonce, clientProofB64)
149+
150+
buf := pool.NewWriteBuffer(1024)
151+
buf.StartMessage(messages.SASLResponse)
152+
buf.Write([]byte(clientFinalMessage))
153+
buf.FinishMessage()
154+
155+
if err := pgConnection.sendMessage(buf); err != nil {
156+
return err
51157
}
158+
default:
159+
return fmt.Errorf("SASL authentication method %s is not supported", saslMethod)
52160
}
53-
}
54-
case authenticationSASL:
55-
pgConnection.saslMethod = strings.Trim(string(answer[9:]), "\x00 \n\r")
56-
switch pgConnection.saslMethod {
57-
// https://datatracker.ietf.org/doc/html/rfc7677
58-
case "SCRAM-SHA-256":
59-
initialResponse, err := buildSCRAMInitialResponse(pgConnection.connConfig.User)
60-
if err != nil {
61-
return err
161+
case authenticationSASLFinal:
162+
switch saslMethod {
163+
case "SCRAM-SHA-256":
164+
serverMessage := string(answer[9:])
165+
166+
fmt.Println("server messsage:", serverMessage)
167+
var serverSigB64 string
168+
for _, part := range strings.Split(serverMessage, ",") {
169+
kv := strings.SplitN(part, "=", 2)
170+
if len(kv) != 2 {
171+
continue
172+
}
173+
if kv[0] == "v" {
174+
serverSigB64 = kv[1]
175+
break
176+
}
177+
}
178+
179+
if serverSigB64 == "" {
180+
return fmt.Errorf("missing server signature in SASLFinal message")
181+
}
182+
183+
serverSig, err := base64.StdEncoding.DecodeString(serverSigB64)
184+
if err != nil {
185+
return fmt.Errorf("failed to decode server signature: %v", err)
186+
}
187+
188+
if !bytes.Equal(serverSig, expectedServerSig) {
189+
return fmt.Errorf("server signature mismatch")
190+
}
191+
default:
192+
return fmt.Errorf("SASL authentication method %s is not supported", saslMethod)
62193
}
63194

64-
initialResponseBytes := []byte(initialResponse)
195+
case authenticationMD5Password:
196+
if pgConnection.connConfig.Password == nil {
197+
return fmt.Errorf("password is required for MD5 authentication")
198+
}
199+
salt := parseSalt(answer)
200+
hashedPassword := hashPasswordMD5(*pgConnection.connConfig.Password, pgConnection.connConfig.User, string(salt))
65201

66-
lengthBytes := make([]byte, 4)
67-
binary.BigEndian.PutUint32(lengthBytes, uint32(len(initialResponseBytes)))
68-
fmt.Println(lengthBytes)
202+
buf := pool.NewWriteBuffer(1024)
203+
buf.StartMessage(messages.Password)
204+
buf.WriteString(hashedPassword)
205+
buf.FinishMessage()
69206

70-
buff := pool.NewWriteBuffer(1024)
71-
buff.StartMessage(messages.SASLInitial)
72-
buff.WriteString(pgConnection.saslMethod)
73-
buff.WriteInt32(int32(len(initialResponseBytes)))
74-
_, err = buff.Write(initialResponseBytes)
75-
if err != nil {
207+
if err := pgConnection.sendMessage(buf); err != nil {
76208
return err
77209
}
78-
buff.FinishMessage()
79210

80-
err = pgConnection.sendMessage(buff)
81-
if err != nil {
82-
return err
211+
case authenticationCleartextPassword:
212+
if pgConnection.connConfig.Password == nil {
213+
return fmt.Errorf("password is required for cleartext authentication")
83214
}
84215

85-
return ProcessAuth(pgConnection)
86-
default:
87-
return fmt.Errorf("SASL authentication method %s is not supported", pgConnection.saslMethod)
88-
}
89-
case authenticationSASLContinue:
90-
fmt.Println("SASLContinue", pgConnection.saslMethod)
91-
switch pgConnection.saslMethod {
92-
case "SCRAM-SHA-256":
93-
return fmt.Errorf("not implemented")
216+
buf := pool.NewWriteBuffer(1024)
217+
buf.StartMessage(messages.Password)
218+
buf.WriteString(*pgConnection.connConfig.Password)
219+
buf.FinishMessage()
220+
221+
if err := pgConnection.sendMessage(buf); err != nil {
222+
return err
223+
}
94224
default:
95-
return fmt.Errorf("SASL authentication method %s is not supported", pgConnection.saslMethod)
96-
}
97-
case authenticationMD5Password:
98-
if pgConnection.connConfig.Password == nil {
99-
return fmt.Errorf("password is required for MD5 authentication")
225+
return fmt.Errorf("unsupported authentication method: %d", authType)
100226
}
227+
}
228+
}
101229

102-
salt := parseSalt(answer)
103-
hashedPassword := hashPasswordMD5(*pgConnection.connConfig.Password, pgConnection.connConfig.User, string(salt))
104-
105-
buf := pool.NewWriteBuffer(1024)
106-
buf.StartMessage(messages.Password)
107-
buf.WriteString(hashedPassword)
108-
buf.FinishMessage()
109-
110-
err := pgConnection.sendMessage(buf)
111-
230+
func waitForReady(pgConnection PgConnection) error {
231+
for {
232+
message, err := pgConnection.readMessage()
112233
if err != nil {
113234
return err
114235
}
115236

116-
return ProcessAuth(pgConnection)
117-
case authenticationCleartextPassword:
118-
if pgConnection.connConfig.Password == nil {
119-
return fmt.Errorf("password is required for cleartext authentication")
237+
switch utils.ParseIdentifier(message) {
238+
case string(messages.ReadyForQuery):
239+
return nil
240+
default:
241+
if pgConnection.isVerbose() {
242+
fmt.Printf("Auth: Unknown message: %s\n", string(message))
243+
}
120244
}
245+
}
246+
}
121247

122-
buf := pool.NewWriteBuffer(1024)
123-
buf.StartMessage(messages.Password)
124-
buf.WriteString(*pgConnection.connConfig.Password)
125-
buf.FinishMessage()
126-
127-
err := pgConnection.sendMessage(buf)
128-
if err != nil {
129-
return err
130-
}
248+
func buildSCRAMInitialResponse(username string) (string, []byte, error) {
249+
nonce, err := generateNonce()
250+
if err != nil {
251+
return "", nil, fmt.Errorf("failed to generate nonce: %v", err)
252+
}
253+
initialResponse := fmt.Sprintf("n,,n=%s,r=%s", username, nonce)
254+
return nonce, []byte(initialResponse), nil
255+
}
131256

132-
return ProcessAuth(pgConnection)
133-
default:
134-
return fmt.Errorf("unsupported authentication method: %d", authType)
257+
func generateNonce() (string, error) {
258+
nonceBytes := make([]byte, 16)
259+
if _, err := rand.Read(nonceBytes); err != nil {
260+
return "", err
135261
}
262+
return base64.StdEncoding.EncodeToString(nonceBytes), nil
136263
}
137264

138265
func hashPasswordMD5(password, username, salt string) string {
@@ -164,21 +291,3 @@ func parseAuthType(message []byte) uint32 {
164291
func parseSalt(message []byte) string {
165292
return string(message[9:13])
166293
}
167-
168-
func buildSCRAMInitialResponse(username string) (string, error) {
169-
nonce, err := generateNonce()
170-
if err != nil {
171-
return "", fmt.Errorf("failed to generate nonce: %v", err)
172-
}
173-
// Format: "n,,n=<username>,r=<nonce>"
174-
initialResponse := fmt.Sprintf("n,,n=%s,r=%s", username, nonce)
175-
return initialResponse, nil
176-
}
177-
178-
func generateNonce() (string, error) {
179-
nonceBytes := make([]byte, 16) // 16 bytes = 128 bits of randomness
180-
if _, err := rand.Read(nonceBytes); err != nil {
181-
return "", err
182-
}
183-
return base64.StdEncoding.EncodeToString(nonceBytes), nil
184-
}

internal/protocol/pg_connection.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414

1515
type PgConnection struct {
1616
conn net.Conn
17-
saslMethod string
1817
connConfig models.ConnConfig
1918
driveConfig models.DriveConfig
2019
}
@@ -202,8 +201,6 @@ func parseConnStr(connUrl string) (models.ConnConfig, error) {
202201
continue
203202
}
204203
if strings.HasPrefix(s, "sslmode=") {
205-
fmt.Println("Secure connection")
206-
207204
sslmode := strings.SplitN(s, "=", 2)[1]
208205
if sslmode == "require" {
209206
connConfig.Secure = true

0 commit comments

Comments
 (0)