diff --git a/Cargo.lock b/Cargo.lock index 1d0535ea5..e7392a699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -887,6 +887,9 @@ name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "elf" @@ -1607,6 +1610,7 @@ dependencies = [ name = "multilinear_extensions" version = "0.1.0" dependencies = [ + "either", "env_logger", "ff_ext", "itertools 0.13.0", @@ -2760,6 +2764,7 @@ version = "0.1.0" dependencies = [ "criterion", "crossbeam-channel", + "either", "ff_ext", "itertools 0.13.0", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index fd664b800..d90f59573 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ cfg-if = "1.0" clap = { version = "4.5", features = ["derive"] } criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" +either = { version = "1.15.*", features = ["serde"] } itertools = "0.13" num-bigint = { version = "0.4.6" } num-derive = "0.4" diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index d816c4d07..1e2c1c0d9 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -16,6 +16,7 @@ use anyhow::{Result, anyhow}; use ff_ext::{ExtensionField, SmallField}; +use itertools::Either; use multilinear_extensions::{Expression, impl_expr_from_unsigned}; use num_derive::ToPrimitive; use strum_macros::{Display, EnumIter}; diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 452f88635..7ac1456dc 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -264,7 +264,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { || "require_equal", |cb| { cb.cs - .require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form()) + .require_zero(name_fn, a.get_monomial_form() - b.get_monomial_form()) }, ) } diff --git a/ceno_zkvm/src/chip_handler/global_state.rs b/ceno_zkvm/src/chip_handler/global_state.rs index 9af586020..a3306c64e 100644 --- a/ceno_zkvm/src/chip_handler/global_state.rs +++ b/ceno_zkvm/src/chip_handler/global_state.rs @@ -2,13 +2,13 @@ use ff_ext::ExtensionField; use super::GlobalStateRegisterMachineChipOperations; use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::RAMType}; -use multilinear_extensions::Expression; +use multilinear_extensions::{Expression, ToExpr}; use p3::field::PrimeCharacteristicRing; impl GlobalStateRegisterMachineChipOperations for CircuitBuilder<'_, E> { fn state_in(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError> { let record: Vec> = vec![ - Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), pc, ts, ]; @@ -18,7 +18,7 @@ impl GlobalStateRegisterMachineChipOperations for CircuitB fn state_out(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError> { let record: Vec> = vec![ - Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), pc, ts, ]; diff --git a/ceno_zkvm/src/chip_handler/utils.rs b/ceno_zkvm/src/chip_handler/utils.rs index 8c3c6abae..3dd90636e 100644 --- a/ceno_zkvm/src/chip_handler/utils.rs +++ b/ceno_zkvm/src/chip_handler/utils.rs @@ -2,7 +2,7 @@ use std::iter::successors; use ff_ext::ExtensionField; use itertools::izip; -use multilinear_extensions::Expression; +use multilinear_extensions::{Expression, ToExpr}; use p3::field::PrimeCharacteristicRing; pub fn rlc_chip_record( @@ -31,7 +31,7 @@ pub fn power_sequence( ), "expression must be constant or challenge" ); - successors(Some(Expression::Constant(E::BaseField::ONE)), move |prev| { + successors(Some(E::BaseField::ONE.expr()), move |prev| { Some(prev.clone() * base.clone()) }) } diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 6cfc18f50..59471dc61 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,5 +1,7 @@ use itertools::{Itertools, chain}; -use multilinear_extensions::{Expression, Fixed, Instance, StructuralWitIn, WitIn, WitnessId}; +use multilinear_extensions::{ + Expression, Fixed, Instance, StructuralWitIn, ToExpr, WitIn, WitnessId, +}; use serde::de::DeserializeOwned; use std::{collections::HashMap, iter::once, marker::PhantomData}; @@ -256,11 +258,9 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), ZKVMError> { let rlc_record = self.rlc_chip_record( - std::iter::once(Expression::Constant(E::BaseField::from_u64( - rom_type as u64, - ))) - .chain(record.clone()) - .collect(), + std::iter::once(E::BaseField::from_u64(rom_type as u64).expr()) + .chain(record.clone()) + .collect(), ); assert_eq!( rlc_record.degree(), @@ -432,7 +432,7 @@ impl ConstraintSystem { let assert_zero_expr = if assert_zero_expr.is_monomial_form() { assert_zero_expr } else { - let e = assert_zero_expr.to_monomial_form(); + let e = assert_zero_expr.get_monomial_form(); assert!(e.is_monomial_form(), "failed to put into monomial form"); e }; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index fd2e1ce0a..f4474f228 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1147,22 +1147,16 @@ Hints: let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) = derive_ram_rws!(RAMType::GlobalState); - gs_rs.insert(eval_by_expr_with_instance( - &[], - &[], - &[], - &instance, - &challenges, - &gs_final, - )); - gs_ws.insert(eval_by_expr_with_instance( - &[], - &[], - &[], - &instance, - &challenges, - &gs_init, - )); + gs_rs.insert( + eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_final) + .right() + .unwrap(), + ); + gs_ws.insert( + eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_init) + .right() + .unwrap(), + ); // gs stores { (pc, timestamp) } find_rw_mismatch!( @@ -1382,7 +1376,7 @@ mod tests { GoldilocksExt2::ONE, GoldilocksExt2::ZERO, )), - Box::new(Expression::Constant(Goldilocks::from_u64(U5 as u64))), + Box::new(Goldilocks::from_u64(U5 as u64).expr()), )), Box::new(Expression::Challenge( 0, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 46aa58243..2db1043ee 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,16 +1,18 @@ use ff_ext::ExtensionField; use std::collections::{BTreeMap, BTreeSet, HashMap}; -use itertools::{Itertools, enumerate, izip}; +use itertools::{Either, Itertools, enumerate, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ + Expression, mle::IntoMLE, util::ceil_log2, virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec}, - virtual_polys::VirtualPolynomials, + virtual_polys::{VirtualPolynomials, VirtualPolynomialsBuilder}, }; use p3::field::{PrimeCharacteristicRing, dot_product}; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use std::iter::Iterator; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverState}, @@ -1230,10 +1232,15 @@ impl TowerProver { let num_threads = optimal_sumcheck_threads(out_rt.len()); let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); - let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); + + let mut expr_builder = VirtualPolynomialsBuilder::default(); + let mut exprs = + Vec::>::with_capacity(prod_specs.len() + logup_specs.len()); + let eq_expr = expr_builder.lift(&eq); for (s, alpha) in izip!(&prod_specs, &alpha_pows) { if round < s.witness.len() { + let alpha_expr = Expression::Constant(Either::Right(*alpha)); let layer_polys = &s.witness[round]; // sanity check @@ -1246,11 +1253,9 @@ impl TowerProver { }) ); + let layer_polys_product = layer_polys.iter().map(|layer_poly| expr_builder.lift(layer_poly)).product::>(); // \sum_s eq(rt, s) * alpha^{i} * ([in_i0[s] * in_i1[s] * .... in_i{num_product_fanin}[s]]) - virtual_polys.add_mle_list( - [vec![&eq], layer_polys.iter().collect()].concat(), - *alpha, - ) + exprs.push(eq_expr.clone() * alpha_expr *layer_polys_product); } } @@ -1259,37 +1264,36 @@ impl TowerProver { if round < s.witness.len() { let layer_polys = &s.witness[round]; // sanity check - assert_eq!(layer_polys.len(), 4); // p1, q1, p2, q2 + assert_eq!(layer_polys.len(), 4); // p1, p2, q1, q2 assert!( layer_polys .iter() .all(|f| f.evaluations().len() == 1 << (log_num_fanin * round)), ); - let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); + let (alpha_numerator, alpha_denominator) = (Expression::Constant(Either::Right(alpha[0])), Expression::Constant(Either::Right(alpha[1]))); - let (q2, q1, p2, p1) = ( - &layer_polys[3], - &layer_polys[2], - &layer_polys[1], - &layer_polys[0], + let (p1, p2, q1, q2) = ( + expr_builder.lift(&layer_polys[0]), + expr_builder.lift(&layer_polys[1]), + expr_builder.lift(&layer_polys[2]), + expr_builder.lift(&layer_polys[3]), ); - // \sum_s eq(rt, s) * alpha_numerator^{i} * (p1 * q2 + p2 * q1) - virtual_polys.add_mle_list(vec![&eq, &p1, &q2], *alpha_numerator); - virtual_polys.add_mle_list(vec![&eq, &p2, &q1], *alpha_numerator); - - // \sum_s eq(rt, s) * alpha_denominator^{i} * (q1 * q2) - virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator); + // \sum_s eq(rt, s) * (alpha_numerator^{i} * (p1 * q2 + p2 * q1) + alpha_denominator^{i} * q1 * q2) + exprs.push(eq_expr.clone() * (alpha_numerator * (p1 * q2.clone() + p2 * q1.clone()) + alpha_denominator * q1 * q2)); } } let wrap_batch_span = entered_span!("wrap_batch"); - // NOTE: at the time of adding this span, visualizing it with the flamegraph layer - // shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys` - // This is likely a bug in the tracing-flame crate. let (sumcheck_proofs, state) = IOPProverState::prove( - virtual_polys, + expr_builder.to_virtual_polys( + num_threads, + out_rt.len(), + None, + &[exprs.into_iter().sum::>()], + &[], + ), transcript, ); exit_span!(wrap_batch_span); @@ -1323,11 +1327,10 @@ impl TowerProver { for (i, s) in enumerate(&logup_specs) { if round < s.witness.len() { // collect evals belong to current spec - // p1, q2, p2, q1 let p1 = *evals_iter.next().expect("insufficient evals length"); - let q2 = *evals_iter.next().expect("insufficient evals length"); let p2 = *evals_iter.next().expect("insufficient evals length"); let q1 = *evals_iter.next().expect("insufficient evals length"); + let q2 = *evals_iter.next().expect("insufficient evals length"); proofs.push_logup_evals_and_point(i, vec![p1, p2, q1, q2], rt_prime.clone()); } } diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index ededc5d82..f1a4e2d73 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -261,9 +261,11 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( &|witness_id, _, _, _| structual_witnesses[witness_id as usize].clone(), &|i| instance[i.0].clone(), &|scalar| { - let scalar: ArcMultilinearExtension = Arc::new( - DenseMultilinearExtension::from_evaluations_vec(0, vec![scalar]), - ); + let scalar: ArcMultilinearExtension = + Arc::new(DenseMultilinearExtension::from_evaluations_vec( + 0, + vec![scalar.left().expect("do not support extension field")], + )); scalar }, &|challenge_id, pow, scalar, offset| { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index d0ce434d3..634e524c3 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -303,7 +303,9 @@ impl> ZKVMVerifier pi_evals, &challenges, &self.vk.initial_global_state_expr, - ); + ) + .right() + .unwrap(); prod_w *= initial_global_state; let finalize_global_state = eval_by_expr_with_instance( &[], @@ -312,7 +314,9 @@ impl> ZKVMVerifier pi_evals, &challenges, &self.vk.finalize_global_state_expr, - ); + ) + .right() + .unwrap(); prod_r *= finalize_global_state; // check rw_set equality across all proofs if prod_r != prod_w { @@ -476,6 +480,8 @@ impl> ZKVMVerifier .iter() .map(|expr| { eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr) + .right() + .unwrap() }) .collect(); let w_records_in_evals: Vec<_> = cs @@ -483,6 +489,8 @@ impl> ZKVMVerifier .iter() .map(|expr| { eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr) + .right() + .unwrap() }) .collect(); let lk_records_in_evals: Vec<_> = cs @@ -490,6 +498,8 @@ impl> ZKVMVerifier .iter() .map(|expr| { eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr) + .right() + .unwrap() }) .collect(); let computed_evals = [ @@ -529,6 +539,8 @@ impl> ZKVMVerifier challenges, expr, ) + .right() + .unwrap() }) .sum::() }, @@ -545,6 +557,8 @@ impl> ZKVMVerifier // verify zero expression (degree = 1) statement, thus no sumcheck if cs.assert_zero_expressions.iter().any(|expr| { eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr) + .right() + .unwrap() != E::ZERO }) { return Err(ZKVMError::VerifyError("zero expression != 0".into())); @@ -813,7 +827,10 @@ impl> ZKVMVerifier pi, challenges, expr, - ) != expected_evals + ) + .right() + .unwrap() + != expected_evals }) { return Err(ZKVMError::VerifyError( "record evaluate != expected_evals".into(), diff --git a/ceno_zkvm/src/state.rs b/ceno_zkvm/src/state.rs index 74961390e..d2dfa6cfc 100644 --- a/ceno_zkvm/src/state.rs +++ b/ceno_zkvm/src/state.rs @@ -20,7 +20,7 @@ impl StateCircuit for GlobalState { circuit_builder: &mut crate::circuit_builder::CircuitBuilder, ) -> Result, ZKVMError> { let states: Vec> = vec![ - Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), circuit_builder.query_init_pc()?.expr(), circuit_builder.query_init_cycle()?.expr(), ]; @@ -32,7 +32,7 @@ impl StateCircuit for GlobalState { circuit_builder: &mut crate::circuit_builder::CircuitBuilder, ) -> Result, ZKVMError> { let states: Vec> = vec![ - Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), circuit_builder.query_end_pc()?.expr(), circuit_builder.query_end_cycle()?.expr(), ]; diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 203dd63e4..b2246d671 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -8,7 +8,7 @@ use crate::{ }; use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; use ff_ext::{ExtensionField, SmallField}; -use itertools::Itertools; +use itertools::{Either, Itertools}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ Expression, impl_expr_from_unsigned, virtual_poly::ArcMultilinearExtension, diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index da6ef2ee9..6aff57b25 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -156,7 +156,7 @@ impl UIntLimbs { limbs .into_iter() .take(Self::NUM_LIMBS) - .map(|limb| Expression::Constant(E::BaseField::from_u64(limb.into()))) + .map(|limb| E::BaseField::from_u64(limb.into()).expr()) .collect::>>(), ), carries: None, @@ -300,7 +300,7 @@ impl UIntLimbs { let k = C / 8; let shift_pows = { let mut shift_pows = Vec::with_capacity(k); - shift_pows.push(Expression::Constant(E::BaseField::ONE)); + shift_pows.push(E::BaseField::ONE.expr()); (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap() << 8)); shift_pows }; @@ -330,7 +330,7 @@ impl UIntLimbs { let k = C / 8; let shift_pows = { let mut shift_pows = Vec::with_capacity(k); - shift_pows.push(Expression::Constant(E::BaseField::ONE)); + shift_pows.push(E::BaseField::ONE.expr()); (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap() << 8)); shift_pows }; diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 9cc859e3c..c9e592e4c 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -76,15 +76,14 @@ impl UIntLimbs { let Expression::Constant(c) = constant else { panic!("addend is not a constant type"); }; - let b = c.to_canonical_u64(); + let b = c + .left() + .expect("do not support extension field here") + .to_canonical_u64(); // convert Expression::Constant to limbs let b_limbs = (0..Self::NUM_LIMBS) - .map(|i| { - Expression::Constant(E::BaseField::from_u64( - (b >> (C * i)) & Self::LIMB_BIT_MASK, - )) - }) + .map(|i| E::BaseField::from_u64((b >> (C * i)) & Self::LIMB_BIT_MASK).expr()) .collect_vec(); self.internal_add(cb, &b_limbs, with_overflow) @@ -296,9 +295,10 @@ mod tests { circuit_builder::{CircuitBuilder, ConstraintSystem}, uint::UIntLimbs, }; - use ff_ext::{ExtensionField, FieldInto, GoldilocksExt2}; + use ff_ext::{ExtensionField, GoldilocksExt2}; use itertools::Itertools; - use multilinear_extensions::{Expression, ToExpr, utils::eval_by_expr}; + use multilinear_extensions::{ToExpr, utils::eval_by_expr}; + use p3::field::PrimeCharacteristicRing; type E = GoldilocksExt2; #[test] @@ -430,7 +430,7 @@ mod tests { let uint_b = UIntLimbs::::new(|| "uint_b", &mut cb).unwrap(); uint_a.add(|| "uint_c", &mut cb, &uint_b, overflow).unwrap() } else { - let const_b = Expression::Constant(const_b.unwrap().into_f()); + let const_b = E::BaseField::from_u64(const_b.unwrap()).expr(); uint_a .add_const(|| "uint_c", &mut cb, const_b, overflow) .unwrap() diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 5f4c836e3..5ef0f4aef 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -224,7 +224,11 @@ pub fn add_mle_list_by_expr<'a, E: ExtensionField>( &|_| unreachable!(), &|witness_id| vec![(E::ONE, { vec![witness_id] })], &|structural_witness_id, _, _, _| vec![(E::ONE, { vec![structural_witness_id] })], - &|scalar| vec![(E::from(scalar), { vec![] })], + &|scalar| { + vec![(scalar.map_either(E::from, |scalar| scalar).into_inner(), { + vec![] + })] + }, &|challenge_id, pow, scalar, offset| { let challenge = challenges[challenge_id as usize]; vec![(challenge.exp_u64(pow as u64) * scalar + offset, vec![])] diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 410f0773e..d7b5d7234 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -17,8 +17,7 @@ pub type Param = >::Param; pub type ProverParam = >::ProverParam; pub type VerifierParam = >::VerifierParam; -/// A point is a vector of num_var length -pub type Point = Vec; +pub type Point = multilinear_extensions::mle::Point; pub fn pcs_setup>( poly_size: usize, diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index f60844fe0..4e7cae4ba 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -10,6 +10,7 @@ repository.workspace = true version.workspace = true [dependencies] +either.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true p3.workspace = true diff --git a/multilinear_extensions/src/expression.rs b/multilinear_extensions/src/expression.rs index e025fd650..19951c421 100644 --- a/multilinear_extensions/src/expression.rs +++ b/multilinear_extensions/src/expression.rs @@ -1,4 +1,4 @@ -mod monomial; +pub mod monomial; pub mod utils; use std::{ @@ -11,6 +11,7 @@ use std::{ use serde::de::DeserializeOwned; use ff_ext::{ExtensionField, SmallField}; +use itertools::Either; use p3::field::PrimeCharacteristicRing; pub type WitnessId = u16; @@ -32,7 +33,7 @@ pub enum Expression { /// Public Values Instance(Instance), /// Constant poly - Constant(E::BaseField), + Constant(Either), /// This is the sum of two expressions Sum(Box>, Box>), /// This is the product of two expressions @@ -51,9 +52,21 @@ enum MonomialState { ProductTerm, } +#[macro_export] +macro_rules! combine_cumulative_either { + ($a:expr, $b:expr, $op:expr) => { + match ($a, $b) { + (Either::Left(c1), Either::Left(c2)) => Either::Left($op(c1, c2)), + (Either::Left(c1), Either::Right(c2)) => Either::Right($op(c2, c1)), + (Either::Right(c1), Either::Left(c2)) => Either::Right($op(c1, c2)), + (Either::Right(c1), Either::Right(c2)) => Either::Right($op(c2, c1)), + } + }; +} + impl Expression { - pub const ZERO: Expression = Expression::Constant(E::BaseField::ZERO); - pub const ONE: Expression = Expression::Constant(E::BaseField::ONE); + pub const ZERO: Expression = Expression::Constant(Either::Left(E::BaseField::ZERO)); + pub const ONE: Expression = Expression::Constant(Either::Left(E::BaseField::ONE)); pub fn degree(&self) -> usize { match self { @@ -75,7 +88,7 @@ impl Expression { fixed_in: &impl Fn(&Fixed) -> T, wit_in: &impl Fn(WitnessId) -> T, // witin id structural_wit_in: &impl Fn(WitnessId, usize, u32, usize) -> T, - constant: &impl Fn(E::BaseField) -> T, + constant: &impl Fn(Either) -> T, challenge: &impl Fn(ChallengeId, usize, E, E) -> T, sum: &impl Fn(T, T) -> T, product: &impl Fn(T, T) -> T, @@ -101,7 +114,7 @@ impl Expression { wit_in: &impl Fn(WitnessId) -> T, // witin id structural_wit_in: &impl Fn(WitnessId, usize, u32, usize) -> T, instance: &impl Fn(Instance) -> T, - constant: &impl Fn(E::BaseField) -> T, + constant: &impl Fn(Either) -> T, challenge: &impl Fn(ChallengeId, usize, E, E) -> T, sum: &impl Fn(T, T) -> T, product: &impl Fn(T, T) -> T, @@ -211,8 +224,8 @@ impl Expression { Self::is_monomial_form_inner(MonomialState::SumTerm, self) } - pub fn to_monomial_form(&self) -> Self { - self.to_monomial_form_inner() + pub fn get_monomial_form(&self) -> Self { + self.get_monomial_terms().into_iter().sum() } pub fn is_constant(&self) -> bool { @@ -225,7 +238,9 @@ impl Expression { Expression::WitIn(_) => false, Expression::StructuralWitIn(..) => false, Expression::Instance(_) => false, - Expression::Constant(c) => *c == E::BaseField::ZERO, + Expression::Constant(c) => c + .map_either(|c| c == E::BaseField::ZERO, |c| c == E::ZERO) + .into_inner(), Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b), Expression::Product(a, b) => Self::is_zero_expr(a) || Self::is_zero_expr(b), Expression::ScaledSum(x, a, b) => { @@ -276,10 +291,10 @@ impl Neg for Expression { | Expression::StructuralWitIn(..) | Expression::Instance(_) => Expression::ScaledSum( Box::new(self), - Box::new(Expression::Constant(-E::BaseField::ONE)), - Box::new(Expression::Constant(E::BaseField::ZERO)), + Box::new(Expression::Constant(Either::Left(-E::BaseField::ONE))), + Box::new(Expression::Constant(Either::Left(E::BaseField::ZERO))), ), - Expression::Constant(c1) => Expression::Constant(-c1), + Expression::Constant(c1) => Expression::Constant(c1.map_either(|c| -c, |c| -c)), Expression::Sum(a, b) => Expression::Sum(-a, -b), Expression::Product(a, b) => Expression::Product(-a, b.clone()), Expression::ScaledSum(x, a, b) => Expression::ScaledSum(x, -a, -b), @@ -322,14 +337,14 @@ impl Add for Expression { | (Expression::Fixed(_), Expression::Constant(_)) | (Expression::Instance(_), Expression::Constant(_)) => Expression::ScaledSum( Box::new(self), - Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE))), Box::new(rhs), ), (Expression::Constant(_), Expression::WitIn(_)) | (Expression::Constant(_), Expression::Fixed(_)) | (Expression::Constant(_), Expression::Instance(_)) => Expression::ScaledSum( Box::new(rhs), - Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE))), Box::new(self), ), // challenge + witness @@ -339,14 +354,14 @@ impl Add for Expression { | (Expression::Fixed(_), Expression::Challenge(..)) | (Expression::Instance(_), Expression::Challenge(..)) => Expression::ScaledSum( Box::new(self), - Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE))), Box::new(rhs), ), (Expression::Challenge(..), Expression::WitIn(_)) | (Expression::Challenge(..), Expression::Fixed(_)) | (Expression::Challenge(..), Expression::Instance(_)) => Expression::ScaledSum( Box::new(rhs), - Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE))), Box::new(self), ), // constant + challenge @@ -357,7 +372,12 @@ impl Add for Expression { | ( Expression::Challenge(challenge_id, pow, scalar, offset), Expression::Constant(c1), - ) => Expression::Challenge(*challenge_id, *pow, *scalar, *offset + *c1), + ) => Expression::Challenge( + *challenge_id, + *pow, + *scalar, + either::for_both!(*c1, c1 => *offset + c1), + ), // challenge + challenge ( @@ -377,7 +397,9 @@ impl Add for Expression { } // constant + constant - (Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 + *c2), + (Expression::Constant(c1), Expression::Constant(c2)) => { + Expression::Constant(combine_cumulative_either!(*c1, *c2, |c1, c2| c1 + c2)) + } // constant + scaled sum (c1 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) @@ -458,7 +480,7 @@ impl Sub for Expression { | (Expression::Fixed(_), Expression::Constant(_)) | (Expression::Instance(_), Expression::Constant(_)) => Expression::ScaledSum( Box::new(self), - Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE))), Box::new(rhs.neg()), ), @@ -469,7 +491,7 @@ impl Sub for Expression { | (Expression::Constant(_), Expression::Fixed(_)) | (Expression::Constant(_), Expression::Instance(_)) => Expression::ScaledSum( Box::new(rhs), - Box::new(Expression::Constant(E::BaseField::ONE.neg())), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE.neg()))), Box::new(self), ), @@ -480,7 +502,7 @@ impl Sub for Expression { | (Expression::Fixed(_), Expression::Challenge(..)) | (Expression::Instance(_), Expression::Challenge(..)) => Expression::ScaledSum( Box::new(self), - Box::new(Expression::Constant(E::BaseField::ONE)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE))), Box::new(rhs.neg()), ), @@ -491,7 +513,7 @@ impl Sub for Expression { | (Expression::Challenge(..), Expression::Fixed(_)) | (Expression::Challenge(..), Expression::Instance(_)) => Expression::ScaledSum( Box::new(rhs), - Box::new(Expression::Constant(E::BaseField::ONE.neg())), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE.neg()))), Box::new(self), ), @@ -499,13 +521,23 @@ impl Sub for Expression { ( Expression::Constant(c1), Expression::Challenge(challenge_id, pow, scalar, offset), - ) => Expression::Challenge(*challenge_id, *pow, *scalar, offset.neg() + *c1), + ) => Expression::Challenge( + *challenge_id, + *pow, + *scalar, + either::for_both!(*c1, c1 => offset.neg() + c1), + ), // challenge - constant ( Expression::Challenge(challenge_id, pow, scalar, offset), Expression::Constant(c1), - ) => Expression::Challenge(*challenge_id, *pow, *scalar, *offset - *c1), + ) => Expression::Challenge( + *challenge_id, + *pow, + *scalar, + either::for_both!(*c1, c1 => *offset - c1), + ), // challenge - challenge ( @@ -525,7 +557,14 @@ impl Sub for Expression { } // constant - constant - (Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 - *c2), + (Expression::Constant(c1), Expression::Constant(c2)) => { + Expression::Constant(match (c1, c2) { + (Either::Left(c1), Either::Left(c2)) => Either::Left(*c1 - *c2), + (Either::Left(c1), Either::Right(c2)) => Either::Right(c2.neg() + *c1), + (Either::Right(c1), Either::Left(c2)) => Either::Right(*c1 - *c2), + (Either::Right(c1), Either::Right(c2)) => Either::Right(*c1 - *c2), + }) + } // constant - scalesum (c1 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) => { @@ -674,7 +713,7 @@ impl Mul for Expression { | (w @ Expression::Fixed(..), c @ Expression::Constant(_)) => Expression::ScaledSum( Box::new(w.clone()), Box::new(c.clone()), - Box::new(Expression::Constant(E::BaseField::ZERO)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ZERO))), ), // challenge * witin // challenge * fixed @@ -684,7 +723,7 @@ impl Mul for Expression { | (w @ Expression::Fixed(..), c @ Expression::Challenge(..)) => Expression::ScaledSum( Box::new(w.clone()), Box::new(c.clone()), - Box::new(Expression::Constant(E::BaseField::ZERO)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ZERO))), ), // instance * witin // instance * fixed @@ -694,7 +733,7 @@ impl Mul for Expression { | (w @ Expression::Fixed(..), c @ Expression::Instance(..)) => Expression::ScaledSum( Box::new(w.clone()), Box::new(c.clone()), - Box::new(Expression::Constant(E::BaseField::ZERO)), + Box::new(Expression::Constant(Either::Left(E::BaseField::ZERO))), ), // constant * challenge ( @@ -704,7 +743,12 @@ impl Mul for Expression { | ( Expression::Challenge(challenge_id, pow, scalar, offset), Expression::Constant(c1), - ) => Expression::Challenge(*challenge_id, *pow, *scalar * *c1, *offset * *c1), + ) => Expression::Challenge( + *challenge_id, + *pow, + either::for_both!(*c1, c1 => *scalar * c1), + either::for_both!(*c1, c1 => *offset * c1), + ), // challenge * challenge ( Expression::Challenge(challenge_id1, pow1, s1, offset1), @@ -755,7 +799,9 @@ impl Mul for Expression { } // constant * constant - (Expression::Constant(c1), Expression::Constant(c2)) => Expression::Constant(*c1 * *c2), + (Expression::Constant(c1), Expression::Constant(c2)) => { + Expression::Constant(combine_cumulative_either!(*c1, *c2, |c1, c2| c1 * c2)) + } // scaledsum * constant (Expression::ScaledSum(x, a, b), c2 @ Expression::Constant(_)) | (c2 @ Expression::Constant(_), Expression::ScaledSum(x, a, b)) => { @@ -859,7 +905,7 @@ impl ToExpr for Instance { impl> ToExpr for F { type Output = Expression; fn expr(&self) -> Expression { - Expression::Constant(*self) + Expression::Constant(Either::Left(*self)) } } @@ -884,7 +930,7 @@ macro_rules! impl_expr_from_unsigned { $( impl> From<$t> for Expression { fn from(value: $t) -> Self { - Expression::Constant(F::from_u64(value as u64)) + Expression::Constant(Either::Left(F::from_u64(value as u64))) } } )* @@ -899,7 +945,7 @@ macro_rules! impl_from_signed { impl> From<$t> for Expression { fn from(value: $t) -> Self { let reduced = (value as i128).rem_euclid(F::MODULUS_U64 as i128) as u64; - Expression::Constant(F::from_u64(reduced)) + Expression::Constant(Either::Left(F::from_u64(reduced))) } } )* @@ -956,9 +1002,13 @@ pub mod fmt { s } } - Expression::Constant(constant) => { - base_field::(constant, true).to_string() - } + Expression::Constant(constant) => constant + .as_ref() + .map_either( + |constant| base_field::(constant, true).to_string(), + |constant| field(constant).to_string(), + ) + .into_inner(), Expression::Fixed(fixed) => format!("{:?}", fixed), Expression::Instance(i) => format!("{:?}", i), Expression::Sum(left, right) => { @@ -989,9 +1039,6 @@ pub mod fmt { } pub fn field(field: &E) -> String { - let name = format!("{:?}", field); - let name = name.split('(').next().unwrap_or("ExtensionField"); - let data = field .as_bases() .iter() @@ -1004,7 +1051,7 @@ pub mod fmt { if only_one_limb { data[0].to_string() } else { - format!("{name}[{}]", data.join(",")) + format!("[{}]", data.join(",")) } } @@ -1019,9 +1066,9 @@ pub mod fmt { } else { // hex if value > F::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { - parens(format!("-{:#x}", F::MODULUS_U64 - value), add_parens) + parens(format!("-{}", F::MODULUS_U64 - value), add_parens) } else { - format!("{value:#x}") + format!("{value}") } } } @@ -1061,6 +1108,7 @@ mod tests { use crate::expression::WitIn; use super::{Expression, ToExpr, fmt}; + use either::Either; use ff_ext::{FieldInto, GoldilocksExt2}; use p3::field::PrimeCharacteristicRing; @@ -1092,14 +1140,14 @@ mod tests { expr, Expression::ScaledSum( Box::new(x.expr()), - Box::new(Expression::Constant(3.into_f())), - Box::new(Expression::Constant(0.into_f())) + Box::new(Expression::Constant(Either::Left(3.into_f()))), + Box::new(Expression::Constant(Either::Left(0.into_f()))) ) ); // constant * challenge // 3 * (c^3 + 1) - let expr: Expression = Expression::Constant(3.into_f()); + let expr: Expression = Expression::Constant(Either::Left(3.into_f())); let c = Expression::Challenge(0, 3, 1.into_f(), 1.into_f()); assert_eq!( expr * c, diff --git a/multilinear_extensions/src/expression/monomial.rs b/multilinear_extensions/src/expression/monomial.rs index 78abe7ac6..fd45b7081 100644 --- a/multilinear_extensions/src/expression/monomial.rs +++ b/multilinear_extensions/src/expression/monomial.rs @@ -1,28 +1,35 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, iproduct}; +use serde::{Deserialize, Serialize}; use super::Expression; +use crate::expression::ToExpr; use Expression::*; -use std::iter::Sum; +use p3::field::PrimeCharacteristicRing; +use std::{fmt::Display, iter::Sum}; impl Expression { - pub(super) fn to_monomial_form_inner(&self) -> Self { - Self::combine(self.distribute()).into_iter().sum() + pub fn get_monomial_terms(&self) -> Vec, Expression>> { + Self::combine(self.distribute()) + .into_iter() + // filter coeff = 0 monimial terms + .filter(|Term { scalar, .. }| *scalar != E::BaseField::ZERO.expr()) + .collect_vec() } - fn distribute(&self) -> Vec> { + fn distribute(&self) -> Vec, Expression>> { match self { Constant(_) => { vec![Term { - coeff: self.clone(), - vars: vec![], + scalar: self.clone(), + product: vec![], }] } Fixed(_) | WitIn(_) | StructuralWitIn(..) | Instance(_) | Challenge(..) => { vec![Term { - coeff: Expression::ONE, - vars: vec![self.clone()], + scalar: Expression::ONE, + product: vec![self.clone()], }] } @@ -30,50 +37,66 @@ impl Expression { Product(a, b) => iproduct!(a.distribute(), b.distribute()) .map(|(a, b)| Term { - coeff: &a.coeff * &b.coeff, - vars: chain!(&a.vars, &b.vars).cloned().collect(), + scalar: &a.scalar * &b.scalar, + product: chain!(&a.product, &b.product).cloned().collect(), }) .collect(), ScaledSum(x, a, b) => chain!( b.distribute(), iproduct!(x.distribute(), a.distribute()).map(|(x, a)| Term { - coeff: &x.coeff * &a.coeff, - vars: chain!(&x.vars, &a.vars).cloned().collect(), + scalar: &x.scalar * &a.scalar, + product: chain!(&x.product, &a.product).cloned().collect(), }) ) .collect(), } } - fn combine(mut terms: Vec>) -> Vec> { - for Term { vars, .. } in &mut terms { - vars.sort(); + fn combine( + mut terms: Vec, Expression>>, + ) -> Vec, Expression>> { + for Term { product, .. } in &mut terms { + product.sort(); } terms .into_iter() - .map(|Term { coeff, vars }| (vars, coeff)) + .map(|Term { scalar, product }| (product, scalar)) .into_group_map() .into_iter() - .map(|(vars, coeffs)| Term { - coeff: coeffs.into_iter().sum(), - vars, + .map(|(product, scalar)| Term { + scalar: scalar.into_iter().sum(), + product, }) .collect() } } -impl Sum> for Expression { - fn sum>>(iter: I) -> Self { - iter.map(|term| term.coeff * term.vars.into_iter().product::>()) +impl Sum, Expression>> for Expression { + fn sum, Expression>>>(iter: I) -> Self { + iter.map(|term| term.scalar * term.product.into_iter().product::>()) .sum() } } -#[derive(Clone, Debug)] -struct Term { - coeff: Expression, - vars: Vec>, +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Term { + pub scalar: S, + pub product: Vec

, +} + +impl Display for Term, Expression> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // join the product terms with " * " + let product_str = self + .product + .iter() + .map(|p| p.to_string()) + .collect::>() + .join(" * "); + // format as: scalar * (a * b * c) + write!(f, "{} * ({})", self.scalar, product_str) + } } #[cfg(test)] @@ -81,6 +104,7 @@ mod tests { use crate::expression::{Fixed as FixedS, utils::eval_by_expr_with_fixed}; use super::*; + use either::Either; use ff_ext::{FieldInto, FromUniformBytes, GoldilocksExt2 as E}; use p3::{field::PrimeCharacteristicRing, goldilocks::Goldilocks as F}; use rand::thread_rng; @@ -94,30 +118,40 @@ mod tests { let a = || Fixed(FixedS(0)); let b = || Fixed(FixedS(1)); let c = || Fixed(FixedS(2)); - let x = || WitIn(0); - let y = || WitIn(1); - let z = || WitIn(2); - let n = || Constant(104u64.into_f()); - let m = || Constant(-F::from_u64(599)); + let x1 = || WitIn(0); + let x2 = || WitIn(1); + let x3 = || WitIn(2); + let x4 = || WitIn(3); + let x5 = || WitIn(4); + let x6 = || WitIn(5); + let x7 = || WitIn(6); + + let n1 = || Constant(Either::Left(103u64.into_f())); + let n2 = || Constant(Either::Left(101u64.into_f())); + let m = || Constant(Either::Left(-F::from_u64(599))); let r = || Challenge(0, 1, E::ONE, E::ZERO); let test_exprs: &[Expression] = &[ - a() * x() * x(), + a() * x1() * x2(), a(), - x(), - n(), + x1(), + n1(), r(), - a() + b() + x() + y() + n() + m() + r(), - a() * x() * n() * r(), - x() * y() * z(), - (x() + y() + a()) * b() * (y() + z()) + c(), - (r() * x() + n() + z()) * m() * y(), - (b() + y() + m() * z()) * (x() + y() + c()), - a() * r() * x(), + a() + b() + x1() + x2() + n1() + m() + r(), + a() * x1() * n1() * r(), + x1() * x2() * x3(), + (x1() + x2() + a()) * b() * (x2() + x3()) + c(), + (r() * x1() + n1() + x3()) * m() * x2(), + (b() + x2() + m() * x3()) * (x1() + x2() + c()), + a() * r() * x1(), + x1() * (n1() * (x2() * x3() + x4() * x5())) + n2() * x2() * x4() + x1() * x6() * x7(), ]; for factored in test_exprs { - let monomials = factored.to_monomial_form_inner(); + let monomials = factored + .get_monomial_terms() + .into_iter() + .sum::>(); assert!(monomials.is_monomial_form()); // Check that the two forms are equivalent (Schwartz-Zippel test). @@ -140,6 +174,10 @@ mod tests { E::random(&mut rng), E::random(&mut rng), E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), ]; let challenges = vec![ E::random(&mut rng), diff --git a/multilinear_extensions/src/expression/utils.rs b/multilinear_extensions/src/expression/utils.rs index 31c3f1d49..ef740a659 100644 --- a/multilinear_extensions/src/expression/utils.rs +++ b/multilinear_extensions/src/expression/utils.rs @@ -1,5 +1,8 @@ +use either::Either; use ff_ext::ExtensionField; +use crate::combine_cumulative_either; + use super::{Expression, StructuralWitIn, WitIn}; impl WitIn { @@ -34,7 +37,11 @@ pub fn eval_by_expr_with_fixed( &|f| fixed[f.0], &|witness_id| witnesses[witness_id as usize], &|witness_id, _, _, _| structural_witnesses[witness_id as usize], - &|scalar| scalar.into(), + &|scalar| { + scalar + .map_either(|scalar| E::from(scalar), |scalar| scalar) + .into_inner() + }, &|challenge_id, pow, scalar, offset| { // TODO cache challenge power to be acquired once for each power let challenge = challenges[challenge_id as usize]; @@ -53,20 +60,24 @@ pub fn eval_by_expr_with_instance( instance: &[E], challenges: &[E], expr: &Expression, -) -> E { - expr.evaluate_with_instance::( - &|f| fixed[f.0], - &|witness_id| witnesses[witness_id as usize], - &|witness_id, _, _, _| structural_witnesses[witness_id as usize], - &|i| instance[i.0], - &|scalar| scalar.into(), +) -> Either { + expr.evaluate_with_instance::>( + &|f| Either::Right(fixed[f.0]), + &|witness_id| Either::Right(witnesses[witness_id as usize]), + &|witness_id, _, _, _| Either::Right(structural_witnesses[witness_id as usize]), + &|i| Either::Right(instance[i.0]), + &|scalar| scalar, &|challenge_id, pow, scalar, offset| { // TODO cache challenge power to be acquired once for each power let challenge = challenges[challenge_id as usize]; - challenge.exp_u64(pow as u64) * scalar + offset + Either::Right(challenge.exp_u64(pow as u64) * scalar + offset) + }, + &|a, b| combine_cumulative_either!(a, b, |a, b| a + b), + &|a, b| combine_cumulative_either!(a, b, |a, b| a * b), + &|x, a, b| { + let ax = combine_cumulative_either!(a, x, |c1, c2| c1 * c2); + // ax + b + combine_cumulative_either!(ax, b, |c1, c2| c1 + c2) }, - &|a, b| a + b, - &|a, b| a * b, - &|x, a, b| a * x + b, ) } diff --git a/multilinear_extensions/src/lib.rs b/multilinear_extensions/src/lib.rs index e42011c02..0ee53aed5 100644 --- a/multilinear_extensions/src/lib.rs +++ b/multilinear_extensions/src/lib.rs @@ -1,5 +1,6 @@ #![deny(clippy::cargo)] #![feature(decl_macro)] +#![feature(strict_overflow_ops)] mod expression; pub use expression::*; pub mod macros; diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index fb5bc37af..02225e1d4 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -15,6 +15,9 @@ use rayon::iter::{ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt::Debug; +/// A point is a vector of num_var length +pub type Point = Vec; + pub trait MultilinearExtension: Send + Sync { type Output; fn fix_variables(&self, partial_point: &[E]) -> Self::Output; @@ -895,31 +898,8 @@ impl<'a, E: ExtensionField> MultilinearExtension for RangedMultilinearExtensi unimplemented!() } - fn fix_high_variables(&self, partial_point: &[E]) -> Self::Output { - // TODO: return error. - assert!( - partial_point.len() <= self.num_vars(), - "invalid size of partial point" - ); - if !partial_point.is_empty() { - let last = partial_point.last().unwrap(); - let inner = self.inner; - let half_size = self.offset >> 1; - let mut mle = op_mle!(inner, |evaluations| { - DenseMultilinearExtension::from_evaluations_ext_vec(self.num_vars() - 1, { - let (lo, hi) = evaluations[self.start..][..self.offset].split_at(half_size); - lo.par_iter() - .zip(hi) - .with_min_len(64) - .map(|(lo, hi)| *last * (*hi - *lo) + *lo) - .collect() - }) - }); - mle.fix_high_variables_in_place(&partial_point[..partial_point.len() - 1]); - mle - } else { - self.inner.clone() - } + fn fix_high_variables(&self, _partial_point: &[E]) -> Self::Output { + unimplemented!() } fn fix_high_variables_in_place(&mut self, _partial_point: &[E]) { @@ -934,8 +914,8 @@ impl<'a, E: ExtensionField> MultilinearExtension for RangedMultilinearExtensi self.num_vars } - fn fix_variables_parallel(&self, partial_point: &[E]) -> Self::Output { - self.inner.fix_variables_parallel(partial_point) + fn fix_variables_parallel(&self, _partial_point: &[E]) -> Self::Output { + unimplemented!() } fn fix_variables_in_place_parallel(&mut self, _partial_point: &[E]) { diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index c1c49d63b..e01f90b92 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -5,8 +5,10 @@ use std::{ use crate::{ macros::{entered_span, exit_span}, mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, + monomial::Term, util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; +use either::Either; use ff_ext::ExtensionField; use itertools::Itertools; use p3::field::Field; @@ -15,10 +17,21 @@ use rayon::{ iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}, slice::ParallelSliceMut, }; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; pub type ArcMultilinearExtension<'a, E> = Arc> + 'a>; +pub type MonomialTermsType<'a, E> = + Vec::BaseField, E>, ArcMultilinearExtension<'a, E>>>; + +#[derive(Default, Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct MonomialTerms { + pub terms: Vec, usize>>, +} #[rustfmt::skip] /// A virtual polynomial is a sum of products of multilinear polynomials; @@ -51,13 +64,15 @@ pub type ArcMultilinearExtension<'a, E> = pub struct VirtualPolynomial<'a, E: ExtensionField> { /// Aux information about the multilinear polynomial pub aux_info: VPAuxInfo, - /// list of reference to products (as usize) of multilinear extension - pub products: Vec<(E, Vec)>, + // format (eq, monomial_form_formula) + pub products: Vec<(Option, MonomialTerms)>, /// Stores multilinear extensions in which product multiplicand can refer /// to. pub flattened_ml_extensions: Vec>, /// Pointers to the above poly extensions raw_pointers_lookup_table: HashMap, + + } #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] @@ -101,59 +116,131 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> { phantom: PhantomData, }, // here `0` points to the first polynomial of `flattened_ml_extensions` - products: vec![(coefficient, vec![0])], + products: vec![( + None, + MonomialTerms { + terms: vec![Term { + scalar: Either::Right(coefficient), + product: vec![0], + }], + }, + )], flattened_ml_extensions: vec![mle], raw_pointers_lookup_table: hm, } } + /// registers a multilinear extension (MLE) in flat storage and tracks its pointer to ensure uniqueness. + /// + /// assigns a unique index to the given `mle` and asserts that it hasn't been registered before + /// by checking its raw pointer. + /// + /// panics if the same MLE (by pointer) is registered more than once. + pub fn register_mle(&mut self, mle: ArcMultilinearExtension<'a, E>) { + let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; + let curr_index = self.flattened_ml_extensions.len(); + self.flattened_ml_extensions.push(mle); + let prev = self.raw_pointers_lookup_table.insert(mle_ptr, curr_index); + assert!(prev.is_none(), "duplicate mle_ptr: {}", mle_ptr); + } + + pub fn add_monomial_terms( + &mut self, + zero_check_half_eq: Option>, + monomial_terms: MonomialTermsType<'a, E>, + ) -> (Option, &MonomialTerms) { + // TODO probably need to add sanity check for all monomial_terms poly equals to eq num_vars + 1 + + let terms = monomial_terms + .into_iter() + .map(|term| { + let Term { scalar, product } = term; + assert!(!product.is_empty(), "some term product is empty"); + // sanity check: all mle in product must have same num_vars() + assert!(product.iter().map(|m| { m.num_vars() }).all_equal()); + + self.aux_info.max_degree = max(self.aux_info.max_degree, product.len()); + let mut indexed_product = Vec::with_capacity(product.len()); + + for mle in product { + let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; + if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) { + indexed_product.push(*index) + } else { + let curr_index = self.flattened_ml_extensions.len(); + self.flattened_ml_extensions.push(mle); + self.raw_pointers_lookup_table.insert(mle_ptr, curr_index); + indexed_product.push(curr_index); + } + } + Term { + scalar, + product: indexed_product, + } + }) + .collect_vec(); + + let eq_index = if let Some(zero_check_half_eq) = zero_check_half_eq { + let eq_ptr: usize = Arc::as_ptr(&zero_check_half_eq) as *const () as usize; + let curr_index = self.flattened_ml_extensions.len(); + self.flattened_ml_extensions.push(zero_check_half_eq); + self.raw_pointers_lookup_table.insert(eq_ptr, curr_index); + Some(curr_index) + } else { + None + }; + + self.products.push((eq_index, MonomialTerms { terms })); + ( + eq_index, + self.products + .last() + .map(|(_, monomial_terms)| monomial_terms) + .unwrap(), + ) + } + /// Add a product of list of multilinear extensions to self /// Returns an error if the list is empty. /// - /// mle in mle_list must be in same num_vars() in same product, + /// mle in product must be in same num_vars() in same product, /// while different product can have different num_vars() /// /// The MLEs will be multiplied together, and then multiplied by the scalar - /// `coefficient`. + /// `scalar`. pub fn add_mle_list( &mut self, - mle_list: Vec>, - coefficient: E, - ) -> &[usize] { - let mle_list: Vec> = mle_list.into_iter().collect(); - let mut indexed_product = Vec::with_capacity(mle_list.len()); - - assert!(!mle_list.is_empty(), "input mle_list is empty"); - // sanity check: all mle in mle_list must have same num_vars() - assert!(mle_list.iter().map(|m| { m.num_vars() }).all_equal()); - - self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len()); - - for mle in mle_list { - let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize; - if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) { - indexed_product.push(*index) - } else { - let curr_index = self.flattened_ml_extensions.len(); - self.flattened_ml_extensions.push(mle); - self.raw_pointers_lookup_table.insert(mle_ptr, curr_index); - indexed_product.push(curr_index); - } - } - self.products.push((coefficient, indexed_product)); - &self.products.last().unwrap().1 + product: Vec>, + scalar: E, + ) -> &MonomialTerms { + let (_, monomial_terms) = self.add_monomial_terms( + None, + vec![Term { + scalar: Either::Right(scalar), + product, + }], + ); + monomial_terms } /// in-place merge with another virtual polynomial pub fn merge(&mut self, other: &VirtualPolynomial<'a, E>) { let start = entered_span!("virtual poly add"); - for (coeffient, products) in other.products.iter() { - let cur: Vec<_> = products + for (zero_check_half_eq_index, MonomialTerms { terms }) in other.products.iter() { + let new_monomial_term = terms .iter() - .map(|&x| other.flattened_ml_extensions[x].clone()) - .collect(); - - self.add_mle_list(cur, *coeffient); + .map(|Term { scalar, product }| Term { + scalar: *scalar, + product: product + .iter() + .map(|&x| other.flattened_ml_extensions[x].clone()) + .collect(), + }) + .collect_vec(); + let zero_check_eq = zero_check_half_eq_index.map(|zero_check_half_eq_index| { + other.flattened_ml_extensions[zero_check_half_eq_index].clone() + }); + self.add_monomial_terms(zero_check_eq, new_monomial_term); } exit_span!(start); } @@ -180,7 +267,15 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> { let res = self .products .iter() - .map(|(c, p)| p.iter().map(|&i| evals[i]).product::() * *c) + .map(|(zero_check_half_eq, MonomialTerms { terms })| { + assert!(zero_check_half_eq.is_none(), "do not support evaluate with eq"); + terms + .iter() + .map(|Term { scalar, product }| { + either::for_both!(scalar, c => product.iter().map(|&i| evals[i]).product::() * *c) + }) + .sum() + }) .sum(); exit_span!(start); diff --git a/multilinear_extensions/src/virtual_polys.rs b/multilinear_extensions/src/virtual_polys.rs index 253d0ddda..64f1bb578 100644 --- a/multilinear_extensions/src/virtual_polys.rs +++ b/multilinear_extensions/src/virtual_polys.rs @@ -1,18 +1,26 @@ use std::{ collections::{BTreeMap, HashMap}, + marker::PhantomData, sync::Arc, }; use crate::{ + Expression, WitnessId, + expression::monomial::Term, util::ceil_log2, - virtual_poly::{ArcMultilinearExtension, VirtualPolynomial}, + utils::eval_by_expr_with_instance, + virtual_poly::{ArcMultilinearExtension, MonomialTerms, VirtualPolynomial}, }; +use either::Either; use ff_ext::ExtensionField; use itertools::Itertools; use p3::util::log2_strict_usize; use crate::util::transpose; +pub type MonomialTermsType<'a, E> = + Vec::BaseField, E>, &'a ArcMultilinearExtension<'a, E>>>; + #[derive(Debug, Default, Clone, Copy)] pub enum PolyMeta { #[default] @@ -20,6 +28,117 @@ pub enum PolyMeta { Phase2Only, } +/// a builder for constructing expressive polynomial formulas represented as expression, +/// primarily used in the sumcheck protocol. +/// +/// this struct manages witness identifiers and multilinear extensions (mles), +/// enabling reuse and deduplication of polynomial +#[derive(Default)] +pub struct VirtualPolynomialsBuilder<'a, E: ExtensionField> { + num_witin: WitnessId, + mles_storage: BTreeMap)>, + _phantom: PhantomData, +} + +impl<'a, E: ExtensionField> VirtualPolynomialsBuilder<'a, E> { + pub fn lift(&mut self, mle: &'a ArcMultilinearExtension<'a, E>) -> Expression { + let mle_ptr: usize = Arc::as_ptr(mle) as *const () as usize; + let (witin_id, _) = self.mles_storage.entry(mle_ptr).or_insert_with(|| { + let witin_id = self.num_witin; + self.num_witin = self.num_witin.strict_add(1); + (witin_id as usize, mle) + }); + + Expression::WitIn(*witin_id as u16) + } + + pub fn to_virtual_polys( + self, + num_threads: usize, + max_num_variables: usize, + half_eq_mles: Option>>, + expressions: &[Expression], + challenges: &[E], + ) -> VirtualPolynomials<'a, E> { + let mles_storage = self + .mles_storage + .values() + .collect::>() // collect into Vec<&(usize, &ArcMultilinearExtension)> + .into_iter() + .sorted_by_key(|(witin_id, _)| *witin_id) // sort by witin_id + .map(|(_, mle)| *mle) // extract &ArcMultilinearExtension + .collect::>(); + + // when half_eq is provided, then all monomial term need to be in same num_vars + let expected_num_vars_per_expr = if let Some(half_eq_mles) = half_eq_mles.as_ref() { + assert_eq!(half_eq_mles.len(), expressions.len()); + Some( + half_eq_mles + .iter() + .map(|half_eq| half_eq.num_vars() + 1) // half_eq + .collect_vec(), + ) + } else { + None + }; + + let mut virtual_polys = VirtualPolynomials::::new(num_threads, max_num_variables); + // register mles to assure index matching the arc_poly order + virtual_polys.register_mles(mles_storage.clone()); + + // convert expression into monomial_terms and add to virtual_polys + for (i, expression) in expressions.iter().enumerate() { + let monomial_terms_expr = expression.get_monomial_terms(); + let monomial_terms = monomial_terms_expr + .into_iter() + .map( + |Term { + scalar: scalar_expr, + product, + }| { + let expected_num_vars = expected_num_vars_per_expr.as_ref().and_then( + |expected_num_vars_per_expr| expected_num_vars_per_expr.get(i), + ); + + let product_mle = product + .into_iter() + .map(|expr| match expr { + Expression::WitIn(witin_id) => { + let mle = mles_storage[witin_id as usize]; + if let Some(expected_num_vars) = expected_num_vars { + assert_eq!(*expected_num_vars, mle.num_vars()); + } + mle + } + other => unimplemented!("un supported expression: {:?}", other), + }) + .collect_vec(); + let scalar = eval_by_expr_with_instance( + &[], + &[], + &[], + &[], + challenges, + &scalar_expr, + ); + Term { + scalar, + product: product_mle, + } + }, + ) + .collect_vec(); + virtual_polys.add_monomial_terms( + half_eq_mles + .as_ref() + .and_then(|half_eq_mles| half_eq_mles.get(i).cloned()), + monomial_terms, + ); + } + virtual_polys + } +} + pub struct VirtualPolynomials<'a, E: ExtensionField> { pub num_threads: usize, polys: Vec>, @@ -56,65 +175,164 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { .collect_vec() } - pub fn add_mle_list(&mut self, polys: Vec<&'a ArcMultilinearExtension<'a, E>>, coeff: E) { + /// registers a batch of multilinear extensions (MLEs) across all threads, + /// distributing each based on num_vars. + /// + /// for each input `mle`, if it is large enough (i.e., has more variables than `log2(num_threads)`), + /// it is split and assigned to the corresponding thread using `get_range_polys_by_thread_id`. + /// otherwise, the full polynomial is duplicated across all threads. + /// + /// the per-thread instances are registered locally and stored in `thread_based_mles_storage` + /// using the MLE’s raw pointer as the key to ensure uniqueness and reference consistency. + pub fn register_mles(&mut self, mles: Vec<&'a ArcMultilinearExtension<'a, E>>) { let log2_num_threads = log2_strict_usize(self.num_threads); - let (poly_meta, polys): (Vec, Vec>>) = polys + for mle in mles { + let mle_ptr: usize = Arc::as_ptr(mle) as *const () as usize; + let mles = (0..self.num_threads) + .map(|thread_id| { + let mle_thread_based = if mle.num_vars() > log2_num_threads { + self.get_range_polys_by_thread_id(thread_id, vec![mle]) + .remove(0) + } else { + // polynomial is too small + Arc::new(mle.get_ranged_mle(1, 0)) + }; + self.polys[thread_id].register_mle(mle_thread_based.clone()); + mle_thread_based + }) + .collect_vec(); + self.thread_based_mles_storage.insert(mle_ptr, mles); + } + } + + /// Adds a group of monomial terms to the current expression set. + /// + /// NOTE: When `zero_check_half_eq` is provided, no deduplication of equality constraints + /// is performed internally. It is the caller’s responsibility to ensure that + /// `zero_check_half_eq` contains only equality constraints unique to this `monomial_terms` group, + /// as reusing the same equality across different groups is semantically invalid. + pub fn add_monomial_terms( + &mut self, + zero_check_half_eq: Option<&'a ArcMultilinearExtension<'a, E>>, + monomial_terms: MonomialTermsType<'a, E>, + ) { + let log2_num_threads = log2_strict_usize(self.num_threads); + + // process eq and separate to thread + let zero_check_half_eq_per_threads = if let Some(zero_check_half_eq) = zero_check_half_eq { + Some( + (0..self.num_threads) + .map(|thread_id| { + if zero_check_half_eq.num_vars() > log2_num_threads { + self.get_range_polys_by_thread_id(thread_id, vec![zero_check_half_eq]) + .remove(0) + } else { + // polynomial is too small + Arc::new(zero_check_half_eq.get_ranged_mle(1, 0)) + } + }) + .collect_vec(), + ) + } else { + None + }; + + let (poly_meta, momomial_terms): (Vec<_>, Vec<_>) = monomial_terms .into_iter() - .map(|p| { - let mle_ptr: usize = Arc::as_ptr(p) as *const () as usize; - let poly_meta = if p.num_vars() > log2_num_threads { + .map(|Term { scalar, product }| { + assert!(!product.is_empty(), "some term product is empty"); + // all mle in product must have same num_vars() + assert!(product.iter().map(|m| { m.num_vars() }).all_equal()); + + let poly_meta = if product.first().unwrap().num_vars() > log2_num_threads { PolyMeta::Normal } else { // polynomial is too small PolyMeta::Phase2Only }; - let mles_cloned = if let Some(mles) = self.thread_based_mles_storage.get(&mle_ptr) { - mles.clone() - } else { - let mles = (0..self.num_threads) - .map(|thread_id| match poly_meta { - PolyMeta::Normal => self - .get_range_polys_by_thread_id(thread_id, vec![p]) - .remove(0), - PolyMeta::Phase2Only => Arc::new(p.get_ranged_mle(1, 0)), - }) - .collect_vec(); - let mles_cloned = mles.clone(); - self.thread_based_mles_storage.insert(mle_ptr, mles); - mles_cloned - }; - (poly_meta, mles_cloned) + + let product_per_threads: Vec>> = product + .into_iter() + .map(|p| { + let mle_ptr: usize = Arc::as_ptr(p) as *const () as usize; + let mles_cloned = + if let Some(mles) = self.thread_based_mles_storage.get(&mle_ptr) { + mles.clone() + } else { + let mles = (0..self.num_threads) + .map(|thread_id| match poly_meta { + PolyMeta::Normal => self + .get_range_polys_by_thread_id(thread_id, vec![p]) + .remove(0), + PolyMeta::Phase2Only => Arc::new(p.get_ranged_mle(1, 0)), + }) + .collect_vec(); + let mles_cloned = mles.clone(); + self.thread_based_mles_storage.insert(mle_ptr, mles); + mles_cloned + }; + mles_cloned + }) + .collect_vec(); + + // product -> thread to thread -> product + ( + poly_meta, + transpose(product_per_threads) + .into_iter() + .map(|product| Term { scalar, product }) + // return Vec, with total length equal #threads + .collect_vec(), + ) }) .unzip(); - // poly -> thread to thread -> poly - let polys = transpose(polys); - let poly_index: &[usize] = self - .polys - .iter_mut() - .zip_eq(polys) - .map(|(poly, polys)| poly.add_mle_list(polys, coeff)) - .collect_vec() - .first() - .expect("expect to get at index from first thread"); + let momomial_terms_threads = transpose(momomial_terms); + assert_eq!(momomial_terms_threads.len(), self.num_threads); - poly_index - .iter() - .zip_eq(&poly_meta) - .for_each(|(index, poly_meta)| { + // collect per thread momomial_terms and add to thread-based virtual_poly + let (hald_eq_index, monomial_term_product_index): (Option, &MonomialTerms) = + *momomial_terms_threads + .into_iter() + .zip_eq(self.polys.iter_mut()) + .enumerate() + .map(|(thread_id, (momomial_terms, virtual_poly))| { + let zero_check_half_eq = zero_check_half_eq_per_threads + .as_ref() + .and_then(|zero_check_half_eq| zero_check_half_eq.get(thread_id).cloned()); + virtual_poly.add_monomial_terms(zero_check_half_eq, momomial_terms) + }) + .collect_vec() + .first() + .expect(""); + + // update poly_meta w.r.t index, optionally record index for eq + if let Some((index, half_eq_mle)) = hald_eq_index.as_ref().zip(zero_check_half_eq.as_ref()) + { + let poly_meta = if half_eq_mle.num_vars() + 1 > log2_num_threads { + PolyMeta::Normal + } else { + // polynomial is too small + PolyMeta::Phase2Only + }; + self.poly_meta.insert(*index, poly_meta); + } + for (poly_meta, term) in poly_meta.iter().zip_eq(&monomial_term_product_index.terms) { + for index in &term.product { self.poly_meta.insert(*index, *poly_meta); - }); + } + } } - /// in-place merge with another virtual polynomial - pub fn merge(&mut self, other: &'a VirtualPolynomial<'a, E>) { - for (coeffient, products) in other.products.iter() { - let cur: Vec<_> = products - .iter() - .map(|&x| &other.flattened_ml_extensions[x]) - .collect(); - self.add_mle_list(cur, *coeffient); - } + // add with only single monomial term + pub fn add_mle_list(&mut self, polys: Vec<&'a ArcMultilinearExtension<'a, E>>, scalar: E) { + self.add_monomial_terms( + None, + vec![Term { + scalar: Either::Right(scalar), + product: polys, + }], + ); } /// return thread_based polynomial with its polynomial type @@ -133,4 +351,24 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { .map(|p| p.aux_info.max_degree) .unwrap_or_default() } + + /// in-place merge with another virtual polynomial + pub fn merge(&mut self, other: &'a VirtualPolynomial<'a, E>) { + for (zero_check_half_eq_index, MonomialTerms { terms }) in other.products.iter() { + let new_monomial_term = terms + .iter() + .map(|Term { scalar, product }| Term { + scalar: *scalar, + product: product + .iter() + .map(|&x| &other.flattened_ml_extensions[x]) + .collect(), + }) + .collect_vec(); + let zero_check_half_eq = zero_check_half_eq_index.map(|zero_check_half_eq_index| { + &other.flattened_ml_extensions[zero_check_half_eq_index] + }); + self.add_monomial_terms(zero_check_half_eq, new_monomial_term); + } + } } diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index a48ac0a5f..7eda88bac 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -10,6 +10,7 @@ repository.workspace = true version.workspace = true [dependencies] +either.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true p3.workspace = true diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 88bbebbcd..a0a0b3fe2 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -5,9 +5,10 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ mle::FieldType, + monomial::Term, op_mle, util::largest_even_below, - virtual_poly::VirtualPolynomial, + virtual_poly::{MonomialTerms, VirtualPolynomial}, virtual_polys::{PolyMeta, VirtualPolynomials}, }; use rayon::{ @@ -400,40 +401,47 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) - let span = entered_span!("products_sum"); - let AdditiveVec(products_sum) = self.poly.products.iter().fold( + let span = entered_span!("build_uni_poly"); + let AdditiveVec(uni_polys) = self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), - |mut products_sum, (coefficient, prod)| { - let span = entered_span!("sum"); - let f = &self.poly.flattened_ml_extensions; - let f_type = &self.poly_meta; - let get_poly_meta = || f_type[prod[0]]; - let mut sum: Vec = match prod.len() { - 1 => sumcheck_code_gen!(1, false, |i| &f[prod[i]], || get_poly_meta()).to_vec(), - 2 => sumcheck_code_gen!(2, false, |i| &f[prod[i]], || get_poly_meta()).to_vec(), - 3 => sumcheck_code_gen!(3, false, |i| &f[prod[i]], || get_poly_meta()).to_vec(), - 4 => sumcheck_code_gen!(4, false, |i| &f[prod[i]], || get_poly_meta()).to_vec(), - 5 => sumcheck_code_gen!(5, false, |i| &f[prod[i]], || get_poly_meta()).to_vec(), - _ => unimplemented!("do not support degree {} > 5", prod.len()), - }; - exit_span!(span); - - sum.iter_mut().for_each(|sum| *sum *= *coefficient); - - let span = entered_span!("extrapolation"); - let extrapolation = (0..self.poly.aux_info.max_degree - prod.len()) - .map(|i| { - let (points, weights) = &self.extrapolation_aux[prod.len() - 1]; - let at = E::from_u64((prod.len() + 1 + i) as u64); - serial_extrapolate(points, weights, &sum, &at) - }) - .collect::>(); - sum.extend(extrapolation); - exit_span!(span); - let span = entered_span!("extend_extrapolate"); - products_sum += AdditiveVec(sum); - exit_span!(span); - products_sum + |mut uni_polys, (_half_eq_opt, MonomialTerms { terms })| { + for Term { + scalar, + product: prod, + } in terms + { + let f = &self.poly.flattened_ml_extensions; + let f_type = &self.poly_meta; + let get_poly_meta = || f_type[prod[0]]; + let mut uni_variate: Vec = match prod.len() { + 1 => sumcheck_code_gen!(1, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 2 => sumcheck_code_gen!(2, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 3 => sumcheck_code_gen!(3, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 4 => sumcheck_code_gen!(4, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 5 => sumcheck_code_gen!(5, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + _ => unimplemented!("do not support degree {} > 5", prod.len()), + }; + + uni_variate + .iter_mut() + .for_each(|sum| either::for_both!(scalar, scalar => *sum *= *scalar)); + + let extrapolation = (0..self.poly.aux_info.max_degree - prod.len()) + .map(|i| { + let (points, weights) = &self.extrapolation_aux[prod.len() - 1]; + let at = E::from_u64((prod.len() + 1 + i) as u64); + serial_extrapolate(points, weights, &uni_variate, &at) + }) + .collect::>(); + uni_variate.extend(extrapolation); + uni_polys += AdditiveVec(uni_variate); + } + uni_polys }, ); exit_span!(span); @@ -441,7 +449,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { exit_span!(start); IOPProverMessage { - evaluations: products_sum, + evaluations: uni_polys, } } @@ -660,50 +668,50 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) - let span = entered_span!("products_sum"); - let AdditiveVec(products_sum) = self + let span = entered_span!("build_uni_poly"); + let AdditiveVec(uni_polys) = self .poly .products .par_iter() .fold_with( AdditiveVec::new(self.poly.aux_info.max_degree + 1), - |mut products_sum, (coefficient, prod)| { - let span = entered_span!("sum"); - - let f = &self.poly.flattened_ml_extensions; - let f_type = &self.poly_meta; - let get_poly_meta = || f_type[prod[0]]; - let mut sum: Vec = match prod.len() { - 1 => sumcheck_code_gen!(1, true, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 2 => sumcheck_code_gen!(2, true, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 3 => sumcheck_code_gen!(3, true, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 4 => sumcheck_code_gen!(4, true, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 5 => sumcheck_code_gen!(5, true, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - _ => unimplemented!("do not support degree {} > 5", prod.len()), - }; - exit_span!(span); - sum.iter_mut().for_each(|sum| *sum *= *coefficient); - - let span = entered_span!("extrapolation"); - let extrapolation = (0..self.poly.aux_info.max_degree - prod.len()) - .into_par_iter() - .map(|i| { - let (points, weights) = &self.extrapolation_aux[prod.len() - 1]; - let at = E::from_u64((prod.len() + 1 + i) as u64); - extrapolate(points, weights, &sum, &at) - }) - .collect::>(); - sum.extend(extrapolation); - exit_span!(span); - let span = entered_span!("extend_extrapolate"); - products_sum += AdditiveVec(sum); - exit_span!(span); - products_sum + |mut uni_polys, (_half_eq_opt, MonomialTerms { terms })| { + for Term { + scalar, + product: prod, + } in terms + { + let f = &self.poly.flattened_ml_extensions; + let f_type = &self.poly_meta; + let get_poly_meta = || f_type[prod[0]]; + let mut sum: Vec = match prod.len() { + 1 => sumcheck_code_gen!(1, true, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 2 => sumcheck_code_gen!(2, true, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 3 => sumcheck_code_gen!(3, true, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 4 => sumcheck_code_gen!(4, true, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 5 => sumcheck_code_gen!(5, true, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + _ => unimplemented!("do not support degree {} > 5", prod.len()), + }; + sum.iter_mut() + .for_each(|sum| either::for_both!(*scalar, scalar => *sum *= scalar)); + + let extrapolation = (0..self.poly.aux_info.max_degree - prod.len()) + .into_par_iter() + .map(|i| { + let (points, weights) = &self.extrapolation_aux[prod.len() - 1]; + let at = E::from_u64((prod.len() + 1 + i) as u64); + extrapolate(points, weights, &sum, &at) + }) + .collect::>(); + sum.extend(extrapolation); + uni_polys += AdditiveVec(sum); + } + uni_polys }, ) .reduce_with(|acc, item| acc + item) @@ -713,7 +721,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { exit_span!(start); IOPProverMessage { - evaluations: products_sum, + evaluations: uni_polys, } }