@@ -19,7 +19,7 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate};
19
19
use mpcs:: { Basefold , BasefoldRSParams , PolynomialCommitmentScheme } ;
20
20
use std:: {
21
21
collections:: { HashMap , HashSet } ,
22
- panic,
22
+ panic:: { self , PanicHookInfo } ,
23
23
time:: Instant ,
24
24
} ;
25
25
use tracing_flame:: FlameLayer ;
@@ -35,6 +35,24 @@ struct Args {
35
35
max_steps : Option < usize > ,
36
36
}
37
37
38
+ fn with_panic_hook < F , R > ( hook : Box < dyn Fn ( & PanicHookInfo < ' _ > ) + Sync + Send + ' static > , f : F ) -> R
39
+ where
40
+ F : FnOnce ( ) -> R ,
41
+ {
42
+ // Save the current panic hook
43
+ let original_hook = panic:: take_hook ( ) ;
44
+
45
+ // Set the new panic hook
46
+ panic:: set_hook ( hook) ;
47
+
48
+ let result = f ( ) ;
49
+
50
+ // Restore the original panic hook
51
+ panic:: set_hook ( original_hook) ;
52
+
53
+ result
54
+ }
55
+
38
56
fn main ( ) {
39
57
let args = Args :: parse ( ) ;
40
58
@@ -125,7 +143,7 @@ fn main() {
125
143
126
144
let pk = zkvm_cs
127
145
. clone ( )
128
- . key_gen :: < Pcs > ( pp. clone ( ) , vp. clone ( ) , zkvm_fixed_traces. clone ( ) )
146
+ . key_gen :: < Pcs > ( pp, vp, zkvm_fixed_traces. clone ( ) )
129
147
. expect ( "keygen failed" ) ;
130
148
let vk = pk. get_vk ( ) ;
131
149
@@ -153,14 +171,14 @@ fn main() {
153
171
record. insn ( ) . codes ( ) . kind == EANY
154
172
&& record. rs1 ( ) . unwrap ( ) . value == CENO_PLATFORM . ecall_halt ( )
155
173
} )
156
- . and_then ( |halt_record| halt_record . rs2 ( ) )
174
+ . and_then ( StepRecord :: rs2)
157
175
. map ( |rs2| rs2. value ) ;
158
176
159
177
let final_access = vm. tracer ( ) . final_accesses ( ) ;
160
178
let end_cycle: u32 = vm. tracer ( ) . cycle ( ) . try_into ( ) . unwrap ( ) ;
161
179
162
180
let pi = PublicValues :: new (
163
- exit_code. unwrap_or ( 0 ) ,
181
+ exit_code. unwrap_or_default ( ) ,
164
182
vm. program ( ) . entry ,
165
183
Tracer :: SUBCYCLES_PER_INSN as u32 ,
166
184
vm. get_pc ( ) . into ( ) ,
@@ -188,7 +206,7 @@ fn main() {
188
206
MemFinalRecord {
189
207
addr : rec. addr ,
190
208
value : vm. peek_register ( index) ,
191
- cycle : * final_access. get ( & vma) . unwrap_or ( & 0 ) ,
209
+ cycle : final_access. get ( & vma) . copied ( ) . unwrap_or_default ( ) ,
192
210
}
193
211
} else {
194
212
// The table is padded beyond the number of registers.
@@ -209,7 +227,7 @@ fn main() {
209
227
MemFinalRecord {
210
228
addr : rec. addr ,
211
229
value : vm. peek_memory ( vma) ,
212
- cycle : * final_access. get ( & vma) . unwrap_or ( & 0 ) ,
230
+ cycle : final_access. get ( & vma) . copied ( ) . unwrap_or_default ( ) ,
213
231
}
214
232
} )
215
233
. collect_vec ( ) ;
@@ -218,7 +236,12 @@ fn main() {
218
236
// Find the final public IO cycles.
219
237
let io_final = io_init
220
238
. iter ( )
221
- . map ( |rec| * final_access. get ( & rec. addr . into ( ) ) . unwrap_or ( & 0 ) )
239
+ . map ( |rec| {
240
+ final_access
241
+ . get ( & rec. addr . into ( ) )
242
+ . copied ( )
243
+ . unwrap_or_default ( )
244
+ } )
222
245
. collect_vec ( ) ;
223
246
224
247
// assign table circuits
@@ -269,18 +292,16 @@ fn main() {
269
292
}
270
293
271
294
let transcript = Transcript :: new ( b"riscv" ) ;
272
- // change public input maliciously should cause verifier to reject proof
295
+ // Maliciously changing the public input should cause the verifier to reject the proof.
273
296
zkvm_proof. raw_pi [ 0 ] = vec ! [ <GoldilocksExt2 as ff_ext:: ExtensionField >:: BaseField :: ONE ] ;
274
297
zkvm_proof. raw_pi [ 1 ] = vec ! [ <GoldilocksExt2 as ff_ext:: ExtensionField >:: BaseField :: ONE ] ;
275
298
276
- // capture panic message, if have
277
- let default_hook = panic:: take_hook ( ) ;
278
- panic:: set_hook ( Box :: new ( |_info| {
279
- // by default it will print msg to stdout/stderr
280
- // we override it to avoid print msg since we will capture the msg by our own
281
- } ) ) ;
282
- let result = panic:: catch_unwind ( || verifier. verify_proof ( zkvm_proof, transcript) ) ;
283
- panic:: set_hook ( default_hook) ;
299
+ // capture panic message, if any
300
+ // by default it will print msg to stdout/stderr
301
+ // we override it to avoid print msg since we will capture the msg by ourselves
302
+ let result = with_panic_hook ( Box :: new ( |_info| ( ) ) , || {
303
+ panic:: catch_unwind ( || verifier. verify_proof ( zkvm_proof, transcript) )
304
+ } ) ;
284
305
match result {
285
306
Ok ( res) => {
286
307
res. expect_err ( "verify proof should return with error" ) ;
@@ -322,23 +343,24 @@ fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) {
322
343
323
344
tracing:: debug!(
324
345
"Memory range (accessed): {:?}" ,
325
- format_segments( vm. platform( ) , accessed_addrs. iter ( ) . copied ( ) )
346
+ format_segments( vm. platform( ) , & accessed_addrs)
326
347
) ;
327
348
tracing:: debug!(
328
349
"Memory range (handled): {:?}" ,
329
- format_segments( vm. platform( ) , handled_addrs. iter ( ) . copied ( ) )
350
+ format_segments( vm. platform( ) , & handled_addrs)
330
351
) ;
331
352
332
353
for addr in & accessed_addrs {
333
354
assert ! ( handled_addrs. contains( addr) , "unhandled addr: {:?}" , addr) ;
334
355
}
335
356
}
336
357
337
- fn format_segments (
358
+ fn format_segments < ' a > (
338
359
platform : & Platform ,
339
- addrs : impl Iterator < Item = ByteAddr > ,
340
- ) -> HashMap < String , MinMaxResult < ByteAddr > > {
360
+ addrs : impl IntoIterator < Item = & ' a ByteAddr > ,
361
+ ) -> HashMap < String , MinMaxResult < & ' a ByteAddr > > {
341
362
addrs
363
+ . into_iter ( )
342
364
. into_grouping_map_by ( |addr| format_segment ( platform, addr. 0 ) )
343
365
. minmax ( )
344
366
}
0 commit comments