Skip to content

Commit 8a67b7b

Browse files
authored
Fix debug tests (#21)
1 parent 4149a25 commit 8a67b7b

File tree

4 files changed

+77
-7
lines changed

4 files changed

+77
-7
lines changed

Sources/HomomorphicEncryption/PolyRq/PolyRq+Ntt.swift

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,11 @@ extension PolyContext {
241241
let rootOfUnityPowers = nttContext.rootOfUnityPowers
242242
// The forward butterfly transforms `x,y` in
243243
// `[0, k * modulus)` -> `[0, (k + 2) * modulus)`.
244-
// We delay modular reduction until overflowing T, i.e.
245-
// `(kMax + 2) * modulus > T.max`, so `kMax = floor(T.max / modulus) - 2
244+
// We delay modular reduction until overflowing the maximum input to `T.subtractIfExceeds`, i.e.,
245+
// we find the largest `kMax` such that `(kMax + 2) * modulus <= (Self.max >> 1) + modulus`.
246+
// So, we have `kMax = T.max / (2 * modulus) - 1`
246247
var lazyReductionCounter = -1 // k
247-
// kMax
248-
let maxLazyReductionCounter = modulusReduceFactor.singleWordModulus.factor.low &- 2
248+
let maxLazyReductionCounter = T.max / (2 * modulus) - 1 // kMax
249249

250250
func applyFinalStageOp(m: Int, op: (_ x: inout T, _ y: inout T) -> Void) {
251251
for i in 0..<m {
@@ -292,14 +292,17 @@ extension PolyContext {
292292
let timeToReduce = lazyReductionCounter > maxLazyReductionCounter
293293
if timeToReduce {
294294
if t == 1 {
295-
lazyReductionCounter &-= 2
295+
// if lazyReductionCounter == 3, `subtractIfExceeds(twiceModulus)`
296+
// only ensures `x in [0, 2 * p - 1]`
297+
lazyReductionCounter = max(lazyReductionCounter - 2, 2)
296298
} else {
297299
lazyReductionCounter = 1
298300
}
299301
}
300302
switch (t, timeToReduce) {
301303
case (1, true):
302304
applyFinalStageOp(m: m) { x, _ in
305+
// Ensure butterfly doesn't overflow
303306
x = x.subtractIfExceeds(twiceModulus)
304307
}
305308
case (1, false):

Sources/HomomorphicEncryption/Scalar.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,12 @@ extension ScalarType {
169169
///
170170
/// Computes a conditional subtraction, `if self >= modulus ? self - modulus : self`, which can be used for modular
171171
/// reduction of `self` from range `[0, 2 * modulus - 1]` to `[0, modulus - 1]`. The computation is constant-time.
172+
/// `self` must be less than or equal to `(Self.max >> 1) + modulus``
172173
/// - Parameter modulus: Modulus.
173174
/// - Returns: `self >= modulus ? self - modulus : self`.
174175
@inlinable
175176
public func subtractIfExceeds(_ modulus: Self) -> Self {
176-
assert(self < 2 * modulus)
177+
assert(self <= (Self.max &>> 1) + modulus) // difference mask fails otherwise
177178
let difference = self &- modulus
178179
let mask = Self(0) &- (difference >> (bitWidth - 1))
179180
return difference &+ (modulus & mask)

Tests/HomomorphicEncryptionTests/NttTests.swift

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,55 @@ final class NttTests: XCTestCase {
106106
])
107107
}
108108

109+
func testNtt16() throws {
110+
// modulus near top of range
111+
try runNttTest(
112+
moduli: [UInt32(536_870_849)],
113+
coeffData: [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
114+
evalData: [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
115+
try runNttTest(
116+
moduli: [UInt32(536_870_849)],
117+
coeffData: [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
118+
evalData: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
119+
try runNttTest(
120+
moduli: [UInt32(536_870_849)],
121+
coeffData: [[
122+
477_051_601,
123+
421_524_611,
124+
456_257_859,
125+
247_136_825,
126+
128_775_020,
127+
76_785_070,
128+
49_764_016,
129+
525_812_772,
130+
325_605_371,
131+
88_935_943,
132+
255_470_762,
133+
39_507_048,
134+
404_978_219,
135+
379_383_003,
136+
244_420_585,
137+
346_826_612,
138+
]], evalData: [[
139+
230_846_094,
140+
480_599_401,
141+
157_364_576,
142+
360_442_736,
143+
531_052_463,
144+
294_311_347,
145+
432_899_854,
146+
219_721_533,
147+
286_807_067,
148+
260_650_843,
149+
362_842_688,
150+
315_862_017,
151+
493_042_020,
152+
520_739_674,
153+
167_758_416,
154+
370_401_491,
155+
]])
156+
}
157+
109158
func testNtt32() throws {
110159
let modulus = UInt32(769)
111160

@@ -175,8 +224,11 @@ final class NttTests: XCTestCase {
175224
return try xEval.inverseNtt()
176225
}
177226

178-
let moduli = [UInt64(576_460_752_303_436_801)]
179227
let degree = 128
228+
let moduli = try UInt32.generatePrimes(
229+
significantBitCounts: [30],
230+
preferringSmall: false,
231+
nttDegree: degree)
180232
let context = try PolyContext(degree: degree, moduli: moduli)
181233
let x = PolyRq<_, Coeff>.random(context: context)
182234
let y = PolyRq<_, Coeff>.random(context: context)

Tests/HomomorphicEncryptionTests/ScalarTests.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ import TestUtilities
1717
import XCTest
1818

1919
class ScalarTests: XCTestCase {
20+
func testSubtractIfExceeds() {
21+
do {
22+
let modulus: UInt32 = (1 << 29) - 63
23+
XCTAssertEqual(UInt32(2 * modulus + 1).subtractIfExceeds(modulus), modulus + 1)
24+
XCTAssertEqual(UInt32(modulus - 1).subtractIfExceeds(modulus), modulus - 1)
25+
}
26+
do {
27+
let modulus: UInt32 = (1 << 31) - 10
28+
let max = (UInt32.max >> 1) + modulus
29+
XCTAssertEqual(UInt32(max).subtractIfExceeds(modulus), max - modulus)
30+
XCTAssertEqual(UInt32(modulus - 1).subtractIfExceeds(modulus), modulus - 1)
31+
}
32+
}
33+
2034
func testAddMod() {
2135
XCTAssertEqual(UInt32(0).addMod(1, modulus: 3), 1)
2236
XCTAssertEqual(UInt32(1).addMod(2, modulus: 3), 0)

0 commit comments

Comments
 (0)