Skip to content

Commit 4c4fa95

Browse files
Clean ups
1 parent 63ad6f9 commit 4c4fa95

File tree

7 files changed

+57
-51
lines changed

7 files changed

+57
-51
lines changed

.github/workflows/integration.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@ jobs:
6262
env:
6363
RAYON_NUM_THREADS: 8
6464
RUST_LOG: debug
65-
run: cargo run --release --package ceno_zkvm --example fibonacci_elf --target ${{ matrix.target }} --
65+
run: cargo run --release --package ceno_zkvm --example fibonacci_elf --target ${{ matrix.target }}

ceno_emul/src/vm_state.rs

+1-7
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,7 @@ impl VMState {
7979

8080
pub fn iter_until_halt(&mut self) -> impl Iterator<Item = Result<StepRecord>> + '_ {
8181
let emu = Emulator::new();
82-
from_fn(move || {
83-
if self.halted() {
84-
None
85-
} else {
86-
Some(self.step(&emu))
87-
}
88-
})
82+
from_fn(move || self.halted().then(|| self.step(&emu)))
8983
}
9084

9185
fn step(&mut self, emu: &Emulator) -> Result<StepRecord> {

ceno_zkvm/examples/fibonacci_elf.rs

+43-21
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate};
1919
use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme};
2020
use std::{
2121
collections::{HashMap, HashSet},
22-
panic,
22+
panic::{self, PanicHookInfo},
2323
time::Instant,
2424
};
2525
use tracing_flame::FlameLayer;
@@ -35,6 +35,24 @@ struct Args {
3535
max_steps: Option<usize>,
3636
}
3737

38+
fn with_panic_hook<F, R>(hook: Box<dyn Fn(&PanicHookInfo<'_>) + Sync + Send + 'static>, f: F) -> R
39+
where
40+
F: FnOnce() -> R,
41+
{
42+
// Save the current panic hook
43+
let original_hook = panic::take_hook();
44+
45+
// Set the new panic hook
46+
panic::set_hook(hook);
47+
48+
let result = f();
49+
50+
// Restore the original panic hook
51+
panic::set_hook(original_hook);
52+
53+
result
54+
}
55+
3856
fn main() {
3957
let args = Args::parse();
4058

@@ -125,7 +143,7 @@ fn main() {
125143

126144
let pk = zkvm_cs
127145
.clone()
128-
.key_gen::<Pcs>(pp.clone(), vp.clone(), zkvm_fixed_traces.clone())
146+
.key_gen::<Pcs>(pp, vp, zkvm_fixed_traces.clone())
129147
.expect("keygen failed");
130148
let vk = pk.get_vk();
131149

@@ -153,14 +171,14 @@ fn main() {
153171
record.insn().codes().kind == EANY
154172
&& record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt()
155173
})
156-
.and_then(|halt_record| halt_record.rs2())
174+
.and_then(StepRecord::rs2)
157175
.map(|rs2| rs2.value);
158176

159177
let final_access = vm.tracer().final_accesses();
160178
let end_cycle: u32 = vm.tracer().cycle().try_into().unwrap();
161179

162180
let pi = PublicValues::new(
163-
exit_code.unwrap_or(0),
181+
exit_code.unwrap_or_default(),
164182
vm.program().entry,
165183
Tracer::SUBCYCLES_PER_INSN as u32,
166184
vm.get_pc().into(),
@@ -188,7 +206,7 @@ fn main() {
188206
MemFinalRecord {
189207
addr: rec.addr,
190208
value: vm.peek_register(index),
191-
cycle: *final_access.get(&vma).unwrap_or(&0),
209+
cycle: final_access.get(&vma).copied().unwrap_or_default(),
192210
}
193211
} else {
194212
// The table is padded beyond the number of registers.
@@ -209,7 +227,7 @@ fn main() {
209227
MemFinalRecord {
210228
addr: rec.addr,
211229
value: vm.peek_memory(vma),
212-
cycle: *final_access.get(&vma).unwrap_or(&0),
230+
cycle: final_access.get(&vma).copied().unwrap_or_default(),
213231
}
214232
})
215233
.collect_vec();
@@ -218,7 +236,12 @@ fn main() {
218236
// Find the final public IO cycles.
219237
let io_final = io_init
220238
.iter()
221-
.map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0))
239+
.map(|rec| {
240+
final_access
241+
.get(&rec.addr.into())
242+
.copied()
243+
.unwrap_or_default()
244+
})
222245
.collect_vec();
223246

224247
// assign table circuits
@@ -269,18 +292,16 @@ fn main() {
269292
}
270293

271294
let transcript = Transcript::new(b"riscv");
272-
// change public input maliciously should cause verifier to reject proof
295+
// Maliciously changing the public input should cause the verifier to reject the proof.
273296
zkvm_proof.raw_pi[0] = vec![<GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE];
274297
zkvm_proof.raw_pi[1] = vec![<GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE];
275298

276-
// capture panic message, if have
277-
let default_hook = panic::take_hook();
278-
panic::set_hook(Box::new(|_info| {
279-
// by default it will print msg to stdout/stderr
280-
// we override it to avoid print msg since we will capture the msg by our own
281-
}));
282-
let result = panic::catch_unwind(|| verifier.verify_proof(zkvm_proof, transcript));
283-
panic::set_hook(default_hook);
299+
// capture panic message, if any
300+
// by default it will print msg to stdout/stderr
301+
// we override it to avoid print msg since we will capture the msg by ourselves
302+
let result = with_panic_hook(Box::new(|_info| ()), || {
303+
panic::catch_unwind(|| verifier.verify_proof(zkvm_proof, transcript))
304+
});
284305
match result {
285306
Ok(res) => {
286307
res.expect_err("verify proof should return with error");
@@ -322,23 +343,24 @@ fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) {
322343

323344
tracing::debug!(
324345
"Memory range (accessed): {:?}",
325-
format_segments(vm.platform(), accessed_addrs.iter().copied())
346+
format_segments(vm.platform(), &accessed_addrs)
326347
);
327348
tracing::debug!(
328349
"Memory range (handled): {:?}",
329-
format_segments(vm.platform(), handled_addrs.iter().copied())
350+
format_segments(vm.platform(), &handled_addrs)
330351
);
331352

332353
for addr in &accessed_addrs {
333354
assert!(handled_addrs.contains(addr), "unhandled addr: {:?}", addr);
334355
}
335356
}
336357

337-
fn format_segments(
358+
fn format_segments<'a>(
338359
platform: &Platform,
339-
addrs: impl Iterator<Item = ByteAddr>,
340-
) -> HashMap<String, MinMaxResult<ByteAddr>> {
360+
addrs: impl IntoIterator<Item = &'a ByteAddr>,
361+
) -> HashMap<String, MinMaxResult<&'a ByteAddr>> {
341362
addrs
363+
.into_iter()
342364
.into_grouping_map_by(|addr| format_segment(platform, addr.0))
343365
.minmax()
344366
}

ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs

+4-13
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,7 @@ impl<E: ExtensionField> MmuConfig<E> {
4646
io_addrs: &[Addr],
4747
) {
4848
assert!(
49-
chain(
50-
static_mem_init.iter().map(|record| record.addr),
51-
io_addrs.iter().copied(),
52-
)
53-
.all_unique(),
49+
chain(static_mem_init.iter().map(|record| &record.addr), io_addrs,).all_unique(),
5450
"memory addresses must be unique"
5551
);
5652

@@ -142,14 +138,9 @@ impl MemPadder {
142138
new_len: usize,
143139
records: Vec<MemInitRecord>,
144140
) -> Vec<MemInitRecord> {
145-
if records.is_empty() {
146-
self.padded(new_len, records)
147-
} else {
148-
self.padded(new_len, records)
149-
.into_iter()
150-
.sorted_by_key(|record| record.addr)
151-
.collect()
152-
}
141+
let mut padded = self.padded(new_len, records);
142+
padded.sort_by_key(|record| record.addr);
143+
padded
153144
}
154145

155146
/// Pad `records` to `new_len` using unused addresses.

ceno_zkvm/src/scheme.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use ff_ext::ExtensionField;
2-
use itertools::Itertools;
32
use mpcs::PolynomialCommitmentScheme;
43
use serde::{Deserialize, Serialize};
54
use std::{collections::BTreeMap, fmt::Debug};
@@ -133,15 +132,15 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
133132
.iter()
134133
.map(|pv| {
135134
if pv.len() == 1 {
136-
// this is constant poly, and always evaluate to same constant value
135+
// this is constant poly, and always evaluates to same constant value
137136
E::from(pv[0])
138137
} else {
139-
// set 0 as placeholder. will be evaluate lazily
138+
// set 0 as placeholder. will be evaluated lazily
140139
// Or the vector is empty, i.e. the constant 0 polynomial.
141140
E::ZERO
142141
}
143142
})
144-
.collect_vec();
143+
.collect();
145144
Self {
146145
raw_pi,
147146
pi_evals,

ceno_zkvm/src/scheme/prover.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,7 @@ impl TowerProver {
12411241
virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator);
12421242
}
12431243
}
1244-
tracing::debug!("generated tower proof at round {}/{}", round, max_round_index);
1244+
tracing::debug!("generated tower proof at round {round}/{max_round_index}");
12451245

12461246
let wrap_batch_span = entered_span!("wrap_batch");
12471247
// NOTE: at the time of adding this span, visualizing it with the flamegraph layer

ceno_zkvm/src/scheme/verifier.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
6060
transcript: Transcript<E>,
6161
does_halt: bool,
6262
) -> Result<bool, ZKVMError> {
63-
// require ecall/halt proof to exist, depending whether we expect a halt.
63+
// require ecall/halt proof to exist, depending on whether we expect a halt.
6464
let num_instances = vm_proof
6565
.opcode_proofs
6666
.get(&HaltInstruction::<E>::name())
6767
.map(|(_, p)| p.num_instances)
68-
.unwrap_or(0);
68+
.unwrap_or_default();
6969
if num_instances != (does_halt as usize) {
7070
return Err(ZKVMError::VerifyError(format!(
7171
"ecall/halt num_instances={}, expected={}",
@@ -117,12 +117,12 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
117117
}
118118
}
119119

120-
for (name, (_, proof)) in vm_proof.opcode_proofs.iter() {
120+
for (name, (_, proof)) in &vm_proof.opcode_proofs {
121121
tracing::debug!("read {}'s commit", name);
122122
PCS::write_commitment(&proof.wits_commit, &mut transcript)
123123
.map_err(ZKVMError::PCSError)?;
124124
}
125-
for (name, (_, proof)) in vm_proof.table_proofs.iter() {
125+
for (name, (_, proof)) in &vm_proof.table_proofs {
126126
tracing::debug!("read {}'s commit", name);
127127
PCS::write_commitment(&proof.wits_commit, &mut transcript)
128128
.map_err(ZKVMError::PCSError)?;

0 commit comments

Comments
 (0)