Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ contracts/src/flattened/
*.temp-checkpoint.json

# intellij
*.idea/
*.idea/

./snapshots
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ itertools = "0.14.0"
rayon = "1.11.0"
async-stream = "0.3.6"
serde = "1.0"

chrono = { version = "0.4", features = ["serde"] }
serde_json = "1.0"

[dev-dependencies]
rand = "0.9.2"
Expand Down
70 changes: 70 additions & 0 deletions examples/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use alloy::{
primitives::{address, Address},
providers::ProviderBuilder,
rpc::client::ClientBuilder,
transports::layers::{RetryBackoffLayer, ThrottleLayer},
};
use alloy_provider::WsConnect;
use amms::{
amms::{uniswap_v2::UniswapV2Pool, uniswap_v3::UniswapV3Pool},
state_space::{hooks::StateHook, StateSpaceBuilder},
};
use futures::StreamExt;
use std::sync::Arc;
use tracing::Level;

#[tokio::main]
async fn main() -> eyre::Result<()> {
tracing_subscriber::fmt().with_max_level(Level::INFO).init();
let rpc_endpoint = std::env::var("ETHEREUM_PROVIDER_WS")?;

let client = ClientBuilder::default()
.layer(ThrottleLayer::new(500))
.layer(RetryBackoffLayer::new(5, 200, 330))
.ws(WsConnect::new(rpc_endpoint))
.await?;

let sync_provider = Arc::new(ProviderBuilder::new().connect_client(client));

let amms = vec![
UniswapV2Pool::new(address!("B4e16d0168e52d35CaCD2c6185b44281Ec28C9Dc"), 300).into(),
UniswapV3Pool::new(address!("88e6A0c2dDD26FEEb64F039a2c41296FcB3f5640")).into(),
];

let state_space_manager = StateSpaceBuilder::new(sync_provider.clone())
.with_amms(amms)
// Enables background snapshotting
// See: [`hooks::SnapshotConfig`] for configuration options
.with_snapshot_enabled(None)
// start syncing from a snapshotted state
// .with_snapshot_path("./snapshots/<Your snapshot>") --- IGNORE ---
// Registers hooks to be called after every block is processed
.with_hooks(vec![simple_counter()])
.sync()
.await?;

let mut stream = state_space_manager.subscribe().await?;
let _res = stream.next().await.iter().take(20);

std::fs::remove_dir_all("./snapshots")?;

Ok(())
}

pub fn simple_counter() -> StateHook<Vec<Address>> {
let count = Arc::new(std::sync::atomic::AtomicU64::new(0));

let hook = move |updated_pools: &Vec<Address>| {
let total = count.fetch_add(
updated_pools.len() as u64,
std::sync::atomic::Ordering::Relaxed,
);

println!(
"total AMMs updated so far: {}",
total + updated_pools.len() as u64
);
};

Arc::new(hook)
}
2 changes: 1 addition & 1 deletion src/amms/abi/WethValueInPools.json

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/amms/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub enum AMMError {
UnrecognizedEventSignature(FixedBytes<32>),
#[error(transparent)]
JoinError(#[from] tokio::task::JoinError),
#[error("Snapshot Error: {0}")]
SnapshotError(#[from] serde_json::Error),
#[error("Snapshot Error: {0}")]
SnapshotIOError(#[from] std::io::Error),
}

#[derive(Error, Debug)]
Expand Down
8 changes: 4 additions & 4 deletions src/state_space/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use std::collections::HashMap;
use crate::amms::amm::{AutomatedMarketMaker, AMM};
use arraydeque::ArrayDeque;

#[derive(Debug)]
#[derive(Debug, Clone)]

pub struct StateChangeCache<const CAP: usize> {
oldest_block: u64,
cache: ArrayDeque<StateChange, CAP>,
pub oldest_block: u64,
pub cache: ArrayDeque<StateChange, CAP>,
}

impl<const CAP: usize> Default for StateChangeCache<CAP> {
Expand Down Expand Up @@ -87,7 +87,7 @@ impl<const CAP: usize> StateChangeCache<CAP> {

// NOTE: we can probably make this more efficient and create a state change struct for each amm rather than
// cloning each amm when caching
#[derive(Debug, Clone)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct StateChange {
pub state_change: Vec<AMM>,
pub block_number: u64,
Expand Down
1 change: 0 additions & 1 deletion src/state_space/filters/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ where
async fn filter(&self, amms: Vec<AMM>) -> Result<Vec<AMM>, AMMError> {
let pool_infos = amms
.iter()
.cloned()
.map(|amm| {
let pool_address = amm.address();
let pool_type = match amm {
Expand Down
173 changes: 173 additions & 0 deletions src/state_space/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use std::{
collections::HashMap,
path::PathBuf,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};

use alloy::primitives::Address;
use tokio::sync::RwLock;
use tracing::error;

use crate::state_space::{SerializableStateSpace, StateSpace};

/// A hook that can be used to observe, and react to state changes.
pub type StateHook<T> = Arc<dyn Fn(&T) + Send + Sync + 'static>;

/// Registry for managing state change hooks.
#[derive(Clone)]
pub struct HookRegistry<T: Send + Sync + Clone + 'static> {
inner: Arc<RwLock<HookRegistryInner<T>>>,
next_id: Arc<AtomicU64>,
}

#[derive(Clone)]
struct HookRegistryInner<T: Send + Sync + Clone + 'static> {
hooks: HashMap<usize, StateHook<T>>,
}

impl<T: Send + Sync + Clone + 'static> HookRegistry<T> {
/// Creates a new hook registry.
pub fn new(hooks: Vec<StateHook<T>>) -> Self {
Self {
inner: Arc::new(RwLock::new(HookRegistryInner {
hooks: hooks.into_iter().enumerate().collect(),
})),
next_id: Arc::new(AtomicU64::new(1)),
}
}

/// Registers a hook and returns a handle that unregisters on drop.
pub async fn register(&self, hook: StateHook<T>) -> HookHandle<T> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
self.inner.write().await.hooks.insert(id as usize, hook);
HookHandle {
id,
reg: self.clone(),
}
}

/// Notifies all registered hooks of a state change.
pub async fn notify(&self, state_change: &T) {
let mut hooks: Vec<StateHook<T>> = {
let guard = self.inner.read().await;
guard.hooks.values().cloned().collect()
};

for hook in hooks.iter_mut() {
(hook)(state_change);
}
}

async fn unregister(&self, id: u64) {
self.inner.write().await.hooks.remove(&(id as usize));
}
}

/// RAII handle for a registered hook. Unregisters the hook on drop.
#[derive(Clone)]
pub struct HookHandle<T: Send + Sync + Clone + 'static> {
id: u64,
reg: HookRegistry<T>,
}

impl<T: Send + Sync + Clone + 'static> Drop for HookHandle<T> {
fn drop(&mut self) {
let id = self.id;
let reg = self.reg.clone();
tokio::spawn(async move {
reg.unregister(id).await;
});
}
}

#[derive(Clone, Debug)]
pub struct SnapshotConfig {
/// Interval between consecutive snapshots.
pub interval: Duration,
/// Directory to store snapshots.
pub directory: PathBuf,
/// Maximum number of snapshots to retain.
pub max_snapshots: usize,
}

impl Default for SnapshotConfig {
fn default() -> Self {
Self {
interval: Duration::from_secs(60),
directory: PathBuf::from("./snapshots"),
max_snapshots: 5,
}
}
}

impl SnapshotConfig {
/// Creates a new snapshot configuration.
pub fn new(interval: Duration, directory: PathBuf, max_snapshots: usize) -> Self {
Self {
interval,
directory,
max_snapshots,
}
}

pub async fn into_state_hook(self, state: Arc<RwLock<StateSpace>>) -> StateHook<Vec<Address>> {
let interval = self.interval;
let max_snapshots = self.max_snapshots;
let directory = self.directory.clone();

let start = std::time::Instant::now();

let hook = move |_: &Vec<Address>| {
if start.elapsed() < interval {
return;
}

let timestamp = chrono::Utc::now().format("%Y%m%d%H%M%S");
let filename = format!("snapshot_{}.json", timestamp);
let path = directory.join(filename);

std::fs::create_dir_all(&directory)
.inspect_err(
|e| error!(target: "snapshot", "Failed to create snapshot directory: {}", e),
)
.ok();

let state = futures::executor::block_on(state.read()).clone();
let state: SerializableStateSpace = state.into();

let Ok(file) = std::fs::File::create(&path) else {
error!(target: "snapshot", "Failed to create snapshot file: {:?}", path);
return;
};

let writer = std::io::BufWriter::new(file);
let Ok(_) = serde_json::to_writer_pretty(writer, &state) else {
error!(target: "snapshot", "Failed to write snapshot to file: {:?}", path);
return;
};

let Ok(mut entries) = std::fs::read_dir(&directory)
.map(|rd| rd.filter_map(Result::ok).collect::<Vec<_>>())
else {
error!(target: "snapshot", "Failed to read snapshot directory: {:?}", directory);
return;
};

entries.sort_by_key(|e| e.metadata().and_then(|m| m.modified()).ok());
if entries.len() > max_snapshots {
entries.drain(..entries.len() - max_snapshots).for_each(|entry| {
let Ok(_) = std::fs::remove_file(entry.path()) else {
error!(target: "snapshot", "Failed to remove old snapshot: {:?}", entry.path());
return;
};
});
}
};

Arc::new(hook)
}
}
Loading
Loading