1313// limitations under the License.
1414
1515use std:: path:: Path ;
16+ use std:: sync:: Arc ;
1617use std:: sync:: atomic:: { AtomicBool , AtomicU64 , AtomicUsize , Ordering } ;
17- use std:: sync:: mpsc:: channel;
18- use std:: sync:: { Arc , RwLock } ;
1918use std:: time:: Duration ;
2019
2120use common_base:: readable_size:: ReadableSize ;
22- use common_telemetry:: { error , info} ;
21+ use common_telemetry:: info;
2322use dashmap:: DashMap ;
2423use dashmap:: mapref:: entry:: Entry ;
2524use lazy_static:: lazy_static;
26- use notify:: { EventKind , RecursiveMode , Watcher } ;
2725use serde:: { Deserialize , Serialize } ;
2826use snafu:: ResultExt ;
2927use tokio_util:: sync:: CancellationToken ;
@@ -32,7 +30,8 @@ use tonic::transport::{
3230} ;
3331use tower:: Service ;
3432
35- use crate :: error:: { CreateChannelSnafu , FileWatchSnafu , InvalidConfigFilePathSnafu , Result } ;
33+ use crate :: error:: { CreateChannelSnafu , InvalidConfigFilePathSnafu , Result } ;
34+ use crate :: reloadable_tls:: { ReloadableTlsConfig , TlsConfigLoader } ;
3635
3736const RECYCLE_CHANNEL_INTERVAL_SECS : u64 = 60 ;
3837pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS : u64 = 10 ;
@@ -191,7 +190,7 @@ impl ChannelManager {
191190 . inner
192191 . reloadable_client_tls_config
193192 . as_ref ( )
194- . and_then ( |c| c. get_client_config ( ) ) ;
193+ . and_then ( |c| c. get_config ( ) ) ;
195194
196195 let http_prefix = if tls_config. is_some ( ) {
197196 "https"
@@ -296,6 +295,36 @@ fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<Client
296295 Ok ( Some ( tls_config) )
297296}
298297
298+ impl TlsConfigLoader < ClientTlsConfig > for ClientTlsOption {
299+ type Config = ClientTlsConfig ;
300+ type Error = crate :: error:: Error ;
301+
302+ fn load ( & self ) -> Result < Option < ClientTlsConfig > > {
303+ load_tls_config ( Some ( self ) )
304+ }
305+
306+ fn watch_paths ( & self ) -> Vec < & Path > {
307+ let mut paths = Vec :: new ( ) ;
308+ if let Some ( cert_path) = & self . client_cert_path {
309+ paths. push ( Path :: new ( cert_path. as_str ( ) ) ) ;
310+ }
311+ if let Some ( key_path) = & self . client_key_path {
312+ paths. push ( Path :: new ( key_path. as_str ( ) ) ) ;
313+ }
314+ if let Some ( ca_path) = & self . server_ca_cert_path {
315+ paths. push ( Path :: new ( ca_path. as_str ( ) ) ) ;
316+ }
317+ paths
318+ }
319+
320+ fn watch_enabled ( & self ) -> bool {
321+ self . enabled && self . watch
322+ }
323+ }
324+
325+ /// Type alias for client-side reloadable TLS config
326+ pub type ReloadableClientTlsConfig = ReloadableTlsConfig < ClientTlsConfig , ClientTlsOption > ;
327+
299328/// Load client TLS configuration from `ClientTlsOption` and return a `ReloadableClientTlsConfig`.
300329/// This is the primary way to create TLS configuration for the ChannelManager.
301330pub fn load_client_tls_config (
@@ -310,142 +339,17 @@ pub fn load_client_tls_config(
310339 }
311340}
312341
313- /// A mutable container for TLS client config
314- ///
315- /// This struct allows dynamic reloading of client certificates and keys
316- #[ derive( Debug ) ]
317- pub struct ReloadableClientTlsConfig {
318- tls_option : ClientTlsOption ,
319- config : RwLock < Option < ClientTlsConfig > > ,
320- version : AtomicUsize ,
321- }
322-
323- impl ReloadableClientTlsConfig {
324- /// Create client config by loading configuration from `ClientTlsOption`
325- fn try_new ( tls_option : ClientTlsOption ) -> Result < ReloadableClientTlsConfig > {
326- let client_config = load_tls_config ( Some ( & tls_option) ) ?;
327- Ok ( Self {
328- tls_option,
329- config : RwLock :: new ( client_config) ,
330- version : AtomicUsize :: new ( 0 ) ,
331- } )
332- }
333-
334- /// Reread client certificates and keys from file system.
335- fn reload ( & self ) -> Result < ( ) > {
336- let client_config = load_tls_config ( Some ( & self . tls_option ) ) ?;
337- * self . config . write ( ) . unwrap ( ) = client_config;
338- self . version . fetch_add ( 1 , Ordering :: Relaxed ) ;
339- Ok ( ( ) )
340- }
341-
342- /// Get the client config held by this container
343- pub fn get_client_config ( & self ) -> Option < ClientTlsConfig > {
344- self . config . read ( ) . unwrap ( ) . clone ( )
345- }
346-
347- /// Get associated `ClientTlsOption`
348- pub fn get_tls_option ( & self ) -> & ClientTlsOption {
349- & self . tls_option
350- }
351-
352- /// Get version of current config
353- ///
354- /// this version will auto increase when client config get reloaded.
355- pub fn get_version ( & self ) -> usize {
356- self . version . load ( Ordering :: Relaxed )
357- }
358-
359- fn cert_path ( & self ) -> Option < & Path > {
360- self . tls_option
361- . client_cert_path
362- . as_ref ( )
363- . map ( |p| Path :: new ( p. as_str ( ) ) )
364- }
365-
366- fn key_path ( & self ) -> Option < & Path > {
367- self . tls_option
368- . client_key_path
369- . as_ref ( )
370- . map ( |p| Path :: new ( p. as_str ( ) ) )
371- }
372-
373- fn server_ca_cert_path ( & self ) -> Option < & Path > {
374- self . tls_option
375- . server_ca_cert_path
376- . as_ref ( )
377- . map ( |p| Path :: new ( p. as_str ( ) ) )
378- }
379-
380- fn watch_enabled ( & self ) -> bool {
381- self . tls_option . enabled && self . tls_option . watch
382- }
383- }
384-
385342pub fn maybe_watch_client_tls_config (
386343 client_tls_config : Arc < ReloadableClientTlsConfig > ,
387344 channel_manager : & ChannelManager ,
388345) -> Result < ( ) > {
389- if !client_tls_config. watch_enabled ( ) {
390- return Ok ( ( ) ) ;
391- }
392-
393- let client_tls_config_for_watcher = client_tls_config. clone ( ) ;
394346 let channel_manager_for_watcher = channel_manager. clone ( ) ;
395347
396- let ( tx, rx) = channel :: < notify:: Result < notify:: Event > > ( ) ;
397- let mut watcher = notify:: recommended_watcher ( tx) . context ( FileWatchSnafu { path : "<none>" } ) ?;
398-
399- // Watch client cert if present
400- if let Some ( cert_path) = client_tls_config. cert_path ( ) {
401- watcher
402- . watch ( cert_path, RecursiveMode :: NonRecursive )
403- . with_context ( |_| FileWatchSnafu {
404- path : cert_path. display ( ) . to_string ( ) ,
405- } ) ?;
406- }
407-
408- // Watch client key if present
409- if let Some ( key_path) = client_tls_config. key_path ( ) {
410- watcher
411- . watch ( key_path, RecursiveMode :: NonRecursive )
412- . with_context ( |_| FileWatchSnafu {
413- path : key_path. display ( ) . to_string ( ) ,
414- } ) ?;
415- }
416-
417- // Watch server CA cert if present
418- if let Some ( ca_path) = client_tls_config. server_ca_cert_path ( ) {
419- watcher
420- . watch ( ca_path, RecursiveMode :: NonRecursive )
421- . with_context ( |_| FileWatchSnafu {
422- path : ca_path. display ( ) . to_string ( ) ,
423- } ) ?;
424- }
425-
426- std:: thread:: spawn ( move || {
427- let _watcher = watcher;
428- while let Ok ( res) = rx. recv ( ) {
429- if let Ok ( event) = res {
430- match event. kind {
431- EventKind :: Modify ( _) | EventKind :: Create ( _) => {
432- info ! ( "Detected TLS cert/key file change: {:?}" , event) ;
433- if let Err ( err) = client_tls_config_for_watcher. reload ( ) {
434- error ! ( err; "Failed to reload TLS client config" ) ;
435- } else {
436- info ! ( "Reloaded TLS cert/key file successfully." ) ;
437- // Clear all existing channels to force reconnection with new certificates
438- channel_manager_for_watcher. clear_all_channels ( ) ;
439- info ! ( "Cleared all existing channels to use new TLS certificates." ) ;
440- }
441- }
442- _ => { }
443- }
444- }
445- }
446- } ) ;
447-
448- Ok ( ( ) )
348+ crate :: reloadable_tls:: maybe_watch_tls_config ( client_tls_config, move || {
349+ // Clear all existing channels to force reconnection with new certificates
350+ channel_manager_for_watcher. clear_all_channels ( ) ;
351+ info ! ( "Cleared all existing channels to use new TLS certificates." ) ;
352+ } )
449353}
450354
451355#[ derive( Clone , Debug , PartialEq , Eq , Serialize , Deserialize ) ]
0 commit comments