diff --git a/conn.go b/conn.go index bc098360..c28ec12a 100644 --- a/conn.go +++ b/conn.go @@ -1113,20 +1113,24 @@ func (cn *conn) ssl(o values) error { return nil } - w := cn.writeBuf(0) - w.int32(80877103) - if err = cn.sendStartupPacket(w); err != nil { - return err - } + // only negotiate the ssl handshake if requested (which is the default). + // sllnegotiation=direct is supported by pg17 and above. + if sslnegotiation(o) { + w := cn.writeBuf(0) + w.int32(80877103) + if err = cn.sendStartupPacket(w); err != nil { + return err + } - b := cn.scratch[:1] - _, err = io.ReadFull(cn.c, b) - if err != nil { - return err - } + b := cn.scratch[:1] + _, err = io.ReadFull(cn.c, b) + if err != nil { + return err + } - if b[0] != 'S' { - return ErrSSLNotSupported + if b[0] != 'S' { + return ErrSSLNotSupported + } } cn.c, err = upgrade(cn.c) diff --git a/ssl.go b/ssl.go index 36b61ba4..5fd9bb73 100644 --- a/ssl.go +++ b/ssl.go @@ -202,3 +202,14 @@ func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error _, err = certs[0].Verify(opts) return err } + +// sslnegotiation returns true if we should negotiate SSL. +// returns false if there should be no negotiation and we should upgrade immediately. +func sslnegotiation(o values) bool { + if negotiation, ok := o["sslnegotiation"]; ok { + if negotiation == "direct" { + return false + } + } + return true +} diff --git a/ssl_test.go b/ssl_test.go index 4c631b81..86c97cb0 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -308,30 +308,49 @@ func TestSNISupport(t *testing.T) { conn_param string hostname string expected_sni string + direct bool }{ { name: "SNI is set by default", conn_param: "", hostname: "localhost", expected_sni: "localhost", + direct: false, }, { name: "SNI is passed when asked for", conn_param: "sslsni=1", hostname: "localhost", expected_sni: "localhost", + direct: false, }, { name: "SNI is not passed when disabled", conn_param: "sslsni=0", hostname: "localhost", expected_sni: "", + direct: false, }, { name: "SNI is not set for IPv4", conn_param: "", hostname: "127.0.0.1", expected_sni: "", + direct: false, + }, + { + name: "SNI is set for negotiated ssl", + conn_param: "sslnegotiation=postgres", + hostname: "localhost", + expected_sni: "localhost", + direct: false, + }, + { + name: "SNI is set for direct ssl", + conn_param: "sslnegotiation=direct", + hostname: "localhost", + expected_sni: "localhost", + direct: true, }, } for _, tt := range tests { @@ -346,7 +365,7 @@ func TestSNISupport(t *testing.T) { } serverErrChan := make(chan error, 1) serverSNINameChan := make(chan string, 1) - go mockPostgresSSL(listener, serverErrChan, serverSNINameChan) + go mockPostgresSSL(listener, tt.direct, serverErrChan, serverSNINameChan) defer listener.Close() defer close(serverErrChan) @@ -381,7 +400,7 @@ func TestSNISupport(t *testing.T) { // // Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection. // While reading clientHello catch passed SNI data and report it to nameChan. -func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) { +func mockPostgresSSL(listener net.Listener, direct bool, errChan chan error, nameChan chan string) { var sniHost string conn, err := listener.Accept() @@ -397,23 +416,25 @@ func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan st return } - // Receive StartupMessage with SSL Request - startupMessage := make([]byte, 8) - if _, err := io.ReadFull(conn, startupMessage); err != nil { - errChan <- err - return - } - // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber - if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) { - errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) - return - } + if !direct { + // Receive StartupMessage with SSL Request + startupMessage := make([]byte, 8) + if _, err := io.ReadFull(conn, startupMessage); err != nil { + errChan <- err + return + } + // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber + if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) { + errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) + return + } - // Respond with SSLOk - _, err = conn.Write([]byte("S")) - if err != nil { - errChan <- err - return + // Respond with SSLOk + _, err = conn.Write([]byte("S")) + if err != nil { + errChan <- err + return + } } // Set up TLS context to catch clientHello. It will always error out during handshake