Skip to content

Commit 75d75ea

Browse files
authored
support build virtual polynomials in expression style (#937)
To close #936 ### Design rationales - introduce `VirtualPolynomialsBuilder` to lift a witness of "ArcPoly" type to expression container, so they can involve into expression domain for calculation - apply `VirtualPolynomialsBuilder` in tower prover. - keep scalar in base field as possible via introducing `Either<Base, Ext>` type - reserve design for "eq" degree -1 optimisation > this part work haven't done yet and set as future work :) `VirtualPolynomialsBuilder` is more like a util function for ceno main sumcheck flow. For GKR layer circuit in gk- iop #799 , the expression system will directly applied on chip-builder and skip `VirtualPolynomialsBuilder` ### benchmark there is no impact for e2e benchmark before/after this change, which is expected 2^20 ``` fibonacci_max_steps_1048576/prove_fibonacci/fibonacci_max_steps_1048576 time: [2.3583 s 2.3709 s 2.3848 s] change: [-1.8405% -1.0740% -0.2480%] (p = 0.03 < 0.05) Change within noise threshold. ``` 2^21 ``` fibonacci_max_steps_2097152/prove_fibonacci/fibonacci_max_steps_2097152 time: [4.4650 s 4.4758 s 4.4867 s] change: [-0.6673% -0.3122% +0.0493%] (p = 0.13 > 0.05) No change in performance detected. ``` 2^22 ``` fibonacci_max_steps_4194304/prove_fibonacci/fibonacci_max_steps_4194304 time: [9.0115 s 9.0574 s 9.1011 s] change: [-1.0658% -0.3407% +0.3803%] (p = 0.40 > 0.05) No change in performance detected. ```
1 parent 681a84c commit 75d75ea

File tree

27 files changed

+812
-365
lines changed

27 files changed

+812
-365
lines changed

Cargo.lock

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ cfg-if = "1.0"
3535
clap = { version = "4.5", features = ["derive"] }
3636
criterion = { version = "0.5", features = ["html_reports"] }
3737
crossbeam-channel = "0.5"
38+
either = { version = "1.15.*", features = ["serde"] }
3839
itertools = "0.13"
3940
num-bigint = { version = "0.4.6" }
4041
num-derive = "0.4"

ceno_emul/src/rv32im.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
use anyhow::{Result, anyhow};
1818
use ff_ext::{ExtensionField, SmallField};
19+
use itertools::Either;
1920
use multilinear_extensions::{Expression, impl_expr_from_unsigned};
2021
use num_derive::ToPrimitive;
2122
use strum_macros::{Display, EnumIter};

ceno_zkvm/src/chip_handler/general.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
264264
|| "require_equal",
265265
|cb| {
266266
cb.cs
267-
.require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form())
267+
.require_zero(name_fn, a.get_monomial_form() - b.get_monomial_form())
268268
},
269269
)
270270
}

ceno_zkvm/src/chip_handler/global_state.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use ff_ext::ExtensionField;
22

33
use super::GlobalStateRegisterMachineChipOperations;
44
use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::RAMType};
5-
use multilinear_extensions::Expression;
5+
use multilinear_extensions::{Expression, ToExpr};
66
use p3::field::PrimeCharacteristicRing;
77

88
impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitBuilder<'_, E> {
99
fn state_in(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError> {
1010
let record: Vec<Expression<E>> = vec![
11-
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
11+
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
1212
pc,
1313
ts,
1414
];
@@ -18,7 +18,7 @@ impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitB
1818

1919
fn state_out(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError> {
2020
let record: Vec<Expression<E>> = vec![
21-
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
21+
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
2222
pc,
2323
ts,
2424
];

ceno_zkvm/src/chip_handler/utils.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::iter::successors;
22

33
use ff_ext::ExtensionField;
44
use itertools::izip;
5-
use multilinear_extensions::Expression;
5+
use multilinear_extensions::{Expression, ToExpr};
66
use p3::field::PrimeCharacteristicRing;
77

88
pub fn rlc_chip_record<E: ExtensionField>(
@@ -31,7 +31,7 @@ pub fn power_sequence<E: ExtensionField>(
3131
),
3232
"expression must be constant or challenge"
3333
);
34-
successors(Some(Expression::Constant(E::BaseField::ONE)), move |prev| {
34+
successors(Some(E::BaseField::ONE.expr()), move |prev| {
3535
Some(prev.clone() * base.clone())
3636
})
3737
}

ceno_zkvm/src/circuit_builder.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use itertools::{Itertools, chain};
2-
use multilinear_extensions::{Expression, Fixed, Instance, StructuralWitIn, WitIn, WitnessId};
2+
use multilinear_extensions::{
3+
Expression, Fixed, Instance, StructuralWitIn, ToExpr, WitIn, WitnessId,
4+
};
35
use serde::de::DeserializeOwned;
46
use std::{collections::HashMap, iter::once, marker::PhantomData};
57

@@ -256,11 +258,9 @@ impl<E: ExtensionField> ConstraintSystem<E> {
256258
record: Vec<Expression<E>>,
257259
) -> Result<(), ZKVMError> {
258260
let rlc_record = self.rlc_chip_record(
259-
std::iter::once(Expression::Constant(E::BaseField::from_u64(
260-
rom_type as u64,
261-
)))
262-
.chain(record.clone())
263-
.collect(),
261+
std::iter::once(E::BaseField::from_u64(rom_type as u64).expr())
262+
.chain(record.clone())
263+
.collect(),
264264
);
265265
assert_eq!(
266266
rlc_record.degree(),
@@ -432,7 +432,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
432432
let assert_zero_expr = if assert_zero_expr.is_monomial_form() {
433433
assert_zero_expr
434434
} else {
435-
let e = assert_zero_expr.to_monomial_form();
435+
let e = assert_zero_expr.get_monomial_form();
436436
assert!(e.is_monomial_form(), "failed to put into monomial form");
437437
e
438438
};

ceno_zkvm/src/scheme/mock_prover.rs

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,22 +1147,16 @@ Hints:
11471147

11481148
let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) =
11491149
derive_ram_rws!(RAMType::GlobalState);
1150-
gs_rs.insert(eval_by_expr_with_instance(
1151-
&[],
1152-
&[],
1153-
&[],
1154-
&instance,
1155-
&challenges,
1156-
&gs_final,
1157-
));
1158-
gs_ws.insert(eval_by_expr_with_instance(
1159-
&[],
1160-
&[],
1161-
&[],
1162-
&instance,
1163-
&challenges,
1164-
&gs_init,
1165-
));
1150+
gs_rs.insert(
1151+
eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_final)
1152+
.right()
1153+
.unwrap(),
1154+
);
1155+
gs_ws.insert(
1156+
eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_init)
1157+
.right()
1158+
.unwrap(),
1159+
);
11661160

11671161
// gs stores { (pc, timestamp) }
11681162
find_rw_mismatch!(
@@ -1382,7 +1376,7 @@ mod tests {
13821376
GoldilocksExt2::ONE,
13831377
GoldilocksExt2::ZERO,
13841378
)),
1385-
Box::new(Expression::Constant(Goldilocks::from_u64(U5 as u64))),
1379+
Box::new(Goldilocks::from_u64(U5 as u64).expr()),
13861380
)),
13871381
Box::new(Expression::Challenge(
13881382
0,

ceno_zkvm/src/scheme/prover.rs

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
use ff_ext::ExtensionField;
22
use std::collections::{BTreeMap, BTreeSet, HashMap};
33

4-
use itertools::{Itertools, enumerate, izip};
4+
use itertools::{Either, Itertools, enumerate, izip};
55
use mpcs::{Point, PolynomialCommitmentScheme};
66
use multilinear_extensions::{
7+
Expression,
78
mle::IntoMLE,
89
util::ceil_log2,
910
virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec},
10-
virtual_polys::VirtualPolynomials,
11+
virtual_polys::{VirtualPolynomials, VirtualPolynomialsBuilder},
1112
};
1213
use p3::field::{PrimeCharacteristicRing, dot_product};
1314
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
15+
use std::iter::Iterator;
1416
use sumcheck::{
1517
macros::{entered_span, exit_span},
1618
structs::{IOPProverMessage, IOPProverState},
@@ -1230,10 +1232,15 @@ impl TowerProver {
12301232
let num_threads = optimal_sumcheck_threads(out_rt.len());
12311233

12321234
let eq: ArcMultilinearExtension<E> = build_eq_x_r_vec(&out_rt).into_mle().into();
1233-
let mut virtual_polys = VirtualPolynomials::<E>::new(num_threads, out_rt.len());
1235+
1236+
let mut expr_builder = VirtualPolynomialsBuilder::default();
1237+
let mut exprs =
1238+
Vec::<Expression<E>>::with_capacity(prod_specs.len() + logup_specs.len());
1239+
let eq_expr = expr_builder.lift(&eq);
12341240

12351241
for (s, alpha) in izip!(&prod_specs, &alpha_pows) {
12361242
if round < s.witness.len() {
1243+
let alpha_expr = Expression::Constant(Either::Right(*alpha));
12371244
let layer_polys = &s.witness[round];
12381245

12391246
// sanity check
@@ -1246,11 +1253,9 @@ impl TowerProver {
12461253
})
12471254
);
12481255

1256+
let layer_polys_product = layer_polys.iter().map(|layer_poly| expr_builder.lift(layer_poly)).product::<Expression<E>>();
12491257
// \sum_s eq(rt, s) * alpha^{i} * ([in_i0[s] * in_i1[s] * .... in_i{num_product_fanin}[s]])
1250-
virtual_polys.add_mle_list(
1251-
[vec![&eq], layer_polys.iter().collect()].concat(),
1252-
*alpha,
1253-
)
1258+
exprs.push(eq_expr.clone() * alpha_expr *layer_polys_product);
12541259
}
12551260
}
12561261

@@ -1259,37 +1264,36 @@ impl TowerProver {
12591264
if round < s.witness.len() {
12601265
let layer_polys = &s.witness[round];
12611266
// sanity check
1262-
assert_eq!(layer_polys.len(), 4); // p1, q1, p2, q2
1267+
assert_eq!(layer_polys.len(), 4); // p1, p2, q1, q2
12631268
assert!(
12641269
layer_polys
12651270
.iter()
12661271
.all(|f| f.evaluations().len() == 1 << (log_num_fanin * round)),
12671272
);
12681273

1269-
let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]);
1274+
let (alpha_numerator, alpha_denominator) = (Expression::Constant(Either::Right(alpha[0])), Expression::Constant(Either::Right(alpha[1])));
12701275

1271-
let (q2, q1, p2, p1) = (
1272-
&layer_polys[3],
1273-
&layer_polys[2],
1274-
&layer_polys[1],
1275-
&layer_polys[0],
1276+
let (p1, p2, q1, q2) = (
1277+
expr_builder.lift(&layer_polys[0]),
1278+
expr_builder.lift(&layer_polys[1]),
1279+
expr_builder.lift(&layer_polys[2]),
1280+
expr_builder.lift(&layer_polys[3]),
12761281
);
12771282

1278-
// \sum_s eq(rt, s) * alpha_numerator^{i} * (p1 * q2 + p2 * q1)
1279-
virtual_polys.add_mle_list(vec![&eq, &p1, &q2], *alpha_numerator);
1280-
virtual_polys.add_mle_list(vec![&eq, &p2, &q1], *alpha_numerator);
1281-
1282-
// \sum_s eq(rt, s) * alpha_denominator^{i} * (q1 * q2)
1283-
virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator);
1283+
// \sum_s eq(rt, s) * (alpha_numerator^{i} * (p1 * q2 + p2 * q1) + alpha_denominator^{i} * q1 * q2)
1284+
exprs.push(eq_expr.clone() * (alpha_numerator * (p1 * q2.clone() + p2 * q1.clone()) + alpha_denominator * q1 * q2));
12841285
}
12851286
}
12861287

12871288
let wrap_batch_span = entered_span!("wrap_batch");
1288-
// NOTE: at the time of adding this span, visualizing it with the flamegraph layer
1289-
// shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys`
1290-
// This is likely a bug in the tracing-flame crate.
12911289
let (sumcheck_proofs, state) = IOPProverState::prove(
1292-
virtual_polys,
1290+
expr_builder.to_virtual_polys(
1291+
num_threads,
1292+
out_rt.len(),
1293+
None,
1294+
&[exprs.into_iter().sum::<Expression<E>>()],
1295+
&[],
1296+
),
12931297
transcript,
12941298
);
12951299
exit_span!(wrap_batch_span);
@@ -1323,11 +1327,10 @@ impl TowerProver {
13231327
for (i, s) in enumerate(&logup_specs) {
13241328
if round < s.witness.len() {
13251329
// collect evals belong to current spec
1326-
// p1, q2, p2, q1
13271330
let p1 = *evals_iter.next().expect("insufficient evals length");
1328-
let q2 = *evals_iter.next().expect("insufficient evals length");
13291331
let p2 = *evals_iter.next().expect("insufficient evals length");
13301332
let q1 = *evals_iter.next().expect("insufficient evals length");
1333+
let q2 = *evals_iter.next().expect("insufficient evals length");
13311334
proofs.push_logup_evals_and_point(i, vec![p1, p2, q1, q2], rt_prime.clone());
13321335
}
13331336
}

ceno_zkvm/src/scheme/utils.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,11 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>(
261261
&|witness_id, _, _, _| structual_witnesses[witness_id as usize].clone(),
262262
&|i| instance[i.0].clone(),
263263
&|scalar| {
264-
let scalar: ArcMultilinearExtension<E> = Arc::new(
265-
DenseMultilinearExtension::from_evaluations_vec(0, vec![scalar]),
266-
);
264+
let scalar: ArcMultilinearExtension<E> =
265+
Arc::new(DenseMultilinearExtension::from_evaluations_vec(
266+
0,
267+
vec![scalar.left().expect("do not support extension field")],
268+
));
267269
scalar
268270
},
269271
&|challenge_id, pow, scalar, offset| {

ceno_zkvm/src/scheme/verifier.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
303303
pi_evals,
304304
&challenges,
305305
&self.vk.initial_global_state_expr,
306-
);
306+
)
307+
.right()
308+
.unwrap();
307309
prod_w *= initial_global_state;
308310
let finalize_global_state = eval_by_expr_with_instance(
309311
&[],
@@ -312,7 +314,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
312314
pi_evals,
313315
&challenges,
314316
&self.vk.finalize_global_state_expr,
315-
);
317+
)
318+
.right()
319+
.unwrap();
316320
prod_r *= finalize_global_state;
317321
// check rw_set equality across all proofs
318322
if prod_r != prod_w {
@@ -476,20 +480,26 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
476480
.iter()
477481
.map(|expr| {
478482
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
483+
.right()
484+
.unwrap()
479485
})
480486
.collect();
481487
let w_records_in_evals: Vec<_> = cs
482488
.w_expressions
483489
.iter()
484490
.map(|expr| {
485491
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
492+
.right()
493+
.unwrap()
486494
})
487495
.collect();
488496
let lk_records_in_evals: Vec<_> = cs
489497
.lk_expressions
490498
.iter()
491499
.map(|expr| {
492500
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
501+
.right()
502+
.unwrap()
493503
})
494504
.collect();
495505
let computed_evals = [
@@ -529,6 +539,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
529539
challenges,
530540
expr,
531541
)
542+
.right()
543+
.unwrap()
532544
})
533545
.sum::<E>()
534546
},
@@ -545,6 +557,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
545557
// verify zero expression (degree = 1) statement, thus no sumcheck
546558
if cs.assert_zero_expressions.iter().any(|expr| {
547559
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
560+
.right()
561+
.unwrap()
548562
!= E::ZERO
549563
}) {
550564
return Err(ZKVMError::VerifyError("zero expression != 0".into()));
@@ -813,7 +827,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
813827
pi,
814828
challenges,
815829
expr,
816-
) != expected_evals
830+
)
831+
.right()
832+
.unwrap()
833+
!= expected_evals
817834
}) {
818835
return Err(ZKVMError::VerifyError(
819836
"record evaluate != expected_evals".into(),

ceno_zkvm/src/state.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ impl<E: ExtensionField> StateCircuit<E> for GlobalState {
2020
circuit_builder: &mut crate::circuit_builder::CircuitBuilder<E>,
2121
) -> Result<Expression<E>, ZKVMError> {
2222
let states: Vec<Expression<E>> = vec![
23-
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
23+
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
2424
circuit_builder.query_init_pc()?.expr(),
2525
circuit_builder.query_init_cycle()?.expr(),
2626
];
@@ -32,7 +32,7 @@ impl<E: ExtensionField> StateCircuit<E> for GlobalState {
3232
circuit_builder: &mut crate::circuit_builder::CircuitBuilder<E>,
3333
) -> Result<Expression<E>, ZKVMError> {
3434
let states: Vec<Expression<E>> = vec![
35-
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
35+
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
3636
circuit_builder.query_end_pc()?.expr(),
3737
circuit_builder.query_end_cycle()?.expr(),
3838
];

0 commit comments

Comments
 (0)