Skip to content

Commit f89563f

Browse files
authored
feat: enable TLS cert verification extensions (#425)
1 parent 68fcec7 commit f89563f

File tree

2 files changed

+164
-29
lines changed

2 files changed

+164
-29
lines changed

transport/tls/stream_dialer.go

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,25 @@ func normalizeHost(host string) string {
8484
return strings.ToLower(host)
8585
}
8686

87-
// ClientConfig encodes the parameters for a TLS client connection.
87+
// ClientConfig holds configuration parameters used for establishing a TLS client connection.
8888
type ClientConfig struct {
89-
// The host name for the Server Name Indication (SNI).
89+
// ServerName specifies the hostname sent for Server Name Indication (SNI).
90+
// This is often the same as the dialed hostname but can be overridden using [WithSNI].
9091
ServerName string
91-
// The hostname to use for certificate validation.
92-
CertificateName string
93-
// The protocol id list for protocol negotiation (ALPN).
92+
93+
// NextProtos lists the application-layer protocols (e.g., "h2", "http/1.1")
94+
// supported by the client for Application-Layer Protocol Negotiation (ALPN).
95+
// See [WithALPN].
9496
NextProtos []string
95-
// The cache for sessin resumption.
97+
98+
// SessionCache enables TLS session resumption by providing a cache for session tickets.
99+
// If nil, session resumption is disabled. See [WithSessionCache].
96100
SessionCache tls.ClientSessionCache
101+
102+
// CertVerifier specifies a custom verifier for the peer's certificate chain.
103+
// If nil, [StandardCertVerifier] is used by default, validating against the dialed
104+
// server name. See [WithCertVerifier].
105+
CertVerifier CertVerifier
97106
}
98107

99108
// toStdConfig creates a [tls.Config] based on the configured parameters.
@@ -106,33 +115,25 @@ func (cfg *ClientConfig) toStdConfig() *tls.Config {
106115
// replacing. This will not disable VerifyConnection.
107116
InsecureSkipVerify: true,
108117
VerifyConnection: func(cs tls.ConnectionState) error {
109-
// This replicates the logic in the standard library verification:
110-
// https://cs.opensource.google/go/go/+/master:src/crypto/tls/handshake_client.go;l=982;drc=b5f87b5407916c4049a3158cc944cebfd7a883a9
111-
// And the documentation example:
112-
// https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection
113-
opts := x509.VerifyOptions{
114-
DNSName: cfg.CertificateName,
115-
Intermediates: x509.NewCertPool(),
116-
}
117-
for _, cert := range cs.PeerCertificates[1:] {
118-
opts.Intermediates.AddCert(cert)
119-
}
120-
_, err := cs.PeerCertificates[0].Verify(opts)
121-
return err
118+
return cfg.CertVerifier.VerifyCertificate(&CertVerificationContext{
119+
PeerCertificates: cs.PeerCertificates,
120+
})
122121
},
123122
}
124123
}
125124

126-
// ClientOption allows configuring the parameters to be used for a client TLS connection.
127-
type ClientOption func(serverName string, config *ClientConfig)
128-
129125
// WrapConn wraps a [transport.StreamConn] in a TLS connection.
130126
func WrapConn(ctx context.Context, conn transport.StreamConn, serverName string, options ...ClientOption) (transport.StreamConn, error) {
131-
cfg := ClientConfig{ServerName: serverName, CertificateName: serverName}
127+
cfg := ClientConfig{ServerName: serverName}
132128
normName := normalizeHost(serverName)
133129
for _, option := range options {
134130
option(normName, &cfg)
135131
}
132+
if cfg.CertVerifier == nil {
133+
// If CertVerifier is not provided, use the default verification logic,
134+
// which validates the peer certificate against the provided serverName.
135+
cfg.CertVerifier = &StandardCertVerifier{CertificateName: serverName}
136+
}
136137
tlsConn := tls.Client(conn, cfg.toStdConfig())
137138
err := tlsConn.HandshakeContext(ctx)
138139
if err != nil {
@@ -181,10 +182,60 @@ func WithSessionCache(sessionCache tls.ClientSessionCache) ClientOption {
181182
}
182183
}
183184

184-
// WithCertificateName sets the hostname to be used for the certificate cerification.
185-
// If absent, defaults to the dialed hostname.
186-
func WithCertificateName(hostname string) ClientOption {
185+
// WithCertVerifier sets the verifier to be used for the certificate verification.
186+
func WithCertVerifier(verifier CertVerifier) ClientOption {
187187
return func(_ string, config *ClientConfig) {
188-
config.CertificateName = hostname
188+
config.CertVerifier = verifier
189189
}
190190
}
191+
192+
// CertVerificationContext provides connection-time context for the certificate verification.
193+
type CertVerificationContext struct {
194+
// PeerCertificates are the parsed certificates sent by the peer, in the
195+
// order in which they were sent. The first element is the leaf certificate
196+
// that the connection is verified against.
197+
//
198+
// On the client side, it can't be empty. On the server side, it can be
199+
// empty if Config.ClientAuth is not RequireAnyClientCert or
200+
// RequireAndVerifyClientCert.
201+
//
202+
// PeerCertificates and its contents should not be modified.
203+
PeerCertificates []*x509.Certificate
204+
}
205+
206+
// CertVerifier verifies peer certificates for TLS connections.
207+
type CertVerifier interface {
208+
// VerifyCertificate verified a peer certificate given the context.
209+
VerifyCertificate(info *CertVerificationContext) error
210+
}
211+
212+
// StandardCertVerifier implements [CertVerifier] using standard TLS certificate chain verification.
213+
type StandardCertVerifier struct {
214+
// CertificateName specifies the expected DNS name (or IP address) against which
215+
// the peer's leaf certificate is verified.
216+
CertificateName string
217+
// Roots contains the set of trusted root certificate authorities.
218+
// If nil, the host's default root CAs are used for certificate chain validation.
219+
Roots *x509.CertPool
220+
}
221+
222+
// VerifyCertificate implements [CertVerifier].
223+
func (v *StandardCertVerifier) VerifyCertificate(certContext *CertVerificationContext) error {
224+
// This replicates the logic in the standard library verification:
225+
// https://cs.opensource.google/go/go/+/master:src/crypto/tls/handshake_client.go;l=982;drc=b5f87b5407916c4049a3158cc944cebfd7a883a9
226+
// And the documentation example:
227+
// https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection
228+
opts := x509.VerifyOptions{
229+
DNSName: v.CertificateName,
230+
Roots: v.Roots,
231+
Intermediates: x509.NewCertPool(),
232+
}
233+
for _, cert := range certContext.PeerCertificates[1:] {
234+
opts.Intermediates.AddCert(cert)
235+
}
236+
_, err := certContext.PeerCertificates[0].Verify(opts)
237+
return err
238+
}
239+
240+
// ClientOption allows configuring the parameters to be used for a client TLS connection.
241+
type ClientOption func(serverName string, config *ClientConfig)

transport/tls/stream_dialer_test.go

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ package tls
1616

1717
import (
1818
"context"
19+
"crypto/ecdsa"
20+
"crypto/elliptic"
21+
"crypto/rand"
1922
"crypto/x509"
23+
"crypto/x509/pkix"
24+
"math/big"
25+
"net"
2026
"testing"
27+
"time"
2128

2229
"github.com/Jigsaw-Code/outline-sdk/transport"
2330
"github.com/stretchr/testify/require"
@@ -77,7 +84,7 @@ func TestIP(t *testing.T) {
7784
}
7885

7986
func TestIPOverride(t *testing.T) {
80-
sd, err := NewStreamDialer(&transport.TCPDialer{}, WithCertificateName("8.8.8.8"))
87+
sd, err := NewStreamDialer(&transport.TCPDialer{}, WithCertVerifier(&StandardCertVerifier{CertificateName: "8.8.8.8"}))
8188
require.NoError(t, err)
8289
conn, err := sd.DialStream(context.Background(), "dns.google:443")
8390
require.NoError(t, err)
@@ -101,7 +108,7 @@ func TestNoSNI(t *testing.T) {
101108
}
102109

103110
func TestAllCustom(t *testing.T) {
104-
sd, err := NewStreamDialer(&transport.TCPDialer{}, WithSNI("decoy.android.com"), WithCertificateName("www.youtube.com"))
111+
sd, err := NewStreamDialer(&transport.TCPDialer{}, WithSNI("decoy.android.com"), WithCertVerifier(&StandardCertVerifier{CertificateName: "www.youtube.com"}))
105112
require.NoError(t, err)
106113
conn, err := sd.DialStream(context.Background(), "www.google.com:443")
107114
require.NoError(t, err)
@@ -176,3 +183,80 @@ func (c countedStreamConn) Close() error {
176183
c.counter.activeConns--
177184
return c.StreamConn.Close()
178185
}
186+
187+
// Helper function to create a self-signed certificate (Root CA)
188+
func createRootCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
189+
t.Helper()
190+
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
191+
require.NoError(t, err)
192+
193+
template := x509.Certificate{
194+
SerialNumber: big.NewInt(1),
195+
Subject: pkix.Name{Organization: []string{"Test Root CA"}},
196+
NotBefore: time.Now().Add(-1 * time.Hour),
197+
NotAfter: time.Now().Add(24 * time.Hour),
198+
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
199+
BasicConstraintsValid: true,
200+
IsCA: true,
201+
}
202+
203+
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
204+
require.NoError(t, err)
205+
206+
cert, err := x509.ParseCertificate(certDER)
207+
require.NoError(t, err)
208+
209+
return cert, privKey
210+
}
211+
212+
// Helper function to create a leaf certificate signed by a parent
213+
func createLeafCert(t *testing.T, dnsNames []string, ipAddresses []net.IP, parentCert *x509.Certificate, parentKey *ecdsa.PrivateKey, notBefore, notAfter time.Time) (*x509.Certificate, *ecdsa.PrivateKey) {
214+
t.Helper()
215+
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
216+
require.NoError(t, err)
217+
218+
template := x509.Certificate{
219+
SerialNumber: big.NewInt(2),
220+
Subject: pkix.Name{CommonName: dnsNames[0]}, // Use first DNS name as CN
221+
DNSNames: dnsNames,
222+
IPAddresses: ipAddresses,
223+
NotBefore: notBefore,
224+
NotAfter: notAfter,
225+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
226+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, // Server cert
227+
BasicConstraintsValid: true,
228+
IsCA: false,
229+
}
230+
231+
certDER, err := x509.CreateCertificate(rand.Reader, &template, parentCert, &privKey.PublicKey, parentKey)
232+
require.NoError(t, err)
233+
234+
cert, err := x509.ParseCertificate(certDER)
235+
require.NoError(t, err)
236+
237+
return cert, privKey
238+
}
239+
240+
func TestGeneratedCert_Valid(t *testing.T) {
241+
// 1. Generate Certs
242+
rootCA, rootKey := createRootCA(t)
243+
leafCert, _ := createLeafCert(t, []string{"test.local"}, nil, rootCA, rootKey, time.Now().Add(-1*time.Hour), time.Now().Add(1*time.Hour))
244+
245+
// 2. Setup Root Pool for Client
246+
rootPool := x509.NewCertPool()
247+
rootPool.AddCert(rootCA)
248+
249+
verificationContext := &CertVerificationContext{PeerCertificates: []*x509.Certificate{leafCert}}
250+
251+
sysVerifier := &StandardCertVerifier{CertificateName: "test.local"}
252+
require.Error(t, sysVerifier.VerifyCertificate(verificationContext))
253+
254+
customVerifier := &StandardCertVerifier{CertificateName: "test.local", Roots: rootPool}
255+
require.NoError(t, customVerifier.VerifyCertificate(verificationContext))
256+
257+
wrongDomainVerifier := &StandardCertVerifier{CertificateName: "other.local", Roots: rootPool}
258+
var hostErr x509.HostnameError
259+
require.ErrorAs(t, wrongDomainVerifier.VerifyCertificate(verificationContext), &hostErr)
260+
require.Equal(t, "other.local", hostErr.Host)
261+
require.Equal(t, leafCert, hostErr.Certificate)
262+
}

0 commit comments

Comments
 (0)