Skip to content

Commit b11e47a

Browse files
committed
feat(shortint): introduce the KS32 atomic pattern
1 parent b301122 commit b11e47a

File tree

8 files changed

+601
-9
lines changed

8 files changed

+601
-9
lines changed

tfhe/src/core_crypto/commons/math/random/gaussian.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl Gaussian<f64> {
3030
}
3131
}
3232

33-
pub fn standard_dev(&self) -> StandardDev {
33+
pub const fn standard_dev(&self) -> StandardDev {
3434
StandardDev(self.std)
3535
}
3636
}

tfhe/src/core_crypto/entities/lwe_keyswitch_key.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,9 @@ pub struct LweKeyswitchKeyConformanceParams {
443443
pub ciphertext_modulus: CiphertextModulus<u64>,
444444
}
445445

446-
impl<C: Container<Element = u64>> ParameterSetConformant for LweKeyswitchKey<C> {
446+
impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> ParameterSetConformant
447+
for LweKeyswitchKey<C>
448+
{
447449
type ParameterSet = LweKeyswitchKeyConformanceParams;
448450

449451
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
@@ -455,7 +457,11 @@ impl<C: Container<Element = u64>> ParameterSetConformant for LweKeyswitchKey<C>
455457
ciphertext_modulus,
456458
} = self;
457459

458-
*ciphertext_modulus == parameter_set.ciphertext_modulus
460+
let Ok(parameters_modulus) = parameter_set.ciphertext_modulus.try_to() else {
461+
return false;
462+
};
463+
464+
*ciphertext_modulus == parameters_modulus
459465
&& data.container_len()
460466
== parameter_set.input_lwe_dimension.0
461467
* lwe_keyswitch_key_input_key_element_encrypted_size(
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
use serde::{Deserialize, Serialize};
2+
use tfhe_csprng::seeders::Seed;
3+
use tfhe_versionable::NotVersioned;
4+
5+
use crate::conformance::ParameterSetConformant;
6+
use crate::core_crypto::prelude::{
7+
allocate_and_generate_new_lwe_keyswitch_key, extract_lwe_sample_from_glwe_ciphertext,
8+
keyswitch_lwe_ciphertext_with_scalar_change, CiphertextModulus as CoreCiphertextModulus,
9+
LweCiphertext, LweCiphertextOwned, LweDimension, LweKeyswitchKeyConformanceParams,
10+
LweKeyswitchKeyOwned, LweSecretKey, MonomialDegree, MsDecompressionType,
11+
};
12+
use crate::shortint::ciphertext::{CompressedModulusSwitchedCiphertext, NoiseLevel};
13+
use crate::shortint::engine::ShortintEngine;
14+
use crate::shortint::oprf::generate_pseudo_random_from_pbs;
15+
use crate::shortint::server_key::{
16+
apply_blind_rotate_no_ms_noise_reduction, decompress_and_apply_lookup_table,
17+
switch_modulus_and_compress, LookupTableOwned, LookupTableSize, ManyLookupTableOwned,
18+
PBSConformanceParams, ShortintBootstrappingKey,
19+
};
20+
use crate::shortint::{Ciphertext, CiphertextModulus, ClientKey, PBSParameters};
21+
22+
use super::{apply_programmable_bootstrap, AtomicPattern, AtomicPatternKind, AtomicPatternMut};
23+
24+
/// The definition of the server key elements used in the [`KeySwitch32`] atomic
25+
/// pattern
26+
///
27+
/// [`KeySwitch32`]: AtomicPatternKind::KeySwitch32
28+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, NotVersioned)] // TODO: Versionize
29+
pub struct KS32AtomicPatternServerKey {
30+
pub key_switching_key: LweKeyswitchKeyOwned<u32>,
31+
pub bootstrapping_key: ShortintBootstrappingKey<u32>,
32+
}
33+
34+
impl ParameterSetConformant for KS32AtomicPatternServerKey {
35+
type ParameterSet = PBSParameters;
36+
37+
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
38+
let Self {
39+
key_switching_key,
40+
bootstrapping_key,
41+
} = self;
42+
43+
let params: PBSConformanceParams = parameter_set.into();
44+
45+
let pbs_key_ok = bootstrapping_key.is_conformant(&params);
46+
47+
let param: LweKeyswitchKeyConformanceParams = parameter_set.into();
48+
49+
let ks_key_ok = key_switching_key.is_conformant(&param);
50+
51+
pbs_key_ok && ks_key_ok
52+
}
53+
}
54+
55+
impl KS32AtomicPatternServerKey {
56+
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
57+
let params = &cks.parameters;
58+
59+
let pbs_params = params.ks32_parameters().unwrap();
60+
61+
let in_key = LweSecretKey::from_container(
62+
cks.small_lwe_secret_key()
63+
.as_ref()
64+
.iter()
65+
.copied()
66+
.map(|x| x as u32)
67+
.collect::<Vec<_>>(),
68+
);
69+
70+
let out_key = &cks.glwe_secret_key;
71+
72+
let bootstrapping_key_base =
73+
engine.new_bootstrapping_key_ks32(pbs_params, &in_key, out_key);
74+
75+
// Creation of the key switching key
76+
let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
77+
&cks.large_lwe_secret_key(),
78+
&in_key,
79+
params.ks_base_log(),
80+
params.ks_level(),
81+
pbs_params.lwe_noise_distribution(),
82+
CoreCiphertextModulus::new_native(), // Does it make sense to parametrize this ?
83+
&mut engine.encryption_generator,
84+
);
85+
86+
Self::from_raw_parts(key_switching_key, bootstrapping_key_base)
87+
}
88+
89+
pub fn from_raw_parts(
90+
key_switching_key: LweKeyswitchKeyOwned<u32>,
91+
bootstrapping_key: ShortintBootstrappingKey<u32>,
92+
) -> Self {
93+
assert_eq!(
94+
key_switching_key.input_key_lwe_dimension(),
95+
bootstrapping_key.output_lwe_dimension(),
96+
"Mismatch between the input LweKeyswitchKey LweDimension ({:?}) \
97+
and the ShortintBootstrappingKey output LweDimension ({:?})",
98+
key_switching_key.input_key_lwe_dimension(),
99+
bootstrapping_key.output_lwe_dimension()
100+
);
101+
102+
assert_eq!(
103+
key_switching_key.output_key_lwe_dimension(),
104+
bootstrapping_key.input_lwe_dimension(),
105+
"Mismatch between the output LweKeyswitchKey LweDimension ({:?}) \
106+
and the ShortintBootstrappingKey input LweDimension ({:?})",
107+
key_switching_key.output_key_lwe_dimension(),
108+
bootstrapping_key.input_lwe_dimension()
109+
);
110+
111+
Self {
112+
key_switching_key,
113+
bootstrapping_key,
114+
}
115+
}
116+
117+
pub fn intermediate_lwe_dimension(&self) -> LweDimension {
118+
self.key_switching_key.output_key_lwe_dimension()
119+
}
120+
}
121+
122+
impl AtomicPattern for KS32AtomicPatternServerKey {
123+
fn ciphertext_lwe_dimension(&self) -> LweDimension {
124+
self.key_switching_key.input_key_lwe_dimension()
125+
}
126+
127+
fn ciphertext_modulus(&self) -> CiphertextModulus {
128+
self.key_switching_key
129+
.ciphertext_modulus()
130+
.try_to()
131+
// CiphertextModulus::try_to fails if target scalar is smaller than the input one, we
132+
// know that it is not the case so it is ok to unwrap
133+
.unwrap()
134+
}
135+
136+
fn ciphertext_decompression_method(&self) -> MsDecompressionType {
137+
match &self.bootstrapping_key {
138+
ShortintBootstrappingKey::Classic { .. } => MsDecompressionType::ClassicPbs,
139+
ShortintBootstrappingKey::MultiBit { fourier_bsk, .. } => {
140+
MsDecompressionType::MultiBitPbs(fourier_bsk.grouping_factor())
141+
}
142+
}
143+
}
144+
145+
fn apply_lookup_table_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) {
146+
ShortintEngine::with_thread_local_mut(|engine| {
147+
let (mut ciphertext_buffer, buffers) = engine.get_buffers(
148+
self.intermediate_lwe_dimension(),
149+
CoreCiphertextModulus::new_native(),
150+
);
151+
152+
keyswitch_lwe_ciphertext_with_scalar_change(
153+
&self.key_switching_key,
154+
&ct.ct,
155+
&mut ciphertext_buffer,
156+
);
157+
158+
apply_programmable_bootstrap(
159+
&self.bootstrapping_key,
160+
&ciphertext_buffer,
161+
&mut ct.ct,
162+
&acc.acc,
163+
buffers,
164+
);
165+
});
166+
}
167+
168+
fn apply_many_lookup_table(
169+
&self,
170+
ct: &Ciphertext,
171+
acc: &ManyLookupTableOwned,
172+
) -> Vec<Ciphertext> {
173+
self.keyswitch_programmable_bootstrap_many_lut(ct, acc)
174+
}
175+
176+
fn lookup_table_size(&self) -> LookupTableSize {
177+
LookupTableSize::new(
178+
self.bootstrapping_key.glwe_size(),
179+
self.bootstrapping_key.polynomial_size(),
180+
)
181+
}
182+
183+
fn kind(&self) -> AtomicPatternKind {
184+
AtomicPatternKind::KeySwitch32
185+
}
186+
187+
fn deterministic_execution(&self) -> bool {
188+
self.bootstrapping_key.deterministic_pbs_execution()
189+
}
190+
191+
fn generate_oblivious_pseudo_random(
192+
&self,
193+
seed: Seed,
194+
random_bits_count: u64,
195+
full_bits_count: u64,
196+
) -> LweCiphertextOwned<u64> {
197+
generate_pseudo_random_from_pbs(
198+
&self.bootstrapping_key,
199+
seed,
200+
random_bits_count,
201+
full_bits_count,
202+
self.ciphertext_modulus(),
203+
)
204+
}
205+
206+
fn switch_modulus_and_compress(&self, ct: &Ciphertext) -> CompressedModulusSwitchedCiphertext {
207+
let compressed_modulus_switched_lwe_ciphertext =
208+
ShortintEngine::with_thread_local_mut(|engine| {
209+
let (mut ciphertext_buffer, _) = engine.get_buffers(
210+
self.intermediate_lwe_dimension(),
211+
CoreCiphertextModulus::new_native(),
212+
);
213+
214+
keyswitch_lwe_ciphertext_with_scalar_change(
215+
&self.key_switching_key,
216+
&ct.ct,
217+
&mut ciphertext_buffer,
218+
);
219+
switch_modulus_and_compress(ciphertext_buffer.as_view(), &self.bootstrapping_key)
220+
});
221+
222+
CompressedModulusSwitchedCiphertext {
223+
compressed_modulus_switched_lwe_ciphertext,
224+
degree: ct.degree,
225+
message_modulus: ct.message_modulus,
226+
carry_modulus: ct.carry_modulus,
227+
atomic_pattern: ct.atomic_pattern,
228+
}
229+
}
230+
231+
fn decompress_and_apply_lookup_table(
232+
&self,
233+
compressed_ct: &CompressedModulusSwitchedCiphertext,
234+
lut: &LookupTableOwned,
235+
) -> Ciphertext {
236+
let mut output = LweCiphertext::from_container(
237+
vec![0; self.ciphertext_lwe_dimension().to_lwe_size().0],
238+
self.ciphertext_modulus(),
239+
);
240+
241+
ShortintEngine::with_thread_local_mut(|engine| {
242+
let (mut ciphertext_buffer, buffers) =
243+
engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus());
244+
245+
decompress_and_apply_lookup_table(
246+
compressed_ct,
247+
&lut.acc,
248+
&self.bootstrapping_key,
249+
&mut ciphertext_buffer,
250+
buffers,
251+
);
252+
253+
output
254+
.as_mut()
255+
.copy_from_slice(ciphertext_buffer.into_container())
256+
});
257+
258+
Ciphertext::new(
259+
output,
260+
lut.degree,
261+
NoiseLevel::NOMINAL,
262+
compressed_ct.message_modulus,
263+
compressed_ct.carry_modulus,
264+
compressed_ct.atomic_pattern,
265+
)
266+
}
267+
268+
fn prepare_for_noise_squashing(&self, ct: &Ciphertext) -> LweCiphertextOwned<u64> {
269+
let mut after_ks_ct = LweCiphertext::new(
270+
0,
271+
self.key_switching_key.output_lwe_size(),
272+
self.key_switching_key.ciphertext_modulus(),
273+
);
274+
275+
keyswitch_lwe_ciphertext_with_scalar_change(
276+
&self.key_switching_key,
277+
&ct.ct,
278+
&mut after_ks_ct,
279+
);
280+
281+
let mut scalar_64_ct = LweCiphertext::new(
282+
0u64,
283+
self.key_switching_key.output_lwe_size(),
284+
self.key_switching_key
285+
.ciphertext_modulus()
286+
.try_to()
287+
.unwrap(), // Ok to unwrap because we go from 32 to 64b
288+
);
289+
290+
for (coeff64, coeff32) in scalar_64_ct
291+
.as_mut()
292+
.iter_mut()
293+
.zip(after_ks_ct.as_ref().iter())
294+
{
295+
*coeff64 = *coeff32 as u64;
296+
}
297+
298+
scalar_64_ct
299+
}
300+
}
301+
302+
impl AtomicPatternMut for KS32AtomicPatternServerKey {
303+
fn set_deterministic_execution(&mut self, new_deterministic_execution: bool) {
304+
self.bootstrapping_key
305+
.set_deterministic_pbs_execution(new_deterministic_execution)
306+
}
307+
}
308+
309+
impl KS32AtomicPatternServerKey {
310+
pub(crate) fn keyswitch_programmable_bootstrap_many_lut(
311+
&self,
312+
ct: &Ciphertext,
313+
lut: &ManyLookupTableOwned,
314+
) -> Vec<Ciphertext> {
315+
let mut acc = lut.acc.clone();
316+
317+
ShortintEngine::with_thread_local_mut(|engine| {
318+
// Compute the programmable bootstrapping with fixed test polynomial
319+
let (mut ciphertext_buffer, buffers) = engine.get_buffers(
320+
self.intermediate_lwe_dimension(),
321+
CoreCiphertextModulus::new_native(),
322+
);
323+
324+
// Compute a key switch
325+
keyswitch_lwe_ciphertext_with_scalar_change(
326+
&self.key_switching_key,
327+
&ct.ct,
328+
&mut ciphertext_buffer,
329+
);
330+
331+
apply_blind_rotate_no_ms_noise_reduction(
332+
&self.bootstrapping_key,
333+
&ciphertext_buffer.as_view(),
334+
&mut acc,
335+
buffers,
336+
);
337+
});
338+
339+
// The accumulator has been rotated, we can now proceed with the various sample extractions
340+
let function_count = lut.function_count();
341+
let mut outputs = Vec::with_capacity(function_count);
342+
343+
for (fn_idx, output_degree) in lut.per_function_output_degree.iter().enumerate() {
344+
let monomial_degree = MonomialDegree(fn_idx * lut.sample_extraction_stride);
345+
let mut output_shortint_ct = ct.clone();
346+
347+
extract_lwe_sample_from_glwe_ciphertext(
348+
&acc,
349+
&mut output_shortint_ct.ct,
350+
monomial_degree,
351+
);
352+
353+
output_shortint_ct.degree = *output_degree;
354+
outputs.push(output_shortint_ct);
355+
}
356+
357+
outputs
358+
}
359+
}

0 commit comments

Comments
 (0)