Skip to content

Commit 62103a9

Browse files
authored
batched sumcheck suffix alignment (#870)
To close #867 New design only apply on prover, as verifier logic remains unchanged ### benchmark original e2e benchmark result remain no change, which is expected ``` fibonacci_max_steps_1048576/prove_fibonacci/fibonacci_max_steps_1048576 time: [3.9061 s 3.9327 s 3.9608 s] change: [-2.0470% -1.2697% -0.4292%] (p = 0.01 < 0.05) Change within noise threshold. ```
1 parent 057a016 commit 62103a9

File tree

8 files changed

+102
-46
lines changed

8 files changed

+102
-46
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

multilinear_extensions/src/virtual_polys.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub struct VirtualPolynomials<'a, E: ExtensionField> {
1818

1919
impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
2020
pub fn new(num_threads: usize, max_num_variables: usize) -> Self {
21+
debug_assert!(num_threads > 0);
2122
VirtualPolynomials {
2223
num_threads,
2324
polys: (0..num_threads)

sumcheck/src/prover.rs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
285285
assert!(extrapolation_aux.len() == max_degree - 1);
286286
let num_polys = polynomial.flattened_ml_extensions.len();
287287
Self {
288+
max_num_variables: polynomial.aux_info.max_num_variables,
288289
challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables),
289290
round: 0,
290291
poly: polynomial,
@@ -335,7 +336,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
335336
let chal = challenge.unwrap();
336337
self.challenges.push(chal);
337338
let r = self.challenges[self.round - 1];
338-
339339
self.fix_var(r.elements);
340340
}
341341
exit_span!(span);
@@ -345,22 +345,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
345345

346346
// Step 2: generate sum for the partial evaluated polynomial:
347347
// f(r_1, ... r_m,, x_{m+1}... x_n)
348-
//
349-
// To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars,
350-
// for it evaluation value we need to times 2^(max_num_vars - num_vars)
351-
// E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n
352-
// For i round univariate poly, f^i(x)
353-
// f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds
354-
// = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n'
355-
// = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b)
356-
// same applied on f^i[1]
357-
// It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value
358348
let span = entered_span!("products_sum");
359349
let AdditiveVec(products_sum) = self.poly.products.iter().fold(
360350
AdditiveVec::new(self.poly.aux_info.max_degree + 1),
361351
|mut products_sum, (coefficient, products)| {
362352
let span = entered_span!("sum");
363-
364353
let f = &self.poly.flattened_ml_extensions;
365354
let mut sum: Vec<E> = match products.len() {
366355
1 => sumcheck_code_gen!(1, false, |i| &f[products[i]]).to_vec(),
@@ -418,12 +407,22 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
418407
.collect()
419408
}
420409

410+
pub fn expected_numvars_at_round(&self) -> usize {
411+
// first round start from 1
412+
let num_vars = self.max_num_variables + 1 - self.round;
413+
debug_assert!(num_vars > 0, "make sumcheck work on constant");
414+
num_vars
415+
}
416+
421417
/// fix_var
422418
pub fn fix_var(&mut self, r: E) {
419+
let expected_numvars_at_round = self.expected_numvars_at_round();
423420
self.poly_index_fixvar_in_place
424421
.iter_mut()
425422
.zip_eq(self.poly.flattened_ml_extensions.iter_mut())
426423
.for_each(|(can_fixvar_in_place, poly)| {
424+
debug_assert!(poly.num_vars() <= expected_numvars_at_round);
425+
debug_assert!(poly.num_vars() > 0);
427426
if *can_fixvar_in_place {
428427
// in place
429428
let poly = Arc::get_mut(poly);
@@ -433,8 +432,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
433432
}
434433
};
435434
} else if poly.num_vars() > 0 {
436-
*poly = Arc::new(poly.fix_variables(&[r]));
437-
*can_fixvar_in_place = true;
435+
if expected_numvars_at_round == poly.num_vars() {
436+
*poly = Arc::new(poly.fix_variables(&[r]));
437+
*can_fixvar_in_place = true;
438+
}
438439
} else {
439440
panic!("calling sumcheck on constant")
440441
}
@@ -524,6 +525,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
524525
let max_degree = polynomial.aux_info.max_degree;
525526
let num_polys = polynomial.flattened_ml_extensions.len();
526527
let prover_state = Self {
528+
max_num_variables: polynomial.aux_info.max_num_variables,
527529
challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables),
528530
round: 0,
529531
poly: polynomial,
@@ -579,7 +581,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
579581
let chal = challenge.unwrap();
580582
self.challenges.push(chal);
581583
let r = self.challenges[self.round - 1];
582-
583584
self.fix_var(r.elements);
584585
}
585586
exit_span!(span);
@@ -641,6 +642,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
641642

642643
/// fix_var
643644
pub fn fix_var_parallel(&mut self, r: E) {
645+
let expected_numvars_at_round = self.expected_numvars_at_round();
644646
self.poly_index_fixvar_in_place
645647
.par_iter_mut()
646648
.zip_eq(self.poly.flattened_ml_extensions.par_iter_mut())
@@ -654,8 +656,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
654656
}
655657
};
656658
} else if poly.num_vars() > 0 {
657-
*poly = Arc::new(poly.fix_variables_parallel(&[r]));
658-
*can_fixvar_in_place = true;
659+
if expected_numvars_at_round == poly.num_vars() {
660+
*poly = Arc::new(poly.fix_variables_parallel(&[r]));
661+
*can_fixvar_in_place = true;
662+
}
659663
} else {
660664
panic!("calling sumcheck on constant")
661665
}

sumcheck/src/structs.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pub struct IOPProverState<'a, E: ExtensionField> {
4444
/// points with precomputed barycentric weights for extrapolating smaller
4545
/// degree uni-polys to `max_degree + 1` evaluations.
4646
pub(crate) extrapolation_aux: Vec<(Vec<E>, Vec<E>)>,
47+
pub(crate) max_num_variables: usize,
4748
/// record poly should fix variable in place or not
4849
pub(crate) poly_index_fixvar_in_place: Vec<bool>,
4950
}

sumcheck/src/test.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
use ark_std::{rand::RngCore, test_rng};
66
use ff_ext::{ExtensionField, FromUniformBytes, GoldilocksExt2};
77
use multilinear_extensions::{
8+
util::max_usable_threads,
89
virtual_poly::{VPAuxInfo, VirtualPolynomial},
910
virtual_polys::VirtualPolynomials,
1011
};
@@ -13,17 +14,19 @@ use transcript::{BasicTranscript, Transcript};
1314

1415
#[test]
1516
fn test_sumcheck_with_different_degree() {
16-
let nv = vec![4, 5]; // test polynomial mixed with different num_var
17-
test_sumcheck_with_different_degree_helper::<GoldilocksExt2>(nv);
17+
// test polynomial mixed with different num_var
18+
let nv = vec![3, 4, 5];
19+
let num_polys = nv.len();
20+
for num_threads in 1..num_polys.min(max_usable_threads()) {
21+
test_sumcheck_with_different_degree_helper::<GoldilocksExt2>(num_threads, &nv);
22+
}
1823
}
1924

20-
fn test_sumcheck_with_different_degree_helper<E: ExtensionField>(nv: Vec<usize>) {
25+
fn test_sumcheck_with_different_degree_helper<E: ExtensionField>(num_threads: usize, nv: &[usize]) {
2126
let mut rng = test_rng();
2227
let degree = 2;
2328
let num_multiplicands_range = (degree, degree + 1);
2429
let num_products = 1;
25-
// TODO investigate error when num_threads > 1
26-
let num_threads = 1;
2730
let mut transcript = BasicTranscript::<E>::new(b"test");
2831

2932
let max_num_variables = *nv.iter().max().unwrap();
@@ -69,10 +72,11 @@ fn test_sumcheck_with_different_degree_helper<E: ExtensionField>(nv: Vec<usize>)
6972
.map(|c| c.elements)
7073
.collect::<Vec<_>>();
7174
assert_eq!(r.len(), max_num_variables);
75+
// r are right alignment
7276
assert!(
7377
input_polys
7478
.iter()
75-
.map(|(poly, _)| { poly.evaluate(&r[..poly.aux_info.max_num_variables]) })
79+
.map(|(poly, _)| { poly.evaluate(&r[r.len() - poly.aux_info.max_num_variables..]) })
7680
.sum::<E>()
7781
== subclaim.expected_evaluation,
7882
"wrong subclaim"

sumcheck_macro/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ itertools.workspace = true
1717
p3 = { path = "../p3" }
1818
proc-macro2 = "1.0.92"
1919
quote = "1.0"
20+
rand.workspace = true
2021
syn = { version = "2.0", features = ["full"] }
2122

2223
[dev-dependencies]

sumcheck_macro/examples/expand.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,21 @@
22
/// ```sh
33
/// cargo expand --example expand
44
/// ```
5-
use ff_ext::ExtensionField;
6-
use ff_ext::GoldilocksExt2;
5+
use ff_ext::{ExtensionField, GoldilocksExt2};
76
use multilinear_extensions::{
87
mle::FieldType, util::largest_even_below, virtual_poly::VirtualPolynomial,
98
};
109
use p3::field::PrimeCharacteristicRing;
10+
use rand::rngs::OsRng;
1111
use sumcheck::util::{AdditiveArray, ceil_log2};
1212

1313
#[derive(Default)]
1414
struct Container<'a, E: ExtensionField> {
1515
poly: VirtualPolynomial<'a, E>,
16-
round: usize,
1716
}
1817

1918
fn main() {
20-
let c = Container::<GoldilocksExt2>::default();
19+
let c = Container::<GoldilocksExt2>::new();
2120
c.run();
2221
}
2322

@@ -26,4 +25,14 @@ impl<E: ExtensionField> Container<'_, E> {
2625
let _result: AdditiveArray<_, 4> =
2726
sumcheck_macro::sumcheck_code_gen!(3, false, |_| &self.poly.flattened_ml_extensions[0]);
2827
}
28+
29+
pub fn expected_numvars_at_round(&self) -> usize {
30+
1
31+
}
32+
33+
pub fn new() -> Self {
34+
Self {
35+
poly: VirtualPolynomial::random(3, (4, 5), 2, &mut OsRng).0,
36+
}
37+
}
2938
}

sumcheck_macro/src/lib.rs

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -219,33 +219,68 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr
219219
};
220220

221221
let iter = if parallalize {
222-
quote! {.into_par_iter().step_by(2).with_min_len(64)}
222+
quote! {.into_par_iter().step_by(2).rev().with_min_len(64)}
223223
} else {
224224
quote! {.step_by(2).rev()}
225225
};
226226

227227
// Generate the final AdditiveArray expression.
228+
229+
// special case: generate product for polynomial num_var less than current expected num_var
230+
// which happened when we batching sumcheck with different num_vars
231+
let product = mul_exprs(
232+
(1..=degree)
233+
.map(|j: u32| {
234+
let v = ident(format!("v{j}"));
235+
quote! {#v[b]}
236+
})
237+
.collect(),
238+
);
239+
228240
let degree_plus_one = (degree + 1) as usize;
229241
quote! {
230-
let res = (0..largest_even_below(v1.len()))
231-
#iter
232-
.map(|b| {
233-
#additive_array_items
234-
})
235-
.sum::<AdditiveArray<_, #degree_plus_one>>();
236-
let res = if v1.len() == 1 {
237-
let b = 0;
238-
AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one])
239-
} else {
240-
res
241-
};
242-
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(v1.len()).max(1) + self.round - 1);
243-
if num_vars_multiplicity > 0 {
244-
AdditiveArray(res.0.map(|e| e * E::BaseField::from_u64(1 << num_vars_multiplicity)))
242+
// To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars
243+
// we actually need to have a full sum, times 2^(bh_num_vars - num_vars) to accumulate into univariate computation
244+
// E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n
245+
// For i < n - n', to compute univariate poly, f^i(x), b is i-th round boolean hypercube
246+
// f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} challenge get from prev rounds
247+
// = \sum_b f_1(b) + f_2(r, 0, b)
248+
// = 2^(|b| - |b1|) * \sum_b1 f_1(b1) + \sum_b f_2(r, 0, b)
249+
// b1 is suffix alignment with b
250+
// same applied on f^i[1], f^i[2], ... f^i[degree + 1]
251+
// It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value
252+
253+
// NOTE: current method work in suffix alignment order
254+
let num_var = ceil_log2(v1.len());
255+
let expected_numvars_at_round = self.expected_numvars_at_round();
256+
if num_var < expected_numvars_at_round {
257+
// TODO optimize by caching computed result for later round reuse
258+
// need to figure out how to cache in one place to support base/extension field
259+
let mut sum = (0..largest_even_below(v1.len())).map(
260+
|b| {
261+
#product
262+
},
263+
).sum();
264+
// calculate multiplicity term
265+
// minus one because when expected num of var is n_i, the boolean hypercube dimension only n_i-1
266+
let num_vars_multiplicity = self.expected_numvars_at_round().saturating_sub(1).saturating_sub(num_var);
267+
if num_vars_multiplicity > 0 {
268+
sum *= E::BaseField::from_u64(1 << num_vars_multiplicity);
269+
}
270+
AdditiveArray::<_, #degree_plus_one>([sum; #degree_plus_one])
245271
} else {
246-
res
272+
if v1.len() == 1 {
273+
let b = 0;
274+
AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one])
275+
} else {
276+
(0..largest_even_below(v1.len()))
277+
#iter
278+
.map(|b| {
279+
#additive_array_items
280+
})
281+
.sum::<AdditiveArray<_, #degree_plus_one>>()
282+
}
247283
}
248-
249284
}
250285
};
251286

@@ -314,7 +349,7 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr
314349
// Generate the second match statement that maps f vars to AdditiveArray.
315350
out = quote! {
316351
{
317-
#out
352+
#out
318353
match (#match_input) {
319354
#match_arms
320355
_ => unreachable!(),

0 commit comments

Comments
 (0)