Skip to content

Commit c3914bc

Browse files
committed
Fix VID verifier
1 parent e50297d commit c3914bc

File tree

1 file changed

+155
-36
lines changed

1 file changed

+155
-36
lines changed

saffron/src/vid.rs

Lines changed: 155 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use rand::{CryptoRng, RngCore};
1717
use rayon::iter::{
1818
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
1919
};
20+
use std::time::Instant;
2021

2122
#[derive(Debug, Clone)]
2223
pub struct VIDProof {
@@ -163,6 +164,7 @@ pub fn divide_by_sub_vanishing_poly(
163164
pub fn prove_vid_ipa<RNG>(
164165
srs: &SRS<Curve>,
165166
domain: EvaluationDomains<ScalarField>,
167+
bases_d2: &[DensePolynomial<ScalarField>],
166168
group_map: &<Curve as CommitmentCurve>::Map,
167169
rng: &mut RNG,
168170
per_node_size: usize,
@@ -193,13 +195,16 @@ where
193195
};
194196

195197
let combined_data_poly: DensePolynomial<ScalarField> =
196-
Evaluations::from_vec_and_domain(combined_data, domain.d1).interpolate_by_ref();
198+
Evaluations::from_vec_and_domain(combined_data.clone(), domain.d1).interpolate_by_ref();
197199

198200
let combined_data_d2 = combined_data_poly.evaluate_over_domain_by_ref(domain.d2);
199201

200202
let mut proofs: Vec<VIDProof> = vec![];
201203
let proofs_number = domain.d2.size() / per_node_size;
202204

205+
println!("proofs_number: {:?}", proofs_number);
206+
println!("per_node_size: {:?}", per_node_size);
207+
203208
let fq_sponge_common = fq_sponge.clone();
204209

205210
println!("computing all divisors");
@@ -229,23 +234,29 @@ where
229234
.collect();
230235

231236
// divisors with cosets
237+
// div_i(X) = X^per_node_size - w^{i*per_node_size}
238+
// div_i(X) is supposed to be zero on all elements from coset i
232239
let all_divisors: Vec<DensePolynomial<ScalarField>> = (0..proofs_number)
233240
.into_par_iter()
234-
.map(|i| {
241+
.map(|node_ix| {
235242
let mut res = DensePolynomial {
236243
coeffs: vec![ScalarField::zero(); per_node_size + 1],
237244
};
238-
res[0] = -all_omegas[i * per_node_size];
245+
res[0] = -all_omegas[node_ix * per_node_size];
239246
res[per_node_size] = ScalarField::one();
240247
res
241248
})
242249
.collect();
243250

244-
assert!(all_divisors[5].evaluate(&all_omegas[per_node_size + 5]) == ScalarField::zero());
251+
for i in 0..3 {
252+
assert!(all_divisors[i].evaluate(&all_omegas[proofs_number + i]) == ScalarField::zero());
253+
}
245254

246255
for node_ix in 0..proofs_number {
256+
let start = Instant::now();
257+
247258
// TEMPORARILY skip most iterations
248-
if node_ix > 4 {
259+
if node_ix > 1 {
249260
continue;
250261
}
251262

@@ -254,10 +265,11 @@ where
254265
.map(|j| j * proofs_number + node_ix)
255266
.collect();
256267

257-
let coset_omega = all_omegas[node_ix * per_node_size].clone();
268+
// c such that (X^N - c) = 0 for all elements in the current coset
269+
let coset_divisor_coeff = all_omegas[node_ix * per_node_size].clone();
258270

259271
let coeff_powers: Vec<_> = (0..proofs_number)
260-
.map(|i| coset_omega.pow([i as u64]))
272+
.map(|i| coset_divisor_coeff.pow([i as u64]))
261273
.collect();
262274

263275
for j in indices.iter() {
@@ -271,8 +283,8 @@ where
271283
// p(X) - \prod L_i(X) e_i
272284
let numerator_eval: Evaluations<ScalarField, R2D<ScalarField>> = {
273285
let mut res = combined_data_d2.clone();
274-
for i in indices {
275-
res.evals[i] = ScalarField::zero();
286+
for i in indices.iter() {
287+
res.evals[*i] = ScalarField::zero();
276288
}
277289
res
278290
};
@@ -288,6 +300,13 @@ where
288300
per_node_size,
289301
&coeff_powers,
290302
);
303+
304+
let divisor_poly = {
305+
let mut coeffs = vec![ScalarField::zero(); per_node_size + 1];
306+
coeffs[0] = -coset_divisor_coeff.clone();
307+
coeffs[per_node_size] = ScalarField::one();
308+
DensePolynomial { coeffs }
309+
};
291310
// let (quotient, res) = DenseOrSparsePolynomial::divide_with_q_and_r(
292311
// &From::from(numerator_eval_interpolated),
293312
// &From::from(all_divisors[i].clone()),
@@ -303,6 +322,8 @@ where
303322
// fail_final_q_division();
304323
// }
305324

325+
assert!(&quotient * &divisor_poly == numerator_eval_interpolated);
326+
306327
quotient
307328
};
308329

@@ -322,7 +343,8 @@ where
322343
// num_chunks = 1 because our constraint is degree 2, which makes the quotient polynomial of degree d1
323344
let quotient_comm: Vec<Curve> = srs.commit_non_hiding(&quotient_poly, 2).chunks;
324345

325-
fq_sponge = fq_sponge_common.clone(); // reset the sponge
346+
let mut fq_sponge = fq_sponge_common.clone(); // reset the sponge
347+
println!("Prover, Quotient comm: {:?}", quotient_comm);
326348
fq_sponge.absorb_g(&quotient_comm);
327349

328350
// aka zeta
@@ -341,13 +363,40 @@ where
341363
println!("evals");
342364

343365
let combined_data_eval = combined_data_poly.evaluate(&evaluation_point);
344-
//let quotient_eval = quotient_poly.evaluate(&evaluation_point);
366+
let quotient_eval = quotient_poly.evaluate(&evaluation_point);
345367
let quotient_eval_1 = quotient_poly_1.evaluate(&evaluation_point);
346368
let quotient_eval_2 = quotient_poly_2.evaluate(&evaluation_point);
347-
//println!("Prover, quotient eval: {:?}", quotient_eval);
369+
println!("Prover, quotient eval: {:?}", quotient_eval);
348370
println!("Prover, quotient eval 1: {:?}", quotient_eval_1);
349371
println!("Prover, quotient eval 2: {:?}", quotient_eval_2);
350372

373+
assert!(
374+
quotient_eval
375+
== quotient_eval_1 + evaluation_point.pow([srs.size() as u64]) * quotient_eval_2
376+
);
377+
378+
// Sanity check for verification
379+
if node_ix == 1 {
380+
let combined_data_at_ixs: Vec<ScalarField> = indices
381+
.iter()
382+
.map(|&i| combined_data_d2[i].clone())
383+
.collect();
384+
385+
let quotient_eval_alt = {
386+
let divisor_poly_at_zeta: ScalarField =
387+
evaluation_point.pow([per_node_size as u64]) - coset_divisor_coeff;
388+
389+
let mut eval = -combined_data_eval;
390+
for (lagrange, data_eval) in bases_d2.iter().zip(combined_data_at_ixs.iter()) {
391+
eval += lagrange.evaluate(&evaluation_point) * data_eval;
392+
}
393+
eval = ScalarField::zero() - eval;
394+
eval = eval * divisor_poly_at_zeta.inverse().unwrap();
395+
eval
396+
};
397+
assert!(quotient_eval_alt == quotient_eval);
398+
}
399+
351400
for eval in [combined_data_eval, quotient_eval_1, quotient_eval_2].into_iter() {
352401
fr_sponge.absorb(&eval);
353402
}
@@ -395,7 +444,12 @@ where
395444
quotient_evals: vec![quotient_eval_1, quotient_eval_2],
396445
combined_data_eval,
397446
opening_proof,
398-
})
447+
});
448+
449+
let duration = start.elapsed();
450+
451+
let millis = duration.as_millis();
452+
println!("Prover time elapsed: {} ms", millis);
399453
}
400454

401455
proofs
@@ -409,22 +463,27 @@ pub fn verify_vid_ipa<RNG>(
409463
rng: &mut RNG,
410464
per_node_size: usize,
411465
node_ix: &usize,
466+
verifier_indices: &[usize],
412467
proof: &VIDProof,
413468
data_comms: &[Curve],
414469
data: &[Vec<ScalarField>],
415470
) -> bool
416471
where
417472
RNG: RngCore + CryptoRng,
418473
{
474+
let start = Instant::now();
475+
419476
let mut fq_sponge = CurveFqSponge::new(Curve::other_curve_sponge_params());
420477
fq_sponge.absorb_g(&data_comms);
421478

422479
let recombination_point = fq_sponge.challenge();
423480
println!("Verifier, recombination point: {:?}", recombination_point);
424481

482+
println!("combining data commitments");
425483
let combined_data_commitment =
426484
crate::utils::aggregate_commitments(recombination_point, data_comms);
427485

486+
println!("Verifier, Quotient comm: {:?}", proof.quotient_comm);
428487
fq_sponge.absorb_g(&proof.quotient_comm);
429488

430489
let evaluation_point = fq_sponge.challenge();
@@ -446,14 +505,62 @@ where
446505
initial
447506
};
448507

508+
println!("Computing alt lagrange");
509+
//let all_omegas: Vec<ScalarField> = (0..domain.d2.size())
510+
// .into_par_iter()
511+
// .map(|i| domain.d2.group_gen.pow([i as u64]))
512+
// .collect();
513+
//let denominators: Vec<ScalarField> = {
514+
// let mut res: Vec<_> = verifier_indices
515+
// .iter()
516+
// .map(|i| {
517+
// let mut acc = ScalarField::zero();
518+
// for j in 0..domain.d2.size() {
519+
// if j != *i {
520+
// acc *= all_omegas[*i] - all_omegas[j]
521+
// }
522+
// }
523+
// acc
524+
// })
525+
// .collect();
526+
// ark_ff::batch_inversion(&mut res);
527+
// res
528+
//};
529+
//let nominator_total: ScalarField = all_omegas
530+
// .clone()
531+
// .into_par_iter()
532+
// .map(|omega_i| evaluation_point - omega_i)
533+
// .reduce_with(|mut l, r| {
534+
// l *= r;
535+
// l
536+
// })
537+
// .unwrap();
538+
539+
//let nominator_diffs: Vec<ScalarField> = {
540+
// let mut res: Vec<_> = verifier_indices
541+
// .iter()
542+
// .map(|i| evaluation_point - all_omegas[*i])
543+
// .collect();
544+
// ark_ff::batch_inversion(&mut res);
545+
// res
546+
//};
547+
548+
//for (i, lagrange) in bases_d2.iter().enumerate() {
549+
// assert!(
550+
// lagrange.evaluate(&evaluation_point)
551+
// == nominator_total * nominator_diffs[i] * denominators[i]
552+
// );
553+
//}
554+
555+
let coset_divisor_coeff = domain.d2.group_gen.pow([(node_ix * per_node_size) as u64]);
556+
557+
println!("quotient eval");
449558
let quotient_eval = {
450-
let divisor_poly_at_zeta: ScalarField = {
451-
evaluation_point.pow([per_node_size as u64])
452-
- domain.d2.group_gen.pow([*node_ix as u64])
453-
};
559+
let divisor_poly_at_zeta: ScalarField =
560+
evaluation_point.pow([per_node_size as u64]) - coset_divisor_coeff;
454561

455562
let mut eval = -proof.combined_data_eval;
456-
for (lagrange, data_eval) in bases_d2.iter().zip(combined_data.iter()) {
563+
for (i, (lagrange, data_eval)) in bases_d2.iter().zip(combined_data.iter()).enumerate() {
457564
eval += lagrange.evaluate(&evaluation_point) * data_eval;
458565
}
459566
eval = ScalarField::zero() - eval;
@@ -516,7 +623,8 @@ where
516623
combined_inner_product(&polyscale, &evalscale, evaluations.as_slice())
517624
};
518625

519-
srs.verify(
626+
println!("verifying IPA");
627+
let res = srs.verify(
520628
group_map,
521629
&mut [BatchEvaluationProof {
522630
sponge: fq_sponge,
@@ -528,7 +636,14 @@ where
528636
combined_inner_product,
529637
}],
530638
rng,
531-
)
639+
);
640+
641+
let duration = start.elapsed();
642+
643+
let millis = duration.as_millis();
644+
println!("Verifier time elapsed: {} ms", millis);
645+
646+
res
532647
}
533648

534649
#[cfg(test)]
@@ -566,14 +681,14 @@ mod tests {
566681
EvaluationDomains::<ScalarField>::create(srs.size()).unwrap();
567682
let group_map = <Vesta as CommitmentCurve>::Map::setup();
568683

569-
let number_of_coms = 32;
570-
let per_node_size = 2048;
684+
let number_of_coms = 8;
685+
let per_node_size = 4;
571686
let proofs_number = domain.d2.size() / per_node_size;
572687

573688
//let verifier_indices: Vec<usize> = generate_unique_u64(512, 1 << 17);
574689
//let verifier_indices: Vec<usize> = (0..(domain.d2.size() / per_node_size)).collect();
575690
//let verifier_indices: Vec<usize> = (0..per_node_size).map(|j| j * (i + 1)).collect();
576-
let verifier_ix = 0; // we're testing verifier 0
691+
let verifier_ix = 1; // we're testing verifier 0
577692
let verifier_indices: Vec<usize> = (0..per_node_size)
578693
.map(|j| j * proofs_number + verifier_ix)
579694
.collect();
@@ -612,6 +727,7 @@ mod tests {
612727
let proofs = prove_vid_ipa(
613728
&srs,
614729
domain,
730+
&bases_d2,
615731
&group_map,
616732
&mut rng,
617733
per_node_size,
@@ -632,19 +748,22 @@ mod tests {
632748
})
633749
.collect();
634750

635-
let res = verify_vid_ipa(
636-
&srs,
637-
domain,
638-
&bases_d2,
639-
&group_map,
640-
&mut rng,
641-
per_node_size,
642-
&verifier_ix,
643-
&proofs[0],
644-
&data_comms,
645-
&expanded_data_at_ixs,
646-
);
647-
assert!(res, "proof must verify")
751+
for i in 0..4 {
752+
let res = verify_vid_ipa(
753+
&srs,
754+
domain,
755+
&bases_d2,
756+
&group_map,
757+
&mut rng,
758+
per_node_size,
759+
&verifier_ix,
760+
&verifier_indices,
761+
&proofs[verifier_ix],
762+
&data_comms,
763+
&expanded_data_at_ixs,
764+
);
765+
assert!(res, "proof must verify")
766+
}
648767
}
649768

650769
#[test]

0 commit comments

Comments
 (0)