Skip to content

Commit a0b719d

Browse files
authored
sumcheck refactor & unittest to demonstrate various num_var within single sumcheck (#862)
This PR addressed few question raised from @Jiangkm3 when deal with mpcs poly batching tasks ### change scope - [x] refactor `virtual_polys` out from `ceno_zkvm` to mle crate so could be share for other crates - [x] unittest to show sumcheck with different variables
1 parent 104ffcb commit a0b719d

File tree

11 files changed

+339
-355
lines changed

11 files changed

+339
-355
lines changed

ceno_zkvm/src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ pub mod stats;
2121
pub mod structs;
2222
mod uint;
2323
mod utils;
24-
mod virtual_polys;
2524
mod witness;
2625

2726
pub use structs::ROMType;

ceno_zkvm/src/scheme/prover.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use multilinear_extensions::{
1010
mle::{IntoMLE, MultilinearExtension},
1111
util::ceil_log2,
1212
virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec},
13+
virtual_polys::VirtualPolynomials,
1314
};
1415
use p3_field::PrimeCharacteristicRing;
1516
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
@@ -33,8 +34,7 @@ use crate::{
3334
structs::{
3435
Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses,
3536
},
36-
utils::{get_challenge_pows, optimal_sumcheck_threads},
37-
virtual_polys::VirtualPolynomials,
37+
utils::{add_mle_list_by_expr, get_challenge_pows, optimal_sumcheck_threads},
3838
};
3939

4040
use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof};
@@ -572,7 +572,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
572572
}
573573
}
574574

575-
distrinct_zerocheck_terms_set.extend(virtual_polys.add_mle_list_by_expr(
575+
distrinct_zerocheck_terms_set.extend(add_mle_list_by_expr(
576+
&mut virtual_polys,
576577
sel_non_lc_zero_sumcheck.as_ref(),
577578
witnesses.iter().collect_vec(),
578579
expr,

ceno_zkvm/src/utils.rs

+142-17
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
use std::{
2-
collections::HashMap,
2+
collections::{BTreeSet, HashMap},
33
fmt::Display,
44
hash::Hash,
5+
mem,
56
panic::{self, PanicHookInfo},
67
};
78

89
use ff_ext::{ExtensionField, SmallField};
910
use itertools::Itertools;
10-
use multilinear_extensions::util::max_usable_threads;
11+
use multilinear_extensions::{
12+
util::max_usable_threads, virtual_poly::ArcMultilinearExtension,
13+
virtual_polys::VirtualPolynomials,
14+
};
1115
use p3_field::Field;
1216
use transcript::Transcript;
1317

18+
use crate::expression::Expression;
19+
1420
pub fn i64_to_base<F: SmallField>(x: i64) -> F {
1521
if x >= 0 {
1622
F::from_u64(x as u64)
@@ -167,21 +173,6 @@ pub fn eval_wellform_address_vec<E: ExtensionField>(offset: u64, scaled: u64, r:
167173
.sum::<E>()
168174
}
169175

170-
/// transpose 2d vector without clone
171-
pub fn transpose<T>(v: Vec<Vec<T>>) -> Vec<Vec<T>> {
172-
assert!(!v.is_empty());
173-
let len = v[0].len();
174-
let mut iters: Vec<_> = v.into_iter().map(|n| n.into_iter()).collect();
175-
(0..len)
176-
.map(|_| {
177-
iters
178-
.iter_mut()
179-
.map(|n| n.next().unwrap())
180-
.collect::<Vec<T>>()
181-
})
182-
.collect()
183-
}
184-
185176
pub fn display_hashmap<K: Display, V: Display>(map: &HashMap<K, V>) -> String {
186177
format!(
187178
"[{}]",
@@ -223,3 +214,137 @@ where
223214

224215
result
225216
}
217+
218+
/// add mle terms into virtual poly by expression
219+
/// return distinct witin in set
220+
pub fn add_mle_list_by_expr<'a, E: ExtensionField>(
221+
virtual_polys: &mut VirtualPolynomials<'a, E>,
222+
selector: Option<&'a ArcMultilinearExtension<'a, E>>,
223+
wit_ins: Vec<&'a ArcMultilinearExtension<'a, E>>,
224+
expr: &Expression<E>,
225+
challenges: &[E],
226+
// sumcheck batch challenge
227+
alpha: E,
228+
) -> BTreeSet<u16> {
229+
assert!(expr.is_monomial_form());
230+
let monomial_terms = expr.evaluate(
231+
&|_| unreachable!(),
232+
&|witness_id| vec![(E::ONE, { vec![witness_id] })],
233+
&|structural_witness_id, _, _, _| vec![(E::ONE, { vec![structural_witness_id] })],
234+
&|scalar| vec![(E::from(scalar), { vec![] })],
235+
&|challenge_id, pow, scalar, offset| {
236+
let challenge = challenges[challenge_id as usize];
237+
vec![(challenge.exp_u64(pow as u64) * scalar + offset, vec![])]
238+
},
239+
&|mut a, b| {
240+
a.extend(b);
241+
a
242+
},
243+
&|mut a, mut b| {
244+
assert!(a.len() <= 2);
245+
assert!(b.len() <= 2);
246+
// special logic to deal with scaledsum
247+
// scaledsum second parameter must be 0
248+
if a.len() == 2 {
249+
assert!((a[1].0, a[1].1.is_empty()) == (E::ZERO, true));
250+
a.truncate(1);
251+
}
252+
if b.len() == 2 {
253+
assert!((b[1].0, b[1].1.is_empty()) == (E::ZERO, true));
254+
b.truncate(1);
255+
}
256+
257+
a[0].1.extend(mem::take(&mut b[0].1));
258+
// return [ab]
259+
vec![(a[0].0 * b[0].0, mem::take(&mut a[0].1))]
260+
},
261+
&|mut x, a, b| {
262+
assert!(a.len() == 1 && a[0].1.is_empty()); // for challenge or constant, term should be empty
263+
assert!(b.len() == 1 && b[0].1.is_empty()); // for challenge or constant, term should be empty
264+
assert!(x.len() == 1 && (x[0].0, x[0].1.len()) == (E::ONE, 1)); // witin size only 1
265+
if b[0].0 == E::ZERO {
266+
// only include first term if b = 0
267+
vec![(a[0].0, mem::take(&mut x[0].1))]
268+
} else {
269+
// return [ax, b]
270+
vec![(a[0].0, mem::take(&mut x[0].1)), (b[0].0, vec![])]
271+
}
272+
},
273+
);
274+
for (constant, monomial_term) in monomial_terms.iter() {
275+
if *constant != E::ZERO && monomial_term.is_empty() && selector.is_none() {
276+
todo!("make virtual poly support pure constant")
277+
}
278+
let sel = selector.map(|sel| vec![sel]).unwrap_or_default();
279+
let terms_polys = monomial_term
280+
.iter()
281+
.map(|wit_id| wit_ins[*wit_id as usize])
282+
.collect_vec();
283+
284+
virtual_polys.add_mle_list([sel, terms_polys].concat(), *constant * alpha);
285+
}
286+
287+
monomial_terms
288+
.into_iter()
289+
.flat_map(|(_, monomial_term)| monomial_term.into_iter().collect_vec())
290+
.collect::<BTreeSet<u16>>()
291+
}
292+
293+
#[cfg(test)]
294+
mod tests {
295+
use ff_ext::GoldilocksExt2;
296+
use itertools::Itertools;
297+
use multilinear_extensions::{
298+
mle::IntoMLE, virtual_poly::ArcMultilinearExtension, virtual_polys::VirtualPolynomials,
299+
};
300+
use p3_field::PrimeCharacteristicRing;
301+
302+
use crate::{
303+
circuit_builder::{CircuitBuilder, ConstraintSystem},
304+
expression::{Expression, ToExpr},
305+
utils::add_mle_list_by_expr,
306+
};
307+
308+
#[test]
309+
fn test_add_mle_list_by_expr() {
310+
type E = ff_ext::GoldilocksExt2;
311+
type F = p3_goldilocks::Goldilocks;
312+
let mut cs = ConstraintSystem::new(|| "test_root");
313+
let mut cb = CircuitBuilder::<E>::new(&mut cs);
314+
let x = cb.create_witin(|| "x");
315+
let y = cb.create_witin(|| "y");
316+
317+
let wits_in: Vec<ArcMultilinearExtension<E>> = (0..cs.num_witin as usize)
318+
.map(|_| vec![F::from_u64(1)].into_mle().into())
319+
.collect();
320+
321+
let mut virtual_polys = VirtualPolynomials::new(1, 0);
322+
323+
// 3xy + 2y
324+
let expr: Expression<E> = 3 * x.expr() * y.expr() + 2 * y.expr();
325+
326+
let distrinct_zerocheck_terms_set = add_mle_list_by_expr(
327+
&mut virtual_polys,
328+
None,
329+
wits_in.iter().collect_vec(),
330+
&expr,
331+
&[],
332+
GoldilocksExt2::ONE,
333+
);
334+
assert!(distrinct_zerocheck_terms_set.len() == 2);
335+
assert!(virtual_polys.degree() == 2);
336+
337+
// 3x^3
338+
let expr: Expression<E> = 3 * x.expr() * x.expr() * x.expr();
339+
let distrinct_zerocheck_terms_set = add_mle_list_by_expr(
340+
&mut virtual_polys,
341+
None,
342+
wits_in.iter().collect_vec(),
343+
&expr,
344+
&[],
345+
GoldilocksExt2::ONE,
346+
);
347+
assert!(distrinct_zerocheck_terms_set.len() == 1);
348+
assert!(virtual_polys.degree() == 3);
349+
}
350+
}

0 commit comments

Comments
 (0)