diff --git a/Cargo.lock b/Cargo.lock index 2c78e2031..03e989477 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2029,6 +2029,7 @@ dependencies = [ "bytes", "copyvec", "criterion", + "derive_more", "either", "env_logger 0.11.5", "ethereum-types", @@ -5296,6 +5297,16 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + [[package]] name = "tracing-subscriber" version = "0.3.18" @@ -5306,12 +5317,15 @@ dependencies = [ "nu-ansi-term", "once_cell", "regex", + "serde", + "serde_json", "sharded-slab", "smallvec", "thread_local", "tracing", "tracing-core", "tracing-log", + "tracing-serde", ] [[package]] @@ -5897,6 +5911,8 @@ version = "0.1.0" dependencies = [ "alloy", "alloy-compat", + "alloy-primitives", + "alloy-serde", "anyhow", "async-stream", "axum", diff --git a/Cargo.toml b/Cargo.toml index cadf0a13a..c8307e3b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,8 @@ alloy = { version = '0.3.0', default-features = false, features = [ "transport-http", "rpc-types-debug", ] } +alloy-primitives = "0.8.0" +alloy-serde = "0.3.0" anyhow = "1.0.86" async-stream = "0.3.5" axum = "0.7.5" @@ -46,6 +48,7 @@ clap = { version = "4.5.7", features = ["derive", "env"] } alloy-compat = "0.1.0" copyvec = "0.2.0" criterion = "0.5.1" +derive_more = { version = "1.0.0", features = ["deref", "deref_mut"] } dotenvy = "0.15.7" either = "1.12.0" enum-as-inner = "0.6.0" @@ -86,6 +89,7 @@ ruint = "1.12.3" serde = "1.0.203" serde_json = "1.0.118" serde_path_to_error = "0.1.16" +serde_with = "3.8.1" serde-big-array = "0.5.1" sha2 = "0.10.8" static_assertions = "1.1.0" @@ -94,8 +98,8 @@ thiserror = "1.0.61" tiny-keccak = "2.0.2" tokio = { version = "1.38.0", features = ["full"] } tower = "0.4" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing = { version = "0.1", features = ["attributes"] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } trybuild = "1.0" u4 = "0.1.0" uint = "0.9.5" diff --git a/evm_arithmetization/Cargo.toml b/evm_arithmetization/Cargo.toml index 22a40013f..3a083a01d 100644 --- a/evm_arithmetization/Cargo.toml +++ b/evm_arithmetization/Cargo.toml @@ -21,6 +21,7 @@ anyhow.workspace = true bitvec.workspace = true bytes.workspace = true copyvec.workspace = true +derive_more.workspace = true either.workspace = true env_logger.workspace = true ethereum-types.workspace = true diff --git a/evm_arithmetization/benches/fibonacci_25m_gas.rs b/evm_arithmetization/benches/fibonacci_25m_gas.rs index 2242b3049..8b752919e 100644 --- a/evm_arithmetization/benches/fibonacci_25m_gas.rs +++ b/evm_arithmetization/benches/fibonacci_25m_gas.rs @@ -192,6 +192,7 @@ fn prepare_setup() -> anyhow::Result> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }) } diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 842d794a9..7da6ad924 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,14 +1,14 @@ -// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], +// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], // for the given context's code. // Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) global verify_path_and_write_jumpdest_table: SWAP2 DUP2 - ADD // final_addr + ADD // final_addr = final_pos + ctx, i = init_pos // stack: final_addr, ctx, i, retdest SWAP2 - ADD // init_addr + ADD // init_addr = i + ctx loop: // stack: i, final_pos, retdest DUP2 DUP2 EQ // i == final_pos @@ -16,12 +16,12 @@ loop: DUP2 DUP2 GT // i > final_pos %jumpi(proof_not_ok) - // stack: i, final_pos, retdest + // stack: i, final_pos, retdest DUP1 MLOAD_GENERAL // SEGMENT_CODE == 0 // stack: opcode, i, final_pos, retdest - DUP1 + DUP1 // Slightly more efficient than `%eq_const(0x5b) ISZERO` PUSH 0x5b SUB @@ -141,7 +141,7 @@ global write_table_if_jumpdest: // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are not // PUSHXX for XX > 32 - n, n in {1, 32}. - + %stack (proof_prefix_addr, ctx) -> (ctx, proof_prefix_addr, 32, proof_prefix_addr, ctx) @@ -214,11 +214,11 @@ return: // non-deterministically guessing the sequence of jumpdest // addresses used during program execution within the current context. // For each jumpdest address we also non-deterministically guess -// a proof, which is another address in the code such that +// a proof, which is another address in the code such that // is_jumpdest doesn't abort, when the proof is at the top of the stack // an the jumpdest address below. If that's the case we set the // corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. -// +// // stack: ctx, code_len, retdest // stack: (empty) global jumpdest_analysis: @@ -235,10 +235,11 @@ global jumpdest_analysis_end: %pop2 JUMP check_proof: + // stack: address + 1, ctx, code_len, retdest // stack: address, ctx, code_len, retdest DUP3 DUP2 %assert_le %decrement - // stack: proof, ctx, code_len, retdest + // stack: address, ctx, code_len, retdest DUP2 SWAP1 // stack: address, ctx, ctx, code_len, retdest // We read the proof @@ -246,7 +247,7 @@ check_proof: // stack: proof, address, ctx, ctx, code_len, retdest %write_table_if_jumpdest // stack: ctx, code_len, retdest - + %jump(jumpdest_analysis) %macro jumpdest_analysis diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index d9745504e..7d832d3b1 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -5,10 +5,12 @@ //! the future execution and generate nondeterministically the corresponding //! jumpdest table, before the actual CPU carries on with contract execution. +use core::option::Option::None; use std::collections::{BTreeSet, HashMap}; use anyhow::anyhow; use ethereum_types::{BigEndianHash, U256}; +use keccak_hash::H256; use log::Level; use mpt_trie::partial_trie::PartialTrie; use plonky2::hash::hash_types::RichField; @@ -19,8 +21,10 @@ use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::generation::debug_inputs; +use crate::generation::jumpdest::{Context, JumpDestTableProcessed, JumpDestTableWitness}; use crate::generation::linked_list::LinkedListsPtrs; use crate::generation::mpt::{load_linked_lists_and_txn_and_receipt_mpts, TrieRootPtrs}; +use crate::generation::prover_input::get_proofs_and_jumpdests; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::state::{ all_ger_prover_inputs, all_withdrawals_prover_inputs_reversed, GenerationState, @@ -54,7 +58,9 @@ pub(crate) struct Interpreter { /// The interpreter will halt only if the current context matches /// halt_context pub(crate) halt_context: Option, - jumpdest_table: HashMap>, + /// A table of call contexts and the JUMPDEST offsets that they jumped to. + // todo + jumpdest_table_interpreter: HashMap>, /// `true` if the we are currently carrying out a jumpdest analysis. pub(crate) is_jumpdest_analysis: bool, /// Holds the value of the clock: the clock counts the number of operations @@ -73,10 +79,9 @@ pub(crate) struct Interpreter { pub(crate) fn simulate_cpu_and_get_user_jumps( final_label: &str, state: &GenerationState, -) -> Option>> { - match state.jumpdest_table { - Some(_) => None, - None => { +) -> Option<(JumpDestTableProcessed, JumpDestTableWitness)> { + match state.jumpdest_tables { + _ => { let halt_pc = KERNEL.global_labels[final_label]; let initial_context = state.registers.context; let mut interpreter = Interpreter::new_with_state_and_halt_condition( @@ -90,20 +95,23 @@ pub(crate) fn simulate_cpu_and_get_user_jumps( let _ = interpreter.run(); - log::trace!("jumpdest table = {:?}", interpreter.jumpdest_table); + log::trace!( + "jumpdest table = {:?}", + interpreter.jumpdest_table_interpreter + ); let clock = interpreter.get_clock(); - interpreter + let (jdtp, jdtw) = interpreter .generation_state - .set_jumpdest_analysis_inputs(interpreter.jumpdest_table); + .get_jumpdest_analysis_inputs(interpreter.jumpdest_table_interpreter.clone()); log::debug!( "Simulated CPU for jumpdest analysis halted after {:?} cycles.", clock ); - - interpreter.generation_state.jumpdest_table + // interpreter.generation_state.jumpdest_table = Some(jdtp.clone()); + Some((jdtp, jdtw)) } } } @@ -116,7 +124,8 @@ pub(crate) struct ExtraSegmentData { pub(crate) withdrawal_prover_inputs: Vec, pub(crate) ger_prover_inputs: Vec, pub(crate) trie_root_ptrs: TrieRootPtrs, - pub(crate) jumpdest_table: Option>>, + // todo + pub(crate) jumpdest_table: Vec>, pub(crate) access_lists_ptrs: LinkedListsPtrs, pub(crate) state_ptrs: LinkedListsPtrs, pub(crate) next_txn_index: usize, @@ -152,6 +161,57 @@ pub(crate) fn set_registers_and_run( interpreter.run() } +/// Computes the JUMPDEST proofs for each context. +/// +/// # Arguments +/// +/// - `jumpdest_table_rpc`: The raw table received from RPC. +/// - `code_db`: The corresponding database of contract code used in the trace. +/// +/// # Output +/// +/// Returns a [`JumpDestTableProccessed`]. +pub(crate) fn get_jumpdest_analysis_inputs_rpc( + jumpdest_table_rpc: &JumpDestTableWitness, + code_map: &HashMap>, + prev_max_batch_ctx: usize, +) -> JumpDestTableProcessed { + let ctx_proofs = jumpdest_table_rpc + .iter() + .flat_map(|(code_addr, ctx_jumpdests)| { + let code = if code_map.contains_key(code_addr) { + &code_map[code_addr] + } else { + &vec![] + }; + prove_context_jumpdests(code, ctx_jumpdests) + }) + .collect(); + JumpDestTableProcessed::new_with_ctx_offset(ctx_proofs, prev_max_batch_ctx) +} + +/// Orchestrates the proving of all contexts in a specific bytecode. +/// +/// # Arguments +/// +/// - `code`: The bytecode for the context `ctx`. +/// - `ctx`: Map from `ctx` to its list of `JUMPDEST` offsets. +/// +/// # Outputs +/// +/// Returns a [`HashMap`] from `ctx` to [`Vec`] of proofs. Each proofs ia a +/// pair. +fn prove_context_jumpdests(code: &[u8], ctx: &Context) -> HashMap> { + ctx.iter() + .map(|(&ctx, jumpdests)| { + let proofs = jumpdests.last().map_or(Vec::default(), |&largest_address| { + get_proofs_and_jumpdests(code, largest_address, jumpdests.clone()) + }); + (ctx, proofs) + }) + .collect() +} + impl Interpreter { /// Returns an instance of `Interpreter` given `GenerationInputs`, and /// assuming we are initializing with the `KERNEL` code. @@ -164,6 +224,7 @@ impl Interpreter { debug_inputs(inputs); let mut result = Self::new(initial_offset, initial_stack, max_cpu_len_log); + result.generation_state.jumpdest_tables = vec![None; inputs.batch_jumpdest_tables.len()]; result.initialize_interpreter_state(inputs); result } @@ -182,7 +243,8 @@ impl Interpreter { halt_context: None, #[cfg(test)] opcode_count: HashMap::new(), - jumpdest_table: HashMap::new(), + // todo + jumpdest_table_interpreter: HashMap::new(), is_jumpdest_analysis: false, clock: 0, max_cpu_len_log, @@ -214,7 +276,8 @@ impl Interpreter { halt_context: Some(halt_context), #[cfg(test)] opcode_count: HashMap::new(), - jumpdest_table: HashMap::new(), + // check + jumpdest_table_interpreter: HashMap::new(), is_jumpdest_analysis: true, clock: 0, max_cpu_len_log, @@ -473,14 +536,15 @@ impl Interpreter { .content } + // what happens here? pub(crate) fn add_jumpdest_offset(&mut self, offset: usize) { if let Some(jumpdest_table) = self - .jumpdest_table + .jumpdest_table_interpreter .get_mut(&self.generation_state.registers.context) { jumpdest_table.insert(offset); } else { - self.jumpdest_table.insert( + self.jumpdest_table_interpreter.insert( self.generation_state.registers.context, BTreeSet::from([offset]), ); diff --git a/evm_arithmetization/src/cpu/kernel/tests/add11.rs b/evm_arithmetization/src/cpu/kernel/tests/add11.rs index 683987244..e46886585 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/add11.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/add11.rs @@ -193,6 +193,7 @@ fn test_add11_yml() { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let initial_stack = vec![]; @@ -370,6 +371,7 @@ fn test_add11_yml_with_exception() { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let initial_stack = vec![]; diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs index f2d00ede5..b047a4de1 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -10,13 +10,17 @@ use plonky2::hash::hash_types::RichField; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; +use crate::generation::jumpdest::JumpDestTableProcessed; use crate::memory::segments::Segment; use crate::witness::memory::MemoryAddress; use crate::witness::operation::CONTEXT_SCALING_FACTOR; impl Interpreter { pub(crate) fn set_jumpdest_analysis_inputs(&mut self, jumps: HashMap>) { - self.generation_state.set_jumpdest_analysis_inputs(jumps); + let (jdtp, _jdtw) = self.generation_state.get_jumpdest_analysis_inputs(jumps); + + let tx_in_batch_idx = self.generation_state.next_txn_index - 1; + self.generation_state.jumpdest_tables[tx_in_batch_idx] = Some(jdtp); } pub(crate) fn get_jumpdest_bit(&self, offset: usize) -> U256 { @@ -101,12 +105,16 @@ fn test_jumpdest_analysis() -> Result<()> { ), )])); - // The `set_jumpdest_analysis_inputs` method is never used. + let tx_in_batch_idx = interpreter.generation_state.next_txn_index - 1; + // TODO The `set_jumpdest_analysis_inputs` method is never used. assert_eq!( - interpreter.generation_state.jumpdest_table, + interpreter.generation_state.jumpdest_tables[tx_in_batch_idx], // Context 3 has jumpdest 1, 5, 7. All have proof 0 and hence // the list [proof_0, jumpdest_0, ... ] is [0, 1, 0, 5, 0, 7, 8, 40] - Some(HashMap::from([(3, vec![0, 1, 0, 5, 0, 7, 8, 40])])) + Some(JumpDestTableProcessed::new(HashMap::from([( + 3, + vec![0, 1, 0, 5, 0, 7, 8, 40] + )]))) ); // Run jumpdest analysis with context = 3 @@ -121,14 +129,13 @@ fn test_jumpdest_analysis() -> Result<()> { .push(U256::from(CONTEXT) << CONTEXT_SCALING_FACTOR) .expect("The stack should not overflow"); + let tx_in_batch_idx = interpreter.generation_state.next_txn_index - 1; // We need to manually pop the jumpdest_table and push its value on the top of // the stack - interpreter - .generation_state - .jumpdest_table + interpreter.generation_state.jumpdest_tables[tx_in_batch_idx] .as_mut() .unwrap() - .get_mut(&CONTEXT) + .try_get_batch_ctx_mut(&CONTEXT) .unwrap() .pop(); interpreter @@ -175,7 +182,10 @@ fn test_packed_verification() -> Result<()> { let mut interpreter: Interpreter = Interpreter::new(write_table_if_jumpdest, initial_stack.clone(), None); interpreter.set_code(CONTEXT, code.clone()); - interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])])); + let tx_in_batch_idx = interpreter.generation_state.next_txn_index - 1; + interpreter.generation_state.jumpdest_tables[tx_in_batch_idx] = Some( + JumpDestTableProcessed::new(HashMap::from([(3, vec![1, 33])])), + ); interpreter.run()?; @@ -188,7 +198,10 @@ fn test_packed_verification() -> Result<()> { let mut interpreter: Interpreter = Interpreter::new(write_table_if_jumpdest, initial_stack.clone(), None); interpreter.set_code(CONTEXT, code.clone()); - interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])])); + let tx_in_batch_idx = interpreter.generation_state.next_txn_index - 1; + interpreter.generation_state.jumpdest_tables[tx_in_batch_idx] = Some( + JumpDestTableProcessed::new(HashMap::from([(3, vec![1, 33])])), + ); assert!(interpreter.run().is_err()); diff --git a/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs index 2dea58b55..9000870ac 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs @@ -101,6 +101,7 @@ fn test_init_exc_stop() { cur_hash: H256::default(), }, ger_data: None, + batch_jumpdest_tables: vec![None], }; let initial_stack = vec![]; let initial_offset = KERNEL.global_labels["init"]; diff --git a/evm_arithmetization/src/generation/jumpdest.rs b/evm_arithmetization/src/generation/jumpdest.rs new file mode 100644 index 000000000..dc180c74d --- /dev/null +++ b/evm_arithmetization/src/generation/jumpdest.rs @@ -0,0 +1,264 @@ +//! EVM opcode 0x5B or 91 is [`JUMPDEST`] which encodes a a valid offset, that +//! opcodes `JUMP` and `JUMPI` can jump to. Jumps to non-[`JUMPDEST`] +//! instructions are invalid. During an execution a [`JUMPDEST`] may be visited +//! zero or more times. Offsets are measured in bytes with respect to the +//! beginning of some contract code, which is uniquely identified by its +//! `CodeHash`. Every time control flow is switches to another contract through +//! a `CALL`-like instruction a new call context, `Context`, is created. Thus, +//! the tripple (`CodeHash`,`Context`, `Offset`) uniquely identifies an visited +//! [`JUMPDEST`] offset of an execution. +//! +//! Since an operation like e.g. `PUSH 0x5B` does not encode a valid +//! [`JUMPDEST`] in its second byte, and `PUSH32 +//! 5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B5B` does not +//! encode any valid [`JUMPDEST`] in bytes 1-32, some diligence must be +//! exercised when proving validity of jump operations. +//! +//! This module concerns itself with data structures for collecting these +//! offsets for [`JUMPDEST`] that was visited during an execution and are not +//! recording duplicity. The proofs, that each of these offsets are not rendered +//! invalid by `PUSH1`-`PUSH32` in any of the previous 32 bytes, are computed +//! later in `prove_context_jumpdests` on basis of these collections. +//! +//! [`JUMPDEST`]: https://www.evm.codes/?fork=cancun#5b + +use std::cmp::max; +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, +}; + +use derive_more::derive::{Deref, DerefMut}; +use itertools::{sorted, Itertools}; +use keccak_hash::H256; +use serde::{Deserialize, Serialize}; + +/// Each `CodeHash` can be called one or more times, +/// each time creating a new `Context`. +/// Each `Context` will contain one or more offsets of `JUMPDEST`. +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Default, Deref, DerefMut)] +pub struct Context(pub HashMap>); + +/// The result after proving a [`JumpDestTableWitness`]. +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Default)] +pub(crate) struct JumpDestTableProcessed { + witness_contexts: HashMap>, + ctx_offset: usize, +} + +/// Map `CodeHash -> (Context -> [JumpDests])` +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Default, Deref, DerefMut)] +pub struct JumpDestTableWitness(HashMap); + +impl Context { + // pub fn get(&self, ctx: usize) -> Option<&BTreeSet> { + // self.get(&ctx) + // } + + pub fn insert(&mut self, ctx: usize, offset_opt: Option) { + let context = self.entry(ctx).or_default(); + + if let Some(offset) = offset_opt { + context.insert(offset); + }; + } + + pub fn max_batch_ctx(&self) -> usize { + self.keys().max().copied().unwrap_or(0) + } +} + +impl JumpDestTableProcessed { + pub fn new(ctx_map: HashMap>) -> Self { + Self { + witness_contexts: ctx_map, + ctx_offset: 0, + } + } + + pub fn new_with_ctx_offset(ctx_map: HashMap>, ctx_offset: usize) -> Self { + Self { + witness_contexts: ctx_map, + ctx_offset, + } + } + + pub fn normalize(&self) -> Self { + let witness_contexts = self + .witness_contexts + .iter() + .map(|(ctx, offsets)| (ctx + self.ctx_offset, offsets.clone())) + .collect(); + + Self { + witness_contexts, + ctx_offset: 0, + } + } + + pub fn try_get_batch_ctx_mut(&mut self, batch_ctx: &usize) -> Option<&mut Vec> { + log::info!("query_ctx {}", batch_ctx,); + let witness_context = *batch_ctx - self.ctx_offset; + self.witness_contexts.get_mut(&witness_context) + } + + pub fn remove_batch_ctx(&mut self, batch_ctx: &usize) { + let witness_context = *batch_ctx - self.ctx_offset; + self.witness_contexts.remove(&witness_context); + } + + pub fn max_batch_ctx(&self) -> usize { + self.witness_contexts.keys().max().copied().unwrap_or(0) + } +} + +impl JumpDestTableWitness { + pub fn get(&self, code_hash: &H256) -> Option<&Context> { + self.0.get(code_hash) + } + + /// Insert `offset` into `ctx` under the corresponding `code_hash`. + /// Creates the required `ctx` keys and `code_hash`. Idempotent. + pub fn insert(&mut self, code_hash: H256, ctx: usize, offset_opt: Option) { + (*self) + .entry(code_hash) + .or_default() + .insert(ctx, offset_opt); + } + + pub fn extend(mut self, other: &Self, prev_max_ctx: usize) -> (Self, usize) { + let mut curr_max_ctx = prev_max_ctx; + + for (code_hash, ctx_tbl) in (*other).iter() { + for (ctx, jumpdests) in ctx_tbl.0.iter() { + let batch_ctx = prev_max_ctx + ctx; + curr_max_ctx = max(curr_max_ctx, batch_ctx); + + for offset in jumpdests { + self.insert(*code_hash, batch_ctx, Some(*offset)); + } + } + } + + (self, curr_max_ctx) + } + + pub fn merge<'a>(jdts: impl IntoIterator) -> (Self, usize) { + jdts.into_iter() + .fold((Default::default(), 0), |(acc, cnt), t| acc.extend(t, cnt)) + } + + /// Obtain the context within any `code_hash` with maximal numeric value. + pub fn max_batch_ctx(&self) -> usize { + self.values() + .map(|ctx| ctx.max_batch_ctx()) + .max() + .unwrap_or(0) + } +} + +// The following Display instances are added to make it easier to read diffs. +impl Display for JumpDestTableWitness { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "\n=== JumpDestTableWitness ===")?; + + for (code, ctxtbls) in &self.0 { + write!(f, "codehash: {:#x}\n{}", code, ctxtbls)?; + } + Ok(()) + } +} + +impl Display for Context { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let v: Vec<_> = self.0.iter().sorted().collect(); + for (ctx, offsets) in v.into_iter() { + write!(f, " ctx: {:>4}: [", ctx)?; + for offset in offsets { + write!(f, "{:#}, ", offset)?; + } + writeln!(f, "]")?; + } + Ok(()) + } +} + +impl Display for JumpDestTableProcessed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "\n=== JumpDestTableProcessed ===")?; + + let v = sorted(self.witness_contexts.clone()); + for (ctx, code) in v { + writeln!(f, "ctx: {:?} {:?}", ctx, code)?; + } + Ok(()) + } +} + +impl FromIterator<(H256, usize, usize)> for JumpDestTableWitness { + fn from_iter>(iter: T) -> Self { + let mut jdtw = JumpDestTableWitness::default(); + for (code_hash, ctx, offset) in iter.into_iter() { + jdtw.insert(code_hash, ctx, Some(offset)); + } + jdtw + } +} + +#[cfg(test)] +mod test { + use std::collections::{BTreeSet, HashMap}; + + use keccak_hash::H256; + + use super::JumpDestTableWitness; + use crate::jumpdest::Context; + + #[test] + fn test_extend_from_iter() { + let code_hash = H256::default(); + + let ctx_map = vec![ + (code_hash, 1, 1), + (code_hash, 2, 2), + (code_hash, 42, 3), + (code_hash, 43, 4), + ]; + let table1 = JumpDestTableWitness::from_iter(ctx_map); + let table2 = table1.clone(); + + let jdts = [&table1, &table2]; + let (actual, max_ctx) = JumpDestTableWitness::merge(jdts); + + let ctx_map_merged = vec![ + (code_hash, 1, 1), + (code_hash, 2, 2), + (code_hash, 42, 3), + (code_hash, 43, 4), + (code_hash, 44, 1), + (code_hash, 45, 2), + (code_hash, 85, 3), + (code_hash, 86, 4), + ]; + let expected = JumpDestTableWitness::from_iter(ctx_map_merged); + + assert_eq!(86, max_ctx); + assert_eq!(expected, actual) + } + + #[test] + fn test_create_context() { + let code_hash = H256::default(); + let mut table1 = JumpDestTableWitness::default(); + table1.insert(code_hash, 42, None); + + let offsets = BTreeSet::::default(); + let mut ctx = HashMap::::default(); + ctx.insert(42, offsets); + let mut contexts = HashMap::::default(); + contexts.insert(code_hash, Context(ctx)); + let table2 = JumpDestTableWitness(contexts); + + assert_eq!(table1, table2); + } +} diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 023e1f2ac..e03d3909d 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use anyhow::anyhow; use ethereum_types::{Address, BigEndianHash, H256, U256}; +use jumpdest::JumpDestTableWitness; use keccak_hash::keccak; use log::error; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; @@ -34,6 +35,7 @@ use crate::util::{h2u, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryState}; use crate::witness::state::RegistersState; +pub mod jumpdest; pub(crate) mod linked_list; pub mod mpt; pub(crate) mod prover_input; @@ -131,6 +133,10 @@ pub struct GenerationInputs { /// /// This is specific to `cdk-erigon`. pub ger_data: Option<(H256, H256)>, + + /// A table listing each JUMPDESTs reached in each call context under + /// associated code hash. + pub batch_jumpdest_tables: Vec>, } /// A lighter version of [`GenerationInputs`], which have been trimmed @@ -181,6 +187,10 @@ pub struct TrimmedGenerationInputs { /// The hash of the current block, and a list of the 256 previous block /// hashes. pub block_hashes: BlockHashes, + + /// A list of tables listing each JUMPDESTs reached in each call context + /// under associated code hash. + pub jumpdest_table: Vec>, } #[derive(Clone, Debug, Deserialize, Serialize, Default)] @@ -255,6 +265,7 @@ impl GenerationInputs { burn_addr: self.burn_addr, block_metadata: self.block_metadata.clone(), block_hashes: self.block_hashes.clone(), + jumpdest_table: self.batch_jumpdest_tables.clone(), } } } diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 704e2f4c6..efeefc563 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -6,10 +6,12 @@ use std::str::FromStr; use anyhow::{bail, Error, Result}; use ethereum_types::{BigEndianHash, H256, U256, U512}; use itertools::Itertools; +use keccak_hash::keccak; use num_bigint::BigUint; use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; +use super::jumpdest::{JumpDestTableProcessed, JumpDestTableWitness}; #[cfg(test)] use super::linked_list::testing::{LinkedList, ADDRESSES_ACCESS_LIST_LEN}; use super::linked_list::{ @@ -22,7 +24,9 @@ use crate::cpu::kernel::constants::cancun_constants::{ POINT_EVALUATION_PRECOMPILE_RETURN_VALUE, }; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; -use crate::cpu::kernel::interpreter::simulate_cpu_and_get_user_jumps; +use crate::cpu::kernel::interpreter::{ + get_jumpdest_analysis_inputs_rpc, simulate_cpu_and_get_user_jumps, +}; use crate::curve_pairings::{bls381, CurveAff, CyclicGroup}; use crate::extension_tower::{FieldExt, Fp12, Fp2, BLS381, BLS_BASE, BLS_SCALAR, BN254, BN_BASE}; use crate::generation::prover_input::EvmField::{ @@ -40,6 +44,9 @@ use crate::witness::memory::MemoryAddress; use crate::witness::operation::CONTEXT_SCALING_FACTOR; use crate::witness::util::{current_context_peek, stack_peek}; +/// A set to hold contract code as a byte vectors. +pub type CodeDb = BTreeSet>; + /// Prover input function represented as a scoped function name. /// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as /// `ProverInputFn([ff, bn254_base, inverse])`. @@ -79,7 +86,7 @@ impl GenerationState { fn run_end_of_txns(&mut self) -> Result { // Reset the jumpdest table before the next transaction. - self.jumpdest_table = None; + // self.jumpdest_table = None; let end = self.next_txn_index == self.inputs.txn_hashes.len(); if end { Ok(U256::one()) @@ -350,26 +357,43 @@ impl GenerationState { .ok_or(ProgramError::ProverInputError(OutOfGerData)) } - /// Returns the next used jump address. + /// Returns the next used jumpdest address. fn run_next_jumpdest_table_address(&mut self) -> Result { - let context = u256_to_usize(stack_peek(self, 0)? >> CONTEXT_SCALING_FACTOR)?; + let batch_context = u256_to_usize(stack_peek(self, 0)? >> CONTEXT_SCALING_FACTOR)?; - if self.jumpdest_table.is_none() { + log::info!( + "Current ctx {} current tx {}", + batch_context, + self.next_txn_index - 1 + ); + let tx_in_batch_idx = self.next_txn_index - 1; + + if self.jumpdest_tables[tx_in_batch_idx].is_none() { self.generate_jumpdest_table()?; } - let Some(jumpdest_table) = &mut self.jumpdest_table else { + let Some(jumpdest_table) = &mut self.jumpdest_tables[tx_in_batch_idx] else { return Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )); }; - if let Some(ctx_jumpdest_table) = jumpdest_table.get_mut(&context) + if let Some(ctx_jumpdest_table) = jumpdest_table.try_get_batch_ctx_mut(&batch_context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop() { + log::info!( + "run_next_jumpdest_table_address, ctx {:>5}, address {:>5}", + batch_context, + next_jumpdest_address + 1 + ); Ok((next_jumpdest_address + 1).into()) } else { - jumpdest_table.remove(&context); + log::info!( + "run_next_jumpdest_table_address, ctx {:>5}, address {:>5}", + batch_context, + 0 + ); + jumpdest_table.remove_batch_ctx(&batch_context); Ok(U256::zero()) } } @@ -377,15 +401,21 @@ impl GenerationState { /// Returns the proof for the last jump address. fn run_next_jumpdest_table_proof(&mut self) -> Result { let context = u256_to_usize(stack_peek(self, 1)? >> CONTEXT_SCALING_FACTOR)?; - let Some(jumpdest_table) = &mut self.jumpdest_table else { + let tx_in_batch_idx = self.next_txn_index - 1; + let Some(jumpdest_table) = &mut self.jumpdest_tables[tx_in_batch_idx] else { return Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )); }; - if let Some(ctx_jumpdest_table) = jumpdest_table.get_mut(&context) + if let Some(ctx_jumpdest_table) = jumpdest_table.try_get_batch_ctx_mut(&context) && let Some(next_jumpdest_proof) = ctx_jumpdest_table.pop() { + log::info!( + "run_next_jumpdest_table_proof, ctx {:>5}, proof {:>5}", + context, + next_jumpdest_proof + ); Ok(next_jumpdest_proof.into()) } else { Err(ProgramError::ProverInputError( @@ -402,8 +432,18 @@ impl GenerationState { let address = u256_to_usize(stack_peek(self, 0)?)?; let closest_opcode_addr = get_closest_opcode_address(&code, address); Ok(if closest_opcode_addr < 32 { + log::info!( + "run_next_non_jumpdest_proof address {:>5}, closest_opcode_addr {:>5}, returns 0", + address, + closest_opcode_addr, + ); U256::zero() } else { + log::info!( + "run_next_non_jumpdest_proof address, {:>5}, closest_opcode_addr {:>5}", + address, + closest_opcode_addr, + ); closest_opcode_addr.into() }) } @@ -784,23 +824,70 @@ impl GenerationState { /// Simulate the user's code and store all the jump addresses with their /// respective contexts. fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { + let tx_in_batch_idx = self.next_txn_index - 1; + let prev_max_witness_ctx: usize = self + .jumpdest_tables + .get(tx_in_batch_idx - 1) + .map(|x| x.as_ref()) + .flatten() + .map(|x| x.max_batch_ctx()) + .unwrap_or(0); + log::info!("TX BATCH: {:#?}", self.inputs.txn_hashes); + log::info!("BATCH LEN: {}", self.inputs.txn_hashes.len()); + log::info!("TXN_NUM_BEFORE: {}", self.inputs.txn_number_before); + log::info!("Maximum CTX in previous tx: {}", prev_max_witness_ctx); + log::info!("TXIDX: {}", tx_in_batch_idx); + log::info!("TX HASH: {:#?}", self.inputs.txn_hashes[tx_in_batch_idx]); + let rpcw = self.inputs.jumpdest_table[tx_in_batch_idx].clone(); + let rpcp: Option = + rpcw.as_ref().map(|jdt: &JumpDestTableWitness| { + get_jumpdest_analysis_inputs_rpc( + jdt, + &self.inputs.contract_code, + prev_max_witness_ctx, + ) + .normalize() + }); + log::info!("RPCW {:#?}", &rpcw); + log::info!("RPCP {:#?}", &rpcp); + if rpcp.is_some() { + self.jumpdest_tables[tx_in_batch_idx] = rpcp; + return Ok(()); + } // Simulate the user's code and (unnecessarily) part of the kernel code, // skipping the validate table call - self.jumpdest_table = simulate_cpu_and_get_user_jumps("terminate_common", self); + self.jumpdest_tables[tx_in_batch_idx] = None; + let (simp, simw) = simulate_cpu_and_get_user_jumps("terminate_common", &*self) + .ok_or(ProgramError::ProverInputError(InvalidJumpdestSimulation))?; + log::info!("SIMW {:#?}", &simw); + log::info!("SIMP {:#?}", &simp); + + if let Some(rpcp) = rpcp { + if &rpcp != &simp { + log::warn!("MISMATCH"); + dbg!(Some(&rpcp), Some(&simp)); + } + } - Ok(()) + self.jumpdest_tables[tx_in_batch_idx] = Some(simp); + return Ok(()); } /// Given a HashMap containing the contexts and the jumpdest addresses, /// compute their respective proofs, by calling /// `get_proofs_and_jumpdests` - pub(crate) fn set_jumpdest_analysis_inputs( - &mut self, + pub(crate) fn get_jumpdest_analysis_inputs( + &self, jumpdest_table: HashMap>, - ) { - self.jumpdest_table = Some(HashMap::from_iter(jumpdest_table.into_iter().map( + ) -> (JumpDestTableProcessed, JumpDestTableWitness) { + let mut jdtw = JumpDestTableWitness::default(); + let jdtp = JumpDestTableProcessed::new(HashMap::from_iter(jumpdest_table.into_iter().map( |(ctx, jumpdest_table)| { let code = self.get_code(ctx).unwrap(); + let code_hash = keccak(code.clone()); + for offset in jumpdest_table.clone() { + jdtw.insert(code_hash, ctx, Some(offset)); + } if let Some(&largest_address) = jumpdest_table.last() { let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); (ctx, proofs) @@ -809,6 +896,7 @@ impl GenerationState { } }, ))); + (jdtp, jdtw) } pub(crate) fn get_current_code(&self) -> Result, ProgramError> { @@ -855,7 +943,7 @@ impl GenerationState { /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes) is a PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses. -fn get_proofs_and_jumpdests( +pub(crate) fn get_proofs_and_jumpdests( code: &[u8], largest_address: usize, jumpdest_table: std::collections::BTreeSet, diff --git a/evm_arithmetization/src/generation/segments.rs b/evm_arithmetization/src/generation/segments.rs index 1df63af29..6bdcaad32 100644 --- a/evm_arithmetization/src/generation/segments.rs +++ b/evm_arithmetization/src/generation/segments.rs @@ -81,7 +81,8 @@ fn build_segment_data( .clone(), ger_prover_inputs: interpreter.generation_state.ger_prover_inputs.clone(), trie_root_ptrs: interpreter.generation_state.trie_root_ptrs.clone(), - jumpdest_table: interpreter.generation_state.jumpdest_table.clone(), + // todo verify + jumpdest_table: interpreter.generation_state.jumpdest_tables.clone(), next_txn_index: interpreter.generation_state.next_txn_index, access_lists_ptrs: interpreter.generation_state.access_lists_ptrs.clone(), state_ptrs: interpreter.generation_state.state_ptrs.clone(), diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index abe4b4f1a..c06f06dcc 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -8,6 +8,7 @@ use keccak_hash::keccak; use log::Level; use plonky2::hash::hash_types::RichField; +use super::jumpdest::JumpDestTableProcessed; use super::linked_list::LinkedListsPtrs; use super::mpt::TrieRootPtrs; use super::segments::GenerationSegmentData; @@ -100,6 +101,7 @@ pub(crate) trait State { } /// Returns the context in which the jumpdest analysis should end. + // this seems pointless fn get_halt_context(&self) -> Option { None } @@ -353,6 +355,7 @@ pub struct GenerationState { pub(crate) memory: MemoryState, pub(crate) traces: Traces, + /// In the batch / block?? pub(crate) next_txn_index: usize, /// Memory used by stale contexts can be pruned so proving segments can be @@ -386,8 +389,8 @@ pub struct GenerationState { /// "proof" for a jump destination is either 0 or an address i > 32 in /// the code (not necessarily pointing to an opcode) such that for every /// j in [i, i+32] it holds that code[j] < 0x7f - j + i. - pub(crate) jumpdest_table: Option>>, - + // jumpdest_table: Option, + pub(crate) jumpdest_tables: Vec>, /// Provides quick access to pointers that reference the location /// of either and account or a slot in the respective access list. pub(crate) access_lists_ptrs: LinkedListsPtrs, @@ -456,7 +459,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - jumpdest_table: None, + jumpdest_tables: vec![], access_lists_ptrs: LinkedListsPtrs::default(), state_ptrs: LinkedListsPtrs::default(), ger_prover_inputs, @@ -494,12 +497,12 @@ impl GenerationState { // We cannot observe anything as the stack is empty. return Ok(()); } - if dst == KERNEL.global_labels["observe_new_address"] { + if dst == KERNEL.global_labels["observe_new_address"] && self.is_kernel() { let tip_u256 = stack_peek(self, 0)?; let tip_h256 = H256::from_uint(&tip_u256); let tip_h160 = H160::from(tip_h256); self.observe_address(tip_h160); - } else if dst == KERNEL.global_labels["observe_new_contract"] { + } else if dst == KERNEL.global_labels["observe_new_contract"] && self.is_kernel() { let tip_u256 = stack_peek(self, 0)?; let tip_h256 = H256::from_uint(&tip_u256); self.observe_contract(tip_h256)?; @@ -572,7 +575,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - jumpdest_table: None, + jumpdest_tables: self.jumpdest_tables.clone(), access_lists_ptrs: self.access_lists_ptrs.clone(), state_ptrs: self.state_ptrs.clone(), } @@ -589,7 +592,8 @@ impl GenerationState { .clone_from(&segment_data.extra_data.ger_prover_inputs); self.trie_root_ptrs .clone_from(&segment_data.extra_data.trie_root_ptrs); - self.jumpdest_table + // todo verify + self.jumpdest_tables .clone_from(&segment_data.extra_data.jumpdest_table); self.state_ptrs .clone_from(&segment_data.extra_data.state_ptrs); diff --git a/evm_arithmetization/src/lib.rs b/evm_arithmetization/src/lib.rs index 9bc6021e2..6cf31bb80 100644 --- a/evm_arithmetization/src/lib.rs +++ b/evm_arithmetization/src/lib.rs @@ -280,6 +280,9 @@ pub mod verifier; pub mod generation; pub mod witness; +pub use generation::jumpdest; +pub use generation::prover_input::CodeDb; + // Utility modules pub mod curve_pairings; pub mod extension_tower; diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index fdcf9af65..8a6f1d39a 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -345,14 +345,6 @@ where where Self: Sized, { - self.perform_op(op, row)?; - self.incr_pc(match op { - Operation::Syscall(_, _, _) | Operation::ExitKernel => 0, - Operation::Push(n) => n as usize + 1, - Operation::Jump | Operation::Jumpi => 0, - _ => 1, - }); - self.incr_gas(gas_to_charge(op)); let registers = self.get_registers(); let gas_limit_address = MemoryAddress::new( @@ -373,6 +365,14 @@ where } } + self.perform_op(op, row)?; + self.incr_pc(match op { + Operation::Syscall(_, _, _) | Operation::ExitKernel => 0, + Operation::Push(n) => n as usize + 1, + Operation::Jump | Operation::Jumpi => 0, + _ => 1, + }); + Ok(op) } diff --git a/evm_arithmetization/tests/add11_yml.rs b/evm_arithmetization/tests/add11_yml.rs index 5406ebe6a..6c6e9d539 100644 --- a/evm_arithmetization/tests/add11_yml.rs +++ b/evm_arithmetization/tests/add11_yml.rs @@ -198,6 +198,7 @@ fn get_generation_inputs() -> GenerationInputs { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], } } /// The `add11_yml` test case from https://github.com/ethereum/tests diff --git a/evm_arithmetization/tests/erc20.rs b/evm_arithmetization/tests/erc20.rs index f594c7bd3..0b3cfe7c7 100644 --- a/evm_arithmetization/tests/erc20.rs +++ b/evm_arithmetization/tests/erc20.rs @@ -196,6 +196,7 @@ fn test_erc20() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/erc721.rs b/evm_arithmetization/tests/erc721.rs index e39915100..5a47f087a 100644 --- a/evm_arithmetization/tests/erc721.rs +++ b/evm_arithmetization/tests/erc721.rs @@ -203,6 +203,7 @@ fn test_erc721() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/global_exit_root.rs b/evm_arithmetization/tests/global_exit_root.rs index 5495d4100..22120041a 100644 --- a/evm_arithmetization/tests/global_exit_root.rs +++ b/evm_arithmetization/tests/global_exit_root.rs @@ -114,6 +114,7 @@ fn test_global_exit_root() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + jumpdest_table: None, }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/log_opcode.rs b/evm_arithmetization/tests/log_opcode.rs index 4f274b0b9..bec6ddedd 100644 --- a/evm_arithmetization/tests/log_opcode.rs +++ b/evm_arithmetization/tests/log_opcode.rs @@ -266,6 +266,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/selfdestruct.rs b/evm_arithmetization/tests/selfdestruct.rs index e528ae804..cbebca587 100644 --- a/evm_arithmetization/tests/selfdestruct.rs +++ b/evm_arithmetization/tests/selfdestruct.rs @@ -170,6 +170,7 @@ fn test_selfdestruct() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/simple_transfer.rs b/evm_arithmetization/tests/simple_transfer.rs index 6d2cbb6d0..add46007c 100644 --- a/evm_arithmetization/tests/simple_transfer.rs +++ b/evm_arithmetization/tests/simple_transfer.rs @@ -162,6 +162,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/withdrawals.rs b/evm_arithmetization/tests/withdrawals.rs index f179d7b7e..1874a2b7f 100644 --- a/evm_arithmetization/tests/withdrawals.rs +++ b/evm_arithmetization/tests/withdrawals.rs @@ -105,6 +105,7 @@ fn test_withdrawals() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_tables: vec![None], }; let max_cpu_len_log = 20; diff --git a/test_batching.sh b/test_batching.sh new file mode 100755 index 000000000..26a06985d --- /dev/null +++ b/test_batching.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +set -exo + +test_batch_size() { + cargo --quiet run --release --bin leader -- --runtime in-memory -b $2 -n 1 --test-only --save-inputs-on-error stdio < $1.witness.json +} + +main() { + local NUM_TX=$(expr $(cast block $1 | wc -l) - 10) + for BATCH_SIZE in $(seq $NUM_TX) + do + test_batch_size $1 $BATCH_SIZE + done +} + +main 748 diff --git a/trace_decoder/src/core.rs b/trace_decoder/src/core.rs index cd2bcfa6e..e0abcd30f 100644 --- a/trace_decoder/src/core.rs +++ b/trace_decoder/src/core.rs @@ -1,15 +1,24 @@ +use core::{convert::Into as _, option::Option::None}; use std::{ cmp, collections::{BTreeMap, BTreeSet, HashMap}, + iter::repeat, marker::PhantomData, mem, }; +use alloy::{ + consensus::{Transaction, TxEnvelope}, + primitives::{address, TxKind}, + rlp::Decodable as _, +}; +use alloy_compat::Compat as _; use anyhow::{anyhow, bail, ensure, Context as _}; use either::Either; -use ethereum_types::{Address, BigEndianHash as _, U256}; +use ethereum_types::{Address, BigEndianHash as _, H160, U256}; use evm_arithmetization::{ generation::TrieInputs, + jumpdest::JumpDestTableWitness, proof::{BlockMetadata, TrieRoots}, tries::{MptKey, ReceiptTrie, StateMpt, StorageTrie, TransactionTrie}, world::{Hasher, KeccakHash, PoseidonHash, Type1World, Type2World, World}, @@ -28,6 +37,24 @@ use crate::{ TxnInfo, TxnMeta, TxnTrace, }; +/// Addresses of precompiled Ethereum contracts. +pub fn is_precompile(addr: H160) -> bool { + let precompiled_addresses = if cfg!(feature = "eth_mainnet") { + address!("0000000000000000000000000000000000000001") + ..address!("000000000000000000000000000000000000000a") + } else { + // Remove KZG Peval for non-Eth mainnet networks + address!("0000000000000000000000000000000000000001") + ..address!("0000000000000000000000000000000000000009") + }; + + precompiled_addresses.contains(&addr.compat()) + || (cfg!(feature = "polygon_pos") + // Include P256Verify for Polygon PoS + && addr.compat() + == address!("0000000000000000000000000000000000000100")) +} + /// Expected trie type when parsing from binary in a [`BlockTrace`]. /// /// See [`crate::wire`] and [`CombinedPreImages`] for more. @@ -80,7 +107,7 @@ pub fn entrypoint( let batches = match start { Either::Left((type1world, mut code)) => { - code.extend(code_db); + code.extend(code_db.clone()); Either::Left( middle( type1world, @@ -97,7 +124,7 @@ pub fn entrypoint( ) } Either::Right((type2world, mut code)) => { - code.extend(code_db); + code.extend(code_db.clone()); Either::Right( middle( type2world, @@ -132,6 +159,7 @@ pub fn entrypoint( }, after, withdrawals, + jumpdest_tables, }| { let (state, storage) = world .clone() @@ -144,7 +172,7 @@ pub fn entrypoint( running_gas_used += gas_used; running_gas_used.into() }, - signed_txns: byte_code.into_iter().map(Into::into).collect(), + signed_txns: byte_code.clone().into_iter().map(Into::into).collect(), withdrawals, ger_data, tries: TrieInputs { @@ -156,20 +184,49 @@ pub fn entrypoint( trie_roots_after: after, checkpoint_state_trie_root, checkpoint_consolidated_hash, - contract_code: contract_code - .into_iter() - .map(|it| match &world { + contract_code: { + let init_codes = + byte_code + .iter() + .filter_map(|nonempty_txn_bytes| -> Option> { + let tx_envelope = + TxEnvelope::decode(&mut &nonempty_txn_bytes[..]).unwrap(); + match tx_envelope.to() { + TxKind::Create => Some(tx_envelope.input().to_vec()), + TxKind::Call(_address) => None, + } + }); + + let mut result = match &world { Either::Left(_type1) => { - (::CodeHasher::hash(&it), it) + Hash2Code::<::CodeHasher>::new() } Either::Right(_type2) => { - (::CodeHasher::hash(&it), it) + panic!() } - }) - .collect(), + }; + result.extend(init_codes); + result.extend(contract_code); + result.extend(code_db.clone()); + result.into_hashmap() + }, block_metadata: b_meta.clone(), block_hashes: b_hashes.clone(), burn_addr, + batch_jumpdest_tables: { + // TODO(einar-polygon): + // Note that this causes any batch containing just a + // single `None` to collapse into a `None`, which + // causing failover to simulating jumpdest analysis for + // the whole batch. There is an optimization opportunity + // here. + dbg!(&jumpdest_tables); + if jumpdest_tables.iter().any(Option::is_none) { + repeat(None).take(jumpdest_tables.len()).collect::>() + } else { + jumpdest_tables + } + }, } }, ) @@ -334,6 +391,8 @@ struct Batch { /// Empty for all but the final batch pub withdrawals: Vec<(Address, U256)>, + + pub jumpdest_tables: Vec>, } impl Batch { @@ -346,6 +405,7 @@ impl Batch { before, after, withdrawals, + jumpdest_tables, } = self; Batch { first_txn_ix, @@ -355,6 +415,7 @@ impl Batch { before: before.map(f), after, withdrawals, + jumpdest_tables, } } } @@ -447,6 +508,8 @@ where )?; } + let mut jumpdest_tables = vec![]; + for txn in batch { let do_increment_txn_ix = txn.is_some(); let TxnInfo { @@ -456,6 +519,7 @@ where byte_code, new_receipt_trie_node_byte, gas_used: txn_gas_used, + jumpdest_table, }, } = txn.unwrap_or_default(); @@ -575,6 +639,8 @@ where } } + jumpdest_tables.push(jumpdest_table); + if do_increment_txn_ix { txn_ix += 1; } @@ -608,6 +674,7 @@ where transactions_root: transaction_trie.root(), receipts_root: receipt_trie.root(), }, + jumpdest_tables, }); observer.collect_tries( @@ -797,6 +864,7 @@ fn map_receipt_bytes(bytes: Vec) -> anyhow::Result> { /// trace. /// If there are any txns that create contracts, then they will also /// get added here as we process the deltas. +#[derive(Default)] struct Hash2Code { /// Key must always be [`hash`](World::CodeHasher) of value. inner: HashMap>, @@ -813,11 +881,15 @@ impl Hash2Code { this } pub fn get(&mut self, hash: H256) -> Option> { - self.inner.get(&hash).cloned() + let res = self.inner.get(&hash).cloned(); + res } pub fn insert(&mut self, code: Vec) { self.inner.insert(H::hash(&code), code); } + pub fn into_hashmap(self) -> HashMap> { + self.inner + } } impl Extend> for Hash2Code { diff --git a/trace_decoder/src/interface.rs b/trace_decoder/src/interface.rs index abe3b0af0..35248f9a3 100644 --- a/trace_decoder/src/interface.rs +++ b/trace_decoder/src/interface.rs @@ -5,6 +5,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; use ethereum_types::{Address, U256}; +use evm_arithmetization::jumpdest::JumpDestTableWitness; use evm_arithmetization::proof::{BlockHashes, BlockMetadata}; use evm_arithmetization::ConsolidatedHash; use keccak_hash::H256; @@ -111,6 +112,9 @@ pub struct TxnMeta { /// Gas used by this txn (Note: not cumulative gas used). pub gas_used: u64, + + /// JumpDest table + pub jumpdest_table: Option, } /// A "trace" specific to an account for a txn. diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index 4bd68bc11..524df90c5 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -68,7 +68,7 @@ mod type1; mod type2; mod wire; -pub use core::{entrypoint, WireDisposition}; +pub use core::{entrypoint, is_precompile, WireDisposition}; mod core; diff --git a/zero/Cargo.toml b/zero/Cargo.toml index 27fd5bdb5..9ee54e11e 100644 --- a/zero/Cargo.toml +++ b/zero/Cargo.toml @@ -11,6 +11,8 @@ categories.workspace = true [dependencies] alloy.workspace = true alloy-compat = "0.1.1" +alloy-primitives.workspace = true +alloy-serde.workspace = true anyhow.workspace = true async-stream.workspace = true axum.workspace = true @@ -64,4 +66,4 @@ cdk_erigon = ["evm_arithmetization/cdk_erigon", "trace_decoder/cdk_erigon"] polygon_pos = ["evm_arithmetization/polygon_pos", "trace_decoder/polygon_pos"] [lints] -workspace = true +workspace = true \ No newline at end of file diff --git a/zero/src/bin/leader.rs b/zero/src/bin/leader.rs index 939b6a655..7dc467f08 100644 --- a/zero/src/bin/leader.rs +++ b/zero/src/bin/leader.rs @@ -103,6 +103,7 @@ async fn main() -> Result<()> { Command::Rpc { rpc_url, rpc_type, + jumpdest_src, checkpoint_block, previous_proof, block_time, @@ -110,6 +111,7 @@ async fn main() -> Result<()> { end_block, backoff, max_retries, + timeout, } => { // Construct the provider. let previous_proof = get_previous_proof(previous_proof)?; @@ -139,6 +141,8 @@ async fn main() -> Result<()> { previous_proof, prover_config, }, + jumpdest_src, + timeout, ) .await?; } diff --git a/zero/src/bin/leader/cli.rs b/zero/src/bin/leader/cli.rs index e83471ad2..92beac4c2 100644 --- a/zero/src/bin/leader/cli.rs +++ b/zero/src/bin/leader/cli.rs @@ -1,11 +1,13 @@ use std::path::PathBuf; +use std::time::Duration; use alloy::eips::BlockId; use alloy::transports::http::reqwest::Url; use clap::{Parser, Subcommand, ValueEnum, ValueHint}; +use zero::parsing::parse_duration; use zero::prover::cli::CliProverConfig; use zero::prover_state::cli::CliProverStateConfig; -use zero::rpc::RpcType; +use zero::rpc::{JumpdestSrc, RpcType}; const WORKER_HELP_HEADING: &str = "Worker Config options"; @@ -70,6 +72,14 @@ pub(crate) enum Command { default_value = "jerigon" )] rpc_type: RpcType, + /// The source of jumpdest tables. + #[arg( + short = 'j', + long, + default_value = "client-fetched-structlogs", + required = false + )] + jumpdest_src: JumpdestSrc, /// The start of the block range to prove (inclusive). #[arg(short = 's', long, env = "ZERO_BIN_START_BLOCK")] start_block: BlockId, @@ -94,6 +104,9 @@ pub(crate) enum Command { /// The maximum number of retries #[arg(long, env = "ZERO_BIN_MAX_RETRIES", default_value_t = 0)] max_retries: u32, + /// Timeout for fetching structlog traces + #[arg(long, default_value = "60", value_parser = parse_duration)] + timeout: Duration, }, /// Reads input from HTTP and writes output to a directory. Http { diff --git a/zero/src/bin/leader/client.rs b/zero/src/bin/leader/client.rs index 6f2015833..619d1c35a 100644 --- a/zero/src/bin/leader/client.rs +++ b/zero/src/bin/leader/client.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Duration; use alloy::providers::Provider; use alloy::rpc::types::{BlockId, BlockNumberOrTag}; @@ -11,7 +12,7 @@ use zero::pre_checks::check_previous_proof_and_checkpoint; use zero::proof_types::GeneratedBlockProof; use zero::prover::{self, BlockProverInput, ProverConfig}; use zero::provider::CachedProvider; -use zero::rpc; +use zero::rpc::{self, JumpdestSrc}; use crate::ProofRuntime; @@ -29,6 +30,8 @@ pub(crate) async fn client_main( block_time: u64, block_interval: BlockInterval, mut leader_config: LeaderConfig, + jumpdest_src: JumpdestSrc, + timeout: Duration, ) -> Result<()> where ProviderT: Provider + 'static, @@ -81,6 +84,8 @@ where cached_provider.clone(), block_id, leader_config.checkpoint_block_number, + jumpdest_src, + timeout, ) .await?; block_tx diff --git a/zero/src/bin/rpc.rs b/zero/src/bin/rpc.rs index 331f3ec91..5186c668d 100644 --- a/zero/src/bin/rpc.rs +++ b/zero/src/bin/rpc.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Duration; use alloy::primitives::B256; use alloy::providers::Provider; @@ -13,10 +14,12 @@ use tracing_subscriber::{prelude::*, EnvFilter}; use url::Url; use zero::block_interval::BlockInterval; use zero::block_interval::BlockIntervalStream; +use zero::parsing::parse_duration; use zero::prover::BlockProverInput; use zero::prover::WIRE_DISPOSITION; use zero::provider::CachedProvider; use zero::rpc; +use zero::rpc::JumpdestSrc; use self::rpc::{retry::build_http_retry_provider, RpcType}; @@ -25,6 +28,8 @@ struct FetchParams { pub start_block: u64, pub end_block: u64, pub checkpoint_block_number: Option, + pub jumpdest_src: JumpdestSrc, + pub timeout: Duration, } #[derive(Args, Clone, Debug)] @@ -40,12 +45,23 @@ struct RpcToolConfig { default_value = "jerigon" )] rpc_type: RpcType, + /// The source of jumpdest tables. + #[arg( + short = 'j', + long, + default_value = "client-fetched-structlogs", + required = false + )] + jumpdest_src: JumpdestSrc, /// Backoff in milliseconds for retry requests. #[arg(long, env = "ZERO_BIN_BACKOFF", default_value_t = 0)] backoff: u64, /// The maximum number of retries. #[arg(long, env = "ZERO_BIN_MAX_RETRIES", default_value_t = 0)] max_retries: u32, + /// Timeout for fetching structlog traces + #[arg(long, default_value = "60", value_parser = parse_duration)] + timeout: Duration, } #[derive(Subcommand)] @@ -102,9 +118,14 @@ where let (block_num, _is_last_block) = block_interval_elem?; let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num)); // Get the prover input for particular block. - let result = - rpc::block_prover_input(cached_provider.clone(), block_id, checkpoint_block_number) - .await?; + let result = rpc::block_prover_input( + cached_provider.clone(), + block_id, + checkpoint_block_number, + params.jumpdest_src, + params.timeout, + ) + .await?; block_prover_inputs.push(result); } @@ -131,6 +152,8 @@ impl Cli { start_block, end_block, checkpoint_block_number, + jumpdest_src: self.config.jumpdest_src, + timeout: self.config.timeout, }; let block_prover_inputs = @@ -156,6 +179,8 @@ impl Cli { start_block: block_number, end_block: block_number, checkpoint_block_number: None, + jumpdest_src: self.config.jumpdest_src, + timeout: self.config.timeout, }; let block_prover_inputs = @@ -209,8 +234,11 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::Registry::default() .with( tracing_subscriber::fmt::layer() + // With the default configuration trace information is written + // to stdout, but we already use stdout to write our payload (the witness). + .with_writer(std::io::stderr) + // .json() .with_ansi(false) - .compact() .with_filter(EnvFilter::from_default_env()), ) .init(); diff --git a/zero/src/parsing.rs b/zero/src/parsing.rs index 5643f82f5..19c49ba83 100644 --- a/zero/src/parsing.rs +++ b/zero/src/parsing.rs @@ -1,5 +1,11 @@ //! Parsing utilities. -use std::{fmt::Display, ops::Add, ops::Range, str::FromStr}; +use std::{ + fmt::Display, + num::ParseIntError, + ops::{Add, Range}, + str::FromStr, + time::Duration, +}; use thiserror::Error; @@ -66,6 +72,11 @@ where } } +pub fn parse_duration(arg: &str) -> Result { + let seconds = arg.parse()?; + Ok(Duration::from_secs(seconds)) +} + #[cfg(test)] mod test { use super::*; diff --git a/zero/src/rpc/jerigon.rs b/zero/src/rpc/jerigon.rs index 915176852..dfe0544ec 100644 --- a/zero/src/rpc/jerigon.rs +++ b/zero/src/rpc/jerigon.rs @@ -1,12 +1,25 @@ -use alloy::{providers::Provider, rpc::types::eth::BlockId, transports::Transport}; +use core::iter::Iterator; +use std::ops::Deref as _; +use std::time::Duration; + +use alloy::eips::BlockNumberOrTag; +use alloy::{ + providers::Provider, + rpc::types::{eth::BlockId, Block, BlockTransactionsKind}, + transports::Transport, +}; +use alloy_primitives::Address; use anyhow::Context as _; +use evm_arithmetization::jumpdest::JumpDestTableWitness; use serde::Deserialize; use serde_json::json; use trace_decoder::{BlockTrace, BlockTraceTriePreImages, CombinedPreImages, TxnInfo}; +use tracing::{debug, warn}; -use super::fetch_other_block_data; +use super::{fetch_other_block_data, JumpdestSrc}; use crate::prover::BlockProverInput; use crate::provider::CachedProvider; +use crate::rpc::jumpdest::{generate_jumpdest_table, get_block_normalized_structlogs}; const WITNESS_ENDPOINT: &str = if cfg!(feature = "cdk_erigon") { "zkevm_getWitness" @@ -26,6 +39,8 @@ pub async fn block_prover_input( cached_provider: std::sync::Arc>, target_block_id: BlockId, checkpoint_block_number: u64, + jumpdest_src: JumpdestSrc, + fetch_timeout: Duration, ) -> anyhow::Result where ProviderT: Provider, @@ -39,16 +54,53 @@ where "debug_traceBlockByNumber".into(), (target_block_id, json!({"tracer": "zeroTracer"})), ) - .await?; + .await? + .into_iter() + .map(|ztr| ztr.result) + .collect::>(); // Grab block witness info (packed as combined trie pre-images) - let block_witness = cached_provider .get_provider() .await? .raw_request::<_, String>(WITNESS_ENDPOINT.into(), vec![target_block_id]) .await?; + let block: Block = cached_provider + .get_block(target_block_id, BlockTransactionsKind::Full) + .await? + .context("no block")?; + + let block_jumpdest_table_witnesses: Vec> = match jumpdest_src { + JumpdestSrc::ProverSimulation => vec![None; tx_results.len()], + JumpdestSrc::ClientFetchedStructlogs => { + // In case of the error with retrieving structlogs from the server, + // continue without interruption. Equivalent to `ProverSimulation` case. + process_transactions( + &block, + cached_provider.get_provider().await?.deref(), + &tx_results, + &fetch_timeout, + ) + .await + .unwrap_or_else(|e| { + warn!("failed to fetch server structlogs for block {target_block_id}: {e}"); + vec![None; tx_results.len()] + }) + } + JumpdestSrc::Serverside => todo!(), + }; + + // weave in the JDTs + let txn_info = tx_results + .into_iter() + .zip(block_jumpdest_table_witnesses) + .map(|(mut tx_info, jdtw)| { + tx_info.meta.jumpdest_table = jdtw; + tx_info + }) + .collect(); + let other_data = fetch_other_block_data(cached_provider, target_block_id, checkpoint_block_number).await?; @@ -61,9 +113,61 @@ where "invalid hex returned from call to {WITNESS_ENDPOINT}" ))?, }), - txn_info: tx_results.into_iter().map(|it| it.result).collect(), + txn_info, code_db: Default::default(), }, other_data, }) } + +/// Processes the transactions in the given block, generating jumpdest tables +/// and updates the code database +pub async fn process_transactions<'i, ProviderT, TransportT>( + block: &Block, + provider: &ProviderT, + tx_results: &[TxnInfo], + fetch_timeout: &Duration, +) -> anyhow::Result>> +where + ProviderT: Provider, + TransportT: Transport + Clone, +{ + let block_structlogs = get_block_normalized_structlogs( + provider, + &BlockNumberOrTag::from(block.header.number), + fetch_timeout, + ) + .await?; + + let tx_traces = tx_results.iter().map(|tx| { + tx.traces + .iter() + .map(|(h, t)| (Address::from(h.as_fixed_bytes()), t)) + }); + + let block_jumpdest_tables = block + .transactions + .as_transactions() + .context("no transactions in block")? + .iter() + .zip(block_structlogs) + .zip(tx_traces) + .map(|((tx, structlog), tx_trace)| { + structlog.and_then(|it| { + generate_jumpdest_table(tx, &it.1, tx_trace).map_or_else( + |error| { + debug!( + "{}: JumpDestTable generation failed with reason: {:?}", + tx.hash.to_string(), + error + ); + None + }, + Some, + ) + }) + }) + .collect::>(); + + Ok(block_jumpdest_tables) +} diff --git a/zero/src/rpc/jumpdest.rs b/zero/src/rpc/jumpdest.rs new file mode 100644 index 000000000..88f93038e --- /dev/null +++ b/zero/src/rpc/jumpdest.rs @@ -0,0 +1,392 @@ +use core::default::Default; +use core::option::Option::None; +use std::collections::HashMap; +use std::ops::Not as _; +use std::time::Duration; + +// use ::compat::Compat; +use alloy::eips::BlockNumberOrTag; +use alloy::primitives::Address; +use alloy::providers::ext::DebugApi; +use alloy::providers::Provider; +use alloy::rpc::types::eth::Transaction; +use alloy::rpc::types::trace::geth::{ + GethDebugTracingOptions, GethDefaultTracingOptions, GethTrace, StructLog, TraceResult, +}; +use alloy::transports::Transport; +use alloy_compat::Compat as _; +use alloy_primitives::{TxHash, U256}; +use anyhow::bail; +use anyhow::ensure; +use evm_arithmetization::jumpdest::JumpDestTableWitness; +use keccak_hash::{keccak, H256}; +use tokio::time::timeout; +use trace_decoder::is_precompile; +use trace_decoder::ContractCodeUsage; +use trace_decoder::TxnTrace; +use tracing::{info, warn}; + +// use crate::rpc::H256; + +#[derive(Debug, Clone)] +pub struct TxStructLogs(pub Option, pub Vec); + +/// Pass `true` for the components needed. +fn structlog_tracing_options(stack: bool, memory: bool, storage: bool) -> GethDebugTracingOptions { + GethDebugTracingOptions { + config: GethDefaultTracingOptions { + disable_stack: Some(!stack), + // needed for CREATE2 + disable_memory: Some(!memory), + disable_storage: Some(!storage), + ..GethDefaultTracingOptions::default() + }, + tracer: None, + ..GethDebugTracingOptions::default() + } +} + +/// Get code hash from a read or write operation of contract code. +fn get_code_hash(usage: &ContractCodeUsage) -> H256 { + match usage { + ContractCodeUsage::Read(hash) => *hash, + ContractCodeUsage::Write(bytes) => keccak(bytes), + } +} + +pub(crate) async fn get_block_normalized_structlogs( + provider: &ProviderT, + block: &BlockNumberOrTag, + fetch_timeout: &Duration, +) -> anyhow::Result>> +where + ProviderT: Provider, + TransportT: Transport + Clone, +{ + let block_stackonly_structlog_traces_fut = + provider.debug_trace_block_by_number(*block, structlog_tracing_options(true, false, false)); + + let block_stackonly_structlog_traces = + match timeout(*fetch_timeout, block_stackonly_structlog_traces_fut).await { + Ok(traces) => traces?, + Err(elapsed) => { + bail!(elapsed); + } + }; + + let block_normalized_stackonly_structlog_traces = block_stackonly_structlog_traces + .into_iter() + .map(|tx_trace_result| match tx_trace_result { + TraceResult::Success { + result, tx_hash, .. + } => Ok(trace_to_tx_structlog(tx_hash, result)), + TraceResult::Error { error, tx_hash } => Err(anyhow::anyhow!( + "error fetching structlog for tx: {tx_hash:?}. Error: {error:?}" + )), + }) + .collect::>, anyhow::Error>>()?; + + Ok(block_normalized_stackonly_structlog_traces) +} + +/// Generate at JUMPDEST table by simulating the call stack in EVM, +/// using a Geth structlog as input. +pub(crate) fn generate_jumpdest_table<'a>( + tx: &Transaction, + structlog: &[StructLog], + tx_traces: impl Iterator, +) -> anyhow::Result { + let mut jumpdest_table = JumpDestTableWitness::default(); + + // This map does neither contain the `init` field of Contract Deployment + // transactions nor CREATE, CREATE2 payloads. + let callee_addr_to_code_hash: HashMap = tx_traces + .filter_map(|(callee_addr, trace)| { + trace + .code_usage + .as_ref() + .map(|code| (callee_addr, get_code_hash(code))) + }) + .collect(); + + let entrypoint_code_hash: H256 = match tx.to { + Some(to_address) if is_precompile(to_address.compat()) => return Ok(jumpdest_table), + Some(to_address) if callee_addr_to_code_hash.contains_key(&to_address).not() => { + return Ok(jumpdest_table) + } + Some(to_address) => callee_addr_to_code_hash[&to_address], + None => { + let init = &tx.input; + keccak(init) + } + }; + + // `None` encodes that previous `entry` was not a JUMP or JUMPI with true + // condition, `Some(jump_target)` encodes we came from a JUMP or JUMPI with + // true condition and target `jump_target`. + let mut prev_jump: Option = None; + + // The next available context. Starts at 1. Never decrements. + let mut next_ctx_available = 1; + // Immediately use context 1; + let mut call_stack = vec![(entrypoint_code_hash, next_ctx_available)]; + next_ctx_available += 1; + + let mut stuctlog_iter = structlog.iter().enumerate().peekable(); + while let Some((step, entry)) = stuctlog_iter.next() { + let op = entry.op.as_str(); + let curr_depth: usize = entry.depth.try_into().unwrap(); + + ensure!(curr_depth <= next_ctx_available, "Structlog is malformed."); + + while curr_depth < call_stack.len() { + call_stack.pop(); + } + + ensure!( + call_stack.is_empty().not(), + "Call stack was unexpectedly empty." + ); + let (ref code_hash, ref ctx) = call_stack.last().unwrap().clone(); + info!("INSERT {} {}", *code_hash, *ctx); + jumpdest_table.insert(*code_hash, *ctx, None); + + // REVIEW: will be removed before merge + tracing::info!( + step, + curr_depth, + tx_hash = ?tx.hash, + ?code_hash, + ctx, + next_ctx_available, + pc = entry.pc, + pc_hex = format!("{:08x?}", entry.pc), + gas = entry.gas, + gas_cost = entry.gas_cost, + op, + ?entry, + ); + + match op { + "CALL" | "CALLCODE" | "DELEGATECALL" | "STATICCALL" => { + prev_jump = None; + ensure!(entry.stack.as_ref().is_some(), "No evm stack found."); + // We reverse the stack, so the order matches our assembly code. + let evm_stack: Vec<_> = entry.stack.as_ref().unwrap().iter().rev().collect(); + // These opcodes expect 6 or 7 operands on the stack, but for jumpdest-table + // generation we only use 2, and failures will be handled in + // next iteration by popping the stack accordingly. + let operands_used = 2; + + if evm_stack.len() < operands_used { + // Note for future debugging: There may exist edge cases, where the call + // context has been incremented before the call op fails. This should be + // accounted for before this and the following `continue`. The details are + // defined in `sys_calls.asm`. + continue; + } + // This is the same stack index (i.e. 2nd) for all four opcodes. See https://ethervm.io/#F1 + let [_gas, address, ..] = evm_stack[..] else { + unreachable!() + }; + + let callee_address = stack_value_to_address(address); + if callee_addr_to_code_hash.contains_key(&callee_address) { + let next_code_hash = callee_addr_to_code_hash[&callee_address]; + call_stack.push((next_code_hash, next_ctx_available)); + }; + + if let Some((_next_step, next_entry)) = stuctlog_iter.peek() { + let next_depth: usize = next_entry.depth.try_into().unwrap(); + if next_depth < curr_depth { + // The call caused an exception. Skip over incrementing + // `next_ctx_available`. + continue; + } + } + // `peek()` only returns `None` if we are at the last entry of + // the Structlog, whether we are on a `CALL` op that throws an + // exception or not. But this is of no consequence to the + // generated Jumpdest table, so we can ignore the case. + + jumpdest_table.insert(*code_hash, next_ctx_available, None); + next_ctx_available += 1; + } + "CREATE" | "CREATE2" => { + bail!(format!( + "{} requires memory, aborting JUMPDEST-table generation.", + tx.hash + )); + } + "JUMP" => { + prev_jump = None; + ensure!(entry.stack.as_ref().is_some(), "No evm stack found."); + // We reverse the stack, so the order matches our assembly code. + let evm_stack: Vec<_> = entry.stack.as_ref().unwrap().iter().rev().collect(); + let operands = 1; + if evm_stack.len() < operands { + continue; + } + let [jump_target, ..] = evm_stack[..] else { + unreachable!() + }; + + prev_jump = Some(*jump_target); + } + "JUMPI" => { + prev_jump = None; + ensure!(entry.stack.as_ref().is_some(), "No evm stack found."); + // We reverse the stack, so the order matches our assembly code. + let evm_stack: Vec<_> = entry.stack.as_ref().unwrap().iter().rev().collect(); + let operands = 2; + if evm_stack.len() < operands { + continue; + }; + + let [jump_target, condition, ..] = evm_stack[..] else { + unreachable!() + }; + let jump_condition = condition.is_zero().not(); + + if jump_condition { + prev_jump = Some(*jump_target) + } + } + "JUMPDEST" => { + let mut jumped_here = false; + + if let Some(jmp_target) = prev_jump { + jumped_here = jmp_target == U256::from(entry.pc); + } + prev_jump = None; + + if jumped_here.not() { + continue; + } + + let jumpdest_offset = TryInto::::try_into(entry.pc); + if jumpdest_offset.is_err() { + continue; + } + ensure!(jumpdest_offset.unwrap() < 24576); + jumpdest_table.insert(*code_hash, *ctx, Some(jumpdest_offset.unwrap())); + } + "EXTCODECOPY" | "EXTCODESIZE" => { + prev_jump = None; + jumpdest_table.insert(*code_hash, next_ctx_available, None); + next_ctx_available += 1; + } + _ => { + prev_jump = None; + } + } + } + info!("RETURN {}", &jumpdest_table); + Ok(jumpdest_table) +} + +fn stack_value_to_address(operand: &U256) -> Address { + let all_bytes: [u8; 32] = operand.to_be_bytes(); + let mut lower_20_bytes = [0u8; 20]; + // Based on `__compat_primitive_types::H160::from(H256::from(all_bytes)). + // into()`. + lower_20_bytes[0..20].copy_from_slice(&all_bytes[32 - 20..32]); + Address::from(lower_20_bytes) +} + +fn trace_to_tx_structlog(tx_hash: Option, trace: GethTrace) -> Option { + match trace { + GethTrace::Default(structlog_frame) => { + Some(TxStructLogs(tx_hash, structlog_frame.struct_logs)) + } + GethTrace::JS(it) => { + let default_frame = compat::deserialize(it) + .inspect_err(|e| warn!("failed to deserialize js default frame {e:?}")) + .ok()?; + Some(TxStructLogs(tx_hash, default_frame.struct_logs)) + } + _ => None, + } +} + +/// This module exists as a workaround for parsing `StructLog`. The `error` +/// field is a string in Geth and Alloy but an object in Erigon. A PR[\^1] has +/// been merged to fix this upstream and should eventually render this +/// unnecessary. [\^1]: `https://github.com/erigontech/erigon/pull/12089` +mod compat { + use std::{collections::BTreeMap, fmt, iter}; + + use alloy::rpc::types::trace::geth::{DefaultFrame, StructLog}; + use alloy_primitives::{Bytes, B256, U256}; + use serde::{de::SeqAccess, Deserialize, Deserializer}; + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result { + _DefaultFrame::deserialize(d) + } + + /// The `error` field is a `string` in `geth` etc. but an `object` in + /// `erigon`. + fn error<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + #[derive(Deserialize)] + #[serde(untagged)] + enum Error { + String(String), + #[allow(dead_code)] + Object(serde_json::Map), + } + Ok(match Error::deserialize(d)? { + Error::String(it) => Some(it), + Error::Object(_) => None, + }) + } + + #[derive(Deserialize)] + #[serde(remote = "DefaultFrame", rename_all = "camelCase")] + struct _DefaultFrame { + failed: bool, + gas: u64, + return_value: Bytes, + #[serde(deserialize_with = "vec_structlog")] + struct_logs: Vec, + } + + fn vec_structlog<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = Vec; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("an array of `StructLog`") + } + fn visit_seq>(self, mut seq: A) -> Result { + #[derive(Deserialize)] + struct With(#[serde(with = "_StructLog")] StructLog); + let v = iter::from_fn(|| seq.next_element().transpose()) + .map(|it| it.map(|With(it)| it)) + .collect::>()?; + Ok(v) + } + } + + d.deserialize_seq(Visitor) + } + + #[derive(Deserialize)] + #[serde(remote = "StructLog", rename_all = "camelCase")] + struct _StructLog { + pc: u64, + op: String, + gas: u64, + gas_cost: u64, + depth: u64, + #[serde(default, deserialize_with = "error")] + error: Option, + stack: Option>, + return_data: Option, + memory: Option>, + #[serde(rename = "memSize")] + memory_size: Option, + storage: Option>, + #[serde(rename = "refund")] + refund_counter: Option, + } +} diff --git a/zero/src/rpc/mod.rs b/zero/src/rpc/mod.rs index b972bc501..98a063425 100644 --- a/zero/src/rpc/mod.rs +++ b/zero/src/rpc/mod.rs @@ -1,6 +1,6 @@ zk_evm_common::check_chain_features!(); -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use alloy::{ primitives::{Address, Bloom, Bytes, FixedBytes, B256, U256}, @@ -23,6 +23,7 @@ use tracing::warn; use crate::prover::BlockProverInput; pub mod jerigon; +pub mod jumpdest; pub mod native; pub mod retry; @@ -39,11 +40,21 @@ pub enum RpcType { Native, } +/// The Jumpdest source type. +#[derive(ValueEnum, Clone, Debug, Copy)] +pub enum JumpdestSrc { + ProverSimulation, + ClientFetchedStructlogs, + Serverside, // later +} + /// Obtain the prover input for one block pub async fn block_prover_input( cached_provider: Arc>, block_id: BlockId, checkpoint_block_number: u64, + jumpdest_src: JumpdestSrc, + fetch_timeout: Duration, ) -> Result where ProviderT: Provider, @@ -51,10 +62,24 @@ where { match cached_provider.rpc_type { RpcType::Jerigon => { - jerigon::block_prover_input(cached_provider, block_id, checkpoint_block_number).await + jerigon::block_prover_input( + cached_provider, + block_id, + checkpoint_block_number, + jumpdest_src, + fetch_timeout, + ) + .await } RpcType::Native => { - native::block_prover_input(cached_provider, block_id, checkpoint_block_number).await + native::block_prover_input( + cached_provider, + block_id, + checkpoint_block_number, + jumpdest_src, + fetch_timeout, + ) + .await } } } diff --git a/zero/src/rpc/native/mod.rs b/zero/src/rpc/native/mod.rs index a4dc7e0c6..dda71a85e 100644 --- a/zero/src/rpc/native/mod.rs +++ b/zero/src/rpc/native/mod.rs @@ -1,6 +1,5 @@ -use std::collections::BTreeSet; -use std::ops::Deref; use std::sync::Arc; +use std::{ops::Deref, time::Duration}; use alloy::{ providers::Provider, @@ -16,20 +15,24 @@ use crate::provider::CachedProvider; mod state; mod txn; -type CodeDb = BTreeSet>; +pub use txn::{process_transaction, process_transactions}; + +use super::JumpdestSrc; /// Fetches the prover input for the given BlockId. pub async fn block_prover_input( provider: Arc>, block_number: BlockId, checkpoint_block_number: u64, + jumpdest_src: JumpdestSrc, + fetch_timeout: Duration, ) -> anyhow::Result where ProviderT: Provider, TransportT: Transport + Clone, { let (block_trace, other_data) = try_join!( - process_block_trace(provider.clone(), block_number), + process_block_trace(provider.clone(), block_number, jumpdest_src, &fetch_timeout), crate::rpc::fetch_other_block_data(provider.clone(), block_number, checkpoint_block_number) )?; @@ -40,9 +43,11 @@ where } /// Processes the block with the given block number and returns the block trace. -async fn process_block_trace( +pub(crate) async fn process_block_trace( cached_provider: Arc>, block_number: BlockId, + jumpdest_src: JumpdestSrc, + fetch_timeout: &Duration, ) -> anyhow::Result where ProviderT: Provider, @@ -53,8 +58,13 @@ where .await? .ok_or(anyhow::anyhow!("block not found {}", block_number))?; - let (code_db, txn_info) = - txn::process_transactions(&block, cached_provider.get_provider().await?.deref()).await?; + let (code_db, txn_info) = txn::process_transactions( + &block, + cached_provider.get_provider().await?.deref(), + jumpdest_src, + fetch_timeout, + ) + .await?; let trie_pre_images = state::process_state_witness(cached_provider, block, &txn_info).await?; Ok(BlockTrace { diff --git a/zero/src/rpc/native/txn.rs b/zero/src/rpc/native/txn.rs index 4989c1ab5..daf3d6388 100644 --- a/zero/src/rpc/native/txn.rs +++ b/zero/src/rpc/native/txn.rs @@ -1,5 +1,10 @@ +use core::option::Option::None; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::time::Duration; +// use __compat_primitive_types::{H256, U256}; +use alloy::eips::BlockNumberOrTag; +use alloy::rpc::types::trace::geth::TraceResult; use alloy::{ primitives::{keccak256, Address, B256, U256}, providers::{ @@ -12,32 +17,116 @@ use alloy::{ trace::geth::{ AccountState, DiffMode, GethDebugBuiltInTracerType, GethDebugTracerType, GethDebugTracingOptions, GethTrace, PreStateConfig, PreStateFrame, PreStateMode, + StructLog, }, }, transports::Transport, }; use alloy_compat::Compat; -use anyhow::Context as _; +use anyhow::{bail, Context as _, Ok}; +use evm_arithmetization::{jumpdest::JumpDestTableWitness, CodeDb}; use futures::stream::{FuturesOrdered, TryStreamExt}; use trace_decoder::{ContractCodeUsage, TxnInfo, TxnMeta, TxnTrace}; +use tracing::{debug, warn}; -use super::CodeDb; +use crate::rpc::jumpdest::get_block_normalized_structlogs; +use crate::rpc::{ + jumpdest::{self}, + JumpdestSrc, +}; +pub(crate) async fn get_block_prestate_traces( + provider: &ProviderT, + block: &BlockNumberOrTag, + tracing_options: GethDebugTracingOptions, +) -> anyhow::Result> +where + ProviderT: Provider, + TransportT: Transport + Clone, +{ + let block_prestate_traces = provider + .debug_trace_block_by_number(*block, tracing_options) + .await?; + + block_prestate_traces + .into_iter() + .map(|trace_result| match trace_result { + TraceResult::Success { result, .. } => Ok(result), + TraceResult::Error { error, .. } => { + bail!("error fetching block prestate traces: {:?}", error) + } + }) + .collect::, anyhow::Error>>() +} /// Processes the transactions in the given block and updates the code db. -pub(super) async fn process_transactions( +pub async fn process_transactions( block: &Block, provider: &ProviderT, + jumpdest_src: JumpdestSrc, + fetch_timeout: &Duration, ) -> anyhow::Result<(CodeDb, Vec)> where ProviderT: Provider, TransportT: Transport + Clone, { + // Get block prestate traces + let block_prestate_trace = get_block_prestate_traces( + provider, + &BlockNumberOrTag::from(block.header.number), + prestate_tracing_options(false), + ) + .await?; + + // Get block diff traces + let block_diff_trace = get_block_prestate_traces( + provider, + &BlockNumberOrTag::from(block.header.number), + prestate_tracing_options(true), + ) + .await?; + + let block_structlogs = match jumpdest_src { + JumpdestSrc::ProverSimulation => vec![None; block_prestate_trace.len()], + JumpdestSrc::ClientFetchedStructlogs => { + // In case of the error with retrieving structlogs from the server, + // continue without interruption. Equivalent to `ProverSimulation` case. + get_block_normalized_structlogs( + provider, + &BlockNumberOrTag::from(block.header.number), + fetch_timeout, + ) + .await + .unwrap_or_else(|e| { + warn!( + "failed to fetch server structlogs for block {}: {e}", + block.header.number + ); + vec![None; block_prestate_trace.len()] + }) + .into_iter() + .map(|tx_struct_log| tx_struct_log.map(|it| it.1)) + .collect() + } + JumpdestSrc::Serverside => todo!( + "Not implemented. See https://github.com/0xPolygonZero/erigon/issues/20 for details." + ), + }; + block .transactions .as_transactions() .context("No transactions in block")? .iter() - .map(|tx| process_transaction(provider, tx)) + .zip( + block_prestate_trace.into_iter().zip( + block_diff_trace + .into_iter() + .zip(block_structlogs.into_iter()), + ), + ) + .map(|(tx, (pre_trace, (diff_trace, structlog)))| { + process_transaction(provider, tx, pre_trace, diff_trace, structlog) + }) .collect::>() .try_fold( (BTreeSet::new(), Vec::new()), @@ -52,25 +141,22 @@ where /// Processes the transaction with the given transaction hash and updates the /// accounts state. -async fn process_transaction( +pub async fn process_transaction( provider: &ProviderT, tx: &Transaction, + pre_trace: GethTrace, + diff_trace: GethTrace, + structlog_opt: Option>, ) -> anyhow::Result<(CodeDb, TxnInfo)> where ProviderT: Provider, TransportT: Transport + Clone, { - let (tx_receipt, pre_trace, diff_trace) = fetch_tx_data(provider, &tx.hash).await?; + let tx_receipt = fetch_tx_receipt(provider, &tx.hash).await?; let tx_status = tx_receipt.status(); let tx_receipt = tx_receipt.map_inner(rlp::map_receipt_envelope); let access_list = parse_access_list(tx.access_list.as_ref()); - let tx_meta = TxnMeta { - byte_code: ::TxEnvelope::try_from(tx.clone())?.encoded_2718(), - new_receipt_trie_node_byte: alloy::rlp::encode(tx_receipt.inner), - gas_used: tx_receipt.gas_used as u64, - }; - let (code_db, mut tx_traces) = match (pre_trace, diff_trace) { ( GethTrace::PreStateTracer(PreStateFrame::Default(read)), @@ -82,7 +168,29 @@ where // Handle case when transaction failed and a contract creation was reverted if !tx_status && tx_receipt.contract_address.is_some() { tx_traces.insert(tx_receipt.contract_address.unwrap(), TxnTrace::default()); - } + }; + + let jumpdest_table: Option = structlog_opt.and_then(|struct_logs| { + jumpdest::generate_jumpdest_table(tx, &struct_logs, tx_traces.iter().map(|(a, t)| (*a, t))) + .map_or_else( + |error| { + debug!( + "{}: JumpDestTable generation failed with reason: {:?}", + tx.hash.to_string(), + error + ); + None + }, + Some, + ) + }); + + let tx_meta = TxnMeta { + byte_code: ::TxEnvelope::try_from(tx.clone())?.encoded_2718(), + new_receipt_trie_node_byte: alloy::rlp::encode(tx_receipt.inner), + gas_used: tx_receipt.gas_used as u64, + jumpdest_table, + }; Ok(( code_db, @@ -97,26 +205,16 @@ where } /// Fetches the transaction data for the given transaction hash. -async fn fetch_tx_data( +async fn fetch_tx_receipt( provider: &ProviderT, tx_hash: &B256, -) -> anyhow::Result<(::ReceiptResponse, GethTrace, GethTrace), anyhow::Error> +) -> anyhow::Result<::ReceiptResponse> where ProviderT: Provider, TransportT: Transport + Clone, { - let tx_receipt_fut = provider.get_transaction_receipt(*tx_hash); - let pre_trace_fut = provider.debug_trace_transaction(*tx_hash, prestate_tracing_options(false)); - let diff_trace_fut = provider.debug_trace_transaction(*tx_hash, prestate_tracing_options(true)); - - let (tx_receipt, pre_trace, diff_trace) = - futures::try_join!(tx_receipt_fut, pre_trace_fut, diff_trace_fut,)?; - - Ok(( - tx_receipt.context("Transaction receipt not found.")?, - pre_trace, - diff_trace, - )) + let tx_receipt = provider.get_transaction_receipt(*tx_hash).await?; + tx_receipt.context("Transaction receipt not found.") } /// Parse the access list data into a hashmap.