Skip to content

Commit 2852c84

Browse files
committed
PR changes
1 parent 6570045 commit 2852c84

File tree

3 files changed

+159
-37
lines changed

3 files changed

+159
-37
lines changed

Sources/NIOCertificateReloading/CertificateReloader.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ extension TLSConfiguration {
117117

118118
/// Configure a ``CertificateReloader`` to observe updates for the certificate and key pair used.
119119
/// - Parameter reloader: A ``CertificateReloader`` to watch for certificate and key pair updates.
120-
mutating public func setCertificateReloader(_ reloader: some CertificateReloader) {
120+
public mutating func setCertificateReloader(_ reloader: some CertificateReloader) {
121121
self.sslContextCallback = { _, promise in
122122
promise.succeed(reloader.sslContextConfigurationOverride)
123123
}

Sources/NIOCertificateReloading/TimedCertificateReloader.swift

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public struct TimedCertificateReloader: CertificateReloader {
129129
public struct Location: Sendable, CustomStringConvertible {
130130
fileprivate enum _Backing: CustomStringConvertible {
131131
case file(path: String)
132-
case memory(provider: @Sendable () -> [UInt8]?)
132+
case memory(provider: @Sendable () throws -> [UInt8])
133133

134134
var description: String {
135135
switch self {
@@ -157,10 +157,10 @@ public struct TimedCertificateReloader: CertificateReloader {
157157
public static func file(path: String) -> Self { Self(_Backing.file(path: path)) }
158158

159159
/// This certificate/key is available in memory, and will be provided by the given closure.
160-
/// - Parameter provider: A closure providing the bytes for the given certificate or key. This closure should return
161-
/// `nil` if a certificate/key isn't currently available for whatever reason.
160+
/// - Parameter provider: A closure providing the bytes for the given certificate or key. It may throw if, e.g., a
161+
/// certificate or key isn't available.
162162
/// - Returns: A `Location`.
163-
public static func memory(provider: @Sendable @escaping () -> [UInt8]?) -> Self {
163+
public static func memory(provider: @Sendable @escaping () throws -> [UInt8]) -> Self {
164164
Self(_Backing.memory(provider: provider))
165165
}
166166
}
@@ -310,7 +310,7 @@ public struct TimedCertificateReloader: CertificateReloader {
310310
/// - logger: An optional logger.
311311
/// - Returns: The newly created ``TimedCertificateReloader``.
312312
/// - Throws: If either the certificate or private key sources cannot be loaded, an error will be thrown.
313-
static public func makeReloaderValidatingSources(
313+
public static func makeReloaderValidatingSources(
314314
refreshInterval: Duration,
315315
certificateSource: CertificateSource,
316316
privateKeySource: PrivateKeySource,
@@ -348,18 +348,18 @@ public struct TimedCertificateReloader: CertificateReloader {
348348

349349
/// Manually attempt a certificate and private key pair update.
350350
public func reload() throws {
351-
if let certificateBytes = try self.loadCertificate(),
352-
let keyBytes = try self.loadPrivateKey(),
353-
let certificate = try self.parseCertificate(from: certificateBytes),
351+
let certificateBytes = try self.loadCertificate()
352+
let keyBytes = try self.loadPrivateKey()
353+
if let certificate = try self.parseCertificate(from: certificateBytes),
354354
let key = try self.parsePrivateKey(from: keyBytes),
355355
key.publicKey.isValidSignature(certificate.signature, for: certificate)
356356
{
357357
try self.attemptToUpdatePair(certificate: certificate, key: key)
358358
}
359359
}
360360

361-
private func loadCertificate() throws -> [UInt8]? {
362-
let certificateBytes: [UInt8]?
361+
private func loadCertificate() throws -> [UInt8] {
362+
let certificateBytes: [UInt8]
363363
switch self.certificateSource.location._backing {
364364
case .file(let path):
365365
guard let bytes = FileManager.default.contents(atPath: path) else {
@@ -368,13 +368,13 @@ public struct TimedCertificateReloader: CertificateReloader {
368368
certificateBytes = Array(bytes)
369369

370370
case .memory(let bytesProvider):
371-
certificateBytes = bytesProvider()
371+
certificateBytes = try bytesProvider()
372372
}
373373
return certificateBytes
374374
}
375375

376-
private func loadPrivateKey() throws -> [UInt8]? {
377-
let keyBytes: [UInt8]?
376+
private func loadPrivateKey() throws -> [UInt8] {
377+
let keyBytes: [UInt8]
378378
switch self.privateKeySource.location._backing {
379379
case .file(let path):
380380
guard let bytes = FileManager.default.contents(atPath: path) else {
@@ -383,7 +383,7 @@ public struct TimedCertificateReloader: CertificateReloader {
383383
keyBytes = Array(bytes)
384384

385385
case .memory(let bytesProvider):
386-
keyBytes = bytesProvider()
386+
keyBytes = try bytesProvider()
387387
}
388388
return keyBytes
389389
}

Tests/NIOCertificateReloadingTests/TimedCertificateReloaderTests.swift

Lines changed: 144 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ import SwiftASN1
2020
import X509
2121
import XCTest
2222

23+
#if canImport(FoundationEssentials)
24+
import FoundationEssentials
25+
#else
26+
import Foundation
27+
#endif
28+
2329
final class TimedCertificateReloaderTests: XCTestCase {
2430
func testCertificatePathDoesNotExist() async throws {
2531
try await runTimedCertificateReloaderTest(
@@ -58,7 +64,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
5864
func testKeyPathDoesNotExist() async throws {
5965
try await runTimedCertificateReloaderTest(
6066
certificate: .init(
61-
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
67+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
6268
format: .der
6369
),
6470
privateKey: .init(
@@ -77,7 +83,7 @@ final class TimedCertificateReloaderTests: XCTestCase {
7783
do {
7884
try await runTimedCertificateReloaderTest(
7985
certificate: .init(
80-
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
86+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
8187
format: .der
8288
),
8389
privateKey: .init(
@@ -95,10 +101,39 @@ final class TimedCertificateReloaderTests: XCTestCase {
95101
}
96102
}
97103

98-
func testCertificateIsInUnexpectedFormat() async throws {
104+
func testCertificateIsInUnexpectedFormat_FromMemory() async throws {
105+
try await runTimedCertificateReloaderTest(
106+
certificate: .init(
107+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
108+
format: .pem
109+
),
110+
privateKey: .init(
111+
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
112+
format: .der
113+
)
114+
) { reloader in
115+
let override = reloader.sslContextConfigurationOverride
116+
XCTAssertNil(override.certificateChain)
117+
XCTAssertNil(override.privateKey)
118+
}
119+
}
120+
121+
private func createTempFile(contents: Data) throws -> URL {
122+
let directory = FileManager.default.temporaryDirectory
123+
let filename = UUID().uuidString
124+
let fileURL = directory.appendingPathComponent(filename)
125+
guard FileManager.default.createFile(atPath: fileURL.path, contents: contents) else {
126+
throw TestError.couldNotCreateFile
127+
}
128+
return fileURL
129+
}
130+
131+
func testCertificateIsInUnexpectedFormat_FromFile() async throws {
132+
let certBytes = try Self.sampleCert.serializeAsPEM().derBytes
133+
let file = try self.createTempFile(contents: Data(certBytes))
99134
try await runTimedCertificateReloaderTest(
100135
certificate: .init(
101-
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
136+
location: .file(path: file.path),
102137
format: .pem
103138
),
104139
privateKey: .init(
@@ -112,10 +147,10 @@ final class TimedCertificateReloaderTests: XCTestCase {
112147
}
113148
}
114149

115-
func testKeyIsInUnexpectedFormat() async throws {
150+
func testKeyIsInUnexpectedFormat_FromMemory() async throws {
116151
try await runTimedCertificateReloaderTest(
117152
certificate: .init(
118-
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
153+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
119154
format: .der
120155
),
121156
privateKey: .init(
@@ -129,10 +164,29 @@ final class TimedCertificateReloaderTests: XCTestCase {
129164
}
130165
}
131166

167+
func testKeyIsInUnexpectedFormat_FromFile() async throws {
168+
let keyBytes = Self.samplePrivateKey.derRepresentation
169+
let file = try self.createTempFile(contents: keyBytes)
170+
try await runTimedCertificateReloaderTest(
171+
certificate: .init(
172+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
173+
format: .der
174+
),
175+
privateKey: .init(
176+
location: .file(path: file.path),
177+
format: .pem
178+
)
179+
) { reloader in
180+
let override = reloader.sslContextConfigurationOverride
181+
XCTAssertNil(override.certificateChain)
182+
XCTAssertNil(override.privateKey)
183+
}
184+
}
185+
132186
func testCertificateAndKeyDoNotMatch() async throws {
133187
try await runTimedCertificateReloaderTest(
134188
certificate: .init(
135-
location: .memory(provider: { try? Self.sampleCert.serializeAsPEM().derBytes }),
189+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
136190
format: .der
137191
),
138192
privateKey: .init(
@@ -146,17 +200,31 @@ final class TimedCertificateReloaderTests: XCTestCase {
146200
}
147201
}
148202

149-
func testReloadSuccessfully() async throws {
150-
let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(nil)
203+
enum TestError: Error {
204+
case emptyCertificate
205+
case emptyPrivateKey
206+
case couldNotCreateFile
207+
}
208+
209+
func testReloadSuccessfully_FromMemory() async throws {
210+
let certificateBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox([])
151211
try await runTimedCertificateReloaderTest(
152212
certificate: .init(
153-
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
213+
location: .memory(provider: {
214+
let cert = certificateBox.withLockedValue({ $0 })
215+
if cert.isEmpty {
216+
throw TestError.emptyCertificate
217+
}
218+
return cert
219+
}),
154220
format: .der
155221
),
156222
privateKey: .init(
157223
location: .memory(provider: { Array(Self.samplePrivateKey.derRepresentation) }),
158224
format: .der
159-
)
225+
),
226+
// We need to disable validation because the provider will initially be empty.
227+
validateSources: false
160228
) { reloader in
161229
// On first attempt, we should have no certificate or private key overrides available,
162230
// since the certificate box is empty.
@@ -183,13 +251,61 @@ final class TimedCertificateReloaderTests: XCTestCase {
183251
}
184252
}
185253

254+
func testReloadSuccessfully_FromFile() async throws {
255+
// Start with empty files.
256+
let certificateFile = try self.createTempFile(contents: Data())
257+
let privateKeyFile = try self.createTempFile(contents: Data())
258+
try await runTimedCertificateReloaderTest(
259+
certificate: .init(
260+
location: .file(path: certificateFile.path),
261+
format: .der
262+
),
263+
privateKey: .init(
264+
location: .file(path: privateKeyFile.path),
265+
format: .der
266+
),
267+
// We need to disable validation because the files will not initially have any contents.
268+
validateSources: false
269+
) { reloader in
270+
// On first attempt, we should have no certificate or private key overrides available,
271+
// since the certificate box is empty.
272+
var override = reloader.sslContextConfigurationOverride
273+
XCTAssertNil(override.certificateChain)
274+
XCTAssertNil(override.privateKey)
275+
276+
// Update the files to contain data
277+
try Data(try Self.sampleCert.serializeAsPEM().derBytes).write(to: certificateFile)
278+
try Self.samplePrivateKey.derRepresentation.write(to: privateKeyFile)
279+
280+
// Give the reload loop some time to run and update the cert-key pair.
281+
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
282+
283+
// Now the overrides should be present.
284+
override = reloader.sslContextConfigurationOverride
285+
XCTAssertEqual(
286+
override.certificateChain,
287+
[.certificate(try .init(bytes: Self.sampleCert.serializeAsPEM().derBytes, format: .der))]
288+
)
289+
XCTAssertEqual(
290+
override.privateKey,
291+
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
292+
)
293+
}
294+
}
295+
186296
func testCertificateNotFoundAtReload() async throws {
187-
let certificateBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(
297+
let certificateBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox(
188298
try! Self.sampleCert.serializeAsPEM().derBytes
189299
)
190300
try await runTimedCertificateReloaderTest(
191301
certificate: .init(
192-
location: .memory(provider: { certificateBox.withLockedValue({ $0 }) }),
302+
location: .memory(provider: {
303+
let cert = certificateBox.withLockedValue({ $0 })
304+
if cert.isEmpty {
305+
throw TestError.emptyCertificate
306+
}
307+
return cert
308+
}),
193309
format: .der
194310
),
195311
privateKey: .init(
@@ -208,8 +324,8 @@ final class TimedCertificateReloaderTests: XCTestCase {
208324
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
209325
)
210326

211-
// Update the box to not contain a certificate.
212-
certificateBox.withLockedValue({ $0 = nil })
327+
// Update the box to contain empty bytes: this will cause the provider to throw.
328+
certificateBox.withLockedValue({ $0 = [] })
213329

214330
// Give the reload loop some time to run and update the cert-key pair.
215331
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
@@ -228,16 +344,22 @@ final class TimedCertificateReloaderTests: XCTestCase {
228344
}
229345

230346
func testKeyNotFoundAtReload() async throws {
231-
let keyBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(
347+
let keyBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox(
232348
Array(Self.samplePrivateKey.derRepresentation)
233349
)
234350
try await runTimedCertificateReloaderTest(
235351
certificate: .init(
236-
location: .memory(provider: { try! Self.sampleCert.serializeAsPEM().derBytes }),
352+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
237353
format: .der
238354
),
239355
privateKey: .init(
240-
location: .memory(provider: { keyBox.withLockedValue({ $0 }) }),
356+
location: .memory(provider: {
357+
let key = keyBox.withLockedValue({ $0 })
358+
if key.isEmpty {
359+
throw TestError.emptyPrivateKey
360+
}
361+
return key
362+
}),
241363
format: .der
242364
)
243365
) { reloader in
@@ -252,8 +374,8 @@ final class TimedCertificateReloaderTests: XCTestCase {
252374
.privateKey(try .init(bytes: Array(Self.samplePrivateKey.derRepresentation), format: .der))
253375
)
254376

255-
// Update the box to not contain a key.
256-
keyBox.withLockedValue({ $0 = nil })
377+
// Update the box to contain empty bytes: this will cause the provider to throw.
378+
keyBox.withLockedValue({ $0 = [] })
257379

258380
// Give the reload loop some time to run and update the cert-key pair.
259381
try await Task.sleep(for: .milliseconds(100), tolerance: .zero)
@@ -272,12 +394,12 @@ final class TimedCertificateReloaderTests: XCTestCase {
272394
}
273395

274396
func testCertificateAndKeyDoNotMatchOnReload() async throws {
275-
let keyBox: NIOLockedValueBox<[UInt8]?> = NIOLockedValueBox(
397+
let keyBox: NIOLockedValueBox<[UInt8]> = NIOLockedValueBox(
276398
Array(Self.samplePrivateKey.derRepresentation)
277399
)
278400
try await runTimedCertificateReloaderTest(
279401
certificate: .init(
280-
location: .memory(provider: { try! Self.sampleCert.serializeAsPEM().derBytes }),
402+
location: .memory(provider: { try Self.sampleCert.serializeAsPEM().derBytes }),
281403
format: .der
282404
),
283405
privateKey: .init(

0 commit comments

Comments
 (0)