Skip to content

Commit 7affb01

Browse files
authored
Add modular reduction and signed encoding to PlaintextMatrix. (#71)
1 parent 8458194 commit 7affb01

File tree

9 files changed

+269
-47
lines changed

9 files changed

+269
-47
lines changed

Sources/HomomorphicEncryption/Array2d.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ extension Array2d {
119119
throw HeError.invalidRotationParameter(range: range, columnCount: data.count)
120120
}
121121

122-
let effectiveStep = step.toRemainder(range)
122+
let effectiveStep = step.toRemainder(range, variableTime: true)
123123
for index in stride(from: 0, to: data.count, by: range) {
124124
let replacement = data[index + effectiveStep..<index + range] + data[index..<index + effectiveStep]
125125
data.replaceSubrange(index..<index + range, with: replacement)

Sources/HomomorphicEncryption/Encoding.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ extension Context {
5151
guard bounds.contains(Scheme.Scalar.SignedScalar(value)) else {
5252
throw HeError.encodingDataOutOfBounds(for: bounds)
5353
}
54-
return try Scheme.Scalar(value.centeredToRemainder(modulus: plaintextModulus))
54+
return Scheme.Scalar(value.centeredToRemainder(modulus: plaintextModulus))
5555
}
5656
return try encode(values: centeredValues, format: format)
5757
}

Sources/HomomorphicEncryption/Modulus.swift

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
/// Stores pre-computed data for efficient modular operations.
1616
/// - Warning: The operations may leak the modulus through timing or other side channels. So this struct should only be
1717
/// used for public moduli.
18-
@usableFromInline
19-
struct Modulus<T: ScalarType>: Equatable, Sendable {
18+
public struct Modulus<T: ScalarType>: Equatable, Sendable {
2019
/// The maximum valid modulus value.
21-
@usableFromInline static var max: T {
20+
public static var max: T {
2221
ReduceModulus.max
2322
}
2423

@@ -31,15 +30,16 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
3130
/// `ceil(2^k / modulus) - 2^(2 * T.bitWidth)` for
3231
/// `k = 2 * T.bitWidth + ceil(log2(modulus)`.
3332
@usableFromInline let divisionModulus: DivisionModulus<T>
34-
@usableFromInline let modulus: T
33+
/// The modulus, `p`.
34+
public let modulus: T
3535

3636
/// Initializes a ``Modulus``.
3737
/// - Parameters:
38-
/// - modulus: Modulus.
38+
/// - modulus: Modulus. Must be less than ``Modulus/max``.
3939
/// - variableTime: Must be `true`, indicating `modulus` is leaked through timing.
4040
/// - Warning: Leaks `modulus` through timing.
4141
@inlinable
42-
init(modulus: T, variableTime: Bool) {
42+
public init(modulus: T, variableTime: Bool) {
4343
precondition(variableTime)
4444
self.singleWordModulus = ReduceModulus(
4545
modulus: modulus,
@@ -57,21 +57,35 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
5757
self.modulus = modulus
5858
}
5959

60+
/// Performs modular reduction with modulus `p`.
61+
/// - Parameter x: Value to reduce.
62+
/// - Returns: `x mod p` in `[0, p).`
6063
@inlinable
61-
func reduce(_ x: T) -> T {
64+
public func reduce(_ x: T) -> T {
6265
singleWordModulus.reduce(x)
6366
}
6467

68+
/// Performs modular reduction with modulus `p`.
69+
/// - Parameter x: Value to reduce.
70+
/// - Returns: `x mod p` in `[0, p).`
6571
@inlinable
66-
func reduce(_ x: T.DoubleWidth) -> T {
72+
public func reduce(_ x: T.SignedScalar) -> T {
73+
singleWordModulus.reduce(x)
74+
}
75+
76+
/// Performs modular reduction with modulus `p`.
77+
/// - Parameter x: Value to reduce.
78+
/// - Returns: `x mod p` in `[0, p).`
79+
@inlinable
80+
public func reduce(_ x: T.DoubleWidth) -> T {
6781
doubleWordModulus.reduce(x)
6882
}
6983

7084
/// Performs modular reduction with modulus `p`.
7185
/// - Parameter x: Must be `< p^2`.
72-
/// - Returns: `x mod p` for `p`.
86+
/// - Returns: `x mod p` in `[0, p).`
7387
@inlinable
74-
func reduceProduct(_ x: T.DoubleWidth) -> T {
88+
public func reduceProduct(_ x: T.DoubleWidth) -> T {
7589
reduceProductModulus.reduceProduct(x)
7690
}
7791

@@ -81,7 +95,7 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
8195
/// - y: Must be `< p`.
8296
/// - Returns: `x * y mod p`.
8397
@inlinable
84-
func multiplyMod(_ x: T, _ y: T) -> T {
98+
public func multiplyMod(_ x: T, _ y: T) -> T {
8599
precondition(x < modulus)
86100
precondition(y < modulus)
87101
let product = x.multipliedFullWidth(by: y)
@@ -92,7 +106,7 @@ struct Modulus<T: ScalarType>: Equatable, Sendable {
92106
/// - Parameter dividend: Number to divide.
93107
/// - Returns: `dividend / modulus`, rounded down to the next integer.
94108
@inlinable
95-
func dividingFloor(by dividend: T.DoubleWidth) -> T.DoubleWidth {
109+
public func dividingFloor(by dividend: T.DoubleWidth) -> T.DoubleWidth {
96110
divisionModulus.dividingFloor(by: dividend)
97111
}
98112
}
@@ -153,15 +167,20 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
153167

154168
/// The maximum valid modulus value.
155169
@usableFromInline static var max: T {
156-
// Constrained by `reduceProduct`
170+
// Constrained by `reduceProduct` and `reduce(_ x: T.SignedScalar)`
157171
(T(1) << (T.bitWidth - 2)) - 1
158172
}
159173

160174
/// Power used in computed Barrett factor.
161175
@usableFromInline let shift: Int
162176
/// Barrett factor.
163177
@usableFromInline let factor: T.DoubleWidth
178+
/// The modulus, `p`.
164179
@usableFromInline let modulus: T
180+
/// `modulus.previousPowerOfTwo`.
181+
@usableFromInline let modulusPreviousPowerOfTwo: T
182+
/// `round(2^{log2(p) - 1) * 2^{T.bitWidth} / p)`.
183+
@usableFromInline let signedFactor: T.SignedScalar
165184

166185
/// Performs pre-computation for fast modular reduction.
167186
/// - Parameters:
@@ -174,12 +193,25 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
174193
precondition(variableTime)
175194
precondition(modulus <= Self.max)
176195
self.modulus = modulus
196+
self.modulusPreviousPowerOfTwo = modulus.previousPowerOfTwo
177197
switch bound {
178198
case .SingleWord:
179199
self.shift = T.bitWidth
180-
let numerator = T.DoubleWidth(1) << shift
181-
// 2^T.bitwidth // p
182-
self.factor = numerator / T.DoubleWidth(modulus)
200+
// floor(2^T.bitwidth / p)
201+
self.factor = T.DoubleWidth((high: 1, low: 0)) / T.DoubleWidth(modulus)
202+
if modulus.isPowerOfTwo {
203+
// This should actually be `T.SignedScalar.max + 1`, but this works too.
204+
// See `reduce(_ x: T.SignedScalar)` for more information.
205+
self.signedFactor = T.SignedScalar.max
206+
} else {
207+
// We compute `round(2^{log2(p) - 1} * 2^{T.bitWidth} / p)` by noting
208+
// `2^{log2(p)} = q.previousPowerOfTwo`, and `round(x/p) = floor(x + floor(p/2) / p)`.
209+
let numerator = T.DoubleWidth((high: modulus.previousPowerOfTwo >> 1, low: T.Magnitude(modulus) >> 1))
210+
// Guaranteed to fit into single word, since `2^{log2(p) - 1) / p < 1/2` for `p` not a power of 2,
211+
// which implies `signedFactor < 2^{T.bitWidth} / 2`
212+
self.signedFactor = T.SignedScalar((numerator / T.DoubleWidth(modulus)).low)
213+
}
214+
183215
case .DoubleWord:
184216
self.shift = 2 * T.bitWidth
185217
self.factor = if modulus.isPowerOfTwo {
@@ -188,11 +220,14 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
188220
// floor(2^{2 * t} / p) == floor((2^{2 * t} - 1) / p) for p not a power of two
189221
T.DoubleWidth.max / T.DoubleWidth(modulus)
190222
}
223+
self.signedFactor = 0 // Unused
224+
191225
case .ModulusSquared:
192226
let reduceModulusAlpha = T.bitWidth - 2
193227
self.shift = modulus.significantBitCount + reduceModulusAlpha
194228
let numerator = T.DoubleWidth(1) << shift
195229
self.factor = numerator / T.DoubleWidth(modulus)
230+
self.signedFactor = 0 // Unused
196231
}
197232
}
198233

@@ -217,8 +252,43 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
217252
return z.subtractIfExceeds(modulus)
218253
}
219254

255+
/// Returns `x mod p` in `[0, p)` for signed integer `x`.
256+
///
257+
/// Requires the modulus `p` to satisfy `p < 2^{T.bitWidth - 2}`.
258+
/// See Algorithm 5 from <https://eprint.iacr.org/2018/039.pdf>.
259+
/// The proof of Lemma 4 still goes through for odd moduli `q < 2^{T.bitWidth - 2}`, by using the bound
260+
/// `floor(2^k \beta / q) >= 2^k \beta / q - 1`, rather than
261+
/// `floor(2^k \beta / q) >= 2^k \beta / q - 1/2`.
262+
/// For a `q` a power of two, the `signedFactor` is off by one (`2^{T.bitWidth} - 1` instead of `2^{T.bitWidth}`),
263+
/// so we provide a quick proof of correctness in this case.
264+
/// Using notation from the proof of Lemma 4 of <https://eprint.iacr.org/2018/039.pdf>, and assuming `a >= 0`,
265+
/// we have `2^k = q / 2`, so `v = floor(2^k β / q) = β / 2`. Since we are using `v - 1` instead of `v`, we have
266+
/// `r = a - q * floor(a * (v - 1) / (2^k β))`. Using `floor(x) >= x - 1`, we have
267+
/// `<= a - q * (a * (v - 1) / (2^k β)) + q`. Using `v = β / 2` and `2^k = q / 2`, we have
268+
/// `= a - q * (a β / 2 - a) / (β q / 2) + q`
269+
/// `= a - a + q a / (β q / 2) + q`
270+
/// `= a / (β / 2) + q`
271+
/// `< 1 + q` for `a < β / 2`.
272+
/// Since we use `v - 1` instead of `v`, the result can only be larger than as Algorithm 5 is written.
273+
/// Hence, the lower bound `r > -1` from the proof of Lemma 4 still holds.
274+
/// Since `r < q + 1`, `r > -1`, and `r` is integral, we have `r in [0, q]`.
275+
/// The final `subtractIfExceeds` ensures `r in [0, q - 1]`.
276+
///
277+
/// The proof follows analagously for `a < 0`.
278+
///
279+
/// - Parameter x: Value to reduce.
280+
/// - Returns: `x mod p` in `[0, p)`.
281+
@inlinable
282+
func reduce(_ x: T.SignedScalar) -> T {
283+
assert(shift == T.bitWidth)
284+
var t = x.multiplyHigh(signedFactor) >> (modulus.log2 - 1)
285+
t = t &* T.SignedScalar(modulus)
286+
return T(x &- t).subtractIfExceeds(modulus)
287+
}
288+
220289
/// Returns `x mod p`.
221290
///
291+
/// Requires modulus `p < 2^{T.bitWidth - 1}`.
222292
/// Useful when `x >= p^2`, otherwise use `` reduceProduct``.
223293
/// Proof of correctness:
224294
/// Let `t = T.bitWidth`
@@ -234,7 +304,7 @@ struct ReduceModulus<T: ScalarType>: Equatable, Sendable {
234304
/// Adding (3) and (4) yields
235305
/// `0 <= x - q * p < x * p / 2^{2 * t} + p < 2 * p`.
236306
///
237-
/// Note, the bound on `p < 2^63` comes from `2 * p < T.max`
307+
/// Note, the bound on `p < 2^{t - 1}` comes from `2 * p < 2^t`
238308
@inlinable
239309
func reduce(_ x: T.DoubleWidth) -> T {
240310
assert(shift == x.bitWidth)

Sources/HomomorphicEncryption/Scalar.swift

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,22 @@ extension SignedScalarType {
6161
return Self(bitPattern: result)
6262
}
6363

64+
/// Computes the high `Self.bitWidth` bits of `self * rhs`.
65+
/// - Parameter rhs: Multiplicand.
66+
/// - Returns: the high `Self.bitWidth` bits of `self * rhs`.
67+
@inlinable
68+
public func multiplyHigh(_ rhs: Self) -> Self {
69+
multipliedFullWidth(by: rhs).high
70+
}
71+
6472
/// Constant-time centered-to-remainder conversion.
6573
/// - Parameter modulus: Modulus.
66-
/// - Returns: Given `self` in `[-floor(modulus/2), floor(modulus-1)/2]`, returns `self % modulus` in `[0,
67-
/// modulus)`.
68-
/// - Throws: Error upon failure to encode.
74+
/// - Returns: Given `self` in `[-floor(modulus/2), floor((modulus-1)/2)]`,
75+
/// returns `self % modulus` in `[0, modulus)`.
6976
@inlinable
70-
public func centeredToRemainder(modulus: some ScalarType) throws -> Self.UnsignedScalar {
77+
public func centeredToRemainder(modulus: some ScalarType) -> Self.UnsignedScalar {
78+
assert(self <= (Self(modulus) - 1) / 2)
79+
assert(self >= -Self(modulus) / 2)
7180
let condition = Self.UnsignedScalar(bitPattern: self >> (bitWidth - 1))
7281
let thenValue = Self.UnsignedScalar(bitPattern: self &+ Self(bitPattern: Self.UnsignedScalar(modulus)))
7382
let elseValue = Self.UnsignedScalar(bitPattern: self)
@@ -198,7 +207,7 @@ extension FixedWidthInteger {
198207
}
199208

200209
extension ScalarType {
201-
/// Computes the high bits `Self.bitWidth` of `self * rhs`.
210+
/// Computes the high `Self.bitWidth` bits of `self * rhs`.
202211
/// - Parameter rhs: Multiplicand.
203212
/// - Returns: the high `Self.bitWidth` bits of `self * rhs`.
204213
@inlinable
@@ -390,6 +399,14 @@ extension FixedWidthInteger {
390399
return 1 &<< ((self &- 1).log2 &+ 1)
391400
}
392401

402+
/// The next power of two greater than or equal to this value.
403+
///
404+
/// This value must be positive.
405+
@inlinable public var previousPowerOfTwo: Self {
406+
precondition(self > 0)
407+
return 1 &<< (Self.bitWidth &- 1 - leadingZeroBitCount)
408+
}
409+
393410
/// Computes a modular multiplication.
394411
///
395412
/// Is not constant time. Use `ReduceModulus` for a constant-time alternative, which is also faster when the modulus
@@ -629,7 +646,7 @@ extension ScalarType {
629646
/// - Parameter modulus: Modulus.
630647
/// - Returns: Given `self` in `[0,modulus)`, returns `self % modulus` in `[-floor(modulus/2), floor(modulus-1)/2]`.
631648
@inlinable
632-
func remainderToCentered(modulus: Self) -> Self.SignedScalar {
649+
public func remainderToCentered(modulus: Self) -> Self.SignedScalar {
633650
let condition = constantTimeGreaterThan((modulus - 1) >> 1)
634651
let thenValue = Self.SignedScalar(self) - Self.SignedScalar(bitPattern: modulus)
635652
let elseValue = Self.SignedScalar(bitPattern: self)

Sources/HomomorphicEncryption/Util.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ extension Sequence {
4343
extension FixedWidthInteger {
4444
// not a constant time operation
4545
@inlinable
46-
func toRemainder(_ mod: Self) -> Self {
46+
func toRemainder(_ mod: Self, variableTime: Bool) -> Self {
47+
precondition(variableTime)
4748
precondition(mod > 0)
4849
var result = self % mod
4950
if result < 0 {

Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,61 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
137137
/// - context: Parameter context to encode the data with.
138138
/// - dimensions: Plaintext matrix dimensions.
139139
/// - packing: The packing with which the data is stored.
140-
/// - values: The data values to store in the plaintext matrix; stored in row-major format.
140+
/// - signedValues: The signed data values to store in the plaintext matrix; stored in row-major format.
141+
/// - reduce: If true, values are reduced into the correct range before encoding.
141142
/// - Throws: Error upon failure to create the plaitnext matrix.
142143
@inlinable
143144
public init(
144145
context: Context<Scheme>,
145146
dimensions: MatrixDimensions,
146147
packing: MatrixPacking,
147-
values: [some ScalarType]) throws
148+
signedValues: [Scheme.SignedScalar],
149+
reduce: Bool = false) throws where Format == Coeff
150+
{
151+
let modulus = Modulus(modulus: context.plaintextModulus, variableTime: true)
152+
let centeredValues = if reduce {
153+
signedValues.map { value in
154+
Scheme.Scalar(modulus.reduce(value))
155+
}
156+
} else {
157+
signedValues.map { value in
158+
Scheme.Scalar(value.centeredToRemainder(modulus: modulus.modulus))
159+
}
160+
}
161+
try self.init(
162+
context: context,
163+
dimensions: dimensions,
164+
packing: packing,
165+
values: centeredValues,
166+
reduce: false)
167+
}
168+
169+
/// Creates a new plaintext matrix.
170+
/// - Parameters:
171+
/// - context: Parameter context to encode the data with.
172+
/// - dimensions: Plaintext matrix dimensions.
173+
/// - packing: The packing with which the data is stored.
174+
/// - values: The data values to store in the plaintext matrix; stored in row-major format.
175+
/// - reduce: If true, values are reduced into the correct range before encoding.
176+
/// - Throws: Error upon failure to create the plaitnext matrix.
177+
@inlinable
178+
init(
179+
context: Context<Scheme>,
180+
dimensions: MatrixDimensions,
181+
packing: MatrixPacking,
182+
values: [Scheme.Scalar],
183+
reduce: Bool = false) throws
148184
where Format == Coeff
149185
{
150186
guard values.count == dimensions.count, !values.isEmpty else {
151187
throw PnnsError.wrongEncodingValuesCount(got: values.count, expected: values.count)
152188
}
189+
var values = values
190+
if reduce {
191+
let modulus = Modulus(modulus: context.plaintextModulus, variableTime: true)
192+
values = values.map { value in modulus.reduce(value) }
193+
}
194+
153195
switch packing {
154196
case .denseColumn:
155197
let plaintexts = try PlaintextMatrix.denseColumnPlaintexts(
@@ -421,7 +463,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
421463
/// - Returns: The stored data values in row-major format.
422464
/// - Throws: Error upon failure to unpack the matrix.
423465
@inlinable
424-
func unpack<V: ScalarType>() throws -> [V] where Format == Coeff {
466+
func unpack() throws -> [Scheme.Scalar] where Format == Coeff {
425467
switch packing {
426468
case .denseColumn:
427469
return try unpackDenseColumn()
@@ -433,6 +475,17 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
433475
}
434476
}
435477

478+
/// Unpacks the plaintext matrix into signed values.
479+
/// - Returns: The stored data values in row-major format.
480+
/// - Throws: Error upon failure to unpack the matrix.
481+
@inlinable
482+
func unpack() throws -> [Scheme.SignedScalar] where Format == Coeff {
483+
let unsigned: [Scheme.Scalar] = try unpack()
484+
return unsigned.map { unsigned in
485+
unsigned.remainderToCentered(modulus: context.plaintextModulus)
486+
}
487+
}
488+
436489
/// Unpacks a plaintext matrix with `denseColumn` packing.
437490
/// - Returns: The stored data values in row-major format.
438491
/// - Throws: Error upon failure to unpack the matrix.

0 commit comments

Comments
 (0)