Skip to content

Commit 69e7d8b

Browse files
committed
chore: extract tls loader
Signed-off-by: shuiyisong <[email protected]>
1 parent af7f9ed commit 69e7d8b

File tree

8 files changed

+228
-233
lines changed

8 files changed

+228
-233
lines changed

src/common/grpc/src/channel_manager.rs

Lines changed: 40 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,15 @@
1313
// limitations under the License.
1414

1515
use std::path::Path;
16+
use std::sync::Arc;
1617
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
17-
use std::sync::mpsc::channel;
18-
use std::sync::{Arc, RwLock};
1918
use std::time::Duration;
2019

2120
use common_base::readable_size::ReadableSize;
22-
use common_telemetry::{error, info};
21+
use common_telemetry::info;
2322
use dashmap::DashMap;
2423
use dashmap::mapref::entry::Entry;
2524
use lazy_static::lazy_static;
26-
use notify::{EventKind, RecursiveMode, Watcher};
2725
use serde::{Deserialize, Serialize};
2826
use snafu::ResultExt;
2927
use tokio_util::sync::CancellationToken;
@@ -32,7 +30,8 @@ use tonic::transport::{
3230
};
3331
use tower::Service;
3432

35-
use crate::error::{CreateChannelSnafu, FileWatchSnafu, InvalidConfigFilePathSnafu, Result};
33+
use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result};
34+
use crate::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader};
3635

3736
const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
3837
pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
@@ -191,7 +190,7 @@ impl ChannelManager {
191190
.inner
192191
.reloadable_client_tls_config
193192
.as_ref()
194-
.and_then(|c| c.get_client_config());
193+
.and_then(|c| c.get_config());
195194

196195
let http_prefix = if tls_config.is_some() {
197196
"https"
@@ -296,6 +295,36 @@ fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<Client
296295
Ok(Some(tls_config))
297296
}
298297

298+
impl TlsConfigLoader<ClientTlsConfig> for ClientTlsOption {
299+
type Config = ClientTlsConfig;
300+
type Error = crate::error::Error;
301+
302+
fn load(&self) -> Result<Option<ClientTlsConfig>> {
303+
load_tls_config(Some(self))
304+
}
305+
306+
fn watch_paths(&self) -> Vec<&Path> {
307+
let mut paths = Vec::new();
308+
if let Some(cert_path) = &self.client_cert_path {
309+
paths.push(Path::new(cert_path.as_str()));
310+
}
311+
if let Some(key_path) = &self.client_key_path {
312+
paths.push(Path::new(key_path.as_str()));
313+
}
314+
if let Some(ca_path) = &self.server_ca_cert_path {
315+
paths.push(Path::new(ca_path.as_str()));
316+
}
317+
paths
318+
}
319+
320+
fn watch_enabled(&self) -> bool {
321+
self.enabled && self.watch
322+
}
323+
}
324+
325+
/// Type alias for client-side reloadable TLS config
326+
pub type ReloadableClientTlsConfig = ReloadableTlsConfig<ClientTlsConfig, ClientTlsOption>;
327+
299328
/// Load client TLS configuration from `ClientTlsOption` and return a `ReloadableClientTlsConfig`.
300329
/// This is the primary way to create TLS configuration for the ChannelManager.
301330
pub fn load_client_tls_config(
@@ -310,142 +339,17 @@ pub fn load_client_tls_config(
310339
}
311340
}
312341

313-
/// A mutable container for TLS client config
314-
///
315-
/// This struct allows dynamic reloading of client certificates and keys
316-
#[derive(Debug)]
317-
pub struct ReloadableClientTlsConfig {
318-
tls_option: ClientTlsOption,
319-
config: RwLock<Option<ClientTlsConfig>>,
320-
version: AtomicUsize,
321-
}
322-
323-
impl ReloadableClientTlsConfig {
324-
/// Create client config by loading configuration from `ClientTlsOption`
325-
fn try_new(tls_option: ClientTlsOption) -> Result<ReloadableClientTlsConfig> {
326-
let client_config = load_tls_config(Some(&tls_option))?;
327-
Ok(Self {
328-
tls_option,
329-
config: RwLock::new(client_config),
330-
version: AtomicUsize::new(0),
331-
})
332-
}
333-
334-
/// Reread client certificates and keys from file system.
335-
fn reload(&self) -> Result<()> {
336-
let client_config = load_tls_config(Some(&self.tls_option))?;
337-
*self.config.write().unwrap() = client_config;
338-
self.version.fetch_add(1, Ordering::Relaxed);
339-
Ok(())
340-
}
341-
342-
/// Get the client config held by this container
343-
pub fn get_client_config(&self) -> Option<ClientTlsConfig> {
344-
self.config.read().unwrap().clone()
345-
}
346-
347-
/// Get associated `ClientTlsOption`
348-
pub fn get_tls_option(&self) -> &ClientTlsOption {
349-
&self.tls_option
350-
}
351-
352-
/// Get version of current config
353-
///
354-
/// this version will auto increase when client config get reloaded.
355-
pub fn get_version(&self) -> usize {
356-
self.version.load(Ordering::Relaxed)
357-
}
358-
359-
fn cert_path(&self) -> Option<&Path> {
360-
self.tls_option
361-
.client_cert_path
362-
.as_ref()
363-
.map(|p| Path::new(p.as_str()))
364-
}
365-
366-
fn key_path(&self) -> Option<&Path> {
367-
self.tls_option
368-
.client_key_path
369-
.as_ref()
370-
.map(|p| Path::new(p.as_str()))
371-
}
372-
373-
fn server_ca_cert_path(&self) -> Option<&Path> {
374-
self.tls_option
375-
.server_ca_cert_path
376-
.as_ref()
377-
.map(|p| Path::new(p.as_str()))
378-
}
379-
380-
fn watch_enabled(&self) -> bool {
381-
self.tls_option.enabled && self.tls_option.watch
382-
}
383-
}
384-
385342
pub fn maybe_watch_client_tls_config(
386343
client_tls_config: Arc<ReloadableClientTlsConfig>,
387344
channel_manager: &ChannelManager,
388345
) -> Result<()> {
389-
if !client_tls_config.watch_enabled() {
390-
return Ok(());
391-
}
392-
393-
let client_tls_config_for_watcher = client_tls_config.clone();
394346
let channel_manager_for_watcher = channel_manager.clone();
395347

396-
let (tx, rx) = channel::<notify::Result<notify::Event>>();
397-
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
398-
399-
// Watch client cert if present
400-
if let Some(cert_path) = client_tls_config.cert_path() {
401-
watcher
402-
.watch(cert_path, RecursiveMode::NonRecursive)
403-
.with_context(|_| FileWatchSnafu {
404-
path: cert_path.display().to_string(),
405-
})?;
406-
}
407-
408-
// Watch client key if present
409-
if let Some(key_path) = client_tls_config.key_path() {
410-
watcher
411-
.watch(key_path, RecursiveMode::NonRecursive)
412-
.with_context(|_| FileWatchSnafu {
413-
path: key_path.display().to_string(),
414-
})?;
415-
}
416-
417-
// Watch server CA cert if present
418-
if let Some(ca_path) = client_tls_config.server_ca_cert_path() {
419-
watcher
420-
.watch(ca_path, RecursiveMode::NonRecursive)
421-
.with_context(|_| FileWatchSnafu {
422-
path: ca_path.display().to_string(),
423-
})?;
424-
}
425-
426-
std::thread::spawn(move || {
427-
let _watcher = watcher;
428-
while let Ok(res) = rx.recv() {
429-
if let Ok(event) = res {
430-
match event.kind {
431-
EventKind::Modify(_) | EventKind::Create(_) => {
432-
info!("Detected TLS cert/key file change: {:?}", event);
433-
if let Err(err) = client_tls_config_for_watcher.reload() {
434-
error!(err; "Failed to reload TLS client config");
435-
} else {
436-
info!("Reloaded TLS cert/key file successfully.");
437-
// Clear all existing channels to force reconnection with new certificates
438-
channel_manager_for_watcher.clear_all_channels();
439-
info!("Cleared all existing channels to use new TLS certificates.");
440-
}
441-
}
442-
_ => {}
443-
}
444-
}
445-
}
446-
});
447-
448-
Ok(())
348+
crate::reloadable_tls::maybe_watch_tls_config(client_tls_config, move || {
349+
// Clear all existing channels to force reconnection with new certificates
350+
channel_manager_for_watcher.clear_all_channels();
351+
info!("Cleared all existing channels to use new TLS certificates.");
352+
})
449353
}
450354

451355
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]

src/common/grpc/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub mod channel_manager;
1616
pub mod error;
1717
pub mod flight;
1818
pub mod precision;
19+
pub mod reloadable_tls;
1920
pub mod select;
2021

2122
pub use arrow_flight::FlightData;
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Copyright 2023 Greptime Team
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::path::Path;
16+
use std::sync::atomic::{AtomicUsize, Ordering};
17+
use std::sync::mpsc::channel;
18+
use std::sync::{Arc, RwLock};
19+
20+
use common_telemetry::{error, info};
21+
use notify::{EventKind, RecursiveMode, Watcher};
22+
use snafu::ResultExt;
23+
24+
use crate::error::{FileWatchSnafu, Result};
25+
26+
/// A trait for loading TLS configuration from an option type
27+
pub trait TlsConfigLoader<T> {
28+
type Config;
29+
type Error;
30+
31+
/// Load the TLS configuration
32+
fn load(&self) -> std::result::Result<Option<T>, Self::Error>;
33+
34+
/// Get paths to certificate files for watching
35+
fn watch_paths(&self) -> Vec<&Path>;
36+
37+
/// Check if watching is enabled
38+
fn watch_enabled(&self) -> bool;
39+
}
40+
41+
/// A mutable container for TLS config
42+
///
43+
/// This struct allows dynamic reloading of certificates and keys.
44+
/// It's generic over the config type (e.g., ServerConfig, ClientTlsConfig)
45+
/// and the option type (e.g., TlsOption, ClientTlsOption).
46+
#[derive(Debug)]
47+
pub struct ReloadableTlsConfig<T, O>
48+
where
49+
O: TlsConfigLoader<T>,
50+
{
51+
tls_option: O,
52+
config: RwLock<Option<T>>,
53+
version: AtomicUsize,
54+
}
55+
56+
impl<T, O> ReloadableTlsConfig<T, O>
57+
where
58+
O: TlsConfigLoader<T>,
59+
{
60+
/// Create config by loading configuration from the option type
61+
pub fn try_new(tls_option: O) -> std::result::Result<Self, O::Error> {
62+
let config = tls_option.load()?;
63+
Ok(Self {
64+
tls_option,
65+
config: RwLock::new(config),
66+
version: AtomicUsize::new(0),
67+
})
68+
}
69+
70+
/// Reread certificates and keys from file system.
71+
pub fn reload(&self) -> std::result::Result<(), O::Error> {
72+
let config = self.tls_option.load()?;
73+
*self.config.write().unwrap() = config;
74+
self.version.fetch_add(1, Ordering::Relaxed);
75+
Ok(())
76+
}
77+
78+
/// Get the config held by this container
79+
pub fn get_config(&self) -> Option<T>
80+
where
81+
T: Clone,
82+
{
83+
self.config.read().unwrap().clone()
84+
}
85+
86+
/// Get associated option
87+
pub fn get_tls_option(&self) -> &O {
88+
&self.tls_option
89+
}
90+
91+
/// Get version of current config
92+
///
93+
/// this version will auto increase when config get reloaded.
94+
pub fn get_version(&self) -> usize {
95+
self.version.load(Ordering::Relaxed)
96+
}
97+
}
98+
99+
/// Watch TLS configuration files for changes and reload automatically
100+
///
101+
/// This is a generic function that works with any ReloadableTlsConfig.
102+
/// When changes are detected, it calls the provided callback after reloading.
103+
pub fn maybe_watch_tls_config<T, O, F, E>(
104+
tls_config: Arc<ReloadableTlsConfig<T, O>>,
105+
on_reload: F,
106+
) -> Result<()>
107+
where
108+
T: Send + Sync + 'static,
109+
O: TlsConfigLoader<T, Error = E> + Send + Sync + 'static,
110+
E: std::error::Error + Send + Sync + 'static,
111+
F: Fn() + Send + 'static,
112+
{
113+
if !tls_config.get_tls_option().watch_enabled() {
114+
return Ok(());
115+
}
116+
117+
let tls_config_for_watcher = tls_config.clone();
118+
119+
let (tx, rx) = channel::<notify::Result<notify::Event>>();
120+
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
121+
122+
// Watch all paths returned by the TlsConfigLoader
123+
for path in tls_config.get_tls_option().watch_paths() {
124+
watcher
125+
.watch(path, RecursiveMode::NonRecursive)
126+
.with_context(|_| FileWatchSnafu {
127+
path: path.display().to_string(),
128+
})?;
129+
}
130+
131+
std::thread::spawn(move || {
132+
let _watcher = watcher;
133+
while let Ok(res) = rx.recv() {
134+
if let Ok(event) = res {
135+
match event.kind {
136+
EventKind::Modify(_) | EventKind::Create(_) => {
137+
info!("Detected TLS cert/key file change: {:?}", event);
138+
if let Err(err) = tls_config_for_watcher.reload() {
139+
error!("Failed to reload TLS config: {}", err);
140+
} else {
141+
info!("Reloaded TLS cert/key file successfully.");
142+
on_reload();
143+
}
144+
}
145+
_ => {}
146+
}
147+
}
148+
}
149+
});
150+
151+
Ok(())
152+
}

0 commit comments

Comments
 (0)