Skip to content

Commit 1ef2f69

Browse files
feat: send first sasl message
1 parent 2302797 commit 1ef2f69

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

internal/protocol/auth.go

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ package protocol
22

33
import (
44
"crypto/md5"
5+
"crypto/rand"
6+
"encoding/base64"
57
"encoding/binary"
68
"encoding/hex"
79
"fmt"
810
"postgres-protocol-go/internal/pool"
911
"postgres-protocol-go/internal/protocol/messages"
1012
"postgres-protocol-go/pkg/utils"
13+
"strings"
1114
)
1215

1316
func ProcessAuth(pgConnection PgConnection) error {
@@ -25,7 +28,6 @@ func ProcessAuth(pgConnection PgConnection) error {
2528

2629
authType := parseAuthType(answer)
2730

28-
// todo: implement SCRAM-SHA-256
2931
switch authType {
3032
case authenticationOk:
3133
if pgConnection.isVerbose() {
@@ -49,6 +51,49 @@ func ProcessAuth(pgConnection PgConnection) error {
4951
}
5052
}
5153
}
54+
case authenticationSASL:
55+
pgConnection.saslMethod = strings.Trim(string(answer[9:]), "\x00 \n\r")
56+
switch pgConnection.saslMethod {
57+
// https://datatracker.ietf.org/doc/html/rfc7677
58+
case "SCRAM-SHA-256":
59+
initialResponse, err := buildSCRAMInitialResponse(pgConnection.connConfig.User)
60+
if err != nil {
61+
return err
62+
}
63+
64+
initialResponseBytes := []byte(initialResponse)
65+
66+
lengthBytes := make([]byte, 4)
67+
binary.BigEndian.PutUint32(lengthBytes, uint32(len(initialResponseBytes)))
68+
fmt.Println(lengthBytes)
69+
70+
buff := pool.NewWriteBuffer(1024)
71+
buff.StartMessage(messages.SASLInitial)
72+
buff.WriteString(pgConnection.saslMethod)
73+
buff.WriteInt32(int32(len(initialResponseBytes)))
74+
_, err = buff.Write(initialResponseBytes)
75+
if err != nil {
76+
return err
77+
}
78+
buff.FinishMessage()
79+
80+
err = pgConnection.sendMessage(buff)
81+
if err != nil {
82+
return err
83+
}
84+
85+
return ProcessAuth(pgConnection)
86+
default:
87+
return fmt.Errorf("SASL authentication method %s is not supported", pgConnection.saslMethod)
88+
}
89+
case authenticationSASLContinue:
90+
fmt.Println("SASLContinue", pgConnection.saslMethod)
91+
switch pgConnection.saslMethod {
92+
case "SCRAM-SHA-256":
93+
return fmt.Errorf("not implemented")
94+
default:
95+
return fmt.Errorf("SASL authentication method %s is not supported", pgConnection.saslMethod)
96+
}
5297
case authenticationMD5Password:
5398
if pgConnection.connConfig.Password == nil {
5499
return fmt.Errorf("password is required for MD5 authentication")
@@ -119,3 +164,21 @@ func parseAuthType(message []byte) uint32 {
119164
func parseSalt(message []byte) string {
120165
return string(message[9:13])
121166
}
167+
168+
func buildSCRAMInitialResponse(username string) (string, error) {
169+
nonce, err := generateNonce()
170+
if err != nil {
171+
return "", fmt.Errorf("failed to generate nonce: %v", err)
172+
}
173+
// Format: "n,,n=<username>,r=<nonce>"
174+
initialResponse := fmt.Sprintf("n,,n=%s,r=%s", username, nonce)
175+
return initialResponse, nil
176+
}
177+
178+
func generateNonce() (string, error) {
179+
nonceBytes := make([]byte, 16) // 16 bytes = 128 bits of randomness
180+
if _, err := rand.Read(nonceBytes); err != nil {
181+
return "", err
182+
}
183+
return base64.StdEncoding.EncodeToString(nonceBytes), nil
184+
}

internal/protocol/messages/messages.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ const (
77
Startup = 0 // No identifier
88
SSL = 0 // No identifier
99
Auth = 'R'
10+
SASLInitial = 'p'
11+
SASLResponse = 'p'
1012
Password = 'p'
1113
Error = 'E'
1214
SimpleQuery = 'Q'

internal/protocol/pg_connection.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
type PgConnection struct {
1616
conn net.Conn
17+
saslMethod string
1718
connConfig models.ConnConfig
1819
driveConfig models.DriveConfig
1920
}

0 commit comments

Comments
 (0)