Skip to content

Optimize keccak circuit #938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ff_ext::ExtensionField;
use gkr_iop::gkr::{Evaluation, GKRCircuit, GKRProverOutput};
use gkr_iop::gkr::GKRProof;
use itertools::Itertools;
use mpcs::PolynomialCommitmentScheme;
use p3::field::PrimeCharacteristicRing;
Expand Down Expand Up @@ -57,11 +57,7 @@ pub struct ZKVMOpcodeProof<E: ExtensionField> {
deserialize = "E::BaseField: DeserializeOwned"
))]
// WARN/TODO: depends on serde's `arc` feature which might not behave correctly
pub struct GKROpcodeProof<E: ExtensionField> {
output_evals: Vec<E>,
prover_output: GKRProverOutput<E, Evaluation<E>>,
circuit: GKRCircuit,
}
pub struct GKROpcodeProof<E: ExtensionField>(pub GKRProof<E>);

#[derive(Clone, Serialize, Deserialize)]
#[serde(bound(
Expand Down
115 changes: 66 additions & 49 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
sync::Arc,
};

use ceno_emul::KeccakSpec;
use ff_ext::ExtensionField;
use gkr_iop::{evaluation::PointAndEval, gkr::GKRCircuitWitness};
use gkr_iop::{
evaluation::PointAndEval,
gkr::{GKRCircuitWitness, GKRProverOutput},
};
use itertools::{Itertools, enumerate, izip};
use mpcs::{Point, PolynomialCommitmentScheme};
use multilinear_extensions::{
Expand All @@ -11,10 +19,6 @@ use multilinear_extensions::{
};
use p3::field::{PrimeCharacteristicRing, dot_product};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
sync::Arc,
};
use sumcheck::{
macros::{entered_span, exit_span},
structs::{IOPProverMessage, IOPProverState},
Expand Down Expand Up @@ -748,54 +752,38 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.collect();
exit_span!(span);

let pcs_open_span = entered_span!("pcs_open", profiling_3 = true);
let opening_dur = std::time::Instant::now();
tracing::debug!(
"[opcode {}]: build opening proof for {} polys",
name,
witnesses.len()
);
tracing::info!(
"[opcode {}] build opening proof took {:?}",
name,
opening_dur.elapsed(),
);
exit_span!(pcs_open_span);

let gkr_span = entered_span!("gkr", profiling_3 = true);
let input_open_point = Arc::new(input_open_point);
let gkr_opcode_proof = if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk {
let gkr_circuit = gkr_iop_pk.vk.get_state().chip.gkr_circuit();

let out_evals = gkr_wit
.layers
.last()
.unwrap()
.bases
.iter()
.map(|base| PointAndEval {
point: input_open_point.clone(),
eval: subprotocols::utils::evaluate_mle_ext(base, &input_open_point),
})
.collect_vec();

let prover_output = gkr_circuit
.prove(gkr_wit, &out_evals, &[], transcript)
.expect("Failed to prove phase");
// unimplemented!("cannot fully handle GKRIOP component yet")
let (gkr_opcode_proof, mut opening_evaluations) =
if let Some((gkr_iop_pk, gkr_wit)) = gkr_iop_pk {
let out_evals = gkr_wit
.layers
.last()
.unwrap()
.bases
.iter()
.map(|base| PointAndEval {
point: input_open_point.clone(),
eval: subprotocols::utils::evaluate_mle_ext(base, &input_open_point),
})
.collect_vec();

let _gkr_open_point = prover_output.opening_evaluations[0].point.clone();
// TODO: open polynomials for GKR proof
let gkr_circuit = gkr_iop_pk.vk.get_state().chip.gkr_circuit();
let prover_output = gkr_circuit
.prove(gkr_wit, &out_evals, &[], transcript)
.expect("Failed to prove phase");
// unimplemented!("cannot fully handle GKRIOP component yet")

let output_evals = out_evals.into_iter().map(|pae| pae.eval).collect_vec();
let GKRProverOutput {
gkr_proof: proof,
opening_evaluations,
} = prover_output;

Some(GKROpcodeProof {
output_evals,
prover_output,
circuit: gkr_circuit,
})
} else {
None
};
(Some(GKROpcodeProof(proof)), opening_evaluations)
} else {
(None, vec![])
};
exit_span!(gkr_span);

// extend with Optio(gkr evals (not combined))
Ok((
Expand Down Expand Up @@ -1402,3 +1390,32 @@ impl TowerProver {
(next_rt, proofs)
}
}

fn process_evaluations<E: ExtensionField>(
evaluations: Vec<gkr_iop::gkr::Evaluation<E>>,
) -> (Vec<Vec<E>>, Vec<mpcs::Evaluation<E>>) {
let mut point_map: HashMap<Vec<E>, usize> = HashMap::new();
let mut point_vec: Vec<Vec<E>> = Vec::new();
let mut result: Vec<mpcs::Evaluation<E>> = Vec::new();

for eval in evaluations {
let point: &Vec<E> = &eval.point; // Arc<Vec<E>> -> &Vec<E>

let index = if let Some(&idx) = point_map.get(point) {
idx
} else {
let new_index = point_vec.len();
point_vec.push(point.clone()); // Clone Vec<E> (not Arc)
point_map.insert(point.clone(), new_index);
new_index
};

result.push(mpcs::Evaluation {
poly: eval.poly,
point: index,
value: eval.value,
});
}

(point_vec, result)
}
1 change: 1 addition & 0 deletions ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ fn test_rw_lk_expression_combination() {
.verify_opcode_proof(
name.as_str(),
verifier.vk.circuit_vks.get(&name).unwrap(),
None,
&proof,
num_instances,
&[],
Expand Down
61 changes: 33 additions & 28 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ use std::{marker::PhantomData, sync::Arc};
use ceno_emul::{KeccakSpec, SyscallSpec};
use ff_ext::ExtensionField;

#[cfg(debug_assertions)]
use ff_ext::{Instrumented, PoseidonField};

use gkr_iop::precompiles::KECCAK_OUT_EVAL_SIZE;
use itertools::{Itertools, interleave, izip};
use mpcs::{Point, PolynomialCommitmentScheme};
use multilinear_extensions::{
Expand All @@ -22,15 +20,20 @@ use witness::next_pow2_instance_padding;
use crate::{
error::ZKVMError,
expression::{Instance, StructuralWitIn},
instructions::{GKRIOPInstruction, riscv::dummy::LargeEcallDummy},
instructions::{GKRIOPInstruction, Instruction, riscv::dummy::LargeEcallDummy},
scheme::{
constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE},
utils::eval_by_expr_with_instance,
},
structs::{PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey},
structs::{
GKRIOPVerifyingKey, KeccakGKRIOP, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey,
},
utils::{eq_eval_less_or_equal_than, eval_wellform_address_vec, get_challenge_pows},
};

#[cfg(debug_assertions)]
use ff_ext::{Instrumented, PoseidonField};

use super::{
ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof, constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE,
};
Expand Down Expand Up @@ -190,11 +193,20 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
for (index, num_instances) in &vm_proof.num_instances {
let circuit_vk = circuit_vks[*index];
let name = circuit_names[*index];

// Only Keccak has non-empty GKR-IOP component
let gkr_iop_vk = if *name == LargeEcallDummy::<E, KeccakSpec>::name() {
Some(&self.vk.keccak_vk)
} else {
None
};

if let Some(opcode_proof) = vm_proof.opcode_proofs.get(index) {
transcript.append_field_element(&E::BaseField::from_u64(*index as u64));
rt_points.push(self.verify_opcode_proof(
name,
circuit_vk,
gkr_iop_vk,
opcode_proof,
*num_instances,
pi_evals,
Expand All @@ -214,6 +226,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
dummy_table_item_multiplicity += num_padded_lks_per_instance * num_instances
+ num_lks.next_power_of_two() * num_padded_instance;

tracing::info!("verified proof for opcode {}", name);
prod_r *= opcode_proof
.record_r_out_evals
.iter()
Expand Down Expand Up @@ -333,6 +346,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
&self,
name: &str,
circuit_vk: &VerifyingKey<E>,
gkr_iop_vk: Option<&GKRIOPVerifyingKey<E, PCS, KeccakGKRIOP<E>>>,
proof: &ZKVMOpcodeProof<E>,
num_instances: usize,
pi: &[E],
Expand Down Expand Up @@ -484,34 +498,25 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
.clone()
.expect("Keccak syscall should contain GKR-IOP proof");

// Match output_evals with EcallDummy polynomials
for (i, gkr_out_eval) in gkr_iop.output_evals.iter().enumerate() {
assert_eq!(
*gkr_out_eval,
proof.wits_in_evals[LargeEcallDummy::<E, KeccakSpec>::output_evals_map(i)],
"{i}"
);
}
// Verify GKR proof
let point = Arc::new(input_opening_point.clone());
let out_evals = gkr_iop
.output_evals
.iter()
.map(|eval| gkr_iop::evaluation::PointAndEval {
point: point.clone(),
eval: *eval,
let out_evals = (0..KECCAK_OUT_EVAL_SIZE)
.map(|i| {
let eval =
proof.wits_in_evals[LargeEcallDummy::<E, KeccakSpec>::output_evals_map(i)];
gkr_iop::evaluation::PointAndEval {
point: point.clone(),
eval,
}
})
.collect_vec();

gkr_iop
.circuit
.verify(
gkr_iop.prover_output.gkr_proof.clone(),
&out_evals,
&[],
transcript,
)
.expect("GKR-IOP verify failure");
if let Some(gkr_iop_vk) = gkr_iop_vk {
let gkr_circuit = gkr_iop_vk.get_state().chip.gkr_circuit();
gkr_circuit
.verify(gkr_iop.0.clone(), &out_evals, &[], transcript)
.expect("GKR-IOP verify failure");
}
}

// derive r_records, w_records, lk_records from witness's evaluations
Expand Down
13 changes: 9 additions & 4 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>, State: Default> Defa
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GKRIOPVerifyingKey<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>, State> {
pub(crate) state: State,
pub fixed_commit: Option<PCS::Commitment>,
Expand All @@ -172,7 +172,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>, State>
}
}

#[derive(Clone, Default, Debug)]
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct KeccakGKRIOP<E> {
pub chip: gkr_iop::chip::Chip,
pub layout: KeccakLayout<E>,
Expand Down Expand Up @@ -556,6 +556,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProvingKey<E, PC
.iter()
.map(|(name, pk)| (name.clone(), pk.vk.clone()))
.collect(),
keccak_vk: self.keccak_pk.vk.clone(),
fixed_commit: self.fixed_commit.clone(),
// expression for global state in/out
initial_global_state_expr: self.initial_global_state_expr.clone(),
Expand All @@ -569,12 +570,16 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProvingKey<E, PC
}
}

#[derive(Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound = "E: ExtensionField + DeserializeOwned")]
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound(
serialize = "E::BaseField: Serialize, GKRIOPVerifyingKey<E, PCS, KeccakGKRIOP<E>>: Serialize",
deserialize = "E::BaseField: DeserializeOwned, GKRIOPVerifyingKey<E, PCS, KeccakGKRIOP<E>>: DeserializeOwned",
))]
pub struct ZKVMVerifyingKey<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
pub vp: PCS::VerifierParam,
// vk for opcode and table circuits
pub circuit_vks: BTreeMap<String, VerifyingKey<E>>,
pub keccak_vk: GKRIOPVerifyingKey<E, PCS, KeccakGKRIOP<E>>,
pub fixed_commit: Option<<PCS as PolynomialCommitmentScheme<E>>::Commitment>,
// expression for global state in/out
pub initial_global_state_expr: Expression<E>,
Expand Down
4 changes: 3 additions & 1 deletion gkr_iop/src/chip.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use serde::{Deserialize, Serialize};

use crate::{evaluation::EvalExpression, gkr::layer::Layer};

pub mod builder;
pub mod protocol;

/// Chip stores all information required in the GKR protocol, including the
/// commit phases, the GKR phase and the opening phase.
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Chip {
/// The number of base inputs committed in the whole protocol.
pub n_committed_bases: usize,
Expand Down
13 changes: 7 additions & 6 deletions gkr_iop/src/precompiles/lookup_keccakf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{
precompiles::utils::{MaskRepresentation, nest, not8_expr, zero_expr},
};
use ndarray::{ArrayView, Ix2, Ix3, s};
use serde::{Deserialize, Serialize};

use super::utils::{CenoLookup, u64s_to_felts, zero_eval};
use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField};
Expand All @@ -24,10 +25,10 @@ use transcript::BasicTranscript;

type E = BinomialExtensionField<Goldilocks, 2>;

#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct KeccakParams {}

#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct KeccakLayout<E> {
_params: KeccakParams,
_input_columns: Vec<usize>,
Expand Down Expand Up @@ -310,6 +311,9 @@ pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND;
pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND;
pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND;

pub const KECCAK_OUT_EVAL_SIZE: usize =
KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS; // 200 (5 * 5 * 8, kecak input bytes) + 200 (keccak output bytes) + 24 * 1656 (round auxiliary witnesses)

macro_rules! allocate_and_split {
($chip:expr, $total:expr, $( $size:expr ),* ) => {{
let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>();
Expand Down Expand Up @@ -1135,10 +1139,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo
})
.collect_vec();

assert_eq!(
out_evals.len(),
KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS
);
assert_eq!(out_evals.len(), KECCAK_OUT_EVAL_SIZE);

out_evals
};
Expand Down
Loading
Loading