Skip to content

Commit e6db501

Browse files
sipapeterdettman
authored andcommitted
Update safegcd writeup to reflect the code
1 parent 755e1c2 commit e6db501

File tree

1 file changed

+78
-60
lines changed

1 file changed

+78
-60
lines changed

doc/safegcd_implementation.md

+78-60
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ do one division by *2<sup>N</sup>* as a final step:
155155
```python
156156
def divsteps_n_matrix(delta, f, g):
157157
"""Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
158-
u, v, q, r = 1, 0, 0, 1 # start with identity matrix
158+
u, v, q, r = 2**N, 0, 0, 2**N # start with identity matrix
159159
for _ in range(N):
160160
if delta > 0 and g & 1:
161-
delta, f, g, u, v, q, r = 1 - delta, g, (g - f) // 2, 2*q, 2*r, q-u, r-v
161+
delta, f, g, u, v, q, r = 1 - delta, g, (g-f)//2, q, r, (q-u)//2, (r-v)//2
162162
elif g & 1:
163-
delta, f, g, u, v, q, r = 1 + delta, f, (g + f) // 2, 2*u, 2*v, q+u, r+v
163+
delta, f, g, u, v, q, r = 1 + delta, f, (g+f)//2, u, v, (q+u)//2, (r+v)//2
164164
else:
165-
delta, f, g, u, v, q, r = 1 + delta, f, (g ) // 2, 2*u, 2*v, q , r
165+
delta, f, g, u, v, q, r = 1 + delta, f, (g )//2, u, v, (q )//2, (r )//2
166166
return delta, (u, v, q, r)
167167
```
168168

@@ -414,9 +414,9 @@ operations (and hope the C compiler isn't smart enough to turn them back into br
414414
divstep can be written instead as (compare to the inner loop of `gcd` in section 1).
415415

416416
```python
417-
x = -f if delta > 0 else f # set x equal to (input) -f or f
417+
x = f if delta > 0 else -f # set x equal to (input) f or -f
418418
if g & 1:
419-
g += x # set g to (input) g-f or g+f
419+
g -= x # set g to (input) g-f or g+f
420420
if delta > 0:
421421
delta = -delta
422422
f += g # set f to (input) g (note that g was set to g-f before)
@@ -433,13 +433,13 @@ that *-v == (v ^ -1) - (-1)*. Thus, if we have a variable *c* that takes on valu
433433
Using this we can write:
434434

435435
```python
436-
x = -f if delta > 0 else f
436+
x = f if delta > 0 else -f
437437
```
438438

439439
in constant-time form as:
440440

441441
```python
442-
c1 = (-delta) >> 63
442+
c1 = delta >> 63
443443
# Conditionally negate f based on c1:
444444
x = (f ^ c1) - c1
445445
```
@@ -454,7 +454,7 @@ Using the facts that *x&0=0* and *x&(-1)=x* (on two's complement systems again),
454454

455455
```python
456456
if g & 1:
457-
g += x
457+
g -= x
458458
```
459459

460460
as:
@@ -463,7 +463,7 @@ as:
463463
# Compute c2=0 if g is even and c2=-1 if g is odd.
464464
c2 = -(g & 1)
465465
# This masks out x if g is even, and leaves x be if g is odd.
466-
g += x & c2
466+
g -= x & c2
467467
```
468468

469469
Using the conditional negation trick again we can write:
@@ -478,7 +478,7 @@ as:
478478

479479
```python
480480
# Compute c3=-1 if g is odd and delta>0, and 0 otherwise.
481-
c3 = c1 & c2
481+
c3 = ~c1 & c2
482482
# Conditionally negate delta based on c3:
483483
delta = (delta ^ c3) - c3
484484
```
@@ -497,45 +497,59 @@ becomes:
497497
f += g & c3
498498
```
499499

500-
It turns out that this can be implemented more efficiently by applying the substitution
501-
*&eta;=-&delta;*. In this representation, negating *&delta;* corresponds to negating *&eta;*, and incrementing
502-
*&delta;* corresponds to decrementing *&eta;*. This allows us to remove the negation in the *c1*
503-
computation:
500+
Putting everything together, extending all operations on f,g (with helper x) to also be applied
501+
to u,q (with helper y) and v,r (with helper z), gives:
504502

505503
```python
506-
# Compute a mask c1 for eta < 0, and compute the conditional negation x of f:
507-
c1 = eta >> 63
508-
x = (f ^ c1) - c1
509-
# Compute a mask c2 for odd g, and conditionally add x to g:
510-
c2 = -(g & 1)
511-
g += x & c2
512-
# Compute a mask c for (eta < 0) and odd (input) g, and use it to conditionally negate eta,
513-
# and add g to f:
514-
c3 = c1 & c2
515-
eta = (eta ^ c3) - c3
516-
f += g & c3
517-
# Incrementing delta corresponds to decrementing eta.
518-
eta -= 1
519-
g >>= 1
504+
def divsteps_n_matrix(delta, f, g):
505+
"""Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
506+
u, v, q, r = 1 << N, 0, 0, 1 << N # start with identity matrix (scaled by 2^N).
507+
for i in range(N):
508+
c1 = delta >> 63
509+
# Compute x, y, z as conditionally-negated versions of f, u, v.
510+
x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
511+
c2 = -(g & 1)
512+
# Conditionally subtract x, y, z from g, q, r.
513+
g, q, r = g - (x & c2), q - (y & c2), r - (z & c2)
514+
c3 = ~c1 & c2
515+
# Conditionally negate delta, and then increment it by 1.
516+
delta = (delta ^ c3) - c3 + 1
517+
# Conditionally add g, q, r to f, u, v.
518+
f, u, v = f + (g & c3), u + (q & c3), v + (r & c3)
519+
# Shift down g, q, r.
520+
g, q, r = g >> 1, u >> 1, v >> 1
521+
return delta, (u, v, q, r)
520522
```
521523

522-
A variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
524+
An interesting optimization is possible here. If we were to drop the *-c1* in the computation
525+
of *x*, *y*, and *z*, we are making them at worst *1* less than the correct value. That
526+
translates to *g*, *q*, and *r* further being at worst *1* more than the correct value.
527+
Now observe that at the start of every iteration of the loop, *u*, *v*, *q*, and *r* are
528+
all multiples of *2<sup>N-i</sub>*, with *i* the iteration number, and thus all even.
529+
In other words, this potential off by one in *g*, *q*, and *r* only affects their bottommost
530+
bit, which is shifted away at the end of the loop. Thus we can instead write:
531+
532+
```python
533+
# Compute x, y, z as conditionally complemented versions of f, u, v.
534+
x, y, z = f ^ c1, u ^ c1, v ^ c1
535+
```
536+
537+
Finally, a variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
523538
*1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs
524-
(which can be shown using convex hull analysis). In this case, the substitution *&zeta;=-(&delta;+1/2)*
525-
is used instead to keep the variable integral. Incrementing *&delta;* by *1* still translates to
526-
decrementing *&zeta;* by *1*, but negating *&delta;* now corresponds to going from *&zeta;* to *-(&zeta;+1)*, or
527-
*~&zeta;*. Doing that conditionally based on *c3* is simply:
539+
(which can be shown using [convex hull analysis](https://github.com/sipa/safegcd-bounds)).
540+
In this case, the substitution *&theta;=&delta;-1/2* is used to keep the variable integral.
541+
Negating *&delta;* now corresponds to going from *&theta;* to
542+
*&theta;-1*. Doing that conditionally based on *c3* (and then incrementing by one) gives us:
528543

529544
```python
530545
...
531-
c3 = c1 & c2
532-
zeta ^= c3
546+
theta = (theta ^ c3) + 1
533547
...
534548
```
535549

536550
By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to
537551
also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of
538-
`divsteps_n_matrix` is obtained. The full code will be in section 7.
552+
`divsteps_n_matrix` is obtained. The resulting code will be in section 7.
539553

540554
These bit fiddling tricks can also be used to make the conditional negations and additions in
541555
`update_de` and `normalize` constant-time.
@@ -550,7 +564,7 @@ faster non-constant time `divsteps_n_matrix` function.
550564

551565
To do so, first consider yet another way of writing the inner loop of divstep operations in
552566
`gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use
553-
the original version with initial *&delta;=1* and *&eta;=-&delta;* here.
567+
the original version with initial *&delta;=1*, but make the substitution *&eta;=-&delta;*.
554568

555569
```python
556570
for _ in range(N):
@@ -651,37 +665,41 @@ Here we need the negated modular inverse, which is a simple transformation of th
651665
have this 6-bit function (based on the 3-bit function above):
652666
- *f(f<sup>2</sup> - 2)*
653667

654-
This loop, again extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
655-
`divsteps_n_matrix`, gives a significantly faster, but non-constant time version.
668+
This loop, extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
669+
`divsteps_n_matrix`, gives a significantly faster, but non-constant time version. In order to
670+
avoid intermediary values that need more than N+1 bits, it is possible to instead start
671+
*u* and *v* at *1* instead of at *2<sup>N</sup>*, and then shift up *u* and *v* whenever
672+
*g* is shifted down (instead of shifting down *q* and *r*). This is effectively making the
673+
algorithm operate on *i*-bits downshifted versions of all these variables. The resulting
674+
code is shown in the next section.
656675

657676

658677
## 7. Final Python version
659678

660679
All together we need the following functions:
661680

662681
- A way to compute the transition matrix in constant time, using the `divsteps_n_matrix` function
663-
from section 2, but with its loop replaced by a variant of the constant-time divstep from
664-
section 5, extended to handle *u*, *v*, *q*, *r*:
682+
from section 5, modified to operate on *&theta;* instead of *&delta;*:
665683

666684
```python
667-
def divsteps_n_matrix(zeta, f, g):
668-
"""Compute zeta and transition matrix t after N divsteps (multiplied by 2^N)."""
669-
u, v, q, r = 1, 0, 0, 1 # start with identity matrix
685+
def divsteps_n_matrix(theta, f, g):
686+
"""Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
687+
u, v, q, r = 1 << N, 0, 0, 1 << N # start with identity matrix (scaled by 2^N).
670688
for _ in range(N):
671-
c1 = zeta >> 63
672-
# Compute x, y, z as conditionally-negated versions of f, u, v.
673-
x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
689+
c1 = theta >> 63
690+
# Compute x, y, z as conditionally complemented versions of f, u, v.
691+
x, y, z = f ^ c1, u ^ c1, v ^ c1
674692
c2 = -(g & 1)
675-
# Conditionally add x, y, z to g, q, r.
676-
g, q, r = g + (x & c2), q + (y & c2), r + (z & c2)
677-
c1 &= c2 # reusing c1 here for the earlier c3 variable
678-
zeta = (zeta ^ c1) - 1 # inlining the unconditional zeta decrement here
693+
# Conditionally subtract x, y, z from g, q, r.
694+
g, q, r = g - (x & c2), q - (y & c2), r - (z & c2)
695+
c3 = ~c1 & c2
696+
# Conditionally negate delta, and then increment it by 1.
697+
theta = (theta ^ c3) + 1
679698
# Conditionally add g, q, r to f, u, v.
680-
f, u, v = f + (g & c1), u + (q & c1), v + (r & c1)
681-
# When shifting g down, don't shift q, r, as we construct a transition matrix multiplied
682-
# by 2^N. Instead, shift f's coefficients u and v up.
683-
g, u, v = g >> 1, u << 1, v << 1
684-
return zeta, (u, v, q, r)
699+
f, u, v = f + (g & c3), u + (q & c3), v + (r & c3)
700+
# Shift down f, q, r.
701+
g, q, r = g >> 1, u >> 1, v >> 1
702+
return theta, (u, v, q, r)
685703
```
686704

687705
- The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time
@@ -723,15 +741,15 @@ def normalize(sign, v, M):
723741
return v
724742
```
725743

726-
- And finally the `modinv` function too, adapted to use *&zeta;* instead of *&delta;*, and using the fixed
744+
- And finally the `modinv` function too, adapted to use *&theta;* instead of *&delta;*, and using the fixed
727745
iteration count from section 5:
728746

729747
```python
730748
def modinv(M, Mi, x):
731749
"""Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
732-
zeta, f, g, d, e = -1, M, x, 0, 1
750+
theta, f, g, d, e = 0, M, x, 0, 1
733751
for _ in range((590 + N - 1) // N):
734-
zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N)
752+
theta, t = divsteps_n_matrix(theta, f % 2**N, g % 2**N)
735753
f, g = update_fg(f, g, t)
736754
d, e = update_de(d, e, t, M, Mi)
737755
return normalize(f, d, M)

0 commit comments

Comments
 (0)