@@ -2,8 +2,9 @@ package adauth
22
33import (
44 "context"
5- "crypto/rsa "
5+ "crypto/x509 "
66 "encoding/hex"
7+ "encoding/pem"
78 "fmt"
89 "io"
910 "net"
@@ -28,8 +29,15 @@ type Options struct {
2829 CCache string
2930 DomainController string
3031 ForceKerberos bool
31- PFXFileName string
32- PFXPassword string
32+
33+ // It is possible to specify a cert/key pair directly, as PEM files or as a
34+ // single PFX file.
35+ Certificate * x509.Certificate
36+ CertificateKey any
37+ PFXFileName string
38+ PFXPassword string
39+ PEMCertFileName string
40+ PEMKeyFileName string
3341
3442 credential * Credential
3543 flagset * pflag.FlagSet
@@ -273,25 +281,31 @@ func (opts *Options) preliminaryCredential() (*Credential, error) {
273281 Resolver : opts .Resolver ,
274282 }
275283
276- if opts .PFXFileName != "" {
277- pfxData , err := os .ReadFile (opts .PFXFileName )
284+ switch {
285+ case opts .Certificate != nil && opts .CertificateKey == nil :
286+ return nil , fmt .Errorf ("specify a key file for the client certificate" )
287+ case opts .Certificate != nil && opts .CertificateKey != nil :
288+ cred .ClientCert = opts .Certificate
289+ cred .ClientCertKey = opts .CertificateKey
290+ case opts .PFXFileName != "" :
291+ cert , key , caCerts , err := readPFX (opts .PFXFileName , opts .PFXPassword )
278292 if err != nil {
279- return nil , fmt . Errorf ( "read PFX: %w" , err )
293+ return nil , err
280294 }
281295
282- key , cert , caCerts , err := pkcs12 .DecodeChain (pfxData , opts .PFXPassword )
296+ cred .ClientCert = cert
297+ cred .ClientCertKey = key
298+ cred .CACerts = caCerts
299+ case opts .PEMCertFileName != "" && opts .PEMKeyFileName == "" :
300+ return nil , fmt .Errorf ("specify a key file for the client certificate" )
301+ case opts .PEMCertFileName != "" && opts .PEMKeyFileName != "" :
302+ cert , key , err := readPEMCertAndKey (opts .PEMCertFileName , opts .PEMKeyFileName )
283303 if err != nil {
284- return nil , fmt .Errorf ("decode PFX: %w" , err )
285- }
286-
287- rsaKey , ok := key .(* rsa.PrivateKey )
288- if ! ok {
289- return nil , fmt .Errorf ("PFX key is not an RSA private key but %T" , rsaKey )
304+ return nil , err
290305 }
291306
292307 cred .ClientCert = cert
293- cred .ClientCertKey = rsaKey
294- cred .CACerts = caCerts
308+ cred .ClientCertKey = key
295309 }
296310
297311 //nolint:nestif
@@ -313,6 +327,67 @@ func (opts *Options) preliminaryCredential() (*Credential, error) {
313327 return cred , nil
314328}
315329
330+ func readPFX (fileName string , password string ) (* x509.Certificate , any , []* x509.Certificate , error ) {
331+ pfxData , err := os .ReadFile (fileName )
332+ if err != nil {
333+ return nil , nil , nil , fmt .Errorf ("read PFX: %w" , err )
334+ }
335+
336+ key , cert , caCerts , err := pkcs12 .DecodeChain (pfxData , password )
337+ if err != nil {
338+ return nil , nil , nil , fmt .Errorf ("decode PFX: %w" , err )
339+ }
340+
341+ return cert , key , caCerts , nil
342+ }
343+
344+ func readPEMCertAndKey (certFileName string , certKeyFileName string ) (* x509.Certificate , any , error ) {
345+ certData , err := os .ReadFile (certFileName )
346+ if err != nil {
347+ return nil , nil , fmt .Errorf ("read cert file: %w" , err )
348+ }
349+
350+ block , _ := pem .Decode (certData )
351+ if block == nil {
352+ return nil , nil , fmt .Errorf ("could not PEM-decode certificate" )
353+ }
354+
355+ if block .Type != "" && ! strings .Contains (strings .ToLower (block .Type ), "certificate" ) {
356+ return nil , nil , fmt .Errorf ("unexpected block type for certificate: %q" , block .Type )
357+ }
358+
359+ cert , err := x509 .ParseCertificate (block .Bytes )
360+ if err != nil {
361+ return nil , nil , fmt .Errorf ("parse certificate: %w" , err )
362+ }
363+
364+ certKeyData , err := os .ReadFile (certKeyFileName )
365+ if err != nil {
366+ return nil , nil , fmt .Errorf ("read cert key file: %w" , err )
367+ }
368+
369+ block , _ = pem .Decode (certKeyData )
370+ if block == nil {
371+ return nil , nil , fmt .Errorf ("could not PEM-decode certificate key" )
372+ }
373+
374+ if block .Type != "" && ! strings .Contains (strings .ToLower (block .Type ), "key" ) {
375+ return nil , nil , fmt .Errorf ("unexpected block type for key: %q" , block .Type )
376+ }
377+
378+ key , err := x509 .ParsePKCS8PrivateKey (block .Bytes )
379+ if err != nil {
380+ key , pkcs1Err := x509 .ParsePKCS1PrivateKey (block .Bytes )
381+ if pkcs1Err == nil {
382+ return cert , key , nil
383+ }
384+
385+ return nil , nil , fmt .Errorf ("parse private key: %w" , err )
386+ }
387+
388+ return cert , key , nil
389+ }
390+
316391// NewDebugFunc creates a debug output handler.
317392func NewDebugFunc (enabled * bool , writer io.Writer , colored bool ) func (string , ... any ) {
318393 return func (format string , a ... any ) {
0 commit comments