1
- use std:: { marker:: PhantomData , sync:: Arc } ;
1
+ use std:: { marker:: PhantomData , mem , sync:: Arc } ;
2
2
3
3
use ff_ext:: ExtensionField ;
4
4
use gkr_iop:: {
@@ -155,18 +155,16 @@ where
155
155
{
156
156
type Trace = TowerChipTrace ;
157
157
158
- fn phase1_witness ( & self , phase1 : & Self :: Trace ) -> Vec < Vec < E :: BaseField > > {
158
+ fn phase1_witness ( & self , phase1 : Self :: Trace ) -> Vec < Vec < E :: BaseField > > {
159
159
let mut res = vec ! [ vec![ ] ; 2 ] ;
160
160
res[ self . committed_table ] = phase1
161
161
. table
162
- . iter ( )
163
- . cloned ( )
162
+ . into_iter ( )
164
163
. map ( E :: BaseField :: from_u64)
165
164
. collect ( ) ;
166
165
res[ self . committed_count ] = phase1
167
166
. count
168
- . iter ( )
169
- . cloned ( )
167
+ . into_iter ( )
170
168
. map ( E :: BaseField :: from_u64)
171
169
. collect ( ) ;
172
170
res
@@ -187,40 +185,30 @@ where
187
185
layer_wits. push ( LayerWitness :: new ( vec ! [ table. clone( ) ] , vec ! [ ] ) ) ;
188
186
189
187
// Compute den_0, den_1, num_0, num_1 for each layer.
190
- let updated_table = table. iter ( ) . map ( |x| beta + * x) . collect_vec ( ) ;
191
- let mut last_den = vec ! [ ] ;
192
- let mut last_num = vec ! [ ] ;
193
-
194
- ( 0 ..self . params . height ) . for_each ( |i| {
195
- if i == 0 {
196
- let ( num_0, num_1) : ( Vec < E :: BaseField > , Vec < E :: BaseField > ) =
197
- count. chunks ( 2 ) . map ( |chunk| ( chunk[ 0 ] , chunk[ 1 ] ) ) . unzip ( ) ;
198
- let ( den_0, den_1) : ( Vec < E > , Vec < E > ) = updated_table
199
- . chunks ( 2 )
200
- . map ( |chunk| ( chunk[ 0 ] , chunk[ 1 ] ) )
201
- . unzip ( ) ;
202
- ( last_den, last_num) = izip ! ( & den_0, & den_1, & num_0, & num_1)
203
- . map ( |( den_0, den_1, num_0, num_1) | {
204
- ( * den_0 * * den_1, * den_0 * * num_1 + * den_1 * * num_0)
205
- } )
206
- . unzip ( ) ;
207
-
208
- layer_wits. push ( LayerWitness :: new ( vec ! [ num_0, num_1] , vec ! [ den_0, den_1] ) ) ;
209
- } else {
210
- let ( den_0, den_1) : ( Vec < E > , Vec < E > ) =
211
- last_den. chunks ( 2 ) . map ( |chunk| ( chunk[ 0 ] , chunk[ 1 ] ) ) . unzip ( ) ;
212
- let ( num_0, num_1) : ( Vec < E > , Vec < E > ) =
213
- last_num. chunks ( 2 ) . map ( |chunk| ( chunk[ 0 ] , chunk[ 1 ] ) ) . unzip ( ) ;
214
-
215
- ( last_den, last_num) = izip ! ( & den_0, & den_1, & num_0, & num_1)
216
- . map ( |( den_0, den_1, num_0, num_1) | {
217
- ( * den_0 * * den_1, * den_0 * * num_1 + * den_1 * * num_0)
218
- } )
219
- . unzip ( ) ;
220
-
221
- layer_wits. push ( LayerWitness :: new ( vec ! [ ] , vec ! [ den_0, den_1, num_0, num_1] ) ) ;
222
- }
223
- } ) ;
188
+ let updated_table = table. iter ( ) . cloned ( ) . map ( |x| beta + x) . collect_vec ( ) ;
189
+
190
+ let ( num_0, num_1) : ( Vec < E :: BaseField > , Vec < E :: BaseField > ) = count. iter ( ) . tuples ( ) . unzip ( ) ;
191
+ let ( den_0, den_1) : ( Vec < E > , Vec < E > ) = updated_table. into_iter ( ) . tuples ( ) . unzip ( ) ;
192
+ let ( mut last_den, mut last_num) : ( Vec < _ > , Vec < _ > ) = izip ! ( & den_0, & den_1, & num_0, & num_1)
193
+ . map ( |( & den_0, & den_1, & num_0, & num_1) | ( den_0 * den_1, den_0 * num_1 + den_1 * num_0) )
194
+ . unzip ( ) ;
195
+
196
+ layer_wits. push ( LayerWitness :: new ( vec ! [ num_0, num_1] , vec ! [ den_0, den_1] ) ) ;
197
+
198
+ layer_wits. extend ( ( 1 ..self . params . height ) . map ( |_i| {
199
+ let ( den_0, den_1) : ( Vec < E > , Vec < E > ) =
200
+ mem:: take ( & mut last_den) . into_iter ( ) . tuples ( ) . unzip ( ) ;
201
+ let ( num_0, num_1) : ( Vec < E > , Vec < E > ) =
202
+ mem:: take ( & mut last_num) . into_iter ( ) . tuples ( ) . unzip ( ) ;
203
+
204
+ ( last_den, last_num) = izip ! ( & den_0, & den_1, & num_0, & num_1)
205
+ . map ( |( & den_0, & den_1, & num_0, & num_1) | {
206
+ ( den_0 * den_1, den_0 * num_1 + den_1 * num_0)
207
+ } )
208
+ . unzip ( ) ;
209
+
210
+ LayerWitness :: new ( vec ! [ ] , vec ! [ den_0, den_1, num_0, num_1] )
211
+ } ) ) ;
224
212
layer_wits. reverse ( ) ;
225
213
226
214
GKRCircuitWitness { layers : layer_wits }
@@ -240,7 +228,7 @@ fn main() {
240
228
let count = ( 0 ..1 << log_size)
241
229
. map ( |_| OsRng . gen_range ( 0 ..1 << log_size as u64 ) )
242
230
. collect_vec ( ) ;
243
- let phase1_witness = layout. phase1_witness ( & TowerChipTrace { table, count } ) ;
231
+ let phase1_witness = layout. phase1_witness ( TowerChipTrace { table, count } ) ;
244
232
245
233
let mut prover_transcript = BasicTranscript :: < E > :: new ( b"protocol" ) ;
246
234
0 commit comments