diff --git a/src/lib.rs b/src/lib.rs index fc36c776..57719aa3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ use multiaddr::{Multiaddr, Protocol}; use transport::Endpoint; use types::ConnectionId; +use crate::transport::manager::DialFailureAddresses; pub use bandwidth::BandwidthSink; pub use error::Error; pub use peer_id::PeerId; @@ -198,6 +199,7 @@ impl Litep2p { config.fallback_names.clone(), config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); let executor = Arc::clone(&litep2p_config.executor); litep2p_config.executor.run(Box::pin(async move { @@ -218,6 +220,7 @@ impl Litep2p { config.fallback_names.clone(), config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { RequestResponseProtocol::new(service, config).run().await @@ -233,6 +236,7 @@ impl Litep2p { Vec::new(), protocol.codec(), litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { let _ = protocol.run(service).await; @@ -252,6 +256,7 @@ impl Litep2p { Vec::new(), ping_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { Ping::new(service, ping_config).run().await @@ -275,6 +280,7 @@ impl Litep2p { fallback_names, kademlia_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::Required, ); litep2p_config.executor.run(Box::pin(async move { let _ = Kademlia::new(service, kademlia_config).run().await; @@ -296,6 +302,7 @@ impl Litep2p { Vec::new(), identify_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); identify_config.public = Some(litep2p_config.keypair.public().into()); @@ -316,6 +323,7 @@ impl Litep2p { Vec::new(), bitswap_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { Bitswap::new(service, bitswap_config).run().await diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 013d4a7b..05b6873e 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -73,6 +73,19 @@ pub(crate) mod handle; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::transport-manager"; +/// Determines if a protocol requires the list of failed addresses upon a dial failure. +/// +/// This is used during protocol registration with the `TransportManager` to specify +/// whether `InnerTransportEvent::DialFailure` events sent to this protocol should +/// include the specific multiaddresses that failed. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum DialFailureAddresses { + /// The protocol needs the list of failed addresses. + Required, + /// The protocol does not need the list of failed addresses. + NotRequired, +} + /// The connection established result. #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum ConnectionEstablishedResult { @@ -106,6 +119,9 @@ pub struct ProtocolContext { /// Fallback names for the protocol. pub fallback_names: Vec, + + /// Specifies if the protocol requires dial failure addresses. + pub dial_failure_mode: DialFailureAddresses, } impl ProtocolContext { @@ -114,11 +130,20 @@ impl ProtocolContext { codec: ProtocolCodec, tx: Sender, fallback_names: Vec, + dial_failure_mode: DialFailureAddresses, ) -> Self { Self { tx, codec, fallback_names, + dial_failure_mode, + } + } + + fn dial_failure_addresses(&self, addresses: &[Multiaddr]) -> Vec { + match self.dial_failure_mode { + DialFailureAddresses::Required => addresses.to_vec(), + DialFailureAddresses::NotRequired => Vec::new(), } } } @@ -332,6 +357,7 @@ impl TransportManager { fallback_names: Vec, codec: ProtocolCodec, keep_alive_timeout: Duration, + dial_failure_mode: DialFailureAddresses, ) -> TransportService { assert!(!self.protocol_names.contains(&protocol)); @@ -352,7 +378,7 @@ impl TransportManager { self.protocols.insert( protocol.clone(), - ProtocolContext::new(codec, sender, fallback_names.clone()), + ProtocolContext::new(codec, sender, fallback_names.clone(), dial_failure_mode), ); self.protocol_names.insert(protocol); self.protocol_names.extend(fallback_names); @@ -1116,10 +1142,10 @@ impl TransportManager { ?protocol, "dial failure, notify protocol", ); - match context.tx.try_send(InnerTransportEvent::DialFailure { - peer, - addresses: vec![address.clone()], - }) { + + let addresses = context.dial_failure_addresses(&[address.clone()]); + + match context.tx.try_send(InnerTransportEvent::DialFailure { peer, addresses: addresses.clone() }) { Ok(()) => {} Err(_) => { tracing::trace!( @@ -1132,10 +1158,7 @@ impl TransportManager { ); let _ = context .tx - .send(InnerTransportEvent::DialFailure { - peer, - addresses: vec![address.clone()], - }) + .send(InnerTransportEvent::DialFailure { peer, addresses }) .await; } } @@ -1266,12 +1289,10 @@ impl TransportManager { .collect::>(); for (protocol, context) in &self.protocols { + let addresses = context.dial_failure_addresses(&addresses); let _ = match context .tx - .try_send(InnerTransportEvent::DialFailure { - peer, - addresses: addresses.clone(), - }) { + .try_send(InnerTransportEvent::DialFailure { peer, addresses: addresses.clone() }) { Ok(_) => Ok(()), Err(_) => { tracing::trace!( @@ -1284,10 +1305,7 @@ impl TransportManager { context .tx - .send(InnerTransportEvent::DialFailure { - peer, - addresses: addresses.clone(), - }) + .send(InnerTransportEvent::DialFailure { peer, addresses }) .await } };