|
1 | 1 | package adauth |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
4 | 5 | "context" |
| 6 | + "crypto/ecdsa" |
| 7 | + "crypto/ed25519" |
5 | 8 | "crypto/rsa" |
6 | 9 | "crypto/x509" |
7 | 10 | "fmt" |
@@ -73,7 +76,7 @@ func CredentialFromPFXBytes( |
73 | 76 | Domain: domain, |
74 | 77 | } |
75 | 78 |
|
76 | | - key, cert, caCerts, err := pkcs12.DecodeChain(pfxData, pfxPassword) |
| 79 | + key, cert, caCerts, err := DecodePFX(pfxData, pfxPassword) |
77 | 80 | if err != nil { |
78 | 81 | return nil, fmt.Errorf("decode PFX: %w", err) |
79 | 82 | } |
@@ -258,3 +261,76 @@ func splitUserIntoDomainAndUsername(user string) (domain string, username string |
258 | 261 | return "", user |
259 | 262 | } |
260 | 263 | } |
| 264 | + |
| 265 | +// DecodePFX loads the private key, certificate and certificate chain from PFX |
| 266 | +// bytes that may or may not be protected by a password. |
| 267 | +func DecodePFX(pfxData []byte, password string) (privateKey any, cert *x509.Certificate, chain []*x509.Certificate, err error) { |
| 268 | + // In some PFXs, especially those create by Microsoft tools, the cert and |
| 269 | + // chain order is reversed such that pkcs12.DecodeChain returns the CA cert |
| 270 | + // as "cert" and the leaf certificate in the chain (see |
| 271 | + // https://github.com/SSLMate/go-pkcs12/issues/54). Our strategy is that we |
| 272 | + // swap certifiates such that "cert" is the certificate that belongs to the |
| 273 | + // private key and "chain" contains all other certificates. |
| 274 | + privateKey, cert, chain, err = pkcs12.DecodeChain(pfxData, password) |
| 275 | + if err != nil || certMatchesKey(privateKey, cert) { |
| 276 | + return privateKey, cert, chain, err |
| 277 | + } |
| 278 | + |
| 279 | + for i := range chain { |
| 280 | + if !certMatchesKey(privateKey, chain[i]) { |
| 281 | + continue |
| 282 | + } |
| 283 | + |
| 284 | + newCert := chain[i] |
| 285 | + chain[i] = cert |
| 286 | + |
| 287 | + return privateKey, newCert, chain, nil |
| 288 | + } |
| 289 | + |
| 290 | + return privateKey, cert, chain, fmt.Errorf("private key does not match any of the %d certificates in PFX", len(chain)+1) |
| 291 | +} |
| 292 | + |
| 293 | +func certMatchesKey(key any, cert *x509.Certificate) bool { |
| 294 | + switch pub := cert.PublicKey.(type) { |
| 295 | + case *rsa.PublicKey: |
| 296 | + priv, ok := key.(*rsa.PrivateKey) |
| 297 | + if !ok { |
| 298 | + return false |
| 299 | + } |
| 300 | + |
| 301 | + if pub.N.Cmp(priv.N) != 0 { |
| 302 | + return false |
| 303 | + } |
| 304 | + |
| 305 | + return true |
| 306 | + case *ecdsa.PublicKey: |
| 307 | + priv, ok := key.(*ecdsa.PrivateKey) |
| 308 | + if !ok { |
| 309 | + return false |
| 310 | + } |
| 311 | + |
| 312 | + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { |
| 313 | + return false |
| 314 | + } |
| 315 | + |
| 316 | + return true |
| 317 | + case ed25519.PublicKey: |
| 318 | + priv, ok := key.(ed25519.PrivateKey) |
| 319 | + if !ok { |
| 320 | + return false |
| 321 | + } |
| 322 | + |
| 323 | + privPublicKey, ok := priv.Public().(ed25519.PublicKey) |
| 324 | + if !ok { |
| 325 | + return false |
| 326 | + } |
| 327 | + |
| 328 | + if !bytes.Equal(privPublicKey, pub) { |
| 329 | + return false |
| 330 | + } |
| 331 | + |
| 332 | + return true |
| 333 | + default: |
| 334 | + return false |
| 335 | + } |
| 336 | +} |
0 commit comments