diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2StandardConnection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2StandardConnection.cs index 4f89a39047..f15fd5741d 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2StandardConnection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2StandardConnection.cs @@ -112,19 +112,15 @@ protected override TTransport CreateTransport() TTransport baseTransport; if (TlsOptions.IsTlsEnabled) { - X509Certificate2? trustedCert = !string.IsNullOrEmpty(TlsOptions.TrustedCertificatePath) - ? new X509Certificate2(TlsOptions.TrustedCertificatePath!) - : null; - RemoteCertificateValidationCallback certValidator = (sender, cert, chain, errors) => HiveServer2TlsImpl.ValidateCertificate(cert, errors, TlsOptions); if (IPAddress.TryParse(hostName!, out var ipAddress)) { - baseTransport = new TTlsSocketTransport(ipAddress, portValue, config: new(), 0, trustedCert, certValidator); + baseTransport = new TTlsSocketTransport(ipAddress, portValue, config: new(), 0, null, certValidator); } else { - baseTransport = new TTlsSocketTransport(hostName!, portValue, config: new(), 0, trustedCert, certValidator); + baseTransport = new TTlsSocketTransport(hostName!, portValue, config: new(), 0, null, certValidator); } } else diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2TlsImpl.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2TlsImpl.cs index 9a9311d7af..836f7d46dd 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2TlsImpl.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2TlsImpl.cs @@ -22,6 +22,7 @@ using System.Net.Http; using System.Net.Security; using System.Security.Cryptography.X509Certificates; +using System.Text.RegularExpressions; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { @@ -138,6 +139,30 @@ static internal TlsProperties GetStandardTlsOptions(IReadOnlyDictionary LoadPemCertificates(string pemPath) + { + List certs = new(); + string pemContent = File.ReadAllText(pemPath); + + MatchCollection matches = Regex.Matches( + pemContent, + "-----BEGIN CERTIFICATE-----(.*?)-----END CERTIFICATE-----", + RegexOptions.Singleline); + + foreach (Match match in matches) + { + string base64 = match.Groups[1].Value + .Replace("\r", "") + .Replace("\n", "") + .Trim(); + + byte[] rawData = Convert.FromBase64String(base64); + certs.Add(new X509Certificate2(rawData)); + } + + return certs; + } + static internal bool ValidateCertificate(X509Certificate? cert, SslPolicyErrors policyErrors, TlsProperties tlsProperties) { if (policyErrors == SslPolicyErrors.None || tlsProperties.DisableServerCertificateValidation) @@ -155,14 +180,39 @@ static internal bool ValidateCertificate(X509Certificate? cert, SslPolicyErrors return !policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors) || (tlsProperties.AllowSelfSigned && IsSelfSigned(cert2)); } - X509Certificate2 trustedRoot = new X509Certificate2(tlsProperties.TrustedCertificatePath); X509Chain customChain = new(); - customChain.ChainPolicy.ExtraStore.Add(trustedRoot); // "tell the X509Chain class that I do trust this root certs and it should check just the certs in the chain and nothing else" customChain.ChainPolicy.VerificationFlags = X509VerificationFlags.AllowUnknownCertificateAuthority; + var collection = LoadPemCertificates(tlsProperties.TrustedCertificatePath!); + + foreach (var trustedCert in collection) + { + customChain.ChainPolicy.ExtraStore.Add(trustedCert); + } customChain.ChainPolicy.RevocationMode = X509RevocationMode.Online; bool chainValid = customChain.Build(cert2); + if (chainValid) + { + bool trustedBy = false; + foreach (X509ChainElement element in customChain.ChainElements) + { + foreach (X509Certificate2 ca in collection) + { + if (element.Certificate.Thumbprint == ca.Thumbprint) + { + trustedBy = true; + break; + } + } + if (trustedBy) + { + break; + } + } + chainValid = chainValid && trustedBy; + } + return chainValid || (tlsProperties.AllowSelfSigned && IsSelfSigned(cert2)); } }