diff --git a/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_hpu_noise.rs b/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_hpu_noise.rs index 2270224bce..62028c3570 100644 --- a/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_hpu_noise.rs +++ b/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_hpu_noise.rs @@ -1,15 +1,14 @@ use super::*; use crate::core_crypto::commons::math::ntt::ntt64::Ntt64; -use crate::core_crypto::commons::test_tools::{torus_modular_diff, variance}; +use crate::core_crypto::commons::test_tools::{ + arithmetic_mean, normality_test_f64, torus_modular_diff, variance, +}; use std::io; // This is 1 / 16 which is exactly representable in an f64 (even an f32) // 1 / 32 is too strict and fails the tests const RELATIVE_TOLERANCE: f64 = 0.0625; -const NB_HPU_TESTS: usize = 5; -const NB_PBS: usize = 200; - #[derive(Clone, Copy)] pub struct HpuTestParams { pub lwe_dimension: LweDimension, @@ -111,6 +110,25 @@ pub const HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_TUNIFORM: HpuTestParams = HpuT ntt_modulus: 18446744069414584321, }; +#[allow(unused)] +pub const HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_TUNIFORM_2M128: HpuTestParams = HpuTestParams { + lwe_dimension: LweDimension(879), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_noise_distribution: DynamicDistribution::new_t_uniform(3), + glwe_noise_distribution: DynamicDistribution::new_t_uniform(17), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(8), + ks_base_log: DecompositionBaseLog(2), + message_modulus_log: CiphertextModulusLog(4), + ct_width: 64, + ksk_width: 21, + //norm2: 8, + norm2: 5, + ntt_modulus: 18446744069414584321, +}; + #[allow(unused)] pub const HPU_TEST_PARAMS_4_BITS_NATIVE_U64: HpuTestParams = HpuTestParams { lwe_dimension: LweDimension(742), @@ -164,8 +182,20 @@ pub fn get_modulo_value(modulus: &CiphertextModulus) -> u } } +#[derive(Clone, Copy, PartialEq)] +enum HpuNoiseMode { + Variance, + Normality, +} + //fn lwe_noise_distribution_hpu + CastFrom>( -fn hpu_noise_distribution(params: HpuTestParams) { +fn hpu_noise_distribution( + params: HpuTestParams, + max_msg_val: u64, + nb_tests: usize, + nb_pbs_per_test: usize, + test_mode: HpuNoiseMode, +) { let lwe_dimension = params.lwe_dimension; let glwe_dimension = params.glwe_dimension; let polynomial_size = params.polynomial_size; @@ -195,15 +225,19 @@ fn hpu_noise_distribution(params: HpuTestParams) { let mut rsc = TestResources::new(); let msg_modulus = 1 << message_modulus_log.0; - let mut msg: u64 = msg_modulus; + assert!(max_msg_val <= msg_modulus, "Cannot start with msg val {max_msg_val:?} with a message_modulus_log of {message_modulus_log:?}"); + let mut msg: u64 = max_msg_val; let delta: u64 = encoding_with_padding / msg_modulus; let ks_delta: u64 = ksk_encoding_with_padding / msg_modulus; let norm2 = params.norm2; - let num_samples = NB_PBS * NB_HPU_TESTS * (msg as usize); + let num_samples = nb_pbs_per_test * nb_tests * (msg as usize); let mut noise_samples = (0..4) .map(|_| Vec::with_capacity(num_samples)) .collect::>(); + let mut normality_test_samples: Vec = Vec::with_capacity(nb_pbs_per_test); + let mut normality_check_result: Vec = Vec::with_capacity(nb_pbs_per_test); + let mut expvalue_score_result: Vec = Vec::with_capacity(nb_pbs_per_test); println!("ciphertext_modulus {ciphertext_modulus:?} ksk_modulus {ksk_modulus:?} message_modulus_log {message_modulus_log:?} encoding_with_padding {encoding_with_padding } expected_variance {expected_variance:?} msg_modulus {msg_modulus} msg {msg} delta {delta}"); let f = |x: u64| x.wrapping_rem(msg_modulus); @@ -344,7 +378,7 @@ fn hpu_noise_distribution(params: HpuTestParams) { while msg != 0 { msg = msg.wrapping_sub(1); - for i in 0..NB_HPU_TESTS { + for i in 0..nb_tests { // re-generate keys generate_binary_lwe_secret_key(&mut lwe_sk, &mut rsc.secret_random_generator); generate_binary_glwe_secret_key(&mut glwe_sk, &mut rsc.secret_random_generator); @@ -413,8 +447,9 @@ fn hpu_noise_distribution(params: HpuTestParams) { let torus_diff = torus_modular_diff(plaintext.0, decrypted.0, ciphertext_modulus); noise_samples[0].push(torus_diff); + normality_test_samples.clear(); - for j in 0..NB_PBS { + for j in 0..nb_pbs_per_test { // b = b - (Delta * msg) to have an encryption of 0 lwe_ciphertext_plaintext_sub_assign(&mut ct, plaintext); @@ -479,6 +514,31 @@ fn hpu_noise_distribution(params: HpuTestParams) { }; noise_samples[2].push(torus_diff); + normality_test_samples.push(torus_diff); + //println!("added in normality_test_samples: {torus_diff:?}"); + if (j == nb_pbs_per_test - 1) && (test_mode == HpuNoiseMode::Normality) { + let sample_set_mean = arithmetic_mean(&normality_test_samples); + let sample_set_var = variance(&normality_test_samples); + let sample_set_score = sample_set_mean / sample_set_var.get_standard_dev().0 + * f64::sqrt(normality_test_samples.len() as f64); + if (-1.96..1.96).contains(&sample_set_score) { + // score is good, it is a success + expvalue_score_result.push(0.0); + } else { + // if score is too high or too low it means expected value is not + // near enough from 0, it is a failure + expvalue_score_result.push(1.0); + } + + if normality_test_f64(&normality_test_samples, 0.05).null_hypothesis_is_valid { + // If we are normal return 0, it's not a failure + println!("normality_test_f64 returned 0.0 mean {sample_set_mean:?} var {sample_set_var:?} score {sample_set_score:?}"); + normality_check_result.push(0.0); + } else { + println!("normality_test_f64 returned 1.0 mean {sample_set_mean:?} var {sample_set_var:?} score {sample_set_score:?}"); + normality_check_result.push(1.0); + } + } // Compute PBS with NTT programmable_bootstrap_ntt64_bnf_lwe_ciphertext_mem_optimized( @@ -502,136 +562,193 @@ fn hpu_noise_distribution(params: HpuTestParams) { assert_eq!(decoded_pbs, f(msg)); let torus_diff = torus_modular_diff(plaintext.0, decrypted_pbs.0, ciphertext_modulus); - println!("after pbs (msg={msg},test_nb={i}/{NB_HPU_TESTS},pbs_nb={j}/{NB_PBS}): plaintext {:?} post pbs {:?} torus_diff {:?}", plaintext.0, decrypted_pbs.0, torus_diff); + if test_mode == HpuNoiseMode::Variance { + println!("after pbs (msg={msg},test_nb={i}/{nb_tests},pbs_nb={j}/{nb_pbs_per_test}): plaintext {:?} post pbs {:?} torus_diff {:?}", plaintext.0, decrypted_pbs.0, torus_diff); + } noise_samples[3].push(torus_diff); } } } - let encryption_variance = variance(&noise_samples[0]); - let bynorm2_variance = variance(&noise_samples[1]); - let after_ks_variance = variance(&noise_samples[2]); - let after_pbs_variance = variance(&noise_samples[3]); - println!( - "exp encrypt var {:?} encrypt var {:?} bynorm2 var {} after_ks_variance {} after_pbs_variance {:?}", - expected_variance.0, - encryption_variance.0, - bynorm2_variance.0, - after_ks_variance.0, - after_pbs_variance.0 - ); - // variance after *norm2 must be around (exp_pbs_variance)*(norm2**2) - // variance after KS must be around (exp_pbs_variance)*(norm2**2)+exp_add_ks_variance - // variance after PBS must be around (exp_pbs_variance) - let expected_bynorm2_variance = Variance(exp_pbs_variance.0 * (norm2 as f64).powf(2.0)); - let expected_after_ks_variance = Variance(expected_bynorm2_variance.0 + exp_add_ks_variance.0); - - let mut wtr = csv::Writer::from_writer(io::stdout()); - let _ = wtr.write_record([ - "data type", - "encrypt exp", - "encrypt", - "post *norm2", - "post KS", - "theo KS", - "post PBS", - "theo PBS", - ]); - let _ = wtr.write_record([ - "variances", - expected_variance.0.to_string().as_str(), - encryption_variance.0.to_string().as_str(), - bynorm2_variance.0.to_string().as_str(), - after_ks_variance.0.to_string().as_str(), - expected_after_ks_variance.0.to_string().as_str(), - after_pbs_variance.0.to_string().as_str(), - exp_pbs_variance.0.to_string().as_str(), - ]); - let _ = wtr.write_record([ - "std_dev", - expected_variance.get_standard_dev().0.to_string().as_str(), - encryption_variance - .get_standard_dev() - .0 - .to_string() - .as_str(), - bynorm2_variance.get_standard_dev().0.to_string().as_str(), - after_ks_variance.get_standard_dev().0.to_string().as_str(), - expected_after_ks_variance - .get_standard_dev() - .0 - .to_string() - .as_str(), - after_pbs_variance.get_standard_dev().0.to_string().as_str(), - exp_pbs_variance.get_standard_dev().0.to_string().as_str(), - ]); - let _ = wtr.write_record([ - "log2 std_dev + ct_w", - (expected_variance.get_log_standard_dev().0 + params.ct_width as f64) - .to_string() - .as_str(), - (encryption_variance.get_log_standard_dev().0 + params.ct_width as f64) - .to_string() - .as_str(), - (bynorm2_variance.get_log_standard_dev().0 + params.ct_width as f64) - .to_string() - .as_str(), - (after_ks_variance.get_log_standard_dev().0 + params.ct_width as f64) - .to_string() - .as_str(), - (expected_after_ks_variance.get_log_standard_dev().0 + params.ct_width as f64) - .to_string() - .as_str(), - (after_pbs_variance.get_log_standard_dev().0 + params.ct_width as f64) - .to_string() - .as_str(), - (exp_pbs_variance.get_log_standard_dev().0 + params.ct_width as f64) - .to_string() - .as_str(), - ]); - - let var_pbs_abs_diff = (exp_pbs_variance.0 - after_pbs_variance.0).abs(); - let pbs_tolerance_thres = RELATIVE_TOLERANCE * exp_pbs_variance.0; - - let var_ksk_abs_diff = (expected_after_ks_variance.0 - after_ks_variance.0).abs(); - let ks_tolerance_thres = RELATIVE_TOLERANCE * expected_after_ks_variance.0; - - let var_bynorm2_abs_diff = (expected_bynorm2_variance.0 - bynorm2_variance.0).abs(); - let bynorm2_tolerance_thres = RELATIVE_TOLERANCE * expected_bynorm2_variance.0; - - let after_pbs_errbit = params.ct_width as f64 + after_pbs_variance.get_log_standard_dev().0; - let after_pbs_exp_errbit = params.ct_width as f64 + exp_pbs_variance.get_log_standard_dev().0; - let bynorm2_errbit = params.ct_width as f64 + bynorm2_variance.get_log_standard_dev().0; - let bynorm2_exp_errbit = - params.ct_width as f64 + expected_bynorm2_variance.get_log_standard_dev().0; - let after_ks_errbit = params.ct_width as f64 + after_ks_variance.get_log_standard_dev().0; - let after_ks_exp_errbit = - params.ct_width as f64 + expected_after_ks_variance.get_log_standard_dev().0; - assert!( - var_pbs_abs_diff < pbs_tolerance_thres, - "Absolute difference for after PBS is incorrect: {var_pbs_abs_diff} >= {pbs_tolerance_thres}, \ - got variance: {after_pbs_variance:?} - log2(str_dev): {after_pbs_errbit:?}, \ - expected variance: {exp_pbs_variance:?} - log2(std_dev): {after_pbs_exp_errbit:?}" - ); - assert!( - var_bynorm2_abs_diff < bynorm2_tolerance_thres, - "Absolute difference for after *norm2 in incorrect: {var_bynorm2_abs_diff} >= {bynorm2_tolerance_thres} \ - got variance: {bynorm2_variance:?} - log2(str_dev): {bynorm2_errbit:?}, \ - expected variance: {expected_bynorm2_variance:?} - log2(std_dev): {bynorm2_exp_errbit:?}" - ); - assert!( - (var_ksk_abs_diff < ks_tolerance_thres) || (after_ks_errbit < after_ks_exp_errbit && (after_ks_exp_errbit - after_ks_errbit < 1f64)), - "Absolute difference for after KS is incorrect: {var_ksk_abs_diff} >= {ks_tolerance_thres} or more than 1 bit away \ - got variance: {after_ks_variance:?} - log2(str_dev): {after_ks_errbit:?}, \ - expected variance: {expected_after_ks_variance:?} - log2(std_dev): {after_ks_exp_errbit:?}" - ); + match test_mode { + HpuNoiseMode::Normality => { + let normality_failure_rate = arithmetic_mean(&normality_check_result); + println!("normality failure rate: {normality_failure_rate:?}"); + assert!( + normality_failure_rate <= 0.065, + "normality failure rate is not acceptable" + ); + let expvalue_score_failure_rate = arithmetic_mean(&expvalue_score_result); + println!("expected value score failure rate: {expvalue_score_failure_rate:?}"); + assert!( + expvalue_score_failure_rate <= 0.08, + "expected value score failure rate is not acceptable" + ); + } + HpuNoiseMode::Variance => { + let encryption_variance = variance(&noise_samples[0]); + let bynorm2_variance = variance(&noise_samples[1]); + let after_ks_variance = variance(&noise_samples[2]); + let after_pbs_variance = variance(&noise_samples[3]); + println!( + "exp encrypt var {:?} encrypt var {:?} bynorm2 var {} after_ks_variance {} after_pbs_variance {:?}", + expected_variance.0, + encryption_variance.0, + bynorm2_variance.0, + after_ks_variance.0, + after_pbs_variance.0 + ); + // variance after *norm2 must be around (exp_pbs_variance)*(norm2**2) + // variance after KS must be around (exp_pbs_variance)*(norm2**2)+exp_add_ks_variance + // variance after PBS must be around (exp_pbs_variance) + let expected_bynorm2_variance = Variance(exp_pbs_variance.0 * (norm2 as f64).powf(2.0)); + let expected_after_ks_variance = + Variance(expected_bynorm2_variance.0 + exp_add_ks_variance.0); + + let mut wtr = csv::Writer::from_writer(io::stdout()); + let _ = wtr.write_record([ + "data type", + "encrypt exp", + "encrypt", + "post *norm2", + "post KS", + "theo KS", + "post PBS", + "theo PBS", + ]); + let _ = wtr.write_record([ + "variances", + expected_variance.0.to_string().as_str(), + encryption_variance.0.to_string().as_str(), + bynorm2_variance.0.to_string().as_str(), + after_ks_variance.0.to_string().as_str(), + expected_after_ks_variance.0.to_string().as_str(), + after_pbs_variance.0.to_string().as_str(), + exp_pbs_variance.0.to_string().as_str(), + ]); + let _ = wtr.write_record([ + "std_dev", + expected_variance.get_standard_dev().0.to_string().as_str(), + encryption_variance + .get_standard_dev() + .0 + .to_string() + .as_str(), + bynorm2_variance.get_standard_dev().0.to_string().as_str(), + after_ks_variance.get_standard_dev().0.to_string().as_str(), + expected_after_ks_variance + .get_standard_dev() + .0 + .to_string() + .as_str(), + after_pbs_variance.get_standard_dev().0.to_string().as_str(), + exp_pbs_variance.get_standard_dev().0.to_string().as_str(), + ]); + let _ = wtr.write_record([ + "log2 std_dev + ct_w", + (expected_variance.get_log_standard_dev().0 + params.ct_width as f64) + .to_string() + .as_str(), + (encryption_variance.get_log_standard_dev().0 + params.ct_width as f64) + .to_string() + .as_str(), + (bynorm2_variance.get_log_standard_dev().0 + params.ct_width as f64) + .to_string() + .as_str(), + (after_ks_variance.get_log_standard_dev().0 + params.ct_width as f64) + .to_string() + .as_str(), + (expected_after_ks_variance.get_log_standard_dev().0 + params.ct_width as f64) + .to_string() + .as_str(), + (after_pbs_variance.get_log_standard_dev().0 + params.ct_width as f64) + .to_string() + .as_str(), + (exp_pbs_variance.get_log_standard_dev().0 + params.ct_width as f64) + .to_string() + .as_str(), + ]); + + let var_pbs_abs_diff = (exp_pbs_variance.0 - after_pbs_variance.0).abs(); + let pbs_tolerance_thres = RELATIVE_TOLERANCE * exp_pbs_variance.0; + + let var_ksk_abs_diff = (expected_after_ks_variance.0 - after_ks_variance.0).abs(); + let ks_tolerance_thres = RELATIVE_TOLERANCE * expected_after_ks_variance.0; + + let var_bynorm2_abs_diff = (expected_bynorm2_variance.0 - bynorm2_variance.0).abs(); + let bynorm2_tolerance_thres = RELATIVE_TOLERANCE * expected_bynorm2_variance.0; + + let after_pbs_errbit = + params.ct_width as f64 + after_pbs_variance.get_log_standard_dev().0; + let after_pbs_exp_errbit = + params.ct_width as f64 + exp_pbs_variance.get_log_standard_dev().0; + let bynorm2_errbit = params.ct_width as f64 + bynorm2_variance.get_log_standard_dev().0; + let bynorm2_exp_errbit = + params.ct_width as f64 + expected_bynorm2_variance.get_log_standard_dev().0; + let after_ks_errbit = + params.ct_width as f64 + after_ks_variance.get_log_standard_dev().0; + let after_ks_exp_errbit = + params.ct_width as f64 + expected_after_ks_variance.get_log_standard_dev().0; + assert!( + var_pbs_abs_diff < pbs_tolerance_thres, + "Absolute difference for after PBS is incorrect: {var_pbs_abs_diff} >= {pbs_tolerance_thres}, \ + got variance: {after_pbs_variance:?} - log2(str_dev): {after_pbs_errbit:?}, \ + expected variance: {exp_pbs_variance:?} - log2(std_dev): {after_pbs_exp_errbit:?}" + ); + assert!( + var_bynorm2_abs_diff < bynorm2_tolerance_thres, + "Absolute difference for after *norm2 in incorrect: {var_bynorm2_abs_diff} >= {bynorm2_tolerance_thres} \ + got variance: {bynorm2_variance:?} - log2(str_dev): {bynorm2_errbit:?}, \ + expected variance: {expected_bynorm2_variance:?} - log2(std_dev): {bynorm2_exp_errbit:?}" + ); + assert!( + (var_ksk_abs_diff < ks_tolerance_thres) || (after_ks_errbit < after_ks_exp_errbit && (after_ks_exp_errbit - after_ks_errbit < 1f64)), + "Absolute difference for after KS is incorrect: {var_ksk_abs_diff} >= {ks_tolerance_thres} or more than 1 bit away \ + got variance: {after_ks_variance:?} - log2(str_dev): {after_ks_errbit:?}, \ + expected variance: {expected_after_ks_variance:?} - log2(std_dev): {after_ks_exp_errbit:?}" + ); + } + } +} + +// Macro to generate tests for all parameter sets with arguments +macro_rules! create_parameterized_test_hpu{ + ($name:ident { $($param:ident),* $(,)? }, $max_msg:expr, $nb_test:expr, $nb_pbs:expr, $check_var:expr) => { + ::paste::paste! { + $( + #[test] + fn []() { + $name($param, $max_msg, $nb_test, $nb_pbs, $check_var) + } + )* + } + }; } -create_parameterized_test!(hpu_noise_distribution { - //HPU_TEST_PARAMS_4_BITS_NATIVE_U64, - //HPU_TEST_PARAMS_4_BITS_HPU_44_KS_21, - //HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21, - HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_GAUSSIAN, - HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_TUNIFORM, - //HPU_TEST_PARAMS_4_BITS_NATIVE_U64_132_BITS_GAUSSIAN, -}); +static NORMALITY_MODE: HpuNoiseMode = HpuNoiseMode::Normality; +static VARIANCE_MODE: HpuNoiseMode = HpuNoiseMode::Variance; + +// tests with >= 16k samples for variance check +create_parameterized_test_hpu!( + hpu_noise_distribution { + HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_GAUSSIAN, + HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_TUNIFORM, + HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_TUNIFORM_2M128, + }, + 16, + 5, + 200, + VARIANCE_MODE +); + +// tests for checking normality & expected value after KS +create_parameterized_test_hpu!( + hpu_noise_distribution { + HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_TUNIFORM, + HPU_TEST_PARAMS_4_BITS_HPU_64_KS_21_132_TUNIFORM_2M128, + }, + 2, + 100, + 160, + NORMALITY_MODE +);