|
1 | 1 | package protocol
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bytes" |
| 5 | + "crypto/hmac" |
4 | 6 | "crypto/md5"
|
5 | 7 | "crypto/rand"
|
| 8 | + "crypto/sha256" |
6 | 9 | "encoding/base64"
|
7 | 10 | "encoding/binary"
|
8 | 11 | "encoding/hex"
|
9 | 12 | "fmt"
|
| 13 | + pbkdf2 "postgres-protocol-go" |
10 | 14 | "postgres-protocol-go/internal/pool"
|
11 | 15 | "postgres-protocol-go/internal/protocol/messages"
|
12 | 16 | "postgres-protocol-go/pkg/utils"
|
| 17 | + "strconv" |
13 | 18 | "strings"
|
14 | 19 | )
|
15 | 20 |
|
16 | 21 | func ProcessAuth(pgConnection PgConnection) error {
|
17 |
| - answer, err := pgConnection.readMessage() |
| 22 | + var ( |
| 23 | + saslMethod string |
| 24 | + clientNonce string |
| 25 | + expectedServerSig []byte |
| 26 | + ) |
18 | 27 |
|
19 |
| - if err != nil { |
20 |
| - return err |
21 |
| - } |
| 28 | + for { |
| 29 | + answer, err := pgConnection.readMessage() |
| 30 | + if err != nil { |
| 31 | + return err |
| 32 | + } |
22 | 33 |
|
23 |
| - identifier := utils.ParseIdentifier(answer) |
| 34 | + identifier := utils.ParseIdentifier(answer) |
| 35 | + if identifier != string(messages.Auth) { |
| 36 | + return fmt.Errorf("expected auth message, got %s", identifier) |
| 37 | + } |
24 | 38 |
|
25 |
| - if identifier != string(messages.Auth) { |
26 |
| - return fmt.Errorf("expected auth message, got %s", identifier) |
27 |
| - } |
| 39 | + authType := parseAuthType(answer) |
28 | 40 |
|
29 |
| - authType := parseAuthType(answer) |
| 41 | + switch authType { |
| 42 | + case authenticationOk: |
| 43 | + if pgConnection.isVerbose() { |
| 44 | + fmt.Println("Authentication successful") |
| 45 | + fmt.Println("Waiting for ReadyForQuery message") |
| 46 | + } |
| 47 | + return waitForReady(pgConnection) |
| 48 | + case authenticationSASL: |
| 49 | + saslMethod = strings.Trim(string(answer[9:]), "\x00 \n\r") |
| 50 | + switch saslMethod { |
| 51 | + case "SCRAM-SHA-256": |
| 52 | + nonce, initialResponse, err := buildSCRAMInitialResponse(pgConnection.connConfig.User) |
| 53 | + if err != nil { |
| 54 | + return err |
| 55 | + } |
| 56 | + clientNonce = nonce |
30 | 57 |
|
31 |
| - switch authType { |
32 |
| - case authenticationOk: |
33 |
| - if pgConnection.isVerbose() { |
34 |
| - fmt.Println("Authentication successful") |
35 |
| - fmt.Println("Waiting for ReadyForQuery message") |
36 |
| - } |
| 58 | + buff := pool.NewWriteBuffer(1024) |
| 59 | + buff.StartMessage(messages.SASLInitial) |
| 60 | + buff.WriteString(saslMethod) |
| 61 | + buff.WriteInt32(int32(len(initialResponse))) |
| 62 | + buff.Write(initialResponse) |
| 63 | + buff.FinishMessage() |
37 | 64 |
|
38 |
| - for { |
39 |
| - message, err := pgConnection.readMessage() |
40 |
| - if err != nil { |
41 |
| - return err |
| 65 | + if err := pgConnection.sendMessage(buff); err != nil { |
| 66 | + return err |
| 67 | + } |
| 68 | + default: |
| 69 | + return fmt.Errorf("SASL authentication method %s is not supported", saslMethod) |
42 | 70 | }
|
| 71 | + case authenticationSASLContinue: |
| 72 | + switch saslMethod { |
| 73 | + case "SCRAM-SHA-256": |
| 74 | + if clientNonce == "" { |
| 75 | + return fmt.Errorf("client nonce not available") |
| 76 | + } |
| 77 | + if pgConnection.connConfig.Password == nil { |
| 78 | + return fmt.Errorf("password is required for SCRAM authentication") |
| 79 | + } |
43 | 80 |
|
44 |
| - // there are other useful messages that can be processed here like client_enconding, DateStyle, BackendKeyData, etc. |
45 |
| - switch utils.ParseIdentifier(message) { |
46 |
| - case string(messages.ReadyForQuery): |
47 |
| - return nil |
48 |
| - default: |
49 |
| - if pgConnection.isVerbose() { |
50 |
| - fmt.Printf("Auth: Unknown message: %s\n", string(message)) |
| 81 | + serverMessage := string(answer[9:]) |
| 82 | + parts := strings.Split(serverMessage, ",") |
| 83 | + var serverNonce, saltB64 string |
| 84 | + var iterations int |
| 85 | + for _, part := range parts { |
| 86 | + kv := strings.SplitN(part, "=", 2) |
| 87 | + if len(kv) != 2 { |
| 88 | + return fmt.Errorf("invalid part in SASLContinue message: %s", part) |
| 89 | + } |
| 90 | + key, value := kv[0], kv[1] |
| 91 | + switch key { |
| 92 | + case "r": |
| 93 | + serverNonce = value |
| 94 | + case "s": |
| 95 | + saltB64 = value |
| 96 | + case "i": |
| 97 | + i, err := strconv.Atoi(value) |
| 98 | + if err != nil { |
| 99 | + return fmt.Errorf("invalid iteration count: %v", err) |
| 100 | + } |
| 101 | + iterations = i |
| 102 | + default: |
| 103 | + return fmt.Errorf("unexpected key in SASLContinue message: %s", key) |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + if !strings.HasPrefix(serverNonce, clientNonce) { |
| 108 | + return fmt.Errorf("server nonce does not start with client nonce") |
| 109 | + } |
| 110 | + |
| 111 | + salt, err := base64.StdEncoding.DecodeString(saltB64) |
| 112 | + if err != nil { |
| 113 | + return fmt.Errorf("failed to decode salt: %v", err) |
| 114 | + } |
| 115 | + |
| 116 | + password := []byte(*pgConnection.connConfig.Password) |
| 117 | + saltedPassword := pbkdf2.Key(password, salt, iterations, 32, sha256.New) |
| 118 | + |
| 119 | + clientKey := hmac.New(sha256.New, saltedPassword) |
| 120 | + clientKey.Write([]byte("Client Key")) |
| 121 | + clientKeyBytes := clientKey.Sum(nil) |
| 122 | + |
| 123 | + storedKey := sha256.Sum256(clientKeyBytes) |
| 124 | + |
| 125 | + clientFirstBare := fmt.Sprintf("n=%s,r=%s", pgConnection.connConfig.User, clientNonce) |
| 126 | + serverFirst := fmt.Sprintf("r=%s,s=%s,i=%d", serverNonce, saltB64, iterations) |
| 127 | + clientFinalWithoutProof := fmt.Sprintf("c=biws,r=%s", serverNonce) |
| 128 | + authMessage := strings.Join([]string{clientFirstBare, serverFirst, clientFinalWithoutProof}, ",") |
| 129 | + |
| 130 | + clientSignature := hmac.New(sha256.New, storedKey[:]) |
| 131 | + clientSignature.Write([]byte(authMessage)) |
| 132 | + clientSignatureBytes := clientSignature.Sum(nil) |
| 133 | + |
| 134 | + clientProof := make([]byte, len(clientKeyBytes)) |
| 135 | + for i := 0; i < len(clientKeyBytes); i++ { |
| 136 | + clientProof[i] = clientKeyBytes[i] ^ clientSignatureBytes[i] |
| 137 | + } |
| 138 | + clientProofB64 := base64.StdEncoding.EncodeToString(clientProof) |
| 139 | + |
| 140 | + serverKey := hmac.New(sha256.New, saltedPassword) |
| 141 | + serverKey.Write([]byte("Server Key")) |
| 142 | + serverKeyBytes := serverKey.Sum(nil) |
| 143 | + |
| 144 | + expectedServerSigHasher := hmac.New(sha256.New, serverKeyBytes) |
| 145 | + expectedServerSigHasher.Write([]byte(authMessage)) |
| 146 | + expectedServerSig = expectedServerSigHasher.Sum(nil) |
| 147 | + |
| 148 | + clientFinalMessage := fmt.Sprintf("c=biws,r=%s,p=%s", serverNonce, clientProofB64) |
| 149 | + |
| 150 | + buf := pool.NewWriteBuffer(1024) |
| 151 | + buf.StartMessage(messages.SASLResponse) |
| 152 | + buf.Write([]byte(clientFinalMessage)) |
| 153 | + buf.FinishMessage() |
| 154 | + |
| 155 | + if err := pgConnection.sendMessage(buf); err != nil { |
| 156 | + return err |
51 | 157 | }
|
| 158 | + default: |
| 159 | + return fmt.Errorf("SASL authentication method %s is not supported", saslMethod) |
52 | 160 | }
|
53 |
| - } |
54 |
| - case authenticationSASL: |
55 |
| - pgConnection.saslMethod = strings.Trim(string(answer[9:]), "\x00 \n\r") |
56 |
| - switch pgConnection.saslMethod { |
57 |
| - // https://datatracker.ietf.org/doc/html/rfc7677 |
58 |
| - case "SCRAM-SHA-256": |
59 |
| - initialResponse, err := buildSCRAMInitialResponse(pgConnection.connConfig.User) |
60 |
| - if err != nil { |
61 |
| - return err |
| 161 | + case authenticationSASLFinal: |
| 162 | + switch saslMethod { |
| 163 | + case "SCRAM-SHA-256": |
| 164 | + serverMessage := string(answer[9:]) |
| 165 | + |
| 166 | + fmt.Println("server messsage:", serverMessage) |
| 167 | + var serverSigB64 string |
| 168 | + for _, part := range strings.Split(serverMessage, ",") { |
| 169 | + kv := strings.SplitN(part, "=", 2) |
| 170 | + if len(kv) != 2 { |
| 171 | + continue |
| 172 | + } |
| 173 | + if kv[0] == "v" { |
| 174 | + serverSigB64 = kv[1] |
| 175 | + break |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + if serverSigB64 == "" { |
| 180 | + return fmt.Errorf("missing server signature in SASLFinal message") |
| 181 | + } |
| 182 | + |
| 183 | + serverSig, err := base64.StdEncoding.DecodeString(serverSigB64) |
| 184 | + if err != nil { |
| 185 | + return fmt.Errorf("failed to decode server signature: %v", err) |
| 186 | + } |
| 187 | + |
| 188 | + if !bytes.Equal(serverSig, expectedServerSig) { |
| 189 | + return fmt.Errorf("server signature mismatch") |
| 190 | + } |
| 191 | + default: |
| 192 | + return fmt.Errorf("SASL authentication method %s is not supported", saslMethod) |
62 | 193 | }
|
63 | 194 |
|
64 |
| - initialResponseBytes := []byte(initialResponse) |
| 195 | + case authenticationMD5Password: |
| 196 | + if pgConnection.connConfig.Password == nil { |
| 197 | + return fmt.Errorf("password is required for MD5 authentication") |
| 198 | + } |
| 199 | + salt := parseSalt(answer) |
| 200 | + hashedPassword := hashPasswordMD5(*pgConnection.connConfig.Password, pgConnection.connConfig.User, string(salt)) |
65 | 201 |
|
66 |
| - lengthBytes := make([]byte, 4) |
67 |
| - binary.BigEndian.PutUint32(lengthBytes, uint32(len(initialResponseBytes))) |
68 |
| - fmt.Println(lengthBytes) |
| 202 | + buf := pool.NewWriteBuffer(1024) |
| 203 | + buf.StartMessage(messages.Password) |
| 204 | + buf.WriteString(hashedPassword) |
| 205 | + buf.FinishMessage() |
69 | 206 |
|
70 |
| - buff := pool.NewWriteBuffer(1024) |
71 |
| - buff.StartMessage(messages.SASLInitial) |
72 |
| - buff.WriteString(pgConnection.saslMethod) |
73 |
| - buff.WriteInt32(int32(len(initialResponseBytes))) |
74 |
| - _, err = buff.Write(initialResponseBytes) |
75 |
| - if err != nil { |
| 207 | + if err := pgConnection.sendMessage(buf); err != nil { |
76 | 208 | return err
|
77 | 209 | }
|
78 |
| - buff.FinishMessage() |
79 | 210 |
|
80 |
| - err = pgConnection.sendMessage(buff) |
81 |
| - if err != nil { |
82 |
| - return err |
| 211 | + case authenticationCleartextPassword: |
| 212 | + if pgConnection.connConfig.Password == nil { |
| 213 | + return fmt.Errorf("password is required for cleartext authentication") |
83 | 214 | }
|
84 | 215 |
|
85 |
| - return ProcessAuth(pgConnection) |
86 |
| - default: |
87 |
| - return fmt.Errorf("SASL authentication method %s is not supported", pgConnection.saslMethod) |
88 |
| - } |
89 |
| - case authenticationSASLContinue: |
90 |
| - fmt.Println("SASLContinue", pgConnection.saslMethod) |
91 |
| - switch pgConnection.saslMethod { |
92 |
| - case "SCRAM-SHA-256": |
93 |
| - return fmt.Errorf("not implemented") |
| 216 | + buf := pool.NewWriteBuffer(1024) |
| 217 | + buf.StartMessage(messages.Password) |
| 218 | + buf.WriteString(*pgConnection.connConfig.Password) |
| 219 | + buf.FinishMessage() |
| 220 | + |
| 221 | + if err := pgConnection.sendMessage(buf); err != nil { |
| 222 | + return err |
| 223 | + } |
94 | 224 | default:
|
95 |
| - return fmt.Errorf("SASL authentication method %s is not supported", pgConnection.saslMethod) |
96 |
| - } |
97 |
| - case authenticationMD5Password: |
98 |
| - if pgConnection.connConfig.Password == nil { |
99 |
| - return fmt.Errorf("password is required for MD5 authentication") |
| 225 | + return fmt.Errorf("unsupported authentication method: %d", authType) |
100 | 226 | }
|
| 227 | + } |
| 228 | +} |
101 | 229 |
|
102 |
| - salt := parseSalt(answer) |
103 |
| - hashedPassword := hashPasswordMD5(*pgConnection.connConfig.Password, pgConnection.connConfig.User, string(salt)) |
104 |
| - |
105 |
| - buf := pool.NewWriteBuffer(1024) |
106 |
| - buf.StartMessage(messages.Password) |
107 |
| - buf.WriteString(hashedPassword) |
108 |
| - buf.FinishMessage() |
109 |
| - |
110 |
| - err := pgConnection.sendMessage(buf) |
111 |
| - |
| 230 | +func waitForReady(pgConnection PgConnection) error { |
| 231 | + for { |
| 232 | + message, err := pgConnection.readMessage() |
112 | 233 | if err != nil {
|
113 | 234 | return err
|
114 | 235 | }
|
115 | 236 |
|
116 |
| - return ProcessAuth(pgConnection) |
117 |
| - case authenticationCleartextPassword: |
118 |
| - if pgConnection.connConfig.Password == nil { |
119 |
| - return fmt.Errorf("password is required for cleartext authentication") |
| 237 | + switch utils.ParseIdentifier(message) { |
| 238 | + case string(messages.ReadyForQuery): |
| 239 | + return nil |
| 240 | + default: |
| 241 | + if pgConnection.isVerbose() { |
| 242 | + fmt.Printf("Auth: Unknown message: %s\n", string(message)) |
| 243 | + } |
120 | 244 | }
|
| 245 | + } |
| 246 | +} |
121 | 247 |
|
122 |
| - buf := pool.NewWriteBuffer(1024) |
123 |
| - buf.StartMessage(messages.Password) |
124 |
| - buf.WriteString(*pgConnection.connConfig.Password) |
125 |
| - buf.FinishMessage() |
126 |
| - |
127 |
| - err := pgConnection.sendMessage(buf) |
128 |
| - if err != nil { |
129 |
| - return err |
130 |
| - } |
| 248 | +func buildSCRAMInitialResponse(username string) (string, []byte, error) { |
| 249 | + nonce, err := generateNonce() |
| 250 | + if err != nil { |
| 251 | + return "", nil, fmt.Errorf("failed to generate nonce: %v", err) |
| 252 | + } |
| 253 | + initialResponse := fmt.Sprintf("n,,n=%s,r=%s", username, nonce) |
| 254 | + return nonce, []byte(initialResponse), nil |
| 255 | +} |
131 | 256 |
|
132 |
| - return ProcessAuth(pgConnection) |
133 |
| - default: |
134 |
| - return fmt.Errorf("unsupported authentication method: %d", authType) |
| 257 | +func generateNonce() (string, error) { |
| 258 | + nonceBytes := make([]byte, 16) |
| 259 | + if _, err := rand.Read(nonceBytes); err != nil { |
| 260 | + return "", err |
135 | 261 | }
|
| 262 | + return base64.StdEncoding.EncodeToString(nonceBytes), nil |
136 | 263 | }
|
137 | 264 |
|
138 | 265 | func hashPasswordMD5(password, username, salt string) string {
|
@@ -164,21 +291,3 @@ func parseAuthType(message []byte) uint32 {
|
164 | 291 | func parseSalt(message []byte) string {
|
165 | 292 | return string(message[9:13])
|
166 | 293 | }
|
167 |
| - |
168 |
| -func buildSCRAMInitialResponse(username string) (string, error) { |
169 |
| - nonce, err := generateNonce() |
170 |
| - if err != nil { |
171 |
| - return "", fmt.Errorf("failed to generate nonce: %v", err) |
172 |
| - } |
173 |
| - // Format: "n,,n=<username>,r=<nonce>" |
174 |
| - initialResponse := fmt.Sprintf("n,,n=%s,r=%s", username, nonce) |
175 |
| - return initialResponse, nil |
176 |
| -} |
177 |
| - |
178 |
| -func generateNonce() (string, error) { |
179 |
| - nonceBytes := make([]byte, 16) // 16 bytes = 128 bits of randomness |
180 |
| - if _, err := rand.Read(nonceBytes); err != nil { |
181 |
| - return "", err |
182 |
| - } |
183 |
| - return base64.StdEncoding.EncodeToString(nonceBytes), nil |
184 |
| -} |
|
0 commit comments