Skip to content

Commit c156b39

Browse files
committed
mirage-crypto-ec: restructure mirage_crypto_ec.ml for SECP256K1
This change modularizes the point representation in preparation for the SECP256K1 implementation, which is based on ECCKiila and uses a different point representation.
1 parent cadf0e1 commit c156b39

File tree

1 file changed

+92
-96
lines changed

1 file changed

+92
-96
lines changed

ec/mirage_crypto_ec.ml

+92-96
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ module type Parameters = sig
8888
val first_byte_bits : int option
8989
end
9090

91-
type point = { f_x : field_element; f_y : field_element; f_z : field_element }
92-
93-
type out_point = { m_f_x : out_field_element; m_f_y : out_field_element; m_f_z : out_field_element }
91+
module Point_proj = struct
92+
type point = { f_x : field_element; f_y : field_element; f_z : field_element }
93+
type out_point = { m_f_x : out_field_element; m_f_y : out_field_element; m_f_z : out_field_element }
94+
end
9495

9596
type scalar = Scalar of string
9697

@@ -107,13 +108,19 @@ module type Foreign = sig
107108
val to_octets : bytes -> field_element -> unit
108109
val inv : out_field_element -> field_element -> unit
109110
val select_c : out_field_element -> bool -> field_element -> field_element -> unit
111+
end
112+
113+
module type Foreign_proj = sig
114+
include Foreign
115+
open Point_proj
110116

111117
val double_c : out_point -> point -> unit
112118
val add_c : out_point -> point -> point -> unit
113119
val scalar_mult_base_c : out_point -> string -> unit
114120
end
115121

116122
module type Field_element = sig
123+
val create : unit -> out_field_element
117124
val mul : field_element -> field_element -> field_element
118125
val sub : field_element -> field_element -> field_element
119126
val add : field_element -> field_element -> field_element
@@ -126,9 +133,6 @@ module type Field_element = sig
126133
val select : bool -> then_:field_element -> else_:field_element -> field_element
127134
val from_be_octets : string -> field_element
128135
val to_octets : field_element -> string
129-
val double_point : point -> point
130-
val add_point : point -> point -> point
131-
val scalar_mult_base_point : scalar -> point
132136
end
133137

134138
module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struct
@@ -196,51 +200,24 @@ module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struc
196200
let tmp = create_octets () in
197201
F.to_octets tmp fe;
198202
b_uts tmp
199-
200-
let out_point () = {
201-
m_f_x = create ();
202-
m_f_y = create ();
203-
m_f_z = create ();
204-
}
205-
206-
let out_p_to_p p = {
207-
f_x = b_uts p.m_f_x ;
208-
f_y = b_uts p.m_f_y ;
209-
f_z = b_uts p.m_f_z ;
210-
}
211-
212-
let double_point p =
213-
let tmp = out_point () in
214-
F.double_c tmp p;
215-
out_p_to_p tmp
216-
217-
let add_point a b =
218-
let tmp = out_point () in
219-
F.add_c tmp a b;
220-
out_p_to_p tmp
221-
222-
let scalar_mult_base_point (Scalar d) =
223-
let tmp = out_point () in
224-
F.scalar_mult_base_c tmp d;
225-
out_p_to_p tmp
226203
end
227204

228205
module type Point = sig
229-
val at_infinity : unit -> point
206+
type point
230207
val is_infinity : point -> bool
231-
val add : point -> point -> point
232-
val double : point -> point
233208
val of_octets : string -> (point, error) result
234209
val to_octets : compress:bool -> point -> string
235210
val to_affine_raw : point -> (field_element * field_element) option
236211
val x_of_finite_point : point -> string
237-
val params_g : point
238-
val select : bool -> then_:point -> else_:point -> point
212+
val scalar_mult : scalar -> point -> point
213+
val scalar_mult_add : scalar -> scalar -> point -> point
239214
val scalar_mult_base : scalar -> point
215+
val generator_tables : unit -> string array array array
240216
end
241217

242-
module Make_point (P : Parameters) (F : Foreign) : Point = struct
218+
module Make_point (P : Parameters) (F : Foreign_proj) : Point = struct
243219
module Fe = Make_field_element(P)(F)
220+
include Point_proj
244221

245222
let at_infinity () =
246223
let f_x = Fe.one in
@@ -273,7 +250,7 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
273250
(** Convert coordinates to a finite point ensuring:
274251
- x < p
275252
- y < p
276-
- y^2 = ax^3 + ax + b
253+
- y^2 = x^3 + ax + b
277254
*)
278255
let validate_finite_point ~x ~y =
279256
match (check_coordinate x, check_coordinate y) with
@@ -325,9 +302,34 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
325302
else
326303
buf
327304

328-
let double p = Fe.double_point p
305+
let out_point () = {
306+
m_f_x = Fe.create ();
307+
m_f_y = Fe.create ();
308+
m_f_z = Fe.create ();
309+
}
310+
311+
let out_p_to_p p =
312+
let b_uts b = Bytes.unsafe_to_string b in
313+
{
314+
f_x = b_uts p.m_f_x ;
315+
f_y = b_uts p.m_f_y ;
316+
f_z = b_uts p.m_f_z ;
317+
}
318+
319+
let double p =
320+
let tmp = out_point () in
321+
F.double_c tmp p;
322+
out_p_to_p tmp
323+
324+
let add a b =
325+
let tmp = out_point () in
326+
F.add_c tmp a b;
327+
out_p_to_p tmp
329328

330-
let add p q = Fe.add_point p q
329+
let scalar_mult_base (Scalar d) =
330+
let tmp = out_point () in
331+
F.scalar_mult_base_c tmp d;
332+
out_p_to_p tmp
331333

332334
let x_of_finite_point p =
333335
match to_affine p with None -> assert false | Some (x, _) -> rev_string x
@@ -418,17 +420,50 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
418420
| 0x00 | 0x04 -> Error `Invalid_length
419421
| _ -> Error `Invalid_format
420422

421-
let scalar_mult_base = Fe.scalar_mult_base_point
423+
(* Branchless Montgomery ladder method *)
424+
let scalar_mult (Scalar s) p =
425+
let r0 = ref (at_infinity ()) in
426+
let r1 = ref p in
427+
for i = P.byte_length * 8 - 1 downto 0 do
428+
let bit = bit_at s i in
429+
let sum = add !r0 !r1 in
430+
let r0_double = double !r0 in
431+
let r1_double = double !r1 in
432+
r0 := select bit ~then_:sum ~else_:r0_double;
433+
r1 := select bit ~then_:r1_double ~else_:sum
434+
done;
435+
!r0
436+
437+
let scalar_mult_add a b p =
438+
add (scalar_mult_base a) (scalar_mult b p)
439+
440+
(* Pre-compute multiples of the generator point
441+
returns the tables along with the number of significant bytes *)
442+
let generator_tables () =
443+
let len = P.fe_length * 2 in
444+
let one_table _ = Array.init 15 (fun _ -> at_infinity ()) in
445+
let table = Array.init len one_table in
446+
let base = ref params_g in
447+
for i = 0 to len - 1 do
448+
table.(i).(0) <- !base;
449+
for j = 1 to 14 do
450+
table.(i).(j) <- add !base table.(i).(j - 1)
451+
done;
452+
base := double !base;
453+
base := double !base;
454+
base := double !base;
455+
base := double !base
456+
done;
457+
let convert {f_x; f_y; f_z} = [|f_x; f_y; f_z|] in
458+
Array.map (Array.map convert) table
459+
422460
end
423461

424462
module type Scalar = sig
425463
val not_zero : string -> bool
426464
val is_in_range : string -> bool
427465
val of_octets : string -> (scalar, error) result
428466
val to_octets : scalar -> string
429-
val scalar_mult : scalar -> point -> point
430-
val scalar_mult_base : scalar -> point
431-
val generator_tables : unit -> field_element array array array
432467
end
433468

434469
module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
@@ -447,43 +482,6 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
447482
| false -> Error `Invalid_range
448483

449484
let to_octets (Scalar buf) = rev_string buf
450-
451-
(* Branchless Montgomery ladder method *)
452-
let scalar_mult (Scalar s) p =
453-
let r0 = ref (P.at_infinity ()) in
454-
let r1 = ref p in
455-
for i = Param.byte_length * 8 - 1 downto 0 do
456-
let bit = bit_at s i in
457-
let sum = P.add !r0 !r1 in
458-
let r0_double = P.double !r0 in
459-
let r1_double = P.double !r1 in
460-
r0 := P.select bit ~then_:sum ~else_:r0_double;
461-
r1 := P.select bit ~then_:r1_double ~else_:sum
462-
done;
463-
!r0
464-
465-
(* Specialization of [scalar_mult d p] when [p] is the generator *)
466-
let scalar_mult_base = P.scalar_mult_base
467-
468-
(* Pre-compute multiples of the generator point
469-
returns the tables along with the number of significant bytes *)
470-
let generator_tables () =
471-
let len = Param.fe_length * 2 in
472-
let one_table _ = Array.init 15 (fun _ -> P.at_infinity ()) in
473-
let table = Array.init len one_table in
474-
let base = ref P.params_g in
475-
for i = 0 to len - 1 do
476-
table.(i).(0) <- !base;
477-
for j = 1 to 14 do
478-
table.(i).(j) <- P.add !base table.(i).(j - 1)
479-
done;
480-
base := P.double !base;
481-
base := P.double !base;
482-
base := P.double !base;
483-
base := P.double !base
484-
done;
485-
let convert {f_x; f_y; f_z} = [|f_x; f_y; f_z|] in
486-
Array.map (Array.map convert) table
487485
end
488486

489487
module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
@@ -498,7 +496,7 @@ module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
498496
type secret = scalar
499497

500498
let share ?(compress = false) private_key =
501-
let public_key = S.scalar_mult_base private_key in
499+
let public_key = P.scalar_mult_base private_key in
502500
point_to_octets ~compress public_key
503501

504502
let secret_of_octets ?compress s =
@@ -522,7 +520,7 @@ module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
522520
let key_exchange secret received =
523521
match point_of_octets received with
524522
| Error _ as err -> err
525-
| Ok shared -> Ok (P.x_of_finite_point (S.scalar_mult secret shared))
523+
| Ok shared -> Ok (P.x_of_finite_point (P.scalar_mult secret shared))
526524
end
527525

528526
module type Foreign_n = sig
@@ -671,7 +669,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
671669

672670
module K_gen_default = K_gen(H)
673671

674-
type pub = point
672+
type pub = P.point
675673

676674
let pub_of_octets = P.of_octets
677675

@@ -687,16 +685,14 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
687685
in
688686
one ()
689687
in
690-
let q = S.scalar_mult_base d in
688+
let q = P.scalar_mult_base d in
691689
(d, q)
692690

693691
let x_of_finite_point_mod_n p =
694692
match P.to_affine_raw p with
695693
| None -> None
696694
| Some (x, _) ->
697-
let x = F.to_montgomery x in
698695
let x = F.mul x F.one in
699-
let x = F.from_montgomery x in
700696
Some (F.to_be_octets x)
701697

702698
let sign ~key ?k msg =
@@ -714,7 +710,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
714710
| Ok ksc -> ksc
715711
| Error _ -> invalid_arg "k not in range" (* if no k is provided, this cannot happen since K_gen_*.gen already preserves the Scalar invariants *)
716712
in
717-
let point = S.scalar_mult_base ksc in
713+
let point = P.scalar_mult_base ksc in
718714
match x_of_finite_point_mod_n point with
719715
| None -> again ()
720716
| Some r ->
@@ -734,7 +730,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
734730
in
735731
do_sign g
736732

737-
let pub_of_priv priv = S.scalar_mult_base priv
733+
let pub_of_priv priv = P.scalar_mult_base priv
738734

739735
let verify ~key (r, s) msg =
740736
try
@@ -756,10 +752,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
756752
S.of_octets (F.to_be_octets u2)
757753
with
758754
| Ok u1, Ok u2 ->
759-
let point =
760-
P.add
761-
(S.scalar_mult_base u1)
762-
(S.scalar_mult u2 key)
755+
let point = P.scalar_mult_add u1 u2 key
763756
in
764757
begin match x_of_finite_point_mod_n point with
765758
| None -> false (* point is infinity *)
@@ -770,7 +763,7 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Dige
770763
| Message_too_long -> false
771764

772765
module Precompute = struct
773-
let generator_tables = S.generator_tables
766+
let generator_tables = P.generator_tables
774767
end
775768
end
776769

@@ -790,6 +783,7 @@ module P256 : Dh_dsa = struct
790783
end
791784

792785
module Foreign = struct
786+
include Point_proj
793787
external mul : out_field_element -> field_element -> field_element -> unit = "mc_p256_mul" [@@noalloc]
794788
external sub : out_field_element -> field_element -> field_element -> unit = "mc_p256_sub" [@@noalloc]
795789
external add : out_field_element -> field_element -> field_element -> unit = "mc_p256_add" [@@noalloc]
@@ -842,6 +836,7 @@ module P384 : Dh_dsa = struct
842836
end
843837

844838
module Foreign = struct
839+
include Point_proj
845840
external mul : out_field_element -> field_element -> field_element -> unit = "mc_p384_mul" [@@noalloc]
846841
external sub : out_field_element -> field_element -> field_element -> unit = "mc_p384_sub" [@@noalloc]
847842
external add : out_field_element -> field_element -> field_element -> unit = "mc_p384_add" [@@noalloc]
@@ -895,6 +890,7 @@ module P521 : Dh_dsa = struct
895890
end
896891

897892
module Foreign = struct
893+
include Point_proj
898894
external mul : out_field_element -> field_element -> field_element -> unit = "mc_p521_mul" [@@noalloc]
899895
external sub : out_field_element -> field_element -> field_element -> unit = "mc_p521_sub" [@@noalloc]
900896
external add : out_field_element -> field_element -> field_element -> unit = "mc_p521_add" [@@noalloc]

0 commit comments

Comments
 (0)