Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions adb_client/src/message_devices/adb_message_device.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use rand::RngExt;
use std::{path::Path, time::Duration};
use std::path::Path;

use crate::{
Result, RustADBError,
adb_transport::ADBTransport,
message_devices::{
adb_message_transport::ADBMessageTransport,
adb_multiplexer::ADBMessageMultiplexer,
adb_session::ADBSession,
adb_transport_message::{
ADBTransportMessage, AUTH_RSAPUBLICKEY, AUTH_SIGNATURE, AUTH_TOKEN,
Expand All @@ -20,7 +22,7 @@ use crate::{
/// Structure is totally agnostic over which transport is truly used.
#[derive(Debug)]
pub(crate) struct ADBMessageDevice<T: ADBMessageTransport> {
transport: T,
multiplexer: ADBMessageMultiplexer<T>,
}

impl<T: ADBMessageTransport> ADBMessageDevice<T> {
Expand All @@ -36,19 +38,17 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
ADBRsaKey::new_random()?
};

let mut message_device = Self { transport };
let mut message_device = Self {
multiplexer: ADBMessageMultiplexer::new(transport),
};
message_device.connect(&private_key)?;

Ok(message_device)
}

pub(crate) fn get_transport_mut(&mut self) -> &mut T {
&mut self.transport
}

/// Send initial connect
fn connect(&mut self, private_key: &ADBRsaKey) -> Result<()> {
self.get_transport_mut().connect()?;
self.multiplexer.connect()?;

let message = ADBTransportMessage::try_new(
MessageCommand::Cnxn,
Expand All @@ -57,21 +57,21 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
format!("host::{}\0", env!("CARGO_PKG_NAME")).as_bytes(),
)?;

self.get_transport_mut().write_message(message)?;
self.multiplexer.write_message(message)?;

let message = self.get_transport_mut().read_message()?;
let message = self.multiplexer.read_authentication_message()?;

// Check if a client is requesting a secure connection and upgrade it if necessary
match message.header().command() {
MessageCommand::Stls => {
self.get_transport_mut()
self.multiplexer
.write_message(ADBTransportMessage::try_new(
MessageCommand::Stls,
1,
0,
&[],
)?)?;
self.get_transport_mut().upgrade_connection()?;
self.multiplexer.upgrade_connection()?;
log::debug!("Connection successfully upgraded from TCP to TLS");
Ok(())
}
Expand Down Expand Up @@ -116,15 +116,17 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {

let message = ADBTransportMessage::try_new(MessageCommand::Auth, AUTH_SIGNATURE, 0, &sign)?;

self.transport.write_message(message)?;
self.multiplexer.write_message(message)?;

let received_response = self.transport.read_message()?;
let received_response = self.multiplexer.read_authentication_message()?;

if received_response.header().command() == MessageCommand::Cnxn {
log::info!(
"Authentication OK, device info {}",
String::from_utf8(received_response.into_payload())?
);
// Authentication is OK, we can now consider sessions
self.multiplexer.set_authenticated();
return Ok(());
}

Expand All @@ -134,11 +136,11 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
let message =
ADBTransportMessage::try_new(MessageCommand::Auth, AUTH_RSAPUBLICKEY, 0, &pubkey)?;

self.transport.write_message(message)?;
self.multiplexer.write_message(message)?;

let response = self
.transport
.read_message_with_timeout(Duration::from_secs(10))
.multiplexer
.read_authentication_message()
.and_then(|message| {
message.assert_command(MessageCommand::Cnxn)?;
Ok(message)
Expand All @@ -148,6 +150,9 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
"Authentication OK, device info {}",
String::from_utf8(response.into_payload())?
);
// Authentication is OK, we can now consider sessions
self.multiplexer.set_authenticated();

Ok(())
}

Expand All @@ -165,9 +170,12 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
0,
cmd.to_string().as_bytes(),
)?;
self.transport.write_message(message)?;
log::debug!("here");
self.multiplexer.write_message(message)?;
log::debug!("after");

let response = self.transport.read_message()?;
let response = self.multiplexer.read_message(local_id)?;
log::debug!("got message from multiplexer");

if response.header().command() != MessageCommand::Okay {
return Err(RustADBError::ADBRequestFailed(format!(
Expand All @@ -184,13 +192,13 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
}

Ok(ADBSession::new(
self.transport.clone(),
self.multiplexer.clone(),
local_id,
response.header().arg0(),
))
}

pub(crate) fn end_transaction(&mut self, session: &mut ADBSession<T>) -> Result<()> {
pub(crate) fn end_transaction(session: &mut ADBSession<T>) -> Result<()> {
let quit_buffer = MessageSubcommand::Quit.with_arg(0u32);
session.send_and_expect_okay(ADBTransportMessage::try_new(
MessageCommand::Write,
Expand All @@ -199,14 +207,14 @@ impl<T: ADBMessageTransport> ADBMessageDevice<T> {
&quit_buffer.encode(),
)?)?;

let _discard_close = self.transport.read_message()?;
let _discard_close = session.read_message()?;
Ok(())
}
}

impl<T: ADBMessageTransport> Drop for ADBMessageDevice<T> {
fn drop(&mut self) {
// Best effort here
let _ = self.get_transport_mut().disconnect();
let _ = self.multiplexer.disconnect();
}
}
163 changes: 163 additions & 0 deletions adb_client/src/message_devices/adb_multiplexer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use std::{
collections::{HashMap, VecDeque},
sync::{
Arc, RwLock,
atomic::{AtomicBool, Ordering},
},
thread::JoinHandle,
time::Duration,
};

use crate::{
Result,
adb_transport::ADBTransport,
message_devices::{
adb_message_transport::ADBMessageTransport, adb_transport_message::ADBTransportMessage,
},
};

/// Internal structure handling multiplexing of messages over various sessions.
///
/// It spawns a thread reading the transport, and implements [`ADBMessageTransport`] to read / write messages..
#[derive(Clone, Debug)]
pub(crate) struct ADBMessageMultiplexer<T: ADBMessageTransport> {
transport: T,
authenticated_data: Arc<RwLock<HashMap<u32, VecDeque<ADBTransportMessage>>>>,
unauthenticated_data: Arc<RwLock<VecDeque<ADBTransportMessage>>>,
handle: Option<Arc<JoinHandle<Result<()>>>>,
authenticated: Arc<AtomicBool>,
}

impl<T: ADBMessageTransport> ADBMessageMultiplexer<T> {
pub fn new(transport: T) -> Self {
Self {
transport,
authenticated_data: Arc::default(),
unauthenticated_data: Arc::default(),
handle: None,
authenticated: Arc::new(AtomicBool::new(false)),
}
}

pub fn upgrade_connection(&mut self) -> Result<()> {
self.transport.upgrade_connection()
}

pub fn set_authenticated(&mut self) {
log::debug!("multiplexer: authenticated");
self.authenticated.store(true, Ordering::Relaxed);
}

pub(crate) fn read_authentication_message(&mut self) -> Result<ADBTransportMessage> {
self.read_message_with_timeout(None, Duration::from_secs(u64::MAX))
}

pub(crate) fn read_message(&mut self, local_id: u32) -> Result<ADBTransportMessage> {
self.read_message_with_timeout(Some(local_id), Duration::from_secs(u64::MAX))
}

pub(crate) fn write_message(&mut self, message: ADBTransportMessage) -> Result<()> {
self.write_message_with_timeout(message, Duration::from_secs(2))
}

pub fn read_message_with_timeout(
&mut self,
local_id: Option<u32>,
read_timeout: std::time::Duration,
) -> Result<ADBTransportMessage> {
loop {
if let Some(local_id) = local_id {
let mut rw_data = self.authenticated_data.write()?;

if let Some(d) = rw_data.get_mut(&local_id)
&& let Some(v) = d.pop_front()
{
return Ok(v);
}
} else {
let mut rw_data = self.unauthenticated_data.write()?;
if let Some(v) = rw_data.pop_front() {
return Ok(v);
}
}

std::thread::sleep(Duration::from_millis(100));
}
}

pub fn write_message_with_timeout(
&mut self,
message: ADBTransportMessage,
write_timeout: std::time::Duration,
) -> Result<()> {
self.transport
.write_message_with_timeout(message, write_timeout)
}
}

impl<T: ADBMessageTransport> ADBTransport for ADBMessageMultiplexer<T> {
fn connect(&mut self) -> crate::Result<()> {
self.transport.connect()?;

let data = self.authenticated_data.clone();
let unauth_data = self.unauthenticated_data.clone();
let mut transport = self.transport.clone();
let authenticated = self.authenticated.clone();

// Spawn a thread responsible of continously reading the underlying transport
// and pushing messages to the internal data structure
let handle = std::thread::spawn(move || {
loop {
log::trace!("waiting for incoming message");
let message = transport.read_message()?;

let remote_id = message.header().arg1();

if authenticated.load(Ordering::Relaxed) {
log::trace!("got new authenticated message for {remote_id} session");

let mut rw_data = data.write()?;
let new_value = if let Some(mut d) = rw_data.remove(&remote_id) {
d.push_back(message);
d
} else {
let mut v = VecDeque::new();
v.push_back(message);
v
};
rw_data.insert(remote_id, new_value);
} else {
log::trace!("got new pre-authenticated message");
let mut rw_unauth_data = unauth_data.write()?;
rw_unauth_data.push_back(message);
}
}
});

self.handle = Some(Arc::new(handle));

Ok(())
}

fn disconnect(&mut self) -> crate::Result<()> {
// Empty both internal data storage structures
{
let mut rw_data = self.authenticated_data.write()?;
*rw_data = HashMap::default();
}

{
let mut rw_unauth_data = self.unauthenticated_data.write()?;
*rw_unauth_data = VecDeque::default();
}

if let Some(handle) = self.handle.take()
&& let Some(handle) = Arc::into_inner(handle)
&& let Err(e) = handle.join()
{
log::error!("Error joining multiplexer thread: {e:?}");
}

Ok(())
}
}
Loading
Loading