diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index 052dc03fd5..da7aefd769 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -8,14 +8,14 @@ using namespace native; template struct HeaderSpecificCols { T pc; T registers[5]; - MemoryReadAuxCols read_records[7]; + MemoryReadAuxCols read_records[8]; MemoryWriteAuxCols write_records; }; template struct ProdSpecificCols { T data_ptr; T p[EXT_DEG * 2]; - MemoryReadAuxCols read_records[2]; + MemoryReadAuxCols read_records[1]; T p_evals[EXT_DEG]; MemoryWriteAuxCols write_record; T eval_rlc[EXT_DEG]; @@ -24,7 +24,7 @@ template struct ProdSpecificCols { template struct LogupSpecificCols { T data_ptr; T pq[EXT_DEG * 4]; - MemoryReadAuxCols read_records[2]; + MemoryReadAuxCols read_records[1]; T p_evals[EXT_DEG]; T q_evals[EXT_DEG]; MemoryWriteAuxCols write_records[2]; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index 99c365135b..173abebbe2 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -11,7 +11,7 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { - for (uint32_t i = 0; i < 7; ++i) { + for (uint32_t i = 0; i < 8; ++i) { mem_fill_base( mem_helper, start_timestamp + i, @@ -25,43 +25,33 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h specific.slice_from(COL_INDEX(HeaderSpecificCols, write_records.base)) ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) - ); if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { mem_fill_base( mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[1].base)) + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) ); mem_fill_base( mem_helper, - start_timestamp + 2, + start_timestamp + 1, specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) ); } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) - ); if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { mem_fill_base( mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[1].base)) + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) ); mem_fill_base( mem_helper, - start_timestamp + 2, + start_timestamp + 1, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) ); mem_fill_base( mem_helper, - start_timestamp + 3, + start_timestamp + 2, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) ); } diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 1cae3847ca..5bec217a1f 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -167,6 +167,9 @@ impl Air for NativeSumcheckAir { alpha, next.alpha, ); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(max_round, next.max_round); builder .when(next.prod_row + next.logup_row) .assert_eq(prod_nested_len, next.prod_nested_len); @@ -223,21 +226,21 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + AB::F::from_canonical_usize(7), + start_timestamp + AB::F::from_canonical_usize(8), ); builder .when(prod_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO, + start_timestamp + within_round_limit * AB::F::TWO, ); builder .when(logup_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3), + start_timestamp + within_round_limit * AB::F::from_canonical_usize(3), ); // Termination condition @@ -330,6 +333,19 @@ impl Air for NativeSumcheckAir { ) .eval(builder, header_row); + // Read max_round + self.memory_bridge + .read( + MemoryAddress::new( + native_as, + register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), + ), + [max_round], + first_timestamp + AB::F::from_canonical_usize(7), + &header_row_specific.read_records[7], + ) + .eval(builder, header_row); + // Write final result self.memory_bridge .write( @@ -348,20 +364,6 @@ impl Air for NativeSumcheckAir { let next_prod_row_specific: &ProdSpecificCols = next.specific[..ProdSpecificCols::::width()].borrow(); - self.memory_bridge - .read( - MemoryAddress::new( - native_as, - register_ptrs[0] - + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) - + (curr_prod_n - AB::F::ONE), - ), // curr_prod_n starts at 1. - [max_round], - start_timestamp, - &prod_row_specific.read_records[0], - ) - .eval(builder, prod_row); - // prod_row * within_round_limit = // prod_in_round_evaluation + prod_next_round_evaluation builder @@ -385,8 +387,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), prod_row_specific.p, - start_timestamp + AB::F::ONE, - &prod_row_specific.read_records[1], + start_timestamp, + &prod_row_specific.read_records[0], ) .eval(builder, prod_row * within_round_limit); @@ -402,7 +404,7 @@ impl Air for NativeSumcheckAir { register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + AB::F::TWO, + start_timestamp + AB::F::ONE, &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); @@ -449,21 +451,6 @@ impl Air for NativeSumcheckAir { let next_logup_row_specfic: &LogupSpecificCols = next.specific[..LogupSpecificCols::::width()].borrow(); - self.memory_bridge - .read( - MemoryAddress::new( - native_as, - register_ptrs[0] - + AB::F::from_canonical_usize(EXT_DEG * 2) - + num_prod_spec - + (curr_logup_n - AB::F::ONE), - ), // curr_logup_n starts at 1. - [max_round], - start_timestamp, - &logup_row_specific.read_records[0], - ) - .eval(builder, logup_row); - // logup_row * within_round_limit = // logup_in_round_evaluation + logup_next_round_evaluation builder @@ -488,8 +475,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), logup_row_specific.pq, - start_timestamp + AB::F::ONE, - &logup_row_specific.read_records[1], + start_timestamp, + &logup_row_specific.read_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -513,7 +500,7 @@ impl Air for NativeSumcheckAir { + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + AB::F::TWO, + start_timestamp + AB::F::ONE, &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -528,7 +515,7 @@ impl Air for NativeSumcheckAir { * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + AB::F::from_canonical_usize(3), + start_timestamp + AB::F::TWO, &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index bb7cfa7080..d5e6f49a62 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -144,7 +144,6 @@ where let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32()); let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()) .map(|x: F| x.as_canonical_u32()); - let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = ctx; // allocate n rows @@ -198,19 +197,22 @@ where r_evals_reg.as_canonical_u32(), head_specific.read_records[4].as_mut(), ); - let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32(), head_specific.read_records[5].as_mut(), ); - let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, challenges_ptr.as_canonical_u32(), head_specific.read_records[6].as_mut(), ); - cur_timestamp += 7; // 5 register reads + ctx read + challenges read + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, + head_specific.read_records[7].as_mut(), + ); + cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); // challenges = [alpha, c1=r, c2=1-r] @@ -221,7 +223,7 @@ where let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - // all rows share same register values, ctx, challenges + // all rows share same register values, ctx, challenges, max_round for row in rows.iter_mut() { // c1, c2 are same during the entire execution row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); @@ -236,6 +238,7 @@ where row.register_ptrs[2] = prod_evals_ptr; row.register_ptrs[3] = logup_evals_ptr; row.register_ptrs[4] = r_evals_ptr; + row.max_round = max_round; } // product rows @@ -256,17 +259,7 @@ where }; prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1 prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp); - - // read max_round - let [max_round]: [F; 1] = tracing_read_native_helper( - state.memory, - ctx_ptr.as_canonical_u32() + (CONTEXT_ARR_BASE_LEN + i) as u32, - prod_specific.read_records[0].as_mut(), - ); - cur_timestamp += 1; - prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); - prod_row.max_round = max_round; let max_round = max_round.as_canonical_u32(); // round starts from 0 @@ -285,7 +278,7 @@ where let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( state.memory, prod_evals_ptr.as_canonical_u32() + start, - prod_specific.read_records[1].as_mut(), + prod_specific.read_records[0].as_mut(), ); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); @@ -350,15 +343,6 @@ where }; logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1 logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp); - - let [max_round]: [F; 1] = tracing_read_native_helper( - state.memory, - ctx_ptr.as_canonical_u32() + num_prod_spec + (CONTEXT_ARR_BASE_LEN + i) as u32, - logup_specific.read_records[0].as_mut(), - ); - logup_row.max_round = max_round; - cur_timestamp += 1; - let alpha_numerator = alpha_acc; let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); @@ -380,7 +364,7 @@ where let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, logup_evals_ptr.as_canonical_u32() + start, - logup_specific.read_records[1].as_mut(), + logup_specific.read_records[0].as_mut(), ); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); @@ -518,7 +502,7 @@ impl TraceFiller for NativeSumcheckFiller { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); - for i in 0..7usize { + for i in 0..8usize { mem_fill_helper( mem_helper, start_timestamp + i as u32, @@ -534,23 +518,17 @@ impl TraceFiller for NativeSumcheckFiller { let prod_row_specific: &mut ProdSpecificCols = cols.specific[..ProdSpecificCols::::width()].borrow_mut(); - // read max_round - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.read_records[0].as_mut(), - ); if cols.within_round_limit == F::ONE { // read p1, p2 mem_fill_helper( mem_helper, - start_timestamp + 1, - prod_row_specific.read_records[1].as_mut(), + start_timestamp, + prod_row_specific.read_records[0].as_mut(), ); // write p_eval mem_fill_helper( mem_helper, - start_timestamp + 2, + start_timestamp + 1, prod_row_specific.write_record.as_mut(), ); } @@ -558,29 +536,23 @@ impl TraceFiller for NativeSumcheckFiller { let logup_row_specific: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); - // read max_round - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.read_records[0].as_mut(), - ); if cols.within_round_limit == F::ONE { // read p1, p2, q1, q2 mem_fill_helper( mem_helper, - start_timestamp + 1, - logup_row_specific.read_records[1].as_mut(), + start_timestamp, + logup_row_specific.read_records[0].as_mut(), ); // write p_eval mem_fill_helper( mem_helper, - start_timestamp + 2, + start_timestamp + 1, logup_row_specific.write_records[0].as_mut(), ); // write q_eval mem_fill_helper( mem_helper, - start_timestamp + 3, + start_timestamp + 2, logup_row_specific.write_records[1].as_mut(), ); } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index b3e6bf4f25..51eb6d39cf 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -92,8 +92,8 @@ pub struct NativeSumcheckCols { pub struct HeaderSpecificCols { pub pc: T, pub registers: [T; 5], - /// 5 register reads + ctx read + challenges read - pub read_records: [MemoryReadAuxCols; 7], + /// 5 register reads + ctx read + max round read + challenges read + pub read_records: [MemoryReadAuxCols; 8], /// Write the final evaluation pub write_records: MemoryWriteAuxCols, } @@ -105,8 +105,8 @@ pub struct ProdSpecificCols { pub data_ptr: T, /// 2 extension elements pub p: [T; EXT_DEG * 2], - /// read max varibale and 2 p values - pub read_records: [MemoryReadAuxCols; 2], + /// read 2 p values + pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// write p_evals @@ -122,8 +122,8 @@ pub struct LogupSpecificCols { pub data_ptr: T, /// 4 extension elements pub pq: [T; EXT_DEG * 4], - /// read max variable and 4 values: p1, p2, q1, q2 - pub read_records: [MemoryReadAuxCols; 2], + /// read 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// Calculated q evals diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index a475bf9e49..54e06e540c 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -214,6 +214,8 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); + let [max_round]: [u32; 1] = + exec_state.vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); @@ -222,12 +224,7 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); - let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32; for i in 0..num_prod_spec { - let [max_round]: [u32; 1] = exec_state - .vm_read(NATIVE_AS, prod_offset + i) - .map(|x: F| x.as_canonical_u32()); - let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, prod_specs_inner_len, @@ -264,12 +261,7 @@ unsafe fn execute_e12_impl( height += 1; } - let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec; for i in 0..num_logup_spec { - // read max_round - let [max_round]: [u32; 1] = exec_state - .vm_read(NATIVE_AS, logup_offset + i) - .map(|x: F| x.as_canonical_u32()); let start = calculate_3d_ext_idx( logup_specs_inner_inner_len, logup_specs_inner_len, diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index a500ee6aac..284a103021 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -103,7 +103,7 @@ fn build_test_program(builder: &mut Builder) { let ctx: Array> = builder.dyn_array(ctx_u32s.len()); for (idx, n) in ctx_u32s.into_iter().enumerate() { - builder.set(&ctx, idx, Usize::from(n as usize)); + builder.set(&ctx, idx, Usize::from(n)); } #[rustfmt::skip] @@ -197,9 +197,9 @@ fn build_test_program(builder: &mut Builder) { } let r_evals = once(eval_acc) - .chain(p_evals.into_iter()) - .chain(logup_p_evals.into_iter()) - .chain(logup_q_evals.into_iter()) + .chain(p_evals) + .chain(logup_p_evals) + .chain(logup_q_evals) .collect::>(); let next_layer_evals: Array> = builder.dyn_array(r_evals.len());