@@ -88,9 +88,10 @@ module type Parameters = sig
88
88
val first_byte_bits : int option
89
89
end
90
90
91
- type point = { f_x : field_element ; f_y : field_element ; f_z : field_element }
92
-
93
- type out_point = { m_f_x : out_field_element ; m_f_y : out_field_element ; m_f_z : out_field_element }
91
+ module Point_proj = struct
92
+ type point = { f_x : field_element ; f_y : field_element ; f_z : field_element }
93
+ type out_point = { m_f_x : out_field_element ; m_f_y : out_field_element ; m_f_z : out_field_element }
94
+ end
94
95
95
96
type scalar = Scalar of string
96
97
@@ -107,13 +108,19 @@ module type Foreign = sig
107
108
val to_octets : bytes -> field_element -> unit
108
109
val inv : out_field_element -> field_element -> unit
109
110
val select_c : out_field_element -> bool -> field_element -> field_element -> unit
111
+ end
112
+
113
+ module type Foreign_proj = sig
114
+ include Foreign
115
+ open Point_proj
110
116
111
117
val double_c : out_point -> point -> unit
112
118
val add_c : out_point -> point -> point -> unit
113
119
val scalar_mult_base_c : out_point -> string -> unit
114
120
end
115
121
116
122
module type Field_element = sig
123
+ val create : unit -> out_field_element
117
124
val mul : field_element -> field_element -> field_element
118
125
val sub : field_element -> field_element -> field_element
119
126
val add : field_element -> field_element -> field_element
@@ -126,9 +133,6 @@ module type Field_element = sig
126
133
val select : bool -> then_ :field_element -> else_ :field_element -> field_element
127
134
val from_be_octets : string -> field_element
128
135
val to_octets : field_element -> string
129
- val double_point : point -> point
130
- val add_point : point -> point -> point
131
- val scalar_mult_base_point : scalar -> point
132
136
end
133
137
134
138
module Make_field_element (P : Parameters ) (F : Foreign ) : Field_element = struct
@@ -196,51 +200,24 @@ module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struc
196
200
let tmp = create_octets () in
197
201
F. to_octets tmp fe;
198
202
b_uts tmp
199
-
200
- let out_point () = {
201
- m_f_x = create () ;
202
- m_f_y = create () ;
203
- m_f_z = create () ;
204
- }
205
-
206
- let out_p_to_p p = {
207
- f_x = b_uts p.m_f_x ;
208
- f_y = b_uts p.m_f_y ;
209
- f_z = b_uts p.m_f_z ;
210
- }
211
-
212
- let double_point p =
213
- let tmp = out_point () in
214
- F. double_c tmp p;
215
- out_p_to_p tmp
216
-
217
- let add_point a b =
218
- let tmp = out_point () in
219
- F. add_c tmp a b;
220
- out_p_to_p tmp
221
-
222
- let scalar_mult_base_point (Scalar d ) =
223
- let tmp = out_point () in
224
- F. scalar_mult_base_c tmp d;
225
- out_p_to_p tmp
226
203
end
227
204
228
205
module type Point = sig
229
- val at_infinity : unit -> point
206
+ type point
230
207
val is_infinity : point -> bool
231
- val add : point -> point -> point
232
- val double : point -> point
233
208
val of_octets : string -> (point , error ) result
234
209
val to_octets : compress :bool -> point -> string
235
210
val to_affine_raw : point -> (field_element * field_element ) option
236
211
val x_of_finite_point : point -> string
237
- val params_g : point
238
- val select : bool -> then_ : point -> else_ : point -> point
212
+ val scalar_mult : scalar -> point -> point
213
+ val scalar_mult_add : scalar -> scalar -> point -> point
239
214
val scalar_mult_base : scalar -> point
215
+ val generator_tables : unit -> string array array array
240
216
end
241
217
242
- module Make_point (P : Parameters ) (F : Foreign ) : Point = struct
218
+ module Make_point (P : Parameters ) (F : Foreign_proj ) : Point = struct
243
219
module Fe = Make_field_element (P )(F )
220
+ include Point_proj
244
221
245
222
let at_infinity () =
246
223
let f_x = Fe. one in
@@ -273,7 +250,7 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
273
250
(* * Convert coordinates to a finite point ensuring:
274
251
- x < p
275
252
- y < p
276
- - y^2 = ax ^3 + ax + b
253
+ - y^2 = x ^3 + ax + b
277
254
*)
278
255
let validate_finite_point ~x ~y =
279
256
match (check_coordinate x, check_coordinate y) with
@@ -325,9 +302,34 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
325
302
else
326
303
buf
327
304
328
- let double p = Fe. double_point p
305
+ let out_point () = {
306
+ m_f_x = Fe. create () ;
307
+ m_f_y = Fe. create () ;
308
+ m_f_z = Fe. create () ;
309
+ }
310
+
311
+ let out_p_to_p p =
312
+ let b_uts b = Bytes. unsafe_to_string b in
313
+ {
314
+ f_x = b_uts p.m_f_x ;
315
+ f_y = b_uts p.m_f_y ;
316
+ f_z = b_uts p.m_f_z ;
317
+ }
318
+
319
+ let double p =
320
+ let tmp = out_point () in
321
+ F. double_c tmp p;
322
+ out_p_to_p tmp
323
+
324
+ let add a b =
325
+ let tmp = out_point () in
326
+ F. add_c tmp a b;
327
+ out_p_to_p tmp
329
328
330
- let add p q = Fe. add_point p q
329
+ let scalar_mult_base (Scalar d ) =
330
+ let tmp = out_point () in
331
+ F. scalar_mult_base_c tmp d;
332
+ out_p_to_p tmp
331
333
332
334
let x_of_finite_point p =
333
335
match to_affine p with None -> assert false | Some (x , _ ) -> rev_string x
@@ -418,17 +420,50 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
418
420
| 0x00 | 0x04 -> Error `Invalid_length
419
421
| _ -> Error `Invalid_format
420
422
421
- let scalar_mult_base = Fe. scalar_mult_base_point
423
+ (* Branchless Montgomery ladder method *)
424
+ let scalar_mult (Scalar s ) p =
425
+ let r0 = ref (at_infinity () ) in
426
+ let r1 = ref p in
427
+ for i = P. byte_length * 8 - 1 downto 0 do
428
+ let bit = bit_at s i in
429
+ let sum = add ! r0 ! r1 in
430
+ let r0_double = double ! r0 in
431
+ let r1_double = double ! r1 in
432
+ r0 := select bit ~then_: sum ~else_: r0_double;
433
+ r1 := select bit ~then_: r1_double ~else_: sum
434
+ done ;
435
+ ! r0
436
+
437
+ let scalar_mult_add a b p =
438
+ add (scalar_mult_base a) (scalar_mult b p)
439
+
440
+ (* Pre-compute multiples of the generator point
441
+ returns the tables along with the number of significant bytes *)
442
+ let generator_tables () =
443
+ let len = P. fe_length * 2 in
444
+ let one_table _ = Array. init 15 (fun _ -> at_infinity () ) in
445
+ let table = Array. init len one_table in
446
+ let base = ref params_g in
447
+ for i = 0 to len - 1 do
448
+ table.(i).(0 ) < - ! base;
449
+ for j = 1 to 14 do
450
+ table.(i).(j) < - add ! base table.(i).(j - 1 )
451
+ done ;
452
+ base := double ! base;
453
+ base := double ! base;
454
+ base := double ! base;
455
+ base := double ! base
456
+ done ;
457
+ let convert {f_x; f_y; f_z} = [|f_x; f_y; f_z|] in
458
+ Array. map (Array. map convert) table
459
+
422
460
end
423
461
424
462
module type Scalar = sig
425
463
val not_zero : string -> bool
426
464
val is_in_range : string -> bool
427
465
val of_octets : string -> (scalar , error ) result
428
466
val to_octets : scalar -> string
429
- val scalar_mult : scalar -> point -> point
430
- val scalar_mult_base : scalar -> point
431
- val generator_tables : unit -> field_element array array array
432
467
end
433
468
434
469
module Make_scalar (Param : Parameters ) (P : Point ) : Scalar = struct
@@ -447,43 +482,6 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
447
482
| false -> Error `Invalid_range
448
483
449
484
let to_octets (Scalar buf ) = rev_string buf
450
-
451
- (* Branchless Montgomery ladder method *)
452
- let scalar_mult (Scalar s ) p =
453
- let r0 = ref (P. at_infinity () ) in
454
- let r1 = ref p in
455
- for i = Param. byte_length * 8 - 1 downto 0 do
456
- let bit = bit_at s i in
457
- let sum = P. add ! r0 ! r1 in
458
- let r0_double = P. double ! r0 in
459
- let r1_double = P. double ! r1 in
460
- r0 := P. select bit ~then_: sum ~else_: r0_double;
461
- r1 := P. select bit ~then_: r1_double ~else_: sum
462
- done ;
463
- ! r0
464
-
465
- (* Specialization of [scalar_mult d p] when [p] is the generator *)
466
- let scalar_mult_base = P. scalar_mult_base
467
-
468
- (* Pre-compute multiples of the generator point
469
- returns the tables along with the number of significant bytes *)
470
- let generator_tables () =
471
- let len = Param. fe_length * 2 in
472
- let one_table _ = Array. init 15 (fun _ -> P. at_infinity () ) in
473
- let table = Array. init len one_table in
474
- let base = ref P. params_g in
475
- for i = 0 to len - 1 do
476
- table.(i).(0 ) < - ! base;
477
- for j = 1 to 14 do
478
- table.(i).(j) < - P. add ! base table.(i).(j - 1 )
479
- done ;
480
- base := P. double ! base;
481
- base := P. double ! base;
482
- base := P. double ! base;
483
- base := P. double ! base
484
- done ;
485
- let convert {f_x; f_y; f_z} = [|f_x; f_y; f_z|] in
486
- Array. map (Array. map convert) table
487
485
end
488
486
489
487
module Make_dh (Param : Parameters ) (P : Point ) (S : Scalar ) : Dh = struct
@@ -498,7 +496,7 @@ module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
498
496
type secret = scalar
499
497
500
498
let share ?(compress = false ) private_key =
501
- let public_key = S . scalar_mult_base private_key in
499
+ let public_key = P . scalar_mult_base private_key in
502
500
point_to_octets ~compress public_key
503
501
504
502
let secret_of_octets ?compress s =
@@ -522,7 +520,7 @@ module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
522
520
let key_exchange secret received =
523
521
match point_of_octets received with
524
522
| Error _ as err -> err
525
- | Ok shared -> Ok (P. x_of_finite_point (S . scalar_mult secret shared))
523
+ | Ok shared -> Ok (P. x_of_finite_point (P . scalar_mult secret shared))
526
524
end
527
525
528
526
module type Foreign_n = sig
@@ -671,7 +669,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
671
669
672
670
module K_gen_default = K_gen (H )
673
671
674
- type pub = point
672
+ type pub = P . point
675
673
676
674
let pub_of_octets = P. of_octets
677
675
@@ -687,16 +685,14 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
687
685
in
688
686
one ()
689
687
in
690
- let q = S . scalar_mult_base d in
688
+ let q = P . scalar_mult_base d in
691
689
(d, q)
692
690
693
691
let x_of_finite_point_mod_n p =
694
692
match P. to_affine_raw p with
695
693
| None -> None
696
694
| Some (x , _ ) ->
697
- let x = F. to_montgomery x in
698
695
let x = F. mul x F. one in
699
- let x = F. from_montgomery x in
700
696
Some (F. to_be_octets x)
701
697
702
698
let sign ~key ?k msg =
@@ -714,7 +710,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
714
710
| Ok ksc -> ksc
715
711
| Error _ -> invalid_arg " k not in range" (* if no k is provided, this cannot happen since K_gen_*.gen already preserves the Scalar invariants *)
716
712
in
717
- let point = S . scalar_mult_base ksc in
713
+ let point = P . scalar_mult_base ksc in
718
714
match x_of_finite_point_mod_n point with
719
715
| None -> again ()
720
716
| Some r ->
@@ -734,7 +730,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
734
730
in
735
731
do_sign g
736
732
737
- let pub_of_priv priv = S . scalar_mult_base priv
733
+ let pub_of_priv priv = P . scalar_mult_base priv
738
734
739
735
let verify ~key (r , s ) msg =
740
736
try
@@ -756,10 +752,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
756
752
S. of_octets (F. to_be_octets u2)
757
753
with
758
754
| Ok u1 , Ok u2 ->
759
- let point =
760
- P. add
761
- (S. scalar_mult_base u1)
762
- (S. scalar_mult u2 key)
755
+ let point = P. scalar_mult_add u1 u2 key
763
756
in
764
757
begin match x_of_finite_point_mod_n point with
765
758
| None -> false (* point is infinity *)
@@ -770,7 +763,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
770
763
| Message_too_long -> false
771
764
772
765
module Precompute = struct
773
- let generator_tables = S . generator_tables
766
+ let generator_tables = P . generator_tables
774
767
end
775
768
end
776
769
@@ -790,6 +783,7 @@ module P256 : Dh_dsa = struct
790
783
end
791
784
792
785
module Foreign = struct
786
+ include Point_proj
793
787
external mul : out_field_element -> field_element -> field_element -> unit = " mc_p256_mul" [@@ noalloc]
794
788
external sub : out_field_element -> field_element -> field_element -> unit = " mc_p256_sub" [@@ noalloc]
795
789
external add : out_field_element -> field_element -> field_element -> unit = " mc_p256_add" [@@ noalloc]
@@ -842,6 +836,7 @@ module P384 : Dh_dsa = struct
842
836
end
843
837
844
838
module Foreign = struct
839
+ include Point_proj
845
840
external mul : out_field_element -> field_element -> field_element -> unit = " mc_p384_mul" [@@ noalloc]
846
841
external sub : out_field_element -> field_element -> field_element -> unit = " mc_p384_sub" [@@ noalloc]
847
842
external add : out_field_element -> field_element -> field_element -> unit = " mc_p384_add" [@@ noalloc]
@@ -895,6 +890,7 @@ module P521 : Dh_dsa = struct
895
890
end
896
891
897
892
module Foreign = struct
893
+ include Point_proj
898
894
external mul : out_field_element -> field_element -> field_element -> unit = " mc_p521_mul" [@@ noalloc]
899
895
external sub : out_field_element -> field_element -> field_element -> unit = " mc_p521_sub" [@@ noalloc]
900
896
external add : out_field_element -> field_element -> field_element -> unit = " mc_p521_add" [@@ noalloc]
0 commit comments