diff --git a/.gitignore b/.gitignore index d794a5dc57..aceaa65679 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ Cargo.lock **/.env .DS_Store +# Log outputs +*.log + .cache/ rustc-* diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index aa8fc843cb..5c6faac19a 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -23,11 +23,11 @@ use crate::{ }; pub struct AggStarkProver> { - leaf_prover: VmLocalProver, - leaf_controller: LeafProvingController, + pub leaf_prover: VmLocalProver, + pub leaf_controller: LeafProvingController, - internal_prover: VmLocalProver, - root_prover: RootVerifierLocalProver, + pub internal_prover: VmLocalProver, + pub root_prover: RootVerifierLocalProver, pub num_children_internal: usize, pub max_internal_wrapper_layers: usize, diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index 1ee1af6885..eecf6370e0 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -15,9 +15,7 @@ use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_native_compiler::{ - CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, - NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE, + CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::{ @@ -29,10 +27,7 @@ use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; use crate::{ - adapters::{convert_adapter::ConvertAdapterChip, *}, - poseidon2::chip::NativePoseidon2Chip, - phantom::*, - *, + adapters::{convert_adapter::ConvertAdapterChip, *}, phantom::*, poseidon2::chip::NativePoseidon2Chip, sumcheck::chip::NativeSumcheckChip, * }; #[derive(Clone, Debug, Serialize, Deserialize, VmConfig, derive_new::new)] @@ -76,6 +71,7 @@ pub enum NativeExecutor { FieldExtension(FieldExtensionChip), FriReducedOpening(FriReducedOpeningChip), VerifyBatch(NativePoseidon2Chip), + SumcheckLayerEval(NativeSumcheckChip), } #[derive(From, ChipUsageGetter, Chip, AnyEnum)] @@ -207,6 +203,17 @@ impl VmExtension for Native { ], )?; + let sumcheck_chip = NativeSumcheckChip::new( + builder.system_port(), + offline_memory.clone(), + ); + inventory.add_executor( + sumcheck_chip, + [ + SumcheckOpcode::SUMCHECK_LAYER_EVAL.global_opcode(), + ] + )?; + builder.add_phantom_sub_executor( NativeHintInputSubEx, PhantomDiscriminant(NativePhantom::HintInput as u16), diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index d8c83fabdd..ea06ecb4c9 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -245,10 +245,10 @@ impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where - V: Copy, + V: Clone, V: Add, { - array::from_fn(|i| x[i] + y[i]) + array::from_fn(|i| x[i].clone() + y[i].clone()) } pub(crate) fn subtract(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 7dbc3fd851..eabf22ef5c 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -538,7 +538,7 @@ fn assert_array_eq, I2: Into, const } } -fn elem_to_ext(elem: F) -> [F; EXT_DEG] { +pub fn elem_to_ext(elem: F) -> [F; EXT_DEG] { let mut ret = [F::ZERO; EXT_DEG]; ret[0] = elem; ret diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index 46c6bc890f..89261f73d5 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -8,6 +8,7 @@ mod fri; mod jal; mod loadstore; mod poseidon2; +mod sumcheck; pub use branch_eq::*; pub use castf::*; @@ -17,6 +18,7 @@ pub use fri::*; pub use jal::*; pub use loadstore::*; pub use poseidon2::*; +pub use sumcheck::*; mod extension; pub use extension::*; diff --git a/extensions/native/circuit/src/poseidon2/trace.rs b/extensions/native/circuit/src/poseidon2/trace.rs index 27ffe858a9..b212df67a6 100644 --- a/extensions/native/circuit/src/poseidon2/trace.rs +++ b/extensions/native/circuit/src/poseidon2/trace.rs @@ -15,9 +15,9 @@ use openvm_stark_backend::{ }; use crate::{ - chip::TranscriptObservationRecord, poseidon2::{ + poseidon2::{ chip::{ - CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, NativePoseidon2Chip, SimplePoseidonRecord, VerifyBatchRecord, NUM_INITIAL_READS + TranscriptObservationRecord, CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, NativePoseidon2Chip, SimplePoseidonRecord, VerifyBatchRecord, NUM_INITIAL_READS }, columns::{ InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, TopLevelSpecificCols diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs new file mode 100644 index 0000000000..050a28c5e9 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -0,0 +1,423 @@ +use std::{array::from_fn, borrow::Borrow, sync::Arc}; +use openvm_circuit::{ + arch::{ContinuationVmProof, ExecutionBridge, ExecutionState}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress}, +}; +use openvm_circuit_primitives::utils::{assert_array_eq, not}; +use openvm_instructions::LocalOpcode; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::{ + air_builders::sub::SubAirBuilder, + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; +use crate::{sumcheck::columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, FieldExtension, EXT_DEG}; + +#[derive(Clone, Debug)] +pub struct NativeSumcheckAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, + pub address_space: F, +} + +impl BaseAir for NativeSumcheckAir { + fn width(&self) -> usize { + NativeSumcheckCols::::width() + } +} + +impl BaseAirWithPublicValues + for NativeSumcheckAir +{ +} + +impl PartitionedBaseAir + for NativeSumcheckAir +{ +} + +impl Air + for NativeSumcheckAir +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &NativeSumcheckCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &NativeSumcheckCols = (*next).borrow(); + + let &NativeSumcheckCols { + // Row indicators + header_row, + prod_row, + logup_row, + + // Whether valid prod/logup row operations follow this row + header_continuation, + prod_continuation, + logup_continuation, + + // Round limit + prod_row_within_max_round, + logup_row_within_max_round, + + // What type of evaluation is performed + prod_in_round_evaluation, + prod_next_round_evaluation, + logup_in_round_evaluation, + logup_next_round_evaluation, + + // Indicates whether the round evaluations should be added to the accumulator + prod_acc, + logup_acc, + + // Timestamps + first_timestamp, + start_timestamp, + last_timestamp, + + // Results from reading registers + register_ptrs, + ctx, + prod_nested_len, + logup_nested_len, + + // Challenges + alpha, + challenges, + + curr_prod_n, + curr_logup_n, + + max_round, + within_round_limit, + should_acc, + eval_acc, + specific, + } = local; + + builder.assert_bool(header_row); + builder.assert_bool(prod_row); + builder.assert_bool(logup_row); + builder.assert_bool(header_continuation); + builder.assert_bool(prod_continuation); + builder.assert_bool(logup_continuation); + builder.assert_bool(prod_row_within_max_round); + builder.assert_bool(logup_row_within_max_round); + builder.assert_bool(prod_in_round_evaluation); + builder.assert_bool(logup_in_round_evaluation); + let enabled = header_row + prod_row + logup_row; + builder.assert_bool(enabled.clone()); + let in_round = ctx[7]; + let continuation = header_continuation + prod_continuation + logup_continuation; + builder.assert_bool(continuation.clone()); + + // Randomness transition + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().expect(""); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{EXT_DEG * 2}].try_into().expect(""); + let c2: [_; EXT_DEG] = challenges[{EXT_DEG * 2}..{EXT_DEG * 3}].try_into().expect(""); + let alpha2: [_; EXT_DEG] = challenges[{EXT_DEG * 3}..{EXT_DEG * 4}].try_into().expect(""); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().expect(""); + + // Carry along columns + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), register_ptrs, next.register_ptrs); + assert_array_eq(&mut builder.when(next.prod_row + next.logup_row), ctx, next.ctx); + assert_array_eq::<_, _, _, {EXT_DEG * 2}>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect("") + ); + builder.when(next.prod_row + next.logup_row).assert_eq(prod_nested_len, next.prod_nested_len); + builder.when(next.prod_row + next.logup_row).assert_eq(logup_nested_len, next.logup_nested_len); + + // Row transition + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(ctx[1]); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(ctx[1], curr_prod_n); + builder + .when(prod_row) + .when(not(prod_continuation)) + .assert_eq(ctx[1], curr_prod_n); + builder + .when(logup_row) + .when(not(logup_continuation)) + .assert_eq(ctx[2], curr_logup_n); + + // Timestamp transition + builder + .when(header_row) + .when(next.prod_row + next.logup_row) + .assert_eq(next.start_timestamp, start_timestamp + AB::F::from_canonical_usize(7)); + builder + .when(prod_row) + .when(next.prod_row + next.logup_row) + .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO); + builder + .when(logup_row) + .when(next.prod_row + next.logup_row) + .assert_eq(next.start_timestamp, start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3)); + + // Termination condition + assert_array_eq(&mut builder.when::(not(continuation)), eval_acc, [AB::F::ZERO; 4]); + + // Randomness transition + assert_array_eq(&mut builder.when(header_continuation), next.challenges[0..EXT_DEG].try_into().expect(""), [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO]); + let alpha_denominator = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_row), alpha_denominator, alpha2); + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_continuation), prod_next_alpha, next_alpha1); + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_continuation), logup_next_alpha, next_alpha1); + + // Header + let header_row_specific: &HeaderSpecificCols = + specific[..HeaderSpecificCols::::width()].borrow(); + let registers = header_row_specific.registers; + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), + [ + registers[4].into(), + registers[0].into(), + registers[1].into(), + self.address_space.into(), + self.address_space.into(), + registers[2].into(), + registers[3].into(), + ], + ExecutionState::new(header_row_specific.pc, first_timestamp), + last_timestamp - first_timestamp, + ) + .eval(builder, header_row); + + // Read registers + for i in 0..5usize { + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, registers[i]), + [register_ptrs[i]], + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], + ) + .eval(builder, header_row); + } + + // React ctx + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0]), + ctx, + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], + ) + .eval(builder, header_row); + + // Read challenges + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[1]), + challenges, + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[6], + ) + .eval(builder, header_row); + + // Write final result + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4], + ), + eval_acc, + last_timestamp - AB::F::ONE, + &header_row_specific.write_records, + ) + .eval(builder, header_row); + + // Prod spec evaluation + let prod_row_specific: &ProdSpecificCols = + specific[..ProdSpecificCols::::width()].borrow(); + let next_prod_row_specific: &ProdSpecificCols = + next.specific[..ProdSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0] + AB::F::from_canonical_usize(EXT_DEG * 2) + (curr_prod_n - AB::F::ONE)), // curr_prod_n starts at 1. + [max_round], + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row); + + builder + .when(prod_row_within_max_round) + .assert_eq(prod_row_specific.data_ptr, (prod_nested_len * (curr_prod_n - AB::F::ONE) + ctx[4] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); + builder + .assert_eq(prod_row * prod_row_within_max_round * in_round, prod_in_round_evaluation); + builder + .assert_eq(prod_row * prod_row_within_max_round * not(in_round), prod_next_round_evaluation); + builder + .assert_eq(prod_row * should_acc, prod_acc); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[2] + prod_row_specific.data_ptr, + ), + prod_row_specific.p, + start_timestamp + AB::F::ONE, + &prod_row_specific.read_records[1], + ) + .eval(builder, prod_row_within_max_round); + + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().expect(""); + let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &prod_row_specific.write_record, + ) + .eval(builder, prod_row_within_max_round); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::multiply::(p1, p2); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_in_round_evaluation), in_round_p_evals, prod_row_specific.p_evals); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_next_round_evaluation), next_round_p_evals, prod_row_specific.p_evals); + + // Accumulate evaluation + let acc_eval = FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(prod_acc), prod_row_specific.acc_eval, acc_eval); + + let next_acc = FieldExtension::subtract( + eval_acc, + next_prod_row_specific.acc_eval, + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.prod_acc), next.eval_acc, next_acc); + + // Logup spec evaluation + let logup_row_specific: &LogupSpecificCols = + specific[..LogupSpecificCols::::width()].borrow(); + let next_logup_row_specfic: &LogupSpecificCols = + next.specific[..LogupSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, register_ptrs[0] + AB::F::from_canonical_usize(EXT_DEG * 2) + ctx[1] + (curr_logup_n - AB::F::ONE)), // curr_logup_n starts at 1. + [max_round], + start_timestamp, + &logup_row_specific.read_records[0], + ) + .eval(builder, logup_row); + + builder + .when(logup_row_within_max_round) + .assert_eq(logup_row_specific.data_ptr, (logup_nested_len * (curr_logup_n - AB::F::ONE) + ctx[6] * ctx[0]) * AB::F::from_canonical_usize(EXT_DEG)); + builder + .assert_eq(logup_row * logup_row_within_max_round * in_round, logup_in_round_evaluation); + builder + .assert_eq(logup_row * logup_row_within_max_round * not(in_round), logup_next_round_evaluation); + builder + .assert_eq(logup_row * should_acc, logup_acc); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + register_ptrs[3] + logup_row_specific.data_ptr, + ), + logup_row_specific.pq, + start_timestamp + AB::F::ONE, + &logup_row_specific.read_records[1], + ) + .eval(builder, logup_row_within_max_round); + + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().expect(""); + let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{EXT_DEG * 3}].try_into().expect(""); + let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + (ctx[1] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &logup_row_specific.write_records[0], + ) + .eval(builder, logup_row_within_max_round); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + register_ptrs[4] + (ctx[1] + ctx[2] + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.q_evals, + start_timestamp + AB::F::from_canonical_usize(3), + &logup_row_specific.write_records[1], + ) + .eval(builder, logup_row_within_max_round); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, q2), + FieldExtension::multiply::(p2, q1), + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_in_round_evaluation), in_round_p_evals, logup_row_specific.p_evals); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_next_round_evaluation), next_round_p_evals, logup_row_specific.p_evals); + + let next_round_q_evals = FieldExtension::add( + FieldExtension::multiply::(q1, c1), + FieldExtension::multiply::(q2, c2), + ); + let in_round_q_evals = FieldExtension::multiply::(q1, q2); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_in_round_evaluation), in_round_q_evals, logup_row_specific.q_evals); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_next_round_evaluation), next_round_q_evals, logup_row_specific.q_evals); + + // Accumulate evaluation + let acc_eval = FieldExtension::add( + FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), + FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(logup_acc), logup_row_specific.acc_eval, acc_eval); + + let next_acc = FieldExtension::subtract( + eval_acc, + next_logup_row_specfic.acc_eval, + ); + assert_array_eq::<_, _, _, EXT_DEG>(&mut builder.when(next.logup_acc), next.eval_acc, next_acc); + } +} \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs new file mode 100644 index 0000000000..f99a8e02a3 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -0,0 +1,401 @@ +use std::sync::{Arc, Mutex}; +use openvm_circuit::{ + arch::{ + ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort, + }, + system::memory::{MemoryController, OfflineMemory, RecordId}, +}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_stark_backend::{ + p3_field::{Field, PrimeField, PrimeField32}, + p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, +}; +use crate::{fri::elem_to_ext, sumcheck::columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}; +use openvm_native_compiler::{ + conversion::AS, + SumcheckOpcode::SUMCHECK_LAYER_EVAL, +}; +use crate::sumcheck::air::NativeSumcheckAir; +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + utils::const_max, +}; +use serde::{Deserialize, Serialize}; +const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; + +#[repr(C)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(bound = "F: Field")] +pub struct SumcheckEvalRecord { + pub from_state: ExecutionState, + pub instruction: Instruction, + pub row_type: usize, // 0 - header; 1 - prod; 2 - logup + pub curr_timestamp_increment: usize, + pub final_timestamp_increment: usize, + pub continuation: bool, + + pub register_ptrs: [F; 5], + pub registers: [F; 5], + pub ctx: [F; EXT_DEG * 2], + pub challenges: [F; EXT_DEG * 4], + pub read_data_records: [RecordId; 7], + pub write_data_records: [RecordId; 2], + + pub max_round: F, + pub within_round_limit: bool, + pub should_acc: bool, + pub prod_spec_n: usize, + pub logup_spec_n: usize, + pub alpha: [F; EXT_DEG], + pub alpha1: [F; EXT_DEG], + pub alpha2: [F; EXT_DEG], + pub data_ptr: F, + pub p1: [F; EXT_DEG], + pub p2: [F; EXT_DEG], + pub q1: [F; EXT_DEG], + pub q2: [F; EXT_DEG], + pub p_evals: [F; EXT_DEG], + pub q_evals: [F; EXT_DEG], + pub eval_acc: [F; EXT_DEG], + pub acc_eval: [F; EXT_DEG], +} + +fn calculate_3d_ext_idx( + inner_inner_len: F, + inner_len: F, + outer_idx: F, + inner_idx: F, + inner_inner_idx: F, +) -> F { + (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) * F::from_canonical_usize(EXT_DEG) +} + +pub struct NativeSumcheckChip { + pub height: usize, + pub(super) air: NativeSumcheckAir, + pub(super) offline_memory: Arc>>, + pub record_set: Vec>, +} + +impl NativeSumcheckChip { + pub fn new( + port: SystemPort, + offline_memory: Arc>>, + ) -> Self { + let air = NativeSumcheckAir { + execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), + memory_bridge: port.memory_bridge, + address_space: F::from_canonical_u32(AS::Native as u32), + }; + + Self { + height: 0, + air, + offline_memory, + record_set: Default::default(), + } + } +} + +impl InstructionExecutor for NativeSumcheckChip { + fn execute( + &mut self, + memory: &mut MemoryController, + instruction: &Instruction, + from_state: ExecutionState, + ) -> Result, ExecutionError> { + let &Instruction { + opcode: op, + a: output_register, + b: input_register_1, + c: input_register_2, + d: data_address_space, + e: register_address_space, + f: input_register_3, + g: input_register_4, + } = instruction; + + if op == SUMCHECK_LAYER_EVAL.global_opcode() { + let mut observation_records: Vec> = vec![]; + let mut curr_timestamp: usize = 0; + + let (read_ctx_pointer, ctx_pointer) = + memory.read_cell(register_address_space, input_register_1); + let (read_cs_pointer, cs_pointer) = + memory.read_cell(register_address_space, input_register_2); + let (read_prod_pointer, prod_ptr) = + memory.read_cell(register_address_space, input_register_3); + let (read_logup_pointer, logup_ptr) = + memory.read_cell(register_address_space, input_register_4); + let (read_result_pointer, r_ptr) = + memory.read_cell(register_address_space, output_register); + let register_ptrs: [F; 5] = [ctx_pointer, cs_pointer, prod_ptr, logup_ptr, r_ptr]; + + let (ctx_read, ctx): (RecordId, [F; EXT_DEG * 2]) = memory.read::<{EXT_DEG * 2}>(data_address_space, ctx_pointer); + let [ + round, + num_prod_spec, + num_logup_spec, + prod_specs_inner_len, + prod_specs_inner_inner_len, + logup_specs_inner_len, + logup_specs_inner_inner_len, + is_op_for_cur_sumcheck_round, // This opcode supports two modes of operation: + // 1. calculate the expected evaluation of two types of sumchecks for the current round + // a. product sumcheck: v' = v[0] * v[1] + // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. + // 2. calculate the expected value of next layer: + // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] + // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] + ] = ctx; + + let (challenges_read, challenges): (RecordId, [F; EXT_DEG * 4]) = memory.read::<{EXT_DEG * 4}>(data_address_space, cs_pointer); + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().expect(""); + + let mut header_row = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 0, + continuation: true, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + alpha, + registers: [ + input_register_1, + input_register_2, + input_register_3, + input_register_4, + output_register, + ], + ctx, + challenges, + read_data_records: [ + read_ctx_pointer, + read_cs_pointer, + read_prod_pointer, + read_logup_pointer, + read_result_pointer, + ctx_read, + challenges_read, + ], + ..Default::default() + }; + + observation_records.push(header_row); + self.height += 1; + curr_timestamp += 7; + + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); + + let mut i = F::ZERO; + let mut i_usize = 0usize; + while i < num_prod_spec { + let mut prod_row: SumcheckEvalRecord = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 1, + continuation: true, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + ctx, + challenges, + alpha, + prod_spec_n: i_usize, + ..Default::default() + }; + prod_row.alpha1 = alpha_acc; + + let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + i); + prod_row.max_round = max_round; + prod_row.read_data_records[0] = read_max_round; + curr_timestamp += 1; + + if round < (max_round - F::from_canonical_usize(1)) { + prod_row.within_round_limit = true; + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + F::from_canonical_usize(0), + ); + prod_row.data_ptr = start; + + let (read_p, ps) = memory.read::<{EXT_DEG * 2}>(data_address_space, prod_ptr + start); + let p1: [F; 4] = ps[0..EXT_DEG].try_into().expect(""); + let p2: [F; 4] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + + prod_row.read_data_records[1] = read_p; + prod_row.p1 = p1; + prod_row.p2 = p2; + + let evals = if is_op_for_cur_sumcheck_round > F::ZERO { + FieldExtension::multiply(p1, p2) + } else { + FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ) + }; + prod_row.p_evals = evals; + + let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + i) * F::from_canonical_usize(EXT_DEG), evals); + prod_row.write_data_records[0] = write_slice_eval_1; + + let is_op_for_next_sumcheck_round = F::ONE - is_op_for_cur_sumcheck_round; + let acc_eval = FieldExtension::multiply(alpha_acc, evals); + prod_row.acc_eval = acc_eval; + + if (round + is_op_for_next_sumcheck_round) < (max_round - F::from_canonical_usize(1)) { + eval_acc = FieldExtension::add(eval_acc, acc_eval); + prod_row.should_acc = true; + prod_row.eval_acc = eval_acc.clone(); + } + + curr_timestamp += 2; + } + + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + + i = i + F::ONE; + i_usize += 1; + observation_records.push(prod_row); + self.height += 1; + } + + let mut i = F::ZERO; + let mut i_usize = 0usize; + while i < num_logup_spec { + let mut logup_row: SumcheckEvalRecord = SumcheckEvalRecord { + from_state, + instruction: instruction.clone(), + row_type: 2, + continuation: true, + curr_timestamp_increment: curr_timestamp, + register_ptrs, + ctx, + challenges, + alpha, + logup_spec_n: i_usize, + ..Default::default() + }; + logup_row.alpha1 = alpha_acc; + + let (read_max_round, max_round) = memory.read_cell(data_address_space, ctx_pointer + F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + num_prod_spec + i); + logup_row.max_round = max_round; + logup_row.read_data_records[0] = read_max_round; + curr_timestamp += 1; + + if round < (max_round - F::from_canonical_usize(1)) { + logup_row.within_round_limit = true; + let start = calculate_3d_ext_idx( + logup_specs_inner_inner_len, + logup_specs_inner_len, + i, + round, + F::from_canonical_usize(0), + ); + logup_row.data_ptr = start; + + let (read_pqs, pqs) = memory.read::<{EXT_DEG * 4}>(data_address_space, logup_ptr + start); + let p1: [F; 4] = pqs[0..EXT_DEG].try_into().expect(""); + let p2: [F; 4] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().expect(""); + let q1: [F; 4] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().expect(""); + let q2: [F; 4] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().expect(""); + + logup_row.read_data_records[1] = read_pqs; + logup_row.p1 = p1; + logup_row.p2 = p2; + logup_row.q1 = q1; + logup_row.q2 = q2; + + let p_evals = if is_op_for_cur_sumcheck_round > F::ZERO { + FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ) + } else { + FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ) + }; + + let q_evals = if is_op_for_cur_sumcheck_round > F::ZERO { + FieldExtension::multiply(q1, q2) + } else { + FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ) + }; + + logup_row.p_evals = p_evals; + logup_row.q_evals = q_evals; + + let (write_slice_eval_1, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + i) * F::from_canonical_usize(EXT_DEG), p_evals); + let (write_slice_eval_2, _) = memory.write::(data_address_space, r_ptr + (F::ONE + num_prod_spec + num_logup_spec + i) * F::from_canonical_usize(EXT_DEG), q_evals); + + logup_row.write_data_records[0] = write_slice_eval_1; + logup_row.write_data_records[1] = write_slice_eval_2; + + let is_op_for_next_sumcheck_round = F::ONE - is_op_for_cur_sumcheck_round; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + logup_row.alpha2 = alpha_denominator; + + if (round + is_op_for_next_sumcheck_round) < (max_round - F::from_canonical_usize(1)) { + let acc_eval = FieldExtension::add( + FieldExtension::multiply(alpha_acc, p_evals), + FieldExtension::multiply(alpha_denominator, q_evals), + ); + logup_row.acc_eval = acc_eval; + eval_acc = FieldExtension::add(eval_acc, acc_eval); + logup_row.should_acc = true; + logup_row.eval_acc = eval_acc.clone(); + } + + curr_timestamp += 3; + } + + alpha_acc = FieldExtension::multiply(FieldExtension::multiply(alpha_acc, alpha), alpha); + + i = i + F::ONE; + i_usize += 1; + observation_records.push(logup_row); + self.height += 1; + } + + let (write_r, _) = memory.write::(data_address_space, r_ptr, eval_acc); + curr_timestamp += 1; + observation_records[0].write_data_records[0] = write_r; + + for record in &mut observation_records { + record.final_timestamp_increment = curr_timestamp; + record.eval_acc = FieldExtension::subtract(eval_acc, record.eval_acc); + } + let last_idx = observation_records.len() - 1; + observation_records[last_idx].continuation = false; + + self.record_set.extend(observation_records); + } else { + unreachable!() + } + + Ok(ExecutionState { + pc: from_state.pc + DEFAULT_PC_STEP, + timestamp: memory.timestamp(), + }) + } + + + fn get_opcode_name(&self, opcode: usize) -> String { + if opcode == SUMCHECK_LAYER_EVAL.global_opcode().as_usize() { + String::from("SUMCHECK_LAYER_EVAL") + } else { + unreachable!("unsupported opcode: {}", opcode) + } + } +} \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs new file mode 100644 index 0000000000..9d8cd072f1 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -0,0 +1,143 @@ +use openvm_circuit::system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use crate::field_extension::EXT_DEG; +use crate::utils::const_max; + +const fn max3(a: usize, b: usize, c: usize) -> usize { + const_max(a, const_max(b, c)) +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct NativeSumcheckCols { + /// Indicates that this row is the header for a layer sum operation + pub header_row: T, + /// Indicates that this row is a step for prod_spec in the layer sum operation + pub prod_row: T, + /// Indicates that this row is a step for logup_spec in the layer sum operation + pub logup_row: T, + + /// Indicates that there are valid operations following this header row + pub header_continuation: T, + /// Indicates that there are valid operations following this product evaluation row + pub prod_continuation: T, + /// Indicates that there are valid operations following this logup row + pub logup_continuation: T, + + /// Indicates that the prod row is within maximum round + pub prod_row_within_max_round: T, + /// Indicates that the logup row is within maximum round + pub logup_row_within_max_round: T, + + /// Indicates what type of evaluation constraints should be applied + pub prod_in_round_evaluation: T, + pub prod_next_round_evaluation: T, + pub logup_in_round_evaluation: T, + pub logup_next_round_evaluation: T, + + /// Indicates if evaluations are accumulated + pub prod_acc: T, + pub logup_acc: T, + + /// Timestamps + pub first_timestamp: T, + pub start_timestamp: T, + pub last_timestamp: T, + + // Register values + pub register_ptrs: [T; 5], + + // Context variables + // [ + // round, + // num_prod_spec, + // num_logup_spec, + // prod_spec_inner_len, + // prod_spec_inner_inner_len, + // logup_spec_inner_len, + // logup_spec_inner_inner_len, + // in_layer, + // ] + pub ctx: [T; EXT_DEG * 2], + + pub prod_nested_len: T, + pub logup_nested_len: T, + + pub curr_prod_n: T, + pub curr_logup_n: T, + + // alpha1, c1, c2, alpha2 (for logup rows) + pub alpha: [T; EXT_DEG], + pub challenges: [T; EXT_DEG * 4], + + // Specific to each row + pub max_round: T, + // Is this round within max_round + pub within_round_limit: T, + // Should the evaluation be accumualted + pub should_acc: T, + + // The current final evaluation accumulator. Extension element. + pub eval_acc: [T; EXT_DEG], + + // /// 1. For header row, 5 registers, ctx, challenges + // /// 2. For the rest: max_variables, p1, p2, q1, q2 + // pub read_records: [MemoryReadAuxCols; 7], + // /// 1. For header row, write final result + // /// 2. For prod rows: write prod_evals + // /// 3. For logup rows: write q_evals, p_evals + // pub write_records: [MemoryWriteAuxCols; 2], + + pub specific: [T; max3( + HeaderSpecificCols::::width(), + ProdSpecificCols::::width(), + LogupSpecificCols::::width(), + )] +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct HeaderSpecificCols { + pub pc: T, + pub registers: [T; 5], + /// 5 register reads + ctx read + challenges read + pub read_records: [MemoryReadAuxCols; 7], + /// Write the final evaluation + pub write_records: MemoryWriteAuxCols +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ProdSpecificCols { + /// Pointer + pub data_ptr: T, + /// 2 extension elements + pub p: [T; EXT_DEG * 2], + /// read max varibale and 2 p values + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// write p_evals + pub write_record: MemoryWriteAuxCols, + /// Evaluation for the accumulator + pub acc_eval: [T; EXT_DEG], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct LogupSpecificCols { + /// Pointer + pub data_ptr: T, + /// 4 extension elements + pub pq: [T; EXT_DEG * 4], + /// read max variable and 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// Calculated q evals + pub q_evals: [T; EXT_DEG], + /// write both p_evals and q_evals + pub write_records: [MemoryWriteAuxCols; 2], + /// Evaluation for the accumulator + pub acc_eval: [T; EXT_DEG], +} \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/mod.rs b/extensions/native/circuit/src/sumcheck/mod.rs new file mode 100644 index 0000000000..8b6cab5165 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/mod.rs @@ -0,0 +1,5 @@ +pub mod air; +pub mod chip; +mod columns; +// mod tests; +mod trace; \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/trace.rs b/extensions/native/circuit/src/sumcheck/trace.rs new file mode 100644 index 0000000000..1d3b9f0941 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/trace.rs @@ -0,0 +1,169 @@ +use std::{borrow::BorrowMut, sync::Arc}; + +use openvm_circuit::system::memory::{MemoryAuxColsFactory, OfflineMemory}; +use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_native_compiler::Poseidon2Opcode::COMP_POS2; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_air::BaseAir, + p3_field::{Field, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, + p3_maybe_rayon::prelude::*, + prover::types::AirProofInput, + AirRef, Chip, ChipUsageGetter, +}; +use rand::distributions::Alphanumeric; +use crate::{FieldExtension, sumcheck::{chip::NativeSumcheckChip, columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}}, EXT_DEG}; + +impl ChipUsageGetter + for NativeSumcheckChip +{ + fn air_name(&self) -> String { + "SumcheckLayerEval".to_string() + } + + fn current_trace_height(&self) -> usize { + self.height + } + + fn trace_width(&self) -> usize { + NativeSumcheckCols::::width() + } +} + +impl NativeSumcheckChip { + fn generate_trace(self) -> RowMajorMatrix { + let width = self.trace_width(); + let height = next_power_of_two_or_zero(self.height); + let mut flat_trace: Vec = F::zero_vec(width * height); + + let memory = self.offline_memory.lock().unwrap(); + let aux_cols_factory = memory.aux_cols_factory(); + + let mut used_cells = 0; + for record in self.record_set { + let slice = &mut flat_trace[used_cells..used_cells + width]; + let cols: &mut NativeSumcheckCols = slice.borrow_mut(); + cols.first_timestamp = F::from_canonical_u32(record.from_state.timestamp); + cols.start_timestamp = F::from_canonical_usize(record.from_state.timestamp as usize + record.curr_timestamp_increment); + cols.last_timestamp = F::from_canonical_usize(record.from_state.timestamp as usize + record.final_timestamp_increment); + cols.register_ptrs = record.register_ptrs; + cols.ctx = record.ctx; + cols.prod_nested_len = record.ctx[4] * record.ctx[3]; + cols.logup_nested_len = record.ctx[6] * record.ctx[5]; + cols.challenges = record.challenges; + cols.alpha = record.alpha; + cols.max_round = record.max_round; + cols.within_round_limit = if record.within_round_limit { F::ONE } else { F::ZERO }; + cols.should_acc = if record.should_acc { F::ONE } else { F::ZERO }; + cols.eval_acc = record.eval_acc; + + if record.row_type == 0 { + cols.header_row = F::ONE; + cols.header_continuation = if record.continuation { F::ONE } else { F::ZERO }; + let header: &mut HeaderSpecificCols = + cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + header.pc = F::from_canonical_u32(record.from_state.pc); + header.registers = record.registers; + + for i in 0..7usize { + let mem_record = memory.record_by_id(record.read_data_records[i]); + aux_cols_factory.generate_read_aux(mem_record, &mut header.read_records[i]); + } + + // write the final result + let mem_record = memory.record_by_id(record.write_data_records[0]); + aux_cols_factory.generate_write_aux(mem_record, &mut header.write_records); + } else if record.row_type == 1 { + cols.prod_row = F::ONE; + cols.prod_continuation = if record.continuation { F::ONE } else { F::ZERO }; + cols.prod_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; + cols.prod_in_round_evaluation = if record.within_round_limit { record.ctx[7] } else { F::ZERO }; + cols.prod_next_round_evaluation = if record.within_round_limit { F::ONE - record.ctx[7] } else { F::ZERO }; + cols.prod_acc = if record.should_acc { F::ONE } else { F::ZERO }; + let prod: &mut ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + + cols.curr_prod_n = F::from_canonical_usize(record.prod_spec_n + 1); + cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); + prod.p[0..EXT_DEG].copy_from_slice(&record.p1); + prod.p[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); + prod.data_ptr = record.data_ptr; + prod.acc_eval = record.acc_eval; + + // Read max_round + let mem_record = memory.record_by_id(record.read_data_records[0]); + aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[0]); + + if record.within_round_limit { + // Read p1, p2 + let mem_record = memory.record_by_id(record.read_data_records[1]); + aux_cols_factory.generate_read_aux(mem_record, &mut prod.read_records[1]); + + // Write p eval + prod.p_evals = record.p_evals; + let mem_record = memory.record_by_id(record.write_data_records[0]); + aux_cols_factory.generate_write_aux(mem_record, &mut prod.write_record); + } + } else if record.row_type == 2 { + cols.logup_row = F::ONE; + cols.logup_continuation = if record.continuation { F::ONE } else { F::ZERO }; + cols.logup_row_within_max_round = if record.within_round_limit { F::ONE } else { F::ZERO }; + cols.logup_in_round_evaluation = if record.within_round_limit { record.ctx[7] } else { F::ZERO }; + cols.logup_next_round_evaluation = if record.within_round_limit { F::ONE - record.ctx[7] } else { F::ZERO }; + cols.logup_acc = if record.should_acc { F::ONE } else { F::ZERO }; + let logup: &mut LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + + cols.curr_logup_n = F::from_canonical_usize(record.logup_spec_n + 1); + cols.challenges[0..EXT_DEG].copy_from_slice(&record.alpha1); + cols.challenges[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.alpha2); + logup.pq[0..EXT_DEG].copy_from_slice(&record.p1); + logup.pq[EXT_DEG..(EXT_DEG * 2)].copy_from_slice(&record.p2); + logup.pq[(EXT_DEG * 2)..(EXT_DEG * 3)].copy_from_slice(&record.q1); + logup.pq[(EXT_DEG * 3)..(EXT_DEG * 4)].copy_from_slice(&record.q2); + logup.data_ptr = record.data_ptr; + logup.acc_eval = record.acc_eval; + + // Read max_round + let mem_record = memory.record_by_id(record.read_data_records[0]); + aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[0]); + + if record.within_round_limit { + // Read p1, p2, q1, q2 + let mem_record = memory.record_by_id(record.read_data_records[1]); + aux_cols_factory.generate_read_aux(mem_record, &mut logup.read_records[1]); + + // Write p and q eval + logup.p_evals = record.p_evals; + logup.q_evals = record.q_evals; + for i in 0..2usize { + let mem_record = memory.record_by_id(record.write_data_records[i]); + aux_cols_factory.generate_write_aux(mem_record, &mut logup.write_records[i]); + } + } + } else { + unreachable!() + } + + used_cells += width; + } + + RowMajorMatrix::new(flat_trace, width) + } +} + +impl Chip + for NativeSumcheckChip> +where + Val: PrimeField32, +{ + fn air(&self) -> AirRef { + Arc::new(self.air.clone()) + } + fn generate_air_proof_input(self) -> AirProofInput { + AirProofInput::simple_no_pis(self.generate_trace()) + } +} \ No newline at end of file diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 0e0db9cff9..b701844008 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -632,6 +632,12 @@ impl + TwoAdicField> AsmCo ); } } + DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr, r_ptr) => { + self.push( + AsmInstruction::SumcheckLayerEval(input_ctx.fp(), challenges.fp(), prod_ptr.fp(), logup_ptr.fp(), r_ptr.fp()), + debug_info, + ); + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index cd4990b08b..48e239cdda 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -171,6 +171,15 @@ pub enum AsmInstruction { CycleTrackerStart(), CycleTrackerEnd(), + + // Native opcode for calculating sumcheck layer evaluation + // SumcheckLayerEval(reg_a, reg_b, reg_c, ... , reg_f, reg_g) + // - reg_a: Output ptr for next layer's evaluations + // - reg_b: Context variables + // - reg_c: Challenge values (alpha, coeff) + // - reg_g: GKR product IOP evaluations + // - reg_f: GKR logup IOP evaluations + SumcheckLayerEval(i32, i32, i32, i32, i32), } impl> AsmInstruction { @@ -403,6 +412,9 @@ impl> AsmInstruction { AsmInstruction::RangeCheck(fp, lo_bits, hi_bits) => { write!(f, "range_check_fp ({})fp, ({}), ({})", fp, lo_bits, hi_bits) } + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => { + write!(f, "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", ctx, cs, p_ptr, l_ptr, r_ptr) + } } } } diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index f8da82c30b..f6fffc2db0 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -9,10 +9,7 @@ use openvm_stark_backend::p3_field::{ExtensionField, PrimeField32, PrimeField64} use serde::{Deserialize, Serialize}; use crate::{ - asm::{AsmInstruction, AssemblyCode}, - FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, - NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, + asm::{AsmInstruction, AssemblyCode}, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode }; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -535,7 +532,19 @@ fn convert_instruction>( // Here it just requires a 0 AS::Immediate, )] - } + }, + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ + Instruction { + opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), + a: i32_f(r_ptr), + b: i32_f(ctx), + c: i32_f(cs), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(p_ptr), + g: i32_f(l_ptr), + } + ], }; let debug_infos = vec![debug_info; instructions.len()]; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 13f5c4a653..af9257d4f0 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -319,6 +319,35 @@ pub enum DslIr { CycleTrackerStart(String), /// End the cycle tracker used by a block of code annotated by the string input. CycleTrackerEnd(String), + + /// Native operation for calculating a sumcheck layer's evaluation + /// This op supports two modes: + /// 1. for computing expected evaluation for current layer, + /// output = [ + /// \sum_i alpha^i * prod[i][0] * prod[i][1] + + /// \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + + /// alpha* logup_p[i][0] * logup_q[i][1] + + /// alpha * logup_p[i][1] * logup_q[i][0] + /// ]; + /// + /// 2. for computing expected evaluation of next layer, + /// output[1+i] = eq(0,r)*p[i][0] + eq(1,r) * p[i][1]. + SumcheckLayerEval( + Ptr, // Context variables: + // 0. round, + // 1. number of product + // 2. number of logup + // 3. (3D array description) prod_specs_eval inner length + // 4. (3D array description) prod_specs_eval inner_inner length + // 5. (3D array description) logup_spec_eval inner length + // 6. (3D array description) logup_spec_eval inner length + // 7. Operational mode indicator + // 8+. usize-type variables indicating maximum rounds + Ptr, // Challenges: alpha, coeffs + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval + Ptr // output + ) } impl Default for DslIr { diff --git a/extensions/native/compiler/src/ir/mod.rs b/extensions/native/compiler/src/ir/mod.rs index 47e901cd3a..29bb52f086 100644 --- a/extensions/native/compiler/src/ir/mod.rs +++ b/extensions/native/compiler/src/ir/mod.rs @@ -23,6 +23,7 @@ mod types; mod utils; mod var; mod verify_batch; +mod sumcheck; pub trait Config: Clone + Default { type N: PrimeField; diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs new file mode 100644 index 0000000000..f49f56ada8 --- /dev/null +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -0,0 +1,53 @@ +use openvm_native_compiler_derive::iter_zip; +use openvm_stark_backend::p3_field::FieldAlgebra; +use crate::ir::Variable; +use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; + +impl Builder { + /// Extends native VM ability to calculate the evaluation for a sumcheck layer + /// This opcode supports two modes (indicated by a context variable): + /// 1. calculate the expected evaluation of two types of sumchecks (prod, logup) + /// 2. calculate the expected value of next layer p[r] = eq(0,r)*p[0] + eq(1,r)*p[1] + /// + /// Context variables + /// + /// 0: round, + /// 1: number of product + /// 2. number of logup + /// 3. (3D array description) prod_specs_eval inner length + /// 4. (3D array description) prod_specs_eval inner_inner length + /// 5. (3D array description) logup_spec_eval inner length + /// 6. (3D array description) logup_spec_eval inner length + /// 7. Operational mode indicator + /// 8+ Additional usize-type variables indicating maximum rounds + /// + /// Output + /// + /// 1. for computing expected evaluation, + /// output = [ + /// \sum_i alpha^i * prod[i][0] * prod[i][1] + + /// \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + + /// alpha* logup_p[i][0] * logup_q[i][1] + + /// alpha * logup_p[i][1] * logup_q[i][0] + /// ]; + /// + /// 2. for computing expected eval of next layer, + /// output[1+i] = eq(0,r)*p[i][0] + eq(1,r) * p[i][1]. + /// + pub fn sumcheck_layer_eval ( + &mut self, + input_ctx: &Array>, // Context variables + challenges: &Array>, // Challenges + prod_specs_eval: &Array>, // GKR product IOP evaluations. Flattened from 3D array. + logup_specs_eval: &Array>, // GKR logup IOP evaluations. Flattened from 3D array. + r_evals: &Array>, // Next layer's evaluations (pointer used for storing opcode output) + ) { + self.operations.push(DslIr::SumcheckLayerEval( + input_ctx.ptr(), + challenges.ptr(), + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), + r_evals.ptr(), + )); + } +} \ No newline at end of file diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index 66c786fbd9..6f693efed1 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -212,3 +212,18 @@ pub enum VerifyBatchOpcode { /// per column polynomial, per opening point VERIFY_BATCH, } + +/// Opcodes for sumcheck. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x180] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum SumcheckOpcode { + /// Compute the expected evaluation for each layer in the tower structure that GKR product IOP and logup IOP uses + /// Supports two modes of operation: + /// 1. Calculate current layer's expected evaluation + /// 2. Calculate next layer's evaluation + SUMCHECK_LAYER_EVAL, +} diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs new file mode 100644 index 0000000000..975f63b211 --- /dev/null +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -0,0 +1,422 @@ +use itertools::Itertools; +use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmConfig, VmExecutor, verify_single, VirtualMachine,}; +use openvm_native_circuit::{Native, NativeConfig, EXT_DEG}; +use openvm_native_compiler::{ + prelude::*, + asm::{AsmBuilder, AsmCompiler}, ir::{Felt, Ext, Usize}, + conversion::{convert_program, CompilerOptions}, +}; +use openvm_native_recursion::{testing_utils::inner::run_recursive_test, challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}}; +use openvm_stark_backend::{ + config::{Domain, StarkGenericConfig}, + p3_commit::PolynomialSpace, + p3_field::{extension::BinomialExtensionField, FieldAlgebra, PackedValue, FieldExtensionAlgebra}, +}; +use openvm_stark_sdk::{ + config::FriParameters, + p3_baby_bear::BabyBear, + utils::ProofInputForTest, + config::{ + baby_bear_poseidon2::BabyBearPoseidon2Engine, + fri_params::standard_fri_params_with_100_bits_conjectured_security, + }, + engine::StarkFriEngine, + utils::create_seeded_rng, +}; +use rand::Rng; +pub type F = BabyBear; +pub type E = BinomialExtensionField; + +#[test] +fn test_sumcheck_layer_eval() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + // Fill in test program logic + builder.halt(); + + + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + // let program = Program::from_instructions(&instructions); + let program: Program<_> = convert_program(asm_code, compilation_options); + let sumcheck_max_constraint_degree = 3; + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let engine = BabyBearPoseidon2Engine::new(fri_params); + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + + let vm = VirtualMachine::new(engine, config); + + let pk = vm.keygen(); + let result = vm.execute_and_generate(program, vec![]).unwrap(); + let proofs = vm.prove(&pk, result); + + for proof in proofs { + verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); + } +} + +fn build_test_program( + builder: &mut Builder, +) { + let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; + let ctx: Array> = builder.dyn_array(ctx_u32s.len()); + for (idx, n) in ctx_u32s.into_iter().enumerate() { + builder.set(&ctx, idx, Usize::from(n as usize)); + } + + let challenges_u32s = [ + 548478283u32, 456436544, 1716290291, 791326976, + 1829717553, 1422025771, 1917123958, 727015942, + 183548369, 591240150, 96141963, 1286249979, + ]; + let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); + for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); + + builder.set(&challenges, idx, e); + } + + let prod_spec_eval_u32s = [ + 1538906710u32, 637535518, 1753132406, 1395236651, + 278806441, 1722910382, 1475548665, 1117874675, + 1578586709, 1826764884, 384068476, 1852240363, + 707958906, 1960944944, 183554399, 1259273357, + 227285124, 243066436, 1718037317, 369721963, + 1752968006, 1061013677, 775617499, 1464907431, + 544300429, 871461966, 135151545, 1343592602, + 1622220528, 643966158, 3932580, 434948358, + 540553922, 1446502052, 153298741, 1191216273, + 265936762, 1463035257, 1237633339, 1797346310, + 1355791584, 389527741, 1741650463, 1728913415, + 1825739540, 1790924136, 460776743, 29536554, + 6842036, 252495270, 1968285155, 299467416, + 49085744, 1499815729, 1098802236, 644489275, + 1827273105, 1888401527, 390077051, 565528894, + 1366177188, 67441791, 958486301, 402056716, + 590379691, 462035406, 633459131, 843304872, + 584100013, 1932496508, 250656031, 146983915, + 1835173157, 939973454, 1844873638, 1916054832, + 1601784696, 167251717, 409107688, 1062925788, + 1291319514, 1790529531, 495655592, 1093359708, + 790197205, 674458164, 195988318, 399764452, + 106865258, 967050329, 350035523, 1109292118, + 1815460301, 281986036, 900636603, 1121197008, + 1228976590, 1879998708, 1924332706, 434695844, + 1159360621, 471397106, 473371067, 1009065094, + 1320176846, 168020789, 1265321929, 1901808675, + 223657700, 1480150183, 1779968584, 144416591, + 304407746, 1864498679, 1482460119, 1554376965, + 1479261548, 1657723043, 1039345063, 1053923521, + 442080513, 1964082352, 691664908, 1941008321, + 1007729002, 860529393, 849697342, 754485488, + 584295923, 1072251466, 1105105254, 996079746, + 1305909868, 1348028973, 122275988, 464050036, + 692807777, 1098809324, 397235220, 596459886, + 1663209783, 720230826, 1422510715, 1760654694, + 544197700, 1417744567, 1938716517, 1571826328, + 1591430185, 1173137446, 175285007, 1541718596, + 1715958587, 1429966110, 583013357, 1667787861, + 109891172, 668253167, 161783842, 296183397, + 1681897325, 1054396117, 264741948, 464026995, + 1907686022, 1532786783, 394869458, 1766734740, + 136047179, 536856195, 376188855, 700633625, + 515518419, 531043483, 60673499, 556496527, + 1743028981, 873954569, 1371062291, 632169731, + 1353239206, 526507035, 1894490088, 589441599, + 1610487168, 1074160583, 366366374, 247602990, + 1535354896, 894493713, 1555870413, 1389854934, + 1897251683, 1525812801, 675621735, 697919636, + 1690274072, 1466810921, 1221110784, 1741995587, + 1877169764, 390876982, 1794129810, 297662156, + 144295349, 417037264, 1290835727, 1654971513, + 1674131303, 1625667423, 1471248832, 1676797844, + 1172916558, 1707775403, 423725211, 1643279661, + 1695774264, 378140395, 1517569394, 1666625392, + 1803981250, 439036260, 247966130, 709534816, + 361144100, 1546096548, 1240886454, 1898161518, + 843262057, 1709259464, 1301015977, 1997626928, + 677153173, 1606710353, 1216038070, 435565562, + 98686333, 1773787396, 267051994, 99395396, + 545509105, 782289675, 1289865975, 1707775075, + 1158993015, 1506576588, 993215179, 1523099397, + 923914455, 1895162386, 284489994, 1444139016, + 1943825680, 466202724, 1632522710, 1384015062, + 723147188, 1284031324, 1430481515, 341213007, + 171192499, 1061688239, 808927167, 83182639, + 759209907, 1728321272, 976049976, 1652071995, + 1002877840, 69880246, 1095135165, 677588420, + 1384715290, 829619452, 170122781, 1958173727, + 13389238, 789379698, 1883383039, 1279195174, + 1618672336, 1192839317, 1348311124, 758896285, + 1939775389, 684108413, 1838340479, 1332232130, + 1070486028, 549228790, 868851698, 1678207843, + 1754321489, 637000403, 647901906, 45343322, + 1768524074, 1167955205, 1816497210, 1609414096, + 1985231742, 1540534482, 232730819, 232221968, + 1509637836, 1480860627, 884647789, 1096458024, + 163721583, 1248032262, 436419506, 1737102298, + 651105860, 452298073, 1064372507, 1792838683, + 619243471, 860127631, 721724708, 950768433, + 279913448, 339693210, 47730422, 1952683911, + 1316500770, 675944216, 386902809, 619333956, + 1194800389, 43989936, 1944372656, 666045666, + 1155873844, 522696968, 58874730, 1497238023, + 421619994, 1980672127, 1657191856, 1913792631, + 1784663131, 1118400672, 1828104993, 1637808383, + 414755472, 775410449, 747132157, 136820101, + 1082674285, 93190395, 357955402, 335652723, + 1192102705, 480365232, 1354935730, 1391829361, + 966662991, 1601510445, 569528575, 545490940, + 1753711688, 807025222, 580374183, 587718008, + 977546290, 1055719519, 1157107032, 562799608, + 859466927, 840450024, 815325134, 936576801, + 1010587056, 246624382, 1808049797, 1098183398, + 1005077390, 772432546, 1976629565, 1003772218, + 1655315418, 1767931114, 982008720, 785023351, + ]; + + let prod_spec_evals: Array> = builder.dyn_array(prod_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in prod_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); + + builder.set(&prod_spec_evals, idx, e); + } + + let logup_spec_eval_u32s = [ + 1522353967u32, 457603397, 421847521, 1352563318, + 1746817766, 737872688, 1087008622, 1850835028, + 456475558, 892966330, 638163666, 148568548, + 678863061, 1334386850, 1896333039, 154585769, + 433618446, 1186936470, 970218722, 1213827097, + 1798557019, 861757965, 119285527, 395360622, + 226164366, 1330279872, 66561048, 785421608, + 1950755756, 1559889596, 348449876, 1090789452, + 257578851, 273164442, 1644906, 295600924, + 1187949602, 1168249609, 469763604, 60929061, + 291163036, 403842501, 1421902433, 1700188477, + 1046093370, 921059131, 1638991894, 464012042, + 96905857, 1370999592, 271896041, 13595534, + 1489760970, 1650552701, 133367846, 25680377, + 377631580, 652729291, 645763356, 426747355, + 482475486, 1877299223, 103226636, 1333832358, + 1399609097, 458536972, 976248802, 1109365280, + 515164588, 1579426417, 1601829549, 607169702, + 852817956, 1980537127, 134138338, 913344050, + 737880920, 476360275, 61624034, 1610624252, + 264461991, 546933535, 937769429, 293346965, + 1522058041, 1012551797, 994330314, 23333322, + 1969510890, 974351570, 2012030621, 120742000, + 450250620, 180547360, 642746933, 1815029950, + 629489142, 1176992624, 723354779, 572648755, + 1218615348, 648847054, 351903235, 723149764, + 248065753, 243829448, 1283393001, 1912627886, + 581641342, 702465306, 205969758, 1061911274, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1703043252, 1467887451, 1714319214, 907866644, + 1542426838, 742609036, 1814393459, 448706641, + 1960340767, 46490834, 186512520, 363973095, + 846448854, 463742343, 2012517527, 40473617, + 9472552, 263483342, 105738598, 586389136, + 254290990, 625150844, 960233097, 1488303724, + 1700231692, 1471714612, 1540211186, 1590246915, + 945341972, 1343225515, 179976237, 34857822, + 276912528, 984309272, 1277293398, 1520924162, + 1823117694, 604836357, 1460812009, 600052559, + 970469338, 1771022707, 181855831, 1445947220, + 467514809, 1514677498, 947030389, 170390653, + 415409007, 1601463730, 204153427, 904614278, + 1855419512, 2009471607, 1352607379, 576586082, + 1343812879, 1176377580, 1166188815, 1592289048, + 761793881, 1529621462, 193034837, 344011596, + 1669461833, 1356800025, 314186361, 586497329, + 1832810846, 1288092861, 1619454491, 732529408, + 737934269, 909504928, 769680420, 1437893101, + 1727002258, 1618231110, 535125583, 153412473, + 1917760929, 588586507, 564531165, 1790797737, + 1666283994, 1366948884, 117673690, 476470378, + 2012274032, 1951406668, 1739767532, 1273142151, + 1591812317, 1900205312, 1912608761, 1734766024, + 1265002082, 1450462894, 749810837, 1329222552, + 745081805, 1231519431, 1420957967, 883846107, + 1995463911, 407795592, 161655852, 125886157, + 995318920, 484905024, 284135318, 551493419, + 406742309, 1089024446, 637339867, 1858138403, + 1230680117, 187078889, 1929517480, 1125646261, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1610035932, 462442436, 831412555, 44798862, + 1748147276, 1911945531, 1329343740, 971894393, + 362147969, 1583335926, 1528700112, 426908674, + 847905883, 447889090, 1050883911, 1883537469, + 1487501632, 964178870, 1818828551, 1980840799, + 340372118, 1697179193, 215113037, 1893217470, + 1138628493, 1788052486, 443362955, 1349213730, + 589553425, 562526667, 1006040406, 1194546769, + 1831034644, 612004157, 730213913, 1068905440, + 371983982, 502900790, 802785198, 822377635, + 1477528437, 501356237, 684668525, 1306043781, + 621032592, 1971342708, 1411586583, 733418745, + 186045462, 1559301855, 323758310, 453170140, + 498381240, 976247416, 631213663, 898017829, + 501459603, 609703046, 1379288251, 177682695, + 912381595, 121915494, 1137416430, 504054388, + 1138277238, 1603388253, 1838013301, 1700271853, + 20488607, 58775264, 217974275, 979141729, + 53136584, 1331566240, 1460303356, 525812787, + 718385521, 1477919263, 1663622276, 1089788203, + 1204483837, 54225863, 290660186, 1441441958, + 134168813, 349638823, 1867912015, 1579183319, + 55528656, 1602973359, 194297109, 949763297, + 101931919, 242300116, 1610052257, 1351823848, + 174522860, 776955925, 1706962365, 808187490, + 1487253852, 431806906, 213982593, 1170647308, + 1776840400, 295916317, 378708073, 381270341, + 457494568, 705823997, 1407301442, 1693003013, + 700310785, 1349874247, 1284363817, 1566253815, + 1014298154, 215294365, 1070968678, 871641358, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1302679751, 1121894357, 368587356, 1564724097, + 733815591, 2012670011, 1146780092, 1439780227, + 1801628424, 838692317, 932318853, 213634365, + 155292454, 1644317110, 1599846194, 978829059, + 1282095862, 1780431647, 527412087, 1024583705, + 804423802, 951808322, 689345230, 180304167, + 1784562773, 1514653374, 2009396440, 1143778943, + 235299446, 1553017484, 475425117, 758292254, + 716575432, 517083432, 1728864125, 418010549, + 43202592, 507659742, 433077118, 1268144019, + 1462778342, 1928073362, 1330130180, 1749624351, + 827401013, 1236194147, 1875519726, 1437946791, + 607293265, 309229599, 1009445595, 1725229718, + 1436309341, 1952606463, 943149111, 291680468, + 1989684076, 1944713370, 1285294139, 399758737, + 1572979232, 213817406, 214840530, 184898060, + 1483844295, 1536616777, 494816009, 217625163, + 529448032, 786640964, 1766471731, 1424140424, + 1721961711, 740275169, 169908711, 913969302, + 1359358267, 1328322971, 593228769, 771095186, + 801680440, 450930656, 1796349530, 1824428677, + 1111258504, 1741666629, 1098430204, 1792001884, + 1679003061, 590088446, 647614538, 1324461639, + 818996796, 229187928, 74288115, 1158900266, + 1512606270, 1381672753, 785927403, 493453164, + 425259497, 1367873539, 931023744, 221202218, + 669580668, 424996238, 1840425275, 1873362670, + 967642716, 263556335, 578560519, 1558449223, + 607579284, 1724012378, 333582342, 1195784167, + 1419727276, 199294290, 138807165, 1061030752, + 1, 0, 0, 0, + 1, 0, 0, 0, + 776332180, 1333076185, 1855163818, 1897408938, + 799274251, 950452503, 691904988, 1205387466, + 659107883, 434394982, 129587940, 639018629, + 659238594, 1957584892, 864291238, 589178070, + 1267157231, 48925338, 200093884, 1953762869, + 1227617341, 1471420621, 193077633, 1007876111, + 228491220, 1377349503, 1889411060, 1807513892, + 1593042934, 1240864695, 1472870721, 583021932, + 598239104, 1862008818, 1811242869, 780768026, + 520870395, 292016292, 322246659, 868240490, + 1715620331, 1183509209, 2010262726, 1003957251, + 264895455, 307755941, 201990485, 1662471178, + 1643997923, 1573129362, 277821143, 388834470, + 943361405, 1449402196, 614413575, 1504113993, + 1860552739, 1755127315, 1734129760, 1232115188, + 803035456, 360488092, 271342171, 1269544258, + 290642673, 660703582, 986842267, 870891877, + 454573044, 1999346236, 701614601, 820253867, + 883282765, 137247873, 1727164949, 1320585493, + 1738664600, 1900116905, 472215154, 1114994489, + 104218174, 1694603079, 771486383, 935361143, + 92277671, 881040480, 925124484, 1464396527, + 100625197, 65290355, 1001454341, 134627585, + 58629702, 1541542242, 568583607, 1706262052, + 530687550, 1303187245, 1010302462, 264001857, + 789816678, 561378226, 827432508, 801307507, + 1613508315, 1650822853, 1603502703, 439320335, + 15283580, 1244486577, 254345266, 1745653280, + 1648250354, 1528271018, 528366563, 1078707735, + 1430767759, 1890467731, 2001894083, 799949326, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1341839494, 1092219735, 755644898, 966729319, + 1914277278, 1545367697, 1765189119, 1693413008, + ]; + + let logup_spec_evals: Array> = builder.dyn_array(logup_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in logup_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); + + builder.set(&logup_spec_evals, idx, e); + } + + let r_evals_u32s = [ + 941378355u32, 1078920879, 696738840, 496039492, + 1555445457, 184545404, 905938226, 1847966044, + 1024875886, 1782716223, 1625644635, 266865456, + 465953066, 1663531470, 757423849, 1957075986, + 1919693393, 839104130, 127480221, 1527842912, + 918650796, 921462354, 575456073, 696646705, + 1585912361, 258186488, 353168830, 1111094691, + 1401166558, 1905942163, 1923083163, 393037255, + 1042127700, 1126793296, 895794165, 1124924482, + 1324266058, 722406365, 1963838171, 968504459, + 1934378800, 714588691, 6465911, 1168379648, + 903786009, 1326035939, 518289228, 418998914, + 1513133474, 1578096058, 617547414, 1658315126, + 68556894, 1697802593, 1346510664, 1709381671, + 345062962, 1254089535, 1002281845, 1882822096, + 700581748, 1431345304, 489112954, 98435728, + 1799886007, 479788390, 223111065, 631662309, + ]; + + let next_layer_evals: Array> = builder.dyn_array(r_evals_u32s.len() / EXT_DEG); + for (idx, n) in r_evals_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]) + ])); + + builder.set(&next_layer_evals, idx, e); + } + + builder.sumcheck_layer_eval(&ctx, &challenges, &prod_spec_evals, &logup_spec_evals, &next_layer_evals); +}