Skip to content

support build virtual polynomials in expression style #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cfg-if = "1.0"
clap = { version = "4.5", features = ["derive"] }
criterion = { version = "0.5", features = ["html_reports"] }
crossbeam-channel = "0.5"
either = { version = "1.15.*", features = ["serde"] }
itertools = "0.13"
num-bigint = { version = "0.4.6" }
num-derive = "0.4"
Expand Down
1 change: 1 addition & 0 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

use anyhow::{Result, anyhow};
use ff_ext::{ExtensionField, SmallField};
use itertools::Either;
use multilinear_extensions::{Expression, impl_expr_from_unsigned};
use num_derive::ToPrimitive;
use strum_macros::{Display, EnumIter};
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
|| "require_equal",
|cb| {
cb.cs
.require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form())
.require_zero(name_fn, a.get_monomial_form() - b.get_monomial_form())
},
)
}
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/chip_handler/global_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use ff_ext::ExtensionField;

use super::GlobalStateRegisterMachineChipOperations;
use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::RAMType};
use multilinear_extensions::Expression;
use multilinear_extensions::{Expression, ToExpr};
use p3::field::PrimeCharacteristicRing;

impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitBuilder<'_, E> {
fn state_in(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError> {
let record: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
pc,
ts,
];
Expand All @@ -18,7 +18,7 @@ impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitB

fn state_out(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError> {
let record: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
pc,
ts,
];
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/chip_handler/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::iter::successors;

use ff_ext::ExtensionField;
use itertools::izip;
use multilinear_extensions::Expression;
use multilinear_extensions::{Expression, ToExpr};
use p3::field::PrimeCharacteristicRing;

pub fn rlc_chip_record<E: ExtensionField>(
Expand Down Expand Up @@ -31,7 +31,7 @@ pub fn power_sequence<E: ExtensionField>(
),
"expression must be constant or challenge"
);
successors(Some(Expression::Constant(E::BaseField::ONE)), move |prev| {
successors(Some(E::BaseField::ONE.expr()), move |prev| {
Some(prev.clone() * base.clone())
})
}
14 changes: 7 additions & 7 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use itertools::{Itertools, chain};
use multilinear_extensions::{Expression, Fixed, Instance, StructuralWitIn, WitIn, WitnessId};
use multilinear_extensions::{
Expression, Fixed, Instance, StructuralWitIn, ToExpr, WitIn, WitnessId,
};
use serde::de::DeserializeOwned;
use std::{collections::HashMap, iter::once, marker::PhantomData};

Expand Down Expand Up @@ -256,11 +258,9 @@ impl<E: ExtensionField> ConstraintSystem<E> {
record: Vec<Expression<E>>,
) -> Result<(), ZKVMError> {
let rlc_record = self.rlc_chip_record(
std::iter::once(Expression::Constant(E::BaseField::from_u64(
rom_type as u64,
)))
.chain(record.clone())
.collect(),
std::iter::once(E::BaseField::from_u64(rom_type as u64).expr())
.chain(record.clone())
.collect(),
);
assert_eq!(
rlc_record.degree(),
Expand Down Expand Up @@ -432,7 +432,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
let assert_zero_expr = if assert_zero_expr.is_monomial_form() {
assert_zero_expr
} else {
let e = assert_zero_expr.to_monomial_form();
let e = assert_zero_expr.get_monomial_form();
assert!(e.is_monomial_form(), "failed to put into monomial form");
e
};
Expand Down
28 changes: 11 additions & 17 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,22 +1147,16 @@ Hints:

let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) =
derive_ram_rws!(RAMType::GlobalState);
gs_rs.insert(eval_by_expr_with_instance(
&[],
&[],
&[],
&instance,
&challenges,
&gs_final,
));
gs_ws.insert(eval_by_expr_with_instance(
&[],
&[],
&[],
&instance,
&challenges,
&gs_init,
));
gs_rs.insert(
eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_final)
.right()
.unwrap(),
);
gs_ws.insert(
eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_init)
.right()
.unwrap(),
);

// gs stores { (pc, timestamp) }
find_rw_mismatch!(
Expand Down Expand Up @@ -1382,7 +1376,7 @@ mod tests {
GoldilocksExt2::ONE,
GoldilocksExt2::ZERO,
)),
Box::new(Expression::Constant(Goldilocks::from_u64(U5 as u64))),
Box::new(Goldilocks::from_u64(U5 as u64).expr()),
)),
Box::new(Expression::Challenge(
0,
Expand Down
55 changes: 29 additions & 26 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use ff_ext::ExtensionField;
use std::collections::{BTreeMap, BTreeSet, HashMap};

use itertools::{Itertools, enumerate, izip};
use itertools::{Either, Itertools, enumerate, izip};
use mpcs::{Point, PolynomialCommitmentScheme};
use multilinear_extensions::{
Expression,
mle::IntoMLE,
util::ceil_log2,
virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec},
virtual_polys::VirtualPolynomials,
virtual_polys::{VirtualPolynomials, VirtualPolynomialsBuilder},
};
use p3::field::{PrimeCharacteristicRing, dot_product};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use std::iter::Iterator;
use sumcheck::{
macros::{entered_span, exit_span},
structs::{IOPProverMessage, IOPProverState},
Expand Down Expand Up @@ -1230,10 +1232,15 @@ impl TowerProver {
let num_threads = optimal_sumcheck_threads(out_rt.len());

let eq: ArcMultilinearExtension<E> = build_eq_x_r_vec(&out_rt).into_mle().into();
let mut virtual_polys = VirtualPolynomials::<E>::new(num_threads, out_rt.len());

let mut expr_builder = VirtualPolynomialsBuilder::default();
let mut exprs =
Vec::<Expression<E>>::with_capacity(prod_specs.len() + logup_specs.len());
let eq_expr = expr_builder.lift(&eq);

for (s, alpha) in izip!(&prod_specs, &alpha_pows) {
if round < s.witness.len() {
let alpha_expr = Expression::Constant(Either::Right(*alpha));
let layer_polys = &s.witness[round];

// sanity check
Expand All @@ -1246,11 +1253,9 @@ impl TowerProver {
})
);

let layer_polys_product = layer_polys.iter().map(|layer_poly| expr_builder.lift(layer_poly)).product::<Expression<E>>();
// \sum_s eq(rt, s) * alpha^{i} * ([in_i0[s] * in_i1[s] * .... in_i{num_product_fanin}[s]])
virtual_polys.add_mle_list(
[vec![&eq], layer_polys.iter().collect()].concat(),
*alpha,
)
exprs.push(eq_expr.clone() * alpha_expr *layer_polys_product);
}
}

Expand All @@ -1259,37 +1264,36 @@ impl TowerProver {
if round < s.witness.len() {
let layer_polys = &s.witness[round];
// sanity check
assert_eq!(layer_polys.len(), 4); // p1, q1, p2, q2
assert_eq!(layer_polys.len(), 4); // p1, p2, q1, q2
assert!(
layer_polys
.iter()
.all(|f| f.evaluations().len() == 1 << (log_num_fanin * round)),
);

let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]);
let (alpha_numerator, alpha_denominator) = (Expression::Constant(Either::Right(alpha[0])), Expression::Constant(Either::Right(alpha[1])));

let (q2, q1, p2, p1) = (
&layer_polys[3],
&layer_polys[2],
&layer_polys[1],
&layer_polys[0],
let (p1, p2, q1, q2) = (
expr_builder.lift(&layer_polys[0]),
expr_builder.lift(&layer_polys[1]),
expr_builder.lift(&layer_polys[2]),
expr_builder.lift(&layer_polys[3]),
);

// \sum_s eq(rt, s) * alpha_numerator^{i} * (p1 * q2 + p2 * q1)
virtual_polys.add_mle_list(vec![&eq, &p1, &q2], *alpha_numerator);
virtual_polys.add_mle_list(vec![&eq, &p2, &q1], *alpha_numerator);

// \sum_s eq(rt, s) * alpha_denominator^{i} * (q1 * q2)
virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator);
// \sum_s eq(rt, s) * (alpha_numerator^{i} * (p1 * q2 + p2 * q1) + alpha_denominator^{i} * q1 * q2)
exprs.push(eq_expr.clone() * (alpha_numerator * (p1 * q2.clone() + p2 * q1.clone()) + alpha_denominator * q1 * q2));
}
}

let wrap_batch_span = entered_span!("wrap_batch");
// NOTE: at the time of adding this span, visualizing it with the flamegraph layer
// shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys`
// This is likely a bug in the tracing-flame crate.
let (sumcheck_proofs, state) = IOPProverState::prove(
virtual_polys,
expr_builder.to_virtual_polys(
num_threads,
out_rt.len(),
None,
&[exprs.into_iter().sum::<Expression<E>>()],
&[],
),
transcript,
);
exit_span!(wrap_batch_span);
Expand Down Expand Up @@ -1323,11 +1327,10 @@ impl TowerProver {
for (i, s) in enumerate(&logup_specs) {
if round < s.witness.len() {
// collect evals belong to current spec
// p1, q2, p2, q1
let p1 = *evals_iter.next().expect("insufficient evals length");
let q2 = *evals_iter.next().expect("insufficient evals length");
let p2 = *evals_iter.next().expect("insufficient evals length");
let q1 = *evals_iter.next().expect("insufficient evals length");
let q2 = *evals_iter.next().expect("insufficient evals length");
proofs.push_logup_evals_and_point(i, vec![p1, p2, q1, q2], rt_prime.clone());
}
}
Expand Down
8 changes: 5 additions & 3 deletions ceno_zkvm/src/scheme/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,11 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>(
&|witness_id, _, _, _| structual_witnesses[witness_id as usize].clone(),
&|i| instance[i.0].clone(),
&|scalar| {
let scalar: ArcMultilinearExtension<E> = Arc::new(
DenseMultilinearExtension::from_evaluations_vec(0, vec![scalar]),
);
let scalar: ArcMultilinearExtension<E> =
Arc::new(DenseMultilinearExtension::from_evaluations_vec(
0,
vec![scalar.left().expect("do not support extension field")],
));
scalar
},
&|challenge_id, pow, scalar, offset| {
Expand Down
23 changes: 20 additions & 3 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
pi_evals,
&challenges,
&self.vk.initial_global_state_expr,
);
)
.right()
.unwrap();
prod_w *= initial_global_state;
let finalize_global_state = eval_by_expr_with_instance(
&[],
Expand All @@ -312,7 +314,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
pi_evals,
&challenges,
&self.vk.finalize_global_state_expr,
);
)
.right()
.unwrap();
prod_r *= finalize_global_state;
// check rw_set equality across all proofs
if prod_r != prod_w {
Expand Down Expand Up @@ -476,20 +480,26 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
.iter()
.map(|expr| {
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
.right()
.unwrap()
})
.collect();
let w_records_in_evals: Vec<_> = cs
.w_expressions
.iter()
.map(|expr| {
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
.right()
.unwrap()
})
.collect();
let lk_records_in_evals: Vec<_> = cs
.lk_expressions
.iter()
.map(|expr| {
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
.right()
.unwrap()
})
.collect();
let computed_evals = [
Expand Down Expand Up @@ -529,6 +539,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
challenges,
expr,
)
.right()
.unwrap()
})
.sum::<E>()
},
Expand All @@ -545,6 +557,8 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
// verify zero expression (degree = 1) statement, thus no sumcheck
if cs.assert_zero_expressions.iter().any(|expr| {
eval_by_expr_with_instance(&[], &proof.wits_in_evals, &[], pi, challenges, expr)
.right()
.unwrap()
!= E::ZERO
}) {
return Err(ZKVMError::VerifyError("zero expression != 0".into()));
Expand Down Expand Up @@ -813,7 +827,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
pi,
challenges,
expr,
) != expected_evals
)
.right()
.unwrap()
!= expected_evals
}) {
return Err(ZKVMError::VerifyError(
"record evaluate != expected_evals".into(),
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl<E: ExtensionField> StateCircuit<E> for GlobalState {
circuit_builder: &mut crate::circuit_builder::CircuitBuilder<E>,
) -> Result<Expression<E>, ZKVMError> {
let states: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
circuit_builder.query_init_pc()?.expr(),
circuit_builder.query_init_cycle()?.expr(),
];
Expand All @@ -32,7 +32,7 @@ impl<E: ExtensionField> StateCircuit<E> for GlobalState {
circuit_builder: &mut crate::circuit_builder::CircuitBuilder<E>,
) -> Result<Expression<E>, ZKVMError> {
let states: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from_u64(RAMType::GlobalState as u64)),
E::BaseField::from_u64(RAMType::GlobalState as u64).expr(),
circuit_builder.query_end_pc()?.expr(),
circuit_builder.query_end_cycle()?.expr(),
];
Expand Down
Loading