Skip to content

Commit eb8b8c1

Browse files
authored
Fix doubling points on the x axis bug, add unittest (#18)
1 parent 55ee2e8 commit eb8b8c1

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

ecc/curve.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -100,25 +100,17 @@ def add_point(self, P: Point, Q: Point) -> Point:
100100
elif Q.is_at_infinity():
101101
return P
102102

103-
if P == Q:
104-
return self._double_point(P)
105103
if P == -Q:
106104
return self.INF
105+
if P == Q:
106+
return self._double_point(P)
107107

108108
return self._add_point(P, Q)
109109

110110
@abstractmethod
111111
def _add_point(self, P: Point, Q: Point) -> Point:
112112
pass
113113

114-
def double_point(self, P: Point) -> Point:
115-
if not self.is_on_curve(P):
116-
raise ValueError("The point is not on the curve.")
117-
if P.is_at_infinity():
118-
return self.INF
119-
120-
return self._double_point(P)
121-
122114
@abstractmethod
123115
def _double_point(self, P: Point) -> Point:
124116
pass
@@ -134,17 +126,14 @@ def mul_point(self, d: int, P: Point) -> Point:
134126
if d == 0:
135127
return self.INF
136128

137-
res = None
129+
res = self.INF
138130
is_negative_scalar = d < 0
139131
d = -d if is_negative_scalar else d
140132
tmp = P
141133
while d:
142134
if d & 0x1 == 1:
143-
if res:
144-
res = self.add_point(res, tmp)
145-
else:
146-
res = tmp
147-
tmp = self.double_point(tmp)
135+
res = self.add_point(res, tmp)
136+
tmp = self.add_point(tmp, tmp)
148137
d >>= 1
149138
if is_negative_scalar:
150139
return -res

tests/test_curve.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

33
from ecc.curve import (
4-
P256, secp256k1, Curve25519, M383, E222, E382
4+
P256, secp256k1, Curve25519, M383, E222, E382, Point
55
)
66

77
CURVES = [P256, secp256k1, Curve25519, M383, E222, E382]
@@ -24,3 +24,9 @@ def test_operator(self):
2424
self.assertEqual(curve.INF + curve.INF, curve.INF)
2525
self.assertEqual(0 * P, curve.INF)
2626
self.assertEqual(1000 * curve.INF, curve.INF)
27+
28+
def test_double_points_y_equals_to_0(self):
29+
P = Point(x=0, y=0, curve=Curve25519)
30+
self.assertEqual(P + P, Curve25519.INF)
31+
self.assertEqual(2 * P, Curve25519.INF)
32+
self.assertEqual(-2 * P, Curve25519.INF)

0 commit comments

Comments
 (0)