Skip to content

Commit dd98c57

Browse files
committed
shared vector pool across whole witness inference
1 parent c7263ab commit dd98c57

File tree

5 files changed

+144
-101
lines changed

5 files changed

+144
-101
lines changed

ceno_zkvm/src/expression.rs

+17-7
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,34 @@ impl<E: ExtensionField> Expression<E> {
143143
}
144144

145145
#[allow(clippy::too_many_arguments)]
146-
pub fn evaluate_with_instance_pool<T>(
146+
pub fn evaluate_with_instance_pool<T, PF1: Fn() -> Vec<E>, PF2: Fn() -> Vec<E::BaseField>>(
147147
&self,
148148
fixed_in: &impl Fn(&Fixed) -> T,
149149
wit_in: &impl Fn(WitnessId) -> T, // witin id
150150
instance: &impl Fn(Instance) -> T,
151151
constant: &impl Fn(E::BaseField) -> T,
152152
challenge: &impl Fn(ChallengeId, usize, E, E) -> T,
153-
sum: &impl Fn(T, T, &mut SimpleVecPool<Vec<E>>, &mut SimpleVecPool<Vec<E::BaseField>>) -> T,
154-
product: &impl Fn(T, T, &mut SimpleVecPool<Vec<E>>, &mut SimpleVecPool<Vec<E::BaseField>>) -> T,
153+
sum: &impl Fn(
154+
T,
155+
T,
156+
&mut SimpleVecPool<Vec<E>, PF1>,
157+
&mut SimpleVecPool<Vec<E::BaseField>, PF2>,
158+
) -> T,
159+
product: &impl Fn(
160+
T,
161+
T,
162+
&mut SimpleVecPool<Vec<E>, PF1>,
163+
&mut SimpleVecPool<Vec<E::BaseField>, PF2>,
164+
) -> T,
155165
scaled: &impl Fn(
156166
T,
157167
T,
158168
T,
159-
&mut SimpleVecPool<Vec<E>>,
160-
&mut SimpleVecPool<Vec<E::BaseField>>,
169+
&mut SimpleVecPool<Vec<E>, PF1>,
170+
&mut SimpleVecPool<Vec<E::BaseField>, PF2>,
161171
) -> T,
162-
pool_e: &mut SimpleVecPool<Vec<E>>,
163-
pool_b: &mut SimpleVecPool<Vec<E::BaseField>>,
172+
pool_e: &mut SimpleVecPool<Vec<E>, PF1>,
173+
pool_b: &mut SimpleVecPool<Vec<E::BaseField>, PF2>,
164174
) -> T {
165175
match self {
166176
Expression::Fixed(f) => fixed_in(f),

ceno_zkvm/src/scheme/mock_prover.rs

+22-55
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ use ff_ext::ExtensionField;
2323
use generic_static::StaticTypeMap;
2424
use goldilocks::SmallField;
2525
use itertools::{Itertools, enumerate, izip};
26-
use multilinear_extensions::{
27-
mle::IntoMLEs, util::max_usable_threads, virtual_poly_v2::ArcMultilinearExtension,
28-
};
26+
use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension};
2927
use rand::thread_rng;
3028
use std::{
3129
collections::{HashMap, HashSet},
@@ -428,7 +426,6 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
428426
challenge: Option<[E; 2]>,
429427
lkm: Option<LkMultiplicity>,
430428
) -> Result<(), Vec<MockProverError<E>>> {
431-
let n_threads = max_usable_threads();
432429
let program = Program::new(
433430
CENO_PLATFORM.pc_base(),
434431
CENO_PLATFORM.pc_base(),
@@ -476,12 +473,10 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
476473
let (left, right) = expr.unpack_sum().unwrap();
477474
let right = right.neg();
478475

479-
let left_evaluated =
480-
wit_infer_by_expr(&[], wits_in, pi, &challenge, &left, n_threads);
476+
let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left);
481477
let left_evaluated = left_evaluated.get_base_field_vec();
482478

483-
let right_evaluated =
484-
wit_infer_by_expr(&[], wits_in, pi, &challenge, &right, n_threads);
479+
let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right);
485480
let right_evaluated = right_evaluated.get_base_field_vec();
486481

487482
// left_evaluated.len() ?= right_evaluated.len() due to padding instance
@@ -501,8 +496,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
501496
}
502497
} else {
503498
// contains require_zero
504-
let expr_evaluated =
505-
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads);
499+
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr);
506500
let expr_evaluated = expr_evaluated.get_base_field_vec();
507501

508502
for (inst_id, element) in enumerate(expr_evaluated) {
@@ -525,7 +519,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
525519
.iter()
526520
.zip_eq(cb.cs.lk_expressions_namespace_map.iter())
527521
{
528-
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads);
522+
let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr);
529523
let expr_evaluated = expr_evaluated.get_ext_field_vec();
530524

531525
// Check each lookup expr exists in t vec
@@ -556,7 +550,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
556550
.map(|expr| {
557551
// TODO generalized to all inst_id
558552
let inst_id = 0;
559-
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr, n_threads)
553+
wit_infer_by_expr(&[], wits_in, pi, &challenge, expr)
560554
.get_base_field_vec()[inst_id]
561555
.to_canonical_u64()
562556
})
@@ -748,7 +742,6 @@ Hints:
748742
witnesses: &ZKVMWitnesses<E>,
749743
pi: &PublicValues<u32>,
750744
) {
751-
let n_threads = max_usable_threads();
752745
let instance = pi
753746
.to_vec::<E>()
754747
.concat()
@@ -822,16 +815,10 @@ Hints:
822815
.zip(cs.lk_expressions_namespace_map.clone().into_iter())
823816
.zip(cs.lk_expressions_items_map.clone().into_iter())
824817
{
825-
let lk_input = (wit_infer_by_expr(
826-
&fixed,
827-
&witness,
828-
&pi_mles,
829-
&challenges,
830-
expr,
831-
n_threads,
832-
)
833-
.get_ext_field_vec())[..num_rows]
834-
.to_vec();
818+
let lk_input =
819+
(wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, expr)
820+
.get_ext_field_vec())[..num_rows]
821+
.to_vec();
835822
rom_inputs.entry(rom_type).or_default().push((
836823
lk_input,
837824
circuit_name.clone(),
@@ -851,24 +838,17 @@ Hints:
851838
.iter()
852839
.zip(cs.lk_expressions_items_map.clone().into_iter())
853840
{
854-
let lk_table = wit_infer_by_expr(
855-
&fixed,
856-
&witness,
857-
&pi_mles,
858-
&challenges,
859-
&expr.values,
860-
n_threads,
861-
)
862-
.get_ext_field_vec()
863-
.to_vec();
841+
let lk_table =
842+
wit_infer_by_expr(&fixed, &witness, &pi_mles, &challenges, &expr.values)
843+
.get_ext_field_vec()
844+
.to_vec();
864845

865846
let multiplicity = wit_infer_by_expr(
866847
&fixed,
867848
&witness,
868849
&pi_mles,
869850
&challenges,
870851
&expr.multiplicity,
871-
n_threads,
872852
)
873853
.get_base_field_vec()
874854
.to_vec();
@@ -988,16 +968,10 @@ Hints:
988968
.zip_eq(cs.w_ram_types.iter())
989969
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
990970
{
991-
let write_rlc_records = (wit_infer_by_expr(
992-
fixed,
993-
witness,
994-
&pi_mles,
995-
&challenges,
996-
w_rlc_expr,
997-
n_threads,
998-
)
999-
.get_ext_field_vec())[..*num_rows]
1000-
.to_vec();
971+
let write_rlc_records =
972+
(wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, w_rlc_expr)
973+
.get_ext_field_vec())[..*num_rows]
974+
.to_vec();
1001975

1002976
if $ram_type == RAMType::GlobalState {
1003977
// w_exprs = [GlobalState, pc, timestamp]
@@ -1012,7 +986,6 @@ Hints:
1012986
&pi_mles,
1013987
&challenges,
1014988
expr,
1015-
n_threads,
1016989
);
1017990
v.get_base_field_vec()[..*num_rows].to_vec()
1018991
})
@@ -1057,16 +1030,10 @@ Hints:
10571030
.zip_eq(cs.r_ram_types.iter())
10581031
.filter(|((_, _), (ram_type, _))| *ram_type == $ram_type)
10591032
{
1060-
let read_records = wit_infer_by_expr(
1061-
fixed,
1062-
witness,
1063-
&pi_mles,
1064-
&challenges,
1065-
r_expr,
1066-
n_threads,
1067-
)
1068-
.get_ext_field_vec()[..*num_rows]
1069-
.to_vec();
1033+
let read_records =
1034+
wit_infer_by_expr(fixed, witness, &pi_mles, &challenges, r_expr)
1035+
.get_ext_field_vec()[..*num_rows]
1036+
.to_vec();
10701037
let mut records = vec![];
10711038
for (row, record) in enumerate(read_records) {
10721039
// TODO: return error

ceno_zkvm/src/scheme/prover.rs

+63-15
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use multilinear_extensions::{
1313
virtual_poly::build_eq_x_r_vec,
1414
virtual_poly_v2::ArcMultilinearExtension,
1515
};
16-
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
16+
use rayon::iter::{
17+
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
18+
};
1719
use sumcheck::{
1820
macros::{entered_span, exit_span},
1921
structs::{IOPProverMessage, IOPProverStateV2},
@@ -25,16 +27,18 @@ use crate::{
2527
error::ZKVMError,
2628
expression::Instance,
2729
scheme::{
28-
constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP},
30+
constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP},
2931
utils::{
3032
infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles,
31-
wit_infer_by_expr,
33+
wit_infer_by_expr, wit_infer_by_expr_pool,
3234
},
3335
},
3436
structs::{
3537
Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses,
3638
},
37-
utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads},
39+
utils::{
40+
SimpleVecPool, get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads,
41+
},
3842
virtual_polys::VirtualPolynomials,
3943
};
4044

@@ -238,6 +242,21 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
238242
let wit_inference_span = entered_span!("wit_inference", profiling_3 = true);
239243
// main constraint: read/write record witness inference
240244
let record_span = entered_span!("record");
245+
let len = witnesses[0].evaluations().len();
246+
let mut pool_e: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
247+
(0..len)
248+
.into_par_iter()
249+
.with_min_len(MIN_PAR_SIZE)
250+
.map(|_| E::ZERO)
251+
.collect::<Vec<E>>()
252+
});
253+
let mut pool_b: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
254+
(0..len)
255+
.into_par_iter()
256+
.with_min_len(MIN_PAR_SIZE)
257+
.map(|_| E::BaseField::ZERO)
258+
.collect::<Vec<E::BaseField>>()
259+
});
241260
let n_threads = max_usable_threads();
242261
let records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
243262
.r_expressions
@@ -246,7 +265,16 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
246265
.chain(cs.lk_expressions.iter())
247266
.map(|expr| {
248267
assert_eq!(expr.degree(), 1);
249-
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads)
268+
wit_infer_by_expr_pool(
269+
&[],
270+
&witnesses,
271+
pi,
272+
challenges,
273+
expr,
274+
n_threads,
275+
&mut pool_e,
276+
&mut pool_b,
277+
)
250278
})
251279
.collect();
252280
let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len());
@@ -526,7 +554,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
526554
// sanity check in debug build and output != instance index for zero check sumcheck poly
527555
if cfg!(debug_assertions) {
528556
let expected_zero_poly =
529-
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr, n_threads);
557+
wit_infer_by_expr(&[], &witnesses, pi, challenges, expr);
530558
let top_100_errors = expected_zero_poly
531559
.get_base_field_vec()
532560
.iter()
@@ -702,21 +730,41 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
702730
let wit_inference_span = entered_span!("wit_inference");
703731
// main constraint: lookup denominator and numerator record witness inference
704732
let record_span = entered_span!("record");
733+
let len = witnesses[0].evaluations().len();
734+
let mut pool_e: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
735+
(0..len)
736+
.into_par_iter()
737+
.with_min_len(MIN_PAR_SIZE)
738+
.map(|_| E::ZERO)
739+
.collect::<Vec<E>>()
740+
});
741+
let mut pool_b: SimpleVecPool<Vec<_>, _> = SimpleVecPool::new(|| {
742+
(0..len)
743+
.into_par_iter()
744+
.with_min_len(MIN_PAR_SIZE)
745+
.map(|_| E::BaseField::ZERO)
746+
.collect::<Vec<E::BaseField>>()
747+
});
705748
let n_threads = max_usable_threads();
706749
let mut records_wit: Vec<ArcMultilinearExtension<'_, E>> = cs
707750
.r_table_expressions
708-
.par_iter()
751+
.iter()
709752
.map(|r| &r.expr)
710-
.chain(cs.w_table_expressions.par_iter().map(|w| &w.expr))
711-
.chain(
712-
cs.lk_table_expressions
713-
.par_iter()
714-
.map(|lk| &lk.multiplicity),
715-
)
716-
.chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values))
753+
.chain(cs.w_table_expressions.iter().map(|w| &w.expr))
754+
.chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity))
755+
.chain(cs.lk_table_expressions.iter().map(|lk| &lk.values))
717756
.map(|expr| {
718757
assert_eq!(expr.degree(), 1);
719-
wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr, n_threads)
758+
wit_infer_by_expr_pool(
759+
&fixed,
760+
&witnesses,
761+
pi,
762+
challenges,
763+
expr,
764+
n_threads,
765+
&mut pool_e,
766+
&mut pool_b,
767+
)
720768
})
721769
.collect();
722770
let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap();

0 commit comments

Comments
 (0)