1515/// Stores pre-computed data for efficient modular operations.
1616/// - Warning: The operations may leak the modulus through timing or other side channels. So this struct should only be
1717/// used for public moduli.
18- @usableFromInline
19- struct Modulus < T: ScalarType > : Equatable , Sendable {
18+ public struct Modulus < T: ScalarType > : Equatable , Sendable {
2019 /// The maximum valid modulus value.
21- @ usableFromInline static var max : T {
20+ public static var max : T {
2221 ReduceModulus . max
2322 }
2423
@@ -31,15 +30,16 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
3130 /// `ceil(2^k / modulus) - 2^(2 * T.bitWidth)` for
3231 /// `k = 2 * T.bitWidth + ceil(log2(modulus)`.
3332 @usableFromInline let divisionModulus : DivisionModulus < T >
34- @usableFromInline let modulus : T
33+ /// The modulus, `p`.
34+ public let modulus : T
3535
3636 /// Initializes a ``Modulus``.
3737 /// - Parameters:
38- /// - modulus: Modulus.
38+ /// - modulus: Modulus. Must be less than ``Modulus/max``.
3939 /// - variableTime: Must be `true`, indicating `modulus` is leaked through timing.
4040 /// - Warning: Leaks `modulus` through timing.
4141 @inlinable
42- init ( modulus: T , variableTime: Bool ) {
42+ public init ( modulus: T , variableTime: Bool ) {
4343 precondition ( variableTime)
4444 self . singleWordModulus = ReduceModulus (
4545 modulus: modulus,
@@ -57,21 +57,35 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
5757 self . modulus = modulus
5858 }
5959
60+ /// Performs modular reduction with modulus `p`.
61+ /// - Parameter x: Value to reduce.
62+ /// - Returns: `x mod p` in `[0, p).`
6063 @inlinable
61- func reduce( _ x: T ) -> T {
64+ public func reduce( _ x: T ) -> T {
6265 singleWordModulus. reduce ( x)
6366 }
6467
68+ /// Performs modular reduction with modulus `p`.
69+ /// - Parameter x: Value to reduce.
70+ /// - Returns: `x mod p` in `[0, p).`
6571 @inlinable
66- func reduce( _ x: T . DoubleWidth ) -> T {
72+ public func reduce( _ x: T . SignedScalar ) -> T {
73+ singleWordModulus. reduce ( x)
74+ }
75+
76+ /// Performs modular reduction with modulus `p`.
77+ /// - Parameter x: Value to reduce.
78+ /// - Returns: `x mod p` in `[0, p).`
79+ @inlinable
80+ public func reduce( _ x: T . DoubleWidth ) -> T {
6781 doubleWordModulus. reduce ( x)
6882 }
6983
7084 /// Performs modular reduction with modulus `p`.
7185 /// - Parameter x: Must be `< p^2`.
72- /// - Returns: `x mod p` for `p`.
86+ /// - Returns: `x mod p` in `[0, p).`
7387 @inlinable
74- func reduceProduct( _ x: T . DoubleWidth ) -> T {
88+ public func reduceProduct( _ x: T . DoubleWidth ) -> T {
7589 reduceProductModulus. reduceProduct ( x)
7690 }
7791
@@ -81,7 +95,7 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
8195 /// - y: Must be `< p`.
8296 /// - Returns: `x * y mod p`.
8397 @inlinable
84- func multiplyMod( _ x: T , _ y: T ) -> T {
98+ public func multiplyMod( _ x: T , _ y: T ) -> T {
8599 precondition ( x < modulus)
86100 precondition ( y < modulus)
87101 let product = x. multipliedFullWidth ( by: y)
@@ -92,7 +106,7 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
92106 /// - Parameter dividend: Number to divide.
93107 /// - Returns: `dividend / modulus`, rounded down to the next integer.
94108 @inlinable
95- func dividingFloor( by dividend: T . DoubleWidth ) -> T . DoubleWidth {
109+ public func dividingFloor( by dividend: T . DoubleWidth ) -> T . DoubleWidth {
96110 divisionModulus. dividingFloor ( by: dividend)
97111 }
98112}
@@ -153,15 +167,20 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
153167
154168 /// The maximum valid modulus value.
155169 @usableFromInline static var max : T {
156- // Constrained by `reduceProduct`
170+ // Constrained by `reduceProduct` and `reduce(_ x: T.SignedScalar)`
157171 ( T ( 1 ) << ( T . bitWidth - 2 ) ) - 1
158172 }
159173
160174 /// Power used in computed Barrett factor.
161175 @usableFromInline let shift : Int
162176 /// Barrett factor.
163177 @usableFromInline let factor : T . DoubleWidth
178+ /// The modulus, `p`.
164179 @usableFromInline let modulus : T
180+ /// `modulus.previousPowerOfTwo`.
181+ @usableFromInline let modulusPreviousPowerOfTwo : T
182+ /// `round(2^{log2(p) - 1) * 2^{T.bitWidth} / p)`.
183+ @usableFromInline let signedFactor : T . SignedScalar
165184
166185 /// Performs pre-computation for fast modular reduction.
167186 /// - Parameters:
@@ -174,12 +193,25 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
174193 precondition ( variableTime)
175194 precondition ( modulus <= Self . max)
176195 self . modulus = modulus
196+ self . modulusPreviousPowerOfTwo = modulus. previousPowerOfTwo
177197 switch bound {
178198 case . SingleWord:
179199 self . shift = T . bitWidth
180- let numerator = T . DoubleWidth ( 1 ) << shift
181- // 2^T.bitwidth // p
182- self . factor = numerator / T. DoubleWidth ( modulus)
200+ // floor(2^T.bitwidth / p)
201+ self . factor = T . DoubleWidth ( ( high: 1 , low: 0 ) ) / T. DoubleWidth ( modulus)
202+ if modulus. isPowerOfTwo {
203+ // This should actually be `T.SignedScalar.max + 1`, but this works too.
204+ // See `reduce(_ x: T.SignedScalar)` for more information.
205+ self . signedFactor = T . SignedScalar. max
206+ } else {
207+ // We compute `round(2^{log2(p) - 1} * 2^{T.bitWidth} / p)` by noting
208+ // `2^{log2(p)} = q.previousPowerOfTwo`, and `round(x/p) = floor(x + floor(p/2) / p)`.
209+ let numerator = T . DoubleWidth ( ( high: modulus. previousPowerOfTwo >> 1 , low: T . Magnitude ( modulus) >> 1 ) )
210+ // Guaranteed to fit into single word, since `2^{log2(p) - 1) / p < 1/2` for `p` not a power of 2,
211+ // which implies `signedFactor < 2^{T.bitWidth} / 2`
212+ self . signedFactor = T . SignedScalar ( ( numerator / T. DoubleWidth ( modulus) ) . low)
213+ }
214+
183215 case . DoubleWord:
184216 self . shift = 2 * T. bitWidth
185217 self . factor = if modulus. isPowerOfTwo {
@@ -188,11 +220,14 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
188220 // floor(2^{2 * t} / p) == floor((2^{2 * t} - 1) / p) for p not a power of two
189221 T . DoubleWidth. max / T. DoubleWidth ( modulus)
190222 }
223+ self . signedFactor = 0 // Unused
224+
191225 case . ModulusSquared:
192226 let reduceModulusAlpha = T . bitWidth - 2
193227 self . shift = modulus. significantBitCount + reduceModulusAlpha
194228 let numerator = T . DoubleWidth ( 1 ) << shift
195229 self . factor = numerator / T. DoubleWidth ( modulus)
230+ self . signedFactor = 0 // Unused
196231 }
197232 }
198233
@@ -217,8 +252,43 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
217252 return z. subtractIfExceeds ( modulus)
218253 }
219254
255+ /// Returns `x mod p` in `[0, p)` for signed integer `x`.
256+ ///
257+ /// Requires the modulus `p` to satisfy `p < 2^{T.bitWidth - 2}`.
258+ /// See Algorithm 5 from <https://eprint.iacr.org/2018/039.pdf>.
259+ /// The proof of Lemma 4 still goes through for odd moduli `q < 2^{T.bitWidth - 2}`, by using the bound
260+ /// `floor(2^k \beta / q) >= 2^k \beta / q - 1`, rather than
261+ /// `floor(2^k \beta / q) >= 2^k \beta / q - 1/2`.
262+ /// For a `q` a power of two, the `signedFactor` is off by one (`2^{T.bitWidth} - 1` instead of `2^{T.bitWidth}`),
263+ /// so we provide a quick proof of correctness in this case.
264+ /// Using notation from the proof of Lemma 4 of <https://eprint.iacr.org/2018/039.pdf>, and assuming `a >= 0`,
265+ /// we have `2^k = q / 2`, so `v = floor(2^k β / q) = β / 2`. Since we are using `v - 1` instead of `v`, we have
266+ /// `r = a - q * floor(a * (v - 1) / (2^k β))`. Using `floor(x) >= x - 1`, we have
267+ /// `<= a - q * (a * (v - 1) / (2^k β)) + q`. Using `v = β / 2` and `2^k = q / 2`, we have
268+ /// `= a - q * (a β / 2 - a) / (β q / 2) + q`
269+ /// `= a - a + q a / (β q / 2) + q`
270+ /// `= a / (β / 2) + q`
271+ /// `< 1 + q` for `a < β / 2`.
272+ /// Since we use `v - 1` instead of `v`, the result can only be larger than as Algorithm 5 is written.
273+ /// Hence, the lower bound `r > -1` from the proof of Lemma 4 still holds.
274+ /// Since `r < q + 1`, `r > -1`, and `r` is integral, we have `r in [0, q]`.
275+ /// The final `subtractIfExceeds` ensures `r in [0, q - 1]`.
276+ ///
277+ /// The proof follows analagously for `a < 0`.
278+ ///
279+ /// - Parameter x: Value to reduce.
280+ /// - Returns: `x mod p` in `[0, p)`.
281+ @inlinable
282+ func reduce( _ x: T . SignedScalar ) -> T {
283+ assert ( shift == T . bitWidth)
284+ var t = x. multiplyHigh ( signedFactor) >> ( modulus. log2 - 1 )
285+ t = t &* T . SignedScalar ( modulus)
286+ return T ( x &- t) . subtractIfExceeds ( modulus)
287+ }
288+
220289 /// Returns `x mod p`.
221290 ///
291+ /// Requires modulus `p < 2^{T.bitWidth - 1}`.
222292 /// Useful when `x >= p^2`, otherwise use `` reduceProduct``.
223293 /// Proof of correctness:
224294 /// Let `t = T.bitWidth`
@@ -234,7 +304,7 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
234304 /// Adding (3) and (4) yields
235305 /// `0 <= x - q * p < x * p / 2^{2 * t} + p < 2 * p`.
236306 ///
237- /// Note, the bound on `p < 2^63 ` comes from `2 * p < T.max `
307+ /// Note, the bound on `p < 2^{t - 1} ` comes from `2 * p < 2^t `
238308 @inlinable
239309 func reduce( _ x: T . DoubleWidth ) -> T {
240310 assert ( shift == x. bitWidth)
0 commit comments