diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b58d59..3842f94 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,9 +33,9 @@ jobs: cargo check --all-targets cargo check --all-targets --no-default-features --features tokio cargo check --all-targets --no-default-features --features async-std - cargo test --features js_interop_tests - cargo test --no-default-features --features js_interop_tests,tokio - cargo test --no-default-features --features js_interop_tests,async-std + cargo test --features js_tests + cargo test --no-default-features --features js_tests,tokio + cargo test --no-default-features --features js_tests,async-std cargo test --benches build-extra: diff --git a/Cargo.toml b/Cargo.toml index d77679f..df292c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,17 +39,18 @@ futures-lite = "1" sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" +futures = "0.3.31" +compact-encoding = "2" [dependencies.hypercore] -version = "0.14.0" -default-features = false - +path = "../core" +#version = "0.14.0" +#default-features = false [dev-dependencies] async-std = { version = "1.12.0", features = ["attributes", "unstable"] } async-compat = "0.2.1" tokio = { version = "1.27.0", features = ["macros", "net", "process", "rt", "rt-multi-thread", "sync", "time"] } -env_logger = "0.7.1" anyhow = "1.0.28" instant = "0.1" criterion = { version = "0.4", features = ["async_std"] } @@ -57,9 +58,9 @@ pretty-bytes = "0.2.2" duplexify = "1.1.0" sluice = "0.5.4" futures = "0.3.13" -log = "0.4" -test-log = { version = "0.2.11", default-features = false, features = ["trace"] } -tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } +tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } +tracing-tree = "0.4.0" +tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse"] @@ -72,8 +73,8 @@ tokio = ["hypercore/tokio"] async-std = ["hypercore/async-std"] # Used only in interoperability tests under tests/js-interop which use the javascript version of hypercore # to verify that this crate works. To run them, use: -# cargo test --features js_interop_tests -js_interop_tests = [] +# cargo test --features js_tests +js_tests = [] [profile.bench] # debug = true diff --git a/README.md b/README.md index b8ed180..fada9df 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ node examples-nodejs/run.js node ## Development -To test interoperability with Javascript, enable the `js_interop_tests` feature: +To test interoperability with Javascript, enable the `js_tests` feature: ```bash -cargo test --features js_interop_tests +cargo test --features js_tests ``` Run benches with: diff --git a/benches/pipe.rs b/benches/pipe.rs index 630146c..9f87d84 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -1,24 +1,26 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use async_std::task; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::io::{AsyncRead, AsyncWrite}; -use futures::stream::StreamExt; -use hypercore_protocol::{schema::*, Duplex}; -use hypercore_protocol::{Channel, Event, Message, Protocol, ProtocolBuilder}; -use log::*; +use futures::{ + io::{AsyncRead, AsyncWrite}, + stream::StreamExt, +}; +use hypercore_protocol::{schema::*, Channel, Duplex, Event, Message, Protocol, ProtocolBuilder}; use pretty_bytes::converter::convert as pretty_bytes; use sluice::pipe::pipe; -use std::io::Result; -use std::time::Instant; +use std::{io::Result, time::Instant}; +use tracing::{debug, error}; const COUNT: u64 = 1000; const SIZE: u64 = 100; const CONNS: u64 = 10; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); + test_utils::log(); let mut group = c.benchmark_group("pipe"); group.sample_size(10); - group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS as u64)); + group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS)); group.bench_function("pipe_echo", |b| { b.iter(|| { task::block_on(async move { @@ -72,7 +74,7 @@ where debug!("[{}] EVENT {:?}", is_initiator, event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await?; + protocol.open(key).await?; } Event::DiscoveryKey(_dkey) => {} Event::Channel(channel) => { @@ -92,7 +94,7 @@ where } Some(Err(err)) => { error!("ERROR {:?}", err); - return Err(err.into()); + return Err(err); } None => return Ok(0), } @@ -127,20 +129,17 @@ async fn on_channel_init(i: u64, mut channel: Channel) -> Result { let start = std::time::Instant::now(); while let Some(message) = channel.next().await { - match message { - Message::Data(mut data) => { - len += value_len(&data); - debug!("[a] recv {}", index(&data)); - if index(&data) >= COUNT { - debug!("close at {}", index(&data)); - channel.close().await?; - break; - } else { - increment_index(&mut data); - channel.send(Message::Data(data)).await?; - } + if let Message::Data(mut data) = message { + len += value_len(&data); + debug!("[a] recv {}", index(&data)); + if index(&data) >= COUNT { + debug!("close at {}", index(&data)); + channel.close().await?; + break; + } else { + increment_index(&mut data); + channel.send(Message::Data(data)).await?; } - _ => {} } } // let bytes = (COUNT * SIZE) as f64; diff --git a/benches/throughput.rs b/benches/throughput.rs index 76d6874..b19167e 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -1,13 +1,18 @@ -use async_std::net::{Shutdown, TcpListener, TcpStream}; -use async_std::task; +#[path = "../src/test_utils.rs"] +mod test_utils; +use async_std::{ + net::{Shutdown, TcpListener, TcpStream}, + task, +}; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::future::Either; -use futures::io::{AsyncRead, AsyncWrite}; -use futures::stream::{FuturesUnordered, StreamExt}; -use hypercore_protocol::{schema::*, Duplex}; -use hypercore_protocol::{Channel, Event, Message, ProtocolBuilder}; -use log::*; +use futures::{ + future::Either, + io::{AsyncRead, AsyncWrite}, + stream::{FuturesUnordered, StreamExt}, +}; +use hypercore_protocol::{schema::*, Channel, Event, Message, ProtocolBuilder}; use std::time::Instant; +use tracing::{debug, info, trace}; const PORT: usize = 11011; const SIZE: u64 = 1000; @@ -15,7 +20,7 @@ const COUNT: u64 = 200; const CLIENTS: usize = 1; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); + test_utils::log(); let address = format!("localhost:{}", PORT); let mut group = c.benchmark_group("throughput"); @@ -64,23 +69,22 @@ criterion_main!(server_benches); async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { let listener = TcpListener::bind(&address).await.unwrap(); - log::info!("listening on {}", listener.local_addr().unwrap()); + info!("listening on {}", listener.local_addr().unwrap()); let (kill_tx, mut kill_rx) = futures::channel::oneshot::channel(); task::spawn(async move { let mut incoming = listener.incoming(); // let kill_rx = &mut kill_rx; loop { match futures::future::select(incoming.next(), &mut kill_rx).await { - Either::Left((next, _)) => match next { - Some(Ok(stream)) => { + Either::Left((next, _)) => { + if let Some(Ok(stream)) = next { let peer_addr = stream.peer_addr().unwrap(); debug!("new connection from {}", peer_addr); task::spawn(async move { onconnection(stream.clone(), stream, false).await; }); } - _ => {} - }, + } Either::Right((_, _)) => return, } } @@ -88,7 +92,7 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { kill_tx } -async fn onconnection(reader: R, writer: W, is_initiator: bool) -> Duplex +async fn onconnection(reader: R, writer: W, is_initiator: bool) where R: AsyncRead + Send + Unpin + 'static, W: AsyncWrite + Send + Unpin + 'static, @@ -101,19 +105,18 @@ where // eprintln!("RECV EVENT [{}] {:?}", protocol.is_initiator(), event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await.unwrap(); + protocol.open(key).await.unwrap(); } Event::DiscoveryKey(_) => {} Event::Channel(channel) => { task::spawn(onchannel(channel, is_initiator)); } Event::Close(_dkey) => { - return protocol.release(); + return; } _ => {} } } - protocol.release() } async fn onchannel(mut channel: Channel, is_initiator: bool) { @@ -127,9 +130,8 @@ async fn onchannel(mut channel: Channel, is_initiator: bool) { async fn channel_server(channel: &mut Channel) { while let Some(message) = channel.next().await { - match message { - Message::Data(_) => channel.send(message).await.unwrap(), - _ => {} + if let Message::Data(_) = message { + channel.send(message).await.unwrap() } } } @@ -140,24 +142,21 @@ async fn channel_client(channel: &mut Channel) { let message = msg_data(0, data.clone()); channel.send(message).await.unwrap(); while let Some(message) = channel.next().await { - match message { - Message::Data(ref msg) => { - if index(msg) < COUNT { - let message = msg_data(index(msg) + 1, data.clone()); - channel.send(message).await.unwrap(); - } else { - let time = start.elapsed(); - let bytes = COUNT * SIZE; - trace!( - "client completed. {} blocks, {} bytes, {:?}", - index(msg), - bytes, - time - ); - break; - } + if let Message::Data(ref msg) = message { + if index(msg) < COUNT { + let message = msg_data(index(msg) + 1, data.clone()); + channel.send(message).await.unwrap(); + } else { + let time = start.elapsed(); + let bytes = COUNT * SIZE; + trace!( + "client completed. {} blocks, {} bytes, {:?}", + index(msg), + bytes, + time + ); + break; } - _ => {} } } } diff --git a/examples-nodejs/run.js b/examples-nodejs/run.js index c96541f..ac77bba 100644 --- a/examples-nodejs/run.js +++ b/examples-nodejs/run.js @@ -37,7 +37,8 @@ function startRust (mode, key, color, name) { color: color || 'blue', env: { ...process.env, - RUST_LOG_STYLE: 'always' + RUST_LOG_STYLE: 'always', + RUST_LOG: 'trace' } }) return rust diff --git a/examples/replication.rs b/examples/replication.rs index bf65b72..35e2908 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -1,25 +1,24 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use anyhow::Result; -use async_std::net::{TcpListener, TcpStream}; -use async_std::prelude::*; -use async_std::sync::{Arc, Mutex}; -use async_std::task; -use env_logger::Env; +use async_std::{ + net::{TcpListener, TcpStream}, + prelude::*, + sync::{Arc, Mutex}, + task, +}; use futures_lite::stream::StreamExt; use hypercore::{ Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, VerifyingKey, }; -use log::*; -use std::collections::HashMap; -use std::convert::TryInto; -use std::env; -use std::fmt::Debug; +use std::{collections::HashMap, convert::TryInto, env, fmt::Debug}; +use tracing::{error, info, instrument}; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; +use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; fn main() { - init_logger(); + test_utils::log(); if env::args().count() < 3 { usage(); } @@ -65,12 +64,11 @@ fn main() { hypercore_store.add(hypercore_wrapper); let hypercore_store = Arc::new(hypercore_store); - let result = match mode.as_ref() { + let _ = match mode.as_ref() { "server" => tcp_server(address, onconnection, hypercore_store).await, "client" => tcp_client(address, onconnection, hypercore_store).await, _ => panic!("{:?}", usage()), }; - log_if_error(&result); }); } @@ -84,6 +82,7 @@ fn usage() { // or once when connected (if client). // Unfortunately, everything that touches the hypercore_store or a hypercore has to be generic // at the moment. +#[instrument(skip_all, ret)] async fn onconnection( stream: TcpStream, is_initiator: bool, @@ -93,8 +92,8 @@ async fn onconnection( let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); info!("protocol created, polling for next()"); while let Some(event) = protocol.next().await { - let event = event?; info!("protocol event {:?}", event); + let event = event?; match event { Event::Handshake(_) => { if is_initiator { @@ -126,17 +125,17 @@ struct HypercoreStore { hypercores: HashMap>, } impl HypercoreStore { - pub fn new() -> Self { + fn new() -> Self { let hypercores = HashMap::new(); Self { hypercores } } - pub fn add(&mut self, hypercore: HypercoreWrapper) { + fn add(&mut self, hypercore: HypercoreWrapper) { let hdkey = hex::encode(hypercore.discovery_key); self.hypercores.insert(hdkey, Arc::new(hypercore)); } - pub fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { + fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { let hdkey = hex::encode(discovery_key); self.hypercores.get(&hdkey) } @@ -151,7 +150,7 @@ struct HypercoreWrapper { } impl HypercoreWrapper { - pub fn from_memory_hypercore(hypercore: Hypercore) -> Self { + fn from_memory_hypercore(hypercore: Hypercore) -> Self { let key = hypercore.key_pair().public.to_bytes(); HypercoreWrapper { key, @@ -160,11 +159,11 @@ impl HypercoreWrapper { } } - pub fn key(&self) -> &[u8; 32] { + fn key(&self) -> &[u8; 32] { &self.key } - pub fn onpeer(&self, mut channel: Channel) { + fn onpeer(&self, mut channel: Channel) { let mut peer_state = PeerState::default(); let mut hypercore = self.hypercore.clone(); task::spawn(async move { @@ -299,6 +298,8 @@ async fn onmessage( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -405,6 +406,8 @@ async fn onmessage( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } channel.send_batch(&messages).await.unwrap(); @@ -414,20 +417,9 @@ async fn onmessage( Ok(()) } -/// Init EnvLogger, logging info, warn and error messages to stdout. -pub fn init_logger() { - env_logger::from_env(Env::default().default_filter_or("info")).init(); -} - -/// Log a result if it's an error. -pub fn log_if_error(result: &Result<()>) { - if let Err(err) = result.as_ref() { - log::error!("error: {}", err); - } -} - /// A simple async TCP server that calls an async function for each incoming connection. -pub async fn tcp_server( +#[instrument(skip_all, ret)] +async fn tcp_server( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, context: C, @@ -437,22 +429,22 @@ where C: Clone + Send + 'static, { let listener = TcpListener::bind(&address).await?; - log::info!("listening on {}", listener.local_addr()?); + tracing::info!("listening on {}", listener.local_addr()?); let mut incoming = listener.incoming(); while let Some(Ok(stream)) = incoming.next().await { let context = context.clone(); let peer_addr = stream.peer_addr().unwrap(); - log::info!("new connection from {}", peer_addr); + tracing::info!("new connection from {}", peer_addr); task::spawn(async move { - let result = onconnection(stream, false, context).await; - log_if_error(&result); - log::info!("connection closed from {}", peer_addr); + let _ = onconnection(stream, false, context).await; + tracing::info!("connection closed from {}", peer_addr); }); } Ok(()) } /// A simple async TCP client that calls an async function when connected. +#[instrument(skip_all, ret)] pub async fn tcp_client( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, @@ -462,8 +454,8 @@ where F: Future> + Send, C: Clone + Send + 'static, { - log::info!("attempting connection to {address}"); + tracing::info!("attempting connection to {address}"); let stream = TcpStream::connect(&address).await?; - log::info!("connected to {address}"); + tracing::info!("connected to {address}"); onconnection(stream, true, context).await } diff --git a/src/builder.rs b/src/builder.rs index d797654..0b9127e 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,5 +1,4 @@ -use crate::Protocol; -use crate::{duplex::Duplex, protocol::Options}; +use crate::{duplex::Duplex, protocol::Options, Protocol}; use futures_lite::io::{AsyncRead, AsyncWrite}; /// Build a Protocol instance with options. diff --git a/src/channels.rs b/src/channels.rs index c2e22f8..f16ac7f 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -1,19 +1,24 @@ -use crate::message::ChannelMessage; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::Message; -use crate::{discovery_key, DiscoveryKey, Key}; +use crate::{ + discovery_key, + message::ChannelMessage, + schema::*, + util::{map_channel_err, pretty_hash}, + DiscoveryKey, Key, Message, +}; use async_channel::{Receiver, Sender, TrySendError}; -use futures_lite::ready; -use futures_lite::stream::Stream; -use std::collections::HashMap; -use std::fmt; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::task::Poll; -use tracing::debug; +use futures_lite::{ready, stream::Stream}; +use std::{ + collections::HashMap, + fmt, + io::{Error, ErrorKind, Result}, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::Poll, +}; +use tracing::instrument; /// A protocol channel. /// @@ -93,7 +98,6 @@ impl Channel { "Channel is closed", )); } - debug!("TX:\n{message:?}\n"); let message = ChannelMessage::new(self.local_id as u64, message); self.outbound_tx .send(vec![message]) @@ -122,10 +126,7 @@ impl Channel { let messages = messages .iter() - .map(|message| { - debug!("TX:\n{message:?}\n"); - ChannelMessage::new(self.local_id as u64, message.clone()) - }) + .map(|message| ChannelMessage::new(self.local_id as u64, message.clone())) .collect(); self.outbound_tx .send(messages) @@ -249,6 +250,7 @@ impl ChannelHandle { self.remote_state.as_ref().map(|s| s.remote_id) } + #[instrument(skip_all, fields(local_id = local_id))] pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) { let local_state = LocalState { local_id, key }; self.local_state = Some(local_state); @@ -276,6 +278,7 @@ impl ChannelHandle { Ok((&local_state.key, remote_state.remote_capability.as_ref())) } + #[instrument(skip_all)] pub(crate) fn open(&mut self, outbound_tx: Sender>) -> Channel { let local_state = self .local_state @@ -433,6 +436,7 @@ impl ChannelMap { self.channels.remove(&hdkey); } + #[instrument(skip(self))] pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec>)> { let channel_handle = self .get_local(local_id) @@ -477,6 +481,7 @@ impl ChannelMap { Ok(()) } + #[instrument(skip_all)] fn alloc_local(&mut self) -> usize { let empty_id = self .local_id diff --git a/src/constants.rs b/src/constants.rs index 77285ee..1efbbed 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,15 +1,8 @@ /// Seed for the discovery key hash pub(crate) const DISCOVERY_NS_BUF: &[u8] = b"hypercore"; -/// Default timeout (in seconds) -pub(crate) const DEFAULT_TIMEOUT: u32 = 20; - /// Default keepalive interval (in seconds) pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; -// 16,78MB is the max encrypted wire message size (will be much smaller usually). -// This limitation stems from the 24bit header. -pub(crate) const MAX_MESSAGE_SIZE: u64 = 0xFFFFFF; - /// v10: Protocol name pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index c0e54a9..20cb734 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -1,38 +1,25 @@ use super::HandshakeResult; -use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; -use std::convert::TryInto; -use std::io; +use std::{convert::TryInto, io}; const STREAM_ID_LENGTH: usize = 32; const KEY_LENGTH: usize = 32; -const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; pub(crate) struct DecryptCipher { pull_stream: PullStream, } -pub(crate) struct EncryptCipher { - push_stream: PushStream, -} - impl std::fmt::Debug for DecryptCipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "DecryptCipher(crypto_secretstream)") } } -impl std::fmt::Debug for EncryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") - } -} - impl DecryptCipher { pub(crate) fn from_handshake_rx_and_init_msg( handshake_result: &HandshakeResult, @@ -75,24 +62,6 @@ impl DecryptCipher { let pull_stream = PullStream::init(Header::from(header), &key); Ok(Self { pull_stream }) } - - pub(crate) fn decrypt( - &mut self, - buf: &mut [u8], - header_len: usize, - body_len: usize, - ) -> io::Result { - let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; - let decrypted_len = to_decrypt.len(); - write_uint24_le(decrypted_len, buf); - let decrypted_end = 3 + to_decrypt.len(); - buf[3..decrypted_end].copy_from_slice(to_decrypt.as_slice()); - // Set extra bytes in the buffer to 0 - let encrypted_end = header_len + body_len; - buf[decrypted_end..encrypted_end].fill(0x00); - Ok(decrypted_end) - } - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> io::Result<(Vec, Tag)> { let mut to_decrypt = buf.to_vec(); let tag = &self.pull_stream.pull(&mut to_decrypt, &[]).map_err(|err| { @@ -102,62 +71,6 @@ impl DecryptCipher { } } -impl EncryptCipher { - pub(crate) fn from_handshake_tx( - handshake_result: &HandshakeResult, - ) -> std::io::Result<(Self, Vec)> { - let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] - .try_into() - .expect("split_tx with incorrect length"); - let key = Key::from(key); - - let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; - write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); - write_stream_id( - &handshake_result.handshake_hash, - handshake_result.is_initiator, - &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], - ); - - let (header, push_stream) = PushStream::init(OsRng, &key); - let header = header.as_ref(); - header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); - let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) - } - - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 - } - - /// Encrypts message in the given buffer to the same buffer, returns number of bytes - /// of total message. - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { - let stat = stat_uint24_le(buf); - if let Some((header_len, body_len)) = stat { - let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); - self.push_stream - .push(&mut to_encrypt, &[], Tag::Message) - .map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) - })?; - let encrypted_len = to_encrypt.len(); - write_uint24_le(encrypted_len, buf); - buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(3 + encrypted_len) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Could not encrypt invalid data, len: {}", buf.len()), - )) - } - } -} - // NB: These values come from Javascript-side // // const [NS_INITIATOR, NS_RESPONDER] = crypto.namespace('hyperswarm/secret-stream', 2) @@ -184,3 +97,53 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { let result = result.as_slice(); out.copy_from_slice(result); } + +//NB "raw" here means UN-framed. No frame header. +const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; + +pub(crate) struct EncryptCipher { + push_stream: PushStream, +} + +impl std::fmt::Debug for EncryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "RawEncryptCipher(crypto_secretstream)") + } +} + +impl EncryptCipher { + pub(crate) fn from_handshake_tx( + handshake_result: &HandshakeResult, + ) -> std::io::Result<(Self, Vec)> { + let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] + .try_into() + .expect("split_tx with incorrect length"); + let key = Key::from(key); + + let mut header_message: [u8; RAW_HEADER_MSG_LEN] = [0; RAW_HEADER_MSG_LEN]; + + write_stream_id( + &handshake_result.handshake_hash, + handshake_result.is_initiator, + &mut header_message[..STREAM_ID_LENGTH], + ); + + let (header, push_stream) = PushStream::init(OsRng, &key); + let header = header.as_ref(); + header_message[STREAM_ID_LENGTH..].copy_from_slice(header); + let msg = header_message.to_vec(); + Ok((Self { push_stream }, msg)) + } + + // TODO make this work in-place + /// Encrypts `msg` and returns the encrypted bytes + pub(crate) fn encrypt(&mut self, msg: &[u8]) -> io::Result> { + let mut out = msg.to_vec(); + self.push_stream + .push(&mut out, &[], Tag::Message) + .map_err(|err| { + io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) + })?; + Ok(out) + } +} diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 64db407..53f3889 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -1,12 +1,14 @@ use super::curve::CurveResolver; -use crate::util::wrap_uint24_le; use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; -use snow::resolvers::{DefaultResolver, FallbackResolver}; -use snow::{Builder, Error as SnowError, HandshakeState}; +use snow::{ + resolvers::{DefaultResolver, FallbackResolver}, + Builder, Error as SnowError, HandshakeState, +}; use std::io::{Error, ErrorKind, Result}; +use tracing::instrument; const CIPHERKEYLEN: usize = 32; const HANDSHAKE_PATTERN: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; @@ -23,7 +25,7 @@ const REPLICATE_RESPONDER: [u8; 32] = [ ]; #[derive(Debug, Clone, Default)] -pub(crate) struct HandshakeResult { +pub struct HandshakeResult { pub(crate) is_initiator: bool, pub(crate) local_pubkey: Vec, pub(crate) remote_pubkey: Vec, @@ -33,6 +35,7 @@ pub(crate) struct HandshakeResult { } impl HandshakeResult { + #[instrument(skip_all)] pub(crate) fn capability(&self, key: &[u8]) -> Option> { Some(replicate_capability( self.is_initiator, @@ -49,6 +52,7 @@ impl HandshakeResult { )) } + #[instrument(skip_all)] pub(crate) fn verify_remote_capability( &self, capability: Option>, @@ -69,6 +73,7 @@ impl HandshakeResult { } } +#[derive(Debug)] pub(crate) struct Handshake { result: HandshakeResult, state: HandshakeState, @@ -80,6 +85,7 @@ pub(crate) struct Handshake { } impl Handshake { + #[instrument] pub(crate) fn new(is_initiator: bool) -> Result { let (state, local_pubkey) = build_handshake_state(is_initiator).map_err(map_err)?; @@ -100,11 +106,10 @@ impl Handshake { }) } - pub(crate) fn start(&mut self) -> Result>> { + pub(crate) fn start_raw(&mut self) -> Result>> { if self.is_initiator() { let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); - Ok(Some(wrapped)) + Ok(Some(self.tx_buf[..tx_len].to_vec())) } else { Ok(None) } @@ -123,13 +128,14 @@ impl Handshake { .read_message(msg, &mut self.rx_buf) .map_err(map_err) } - fn send(&mut self) -> Result { + pub(crate) fn send(&mut self) -> Result { self.state .write_message(&self.payload, &mut self.tx_buf) .map_err(map_err) } - pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { + #[instrument(skip_all, fields(is_initiator = %self.result.is_initiator))] + pub(crate) fn read_raw(&mut self, msg: &[u8]) -> Result>> { // eprintln!("hs read len {}", msg.len()); if self.complete() { return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); @@ -137,16 +143,17 @@ impl Handshake { let _rx_len = self.recv(msg)?; + // first non-init if !self.is_initiator() && !self.did_receive { self.did_receive = true; let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + let wrapped = self.tx_buf[..tx_len].to_vec(); return Ok(Some(wrapped)); } let tx_buf = if self.is_initiator() { let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + let wrapped = self.tx_buf[..tx_len].to_vec(); Some(wrapped) } else { None @@ -170,11 +177,11 @@ impl Handshake { Ok(tx_buf) } - pub(crate) fn into_result(self) -> Result { + pub(crate) fn get_result(&self) -> Result<&HandshakeResult> { if !self.complete() { Err(Error::new(ErrorKind::Other, "Handshake is not complete")) } else { - Ok(self.result) + Ok(&self.result) } } } @@ -223,7 +230,7 @@ fn map_err(e: SnowError) -> Error { } /// Create a hash used to indicate replication capability. -/// See https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11 +/// See JavaScript [here](https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11). fn replicate_capability(is_initiator: bool, key: &[u8], handshake_hash: &[u8]) -> Vec { let seed = if is_initiator { REPLICATE_INITIATOR diff --git a/src/duplex.rs b/src/duplex.rs index fe79c1b..7b0f1e5 100644 --- a/src/duplex.rs +++ b/src/duplex.rs @@ -1,7 +1,9 @@ use futures_lite::{AsyncRead, AsyncWrite}; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; #[derive(Clone, Debug, PartialEq)] /// Duplex IO stream from reader and writer halves. diff --git a/src/framing.rs b/src/framing.rs new file mode 100644 index 0000000..760b7c6 --- /dev/null +++ b/src/framing.rs @@ -0,0 +1,382 @@ +//! Wrap bytes in length prefixed framing. +use crate::util::{stat_uint24_le, wrap_uint24_le}; +use futures::{Sink, Stream}; +use futures_lite::io::{AsyncRead, AsyncWrite}; +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, +}; +use tracing::{error, info, instrument, trace, warn}; + +const BUF_SIZE: usize = 1024 * 64; +const _HEADER_LEN: usize = 3; + +/// take a `AsyncWrite` of length prefixed messages and emit them as a Stream +pub struct Uint24LELengthPrefixedFraming { + io: IO, + /// Data from [`Self::io`]'s [`AsyncRead`] interface to be sent out via the [`Stream`] interface. + to_stream: Vec, + /// Data from the `Sink` interface to be written out to [`Self::io`]'s [`AsyncWrite`] interface. + from_sink: VecDeque>, + /// The index in [`Self::to_stream`] of the last byte that was sent to the [`Stream`]. + last_out_idx: usize, + /// The index in [`Self::to_stream`] of the last byte that was read from [`Self::io`]'s + /// [`AsyncRead`] + last_data_idx: usize, + /// Current step of a message being parsed + step: Step, +} +impl Debug for Uint24LELengthPrefixedFraming { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Framer") + //.field("io", &self.io) + .field("to_stream.len()", &self.to_stream.len()) + .field("from_sink", &self.from_sink.len()) + .field("last_out_idx", &self.last_out_idx) + .field("last_data_idx", &self.last_data_idx) + .field("step", &self.step) + .finish() + } +} +impl Uint24LELengthPrefixedFraming +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + /// Build [`Uint24LELengthPrefixedFraming`] around an [`AsyncWrite`]/[`AsyncRead`] thing. + pub fn new(io: IO) -> Self { + Self { + io, + to_stream: vec![0u8; BUF_SIZE], + from_sink: VecDeque::new(), + last_out_idx: 0, + last_data_idx: 0, + step: Step::Header, + } + } +} + +#[derive(Debug)] +enum Step { + Header, + Body { start: usize, end: u64 }, +} + +impl Stream for Uint24LELengthPrefixedFraming +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + type Item = Result>; + + #[instrument(skip_all)] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { + io, + to_stream, + last_out_idx, + last_data_idx, + step, + .. + } = self.get_mut(); + trace!( + "Try to AsyncRead up to (buff_size[{}] - last_data_idx[{}]) = [{}]", + to_stream.len(), + *last_data_idx, + to_stream.len() - *last_data_idx + ); + let n_bytes_read = match Pin::new(io).poll_read(cx, &mut to_stream[*last_data_idx..]) { + Poll::Ready(Ok(n)) => n, + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => 0, + }; + // TODO handle if to_stream is full + trace!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); + *last_data_idx += n_bytes_read; + // grow buffer if it's full + if *last_data_idx == to_stream.len() - 1 { + warn!("Buffer full, double it's size"); + to_stream.extend(vec![0; to_stream.len()]); + } + + if let Step::Header = step { + trace!(step = ?*step, "enter"); + let cur_data = &to_stream[*last_out_idx..*last_data_idx]; + + let Some((header_len, body_len)) = stat_uint24_le(cur_data) else { + trace!("not enough bytes to read header"); + return Poll::Pending; + }; + + let cur_frame_start = *last_out_idx + header_len; + let cur_frame_end = (cur_frame_start as u64) + body_len; + *step = Step::Body { + start: cur_frame_start, + end: cur_frame_end, + }; + } + + info!(step = ?*step, "enter"); + if let Step::Body { start, end } = step { + let end = *end as usize; + if end <= *last_data_idx { + trace!(frame_size = end - *start, "Frame ready"); + let out = to_stream[*start..end].to_vec(); + *step = Step::Header; + + // remove bytes we're done with + to_stream.rotate_left(end); + *last_data_idx -= end; + *last_out_idx = 0; + return Poll::Ready(Some(Ok(out))); + } + } + Poll::Pending + } +} + +impl Sink> for Uint24LELengthPrefixedFraming +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + type Error = std::io::Error; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + #[instrument(skip_all)] + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { + self.from_sink.push_back(wrap_uint24_le(&item)); + Ok(()) + } + + #[instrument(skip_all)] + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let Self { from_sink, io, .. } = self.get_mut(); + loop { + if let Some(msg) = from_sink.pop_front() { + match Pin::new(&mut *io).poll_write(cx, &msg) { + Poll::Pending => { + from_sink.push_front(msg); + return Poll::Pending; + } + Poll::Ready(Ok(n)) => { + if n != msg.len() { + from_sink.push_front(msg[n..].to_vec()); + warn!("only wrote [{n} / {}] bytes of message", msg.len()); + } + trace!("flushed whole message of N=[{n}] bytes"); + } + Poll::Ready(Err(e)) => { + error!("Error flushing data"); + return Poll::Ready(Err(e)); + } + } + } else { + trace!("No more messages to flush"); + return Poll::Ready(Ok(())); + } + } + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let Self { io, .. } = self.get_mut(); + Pin::new(&mut *io).poll_close(cx) + } +} +#[cfg(test)] +pub(crate) mod test { + use crate::{test_utils::log, Duplex}; + + use super::*; + use futures::{SinkExt, StreamExt}; + use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + use tokio::spawn; + use tokio_util::compat::TokioAsyncReadCompatExt; + + pub(crate) fn duplex( + channel_size: usize, + ) -> (impl AsyncRead + AsyncWrite, impl AsyncRead + AsyncWrite) { + let (left, right) = tokio::io::duplex(channel_size); + (left.compat(), right.compat()) + } + + #[tokio::test] + async fn duplex_works() -> Result<()> { + let (mut left, mut right) = duplex(64); + left.write_all(b"hello").await?; + let mut b = vec![0; 5]; + right.read_exact(&mut b).await?; + assert_eq!(b, b"hello"); + Ok(()) + } + + #[tokio::test] + async fn input() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = Uint24LELengthPrefixedFraming::new(left); + let input = b"yelp"; + let msg = wrap_uint24_le(input); + right.write_all(&msg).await?; + let Some(Ok(rx)) = lp.next().await else { + panic!() + }; + assert_eq!(rx, input); + Ok(()) + } + #[tokio::test] + async fn stream_many() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = Uint24LELengthPrefixedFraming::new(left); + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + let msg = wrap_uint24_le(d); + right.write_all(&msg).await?; + } + for d in data { + let Some(Ok(res)) = lp.next().await else { + panic!(); + }; + assert_eq!(&res, d); + } + Ok(()) + } + #[tokio::test] + async fn sink_many() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = Uint24LELengthPrefixedFraming::new(left); + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + lp.send(d.to_vec()).await.unwrap(); + } + + let mut expected = vec![]; + data.iter().for_each(|d| expected.extend(wrap_uint24_le(d))); + let mut result = vec![0; expected.len()]; + right.read_exact(&mut result).await?; + assert_eq!(result, expected); + Ok(()) + } + + #[tokio::test] + async fn left_and_right() -> Result<()> { + let (left, right) = duplex(64); + + let mut leftlp = Uint24LELengthPrefixedFraming::new(left); + let mut rightlp = Uint24LELengthPrefixedFraming::new(right); + + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + } + + let mut result1 = vec![]; + for _ in data { + result1.push(leftlp.next().await.unwrap().unwrap()); + } + assert_eq!(result1, data); + + for d in data { + leftlp.send(d.to_vec()).await.unwrap(); + } + let mut result2 = vec![]; + for _ in data { + result2.push(rightlp.next().await.unwrap().unwrap()); + } + assert_eq!(result2, data); + + let mut r3 = vec![]; + let mut r4 = vec![]; + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + leftlp.send(d.to_vec()).await.unwrap(); + } + + for _ in data { + r3.push(rightlp.next().await.unwrap().unwrap()); + r4.push(leftlp.next().await.unwrap().unwrap()); + } + assert_eq!(r3, data); + assert_eq!(r4, data); + + Ok(()) + } + #[tokio::test] + async fn left_and_right_sluice() -> Result<()> { + let (ar, bw) = sluice::pipe::pipe(); + let (br, aw) = sluice::pipe::pipe(); + let left = Duplex::new(ar, aw); + let right = Duplex::new(br, bw); + + let mut leftlp = Uint24LELengthPrefixedFraming::new(left); + let mut rightlp = Uint24LELengthPrefixedFraming::new(right); + + // NB sluice has a max "chunk" thing of 4 + // so we limit the data we're sending to 3 things + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle"]; + // NB this sluice pipe + // + for d in data { + rightlp.feed(d.to_vec()).await?; + } + let rflush = spawn(async move { + rightlp.flush().await.unwrap(); + rightlp + }); + + let mut result1 = vec![]; + for _ in data { + result1.push(leftlp.next().await.unwrap()?); + } + let mut rightlp = rflush.await?; + + assert_eq!(result1, data); + + for d in data { + leftlp.feed(d.to_vec()).await?; + } + let lflush = spawn(async move { + leftlp.flush().await.unwrap(); + leftlp + }); + + let mut result2 = vec![]; + for _ in data { + result2.push(rightlp.next().await.unwrap()?); + } + let mut leftlp = lflush.await?; + assert_eq!(result2, data); + + let mut r3 = vec![]; + let mut r4 = vec![]; + + for d in data { + rightlp.send(d.to_vec()).await?; + leftlp.send(d.to_vec()).await?; + } + + for _ in data { + r3.push(rightlp.next().await.unwrap()?); + r4.push(leftlp.next().await.unwrap()?); + } + + assert_eq!(r3, data); + assert_eq!(r4, data); + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 531a068..7857bd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,17 +50,14 @@ //! //! ```no_run //! # async_std::task::block_on(async { -//! use hypercore_protocol::{ProtocolBuilder, Event, Message}; -//! use hypercore_protocol::schema::*; //! use async_std::prelude::*; +//! use hypercore_protocol::{schema::*, Event, Message, ProtocolBuilder}; //! // Start a tcp server. //! let listener = async_std::net::TcpListener::bind("localhost:8000").await.unwrap(); //! async_std::task::spawn(async move { //! let mut incoming = listener.incoming(); //! while let Some(Ok(stream)) = incoming.next().await { -//! async_std::task::spawn(async move { -//! onconnection(stream, false).await -//! }); +//! async_std::task::spawn(async move { onconnection(stream, false).await }); //! } //! }); //! @@ -69,7 +66,7 @@ //! onconnection(stream, true).await; //! //! /// Start Hypercore protocol on a TcpStream. -//! async fn onconnection (stream: async_std::net::TcpStream, is_initiator: bool) { +//! async fn onconnection(stream: async_std::net::TcpStream, is_initiator: bool) { //! // A peer either is the initiator or a connection or is being connected to. //! let name = if is_initiator { "dialer" } else { "listener" }; //! // A key for the channel we want to open. Usually, this is a pre-shared key that both peers @@ -86,7 +83,7 @@ //! // The handshake event is emitted after the protocol is fully established. //! Event::Handshake(_remote_key) => { //! protocol.open(key.clone()).await; -//! }, +//! } //! // A Channel event is emitted for each established channel. //! Event::Channel(mut channel) => { //! // A Channel can be sent to other tasks. @@ -97,7 +94,7 @@ //! eprintln!("{} received message: {:?}", name, message); //! } //! }); -//! }, +//! } //! _ => {} //! } //! } @@ -122,17 +119,22 @@ mod channels; mod constants; mod crypto; mod duplex; +mod framing; mod message; +mod mqueue; +mod noise; mod protocol; -mod reader; +#[cfg(test)] +mod test_utils; mod util; -mod writer; /// The wire messages used by the protocol. pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; +pub use framing::Uint24LELengthPrefixedFraming; +pub use noise::{encrypted_framed_message_channel, Encrypted, Event as NoiseEvent}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, diff --git a/src/message.rs b/src/message.rs index 27b74c1..7665df4 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,348 +1,117 @@ use crate::schema::*; -use crate::util::{stat_uint24_le, write_uint24_le}; -use hypercore::encoding::{ - CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, +use compact_encoding::{ + decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, + VecEncodable, }; use pretty_hash::fmt as pretty_fmt; -use std::fmt; -use std::io; - -/// The type of a data frame. -#[derive(Debug, Clone, PartialEq)] -pub(crate) enum FrameType { - Raw, - Message, -} - -/// Encode data into a buffer. -/// -/// This trait is implemented on data frames and their components -/// (channel messages, messages, and individual message types through prost). -pub(crate) trait Encoder: Sized + fmt::Debug { - /// Calculates the length that the encoded message needs. - fn encoded_len(&mut self) -> Result; - - /// Encodes the message to a buffer. - /// - /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&mut self, buf: &mut [u8]) -> Result; -} - -impl Encoder for &[u8] { - fn encoded_len(&mut self) -> Result { - Ok(self.len()) - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let len = self.encoded_len()?; - if len > buf.len() { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - buf[..len].copy_from_slice(&self[..]); - Ok(len) - } -} - -/// A frame of data, either a buffer or a message. -#[derive(Clone, PartialEq)] -pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), -} - -impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } -} - -impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } -} - -impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => { - let mut index = 0; - let mut raw_batch: Vec> = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - raw_batch.push( - buf[index + header_len..index + header_len + body_len as usize] - .to_vec(), - ); - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in raw batch", - )); - } - } - Ok(Frame::RawBatch(raw_batch)) - } - FrameType::Message => { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } +use std::{fmt, io}; +use tracing::{debug, instrument, trace, warn}; + +const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; +const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; +const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; +const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; + +#[instrument(skip_all err)] +pub(crate) fn decode_unframed_channel_messages( + buf: &[u8], +) -> Result<(Vec, usize), io::Error> { + let og_len = buf.len(); + if og_len >= 3 && buf[0] == 0x00 { + // batch of NOT open/close messages + if buf[1] == 0x00 { + let (_, mut buf) = take_array::<2>(buf)?; + // Batch of messages + let mut messages: Vec = vec![]; + + // First, there is the original channel + let mut current_channel; + (current_channel, buf) = u64::decode(buf)?; + while !buf.is_empty() { + // Length of the message is inbetween here + let channel_message_length; + (channel_message_length, buf) = decode_usize(buf)?; + if channel_message_length > buf.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length: [{channel_message_length}] +\tbut we have [{}] remaining bytes. +\tInitial buffer size [{og_len}]", + buf.len() + ), + )); } - Ok(Frame::MessageBatch(combined_messages)) - } - } - } - - /// Decode a frame from a buffer. - pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), - FrameType::Message => { - let (frame, _) = Self::decode_message(buf)?; - Ok(frame) - } - } - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } + // Then the actual message + let channel_message; + let bl = buf.len(); + (channel_message, buf) = ChannelMessage::decode_with_channel(buf, current_channel)?; + trace!( + "Decoded ChannelMessage::{:?} using [{} bytes]", + channel_message.message, + bl - buf.len() + ); + messages.push(channel_message); + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if !buf.is_empty() && buf[0] == 0x00 { + (current_channel, buf) = u64::decode(buf)?; } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) } - } else if buf.len() >= 2 { - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) + Ok((messages, og_len - buf.len())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) } else { Err(io::Error::new( io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), + "received invalid special message", )) } - } - - fn preencode(&mut self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) + } else if buf.len() >= 2 { + trace!("Decoding single ChannelMessage"); + // Single message + let og_len = buf.len(); + let (channel_message, buf) = ChannelMessage::decode_from_channel_and_message(buf)?; + Ok((vec![channel_message], og_len - buf.len())) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:?}"), + )) } } -impl Encoder for Frame { - fn encoded_len(&mut self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; +fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result { + Ok(match messages { + [] => 0, + [msg] => match msg.message { + Message::Open(_) | Message::Close(_) => 2 + msg.encoded_size()?, + _ => msg.encoded_size()?, + }, + msgs => { + let mut out = MULTI_MESSAGE_PREFIX.len(); + let mut current_channel: u64 = messages[0].channel; + out += current_channel.encoded_size()?; + for message in msgs.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + out += CHANNEL_CHANGE_SEPERATOR.len() + message.channel.encoded_size()?; + current_channel = message.channel; } + let message_length = message.message.encoded_size()?; + out += message_length + (message_length as u64).encoded_size()?; } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref mut messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } - } - } - }; - Ok(len) - } + out + } + }) } /// A protocol message. @@ -365,6 +134,114 @@ pub enum Message { LocalSignal((String, Vec)), } +macro_rules! message_from { + ($($val:ident),+) => { + $( + impl From<$val> for Message { + fn from(value: $val) -> Self { + Message::$val(value) + } + } + )* + } +} +message_from!( + Open, + Close, + Synchronize, + Request, + Cancel, + Data, + NoData, + Want, + Unwant, + Bitfield, + Range, + Extension +); + +macro_rules! decode_message { + ($type:ty, $buf:expr) => {{ + let (x, rest) = <$type>::decode($buf)?; + (Message::from(x), rest) + }}; +} + +impl CompactEncoding for Message { + fn encoded_size(&self) -> Result { + let typ_size = if let Self::Open(_) | Self::Close(_) = &self { + 0 + } else { + self.typ().encoded_size()? + }; + let msg_size = match self { + Self::LocalSignal(_) => Ok(0), + Self::Open(x) => x.encoded_size(), + Self::Close(x) => x.encoded_size(), + Self::Synchronize(x) => x.encoded_size(), + Self::Request(x) => x.encoded_size(), + Self::Cancel(x) => x.encoded_size(), + Self::Data(x) => x.encoded_size(), + Self::NoData(x) => x.encoded_size(), + Self::Want(x) => x.encoded_size(), + Self::Unwant(x) => x.encoded_size(), + Self::Bitfield(x) => x.encoded_size(), + Self::Range(x) => x.encoded_size(), + Self::Extension(x) => x.encoded_size(), + }?; + Ok(typ_size + msg_size) + } + + #[instrument(skip_all, fields(name = self.name()))] + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + debug!("Encoding {self:?}"); + let rest = if let Self::Open(_) | Self::Close(_) = &self { + buffer + } else { + self.typ().encode(buffer)? + }; + match self { + Self::Open(x) => x.encode(rest), + Self::Close(x) => x.encode(rest), + Self::Synchronize(x) => x.encode(rest), + Self::Request(x) => x.encode(rest), + Self::Cancel(x) => x.encode(rest), + Self::Data(x) => x.encode(rest), + Self::NoData(x) => x.encode(rest), + Self::Want(x) => x.encode(rest), + Self::Unwant(x) => x.encode(rest), + Self::Bitfield(x) => x.encode(rest), + Self::Range(x) => x.encode(rest), + Self::Extension(x) => x.encode(rest), + Self::LocalSignal(_) => unimplemented!("do not encode LocalSignal"), + } + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (typ, rest) = u64::decode(buffer)?; + Ok(match typ { + 0 => decode_message!(Synchronize, rest), + 1 => decode_message!(Request, rest), + 2 => decode_message!(Cancel, rest), + 3 => decode_message!(Data, rest), + 4 => decode_message!(NoData, rest), + 5 => decode_message!(Want, rest), + 6 => decode_message!(Unwant, rest), + 7 => decode_message!(Bitfield, rest), + 8 => decode_message!(Range, rest), + 9 => decode_message!(Extension, rest), + _ => { + return Err(EncodingError::new( + EncodingErrorKind::InvalidData, + &format!("Invalid message type to decode: {typ}"), + )) + } + }) + } +} impl Message { /// Wire type of this message. pub(crate) fn typ(&self) -> u64 { @@ -382,71 +259,23 @@ impl Message { value => unimplemented!("{} does not have a type", value), } } - - /// Decode a message from a buffer based on type. - pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { - let mut state = HypercoreState::from_buffer(buf); - let message = match typ { - 0 => Ok(Self::Synchronize((*state).decode(buf)?)), - 1 => Ok(Self::Request(state.decode(buf)?)), - 2 => Ok(Self::Cancel((*state).decode(buf)?)), - 3 => Ok(Self::Data(state.decode(buf)?)), - 4 => Ok(Self::NoData((*state).decode(buf)?)), - 5 => Ok(Self::Want((*state).decode(buf)?)), - 6 => Ok(Self::Unwant((*state).decode(buf)?)), - 7 => Ok(Self::Bitfield((*state).decode(buf)?)), - 8 => Ok(Self::Range((*state).decode(buf)?)), - 9 => Ok(Self::Extension((*state).decode(buf)?)), - _ => Err(EncodingError::new( - EncodingErrorKind::InvalidData, - &format!("Invalid message type to decode: {typ}"), - )), - }?; - Ok((message, state.start())) - } - - /// Pre-encodes a message to state, returns length - pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { - match self { - Self::Open(ref message) => state.0.preencode(message)?, - Self::Close(ref message) => state.0.preencode(message)?, - Self::Synchronize(ref message) => state.0.preencode(message)?, - Self::Request(ref message) => state.preencode(message)?, - Self::Cancel(ref message) => state.0.preencode(message)?, - Self::Data(ref message) => state.preencode(message)?, - Self::NoData(ref message) => state.0.preencode(message)?, - Self::Want(ref message) => state.0.preencode(message)?, - Self::Unwant(ref message) => state.0.preencode(message)?, - Self::Bitfield(ref message) => state.0.preencode(message)?, - Self::Range(ref message) => state.0.preencode(message)?, - Self::Extension(ref message) => state.0.preencode(message)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.end()) - } - - /// Encodes a message to a given buffer, using preencoded state, results size - pub(crate) fn encode( - &self, - state: &mut HypercoreState, - buf: &mut [u8], - ) -> Result { + /// Get the name of the message + pub fn name(&self) -> &'static str { match self { - Self::Open(ref message) => state.0.encode(message, buf)?, - Self::Close(ref message) => state.0.encode(message, buf)?, - Self::Synchronize(ref message) => state.0.encode(message, buf)?, - Self::Request(ref message) => state.encode(message, buf)?, - Self::Cancel(ref message) => state.0.encode(message, buf)?, - Self::Data(ref message) => state.encode(message, buf)?, - Self::NoData(ref message) => state.0.encode(message, buf)?, - Self::Want(ref message) => state.0.encode(message, buf)?, - Self::Unwant(ref message) => state.0.encode(message, buf)?, - Self::Bitfield(ref message) => state.0.encode(message, buf)?, - Self::Range(ref message) => state.0.encode(message, buf)?, - Self::Extension(ref message) => state.0.encode(message, buf)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.start()) + Message::Open(_) => "Open", + Message::Close(_) => "Close", + Message::Synchronize(_) => "Synchronize", + Message::Request(_) => "Request", + Message::Cancel(_) => "Cancel", + Message::Data(_) => "Data", + Message::NoData(_) => "NoData", + Message::Want(_) => "Want", + Message::Unwant(_) => "Unwant", + Message::Bitfield(_) => "Bitfield", + Message::Range(_) => "Range", + Message::Extension(_) => "Extension", + Message::LocalSignal(_) => "LocalSignal", + } } } @@ -479,7 +308,6 @@ impl fmt::Display for Message { pub(crate) struct ChannelMessage { pub(crate) channel: u64, pub(crate) message: Message, - state: Option, } impl PartialEq for ChannelMessage { @@ -494,14 +322,21 @@ impl fmt::Debug for ChannelMessage { } } +impl fmt::Display for ChannelMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ChannelMessage {{ channel {}, message {} }}", + self.channel, + self.message.name() + ) + } +} + impl ChannelMessage { /// Create a new message. pub(crate) fn new(channel: u64, message: Message) -> Self { - Self { - channel, - message, - state: None, - } + Self { channel, message } } /// Consume self and return (channel, Message). @@ -513,23 +348,24 @@ impl ChannelMessage { /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it + #[instrument(skip_all, err)] pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.len() <= 5 { + debug!("Decode ChannelMessage::Open"); + let og_len = buf.len(); + if og_len <= 5 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Open message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let open_msg: Open = state.decode(buf)?; + let (open_msg, buf) = Open::decode(buf)?; Ok(( Self { channel: open_msg.channel, message: Message::Open(open_msg), - state: None, }, - state.start(), + og_len - buf.len(), )) } @@ -538,105 +374,155 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + debug!("Decode ChannelMessage::Close"); + let og_len = buf.len(); if buf.is_empty() { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Close message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let close_msg: Close = state.decode(buf)?; + let (close, buf) = Close::decode(buf)?; Ok(( Self { - channel: close_msg.channel, - message: Message::Close(close_msg), - state: None, + channel: close.channel, + message: Message::Close(close), }, - state.start(), + og_len - buf.len(), )) } + #[instrument(err, skip_all)] + pub(crate) fn decode_from_channel_and_message( + buf: &[u8], + ) -> Result<(Self, &[u8]), EncodingError> { + //::decode(buf) + let (channel, buf) = u64::decode(buf)?; + let (message, buf) = ::decode(buf)?; + debug!( + "Decode ChannelMessage{{ channel: {channel}, message: {} }}", + message.name() + ); + Ok((Self { channel, message }, buf)) + } /// Decode a normal channel message from a buffer. /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { + #[instrument(err, skip(buf))] + pub(crate) fn decode_with_channel(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, - "received empty message", + format!("received empty message [{buf:?}]"), )); } - let mut state = State::from_buffer(buf); - let typ: u64 = state.decode(buf)?; - let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok(( - Self { - channel, - message, - state: None, - }, - state.start() + length, - )) + let (message, buf) = ::decode(buf)?; + Ok((Self { channel, message }, buf)) } +} - /// Performance optimization for letting calling encoded_len() already do - /// the preencode phase of compact_encoding. - fn prepare_state(&mut self) -> Result<(), EncodingError> { - if self.state.is_none() { - let state = if let Message::Open(_) = self.message { - // Open message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else { - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); - let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state - }; - self.state = Some(state); - } - Ok(()) +/// NB: currently this is just for a standalone channel message. ChannelMessages in a vec decode & +/// encode differently +impl CompactEncoding for ChannelMessage { + fn encoded_size(&self) -> Result { + let channel_size = if let Message::Open(_) | Message::Close(_) = &self.message { + 0 + } else { + self.channel.encoded_size()? + }; + + Ok(channel_size + self.message.encoded_size()?) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = if let Message::Open(_) | Message::Close(_) = &self.message { + buffer + } else { + self.channel.encode(buffer)? + }; + ::encode(&self.message, rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + ChannelMessage::decode_from_channel_and_message(buffer) } } -impl Encoder for ChannelMessage { - fn encoded_len(&mut self) -> Result { - self.prepare_state()?; - Ok(self.state.as_ref().unwrap().end()) +impl VecEncodable for ChannelMessage { + #[instrument(skip_all, ret)] + fn vec_encoded_size(vec: &[Self]) -> Result + where + Self: Sized, + { + vec_channel_messages_encoded_size(vec) } - fn encode(&mut self, buf: &mut [u8]) -> Result { - self.prepare_state()?; - let state = self.state.as_mut().unwrap(); - if let Message::Open(_) = self.message { - // Open message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else if let Message::Close(_) = self.message { - // Close message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else { - let typ = self.message.typ(); - state.0.encode(&typ, buf)?; - self.message.encode(state, buf)?; + #[instrument(skip_all)] + fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> + where + Self: Sized, + { + let in_buf_len = buffer.len(); + trace!( + "Vec::encode to buf.len() = [{}]", + buffer.len() + ); + let mut rest = buffer; + match vec { + [] => Ok(rest), + [msg] => { + rest = match msg.message { + Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, rest)?, + Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, rest)?, + _ => msg.channel.encode(rest)?, + }; + msg.message.encode(rest) + } + msgs => { + rest = write_array(&MULTI_MESSAGE_PREFIX, rest)?; + let mut current_channel: u64 = msgs[0].channel; + rest = current_channel.encode(rest)?; + for msg in msgs { + if msg.channel != current_channel { + rest = write_array(&CHANNEL_CHANGE_SEPERATOR, rest)?; + rest = msg.channel.encode(rest)?; + current_channel = msg.channel; + } + let msg_len = msg.message.encoded_size()?; + rest = (msg_len as u64).encode(rest)?; + rest = msg.message.encode(rest)?; + } + trace!("wrote [{}] bytes to buffer", in_buf_len - rest.len()); + Ok(rest) + } } - Ok(state.start()) + } + + fn vec_decode(buffer: &[u8]) -> Result<(Vec, &[u8]), EncodingError> + where + Self: Sized, + { + let mut combined_messages: Vec = vec![]; + let mut rest = buffer; + while !rest.is_empty() { + let (msgs, length) = decode_unframed_channel_messages(rest) + .map_err(|e| EncodingError::external(&format!("{e}")))?; + rest = &rest[length..]; + combined_messages.extend(msgs); + } + Ok((combined_messages, rest)) } } #[cfg(test)] mod tests { + + use crate::test_utils::log; + use super::*; use hypercore::{ DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, @@ -646,19 +532,20 @@ mod tests { ($( $msg:expr ),*) => { $( let channel = rand::random::() as u64; - let mut channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); - let mut buf = vec![0u8; encoded_len]; - let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); - let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); - assert_eq!(channel, decoded.0); - assert_eq!($msg, decoded.1); + let channel_message = ChannelMessage::new(channel, $msg); + let encoded_size = channel_message.encoded_size()?; + let mut buf = vec![0u8; encoded_size]; + let rest = ::encode(&channel_message, &mut buf)?; + assert!(rest.is_empty()); + let (decoded, rest) = ::decode(&buf)?; + assert!(rest.is_empty()); + assert_eq!(decoded, channel_message); )* } } #[test] - fn message_encode_decode() { + fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { Message::Synchronize(Synchronize{ fork: 0, @@ -685,7 +572,9 @@ mod tests { upgrade: Some(RequestUpgrade { start: 0, length: 10 - }) + }), + manifest: false, + priority: 0 }), Message::Cancel(Cancel { request: 1, @@ -739,5 +628,30 @@ mod tests { message: vec![0x44, 20] }) }; + Ok(()) + } + + #[test] + fn enc_dec_vec_chan_message() -> Result<(), EncodingError> { + let one = Message::Synchronize(Synchronize { + fork: 0, + length: 4, + remote_length: 0, + downloading: true, + uploading: true, + can_upgrade: true, + }); + let two = Message::Range(Range { + drop: false, + start: 0, + length: 4, + }); + let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)]; + log(); + let buff = msgs.to_encoded_bytes()?; + let (result, rest) = as CompactEncoding>::decode(&buff)?; + assert!(rest.is_empty()); + assert_eq!(result, msgs); + Ok(()) } } diff --git a/src/mqueue.rs b/src/mqueue.rs new file mode 100644 index 0000000..9a2d91a --- /dev/null +++ b/src/mqueue.rs @@ -0,0 +1,209 @@ +//! Interface for reading and writing message to a Stream/Sink + +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, +}; + +use compact_encoding::CompactEncoding as _; +use futures::{Sink, Stream}; +use tracing::{error, info, instrument}; + +use crate::{message::ChannelMessage, noise::EncryptionInfo, NoiseEvent}; + +#[derive(Debug)] +pub(crate) enum MqueueEvent { + Meta(EncryptionInfo), + Message(Result>), +} + +impl From for MqueueEvent { + fn from(e: NoiseEvent) -> Self { + match e { + NoiseEvent::Meta(einf) => Self::Meta(einf), + NoiseEvent::Decrypted(dec_res) => Self::Message(match dec_res { + Ok(encoded) => match >::decode(&encoded) { + Ok((messages, _rest)) => Ok(messages), // _rest.len() == 0 + Err(e) => Err(e.into()), + }, + Err(e) => Err(e), + }), + } + } +} + +pub(crate) struct MessageIo { + io: IO, + write_queue: VecDeque, +} + +impl Debug for MessageIo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MessageIo") + //.field("io", &self.io) + .field("write_queue", &self.write_queue) + .finish() + } +} + +impl + Sink> + Send + Unpin + 'static> MessageIo { + pub(crate) fn new(io: IO) -> Self { + Self { + io, + write_queue: Default::default(), + } + } + + pub(crate) fn enqueue(&mut self, msg: ChannelMessage) { + self.write_queue.push_back(msg) + } + + #[instrument(skip_all)] + pub(crate) fn poll_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut pending = true; + // TODO handle error? + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(&mut self.io), cx) { + pending = false; + if self.write_queue.is_empty() { + break; + } + let mut messages = vec![]; + while let Some(msg) = self.write_queue.pop_front() { + messages.push(msg); + } + + let buf = match messages.to_encoded_bytes() { + Ok(x) => x, + Err(e) => { + error!(error = ?e, "error encoding messages"); + // TODO this would probably be a programming error. + // if so, this sholud just be an unwrap/expect + return Poll::Ready(Err(e.into())); + } + }; + + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf.to_vec()) { + todo!() + } + + match Sink::poll_flush(Pin::new(&mut self.io), cx) { + Poll::Ready(Err(_e)) => todo!(), + Poll::Pending => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + _ => {} + } + } + + if pending { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + pub(crate) fn poll_inbound(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io) + .poll_next(cx) + .map(|opt| opt.map(MqueueEvent::from)) + } +} + +impl + Sink> + Send + Unpin + 'static> Stream + for MessageIo +{ + type Item = MqueueEvent; + + #[instrument(skip_all, ret)] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let _ = self.poll_outbound(cx); + match self.poll_inbound(cx) { + Poll::Ready(Some(MqueueEvent::Message(Ok(x)))) => { + for m in x.iter() { + info!("RX ChannelMessage::{m}"); + } + Poll::Ready(Some(MqueueEvent::Message(Ok(x)))) + } + x => x, + } + } +} + +#[cfg(test)] +mod test { + use std::io::Result; + + use futures::{future::select, AsyncRead, AsyncWrite}; + use futures_lite::StreamExt; + + use crate::{ + encrypted_framed_message_channel, framing::test::duplex, message::ChannelMessage, + schema::NoData, test_utils::log, Encrypted, Uint24LELengthPrefixedFraming, + }; + + use super::{MessageIo, MqueueEvent}; + pub(crate) fn encrypted_and_framed< + BytesTxRx: AsyncRead + AsyncWrite + Send + Unpin + 'static, + >( + is_initiator: bool, + io: BytesTxRx, + ) -> MessageIo>> { + let io = encrypted_framed_message_channel(is_initiator, io); + MessageIo { + io, + write_queue: Default::default(), + } + } + fn new_msg(channel: u64) -> ChannelMessage { + ChannelMessage { + channel, + message: crate::Message::NoData(NoData { + request: channel + 1, + }), + } + } + + fn take_messages(e: Option) -> Option> { + match e { + Some(MqueueEvent::Message(Result::Ok(out))) => Some(out), + _ => None, + } + } + + #[tokio::test] + async fn mqueue() -> Result<()> { + log(); + + let rtolm = new_msg(38); + let ltorm = new_msg(42); + + let (left, right) = duplex(1024 * 64); + let mut left = encrypted_and_framed(true, left); + let mut right = encrypted_and_framed(false, right); + left.enqueue(ltorm.clone()); + right.enqueue(rtolm.clone()); + + loop { + match select(left.next(), right.next()).await { + futures::future::Either::Left((m, _)) => { + if let Some(m) = take_messages(m) { + assert_eq!(m, vec![rtolm]); + break; + } + } + futures::future::Either::Right((m, _)) => { + if let Some(m) = take_messages(m) { + assert_eq!(m, vec![rtolm]); + break; + } + } + } + } + Ok(()) + } +} diff --git a/src/noise.rs b/src/noise.rs new file mode 100644 index 0000000..2c6e001 --- /dev/null +++ b/src/noise.rs @@ -0,0 +1,700 @@ +use futures::{AsyncRead, AsyncWrite, Sink, Stream}; +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + mem::replace, + pin::Pin, + task::{Context, Poll}, +}; +use tracing::{debug, error, instrument, trace, warn}; + +use crate::{ + crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}, + Uint24LELengthPrefixedFraming, +}; + +/// Create a framed and encrypted Stream/Sink that reads/writes to an AsyncRead/AsyncWrite. +pub fn encrypted_framed_message_channel( + is_initiator: bool, + io: IO, +) -> Encrypted> { + let framed = Uint24LELengthPrefixedFraming::new(io); + Encrypted::new(is_initiator, framed) +} + +#[derive(Debug)] +pub(crate) enum Step { + NotInitialized, + Handshake(Box), + SecretStream((EncryptCipher, HandshakeResult)), + Established((EncryptCipher, DecryptCipher, HandshakeResult)), +} + +impl Step { + fn established(&self) -> bool { + matches!(self, Step::Established(_)) + } +} + +impl std::fmt::Display for Step { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let x = match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + }; + write!(f, "{}", x) + } +} + +#[derive(Debug)] +/// Encryption related info +pub enum EncryptionInfo { + Handshake(HandshakeResult), +} +#[derive(Debug)] +/// Decrypted messages and encryption related events +pub enum Event { + /// Events related to the encryption stream + Meta(EncryptionInfo), + /// A decrypted message + Decrypted(Result>), +} + +impl From>> for Event { + fn from(value: Result>) -> Self { + Self::Decrypted(value) + } +} +impl From for Event { + fn from(value: HandshakeResult) -> Self { + Self::Meta(EncryptionInfo::Handshake(value)) + } +} + +/// Wrap a stream with encryption +pub struct Encrypted { + io: IO, + step: Step, + is_initiator: bool, + encrypted_tx: VecDeque>, + encrypted_rx: VecDeque>>, + plain_tx: VecDeque>, + plain_rx: VecDeque, + flush: bool, +} + +impl Encrypted +where + IO: Stream>> + Sink> + Send + Unpin + 'static, +{ + /// Create [`Self`] from a Stream/Sink + #[instrument(skip_all, fields(initiator = %is_initiator))] + pub fn new(is_initiator: bool, io: IO) -> Self { + Self { + io, + is_initiator, + step: Step::NotInitialized, + encrypted_tx: Default::default(), + encrypted_rx: Default::default(), + plain_tx: Default::default(), + plain_rx: Default::default(), + flush: false, + } + } + /// Wether an encrypted connection has been established. + pub fn encryption_established(&self) -> bool { + self.step.established() + } + + /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. + #[instrument(name = "did_as_much_as_possible", skip_all, ret)] + fn did_as_much_as_possible(&mut self, cx: &mut Context<'_>) -> bool { + // No incoming encrypted messages available. + self.poll_incomming_encrypted_messages(cx).is_pending() + // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. + && (self.encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(&mut self.io), cx).is_pending()) + // No encrypted messages waiting to be decrypted. + && self.encrypted_rx.is_empty() + // No plaint text messages waiting to be enccrypted or we're still setting up + && (self.plain_tx.is_empty() || !self.step.established()) + } + + /// Handle all message throughput. Sends, encrypts and decrypts messages + /// Returns `true` `step` is already [`Step::Established`]. + #[instrument(name = "poll_message_throughput", skip_all, ret)] + fn poll_message_throughput(&mut self, cx: &mut Context<'_>) -> bool { + self.poll_outgoing_encrypted_messages(cx); + let _ = self.poll_incomming_encrypted_messages(cx); + if let Step::Established((encryptor, decryptor, ..)) = &mut self.step { + // decrypt incomming msgs + poll_decrypt( + decryptor, + &mut self.encrypted_rx, + &mut self.plain_rx, + self.is_initiator, + ); + // encrypt any pending plaintext outgoinng messages + poll_encrypt( + encryptor, + &mut self.encrypted_tx, + &mut self.plain_tx, + self.is_initiator, + &mut self.flush, + ); + true + } else { + self.poll_setup(); + false + } + } + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + fn poll_setup(&mut self) { + // if we get an error, it could be because the other side reset, and is sending a new + // initialization message. + // If this is the case, we should retry this message after the error. + // But to avoid repeatedly retrying the first message, we should only retry if it is *not* the first msg. + // Still setting up + if let Ok(Some(msg)) = maybe_init(&mut self.step, self.is_initiator) { + // queue the init message to send first + trace!(initiator = %self.is_initiator,"queue initial msg"); + self.encrypted_tx.push_front(msg); + } + // TODO handle error + while let Some(enc_res) = self.encrypted_rx.pop_front() { + match enc_res { + Err(e) => { + error!("Recieved an error during setup encryption setup: {e:?}"); + break; + } + Ok(incoming_msg) => { + trace!(initiator = %self.is_initiator, "encrypted_rx dequeue recieved setup msg"); + if let Ok(msgs) = match self.handle_setup_message(&incoming_msg) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + trace!(initiator = %self.is_initiator,"queue more setup msg"); + self.encrypted_tx.push_front(msg); + } + } + } + } + + if self.step.established() { + return; + } + } + } + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + /// Fills `encrypted_rx` and drains `encrypted_tx`. + fn poll_outgoing_encrypted_messages(&mut self, cx: &mut Context<'_>) { + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(&mut self.io), cx) { + if let Some(encrypted_out) = self.encrypted_tx.pop_front() { + trace!(initiator = %self.is_initiator, msg_len = encrypted_out.len(), step = %self.step, "TX message"); + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), encrypted_out) { + error!("Error polling encyrpted side io") + } + + self.flush = true; + } else { + break; + } + } + if self.flush { + match Sink::poll_flush(Pin::new(&mut self.io), cx) { + Poll::Ready(Ok(())) => { + self.flush = false; + trace!(initiator = %self.is_initiator, "all flushed"); + } + Poll::Ready(Err(_e)) => { + error!(initiator = %self.is_initiator, "Error sending encrypted msg") + } + Poll::Pending => { + // flush not complete try again later + self.flush = true; + } + } + } + } + + fn poll_incomming_encrypted_messages(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // pull in any incomming encrypted messages + let mut got_some = false; + while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(&mut self.io), cx) { + trace!(initiator = %self.is_initiator, step = %self.step, "RX message"); + self.encrypted_rx.push_back(encrypted_msg); + got_some = true; + } + if got_some { + Poll::Ready(()) + } else { + Poll::Pending + } + } + /// handle setup messages: if any are incorrect (cause an error) the state is reset + #[instrument(err, skip_all, fields(initiator = %self.is_initiator))] + fn handle_setup_message(&mut self, msg: &[u8]) -> Result>> { + // this would only happen after reset with a bad message. + let mut first_message = false; + if let Step::NotInitialized = self.step { + first_message = true; + assert!(!self.is_initiator); + warn!(initiator = %self.is_initiator, "Encrypted state was reset"); + let mut handshake = Handshake::new(self.is_initiator)?; + let _ = handshake.start_raw()?; + self.step = Step::Handshake(Box::new(handshake)); + } + match &self.step { + Step::NotInitialized => { + unreachable!("should not happen") + } + Step::Handshake(_) => { + let mut out = vec![]; + if let Step::Handshake(mut handshake) = + replace(&mut self.step, Step::NotInitialized) + { + trace!("RX handshake msg"); + if let Some(response) = match handshake.read_raw(msg) { + Ok(x) => x, + Err(e) => { + let maybe_init_message = + (!first_message && !self.is_initiator).then_some(msg.to_vec()); + + self.reset_encrypted(maybe_init_message); + return Err(e); + } + } { + trace!( + initiator = %self.is_initiator, + "read message and emitting response", + ); + out.push(response); + } + + if handshake.complete() { + debug!(initiator = %self.is_initiator, "Handshake completed"); + let handshake_result = match handshake.get_result() { + Ok(x) => x, + Err(e) => { + error!("into-result error {e:?}"); + return Err(e); + } + }; + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = + match EncryptCipher::from_handshake_tx(handshake_result) { + Ok(x) => x, + Err(e) => { + error!("from_handshake_tx error {e:?}"); + return Err(e); + } + }; + out.push(init_msg); + self.step = Step::SecretStream((cipher, handshake_result.clone())); + debug!(initiator = %self.is_initiator, "Step changed to {}", self.step); + } else { + self.step = Step::Handshake(handshake); + } + } + Ok(out) + } + Step::SecretStream(_) => { + if let Step::SecretStream((enc_cipher, hs_result)) = + replace(&mut self.step, Step::NotInitialized) + { + let dec_cipher = + DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; + self.plain_rx.push_back(Event::from(hs_result.clone())); + self.step = Step::Established((enc_cipher, dec_cipher, hs_result)); + debug!(initiator = %self.is_initiator, "Step changed to {}", self.step); + } + Ok(vec![]) + } + Step::Established(_) => { + unreachable!("`handle_setup_message` should never be called when Step::Established") + } + } + } + #[instrument(skip_all)] + fn reset_encrypted(&mut self, maybe_init_message: Option>) { + error!("Encrypted RESET"); + self.step = Step::NotInitialized; + self.encrypted_tx.clear(); + self.encrypted_rx.clear(); + if let Some(msg) = maybe_init_message { + self.encrypted_rx.push_front(Ok(msg)); + } + self.flush = false; + } +} + +impl< + IO: Stream>> + + Sink, Error = std::io::Error> + + Send + + Unpin + + 'static, + > Sink> for Encrypted +{ + type Error = std::io::Error; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Sink::poll_ready(Pin::new(&mut self.io), cx) + } + + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { + trace!(initiator = %self.is_initiator, "enqueue plain_tx"); + self.plain_tx.push_back(item); + Ok(()) + } + + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // The flow here can be understood as reading from the encrypted side moving those messages + // through to the plaintext side, then reading new plaintext messages and moving them to + // the encrypted side. + // We do this repeatedly until there's nothing else to do + loop { + self.poll_message_throughput(cx); + self.poll_outgoing_encrypted_messages(cx); + + // check if we've done all possible work + if self.did_as_much_as_possible(cx) { + if !self.step.established() || !self.encrypted_tx.is_empty() || self.flush { + trace!(not_established = !self.step.established(), tx_msgs_waiting = !self.encrypted_tx.is_empty(), flush = ?self.flush, "not done flushing"); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + return Poll::Ready(Ok(())); + } + } + } + + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + todo!() + } +} + +impl>> + Sink> + Send + Unpin + 'static> Stream + for Encrypted +{ + type Item = Event; + + #[instrument(skip_all, fields(initiator = %self.is_initiator, ret, err))] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.poll_message_throughput(cx) { + if let Some(msg) = self.plain_rx.pop_front() { + Poll::Ready(Some(msg)) + } else { + Poll::Pending + } + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +#[instrument(skip_all)] +fn poll_decrypt( + decryptor: &mut DecryptCipher, + encrypted_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, + is_initiator: bool, +) { + // decrypt any incromming encrypted messages + // TODO handle error + while let Some(incoming_msg_res) = encrypted_rx.pop_front() { + match incoming_msg_res { + Ok(incoming_msg) => { + trace!(initiator = %is_initiator, "encrypted_rx dequeue decrypt"); + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + plain_rx.push_back(Event::from(Ok(plain_msg))); + trace!(initiator = %is_initiator, n_plain_rx_msgs = plain_rx.len(), "plain_rx enqueue"); + } + Err(e) => { + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") + } + } + } + Err(e) => { + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") + } + } + } +} + +#[instrument(skip_all)] +fn poll_encrypt( + encryptor: &mut EncryptCipher, + encrypted_tx: &mut VecDeque>, + plain_tx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) { + // encrypt any pending plaintext outgoinng messages + while let Some(plain_out) = plain_tx.pop_front() { + let enc_out = match encryptor.encrypt(&plain_out) { + Ok(x) => x, + Err(_e) => todo!("We failed to encrypt our own message...?"), + }; + trace!(initiator = %is_initiator, encrypted_msg_length = enc_out.len(), "enqueue new encrypted message from plain tx queue"); + encrypted_tx.push_back(enc_out); + *flush = true; + } +} + +fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { + if !matches!(step, Step::NotInitialized) { + return Ok(None); + } + trace!(initiator = %is_initiator, "Init, state {step:?}"); + let mut handshake = Handshake::new(is_initiator)?; + let out = handshake.start_raw()?; + *step = Step::Handshake(Box::new(handshake)); + Ok(out) +} + +impl std::fmt::Debug for Encrypted { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Encrypted") + //.field("io", &self.io) + //.field("step", &self.step) + .field("initiator", &self.is_initiator) + .field("encrypted_tx.len()", &self.encrypted_tx.len()) + .field("encrypted_rx", &self.encrypted_rx.len()) + .field("plain_tx", &self.plain_tx.len()) + .field("plain_rx", &self.plain_rx.len()) + //.field("flush", &self.flush) + .finish() + } +} + +#[cfg(test)] +mod test { + + use crate::{ + framing::test::duplex, test_utils::create_result_connected, Uint24LELengthPrefixedFraming, + }; + + use super::*; + use futures::{future::join, SinkExt, StreamExt}; + + fn inner(e: Option) -> Vec { + if let Some(Event::Decrypted(Ok(x))) = e { + return x; + } + panic!() + } + #[tokio::test] + async fn encrypted() -> Result<()> { + let hello = b"hello".to_vec(); + let world = b"world".to_vec(); + let (lc, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); + + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(received), hello); + + assert!(left.encryption_established()); + + assert!(right.encryption_established()); + + // NB: we cannot totally finish 'left.send' until the other side becomes active + // because the handshake with the other side ('right') must complete + // before the 'hello' message is sent. So we poll both the send and receive concurrently. + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + + // right recieves left's message + assert_eq!(inner(received), hello); + + // now that the encrypted channel is established, we don't need to spawn. + right.send(world.clone()).await.unwrap(); + + // left recieves right's message + left.next().await; + assert_eq!(inner(left.next().await), world); + Ok(()) + } + + #[tokio::test] + async fn encrypted_many() -> Result<()> { + let hello = b"hello".to_vec(); + let data = vec![ + b"yolo".to_vec(), + b"squalor".to_vec(), + b"idleness".to_vec(), + b"hello".to_vec(), + b"stuff".to_vec(), + ]; + let (lc, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); + + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(received), hello); + + for d in &data { + right.send(d.to_vec()).await?; + } + let mut result = vec![]; + let _ = left.next().await; + for _ in &data { + result.push(inner(left.next().await)); + } + assert_eq!(result, data); + Ok(()) + } + + #[tokio::test] + async fn with_framing() -> Result<()> { + let hello = b"hello".to_vec(); + + let (left, right) = duplex(1024 * 64); + let left = Uint24LELengthPrefixedFraming::new(left); + let right = Uint24LELengthPrefixedFraming::new(right); + + let mut left = Encrypted::new(true, left); + let mut right = Encrypted::new(false, right); + + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(right.next().await), hello); + + let data = vec![ + b"yolo".to_vec(), + b"squalor".to_vec(), + b"idleness".to_vec(), + b"hello".to_vec(), + b"stuff".to_vec(), + ]; + + // send right to left + for d in &data { + right.send(d.to_vec()).await?; + } + let _ = left.next().await; + let mut result = vec![]; + for _ in &data { + result.push(inner(left.next().await)); + } + assert_eq!(result, data); + + // send left to right + for d in &data { + left.send(d.to_vec()).await?; + } + let mut result = vec![]; + for _ in &data { + result.push(inner(right.next().await)); + } + assert_eq!(result, data); + + // send both ways + let mut res = vec![]; + for d in &data { + left.send(d.to_vec()).await?; + right.send(d.to_vec()).await?; + res.push(d.to_vec()); + } + let mut left_result = vec![]; + let mut right_result = vec![]; + for _ in &data { + right_result.push(inner(right.next().await)); + left_result.push(inner(left.next().await)); + } + assert_eq!(right_result, data); + assert_eq!(left_result, data); + + Ok(()) + } + + #[tokio::test] + async fn test_setup_error_causes_re_init() -> Result<()> { + let (lc, mut init_side_messages) = create_result_connected(); + let (mut other_side_messages, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); + let hello = b"hello".to_vec(); + + let send_fut = tokio::task::spawn(async move { + left.send(hello).await.unwrap(); + left + }); + + let init_msg = init_side_messages.next().await.unwrap()?; + + other_side_messages.send(init_msg).await?; + // other side encrypted needs to be polled to do work and send a response + let other_send_fut = tokio::task::spawn(async move { + right.send(b"other hello".to_vec()).await.unwrap(); + right + }); + + let _first_response = other_side_messages.next().await.unwrap()?; + // both sides now have a handshake in progress + + // send a bad message to init side. It should reset, and emit new init msg + init_side_messages.send(b"bad msg".to_vec()).await?; + + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + + // exchange one more message then we're set up + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + // now our spawned sends can complete + let mut left = send_fut.await?; + let mut right = other_send_fut.await?; + + // exchange hellos + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + + assert!(left.encryption_established()); + assert!(right.encryption_established()); + let _ = right.next().await; + let _ = left.next().await; + + assert_eq!(inner(right.next().await), b"hello"); + assert_eq!(inner(left.next().await), b"other hello"); + + Ok(()) + } +} diff --git a/src/protocol.rs b/src/protocol.rs index 7b8d468..e188baf 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,24 +1,32 @@ use async_channel::{Receiver, Sender}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use futures_lite::stream::Stream; +use futures_lite::{ + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; use futures_timer::Delay; -use std::collections::VecDeque; -use std::convert::TryInto; -use std::fmt; -use std::future::Future; -use std::io::{self, Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; - -use crate::channels::{Channel, ChannelMap}; -use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, FrameType, Message}; -use crate::reader::ReadState; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::writer::WriteState; +use std::{ + collections::VecDeque, + convert::TryInto, + fmt, + io::{self, Error, ErrorKind, Result}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tracing::{debug, error, instrument, warn}; + +use crate::{ + channels::{Channel, ChannelMap}, + constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}, + crypto::HandshakeResult, + encrypted_framed_message_channel, + message::{ChannelMessage, Message}, + mqueue::{MessageIo, MqueueEvent}, + noise::EncryptionInfo, + schema::*, + util::{map_channel_err, pretty_hash}, + Encrypted, Uint24LELengthPrefixedFraming, +}; macro_rules! return_error { ($msg:expr) => { @@ -29,7 +37,6 @@ macro_rules! return_error { } const CHANNEL_CAP: usize = 1000; -const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); /// Options for a Protocol instance. #[derive(Debug)] @@ -111,34 +118,9 @@ impl fmt::Debug for Event { } } -/// Protocol state -#[allow(clippy::large_enum_variant)] -pub(crate) enum State { - NotInitialized, - // The Handshake struct sits behind an option only so that we can .take() - // it out, it's never actually empty when in State::Handshake. - Handshake(Option), - SecretStream(Option), - Established, -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::NotInitialized => write!(f, "NotInitialized"), - State::Handshake(_) => write!(f, "Handshaking"), - State::SecretStream(_) => write!(f, "SecretStream"), - State::Established => write!(f, "Established"), - } - } -} - /// A Protocol stream. pub struct Protocol { - write_state: WriteState, - read_state: ReadState, - io: IO, - state: State, + io: MessageIo>>, options: Options, handshake: Option, channels: ChannelMap, @@ -153,10 +135,7 @@ pub struct Protocol { impl std::fmt::Debug for Protocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Protocol") - .field("write_state", &self.write_state) - .field("read_state", &self.read_state) //.field("io", &self.io) - .field("state", &self.state) .field("options", &self.options) .field("handshake", &self.handshake) .field("channels", &self.channels) @@ -181,12 +160,10 @@ where Sender>, Receiver>, ) = async_channel::bounded(1); + Protocol { - io, - read_state: ReadState::new(), - write_state: WriteState::new(), + io: MessageIo::new(encrypted_framed_message_channel(options.is_initiator, io)), options, - state: State::NotInitialized, channels: ChannelMap::new(), handshake: None, command_rx, @@ -246,18 +223,10 @@ where self.channels.iter().map(|c| c.discovery_key()) } - /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> IO { - self.io - } - + #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - if let State::NotInitialized = this.state { - return_error!(this.init()); - } - // Drain queued events first. if let Some(event) = this.queued_events.pop_front() { return Poll::Ready(Ok(event)); @@ -266,8 +235,8 @@ where // Read and process incoming messages. return_error!(this.poll_inbound_read(cx)); - if let State::Established = this.state { - // Check for commands, but only once the connection is established. + // Check for commands, but only once the connection is established. + if this.options.noise && this.handshake.is_some() { return_error!(this.poll_commands(cx)); } @@ -285,43 +254,21 @@ where } } - fn init(&mut self) -> Result<()> { - tracing::debug!( - "protocol init, state {:?}, options {:?}", - self.state, - self.options - ); - match self.state { - State::NotInitialized => {} - _ => return Ok(()), - }; - - self.state = if self.options.noise { - let mut handshake = Handshake::new(self.options.is_initiator)?; - // If the handshake start returns a buffer, send it now. - if let Some(buf) = handshake.start()? { - self.queue_frame_direct(buf.to_vec()).unwrap(); - } - self.read_state.set_frame_type(FrameType::Raw); - State::Handshake(Some(handshake)) - } else { - self.read_state.set_frame_type(FrameType::Message); - State::Established - }; - - Ok(()) - } - /// Poll commands. fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { - self.on_command(command)?; + if let Err(e) = self.on_command(command) { + error!(error = ?e, "Error handling command"); + return Err(e); + } } Ok(()) } - /// Poll the keepalive timer and queue a ping message if needed. - fn poll_keepalive(&mut self, cx: &mut Context<'_>) { + /// TODO Poll the keepalive timer and queue a ping message if needed. + fn poll_keepalive(&mut self, _cx: &mut Context<'_>) { + /* + const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); if Pin::new(&mut self.keepalive).poll(cx).is_ready() { if let State::Established = self.state { // 24 bit header for the empty message, hence the 3 @@ -330,8 +277,10 @@ where } self.keepalive.reset(KEEPALIVE_DURATION); } + */ } + // just handles Close and LocalSignal?? fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { // If message is close, close the local channel. if let ChannelMessage { @@ -355,36 +304,46 @@ where } /// Poll for inbound messages and processs them. + #[instrument(skip_all, err)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - let msg = self.read_state.poll_reader(cx, &mut self.io); - match msg { - Poll::Ready(Ok(message)) => { - self.on_inbound_frame(message)?; - } - Poll::Ready(Err(e)) => return Err(e), + match self.io.poll_inbound(cx) { + Poll::Ready(opt) => match opt { + Some(e) => match e { + MqueueEvent::Meta(einf) => match einf { + EncryptionInfo::Handshake(hs_res) => { + let remote_pubkey = parse_key(&hs_res.remote_pubkey)?; + self.handshake = Some(hs_res); + debug!(handshake = ?self.handshake, "set Protocol::handshake"); + self.queue_event(Event::Handshake(remote_pubkey)) + } + }, + MqueueEvent::Message(msgs) => self.on_inbound_channel_messages(msgs?)?, + }, + + None => return Ok(()), + }, Poll::Pending => return Ok(()), } } } /// Poll for outbound messages and write them. + #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { + // if no parking or setup in progress + if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) { + error!(err = ?e, "error from poll_outbound"); return Err(e); } - if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { - return Ok(()); - } - + // send messages outbound_rx match Pin::new(&mut self.outbound_rx).poll_next(cx) { Poll::Ready(Some(mut messages)) => { if !messages.is_empty() { messages.retain(|message| self.on_outbound_message(message)); - if !messages.is_empty() { - let frame = Frame::MessageBatch(messages); - self.write_state.park_frame(frame); + for msg in messages { + self.io.enqueue(msg); } } } @@ -394,119 +353,15 @@ where } } - fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { - match frame { - Frame::RawBatch(raw_batch) => { - let mut processed_state: Option = None; - for buf in raw_batch { - let state_name: String = format!("{:?}", self.state); - match self.state { - State::Handshake(_) => self.on_handshake_message(buf)?, - State::SecretStream(_) => self.on_secret_stream_message(buf)?, - State::Established => { - if let Some(processed_state) = processed_state.as_ref() { - let previous_state = if self.options.encrypted { - State::SecretStream(None) - } else { - State::Handshake(None) - }; - if processed_state == &format!("{previous_state:?}") { - // This is the unlucky case where the batch had two or more messages where - // the first one was correctly identified as Raw but everything - // after that should have been (decrypted and) a MessageBatch. Correct the mistake - // here post-hoc. - let buf = self.read_state.decrypt_buf(&buf)?; - let frame = Frame::decode(&buf, &FrameType::Message)?; - self.on_inbound_frame(frame)?; - continue; - } - } - unreachable!( - "May not receive raw frames in Established state" - ) - } - _ => unreachable!( - "May not receive raw frames outside of handshake or secretstream state, was {:?}", - self.state - ), - }; - if processed_state.is_none() { - processed_state = Some(state_name) - } - } - Ok(()) - } - Frame::MessageBatch(channel_messages) => match self.state { - State::Established => { - for channel_message in channel_messages { - self.on_inbound_message(channel_message)? - } - Ok(()) - } - _ => unreachable!("May not receive message batch frames when not established"), - }, - } - } - - fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { - let mut handshake = match &mut self.state { - State::Handshake(handshake) => handshake.take().unwrap(), - _ => unreachable!("May not call on_handshake_message when not in Handshake state"), - }; - - if let Some(response_buf) = handshake.read(&buf)? { - self.queue_frame_direct(response_buf.to_vec()).unwrap(); - } - - if !handshake.complete() { - self.state = State::Handshake(Some(handshake)); - } else { - let handshake_result = handshake.into_result()?; - - if self.options.encrypted { - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(&handshake_result)?; - self.state = State::SecretStream(Some(cipher)); - - // Send the secret stream init message header to the other side - self.queue_frame_direct(init_msg).unwrap(); - } else { - // Skip secret stream and go straight to Established, then notify about - // handshake - self.read_state.set_frame_type(FrameType::Message); - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - } - // Store handshake result - self.handshake = Some(handshake_result); + #[instrument(skip_all)] + fn on_inbound_channel_messages(&mut self, channel_messages: Vec) -> Result<()> { + for channel_message in channel_messages { + self.on_inbound_message(channel_message)? } Ok(()) } - fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { - let encrypt_cipher = match &mut self.state { - State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), - _ => { - unreachable!("May not call on_secret_stream_message when not in SecretStream state") - } - }; - let handshake_result = &self - .handshake - .as_ref() - .expect("Handshake result must be set before secret stream"); - let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; - self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); - self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); - self.read_state.set_frame_type(FrameType::Message); - - // Lastly notify that handshake is ready and set state to established - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - Ok(()) - } - + #[instrument(skip_all)] fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); @@ -520,6 +375,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_command(&mut self, command: Command) -> Result<()> { match command { Command::Open(key) => self.command_open(key), @@ -529,6 +385,7 @@ where } /// Open a Channel with the given key. Adding it to our channel map + #[instrument(skip_all)] fn command_open(&mut self, key: Key) -> Result<()> { // Create a new channel. let channel_handle = self.channels.attach_local(key); @@ -552,8 +409,7 @@ where capability, }); let channel_message = ChannelMessage::new(channel, message); - self.write_state - .queue_frame(Frame::MessageBatch(vec![channel_message])); + self.io.enqueue(channel_message); Ok(()) } @@ -570,6 +426,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; let channel_handle = @@ -586,15 +443,12 @@ where Ok(()) } + #[instrument(skip(self))] fn queue_event(&mut self, event: Event) { self.queued_events.push_back(event); } - fn queue_frame_direct(&mut self, body: Vec) -> Result { - let mut frame = Frame::RawBatch(vec![body]); - self.write_state.try_queue_direct(&mut frame) - } - + #[instrument(skip(self))] fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; self.verify_remote_capability(remote_capability.cloned(), key)?; @@ -624,6 +478,7 @@ where Ok(()) } + #[instrument(skip_all)] fn capability(&self, key: &[u8]) -> Option> { match self.handshake.as_ref() { Some(handshake) => handshake.capability(key), @@ -631,6 +486,7 @@ where } } + #[instrument(skip_all)] fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { match self.handshake.as_ref() { Some(handshake) => handshake.verify_remote_capability(capability, key), @@ -648,11 +504,15 @@ where { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Protocol::poll_next(self, cx).map(Some) + match Protocol::poll_next(self, cx) { + Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } } } -/// Send [Command](Command)s to the [Protocol](Protocol). +/// Send [`Command`]s to the [`Protocol`]. #[derive(Clone, Debug)] pub struct CommandTx(Sender); diff --git a/src/reader.rs b/src/reader.rs deleted file mode 100644 index 51b370b..0000000 --- a/src/reader.rs +++ /dev/null @@ -1,231 +0,0 @@ -use crate::crypto::DecryptCipher; -use futures_lite::io::AsyncRead; -use futures_timer::Delay; -use std::future::Future; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE}; -use crate::message::{Frame, FrameType}; -use crate::util::stat_uint24_le; -use std::time::Duration; - -const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64); -const READ_BUF_INITIAL_SIZE: usize = 1024 * 128; - -#[derive(Debug)] -pub(crate) struct ReadState { - /// The read buffer. - buf: Vec, - /// The start of the not-yet-processed byte range in the read buffer. - start: usize, - /// The end of the not-yet-processed byte range in the read buffer. - end: usize, - /// The logical state of the reading (either header or body). - step: Step, - /// The timeout after which the connection is closed. - timeout: Delay, - /// Optional decryption cipher. - cipher: Option, - /// The frame type to be passed to the decoder. - frame_type: FrameType, -} - -impl ReadState { - pub(crate) fn new() -> ReadState { - ReadState { - buf: vec![0u8; READ_BUF_INITIAL_SIZE], - start: 0, - end: 0, - step: Step::Header, - timeout: Delay::new(TIMEOUT), - cipher: None, - frame_type: FrameType::Raw, - } - } -} - -#[derive(Debug)] -enum Step { - Header, - Body { - header_len: usize, - body_len: usize, - }, - /// Multiple messages one after another - Batch, -} - -impl ReadState { - pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) { - self.cipher = Some(decrypt_cipher); - } - - /// Decrypts a given buf with stored cipher, if present. Used to correct - /// the rare mistake that more than two messages came in where the first - /// one created the cipher, and the next one should have been decrypted - /// but wasn't. - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result> { - if let Some(cipher) = self.cipher.as_mut() { - Ok(cipher.decrypt_buf(buf)?.0) - } else { - Ok(buf.to_vec()) - } - } - - pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) { - self.frame_type = frame_type; - } - - pub(crate) fn poll_reader( - &mut self, - cx: &mut Context<'_>, - mut reader: &mut R, - ) -> Poll> - where - R: AsyncRead + Unpin, - { - let mut incomplete = true; - loop { - if !incomplete { - if let Some(result) = self.process() { - return Poll::Ready(result); - } - } else { - incomplete = false; - } - let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) { - Poll::Ready(Ok(n)) if n > 0 => n, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - // If the reader is pending, poll the timeout. - Poll::Pending | Poll::Ready(Ok(_)) => { - // Return Pending if the timeout is pending, or an error if the - // timeout expired (i.e. returned Poll::Ready). - return Pin::new(&mut self.timeout) - .poll(cx) - .map(|()| Err(Error::new(ErrorKind::TimedOut, "Remote timed out"))); - } - }; - - let end = self.end + n; - let (success, segments) = create_segments(&self.buf[self.start..end])?; - if success { - if let Some(ref mut cipher) = self.cipher { - let mut dec_end = self.start; - for (index, header_len, body_len) in segments { - let de = cipher.decrypt( - &mut self.buf[self.start + index..end], - header_len, - body_len, - )?; - dec_end = self.start + index + de; - } - self.end = dec_end; - } else { - self.end = end; - } - } else { - // Could not segment due to buffer being full, need to cycle the buffer - // and possibly resize it too if the message is too big. - self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]); - - // Set incomplete flag to skip processing and instead poll more data - incomplete = true; - } - self.timeout.reset(TIMEOUT); - } - } - - fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { - let (last_index, last_header_len, last_body_len) = last_segment; - let total_incoming_length = last_index + last_header_len + last_body_len; - if self.buf.len() < total_incoming_length { - // The incoming segments will not fit into the buffer, need to resize it - self.buf.resize(total_incoming_length, 0u8); - } - let temp = self.buf[self.start..].to_vec(); - let len = temp.len(); - self.buf[..len].copy_from_slice(&temp[..]); - self.end = len; - self.start = 0; - } - - fn process(&mut self) -> Option> { - loop { - match self.step { - Step::Header => { - let stat = stat_uint24_le(&self.buf[self.start..self.end]); - if let Some((header_len, body_len)) = stat { - if body_len == 0 { - // This is a keepalive message, just remain in Step::Header - self.start += header_len; - return None; - } else if (self.start + header_len + body_len as usize) < self.end { - // There are more than one message here, create a batch from all of - // then - self.step = Step::Batch; - } else { - let body_len = body_len as usize; - if body_len > MAX_MESSAGE_SIZE as usize { - return Some(Err(Error::new( - ErrorKind::InvalidData, - "Message length above max allowed size", - ))); - } - self.step = Step::Body { - header_len, - body_len, - }; - } - } else { - return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header"))); - } - } - - Step::Body { - header_len, - body_len, - } => { - let message_len = header_len + body_len; - let range = self.start + header_len..self.start + message_len; - let frame = Frame::decode(&self.buf[range], &self.frame_type); - self.start += message_len; - self.step = Step::Header; - return Some(frame); - } - Step::Batch => { - let frame = - Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type); - self.start = self.end; - self.step = Step::Header; - return Some(frame); - } - } - } - } -} - -#[allow(clippy::type_complexity)] -fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { - let mut index: usize = 0; - let len = buf.len(); - let mut segments: Vec<(usize, usize, usize)> = vec![]; - while index < len { - if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) { - let body_len = body_len as usize; - segments.push((index, header_len, body_len)); - if len < index + header_len + body_len { - // The segments will not fit, return false to indicate that more needs to be read - return Ok((false, segments)); - } - index += header_len + body_len; - } else { - return Err(Error::new( - ErrorKind::InvalidData, - "Could not read header while decrypting", - )); - } - } - Ok((true, segments)) -} diff --git a/src/schema.rs b/src/schema.rs index ef58e77..49a0ac5 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,7 +1,11 @@ -use hypercore::encoding::{CompactEncoding, EncodingError, HypercoreState, State}; +use compact_encoding::{ + map_decode, map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, + CompactEncoding, EncodingError, +}; use hypercore::{ DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, }; +use tracing::instrument; /// Open message #[derive(Debug, Clone, PartialEq)] @@ -16,46 +20,55 @@ pub struct Open { pub capability: Option>, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Open) -> Result { - self.preencode(&value.channel)?; - self.preencode(&value.protocol)?; - self.preencode(&value.discovery_key)?; - if value.capability.is_some() { - self.add_end(1)?; // flags for future use - self.preencode_fixed_32()?; +impl CompactEncoding for Open { + #[instrument(skip_all, ret, err)] + fn encoded_size(&self) -> Result { + let out = sum_encoded_size!(self.channel, self.protocol, self.discovery_key); + if self.capability.is_some() { + return Ok( + out + + 1 // flags for future use + + 32, // TODO capabalilities buff should always be 32 bytes, but it's a vec + ); } - Ok(self.end()) + Ok(out) } - fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer)?; - self.encode(&value.protocol, buffer)?; - self.encode(&value.discovery_key, buffer)?; - if let Some(capability) = &value.capability { - self.add_start(1)?; // flags for future use - self.encode_fixed_32(capability, buffer)?; + #[instrument(skip_all)] + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = map_encode!(buffer, self.channel, self.protocol, self.discovery_key); + if let Some(cap) = &self.capability { + let (_, rest) = take_array_mut::<1>(rest)?; + return write_slice(cap, rest); } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - let protocol: String = self.decode(buffer)?; - let discovery_key: Vec = self.decode(buffer)?; - let capability: Option> = if self.start() < self.end() { - self.add_start(1)?; // flags for future use - let capability: Vec = self.decode_fixed_32(buffer)?.to_vec(); - Some(capability) + Ok(rest) + } + + #[instrument(skip_all, err)] + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((channel, protocol, discovery_key), rest) = + map_decode!(buffer, [u64, String, Vec]); + // NB: Open/Close are only sent alone in their own Frame. So we're done when there is no + // more data + let (capability, rest) = if !rest.is_empty() { + let (_, rest) = take_array::<1>(rest)?; + let (capability, rest) = take_array::<32>(rest)?; + (Some(capability.to_vec()), rest) } else { - None + (None, rest) }; - Ok(Open { - channel, - protocol, - discovery_key, - capability, - }) + Ok(( + Self { + channel, + protocol, + discovery_key, + capability, + }, + rest, + )) } } @@ -66,18 +79,21 @@ pub struct Close { pub channel: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Close) -> Result { - self.preencode(&value.channel) +impl CompactEncoding for Close { + fn encoded_size(&self) -> Result { + self.channel.encoded_size() } - fn encode(&mut self, value: &Close, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.channel.encode(buffer) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - Ok(Close { channel }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (channel, rest) = u64::decode(buffer)?; + Ok((Self { channel }, rest)) } } @@ -98,40 +114,44 @@ pub struct Synchronize { pub can_upgrade: bool, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Synchronize) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.fork)?; - self.preencode(&value.length)?; - self.preencode(&value.remote_length) - } - - fn encode(&mut self, value: &Synchronize, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.can_upgrade { 1 } else { 0 }; - flags |= if value.uploading { 2 } else { 0 }; - flags |= if value.downloading { 4 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.fork, buffer)?; - self.encode(&value.length, buffer)?; - self.encode(&value.remote_length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let fork: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - let remote_length: u64 = self.decode(buffer)?; +impl CompactEncoding for Synchronize { + fn encoded_size(&self) -> Result { + Ok(1 + sum_encoded_size!(self.fork, self.length, self.remote_length)) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; + flags |= if self.uploading { 2 } else { 0 }; + flags |= if self.downloading { 4 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + Ok(map_encode!( + rest, + self.fork, + self.length, + self.remote_length + )) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let ((fork, length, remote_length), rest) = map_decode!(rest, [u64, u64, u64]); let can_upgrade = flags & 1 != 0; let uploading = flags & 2 != 0; let downloading = flags & 4 != 0; - Ok(Synchronize { - fork, - length, - remote_length, - can_upgrade, - uploading, - downloading, - }) + Ok(( + Synchronize { + fork, + length, + remote_length, + can_upgrade, + uploading, + downloading, + }, + rest, + )) } } @@ -150,83 +170,105 @@ pub struct Request { pub seek: Option, /// Request upgrade pub upgrade: Option, + // TODO what is this + /// Request manifest + pub manifest: bool, + // TODO what is this + // this could prob be usize + /// Request priority + pub priority: u64, +} + +macro_rules! maybe_decode { + ($cond:expr, $type:ty, $buf:ident) => { + if $cond { + let (result, rest) = <$type>::decode($buf)?; + (Some(result), rest) + } else { + (None, $buf) + } + }; } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Request) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.id)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; +impl CompactEncoding for Request { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self.id, self.fork); + if let Some(block) = &self.block { + out += block.encoded_size()?; + } + if let Some(hash) = &self.hash { + out += hash.encoded_size()?; } - if let Some(hash) = &value.hash { - self.preencode(hash)?; + if let Some(seek) = &self.seek { + out += seek.encoded_size()?; } - if let Some(seek) = &value.seek { - self.preencode(seek)?; + if let Some(upgrade) = &self.upgrade { + out += upgrade.encoded_size()?; } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; + Ok(out) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + flags |= if self.manifest { 16 } else { 0 }; + flags |= if self.priority != 0 { 32 } else { 0 }; + let mut rest = write_array(&[flags], buffer)?; + rest = map_encode!(rest, self.id, self.fork); + + if let Some(block) = &self.block { + rest = block.encode(rest)?; } - Ok(self.end()) - } - - fn encode(&mut self, value: &Request, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.id, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; + if let Some(hash) = &self.hash { + rest = hash.encode(rest)?; } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; + if let Some(seek) = &self.seek { + rest = seek.encode(rest)?; } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; + if let Some(upgrade) = &self.upgrade { + rest = upgrade.encode(rest)?; } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; + + if self.priority != 0 { + rest = self.priority.encode(rest)?; } - Ok(self.start()) + + Ok(rest) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let id: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let ((id, fork), rest) = map_decode!(rest, [u64, u64]); + + let (block, rest) = maybe_decode!(flags & 1 != 0, RequestBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, RequestSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, RequestUpgrade, rest); + let manifest = flags & 16 != 0; + let (priority, rest) = if flags & 32 != 0 { + u64::decode(rest)? } else { - None + (0, rest) }; - Ok(Request { - id, - fork, - block, - hash, - seek, - upgrade, - }) + Ok(( + Request { + id, + fork, + block, + hash, + seek, + upgrade, + manifest, + priority, + }, + rest, + )) } } @@ -237,18 +279,21 @@ pub struct Cancel { pub request: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Cancel) -> Result { - self.preencode(&value.request) +impl CompactEncoding for Cancel { + fn encoded_size(&self) -> Result { + self.request.encoded_size() } - fn encode(&mut self, value: &Cancel, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.request.encode(buffer) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(Cancel { request }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (request, rest) = u64::decode(buffer)?; + Ok((Cancel { request }, rest)) } } @@ -269,81 +314,74 @@ pub struct Data { pub upgrade: Option, } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Data) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.request)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; +macro_rules! opt_encoded_size { + ($opt:expr, $sum:ident) => { + if let Some(thing) = $opt { + $sum += thing.encoded_size()?; } - if let Some(hash) = &value.hash { - self.preencode(hash)?; - } - if let Some(seek) = &value.seek { - self.preencode(seek)?; - } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Data, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.request, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; - } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; - } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; - } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; - } - Ok(self.start()) - } + }; +} - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let request: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) +// TODO we could write a macro where it takes a $cond that returns an opt. +// if the option is Some(T) then do T::encode(buf) +// also if some add $flag. +// This would simplify some of these impls +macro_rules! opt_encoded_bytes { + ($opt:expr, $buf:ident) => { + if let Some(thing) = $opt { + thing.encode($buf)? } else { - None - }; - Ok(Data { - request, - fork, - block, - hash, - seek, - upgrade, - }) + $buf + } + }; +} +impl CompactEncoding for Data { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self.request, self.fork); + opt_encoded_size!(&self.block, out); + opt_encoded_size!(&self.hash, out); + opt_encoded_size!(&self.seek, out); + opt_encoded_size!(&self.upgrade, out); + Ok(out) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + let rest = map_encode!(rest, self.request, self.fork); + + let rest = opt_encoded_bytes!(&self.block, rest); + let rest = opt_encoded_bytes!(&self.hash, rest); + let rest = opt_encoded_bytes!(&self.seek, rest); + let rest = opt_encoded_bytes!(&self.upgrade, rest); + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let ((request, fork), rest) = map_decode!(rest, [u64, u64]); + let (block, rest) = maybe_decode!(flags & 1 != 0, DataBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, DataHash, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, DataSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, DataUpgrade, rest); + Ok(( + Data { + request, + fork, + block, + hash, + seek, + upgrade, + }, + rest, + )) } } @@ -367,18 +405,21 @@ pub struct NoData { pub request: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &NoData) -> Result { - self.preencode(&value.request) +impl CompactEncoding for NoData { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.request)) } - fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.request)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(NoData { request }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (request, rest) = u64::decode(buffer)?; + Ok((Self { request }, rest)) } } @@ -390,21 +431,22 @@ pub struct Want { /// Length pub length: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Want) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) + +impl CompactEncoding for Want { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.start, self.length)) } - fn encode(&mut self, value: &Want, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.start, self.length)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Want { start, length }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -416,21 +458,22 @@ pub struct Unwant { /// Length pub length: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Unwant) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) + +impl CompactEncoding for Unwant { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.start, self.length)) } - fn encode(&mut self, value: &Unwant, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.start, self.length)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Unwant { start, length }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -442,21 +485,21 @@ pub struct Bitfield { /// Bitfield in 32 bit chunks beginning from `start` pub bitfield: Vec, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Bitfield) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.bitfield) +impl CompactEncoding for Bitfield { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.start, self.bitfield)) } - fn encode(&mut self, value: &Bitfield, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.bitfield, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.start, self.bitfield)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let bitfield: Vec = self.decode(buffer)?; - Ok(Bitfield { start, bitfield }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((start, bitfield), rest) = map_decode!(buffer, [u64, Vec]); + Ok((Self { start, bitfield }, rest)) } } @@ -473,41 +516,46 @@ pub struct Range { pub length: u64, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Range) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.start)?; - if value.length != 1 { - self.preencode(&value.length)?; +impl CompactEncoding for Range { + fn encoded_size(&self) -> Result { + let mut out = 1 + sum_encoded_size!(self.start); + if self.length != 1 { + out += self.length.encoded_size()?; } - Ok(self.end()) + Ok(out) } - fn encode(&mut self, value: &Range, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.drop { 1 } else { 0 }; - flags |= if value.length == 1 { 2 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.start, buffer)?; - if value.length != 1 { - self.encode(&value.length, buffer)?; + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.drop { 1 } else { 0 }; + flags |= if self.length == 1 { 2 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + let rest = self.start.encode(rest)?; + if self.length != 1 { + return self.length.encode(rest); } - Ok(self.end()) + Ok(rest) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let start: u64 = self.decode(buffer)?; + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (start, rest) = u64::decode(rest)?; let drop = flags & 1 != 0; - let length: u64 = if flags & 2 != 0 { - 1 + let (length, rest) = if flags & 2 != 0 { + (1, rest) } else { - self.decode(buffer)? + u64::decode(rest)? }; - Ok(Range { - drop, - length, - start, - }) + Ok(( + Range { + drop, + length, + start, + }, + rest, + )) } } @@ -519,20 +567,20 @@ pub struct Extension { /// Message content, use empty vector for no data. pub message: Vec, } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Extension) -> Result { - self.preencode(&value.name)?; - self.preencode_raw_buffer(&value.message) +impl CompactEncoding for Extension { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self.name, self.message)) } - fn encode(&mut self, value: &Extension, buffer: &mut [u8]) -> Result { - self.encode(&value.name, buffer)?; - self.encode_raw_buffer(&value.message, buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(map_encode!(buffer, self.name, self.message)) } - fn decode(&mut self, buffer: &[u8]) -> Result { - let name: String = self.decode(buffer)?; - let message: Vec = self.decode_raw_buffer(buffer)?; - Ok(Extension { name, message }) + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ((name, message), rest) = map_decode!(buffer, [String, Vec]); + Ok((Self { name, message }, rest)) } } diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 0000000..2e5e994 --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,204 @@ +#![allow(dead_code)] +use std::{ + io::{self, ErrorKind}, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{ + channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender}, + Sink, Stream, StreamExt, +}; + +#[derive(Debug)] +pub(crate) struct Io { + receiver: Receiver>, + sender: Sender>, +} + +impl Default for Io { + fn default() -> Self { + let (sender, receiver) = unbounded(); + Self { sender, receiver } + } +} + +impl Stream for Io { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.receiver).poll_next(cx) + } +} + +impl Sink> for Io { + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + Pin::new(&mut self.sender) + .start_send(item) + .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +#[derive(Default, Debug)] +pub(crate) struct TwoWay { + l_to_r: Io, + r_to_l: Io, +} + +impl TwoWay { + fn split_sides(self) -> (Io, Io) { + let left = Io { + sender: self.l_to_r.sender, + receiver: self.r_to_l.receiver, + }; + let right = Io { + sender: self.r_to_l.sender, + receiver: self.l_to_r.receiver, + }; + (left, right) + } +} + +pub(crate) fn log() { + static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); + START_LOGS.get_or_init(|| { + use tracing_subscriber::{ + layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, + }; + let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable + + // Create the hierarchical layer from tracing_tree + let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level + .with_targets(true) + .with_bracketed_fields(true) + .with_indent_lines(true) + .with_thread_ids(false) + .with_thread_names(true) + //.with_span_modes(true) + ; + + tracing_subscriber::registry() + .with(env_filter) + .with(tree_layer) + .init(); + }); +} + +pub(crate) struct Moo { + receiver: Rx, + sender: Tx, +} + +impl + Unpin, Tx: Unpin> Stream for Moo { + type Item = RxItem; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Pin::new(&mut this.receiver).poll_next(cx) + } +} + +impl + Unpin> Sink + for Moo +{ + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: TxItem) -> Result<(), Self::Error> { + let this = self.get_mut(); + Pin::new(&mut this.sender) + .start_send(item) + .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +/// Creaee [`Moo`] from return value of [`unbounded`] +impl From<(Tx, Rx)> for Moo { + fn from(value: (Tx, Rx)) -> Self { + Moo { + receiver: value.1, + sender: value.0, + } + } +} + +impl Moo { + /// connect two [`Moo`]s + fn connect( + self, + other: Moo, + ) -> (Moo, Moo) { + let left = Moo { + receiver: self.receiver, + sender: other.sender, + }; + let right = Moo { + receiver: other.receiver, + sender: self.sender, + }; + (left, right) + } +} + +fn result_channel() -> (Sender>, impl Stream>>) { + let (tx, rx) = unbounded::>(); + (tx, rx.map(Ok)) +} + +#[allow(clippy::type_complexity)] +pub(crate) fn create_result_connected() -> ( + Moo>>, impl Sink>>, + Moo>>, impl Sink>>, +) { + let a = Moo::from(result_channel()); + let b = Moo::from(result_channel()); + a.connect(b) +} + +#[cfg(test)] +mod test { + #![allow(unused_imports)] // test's within tests confused clippy + use futures::{SinkExt, StreamExt}; + #[tokio::test] + async fn way_one() { + let mut a = super::Io::default(); + let _ = a.send(b"hello".into()).await; + let Some(res) = a.next().await else { panic!() }; + assert_eq!(res, b"hello"); + } + + #[tokio::test] + async fn split() { + let (mut left, mut right) = (super::TwoWay::default()).split_sides(); + left.send(b"hello".to_vec()).await.unwrap(); + let Some(res) = right.next().await else { + panic!(); + }; + assert_eq!(res, b"hello"); + } +} diff --git a/src/util.rs b/src/util.rs index c99ff9c..5f243f2 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,11 +2,12 @@ use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; -use std::convert::TryInto; -use std::io::{Error, ErrorKind}; +use std::{ + convert::TryInto, + io::{Error, ErrorKind}, +}; -use crate::constants::DISCOVERY_NS_BUF; -use crate::DiscoveryKey; +use crate::{constants::DISCOVERY_NS_BUF, DiscoveryKey}; /// Calculate the discovery key of a key. /// @@ -29,9 +30,8 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { } pub(crate) const UINT_24_LENGTH: usize = 3; - #[inline] -pub(crate) fn wrap_uint24_le(data: &Vec) -> Vec { +pub(crate) fn wrap_uint24_le(data: &[u8]) -> Vec { let mut buf: Vec = vec![0; 3]; let n = data.len(); write_uint24_le(n, &mut buf); @@ -47,6 +47,7 @@ pub(crate) fn write_uint24_le(n: usize, buf: &mut [u8]) { } #[inline] +/// Read uint24 from the given `buffer` as a `u64` pub(crate) fn stat_uint24_le(buffer: &[u8]) -> Option<(usize, u64)> { if buffer.len() >= 3 { let len = diff --git a/src/writer.rs b/src/writer.rs deleted file mode 100644 index e3cc5da..0000000 --- a/src/writer.rs +++ /dev/null @@ -1,173 +0,0 @@ -use crate::crypto::EncryptCipher; -use crate::message::{Encoder, Frame}; - -use futures_lite::{ready, AsyncWrite}; -use std::collections::VecDeque; -use std::fmt; -use std::io::Result; -use std::pin::Pin; -use std::task::{Context, Poll}; - -const BUF_SIZE: usize = 1024 * 64; - -#[derive(Debug)] -pub(crate) enum Step { - Flushing, - Writing, - Processing, -} - -pub(crate) struct WriteState { - queue: VecDeque, - buf: Vec, - current_frame: Option, - start: usize, - end: usize, - cipher: Option, - step: Step, -} - -impl fmt::Debug for WriteState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("WriteState") - .field("queue (len)", &self.queue.len()) - .field("step", &self.step) - .field("buf (len)", &self.buf.len()) - .field("current_frame", &self.current_frame) - .field("start", &self.start) - .field("end", &self.end) - .field("cipher", &self.cipher.is_some()) - .finish() - } -} - -impl WriteState { - pub(crate) fn new() -> Self { - Self { - queue: VecDeque::new(), - buf: vec![0u8; BUF_SIZE], - current_frame: None, - start: 0, - end: 0, - cipher: None, - step: Step::Processing, - } - } - - pub(crate) fn queue_frame(&mut self, frame: F) - where - F: Into, - { - self.queue.push_back(frame.into()) - } - - pub(crate) fn try_queue_direct(&mut self, frame: &mut T) -> Result { - let promised_len = frame.encoded_len()?; - let padded_promised_len = self.safe_encrypted_len(promised_len); - if self.buf.len() < padded_promised_len { - self.buf.resize(padded_promised_len, 0u8); - } - if padded_promised_len > self.remaining() { - return Ok(false); - } - let actual_len = frame.encode(&mut self.buf[self.end..])?; - if actual_len != promised_len { - panic!( - "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" - ); - } - self.advance(padded_promised_len)?; - Ok(true) - } - - pub(crate) fn can_park_frame(&self) -> bool { - self.current_frame.is_none() - } - - pub(crate) fn park_frame(&mut self, frame: F) - where - F: Into, - { - if self.current_frame.is_none() { - self.current_frame = Some(frame.into()) - } - } - - fn advance(&mut self, n: usize) -> Result<()> { - let end = self.end + n; - - let encrypted_end = if let Some(ref mut cipher) = self.cipher { - self.end + cipher.encrypt(&mut self.buf[self.end..end])? - } else { - end - }; - - self.end = encrypted_end; - Ok(()) - } - - pub(crate) fn upgrade_with_encrypt_cipher(&mut self, encrypt_cipher: EncryptCipher) { - self.cipher = Some(encrypt_cipher); - } - - fn remaining(&self) -> usize { - self.buf.len() - self.end - } - - fn pending(&self) -> usize { - self.end - self.start - } - - pub(crate) fn poll_send( - &mut self, - cx: &mut Context<'_>, - mut writer: &mut W, - ) -> Poll> - where - W: AsyncWrite + Unpin, - { - loop { - self.step = match self.step { - Step::Processing => { - if self.current_frame.is_none() && !self.queue.is_empty() { - self.current_frame = self.queue.pop_front(); - } - - if let Some(mut frame) = self.current_frame.take() { - if !self.try_queue_direct(&mut frame)? { - self.current_frame = Some(frame); - } - } - - if self.pending() == 0 { - return Poll::Ready(Ok(())); - } - Step::Writing - } - Step::Writing => { - let n = ready!( - Pin::new(&mut writer).poll_write(cx, &self.buf[self.start..self.end]) - )?; - self.start += n; - if self.start == self.end { - self.start = 0; - self.end = 0; - } - Step::Flushing - } - Step::Flushing => { - ready!(Pin::new(&mut writer).poll_flush(cx))?; - Step::Processing - } - } - } - } - - fn safe_encrypted_len(&self, encoded_len: usize) -> usize { - if let Some(cipher) = &self.cipher { - cipher.safe_encrypted_len(encoded_len) - } else { - encoded_len - } - } -} diff --git a/tests/_util.rs b/tests/_util.rs index 9d0f9bf..78c89e4 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -1,13 +1,24 @@ use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task::{self, JoinHandle}; -use futures_lite::io::{AsyncRead, AsyncWrite}; +use futures_lite::{ + io::{AsyncRead, AsyncWrite}, + StreamExt, +}; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; use std::io; +use tokio::{io::DuplexStream, task::JoinHandle}; + +type TokioDuplex = tokio_util::compat::Compat; + +pub(crate) fn duplex(channel_size: usize) -> (TokioDuplex, TokioDuplex) { + use tokio_util::compat::TokioAsyncReadCompatExt as _; + let (left, right) = tokio::io::duplex(channel_size); + (left.compat(), right.compat()) +} pub type MemoryProtocol = Protocol>; -pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { + +pub fn create_pair_memory() -> (MemoryProtocol, MemoryProtocol) { let (ar, bw) = sluice::pipe::pipe(); let (br, aw) = sluice::pipe::pipe(); @@ -15,24 +26,23 @@ pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol) let b = ProtocolBuilder::new(false); let a = a.connect_rw(ar, aw); let b = b.connect_rw(br, bw); - Ok((a, b)) + (a, b) } -pub type TcpProtocol = Protocol; -pub async fn create_pair_tcp() -> io::Result<(TcpProtocol, TcpProtocol)> { - let (stream_a, stream_b) = tcp::pair().await?; - let a = ProtocolBuilder::new(true).connect(stream_a); - let b = ProtocolBuilder::new(false).connect(stream_b); +pub async fn create_pair_memory2() -> io::Result<(Protocol, Protocol)> { + let (left, right) = duplex(1024 * 1024); + let a = ProtocolBuilder::new(true); + let b = ProtocolBuilder::new(false); + let a = a.connect(left); + let b = b.connect(right); Ok((a, b)) } -pub fn next_event( - mut proto: Protocol, -) -> impl Future, io::Result)> +pub fn next_event(mut proto: Protocol) -> JoinHandle<(Protocol, io::Result)> where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - task::spawn(async move { + tokio::task::spawn(async move { let e1 = proto.next().await; let e1 = e1.unwrap(); (proto, e1) @@ -62,7 +72,7 @@ pub fn drive_until_channel( where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - task::spawn(async move { + tokio::task::spawn(async move { while let Some(event) = proto.next().await { let event = event?; if let Event::Channel(channel) = event { @@ -76,30 +86,10 @@ where }) } -pub mod tcp { - use async_std::net::{TcpListener, TcpStream}; - use async_std::prelude::*; - use async_std::task; - use std::io::{Error, ErrorKind, Result}; - pub async fn pair() -> Result<(TcpStream, TcpStream)> { - let address = "localhost:9999"; - let listener = TcpListener::bind(&address).await?; - let mut incoming = listener.incoming(); - - let connect_task = task::spawn(async move { TcpStream::connect(&address).await }); - - let server_stream = incoming.next().await; - let server_stream = - server_stream.ok_or_else(|| Error::new(ErrorKind::Other, "Stream closed"))?; - let server_stream = server_stream?; - let client_stream = connect_task.await?; - Ok((server_stream, client_stream)) - } -} - -const RETRY_TIMEOUT: u64 = 100_u64; -const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; +#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { + const RETRY_TIMEOUT: u64 = 100_u64; + const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; loop { let timeout = async_std::future::timeout( Duration::from_millis(NO_RESPONSE_TIMEOUT), diff --git a/tests/basic.rs b/tests/basic.rs index 8a99c7e..d713937 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,26 +1,26 @@ -#![allow(dead_code, unused_imports)] - -use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use hypercore_protocol::{discovery_key, Channel, Event, Message, Protocol, ProtocolBuilder}; -use hypercore_protocol::{schema::*, DiscoveryKey}; +use _util::{ + create_pair_memory, create_pair_memory2, drive_until_channel, event_channel, + event_discovery_key, next_event, +}; +use futures_lite::StreamExt; +use hypercore_protocol::{discovery_key, schema::*, DiscoveryKey, Event, Message}; use std::io; -use test_log::test; +use tokio::task; mod _util; -use _util::*; -#[test(async_std::test)] +#[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - // env_logger::init(); - let (proto_a, proto_b) = create_pair_memory().await?; + let (proto_a, proto_b) = create_pair_memory2().await?; let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_a, event_a) = next_a.await; - let (proto_b, event_b) = next_b.await; + let (mut proto_a, event_a) = next_a.await?; + let (proto_b, event_b) = next_b.await?; + + //let (a, b) = join(next_a, next_b).await; + //let (mut proto_a, event_a) = a?; + //let (proto_b, event_b) = b?; assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); @@ -35,18 +35,18 @@ async fn basic_protocol() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_b, event_b) = next_b.await; + let (mut proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::DiscoveryKey(_)))); assert_eq!(event_discovery_key(event_b.unwrap()), discovery_key(&key)); proto_b.open(key).await?; let next_b = next_event(proto_b); - let (proto_b, event_b) = next_b.await; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::Channel(_)))); let mut channel_b = event_channel(event_b.unwrap()); - let (proto_a, event_a) = next_a.await; + let (proto_a, event_a) = next_a.await?; assert!(matches!(event_a, Ok(Event::Channel(_)))); let mut channel_a = event_channel(event_a.unwrap()); @@ -68,8 +68,8 @@ async fn basic_protocol() -> anyhow::Result<()> { channel_a.close().await?; - let (_, event_a) = next_a.await; - let (_, event_b) = next_b.await; + let (_, event_a) = next_a.await?; + let (_, event_b) = next_b.await?; assert!(matches!(event_a, Ok(Event::Close(_)))); assert!(matches!(event_b, Ok(Event::Close(_)))); @@ -78,9 +78,9 @@ async fn basic_protocol() -> anyhow::Result<()> { Ok(()) } -#[test(async_std::test)] +#[tokio::test] async fn open_close_channels() -> anyhow::Result<()> { - let (mut proto_a, mut proto_b) = create_pair_memory().await?; + let (mut proto_a, mut proto_b) = create_pair_memory(); let key1 = [0u8; 32]; let key2 = [1u8; 32]; @@ -91,8 +91,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (mut proto_a, mut channel_a1) = next_a.await?; - let (mut proto_b, mut channel_b1) = next_b.await?; + let (mut proto_a, mut channel_a1) = next_a.await??; + let (mut proto_b, mut channel_b1) = next_b.await??; proto_a.open(key2).await?; proto_b.open(key2).await?; @@ -100,8 +100,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (proto_a, mut channel_a2) = next_a.await?; - let (proto_b, mut channel_b2) = next_b.await?; + let (proto_a, mut channel_a2) = next_a.await??; + let (proto_b, mut channel_b2) = next_b.await??; eprintln!( "got channels: {:?}", @@ -119,8 +119,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_a, ev_a) = next_a.await; - let (mut proto_b, ev_b) = next_b.await; + let (mut proto_a, ev_a) = next_a.await?; + let (mut proto_b, ev_b) = next_b.await?; let ev_a = ev_a?; let ev_b = ev_b?; eprintln!("next a: {ev_a:?}"); @@ -165,7 +165,6 @@ async fn open_close_channels() -> anyhow::Result<()> { assert_eq!(msg_b, Some(want(0, 10))); eprintln!("all good!"); - Ok(()) } diff --git a/tests/js/mod.rs b/tests/js/mod.rs index 8894b3d..b8cd6ec 100644 --- a/tests/js/mod.rs +++ b/tests/js/mod.rs @@ -1,8 +1,10 @@ use anyhow::Result; use instant::Duration; -use std::fs::{create_dir_all, remove_dir_all, remove_file}; -use std::path::Path; -use std::process::Command; +use std::{ + fs::{create_dir_all, remove_dir_all, remove_file}, + path::Path, + process::Command, +}; #[cfg(feature = "async-std")] use async_std::{ diff --git a/tests/js_interop.rs b/tests/js_interop.rs index d703734..d81a812 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,44 +1,36 @@ +pub mod _util; +#[path = "../src/test_utils.rs"] +mod test_utils; + use _util::wait_for_localhost_port; use anyhow::Result; +#[cfg(feature = "tokio")] +use async_compat::CompatExt; use futures::Future; use futures_lite::stream::StreamExt; -use hypercore::SigningKey; use hypercore::{ - Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, + Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, SigningKey, Storage, VerifyingKey, PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH, }; use instant::Duration; -use std::fmt::Debug; -use std::path::Path; -use std::sync::Arc; -use std::sync::Once; - -#[cfg(feature = "tokio")] -use async_compat::CompatExt; -#[cfg(feature = "async-std")] -use async_std::{ - fs::{metadata, File}, - io::{prelude::BufReadExt, BufReader, BufWriter, WriteExt}, - net::{TcpListener, TcpStream}, - sync::Mutex, - task::{self, sleep}, - test as async_test, +use std::{ + fmt::Debug, + path::Path, + sync::{Arc, Once}, }; -use test_log::test; #[cfg(feature = "tokio")] use tokio::{ fs::{metadata, File}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, net::{TcpListener, TcpStream}, sync::Mutex, - task, test as async_test, + task, time::sleep, }; +use tracing::instrument; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; +use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; -pub mod _util; mod js; use js::{cleanup, install, js_run_client, js_start_server, prepare_test_set}; @@ -49,6 +41,7 @@ fn init() { cleanup(); install(); }); + test_utils::log(); } const TEST_SET_NODE_CLIENT_NODE_SERVER: &str = "ncns"; @@ -59,65 +52,64 @@ const TEST_SET_SERVER_WRITER: &str = "sw"; const TEST_SET_CLIENT_WRITER: &str = "cw"; const TEST_SET_SIMPLE: &str = "simple"; -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_server_writer() -> Result<()> { - js_interop_ncns_simple(true, 8101).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncns_server_writer() -> Result<()> { + ncns(true, 8101).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_client_writer() -> Result<()> { - js_interop_ncns_simple(false, 8102).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncns_client_writer() -> Result<()> { + ncns(false, 8102).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcns_simple_server_writer() -> Result<()> { - js_interop_rcns_simple(true, 8103).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn rcns_server_writer() -> Result<()> { + rcns(true, 8103).await?; Ok(()) } -#[test(async_test)] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcns_simple_client_writer() -> Result<()> { - js_interop_rcns_simple(false, 8104).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn rcns_client_writer() -> Result<()> { + rcns(false, 8104).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_server_writer() -> Result<()> { - js_interop_ncrs_simple(true, 8105).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncrs_server_writer() -> Result<()> { + ncrs(true, 8105).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_client_writer() -> Result<()> { - js_interop_ncrs_simple(false, 8106).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn ncrs_client_writer() -> Result<()> { + ncrs(false, 8106).await?; Ok(()) } -#[test(async_test)] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcrs_simple_server_writer() -> Result<()> { - js_interop_rcrs_simple(true, 8107).await?; +#[tokio::test] +#[cfg_attr(not(feature = "js_tests"), ignore)] +async fn rcrs_server_writer() -> Result<()> { + rcrs(true, 8107).await?; Ok(()) } -#[test(async_test)] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcrs_simple_client_writer() -> Result<()> { - js_interop_rcrs_simple(false, 8108).await?; +#[tokio::test] +//#[cfg_attr(not(feature = "js_tests"), ignore)] +//#[ignore] // FIXME this tests hangs sporadically +async fn rcrs_client_writer() -> Result<()> { + rcrs(false, 8108).await?; Ok(()) } -async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -156,7 +148,7 @@ async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn rcns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -195,13 +187,14 @@ async fn js_interop_rcns_simple(server_writer: bool, port: u32) -> Result<()> { &result_path, ) .await?; + dbg!(); assert_result(result_path, item_count, item_size, data_char).await?; drop(server); Ok(()) } -async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -247,7 +240,7 @@ async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn rcrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -339,11 +332,15 @@ async fn run_client( data_path: &str, result_path: &str, ) -> Result<()> { + dbg!(); let hypercore = if is_writer { + dbg!(); create_writer_hypercore(data_count, data_size, data_char, data_path).await? } else { + dbg!(); create_reader_hypercore(data_path).await? }; + dbg!(); let hypercore_wrapper = HypercoreWrapper::from_disk_hypercore( hypercore, if is_writer { @@ -352,7 +349,9 @@ async fn run_client( Some(result_path.to_string()) }, ); + dbg!(); tcp_client(port, on_replication_connection, Arc::new(hypercore_wrapper)).await?; + dbg!(); Ok(()) } @@ -441,63 +440,38 @@ pub fn get_test_key_pair(include_secret: bool) -> PartialKeypair { PartialKeypair { public, secret } } -#[cfg(feature = "async-std")] -async fn on_replication_connection( - stream: TcpStream, - is_initiator: bool, - hypercore: Arc, -) -> Result<()> { - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); - while let Some(event) = protocol.next().await { - let event = event?; - match event { - Event::Handshake(_) => { - if is_initiator { - protocol.open(*hypercore.key()).await?; - } - } - Event::DiscoveryKey(dkey) => { - if hypercore.discovery_key == dkey { - protocol.open(*hypercore.key()).await?; - } else { - panic!("Invalid discovery key"); - } - } - Event::Channel(channel) => { - hypercore.on_replication_peer(channel); - } - Event::Close(_dkey) => { - break; - } - _ => {} - } - } - Ok(()) -} - #[cfg(feature = "tokio")] +#[instrument(skip_all)] async fn on_replication_connection( stream: TcpStream, is_initiator: bool, hypercore: Arc, ) -> Result<()> { + use tracing::info; + let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream.compat()); + let mut channel_opened = false; while let Some(event) = protocol.next().await { let event = event?; match event { Event::Handshake(_) => { - if is_initiator { + info!("Event::Handshake"); + if is_initiator && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } } Event::DiscoveryKey(dkey) => { - if hypercore.discovery_key == dkey { + info!("Event::DiscoveryKey"); + if hypercore.discovery_key == dkey && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } else { panic!("Invalid discovery key"); } } Event::Channel(channel) => { + info!("Event::Channel is_initiator = {is_initiator}"); hypercore.on_replication_peer(channel); } Event::Close(_dkey) => { @@ -647,6 +621,8 @@ async fn on_replication_message( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -758,6 +734,8 @@ async fn on_replication_message( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } let exit = if synced { @@ -767,7 +745,10 @@ async fn on_replication_message( for i in 0..new_info.contiguous_length { let value = String::from_utf8(hypercore.get(i).await?.unwrap()).unwrap(); let line = format!("{} {}\n", i, value); - writer.write(line.as_bytes()).await?; + let n_written = writer.write(line.as_bytes()).await?; + if line.len() != n_written { + panic!("Couldn't write all write all bytse"); + } } writer.flush().await?; true @@ -847,39 +828,6 @@ impl RustServer { } } -impl Drop for RustServer { - fn drop(&mut self) { - #[cfg(feature = "async-std")] - if let Some(handle) = self.handle.take() { - task::block_on(handle.cancel()); - } - } -} - -#[cfg(feature = "async-std")] -pub async fn tcp_server( - port: u32, - onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, - context: C, -) -> Result<()> -where - F: Future> + Send, - C: Clone + Send + 'static, -{ - let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; - let mut incoming = listener.incoming(); - while let Some(Ok(stream)) = incoming.next().await { - let context = context.clone(); - let _peer_addr = stream.peer_addr().unwrap(); - task::spawn(async move { - onconnection(stream, false, context) - .await - .expect("Should return ok"); - }); - } - Ok(()) -} - #[cfg(feature = "tokio")] pub async fn tcp_server( port: u32,