diff --git a/.github/workflows/build-claasp-base-image.yaml b/.github/workflows/build-claasp-base-image.yaml index d3379481c..f1a642fcc 100644 --- a/.github/workflows/build-claasp-base-image.yaml +++ b/.github/workflows/build-claasp-base-image.yaml @@ -2,14 +2,14 @@ name: Build and push Intel and M1 images for testing on: push: branches: - - main + - test_merging_main_to_develop jobs: build-image: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest] steps: - name: Checkout uses: actions/checkout@v3 @@ -20,14 +20,6 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - - name: Cache Docker layers - uses: actions/cache@v3 - with: - path: /tmp/.buildx-cache - key: ${{ runner.os }}-buildx-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-buildx- - - name: Login Docker Hub uses: docker/login-action@v3 with: @@ -41,12 +33,19 @@ jobs: context: . file: ./docker/Dockerfile push: true - tags: ${{ matrix.os == 'ubuntu-latest' && 'tiicrc/claasp-base:latest' || 'tiicrc/claasp-m1-base:latest' }} + tags: tiicrc/claasp-base:latest1 target: claasp-base - cache-from: type=local,src=/tmp/.buildx-cache - cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max - - - name: Move cache - run: | - rm -rf /tmp/.buildx-cache - mv /tmp/.buildx-cache-new /tmp/.buildx-cache + cache-from: type=gha,scope=claasp-base + cache-to: type=gha,scope=claasp-base,mode=max + + - name: Build & Push (m1) + uses: docker/build-push-action@v4 + id: built-image-m1 + with: + context: . + file: ./docker/Dockerfile + push: true + tags: tiicrc/claasp-m1-base:latest1 + target: claasp-base + cache-from: type=gha,scope=claasp-base-m1 + cache-to: type=gha,scope=claasp-base-m1,mode=max \ No newline at end of file diff --git a/.gitignore b/.gitignore index 512fc097c..bef20ec45 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,9 @@ local/ upstream/ SAGE_BIN_PATH + +# experiments codes +diff_ca_kalyna.py +diff_ca_sm4.py +diff_ca_sm4_2.py +output_kalyna.txt \ No newline at end of file diff --git a/claasp/DTOs/component_state.py b/claasp/DTOs/component_state.py index e3c0ea4a1..ee7f15d10 100644 --- a/claasp/DTOs/component_state.py +++ b/claasp/DTOs/component_state.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** diff --git a/claasp/DTOs/power_of_2_word_based_dto.py b/claasp/DTOs/power_of_2_word_based_dto.py index 20b84e284..9bca1e8ff 100644 --- a/claasp/DTOs/power_of_2_word_based_dto.py +++ b/claasp/DTOs/power_of_2_word_based_dto.py @@ -1,24 +1,22 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** class PowerOf2WordBasedDTO: - def __init__(self, word_size=None, fixed=False): self._word_size = word_size self._fixed = fixed diff --git a/claasp/cipher.py b/claasp/cipher.py index 0ea15fb8c..af7d4263e 100644 --- a/claasp/cipher.py +++ b/claasp/cipher.py @@ -1,49 +1,52 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - +import importlib +import inspect import os import sys -import inspect import claasp from claasp import editor +from claasp.cipher_modules import code_generator +from claasp.cipher_modules import tester, evaluator +from claasp.cipher_modules.inverse_cipher import * +from claasp.cipher_modules.models.algebraic.algebraic_model import AlgebraicModel from claasp.components.cipher_output_component import CipherOutput from claasp.compound_xor_differential_cipher import convert_to_compound_xor_cipher from claasp.rounds import Rounds -from claasp.cipher_modules import tester, evaluator -from claasp.utils.templates import TemplateManager, CSVBuilder -from claasp.cipher_modules.models.algebraic.algebraic_model import AlgebraicModel -from claasp.cipher_modules import code_generator -import importlib -from claasp.cipher_modules.inverse_cipher import * +from claasp.name_mappings import CIPHER_INVERSE_SUFFIX + tii_path = inspect.getfile(claasp) tii_dir_path = os.path.dirname(tii_path) - -TII_C_LIB_PATH = f'{tii_dir_path}/cipher/' +TII_C_LIB_PATH = f"{tii_dir_path}/cipher/" class Cipher: - - - def __init__(self, family_name, cipher_type, cipher_inputs, - cipher_inputs_bit_size, cipher_output_bit_size, - cipher_reference_code=None): + def __init__( + self, + family_name, + cipher_type, + cipher_inputs, + cipher_inputs_bit_size, + cipher_output_bit_size, + cipher_reference_code=None, + ): """ Construct an instance of the Cipher class. @@ -151,6 +154,7 @@ def __init__(self, family_name, cipher_type, cipher_inputs, def __repr__(self): return self.id + def _are_there_not_forbidden_components(self, forbidden_types, forbidden_descriptions): return self._rounds.are_there_not_forbidden_components(forbidden_types, forbidden_descriptions) @@ -170,16 +174,29 @@ def add_FSR_component(self, input_id_links, input_bit_positions, output_bit_size return editor.add_FSR_component(self, input_id_links, input_bit_positions, output_bit_size, description) def add_intermediate_output_component(self, input_id_links, input_bit_positions, output_bit_size, output_tag): - return editor.add_intermediate_output_component(self, input_id_links, input_bit_positions, - output_bit_size, output_tag) + return editor.add_intermediate_output_component( + self, input_id_links, input_bit_positions, output_bit_size, output_tag + ) def add_linear_layer_component(self, input_id_links, input_bit_positions, output_bit_size, description): - return editor.add_linear_layer_component(self, input_id_links, input_bit_positions, - output_bit_size, description) - - def add_mix_column_component(self, input_id_links, input_bit_positions, output_bit_size, mix_column_description): - return editor.add_mix_column_component(self, input_id_links, input_bit_positions, - output_bit_size, mix_column_description) + return editor.add_linear_layer_component( + self, input_id_links, input_bit_positions, output_bit_size, description + ) + + def add_mix_column_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + mix_column_description, + ): + return editor.add_mix_column_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + mix_column_description, + ) def add_MODADD_component(self, input_id_links, input_bit_positions, output_bit_size, modulus=None): return editor.add_MODADD_component(self, input_id_links, input_bit_positions, output_bit_size, modulus) @@ -193,9 +210,20 @@ def add_NOT_component(self, input_id_links, input_bit_positions, output_bit_size def add_OR_component(self, input_id_links, input_bit_positions, output_bit_size): return editor.add_OR_component(self, input_id_links, input_bit_positions, output_bit_size) - def add_permutation_component(self, input_id_links, input_bit_positions, output_bit_size, permutation_description): - return editor.add_permutation_component(self, input_id_links, input_bit_positions, - output_bit_size, permutation_description) + def add_permutation_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + ): + return editor.add_permutation_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + ) def add_reverse_component(self, input_id_links, input_bit_positions, output_bit_size): return editor.add_reverse_component(self, input_id_links, input_bit_positions, output_bit_size) @@ -221,13 +249,35 @@ def add_SHIFT_component(self, input_id_links, input_bit_positions, output_bit_si def add_shift_rows_component(self, input_id_links, input_bit_positions, output_bit_size, parameter): return editor.add_shift_rows_component(self, input_id_links, input_bit_positions, output_bit_size, parameter) - def add_sigma_component(self, input_id_links, input_bit_positions, output_bit_size, rotation_amounts_parameter): - return editor.add_sigma_component(self, input_id_links, input_bit_positions, - output_bit_size, rotation_amounts_parameter) - - def add_theta_gaston_component(self, input_id_links, input_bit_positions, output_bit_size, rotation_amounts_parameter): - return editor.add_theta_gaston_component(self, input_id_links, input_bit_positions, - output_bit_size, rotation_amounts_parameter) + def add_sigma_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ): + return editor.add_sigma_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ) + + def add_theta_gaston_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ): + return editor.add_theta_gaston_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ) def add_theta_keccak_component(self, input_id_links, input_bit_positions, output_bit_size): return editor.add_theta_keccak_component(self, input_id_links, input_bit_positions, output_bit_size) @@ -236,31 +286,45 @@ def add_theta_xoodoo_component(self, input_id_links, input_bit_positions, output return editor.add_theta_xoodoo_component(self, input_id_links, input_bit_positions, output_bit_size) def add_variable_rotate_component(self, input_id_links, input_bit_positions, output_bit_size, parameter): - return editor.add_variable_rotate_component(self, input_id_links, input_bit_positions, - output_bit_size, parameter) + return editor.add_variable_rotate_component( + self, input_id_links, input_bit_positions, output_bit_size, parameter + ) def add_variable_shift_component(self, input_id_links, input_bit_positions, output_bit_size, parameter): - return editor.add_variable_shift_component(self, input_id_links, input_bit_positions, - output_bit_size, parameter) - - def add_word_permutation_component(self, input_id_links, input_bit_positions, - output_bit_size, permutation_description, word_size): - return editor.add_word_permutation_component(self, input_id_links, input_bit_positions, - output_bit_size, permutation_description, word_size) + return editor.add_variable_shift_component( + self, input_id_links, input_bit_positions, output_bit_size, parameter + ) + + def add_word_permutation_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + word_size, + ): + return editor.add_word_permutation_component( + self, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + word_size, + ) def add_XOR_component(self, input_id_links, input_bit_positions, output_bit_size): return editor.add_XOR_component(self, input_id_links, input_bit_positions, output_bit_size) def as_python_dictionary(self): return { - 'cipher_id': self._id, - 'cipher_type': self._type, - 'cipher_inputs': self._inputs, - 'cipher_inputs_bit_size': self._inputs_bit_size, - 'cipher_output_bit_size': self._output_bit_size, - 'cipher_number_of_rounds': self.number_of_rounds, - 'cipher_rounds': self._rounds.rounds_as_python_dictionary(), - 'cipher_reference_code': self._reference_code + "cipher_id": self._id, + "cipher_type": self._type, + "cipher_inputs": self._inputs, + "cipher_inputs_bit_size": self._inputs_bit_size, + "cipher_output_bit_size": self._output_bit_size, + "cipher_number_of_rounds": self.number_of_rounds, + "cipher_rounds": self._rounds.rounds_as_python_dictionary(), + "cipher_reference_code": self._reference_code, } def component_from(self, round_number, index): @@ -276,7 +340,7 @@ def delete_generated_evaluate_c_shared_library(self): EXAMPLES:: -ds sage: from claasp.ciphers.toys.fancy_block_cipher import FancyBlockCipher as fancy + sage: from claasp.ciphers.toys.fancy_block_cipher import FancyBlockCipher as fancy sage: fancy().delete_generated_evaluate_c_shared_library() # doctest: +SKIP """ code_generator.delete_generated_evaluate_c_shared_library(self) @@ -530,7 +594,13 @@ def cipher_inverse(self): sage: cipher_inv.evaluate([ciphertext, key]) == plaintext True """ - inverted_cipher = Cipher(f"{self.id}{CIPHER_INVERSE_SUFFIX}", f"{self.type}", [], [], self.output_bit_size) + inverted_cipher = Cipher( + f"{self.id}{CIPHER_INVERSE_SUFFIX}", + f"{self.type}", + [], + [], + self.output_bit_size, + ) inverted_cipher_components = [] cipher_components_tmp = get_cipher_components(self) @@ -543,23 +613,39 @@ def cipher_inverse(self): for c in cipher_components_tmp: # print(c.id, "---------", len(cipher_components_tmp)) # OPTION 1 - Add components that are not invertible - if are_there_enough_available_inputs_to_evaluate_component(c, available_bits, all_equivalent_bits, - key_schedule_component_ids, self): + if are_there_enough_available_inputs_to_evaluate_component( + c, + available_bits, + all_equivalent_bits, + key_schedule_component_ids, + self, + ): # print("--------> evaluated") - inverted_component = evaluated_component(c, available_bits, key_schedule_component_ids, - all_equivalent_bits, self) + inverted_component = evaluated_component( + c, + available_bits, + key_schedule_component_ids, + all_equivalent_bits, + self, + ) update_available_bits_with_component_output_bits(c, available_bits, self) inverted_cipher_components.append(inverted_component) cipher_components_tmp.remove(c) # OPTION 2 - Add components that are invertible - elif (is_possibly_invertible_component(c) and are_there_enough_available_inputs_to_perform_inversion(c, - available_bits, - all_equivalent_bits, - self)) or ( - c.type == CIPHER_INPUT and (c.description[0] == INPUT_KEY or c.description[0] == INPUT_TWEAK)): + elif ( + is_possibly_invertible_component(c) + and are_there_enough_available_inputs_to_perform_inversion( + c, available_bits, all_equivalent_bits, self + ) + ) or (c.type == CIPHER_INPUT and (c.description[0] == INPUT_KEY or c.description[0] == INPUT_TWEAK)): # print("--------> inverted") - inverted_component = component_inverse(c, available_bits, all_equivalent_bits, - key_schedule_component_ids, self) + inverted_component = component_inverse( + c, + available_bits, + all_equivalent_bits, + key_schedule_component_ids, + self, + ) update_available_bits_with_component_input_bits(c, available_bits) update_available_bits_with_component_output_bits(c, available_bits, self) inverted_cipher_components.append(inverted_component) @@ -582,14 +668,14 @@ def cipher_inverse(self): inverted_cipher._rounds.round_at(0)._components.append(component) else: inverted_cipher._rounds.round_at(self.number_of_rounds - 1 - component.round)._components.append( - component) + component + ) sorted_inverted_cipher = sort_cipher_graph(inverted_cipher) return sorted_inverted_cipher def get_partial_cipher(self, start_round=None, end_round=None, keep_key_schedule=True): - if start_round is None: start_round = 0 if end_round is None: @@ -599,13 +685,19 @@ def get_partial_cipher(self, start_round=None, end_round=None, keep_key_schedule assert start_round <= end_round inputs = deepcopy(self.inputs) - partial_cipher = Cipher(f"{self.family_name}_partial_{start_round}_to_{end_round}", f"{self.type}", inputs, - self._inputs_bit_size, self.output_bit_size) + partial_cipher = Cipher( + f"{self.family_name}_partial_{start_round}_to_{end_round}", + f"{self.type}", + inputs, + self._inputs_bit_size, + self.output_bit_size, + ) for round in self.rounds_as_list: partial_cipher.rounds_as_list.append(deepcopy(round)) - removed_components_ids, intermediate_outputs = remove_components_from_rounds(partial_cipher, start_round, - end_round, keep_key_schedule) + removed_components_ids, intermediate_outputs = remove_components_from_rounds( + partial_cipher, start_round, end_round, keep_key_schedule + ) if start_round > 0: for input_type in set([input for input in self.inputs if INPUT_KEY not in input]): @@ -616,19 +708,29 @@ def get_partial_cipher(self, start_round=None, end_round=None, keep_key_schedule partial_cipher.inputs.insert(0, intermediate_outputs[start_round - 1].id) partial_cipher.inputs_bit_size.insert(0, intermediate_outputs[start_round - 1].output_bit_size) - update_input_links_from_rounds(partial_cipher.rounds_as_list[start_round:end_round + 1], - removed_components_ids, intermediate_outputs) + update_input_links_from_rounds( + partial_cipher.rounds_as_list[start_round : end_round + 1], + removed_components_ids, + intermediate_outputs, + ) if end_round < self.number_of_rounds - 1: removed_components_ids.append(CIPHER_OUTPUT) last_round = partial_cipher.rounds_as_list[end_round] for component in last_round.components: - if component.description == ['round_output']: + if component.description == ["round_output"]: last_round.remove_component(component) - new_cipher_output = Component(component.id, CIPHER_OUTPUT, - Input(component.output_bit_size, component.input_id_links, - component.input_bit_positions), - component.output_bit_size, [CIPHER_OUTPUT]) + new_cipher_output = Component( + component.id, + CIPHER_OUTPUT, + Input( + component.output_bit_size, + component.input_id_links, + component.input_bit_positions, + ), + component.output_bit_size, + [CIPHER_OUTPUT], + ) new_cipher_output.__class__ = CipherOutput last_round.add_component(new_cipher_output) @@ -639,22 +741,35 @@ def add_suffix_to_components(self, suffix, component_id_list=None): if component_id_list is None: component_id_list = self.get_all_components_ids() + self.inputs renamed_inputs = [f"{input}{suffix}" if input in component_id_list else input for input in self.inputs] - renamed_cipher = Cipher(f"{self.family_name}", f"{self.type}", renamed_inputs, - self.inputs_bit_size, self.output_bit_size) + renamed_cipher = Cipher( + f"{self.family_name}", + f"{self.type}", + renamed_inputs, + self.inputs_bit_size, + self.output_bit_size, + ) for round in self.rounds_as_list: renamed_cipher.add_round() for component_number in range(round.number_of_components): component = round.component_from(component_number) - renamed_input_id_links = [f"{id}{suffix}" if id in component_id_list else id for id in - component.input_id_links] + renamed_input_id_links = [ + f"{id}{suffix}" if id in component_id_list else id for id in component.input_id_links + ] if component.id in component_id_list: - renamed_component_id = f'{component.id}{suffix}' + renamed_component_id = f"{component.id}{suffix}" else: renamed_component_id = component.id - renamed_component = Component(renamed_component_id, component.type, - Input(component.input_bit_size, renamed_input_id_links, - component.input_bit_positions), - component.output_bit_size, component.description) + renamed_component = Component( + renamed_component_id, + component.type, + Input( + component.input_bit_size, + renamed_input_id_links, + component.input_bit_positions, + ), + component.output_bit_size, + component.description, + ) renamed_component.__class__ = component.__class__ renamed_cipher.rounds.current_round.add_component(renamed_component) @@ -686,9 +801,9 @@ def cipher_partial_inverse(self, start_round=None, end_round=None, keep_key_sche partial_cipher_inverse = partial_cipher.cipher_inverse() key_schedule_component_ids = get_key_schedule_component_ids(partial_cipher_inverse) - key_schedule_components = [partial_cipher_inverse.get_component_from_id(id) for id in key_schedule_component_ids - if - INPUT_KEY not in id] + key_schedule_components = [ + partial_cipher_inverse.get_component_from_id(id) for id in key_schedule_component_ids if INPUT_KEY not in id + ] if not keep_key_schedule: for current_round in partial_cipher_inverse.rounds_as_list: @@ -697,7 +812,14 @@ def cipher_partial_inverse(self, start_round=None, end_round=None, keep_key_sche return partial_cipher_inverse - def evaluate_vectorized(self, cipher_input, intermediate_output=False, verbosity=False, evaluate_api = False, bit_based = False): + def evaluate_vectorized( + self, + cipher_input, + intermediate_output=False, + verbosity=False, + evaluate_api=False, + bit_based=False, + ): """ Return the output of the cipher for multiple inputs. @@ -722,6 +844,7 @@ def evaluate_vectorized(self, cipher_input, intermediate_output=False, verbosity - ``evaluate_api`` -- **boolean** (default: `False`); if set to True, takes integer inputs (as the evaluate function) and returns integer inputs; it is expected that cipher.evaluate(x) == cipher.evaluate_vectorized(x, evaluate_api = True) is True. + EXAMPLES:: sage: import numpy as np @@ -741,10 +864,15 @@ def evaluate_vectorized(self, cipher_input, intermediate_output=False, verbosity sage: int.from_bytes(result[-1][1].tobytes(), byteorder='big') == C1Lib True """ - return evaluator.evaluate_vectorized(self, cipher_input, intermediate_output, verbosity, evaluate_api, bit_based) + return evaluator.evaluate_vectorized(self, cipher_input, intermediate_output, verbosity, evaluate_api) def evaluate_with_intermediate_outputs_continuous_diffusion_analysis( - self, cipher_input, sbox_precomputations, sbox_precomputations_mix_columns, verbosity=False): + self, + cipher_input, + sbox_precomputations, + sbox_precomputations_mix_columns, + verbosity=False, + ): """ Return the output of the continuous generalized cipher. @@ -774,7 +902,12 @@ def evaluate_with_intermediate_outputs_continuous_diffusion_analysis( True """ return evaluator.evaluate_with_intermediate_outputs_continuous_diffusion_analysis( - self, cipher_input, sbox_precomputations, sbox_precomputations_mix_columns, verbosity) + self, + cipher_input, + sbox_precomputations, + sbox_precomputations_mix_columns, + verbosity, + ) def generate_bit_based_c_code(self, intermediate_output=False, verbosity=False): """ @@ -948,8 +1081,14 @@ def is_andrx(self): sage: midori.is_andrx() False """ - forbidden_types = {'sbox', 'mix_column', 'linear_layer'} - forbidden_descriptions = {'OR', 'MODADD', 'MODSUB', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT'} + forbidden_types = {"sbox", "mix_column", "linear_layer"} + forbidden_descriptions = { + "OR", + "MODADD", + "MODSUB", + "SHIFT", + "SHIFT_BY_VARIABLE_AMOUNT", + } return self._are_there_not_forbidden_components(forbidden_types, forbidden_descriptions) @@ -968,8 +1107,14 @@ def is_arx(self): sage: midori.is_arx() False """ - forbidden_types = {'sbox', 'mix_column', 'linear_layer'} - forbidden_descriptions = {'OR', 'AND', 'MODSUB', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT'} + forbidden_types = {"sbox", "mix_column", "linear_layer"} + forbidden_descriptions = { + "OR", + "AND", + "MODSUB", + "SHIFT", + "SHIFT_BY_VARIABLE_AMOUNT", + } return self._are_there_not_forbidden_components(forbidden_types, forbidden_descriptions) @@ -1007,8 +1152,8 @@ def is_shift_arx(self): sage: xtea.is_shift_arx() True """ - forbidden_types = {'sbox', 'mix_column', 'linear_layer'} - forbidden_descriptions = {'AND', 'OR', 'MODSUB'} + forbidden_types = {SBOX, MIX_COLUMN, LINEAR_LAYER} + forbidden_descriptions = {"AND", "OR", "MODSUB"} return self._are_there_not_forbidden_components(forbidden_types, forbidden_descriptions) @@ -1027,10 +1172,21 @@ def is_spn(self): sage: aes.is_spn() True """ - spn_components = {CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, MIX_COLUMN, - SBOX, 'ROTATE', 'XOR'} - set_of_components, set_of_mix_column_sizes, set_of_rotate_and_shift_values, set_of_sbox_sizes = \ - self.get_sizes_of_components_by_type() + spn_components = { + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + MIX_COLUMN, + SBOX, + "ROTATE", + "XOR", + } + ( + set_of_components, + set_of_mix_column_sizes, + set_of_rotate_and_shift_values, + set_of_sbox_sizes, + ) = self.get_sizes_of_components_by_type() if (len(set_of_sbox_sizes) > 1) or (len(set_of_mix_column_sizes) > 1): return False sbox_size = 0 @@ -1055,20 +1211,20 @@ def get_model(self, technique, problem): - ``technique`` -- **string** ; sat, smt, milp or cp - ``problem`` -- **string** ; xor_differential, xor_linear, cipher_model (more to be added as more model types are added to the library) - """ - if technique == 'cp': - technique = 'mzn' - formalism = 'cp' - if problem == 'xor_differential': - constructor_name = f'{technique[0].capitalize()}{technique[1:]}XorDifferentialModel' + """ + if technique == "cp": + technique = "mzn" + formalism = "cp" + if problem == "xor_differential": + constructor_name = f"{technique[0].capitalize()}{technique[1:]}XorDifferentialModel" elif problem == "xor_linear": - constructor_name = f'{technique[0].capitalize()}{technique[1:]}XorLinearModel' - elif problem == 'cipher_model': - constructor_name = f'{technique[0].capitalize()}{technique[1:]}CipherModel' + constructor_name = f"{technique[0].capitalize()}{technique[1:]}XorLinearModel" + elif problem == "cipher_model": + constructor_name = f"{technique[0].capitalize()}{technique[1:]}CipherModel" - module_name = f'claasp.cipher_modules.models.{technique}.{technique}_models.{technique}_{problem}_model' - if technique == 'mzn': - module_name = f'claasp.cipher_modules.models.{formalism}.{technique}_models.{technique}_{problem}_model' + module_name = f"claasp.cipher_modules.models.{technique}.{technique}_models.{technique}_{problem}_model" + if technique == "mzn": + module_name = f"claasp.cipher_modules.models.{formalism}.{technique}_models.{technique}_{problem}_model" module = importlib.import_module(module_name) constructor = getattr(module, constructor_name) @@ -1086,15 +1242,25 @@ def get_sizes_of_components_by_type(self): set_of_mix_column_sizes.add(component.description[2]) if component.type == WORD_OPERATION: set_of_components.add(component.description[0]) - if component.description[0] == 'ROTATE' or component.description[0] == 'SHIFT': + if component.description[0] == "ROTATE" or component.description[0] == "SHIFT": set_of_rotate_and_shift_values.add(component.description[1]) else: set_of_components.add(component.type) - return set_of_components, set_of_mix_column_sizes, set_of_rotate_and_shift_values, set_of_sbox_sizes + return ( + set_of_components, + set_of_mix_column_sizes, + set_of_rotate_and_shift_values, + set_of_sbox_sizes, + ) def make_cipher_id(self): - return editor.make_cipher_id(self._family_name, self._inputs, self._inputs_bit_size, - self._output_bit_size, self.number_of_rounds) + return editor.make_cipher_id( + self._family_name, + self._inputs, + self._inputs_bit_size, + self._output_bit_size, + self.number_of_rounds, + ) def make_file_name(self): return editor.make_file_name(self._id) @@ -1207,8 +1373,8 @@ def print_as_python_dictionary(self): } """ print("cipher = {") - print("'cipher_id': '" + self._id + "',") - print("'cipher_type': '" + self._type + "',") + print(f"'cipher_id': '{self._id}',") + print(f"'cipher_type': '{self._type}',") print(f"'cipher_inputs': {self._inputs},") print(f"'cipher_inputs_bit_size': {self._inputs_bit_size},") print(f"'cipher_output_bit_size': {self._output_bit_size},") @@ -1241,7 +1407,7 @@ def print_as_python_dictionary_to_file(self, file_name=""): original_stdout = sys.stdout # Save a reference to the original standard output if file_name == "": file_name = self._file_name - with open(file_name, 'w') as f: + with open(file_name, "w") as f: sys.stdout = f # Change the standard output to the file we created. self.print_as_python_dictionary() sys.stdout = original_stdout # Reset the standard output to its original value @@ -1322,7 +1488,7 @@ def print_evaluation_python_code_to_file(self, file_name): """ original_stdout = sys.stdout # Save a reference to the original standard output - with open(file_name, 'w') as f: + with open(file_name, "w") as f: sys.stdout = f # Change the standard output to the file we created. self.print_evaluation_python_code() sys.stdout = original_stdout # Reset the standard output to its original value @@ -1486,49 +1652,85 @@ def find_impossible_property(self, type, technique="sat", solver="kissat", scena - ``technique`` -- **string**; {"sat", "smt", "milp", "cp"}: the technique to use for the search - ``solver`` -- **string**; the name of the solver to use for the search """ - from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list - model = self.get_model(technique, f'xor_{type}') - if type == 'differential': + from claasp.cipher_modules.models.utils import ( + set_fixed_variables, + integer_to_bit_list, + ) + + model = self.get_model(technique, f"xor_{type}") + if type == "differential": search_function = model.find_one_xor_differential_trail else: search_function = model.find_one_xor_linear_trail last_component_id = self.get_all_components()[-1].id impossible = [] inputs_dictionary = self.inputs_size_to_dict() - plain_bits = inputs_dictionary['plaintext'] - key_bits = inputs_dictionary['key'] + plain_bits = inputs_dictionary[INPUT_PLAINTEXT] + key_bits = inputs_dictionary[INPUT_KEY] if scenario == "single-key": # Fix the key difference to be zero, and the plaintext difference to be non-zero. for input_bit_position in range(plain_bits): for output_bit_position in range(plain_bits): fixed_values = [] - fixed_values.append(set_fixed_variables('key', 'equal', list(range(key_bits)), - integer_to_bit_list(0, key_bits, 'big'))) - fixed_values.append(set_fixed_variables('plaintext', 'equal', list(range(plain_bits)), - integer_to_bit_list(1 << input_bit_position, plain_bits, - 'big'))) - fixed_values.append(set_fixed_variables(last_component_id, 'equal', list(range(plain_bits)), - integer_to_bit_list(1 << output_bit_position, plain_bits, - 'big'))) + fixed_values.append( + set_fixed_variables( + INPUT_KEY, + "equal", + list(range(key_bits)), + integer_to_bit_list(0, key_bits, "big"), + ) + ) + fixed_values.append( + set_fixed_variables( + INPUT_PLAINTEXT, + "equal", + list(range(plain_bits)), + integer_to_bit_list(1 << input_bit_position, plain_bits, "big"), + ) + ) + fixed_values.append( + set_fixed_variables( + last_component_id, + "equal", + list(range(plain_bits)), + integer_to_bit_list(1 << output_bit_position, plain_bits, "big"), + ) + ) solution = search_function(fixed_values, solver_name=solver) - if solution['status'] == "UNSATISFIABLE": + if solution["status"] == "UNSATISFIABLE": impossible.append((1 << input_bit_position, 1 << output_bit_position)) elif scenario == "related-key": for input_bit_position in range(key_bits): for output_bit_position in range(plain_bits): fixed_values = [] - fixed_values.append(set_fixed_variables('key', 'equal', list(range(key_bits)), - integer_to_bit_list(1 << (input_bit_position), key_bits, - 'big'))) - fixed_values.append(set_fixed_variables('plaintext', 'equal', list(range(plain_bits)), - integer_to_bit_list(0, plain_bits, 'big'))) - - fixed_values.append(set_fixed_variables(last_component_id, 'equal', list(range(plain_bits)), - integer_to_bit_list(1 << output_bit_position, plain_bits, - 'big'))) + fixed_values.append( + set_fixed_variables( + INPUT_KEY, + "equal", + list(range(key_bits)), + integer_to_bit_list(1 << (input_bit_position), key_bits, "big"), + ) + ) + fixed_values.append( + set_fixed_variables( + INPUT_PLAINTEXT, + "equal", + list(range(plain_bits)), + integer_to_bit_list(0, plain_bits, "big"), + ) + ) + + fixed_values.append( + set_fixed_variables( + last_component_id, + "equal", + list(range(plain_bits)), + integer_to_bit_list(1 << output_bit_position, plain_bits, "big"), + ) + ) solution = search_function(fixed_values, solver_name=solver) - if solution['status'] == "UNSATISFIABLE": + if solution["status"] == "UNSATISFIABLE": impossible.append((1 << input_bit_position, 1 << output_bit_position)) return impossible @@ -1603,7 +1805,8 @@ def type(self): def create_networx_graph_from_input_ids(self): import networkx as nx - data = self.as_python_dictionary()['cipher_rounds'] + + data = self.as_python_dictionary()["cipher_rounds"] # Create a directed graph G = nx.DiGraph() @@ -1671,4 +1874,3 @@ def get_descendants_subgraph(G, start_nodes): def update_input_id_links_from_component_id(self, component_id, new_input_id_links): round_number = self.get_round_from_component_id(component_id) self._rounds.rounds[round_number].update_input_id_links_from_component_id(component_id, new_input_id_links) - diff --git a/claasp/cipher_modules/avalanche_tests.py b/claasp/cipher_modules/avalanche_tests.py index 6fd3b6afd..add90426b 100644 --- a/claasp/cipher_modules/avalanche_tests.py +++ b/claasp/cipher_modules/avalanche_tests.py @@ -322,8 +322,7 @@ def _generate_random_inputs(self, nb_samples): def _generate_avalanche_probability_vectors(self, dict_intermediate_output_names, inputs, evaluated_inputs, input_diff, index_of_specific_input): inputs_prime = self._generate_inputs_prime(index_of_specific_input, input_diff, inputs) - evaluated_inputs_prime = evaluator.evaluate_vectorized(self._cipher, inputs_prime, - intermediate_output=True, verbosity=False) + evaluated_inputs_prime = evaluator.evaluate_vectorized(self._cipher, inputs_prime, intermediate_output=True) intermediate_avalanche_probability_vectors = {} for intermediate_output_name in list(dict_intermediate_output_names.keys()): intermediate_avalanche_probability_vectors[intermediate_output_name] = \ diff --git a/claasp/cipher_modules/code_generator.py b/claasp/cipher_modules/code_generator.py index c757ac7e8..bfa797157 100644 --- a/claasp/cipher_modules/code_generator.py +++ b/claasp/cipher_modules/code_generator.py @@ -329,15 +329,6 @@ def prepare_input_bit_based_vectorized_python_code_string(component): return params -def constant_to_bitstring(val, output_size): - ret = [] - _val = int(val, 0) - for i in range(output_size): - ret.append((_val >> (output_size - 1 - i)) & 1) - - return ret - - def generate_byte_based_vectorized_python_code_string(cipher, store_intermediate_outputs=False, verbosity=False, integers_inputs_and_outputs = False): r""" Return string python code needed to evaluate a cipher using a vectorized implementation byte based oriented. diff --git a/claasp/cipher_modules/division_trail_search.py b/claasp/cipher_modules/division_trail_search.py deleted file mode 100644 index 06e5b41eb..000000000 --- a/claasp/cipher_modules/division_trail_search.py +++ /dev/null @@ -1,839 +0,0 @@ - -# **************************************************************************** -# Copyright 2023 Technology Innovation Institute -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# **************************************************************************** - -import time -from sage.crypto.sbox import SBox -from collections import Counter -from sage.rings.polynomial.pbori.pbori import BooleanPolynomialRing -from claasp.cipher_modules.graph_generator import create_networkx_graph_from_input_ids, _get_predecessors_subgraph -from claasp.cipher_modules.component_analysis_tests import binary_matrix_of_linear_component -from gurobipy import Model, GRB, Env -import os - -verbosity = False - -class MilpDivisionTrailModel(): - """ - - Given a number of rounds of a chosen cipher and a chosen output bit, this module produces a model that can either: - - obtain the ANF of this chosen output bit, - - find the degree of this ANF, - - or check the presence or absence of a specified monomial. - - This module can only be used if the user possesses a Gurobi license. - - """ - - def __init__(self, cipher): - self._cipher = cipher - self._variables = None - self._model = None - self._occurences = None - self._used_variables = [] - self._variables_as_list = [] - self._unused_variables = [] - self._used_predecessors_sorted = None - self._output_id = None - self._output_bit_index_previous_comp = None - self._block_needed = None - self._input_id_link_needed = None - - def get_all_variables_as_list(self): - for component_id in list(self._variables.keys())[:-1]: - for bit_position in self._variables[component_id].keys(): - for value in self._variables[component_id][bit_position].keys(): - if value != "current": - varname = self._variables[component_id][bit_position][value].VarName - if varname not in self._variables_as_list: # rot and intermediate has the same name than original - self._variables_as_list.append(varname) - - def get_unused_variables(self): - self.get_all_variables_as_list() - for variable in self._variables_as_list: - if variable not in self._used_variables: - self._unused_variables.append(variable) - - def set_unused_variables_to_zero(self): - self.get_unused_variables() - for name in self._unused_variables: - var = self._model.getVarByName(name) - self._model.addConstr(var == 0) - - def set_as_used_variables(self, variables): - for v in variables: - if v.VarName not in self._used_variables: - self._used_variables.append(v.VarName) - if "copy" in v.VarName.split("_"): - tmp1 = v.VarName.split("_")[2:] - tmp2 = "_".join(tmp1) - self._used_variables.append(tmp2) - - def build_gurobi_model(self): - env = Env(empty=True) - env.setParam('ComputeServer', os.getenv('GUROBI_COMPUTE_SERVER')) - env.start() - # Create a new model - model = Model("basic_model", env=env) - # model = Model() - model.Params.LogToConsole = 0 - # model.Params.Threads = 16 - self._model = model - - def get_anfs_from_sbox(self, component): - anfs = [] - B = BooleanPolynomialRing(component.output_bit_size, 'x') - C = BooleanPolynomialRing(component.output_bit_size, 'x') - var_names = [f"x{i}" for i in range(component.output_bit_size)] - d = {} - for i in range(component.output_bit_size): - d[B(var_names[i])] = C(var_names[component.output_bit_size - i - 1]) - - sbox = SBox(component.description) - for i in range(component.input_bit_size): - anf = sbox.component_function(1 << i).algebraic_normal_form() - anf = anf.subs(d) # x0 was msb, now it is the lsb - anfs.append(anf) - anfs.reverse() - return anfs - - def get_monomial_occurences(self, component): - B = BooleanPolynomialRing(component.input_bit_size, 'x') - anfs = self.get_anfs_from_sbox(component) - - anfs = [B(anfs[i]) for i in range(component.input_bit_size)] - monomials = [] - for index, anf in enumerate(anfs): - if index in list(self._occurences[component.id].keys()): - monomials += anf.monomials() - monomials_degree_based = {} - sbox = SBox(component.description) - for deg in range(sbox.max_degree() + 1): - monomials_degree_based[deg] = dict( - Counter([monomial for monomial in monomials if monomial.degree() == deg])) - if deg >= 2: - for monomial in monomials_degree_based[deg].keys(): - deg1_monomials = monomial.variables() - for deg1_monomial in deg1_monomials: - if deg1_monomial not in monomials_degree_based[1].keys(): - monomials_degree_based[1][deg1_monomial] = 0 - monomials_degree_based[1][deg1_monomial] += monomials_degree_based[deg][monomial] - - sorted_monomials_degree_based = {1: {}} - for xi in B.variable_names(): - if B(xi) not in monomials_degree_based[1].keys(): - sorted_monomials_degree_based[1][B(xi)] = 0 - else: - sorted_monomials_degree_based[1][B(xi)] = monomials_degree_based[1][B(xi)] - for deg in range(sbox.max_degree() + 1): - if deg != 1: - sorted_monomials_degree_based[deg] = monomials_degree_based[deg] - - return sorted_monomials_degree_based - - def create_gurobi_vars_sbox(self, component, input_vars_concat): - monomial_occurences = self.get_monomial_occurences(component) - B = BooleanPolynomialRing(component.input_bit_size, 'x') - x = B.variable_names() - - copy_xi = {} - for index, xi in enumerate(monomial_occurences[1].keys()): - nb_occurence_xi = monomial_occurences[1][B(xi)] - if nb_occurence_xi != 0: - copy_xi[B(xi)] = self._model.addVars(list(range(nb_occurence_xi)), vtype=GRB.BINARY, - name="copy_" + input_vars_concat[index].VarName + "_as_" + str(xi)) - self._model.update() - self.set_as_used_variables(list(copy_xi[B(xi)].values())) - self.set_as_used_variables([input_vars_concat[index]]) - for i in range(nb_occurence_xi): - self._model.addConstr(input_vars_concat[index] >= copy_xi[B(xi)][i]) - self._model.addConstr( - sum(copy_xi[B(xi)][i] for i in range(nb_occurence_xi)) >= input_vars_concat[index]) - - copy_monomials_deg = {} - for deg in list(monomial_occurences.keys()): - if deg >= 2: - nb_monomials = sum(monomial_occurences[deg].values()) - copy_monomials_deg[deg] = self._model.addVars(list(range(nb_monomials)), vtype=GRB.BINARY) - self._model.update() - - copy_monomials_deg[1] = copy_xi - degrees = list(copy_monomials_deg.keys()) - for deg in degrees: - if deg >= 2: - copy_monomials_deg[deg]["current"] = 0 - elif deg == 1: - monomials = list(copy_monomials_deg[1].keys()) - for monomial in monomials: - copy_monomials_deg[deg][monomial]["current"] = 0 - self._model.update() - return copy_monomials_deg - - def add_sbox_constraints(self, component): - output_vars = self.get_output_vars(component) - input_vars_concat = self.get_input_vars(component) - - B = BooleanPolynomialRing(component.input_bit_size, 'x') - x = B.variable_names() - anfs = self.get_anfs_from_sbox(component) - anfs = [B(anfs[i]) for i in range(component.input_bit_size)] - - copy_monomials_deg = self.create_gurobi_vars_sbox(component, input_vars_concat) - - for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): - constr = 0 - equality = True - monomials = anfs[bit_pos].monomials() - for monomial in monomials: - deg = monomial.degree() - if deg == 1: - current = copy_monomials_deg[deg][monomial]["current"] - constr += copy_monomials_deg[deg][monomial][current] - copy_monomials_deg[deg][monomial]["current"] += 1 - elif deg >= 2: - current = copy_monomials_deg[deg]["current"] - for deg1_monomial in monomial.variables(): - current_deg1 = copy_monomials_deg[1][deg1_monomial]["current"] - self._model.addConstr( - copy_monomials_deg[deg][current] == copy_monomials_deg[1][deg1_monomial][current_deg1]) - self.set_as_used_variables([copy_monomials_deg[deg][current]]) - copy_monomials_deg[1][deg1_monomial]["current"] += 1 - constr += copy_monomials_deg[deg][current] - copy_monomials_deg[deg]["current"] += 1 - elif deg == 0: - equality = False - if equality: - self._model.addConstr(output_vars[index] == constr) - else: - self._model.addConstr(output_vars[index] >= constr) - self._model.update() - - def create_copies_for_linear_layer(self, binary_matrix, input_vars_concat): - copies = {} - for index, var in enumerate(input_vars_concat): - column = [row[index] for row in binary_matrix] - number_of_1s = list(column).count(1) - if number_of_1s > 1: - current = 1 - else: - current = 0 - copies[index] = {} - copies[index][0] = var - copies[index]["current"] = current - self.set_as_used_variables([var]) - new_vars = self._model.addVars(list(range(number_of_1s)), vtype=GRB.BINARY, - name="copy_" + var.VarName) - self._model.update() - for i in range(number_of_1s): - self._model.addConstr(var >= new_vars[i]) - self._model.addConstr( - sum(new_vars[i] for i in range(number_of_1s)) >= var) - self._model.update() - for i in range(1, number_of_1s + 1): - copies[index][i] = new_vars[i - 1] - return copies - - def add_linear_layer_constraints(self, component): - output_vars = self.get_output_vars(component) - input_vars_concat = self.get_input_vars(component) - - if component.type == "linear_layer": - binary_matrix = component.description - else: - binary_matrix = binary_matrix_of_linear_component(component) - - copies = self.create_copies_for_linear_layer(binary_matrix, input_vars_concat) - for index_row, row in enumerate(binary_matrix): - constr = 0 - for index_bit, bit in enumerate(row): - if bit: - current = copies[index_bit]["current"] - constr += copies[index_bit][current] - copies[index_bit]["current"] += 1 - self.set_as_used_variables([copies[index_bit][current]]) - self._model.addConstr(output_vars[index_row] == constr) - self._model.update() - - def add_xor_constraints(self, component): - output_vars = self.get_output_vars(component) - - input_vars_concat = [] - constant_flag = [] - for index, input_name in enumerate(component.input_id_links): - for pos in component.input_bit_positions[index]: - current = self._variables[input_name][pos]["current"] - if input_name[:8] == "constant": - const_comp = self._cipher.get_component_from_id(input_name) - constant_flag.append( - (int(const_comp.description[0], 16) >> (const_comp.output_bit_size - 1 - pos)) & 1) - else: - input_vars_concat.append(self._variables[input_name][pos][current]) - self._variables[input_name][pos]["current"] += 1 - - block_size = component.output_bit_size - nb_blocks = component.description[1] - if constant_flag != []: - nb_blocks -= 1 - for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): - constr = 0 - for j in range(nb_blocks): - constr += input_vars_concat[index + block_size * j] - self.set_as_used_variables([input_vars_concat[index + block_size * j]]) - if (constant_flag != []) and (constant_flag[index]): - self._model.addConstr(output_vars[index] >= constr) - else: - self._model.addConstr(output_vars[index] == constr) - self._model.update() - - def create_copies(self, nb_copies, var_to_copy): - copies = self._model.addVars(list(range(nb_copies)), vtype=GRB.BINARY) - for i in range(nb_copies): - self._model.addConstr(var_to_copy >= copies[i]) - self._model.addConstr(sum(copies[i] for i in range(nb_copies)) >= var_to_copy) - self._model.update() - return list(copies.values()) - - def get_output_vars(self, component): - output_vars = [] - tmp = list(self._occurences[component.id].keys()) - tmp.sort() - for i in tmp: - output_vars.append(self._model.getVarByName(f"{component.id}[{i}]")) - return output_vars - - def get_input_vars(self, component): - input_vars_concat = [] - for index, input_name in enumerate(component.input_id_links): - for pos in component.input_bit_positions[index]: - current = self._variables[input_name][pos]["current"] - input_vars_concat.append(self._variables[input_name][pos][current]) - self._variables[input_name][pos]["current"] += 1 - return input_vars_concat - - def add_modadd_constraints(self, component): - # constraints are taken from https://www.iacr.org/archive/asiacrypt2017/106240224/106240224.pdf - output_vars = self.get_output_vars(component) - - input_vars_concat = [] - for index, input_name in enumerate(component.input_id_links): - for pos in component.input_bit_positions[index]: - current = self._variables[input_name][pos]["current"] - input_vars_concat.append(self._variables[input_name][pos][current]) - self._variables[input_name][pos]["current"] += 1 - self.set_as_used_variables([self._variables[input_name][pos][current]]) - - len_concat = len(input_vars_concat) - n = int(len_concat / 2) - copies = {"a": {}, "b": {}} - copies["a"][n - 1] = self.create_copies(2, input_vars_concat[n - 1]) - copies["b"][n - 1] = self.create_copies(2, input_vars_concat[len_concat - 1]) - self._model.addConstr(output_vars[n - 1] == copies["a"][n - 1][0] + copies["b"][n - 1][0]) - - v = [self._model.addVar()] - self._model.addConstr(v[0] == copies["a"][n - 1][1]) - self._model.addConstr(v[0] == copies["b"][n - 1][1]) - - g0, r0 = self.create_copies(2, v[0]) - g = [g0] - r = [r0] - m = [] - q = [] - w = [] - - copies["a"][n - 2] = self.create_copies(3, input_vars_concat[n - 2]) - copies["b"][n - 2] = self.create_copies(3, input_vars_concat[len_concat - 2]) - - for i in range(2, n - 1): - self._model.addConstr(output_vars[n - i] == copies["a"][n - i][0] + copies["b"][n - i][0] + g[i - 2]) - v.append(self._model.addVar()) - self._model.addConstr(v[i - 1] == copies["a"][n - i][1]) - self._model.addConstr(v[i - 1] == copies["b"][n - i][1]) - m.append(self._model.addVar()) - self._model.addConstr(m[i - 2] == copies["a"][n - i][2] + copies["b"][n - i][2]) - q.append(self._model.addVar()) - self._model.addConstr(q[i - 2] == m[i - 2]) - self._model.addConstr(q[i - 2] == r[i - 2]) - w.append(self._model.addVar()) - self._model.addConstr(w[i - 2] == v[i - 1] + q[i - 2]) - g_i_1, r_i_1 = self.create_copies(2, w[i - 2]) - g.append(g_i_1) - r.append(r_i_1) - copies["a"][n - i - 1] = self.create_copies(3, input_vars_concat[n - i - 1]) - copies["b"][n - i - 1] = self.create_copies(3, input_vars_concat[len_concat - i - 1]) - - self._model.addConstr(output_vars[1] == copies["a"][1][0] + copies["b"][1][0] + g[n - 3]) - v.append(self._model.addVar()) - self._model.addConstr(v[n - 2] == copies["a"][1][1]) - self._model.addConstr(v[n - 2] == copies["b"][1][1]) - m.append(self._model.addVar()) - self._model.addConstr(m[n - 3] == copies["a"][1][2] + copies["b"][1][2]) - q.append(self._model.addVar()) - self._model.addConstr(q[n - 3] == m[n - 3]) - self._model.addConstr(q[n - 3] == r[n - 3]) - w.append(self._model.addVar()) - self._model.addConstr(w[n - 3] == v[n - 2] + q[n - 3]) - self._model.addConstr(output_vars[0] == input_vars_concat[0] + input_vars_concat[n] + w[n - 3]) - self._model.update() - - def add_and_constraints(self, component): - # Constraints taken from Misuse-free paper - output_vars = self.get_output_vars(component) - input_vars_concat = self.get_input_vars(component) - - block_size = int(len(input_vars_concat) // component.description[1]) - for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): - self._model.addConstr(output_vars[index] == input_vars_concat[index]) - self._model.addConstr(output_vars[index] == input_vars_concat[index + block_size]) - self.set_as_used_variables([input_vars_concat[index], input_vars_concat[index + block_size]]) - self._model.update() - - def add_not_constraints(self, component): - output_vars = self.get_output_vars(component) - input_vars_concat = self.get_input_vars(component) - - for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): - self._model.addConstr(output_vars[index] >= input_vars_concat[index]) - self.set_as_used_variables([input_vars_concat[index]]) - self._model.update() - - def get_cipher_output_component_id(self): - for component in self._cipher.get_all_components(): - if component.type == "cipher_output": - return component.id - - def add_constraints(self, predecessors, input_id_link_needed, block_needed): - self.build_gurobi_model() - self.create_gurobi_vars_from_all_components(predecessors, input_id_link_needed, block_needed) - - used_predecessors_sorted = self.order_predecessors(list(self._occurences.keys())) - self._used_predecessors_sorted = used_predecessors_sorted - for component_id in used_predecessors_sorted: - if component_id not in self._cipher.inputs: - component = self._cipher.get_component_from_id(component_id) - if component.type == "sbox": - self.add_sbox_constraints(component) - elif component.type in ["linear_layer", "mix_column"]: - self.add_linear_layer_constraints(component) - elif component.type in ["cipher_output", "constant", "intermediate_output"]: - continue - elif component.type == "word_operation": - if component.description[0] == "XOR": - self.add_xor_constraints(component) - elif component.description[0] == "ROTATE": - continue - elif component.description[0] == "AND": - self.add_and_constraints(component) - elif component.description[0] == "NOT": - self.add_not_constraints(component) - elif component.description[0] == "MODADD": - self.add_modadd_constraints(component) - else: - print(f"---> {component.id} not yet implemented") - - return self._model - - def get_where_component_is_used(self, predecessors, input_id_link_needed, block_needed): - occurences = {} - ids = self._cipher.inputs + predecessors - for name in ids: - for component_id in predecessors: - component = self._cipher.get_component_from_id(component_id) - if (name in component.input_id_links) and (component.type not in ["cipher_output"]): - indexes = [i for i, j in enumerate(component.input_id_links) if j == name] - if name not in occurences.keys(): - occurences[name] = [] - for index in indexes: - occurences[name].append(component.input_bit_positions[index]) - if input_id_link_needed in self._cipher.inputs: - occurences[input_id_link_needed] = [block_needed] - else: - component = self._cipher.get_component_from_id(input_id_link_needed) - occurences[input_id_link_needed] = [[i for i in range(component.output_bit_size)]] - - occurences_final = {} - for component_id in occurences.keys(): - occurences_final[component_id] = self.find_copy_indexes(occurences[component_id]) - - self._occurences = occurences_final - return occurences_final - - def find_copy_indexes(self, input_bit_positions): - l = {} - for input_bit_position in input_bit_positions: - for pos in input_bit_position: - if pos not in l.keys(): - l[pos] = 0 - l[pos] += 1 - return l - - def order_predecessors(self, used_predecessors): - for component_id in self._cipher.inputs: - if component_id in list(self._occurences.keys()): - used_predecessors.remove(component_id) - tmp = {} - final = {} - for r in range(self._cipher.number_of_rounds): - tmp[r] = {} - for component_id in used_predecessors: - if int(component_id.split("_")[-2]) == r: - tmp[r][component_id] = int(component_id.split("_")[-1]) - final[r] = {k: v for k, v in sorted(tmp[r].items(), key=lambda item: item[1])} - - used_predecessors_sorted = [] - for r in range(self._cipher.number_of_rounds): - used_predecessors_sorted += list(final[r].keys()) - - l = [] - for component_id in self._cipher.inputs: - if component_id in list(self._occurences.keys()): - l.append(component_id) - used_predecessors_sorted = l + used_predecessors_sorted - return used_predecessors_sorted - - def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_needed, block_needed): - occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed) - all_vars = {} - used_predecessors_sorted = self.order_predecessors(list(occurences.keys())) - for component_id in used_predecessors_sorted: - all_vars[component_id] = {} - # We need the inputs vars to be the first ones defined by gurobi in order to find their values with X.values method. - # That's why we split the following loop: we first created the original vars, and then the copies vars when necessary. - if component_id[:3] == "rot": - component = self._cipher.get_component_from_id(component_id) - rotate_offset = component.description[1] - tmp = [] - for index, input_id_link in enumerate(component.input_id_links): - for j, pos in enumerate(component.input_bit_positions[index]): - current = all_vars[input_id_link][pos]["current"] - tmp.append(all_vars[input_id_link][pos][current]) - all_vars[input_id_link][pos]["current"] += 1 - - tmp2 = [] - for j in range(len(tmp)): - all_vars[component_id][j] = {} - all_vars[component_id][j][0] = tmp[(j - rotate_offset) % component.output_bit_size] - tmp2.append(all_vars[component_id][j][0]) - all_vars[component_id][j]["current"] = 0 - - for pos, gurobi_var in enumerate(tmp2): - if pos in list(occurences[component_id].keys()): - nb_copies_needed = occurences[component_id][pos] - if nb_copies_needed >= 2: - all_vars[component_id][pos]["current"] = 1 - for i in range(nb_copies_needed): - all_vars[component_id][pos][i + 1] = self._model.addVar(vtype=GRB.BINARY, - name=f"copy_{i + 1}_" + gurobi_var.VarName) - self._model.addConstr( - all_vars[component_id][pos][0] >= all_vars[component_id][pos][i + 1]) - self._model.addConstr( - sum(all_vars[component_id][pos][i + 1] for i in range(nb_copies_needed)) >= - all_vars[component_id][pos][0]) - elif component_id[:5] == "inter": - component = self._cipher.get_component_from_id(component_id) - tmp = [] - for index, input_id_link in enumerate(component.input_id_links): - for j, pos in enumerate(component.input_bit_positions[index]): - current = all_vars[input_id_link][pos]["current"] - tmp.append(all_vars[input_id_link][pos][current]) - all_vars[input_id_link][pos]["current"] += 1 - - for j in range(len(tmp)): - all_vars[component_id][j] = {} - all_vars[component_id][j][0] = tmp[j] - all_vars[component_id][j]["current"] = 0 - - for pos, gurobi_var in enumerate(tmp): - if pos in list(occurences[component_id].keys()): - nb_copies_needed = occurences[component_id][pos] - if nb_copies_needed >= 2: - all_vars[component_id][pos]["current"] = 1 - for i in range(nb_copies_needed): - all_vars[component_id][pos][i + 1] = self._model.addVar(vtype=GRB.BINARY, - name=f"copy_{i + 1}_" + gurobi_var.VarName) - self._model.addConstr( - all_vars[component_id][pos][0] >= all_vars[component_id][pos][i + 1]) - self._model.addConstr( - sum(all_vars[component_id][pos][i + 1] for i in range(nb_copies_needed)) >= - all_vars[component_id][pos][0]) - else: - for pos in list(occurences[component_id].keys()): - all_vars[component_id][pos] = {} - all_vars[component_id][pos][0] = self._model.addVar(vtype=GRB.BINARY, - name=component_id + f"[{pos}]") - all_vars[component_id][pos]["current"] = 0 - for pos in list(occurences[component_id].keys()): - nb_copies_needed = occurences[component_id][pos] - if nb_copies_needed >= 2: - all_vars[component_id][pos]["current"] = 1 - for i in range(nb_copies_needed): - all_vars[component_id][pos][i + 1] = self._model.addVar(vtype=GRB.BINARY, - name=f"copy_{i + 1}_" + component_id + f"[{pos}]") - self._model.addConstr(all_vars[component_id][pos][0] >= all_vars[component_id][pos][i + 1]) - self._model.addConstr( - sum(all_vars[component_id][pos][i + 1] for i in range(nb_copies_needed)) >= - all_vars[component_id][pos][0]) - self._model.update() - - self._model.update() - # print("all_vars") - # print(all_vars) - self._model.update() - self._variables = all_vars - - def find_index_second_input(self): - occurences = self._occurences - count = 0 - for pos in list(occurences[self._cipher.inputs[0]].keys()): - if occurences[self._cipher.inputs[0]][pos] > 1: - count += occurences[self._cipher.inputs[0]][pos] + 1 - else: - count += occurences[self._cipher.inputs[0]][pos] - return count - - def get_output_bit_index_previous_component(self, output_bit_index_ciphertext, chosen_cipher_output=None): - if chosen_cipher_output != None: - pivot = 0 - for comp in self._cipher.get_all_components(): - for index, id_link in enumerate(comp.input_id_links): - if chosen_cipher_output == id_link: - output_id = comp.id - block_needed = comp.input_bit_positions[index] - input_id_link_needed = chosen_cipher_output - output_bit_index_previous_comp = output_bit_index_ciphertext - return output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot - else: - output_id = self.get_cipher_output_component_id() - component = self._cipher.get_component_from_id(output_id) - pivot = 0 - output_bit_index_previous_comp = output_bit_index_ciphertext - for index, block in enumerate(component.input_bit_positions): - if pivot <= output_bit_index_ciphertext < pivot + len(block): - output_bit_index_previous_comp = block[output_bit_index_ciphertext - pivot] - block_needed = block - input_id_link_needed = component.input_id_links[index] - break - pivot += len(block) - - if input_id_link_needed[:5] == "inter": - pivot = 0 - component_inter = self._cipher.get_component_from_id(input_id_link_needed) - for index, block in enumerate(component_inter.input_bit_positions): - if pivot <= block_needed[output_bit_index_previous_comp] < pivot + len(block): - output_bit_index_before_inter = block[block_needed[output_bit_index_previous_comp] - pivot] - input_id_link_needed = component_inter.input_id_links[index] - block_needed = block - break - pivot += len(block) - output_bit_index_previous_comp = output_bit_index_before_inter - return output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot - - def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertext, fixed_degree=None, - chosen_cipher_output=None): - start = time.time() - output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot = self.get_output_bit_index_previous_component( - output_bit_index_ciphertext, chosen_cipher_output) - - self._output_id = output_id - self._output_bit_index_previous_comp = output_bit_index_previous_comp - self._block_needed = block_needed - self._input_id_link_needed = input_id_link_needed - - G = create_networkx_graph_from_input_ids(self._cipher) - predecessors = list(_get_predecessors_subgraph(G, [input_id_link_needed])) - for input_id in self._cipher.inputs + ['']: - if input_id in predecessors: - predecessors.remove(input_id) - - self.add_constraints(predecessors, input_id_link_needed, block_needed) - - var_from_block_needed = [] - for i in block_needed: - var_from_block_needed.append(self._variables[input_id_link_needed][i][0]) - - output_vars = self._model.addVars(list(range(pivot, pivot + len(block_needed))), vtype=GRB.BINARY, - name=output_id) - self._variables[output_id] = output_vars - output_vars = list(output_vars.values()) - self._model.update() - - for i in range(len(block_needed)): - self._model.addConstr(output_vars[i] == var_from_block_needed[i]) - self.set_as_used_variables([output_vars[i], var_from_block_needed[i]]) - - ks = self._model.addVar() - self._model.addConstr(ks == sum(output_vars[i] for i in range(len(block_needed)))) - self._model.addConstr(ks == 1) - self._model.addConstr(output_vars[output_bit_index_previous_comp] == 1) - - if fixed_degree != None: - plaintext_vars = [] - for i in range( - self._cipher.inputs_bit_size[0]): # Carreful, here we are assuming that input[0] is the plaintext - plaintext_vars.append(self._model.getVarByName(f"plaintext[{i}]")) - self._model.addConstr( - sum(plaintext_vars[i] for i in range(self._cipher.inputs_bit_size[0])) == fixed_degree) - - self.set_unused_variables_to_zero() - self._model.update() - end = time.time() - building_time = end - start - if verbosity: - print(f"########## building_time : {building_time}") - self._model.update() - - def get_solutions(self): - start = time.time() - index_second_input = self.find_index_second_input() - nb_inputs_used = 0 - for input_id in self._cipher.inputs: - if input_id in list(self._occurences.keys()): - nb_inputs_used += 1 - if nb_inputs_used == 2: - max_input_bit_pos = index_second_input + len(list(self._occurences[self._cipher.inputs[1]].keys())) - first_input_bit_positions = list(self._occurences[self._cipher.inputs[0]].keys()) - second_input_bit_positions = list(self._occurences[self._cipher.inputs[1]].keys()) - else: - max_input_bit_pos = index_second_input - first_input_bit_positions = list(self._occurences[self._cipher.inputs[0]].keys()) - - solCount = self._model.SolCount - monomials = [] - for sol in range(solCount): - self._model.setParam(GRB.Param.SolutionNumber, sol) - values = self._model.Xn - - tmp = "" - for index, v in enumerate(values[:max_input_bit_pos]): - if v == 1: - if nb_inputs_used > 1: - if index < len(list(self._occurences[self._cipher.inputs[0]].keys())): - tmp += self._cipher.inputs[0][0] + str(first_input_bit_positions[index]) - elif index_second_input <= index < index_second_input + len( - list(self._occurences[self._cipher.inputs[1]].keys())): - tmp += self._cipher.inputs[1][0] + str( - second_input_bit_positions[abs(index_second_input - index)]) - else: - if index < len(list(self._occurences[self._cipher.inputs[0]].keys())): - tmp += self._cipher.inputs[0][0] + str(first_input_bit_positions[index]) - if 1 not in values[:max_input_bit_pos]: - tmp += str(1) - else: - if nb_inputs_used == 1: - input1_prefix = self._cipher.inputs[0][0] - l = tmp.split(input1_prefix)[1:] - sorted_l = sorted(l, key=lambda x: (x == '', int(x) if x else 0)) - l = [''] + sorted_l - tmp = input1_prefix.join(l) - - if tmp in monomials: - monomials.remove(tmp) - else: - monomials.append(tmp) - - end = time.time() - printing_time = end - start - if verbosity: - print('Number of solutions (might cancel each other) found: ' + str(solCount)) - print(f"########## printing_time : {printing_time}") - print(f'Number of monomials found: {len(monomials)}') - return monomials - - def optimize_model(self): - start = time.time() - self._model.optimize() - end = time.time() - solving_time = end - start - if verbosity: - print(self._model) - print(f"########## solving_time : {solving_time}") - - def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, chosen_cipher_output=None): - self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output) - self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large - self._model.setParam(GRB.Param.PoolSearchMode, 2) - self._model.write("division_trail_model.lp") - - self.optimize_model() - return self.get_solutions() - - def check_presence_of_particular_monomial_in_specific_anf(self, monomial, output_bit_index, fixed_degree=None, - chosen_cipher_output=None): - self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output) - self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large - self._model.setParam(GRB.Param.PoolSearchMode, 2) - - for term in monomial: - var_term = self._model.getVarByName(f"{term[0]}[{term[1]}]") - self._model.addConstr(var_term == 1) - self._model.update() - self._model.write("division_trail_model.lp") - - self.optimize_model() - return self.get_solutions() - - def check_presence_of_particular_monomial_in_all_anf(self, monomial, fixed_degree=None, - chosen_cipher_output=None): - s = "" - for term in monomial: - s += term[0][0] + str(term[1]) - for i in range(self._cipher.output_bit_size): - print(f"\nSearch of {s} in anf {i} :") - self.check_presence_of_particular_monomial_in_specific_anf(monomial, i, fixed_degree, - chosen_cipher_output) - - def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_output=None, cube_index=[]): - fixed_degree = None - self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output) - self._model.setParam(GRB.Param.PoolSearchMode, 1) - self._model.setParam('Presolve', 2) - self._model.setParam("MIPFocus", 2) - self._model.setParam("MIPGap", 0) # when set to 0, best solution = optimal solution - self._model.setParam('Cuts', 2) - - index_plaintext = self._cipher.inputs.index("plaintext") - plaintext_bit_size = self._cipher.inputs_bit_size[index_plaintext] - p = [] - nb_plaintext_bits_used = len(list(self._occurences["plaintext"].keys())) - for i in range(nb_plaintext_bits_used): - p.append(self._model.getVarByName(f"plaintext[{i}]")) - self._model.setObjective(sum(p[i] for i in range(nb_plaintext_bits_used)), GRB.MAXIMIZE) - - if cube_index: - for i in range(plaintext_bit_size): - if i not in cube_index: - self._model.addConstr(p[i] == 0) - - self._model.update() - self._model.write("division_trail_model.lp") - self.optimize_model() - - degree = self._model.getObjective().getValue() - return degree - - def re_init(self): - self._variables = None - self._model = None - self._occurences = None - self._used_variables = [] - self._variables_as_list = [] - self._unused_variables = [] - - def find_degree_of_all_output_bits(self, chosen_cipher_output=None): - for i in range(self._cipher.output_bit_size): - self.re_init() - degree = self.find_degree_of_specific_output_bit(i, chosen_cipher_output) - print(f"Degree of anf corresponding to output bit at position {i} = {degree}\n") diff --git a/claasp/cipher_modules/evaluator.py b/claasp/cipher_modules/evaluator.py index 793414604..b7bbae168 100644 --- a/claasp/cipher_modules/evaluator.py +++ b/claasp/cipher_modules/evaluator.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,13 +21,14 @@ from subprocess import Popen, PIPE from claasp.cipher_modules import code_generator -from claasp.cipher_modules.generic_functions_vectorized_byte import cipher_inputs_to_evaluate_vectorized_inputs, \ - evaluate_vectorized_outputs_to_integers +from claasp.cipher_modules.generic_functions_vectorized_byte import ( + cipher_inputs_to_evaluate_vectorized_inputs, + evaluate_vectorized_outputs_to_integers, +) def evaluate(cipher, cipher_input, intermediate_output=False, verbosity=False): python_code_string = code_generator.generate_python_code_string(cipher, verbosity) - f_module = ModuleType("evaluate") exec(python_code_string, f_module.__dict__) @@ -41,62 +42,86 @@ def evaluate_using_c(cipher, inputs, intermediate_output, verbosity): cipher.generate_evaluate_c_code_shared_library(intermediate_output, verbosity) name = cipher.id + "_evaluate" c_cipher_inputs = [hex(value) for value in inputs] - process = Popen([code_generator.TII_C_LIB_PATH + name + ".o"] + c_cipher_inputs, stdout=PIPE) + process = Popen( + [code_generator.TII_C_LIB_PATH + name + ".o"] + c_cipher_inputs, stdout=PIPE + ) output = process.stdout if verbosity and intermediate_output: - line = output.readline().decode('utf-8') + line = output.readline().decode("utf-8") - while line != '{\n': - print(line[:-1].decode('utf-8')) - line = output.readline().decode('utf-8') + while line != "{\n": + print(line[:-1].decode("utf-8")) + line = output.readline().decode("utf-8") dict_str = line for line in output.readlines(): - dict_str += line.decode('utf-8') + dict_str += line.decode("utf-8") function_output = eval(dict_str) elif verbosity and not intermediate_output: output_lines = output.readlines() for line in output_lines[:-1]: - print(line[:-1].decode('utf-8')) + print(line[:-1].decode("utf-8")) - function_output = int(output_lines[-1].decode('utf-8')[:-1], 16) + function_output = int(output_lines[-1].decode("utf-8")[:-1], 16) elif intermediate_output: - dict_str = '' + dict_str = "" for line in output.readlines(): - dict_str += line.decode('utf-8') + dict_str += line.decode("utf-8") function_output = eval(dict_str) else: - function_output = int(output.read().decode('utf-8')[:-1], 16) + function_output = int(output.read().decode("utf-8")[:-1], 16) code_generator.delete_generated_evaluate_c_shared_library(cipher) return function_output -def evaluate_vectorized(cipher, cipher_input, intermediate_output=False, verbosity=False, evaluate_api=False, - bit_based=False): - python_code_string = code_generator.generate_byte_based_vectorized_python_code_string(cipher, - store_intermediate_outputs=intermediate_output, - verbosity=verbosity, - integers_inputs_and_outputs=evaluate_api) +def evaluate_vectorized( + cipher, + cipher_input, + intermediate_output=False, + verbosity=False, + evaluate_api=False, +): + python_code_string = ( + code_generator.generate_byte_based_vectorized_python_code_string( + cipher, + store_intermediate_outputs=intermediate_output, + verbosity=verbosity, + integers_inputs_and_outputs=evaluate_api, + ) + ) f_module = ModuleType("evaluate") exec(python_code_string, f_module.__dict__) cipher_output = f_module.evaluate(cipher_input, intermediate_output) return cipher_output -def evaluate_with_intermediate_outputs_continuous_diffusion_analysis(cipher, cipher_input, sbox_precomputations, - sbox_precomputations_mix_columns, verbosity=False): - python_code_string = code_generator.generate_python_code_string_for_continuous_diffusion_analysis(cipher, verbosity) +def evaluate_with_intermediate_outputs_continuous_diffusion_analysis( + cipher, + cipher_input, + sbox_precomputations, + sbox_precomputations_mix_columns, + verbosity=False, +): + python_code_string = ( + code_generator.generate_python_code_string_for_continuous_diffusion_analysis( + cipher, verbosity + ) + ) python_code_string = python_code_string.replace( - "def evaluate(input):", "def evaluate(input, sbox_precomputations, sbox_precomputations_mix_columns):") + "def evaluate(input):", + "def evaluate(input, sbox_precomputations, sbox_precomputations_mix_columns):", + ) f_module = ModuleType("evaluate_continuous_diffusion_analysis") exec(python_code_string, f_module.__dict__) - return f_module.evaluate(cipher_input, sbox_precomputations, sbox_precomputations_mix_columns) + return f_module.evaluate( + cipher_input, sbox_precomputations, sbox_precomputations_mix_columns + ) diff --git a/claasp/cipher_modules/inverse_cipher.py b/claasp/cipher_modules/inverse_cipher.py index 49f830faf..253275cc5 100644 --- a/claasp/cipher_modules/inverse_cipher.py +++ b/claasp/cipher_modules/inverse_cipher.py @@ -1,34 +1,58 @@ from copy import * from sage.crypto.sbox import SBox -from claasp.cipher_modules.component_analysis_tests import binary_matrix_of_linear_component, \ - get_inverse_matrix_in_integer_representation +from sage.rings.finite_rings.finite_field_constructor import FiniteField as GF +from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing + +from claasp.cipher_modules.component_analysis_tests import ( + binary_matrix_of_linear_component, + get_inverse_matrix_in_integer_representation, +) from claasp.cipher_modules.graph_generator import create_networkx_graph_from_input_ids from claasp.component import Component -from claasp.components import modsub_component, cipher_output_component, linear_layer_component, \ - intermediate_output_component +from claasp.components import ( + cipher_output_component, + intermediate_output_component, + linear_layer_component, + modsub_component, +) from claasp.input import Input -from sage.rings.finite_rings.finite_field_constructor import FiniteField as GF -from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing from claasp.cipher_modules.component_analysis_tests import int_to_poly -from claasp.name_mappings import * +from claasp.name_mappings import ( + CIPHER_INPUT, + CIPHER_OUTPUT, + CONSTANT, + INPUT_KEY, + INPUT_PLAINTEXT, + INPUT_STATE, + INPUT_TWEAK, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) def get_cipher_components(self): component_list = self.get_all_components() for c in component_list: - setattr(c, 'round', int(c.id.split("_")[-2])) + setattr(c, "round", int(c.id.split("_")[-2])) # build input components for index, input_id in enumerate(self.inputs): if INPUT_KEY in input_id: - input_component = Component(input_id, "cipher_input", Input(0, [[]], [[]]), self.inputs_bit_size[index], - [INPUT_KEY]) + input_component = Component( + input_id, "cipher_input", Input(0, [[]], [[]]), self.inputs_bit_size[index], [INPUT_KEY] + ) else: - input_component = Component(input_id, "cipher_input", Input(0, [[]], [[]]), self.inputs_bit_size[index], [input_id]) - setattr(input_component, 'round', -1) + input_component = Component( + input_id, "cipher_input", Input(0, [[]], [[]]), self.inputs_bit_size[index], [input_id] + ) + setattr(input_component, "round", -1) component_list.append(input_component) return component_list + def get_all_components_with_the_same_input_id_link_and_input_bit_positions(input_id_link, input_bit_positions, self): cipher_components = get_cipher_components(self) output_list = [] @@ -39,14 +63,25 @@ def get_all_components_with_the_same_input_id_link_and_input_bit_positions(input list_to_be_compared = copy(c.input_bit_positions[i]) list_to_be_compared.sort() # if input_id_link == c.input_id_links[i] and list_to_be_compared in copy_input_bit_positions: #changed adding sort - if input_id_link == c.input_id_links[i] and all(ele in copy_input_bit_positions for ele in list_to_be_compared): #changed adding sort + if input_id_link == c.input_id_links[i] and all( + ele in copy_input_bit_positions for ele in list_to_be_compared + ): # changed adding sort output_list.append(c) break return output_list def are_equal_components(component1, component2): - attributes = ["id", "type", "input_id_links", "input_bit_size", "input_bit_positions", "output_bit_position", "description", "round"] + attributes = [ + "id", + "type", + "input_id_links", + "input_bit_size", + "input_bit_positions", + "output_bit_position", + "description", + "round", + ] for attr in attributes: if getattr(component1, attr) != getattr(component2, attr): return False @@ -75,12 +110,11 @@ def get_output_components(component, self): def is_bit_contained_in(bit, available_bits): for b in available_bits: - if bit["component_id"] == b["component_id"] and \ - bit["position"] == b["position"] and \ - bit["type"] == b["type"]: + if bit["component_id"] == b["component_id"] and bit["position"] == b["position"] and bit["type"] == b["type"]: return True return False + def add_bit_to_bit_list(bit, bit_list): if not is_bit_contained_in(bit, bit_list): bit_list.append(bit) @@ -89,15 +123,12 @@ def add_bit_to_bit_list(bit, bit_list): def _are_all_bits_available(id, input_bit_positions_len, offset, available_bits): for j in range(input_bit_positions_len): - bit = { - "component_id": id, - "position": offset + j, - "type": "input" - } + bit = {"component_id": id, "position": offset + j, "type": "input"} if not is_bit_contained_in(bit, available_bits): return False return True + def get_available_output_components(component, available_bits, self, return_index=False): cipher_components = get_cipher_components(self) available_output_components = [] @@ -105,17 +136,21 @@ def get_available_output_components(component, available_bits, self, return_inde accumulator = 0 for i in range(len(c.input_id_links)): if (component.id == c.input_id_links[i]) and (c not in available_output_components): - all_bits_available = _are_all_bits_available(c.id, len(c.input_bit_positions[i]), accumulator, - available_bits) + all_bits_available = _are_all_bits_available( + c.id, len(c.input_bit_positions[i]), accumulator, available_bits + ) if all_bits_available: if return_index: - available_output_components.append((c, list(range(accumulator, accumulator + len(c.input_bit_positions[i]))))) + available_output_components.append( + (c, list(range(accumulator, accumulator + len(c.input_bit_positions[i])))) + ) else: available_output_components.append(c) - accumulator += len(c.input_bit_positions[i]) # changed + accumulator += len(c.input_bit_positions[i]) # changed return available_output_components + def sort_input_id_links_and_input_bit_positions(input_id_links, input_bit_positions, component, self): updated_input_bit_positions = [] updated_input_id_links = [] @@ -131,8 +166,9 @@ def sort_input_id_links_and_input_bit_positions(input_id_links, input_bit_positi if len(ordered_list) == 0: l = component_input_id_link.input_bit_positions[position] if l != sorted(l): - l_ordered = find_correct_order_for_inversion(l, input_bit_positions[index], - component_input_id_link) + l_ordered = find_correct_order_for_inversion( + l, input_bit_positions[index], component_input_id_link + ) else: l_ordered = input_bit_positions[index] ordered_list.append(l) @@ -149,8 +185,9 @@ def sort_input_id_links_and_input_bit_positions(input_id_links, input_bit_positi ordered_list.insert(position_to_insert, component_input_id_link.input_bit_positions[position]) l = component_input_id_link.input_bit_positions[position] if l != sorted(l): - l_ordered = find_correct_order_for_inversion(l, input_bit_positions[index], - component_input_id_link) + l_ordered = find_correct_order_for_inversion( + l, input_bit_positions[index], component_input_id_link + ) else: l_ordered = input_bit_positions[index] updated_input_bit_positions.insert(position_to_insert, l_ordered) @@ -158,6 +195,7 @@ def sort_input_id_links_and_input_bit_positions(input_id_links, input_bit_positi index += 1 return updated_input_id_links, updated_input_bit_positions + def is_bit_adjacent_to_list_of_bits(bit_name, list_of_bit_names, all_equivalent_bits): if bit_name not in all_equivalent_bits.keys(): return False @@ -166,27 +204,28 @@ def is_bit_adjacent_to_list_of_bits(bit_name, list_of_bit_names, all_equivalent_ return True return False + def equivalent_bits_in_common(bits_of_an_output_component, component_bits, all_equivalent_bits): bits_in_common = [] for bit1 in bits_of_an_output_component: - bit_name1 = bit1["component_id"] + "_" + str(bit1["position"]) + "_" + bit1["type"] + bit_name1 = f"{bit1['component_id']}_{bit1['position']}_{bit1['type']}" if bit_name1 not in all_equivalent_bits.keys(): return [] for bit2 in component_bits: - bit_name2 = bit2["component_id"] + "_" + str(bit2["position"]) + "_" + bit2["type"] + bit_name2 = f"{bit2['component_id']}_{bit2['position']}_{bit2['type']}" if bit_name2 in all_equivalent_bits[bit_name1]: bits_in_common.append(bit1) break return bits_in_common -def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components(component, - available_output_components, - all_equivalent_bits, - self): + +def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, available_output_components, all_equivalent_bits, self +): tmp_input_id_links = [] tmp_input_bit_positions = [] for bit_position in range(component.output_bit_size): - bit_name_input = component.id + "_" + str(bit_position) + "_output" + bit_name_input = f"{component.id}_{bit_position}_output" flag_link_found = False for c in available_output_components: if is_possibly_invertible_component(c): @@ -194,24 +233,28 @@ def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_av l = [] for index, link in enumerate(c.input_id_links): if link == component.id: - l += list(range(starting_bit_position, starting_bit_position + len(c.input_bit_positions[index]))) + l += list( + range(starting_bit_position, starting_bit_position + len(c.input_bit_positions[index])) + ) starting_bit_position += len(c.input_bit_positions[index]) for i in l: - bit_name = c.id + "_" + str(i) + "_input" + bit_name = f"{c.id}_{i}_input" if is_bit_adjacent_to_list_of_bits(bit_name_input, [bit_name], all_equivalent_bits): if c.input_bit_size == c.output_bit_size: - bit_name_output_updated = c.id + "_" + str(i) + "_output_updated" - if is_bit_adjacent_to_list_of_bits(bit_name, [bit_name_output_updated], - all_equivalent_bits): + bit_name_output_updated = f"{c.id}_{i}_output_updated" + if is_bit_adjacent_to_list_of_bits( + bit_name, [bit_name_output_updated], all_equivalent_bits + ): tmp_input_id_links.append(c.id) tmp_input_bit_positions.append(i) flag_link_found = True break else: for j in range(c.output_bit_size): - bit_name_output_updated = c.id + "_" + str(j) + "_output_updated" - if is_bit_adjacent_to_list_of_bits(bit_name, [bit_name_output_updated], - all_equivalent_bits): + bit_name_output_updated = f"{c.id}_{j}_output_updated" + if is_bit_adjacent_to_list_of_bits( + bit_name, [bit_name_output_updated], all_equivalent_bits + ): tmp_input_id_links.append(c.id) tmp_input_bit_positions.append(j) flag_link_found = True @@ -239,6 +282,7 @@ def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_av return input_id_links, input_bit_positions + def get_all_bit_names(self): dictio = {} cipher_components = get_cipher_components(self) @@ -248,17 +292,9 @@ def get_all_bit_names(self): for index, input_id_link in enumerate(c.input_id_links): j = 0 for i in c.input_bit_positions[index]: - output_bit = { - "component_id": input_id_link, - "position": i, - "type": "output" - } - output_bit_name = input_id_link + "_" + str(i) + "_output" - input_bit = { - "component_id": c.id, - "position": starting_bit_position + j, - "type": "input" - } + output_bit = {"component_id": input_id_link, "position": i, "type": "output"} + output_bit_name = f"{input_id_link}_{i}_output" + input_bit = {"component_id": c.id, "position": starting_bit_position + j, "type": "input"} input_bit_name = c.id + "_" + str(starting_bit_position + j) + "_input" if output_bit_name not in dictio.keys(): dictio[output_bit_name] = output_bit @@ -266,27 +302,24 @@ def get_all_bit_names(self): dictio[input_bit_name] = input_bit if c.type != CIPHER_OUTPUT: - output_updated_bit = { - "component_id": input_id_link, - "position": i, - "type": "output_updated" - } - output_updated_bit_name = input_id_link + "_" + str(i) + "_output_updated" + output_updated_bit = {"component_id": input_id_link, "position": i, "type": "output_updated"} + output_updated_bit_name = f"{input_id_link}_{i}_output_updated" if output_updated_bit_name not in dictio.keys(): # changed, if added dictio[output_updated_bit_name] = output_updated_bit output_updated_bit = { - "component_id": c.id, - "position": starting_bit_position + j, - "type": "output_updated" - } - output_updated_bit_name = c.id + "_" + str(starting_bit_position + j) + "_output_updated" - if output_updated_bit_name not in dictio.keys(): # changed, if added + "component_id": c.id, + "position": starting_bit_position + j, + "type": "output_updated", + } + output_updated_bit_name = f"{c.id}_{starting_bit_position + j}_output_updated" + if output_updated_bit_name not in dictio.keys(): # changed, if added dictio[output_updated_bit_name] = output_updated_bit j += 1 starting_bit_position += len(c.input_bit_positions[index]) return dictio + def get_all_equivalent_bits(self): dictio = {} component_list = self.get_all_components() @@ -298,8 +331,8 @@ def get_all_equivalent_bits(self): else: input_bit_positions = c.input_bit_positions[index] for i in input_bit_positions: - output_bit_name = input_id_link + "_" + str(i) + "_output" - input_bit_name = c.id + "_" + str(current_bit_position) + "_input" + output_bit_name = f"{input_id_link}_{i}_output" + input_bit_name = f"{c.id}_{current_bit_position}_input" current_bit_position += 1 if output_bit_name not in dictio.keys(): dictio[output_bit_name] = [] @@ -318,7 +351,10 @@ def get_all_equivalent_bits(self): return updated_dictio -def get_equivalent_input_bit_from_output_bit(potential_unwanted_component, base_component, available_bits, all_equivalent_bits, key_schedule_components, self): + +def get_equivalent_input_bit_from_output_bit( + potential_unwanted_component, base_component, available_bits, all_equivalent_bits, key_schedule_components, self +): all_bit_names = get_all_bit_names(self) potential_unwanted_bits = [] potential_unwanted_bits_names = [] @@ -328,23 +364,21 @@ def get_equivalent_input_bit_from_output_bit(potential_unwanted_component, base_ input_bit_positions_of_potential_unwanted_component = base_component.input_bit_positions[index] for i in input_bit_positions_of_potential_unwanted_component: - output_bit = { - "component_id": potential_unwanted_component.id, - "position": i, - "type": "output" - } - output_bit_name = potential_unwanted_component.id + "_" + str(i) + "_output" + output_bit = {"component_id": potential_unwanted_component.id, "position": i, "type": "output"} + output_bit_name = f"{potential_unwanted_component.id}_{i}_output" potential_unwanted_bits.append(output_bit) potential_unwanted_bits_names.append(output_bit_name) equivalent_bits = [] for potential_unwanted_bits_name in potential_unwanted_bits_names: for equivalent_bit in all_equivalent_bits[potential_unwanted_bits_name]: - if (equivalent_bit in all_bit_names.keys()) and ( - all_bit_names[equivalent_bit]["component_id"] != base_component.id) and ( - all_bit_names[equivalent_bit] in available_bits) and ( - all_bit_names[equivalent_bit]["component_id"] not in key_schedule_components) and ( - all_bit_names[equivalent_bit]["type"] == "output_updated"): # changed, line added + if ( + (equivalent_bit in all_bit_names.keys()) + and (all_bit_names[equivalent_bit]["component_id"] != base_component.id) + and (all_bit_names[equivalent_bit] in available_bits) + and (all_bit_names[equivalent_bit]["component_id"] not in key_schedule_components) + and (all_bit_names[equivalent_bit]["type"] == "output_updated") + ): # changed, line added if len(equivalent_bits) == 0: equivalent_bits.append(equivalent_bit) elif all_bit_names[equivalent_bit]["component_id"] == all_bit_names[equivalent_bits[0]]["component_id"]: @@ -359,11 +393,10 @@ def get_equivalent_input_bit_from_output_bit(potential_unwanted_component, base_ input_bit_positions.sort() return all_bit_names[equivalent_bits[0]]["component_id"], input_bit_positions -def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_input_components(component, - available_bits, - all_equivalent_bits, - key_schedule_components, - self): + +def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_input_components( + component, available_bits, all_equivalent_bits, key_schedule_components, self +): input_id_links = [] input_bit_positions = [] for i in range(len(component.input_id_links)): @@ -373,7 +406,7 @@ def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_in bit = { "component_id": component.input_id_links[i], "position": component.input_bit_positions[i][j], - "type": "output" + "type": "output", } bits.append(bit) if not is_bit_contained_in(bit, available_bits): @@ -381,9 +414,16 @@ def compute_input_id_links_and_input_bit_positions_for_inverse_component_from_in break if component_available: potential_unwanted_component = get_component_from_id(component.input_id_links[i], self) - equivalent_component, input_bit_positions_of_equivalent_component = get_equivalent_input_bit_from_output_bit( - potential_unwanted_component, component, available_bits, all_equivalent_bits, key_schedule_components, - self) + equivalent_component, input_bit_positions_of_equivalent_component = ( + get_equivalent_input_bit_from_output_bit( + potential_unwanted_component, + component, + available_bits, + all_equivalent_bits, + key_schedule_components, + self, + ) + ) input_id_links.append(equivalent_component) input_bit_positions.append(input_bit_positions_of_equivalent_component) @@ -395,16 +435,11 @@ def component_input_bits(component): for index, link in enumerate(component.input_id_links): tmp = [] for position in component.input_bit_positions[index]: - tmp.append( - { - "component_id": link, - "position": position, - "type": "output_updated" - } - ) + tmp.append({"component_id": link, "position": position, "type": "output_updated"}) component_input_bits_list.append(tmp) return component_input_bits_list + def component_output_bits(component, self): # set of list_bits needed to invert output_components = get_output_components(component, self) @@ -412,21 +447,19 @@ def component_output_bits(component, self): for c in output_components: tmp = [] for j in range(c.output_bit_size): - bit = { - "component_id": c.id, - "position": j, - "type": "output_updated" - } + bit = {"component_id": c.id, "position": j, "type": "output_updated"} tmp.append(bit) component_output_bits_list.append(tmp) return component_output_bits_list + def are_these_bits_available(bits_list, available_bits): for bit in bits_list: if bit not in available_bits: return False return True + # def are_there_enough_available_inputs_to_evaluate_component(component, available_bits, all_equivalent_bits, key_schedule_components, # self): # # check input links @@ -466,17 +499,24 @@ def are_these_bits_available(bits_list, available_bits): # can_be_evaluated[index] = True # return sum(can_be_evaluated) == len(can_be_evaluated) -def are_there_enough_available_inputs_to_evaluate_component(component, available_bits, all_equivalent_bits, key_schedule_components, self): + +def are_there_enough_available_inputs_to_evaluate_component( + component, available_bits, all_equivalent_bits, key_schedule_components, self +): # check input links component_input_bits_list = component_input_bits(component) can_be_evaluated = [True] * len(component_input_bits_list) available_output_components = [] - if component.type in [CONSTANT, CIPHER_INPUT]: + if component.type in (CONSTANT, CIPHER_INPUT): return False for index, bits_list in enumerate(component_input_bits_list): if not are_these_bits_available(bits_list, available_bits): can_be_evaluated[index] = False - available_input_components = [get_component_from_id(c_id, self) for i,c_id in enumerate(component.input_id_links) if can_be_evaluated[i] == True] + available_input_components = [ + get_component_from_id(c_id, self) + for i, c_id in enumerate(component.input_id_links) + if can_be_evaluated[i] == True + ] if sum(can_be_evaluated) == len(can_be_evaluated): return True @@ -488,32 +528,32 @@ def are_there_enough_available_inputs_to_evaluate_component(component, available # can_be_evaluated_from_outputs = [False] * len(output_components) link_bit_names = [] for bit in component_input_bits_list[index]: - link_bit_name = bit["component_id"] + "_" + str(bit["position"]) + "_output" + link_bit_name = f"{bit['component_id']}_{bit['position']}_output" link_bit_names.append(link_bit_name) - for index_output_comp, output_component in enumerate(output_components): - if (output_component.id not in component.input_id_links) and ( - output_component.id != component.id): + for _, output_component in enumerate(output_components): + if (output_component.id not in component.input_id_links) and (output_component.id != component.id): index_id = output_component.input_id_links.index(link) starting_bit = 0 for index_list, list_bit_positions in enumerate(output_component.input_bit_positions): if index_list == index_id: break starting_bit += len(list_bit_positions) - output_component_bit_name = output_component.id + "_" + str(starting_bit) + "_output_updated" - if is_bit_adjacent_to_list_of_bits(output_component_bit_name, link_bit_names, - all_equivalent_bits): + output_component_bit_name = f"{output_component.id}_{starting_bit}_output_updated" + if is_bit_adjacent_to_list_of_bits( + output_component_bit_name, link_bit_names, all_equivalent_bits + ): # can_be_evaluated[index] = True available_output_components.append(output_component) list_of_bit_names = [] for c in available_output_components: for i in range(c.output_bit_size): - list_of_bit_names.append(c.id + "_" + str(i) + "_output_updated") + list_of_bit_names.append(f"{c.id}_{i}_output_updated") for c in available_input_components: for i in range(c.output_bit_size): - list_of_bit_names.append(c.id + "_" + str(i) + "_output") + list_of_bit_names.append(f"{c.id}_{i}_output") for i in range(component.input_bit_size): - bit_name = component.id + "_" + str(i) + "_input" + bit_name = f"{component.id}_{i}_input" if not is_bit_adjacent_to_list_of_bits(bit_name, list_of_bit_names, all_equivalent_bits): return False return True @@ -523,6 +563,7 @@ def _get_successor_components(component_id, cipher): graph_cipher = create_networkx_graph_from_input_ids(cipher) return list(graph_cipher.successors(component_id)) + def are_there_enough_available_inputs_to_perform_inversion(component, available_bits, all_equivalent_bits, self): """ NOTE: it assumes that the component input size is a multiple of the output size @@ -530,14 +571,14 @@ def are_there_enough_available_inputs_to_perform_inversion(component, available_ # STEP 1 - Special case for output components which have no output links (only cipher output) if (component.type == CIPHER_OUTPUT) or (component.id == INPUT_KEY): return True - if (component.type == INTERMEDIATE_OUTPUT and _get_successor_components(component.id, self) == []): + if component.type == INTERMEDIATE_OUTPUT and _get_successor_components(component.id, self) == []: return False # STEP 2 - Other components bit_lists_link_to_component_from_output = component_output_bits(component, self) component_output_bits_list = [] for i in range(component.output_bit_size): - component_output_bits_list.append({"component_id" : component.id, "position" : i, "type" : "output"}) + component_output_bits_list.append({"component_id": component.id, "position": i, "type": "output"}) bit_lists_link_to_component_from_output_and_available = [] for bit_list in bit_lists_link_to_component_from_output: bits_in_common = equivalent_bits_in_common(bit_list, component_output_bits_list, all_equivalent_bits) @@ -557,16 +598,25 @@ def are_there_enough_available_inputs_to_perform_inversion(component, available_ output_components = get_output_components(component_of_link, self) link_bit_names = [] for bit in bit_lists_link_to_component_from_input[index]: - link_bit_name = bit["component_id"] + "_" + str(bit["position"]) + "_output" + link_bit_name = f"{bit['component_id']}_{bit['position']}_output" link_bit_names.append(link_bit_name) for output_component in output_components: nb_available_output_component_bits = 0 - if (output_component.id not in component.input_id_links) and ( - output_component.id != component.id) and (output_component.type != INTERMEDIATE_OUTPUT): + if ( + (output_component.id not in component.input_id_links) + and (output_component.id != component.id) + and (output_component.type != INTERMEDIATE_OUTPUT) + ): for i in range(output_component.output_bit_size): - output_component_bit_name = output_component.id + "_" + str(i) + "_output_updated" - output_component_bit = {"component_id": output_component.id, "position": i, "type": "output_updated"} - if is_bit_adjacent_to_list_of_bits(output_component_bit_name, link_bit_names, all_equivalent_bits) and (output_component_bit in available_bits): + output_component_bit_name = f"{output_component.id}_{i}_output_updated" + output_component_bit = { + "component_id": output_component.id, + "position": i, + "type": "output_updated", + } + if is_bit_adjacent_to_list_of_bits( + output_component_bit_name, link_bit_names, all_equivalent_bits + ) and (output_component_bit in available_bits): nb_available_output_component_bits += 1 if nb_available_output_component_bits == output_component.output_bit_size: can_be_used_for_inversion[index] = True @@ -582,11 +632,10 @@ def are_there_enough_available_inputs_to_perform_inversion(component, available_ else: return len(bit_lists_link_to_component_from_input_and_output) >= component.input_bit_size -def is_possibly_invertible_component(component): +def is_possibly_invertible_component(component): # if sbox is a permutation - if component.type == SBOX and \ - len(list(set(component.description))) == len(component.description): + if component.type == SBOX and len(list(set(component.description))) == len(component.description): is_invertible = True # if sbox is NOT a permutation, then cannot be inverted elif component.type == SBOX and len(list(set(component.description))) != len(component.description): @@ -621,6 +670,7 @@ def is_possibly_invertible_component(component): return is_invertible + def is_intersection_of_input_id_links_null(inverse_component, component): flag_intersection_null = True for input_id_link in component.input_id_links: @@ -629,47 +679,48 @@ def is_intersection_of_input_id_links_null(inverse_component, component): if flag_intersection_null: return True, [] - if (component.type == "constant"): + if component.type == "constant": return False, list(range(component.output_bit_size)) starting_bit_position = 0 input_bit_positions = [] for index, input_id_link in enumerate(component.input_id_links): if input_id_link not in inverse_component.input_id_links: - input_bit_positions += range(starting_bit_position, starting_bit_position + len(component.input_bit_positions[index])) + input_bit_positions += range( + starting_bit_position, starting_bit_position + len(component.input_bit_positions[index]) + ) starting_bit_position += len(component.input_bit_positions[index]) return False, input_bit_positions + def find_input_id_link_bits_equivalent(inverse_component, component, all_equivalent_bits): bit_positions = [] list_of_keys = [] for index, input_id_link in enumerate(inverse_component.input_id_links): for position, i in enumerate(inverse_component.input_bit_positions[index]): - potential_equivalent_bit_name = input_id_link + "_" + str(i) + "_output_updated" + potential_equivalent_bit_name = f"{input_id_link}_{i}_output_updated" if potential_equivalent_bit_name in all_equivalent_bits.keys(): list_of_keys += all_equivalent_bits[potential_equivalent_bit_name] offset = 0 for index, input_id_link in enumerate(component.input_id_links): for pos, i in enumerate(component.input_bit_positions[index]): - output_bit_name = input_id_link + "_" + str(i) + "_output" - if output_bit_name in all_equivalent_bits and not any("output_updated" in item for item in all_equivalent_bits[output_bit_name]): + output_bit_name = f"{input_id_link}_{i}_output" + if output_bit_name in all_equivalent_bits and not any( + "output_updated" in item for item in all_equivalent_bits[output_bit_name] + ): bit_positions.append(offset + pos) offset += len(component.input_bit_positions[index]) return bit_positions -def update_output_bits(inverse_component, self, all_equivalent_bits, available_bits): +def update_output_bits(inverse_component, self, all_equivalent_bits, available_bits): def _add_output_bit_equivalences(id, bit_positions, component, all_equivalent_bits, available_bits): for i in range(component.output_bit_size): - output_bit_name_updated = id + "_" + str(i) + "_output_updated" - bit = { - "component_id": id, - "position": i, - "type": "output_updated" - } + output_bit_name_updated = f"{id}_{i}_output_updated" + bit = {"component_id": id, "position": i, "type": "output_updated"} available_bits.append(bit) - input_bit_name = id + "_" + str(bit_positions[i]) + "_input" + input_bit_name = f"{id}_{bit_positions[i]}_input" all_equivalent_bits[input_bit_name].append(output_bit_name_updated) if output_bit_name_updated not in all_equivalent_bits.keys(): all_equivalent_bits[output_bit_name_updated] = [] @@ -682,16 +733,16 @@ def _add_output_bit_equivalences(id, bit_positions, component, all_equivalent_bi id = inverse_component.id component = get_component_from_id(id, self) - if (component.description == [INPUT_KEY]) or (component.description == [INPUT_TWEAK]) or(component.type == CONSTANT): + if ( + (component.description == [INPUT_KEY]) + or (component.description == [INPUT_TWEAK]) + or (component.type == CONSTANT) + ): for i in range(component.output_bit_size): - output_bit_name_updated = id + "_" + str(i) + "_output_updated" - bit = { - "component_id": id, - "position": i, - "type": "output_updated" - } + output_bit_name_updated = f"{id}_{i}_output_updated" + bit = {"component_id": id, "position": i, "type": "output_updated"} available_bits.append(bit) - input_bit_name = id + "_" + str(i) + "_output" + input_bit_name = f"{id}_{i}_output" if input_bit_name not in all_equivalent_bits.keys(): all_equivalent_bits[input_bit_name] = [] all_equivalent_bits[input_bit_name].append(output_bit_name_updated) @@ -702,18 +753,24 @@ def _add_output_bit_equivalences(id, bit_positions, component, all_equivalent_bi if name != output_bit_name_updated: all_equivalent_bits[output_bit_name_updated].append(name) elif component.input_bit_size == component.output_bit_size: - _add_output_bit_equivalences(id, range(component.output_bit_size), component, all_equivalent_bits, available_bits) + _add_output_bit_equivalences( + id, range(component.output_bit_size), component, all_equivalent_bits, available_bits + ) else: input_bit_positions = find_input_id_link_bits_equivalent(inverse_component, component, all_equivalent_bits) _add_output_bit_equivalences(id, input_bit_positions, component, all_equivalent_bits, available_bits) + def order_input_id_links_for_modadd(component, input_id_links, input_bit_positions, available_bits, self): available_output_components_with_indices = get_available_output_components(component, available_bits, self, True) old_index = 0 for index, input_id_link in enumerate(input_id_links): - index_id_list = [_ for _, x in enumerate(available_output_components_with_indices) if - x[0].id == input_id_link and set(x[1]) == set(input_bit_positions[index])] + index_id_list = [ + _ + for _, x in enumerate(available_output_components_with_indices) + if x[0].id == input_id_link and set(x[1]) == set(input_bit_positions[index]) + ] if index_id_list: old_index = index break @@ -721,6 +778,7 @@ def order_input_id_links_for_modadd(component, input_id_links, input_bit_positio input_bit_positions.insert(0, input_bit_positions.pop(old_index)) return input_id_links, input_bit_positions + def component_inverse(component, available_bits, all_equivalent_bits, key_schedule_components, self): """ This functions assumes that the component is actually invertible. @@ -729,148 +787,237 @@ def component_inverse(component, available_bits, all_equivalent_bits, key_schedu available_output_components = get_available_output_components(component, available_bits, self) if component.type == SBOX: - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components(component, output_components, all_equivalent_bits, self) + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, output_components, all_equivalent_bits, self + ) + ) S = SBox(component.description) Sinv = list(S.inverse()) - inverse_component = Component(component.id, component.type, Input(component.input_bit_size, input_id_links, input_bit_positions), component.output_bit_size, Sinv) + inverse_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + Sinv, + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == LINEAR_LAYER: - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components(component, output_components, all_equivalent_bits, self) + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, output_components, all_equivalent_bits, self + ) + ) binary_matrix = binary_matrix_of_linear_component(component) inv_binary_matrix = binary_matrix.inverse() - inverse_component = Component(component.id, component.type, - Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, list(inv_binary_matrix)) + inverse_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + list(inv_binary_matrix), + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == MIX_COLUMN: - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( - component, available_output_components, all_equivalent_bits, self) + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, available_output_components, all_equivalent_bits, self + ) + ) description = component.description - G = PolynomialRing(GF(2), 'x') + G = PolynomialRing(GF(2), "x") x = G.gen() irr_poly = int_to_poly(int(description[1]), int(description[2]), x) if irr_poly and not irr_poly.is_irreducible(): binary_matrix = binary_matrix_of_linear_component(component) inv_binary_matrix = binary_matrix.inverse() - inverse_component = Component(component.id, LINEAR_LAYER, Input(component.input_bit_size, input_id_links, input_bit_positions), component.output_bit_size, list(inv_binary_matrix.transpose())) + inverse_component = Component( + component.id, + LINEAR_LAYER, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + list(inv_binary_matrix.transpose()), + ) inverse_component.__class__ = linear_layer_component.LinearLayer else: inv_matrix = get_inverse_matrix_in_integer_representation(component) - inverse_component = Component(component.id, component.type, - Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, [[list(row) for row in inv_matrix]] + component.description[1:]) + inverse_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + [[list(row) for row in inv_matrix]] + component.description[1:], + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == WORD_OPERATION and component.description[0] == "SIGMA": - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components(component, output_components, all_equivalent_bits, self) + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, output_components, all_equivalent_bits, self + ) + ) binary_matrix = binary_matrix_of_linear_component(component) inv_binary_matrix = binary_matrix.inverse() - inverse_component = Component(component.id, LINEAR_LAYER, - Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, list(inv_binary_matrix.transpose())) + inverse_component = Component( + component.id, + LINEAR_LAYER, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + list(inv_binary_matrix.transpose()), + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == WORD_OPERATION and component.description[0] == "XOR": - input_id_links_from_output_components, input_bit_positions_from_output_components = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( - component, output_components, all_equivalent_bits, self) - input_id_links_from_input_components, input_bit_positions_from_input_components = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_input_components(component, available_bits, all_equivalent_bits, key_schedule_components, self) + input_id_links_from_output_components, input_bit_positions_from_output_components = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, output_components, all_equivalent_bits, self + ) + ) + input_id_links_from_input_components, input_bit_positions_from_input_components = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_input_components( + component, available_bits, all_equivalent_bits, key_schedule_components, self + ) + ) input_id_links = input_id_links_from_input_components + input_id_links_from_output_components input_bit_positions = input_bit_positions_from_input_components + input_bit_positions_from_output_components - inverse_component = Component(component.id, component.type, - Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, component.description) + inverse_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + component.description, + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == WORD_OPERATION and component.description[0] == "ROTATE": - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components(component, available_output_components, all_equivalent_bits, self) - inverse_component = Component(component.id, component.type, - Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, [component.description[0], -component.description[1]]) + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, available_output_components, all_equivalent_bits, self + ) + ) + inverse_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + [component.description[0], -component.description[1]], + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == WORD_OPERATION and component.description[0] == "NOT": - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components(component, available_output_components, all_equivalent_bits, self) - inverse_component = Component(component.id, component.type, - Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, [component.description[0], component.description[1]]) + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, available_output_components, all_equivalent_bits, self + ) + ) + inverse_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + [component.description[0], component.description[1]], + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == WORD_OPERATION and component.description[0] == "MODADD": - input_id_links_from_output_components, input_bit_positions_from_output_components = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( - component, available_output_components, all_equivalent_bits, self) - input_id_links_from_input_components, input_bit_positions_from_input_components = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_input_components( - component, available_bits, all_equivalent_bits, key_schedule_components, self) + input_id_links_from_output_components, input_bit_positions_from_output_components = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, available_output_components, all_equivalent_bits, self + ) + ) + input_id_links_from_input_components, input_bit_positions_from_input_components = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_input_components( + component, available_bits, all_equivalent_bits, key_schedule_components, self + ) + ) input_id_links = input_id_links_from_input_components + input_id_links_from_output_components input_bit_positions = input_bit_positions_from_input_components + input_bit_positions_from_output_components - input_id_links, input_bit_positions = order_input_id_links_for_modadd(component, input_id_links, input_bit_positions, available_bits, self) - inverse_component = Component(component.id, component.type, - Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, ["MODSUB", component.description[1], component.description[2]]) + input_id_links, input_bit_positions = order_input_id_links_for_modadd( + component, input_id_links, input_bit_positions, available_bits, self + ) + inverse_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + ["MODSUB", component.description[1], component.description[2]], + ) inverse_component.__class__ = modsub_component.MODSUB setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == CONSTANT: - inverse_component = Component(component.id, component.type, - Input(0, [[]], [[]]), - component.output_bit_size, component.description) + inverse_component = Component( + component.id, component.type, Input(0, [[]], [[]]), component.output_bit_size, component.description + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == CIPHER_OUTPUT: - inverse_component = Component(component.id, CIPHER_INPUT, - Input(0, [[]], [[]]), - component.output_bit_size, [CIPHER_INPUT]) + inverse_component = Component( + component.id, CIPHER_INPUT, Input(0, [[]], [[]]), component.output_bit_size, [CIPHER_INPUT] + ) setattr(inverse_component, "round", -1) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) - elif component.type == CIPHER_INPUT and (component.id in [INPUT_PLAINTEXT, INPUT_STATE] or INTERMEDIATE_OUTPUT in component.id): - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( - component, available_output_components, all_equivalent_bits, self) - inverse_component = Component(component.id, CIPHER_OUTPUT, - Input(component.output_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, [component.id]) + elif component.type == CIPHER_INPUT and ( + component.id in [INPUT_PLAINTEXT, INPUT_STATE] or INTERMEDIATE_OUTPUT in component.id + ): + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, available_output_components, all_equivalent_bits, self + ) + ) + inverse_component = Component( + component.id, + CIPHER_OUTPUT, + Input(component.output_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + [component.id], + ) inverse_component.__class__ = cipher_output_component.CipherOutput setattr(inverse_component, "round", component.round) elif component.type == CIPHER_INPUT and (component.description == [INPUT_KEY] or component.id == INPUT_TWEAK): - inverse_component = Component(component.id, CIPHER_INPUT, - Input(0, [[]], [[]]), - component.output_bit_size, [component.id]) + inverse_component = Component( + component.id, CIPHER_INPUT, Input(0, [[]], [[]]), component.output_bit_size, [component.id] + ) inverse_component.__class__ = component.__class__ setattr(inverse_component, "round", -1) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) elif component.type == INTERMEDIATE_OUTPUT: - input_id_links, input_bit_positions = compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( - component, available_output_components, all_equivalent_bits, self) - inverse_component = Component(component.id, INTERMEDIATE_OUTPUT, - Input(component.output_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, component.description) + input_id_links, input_bit_positions = ( + compute_input_id_links_and_input_bit_positions_for_inverse_component_from_available_output_components( + component, available_output_components, all_equivalent_bits, self + ) + ) + inverse_component = Component( + component.id, + INTERMEDIATE_OUTPUT, + Input(component.output_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + component.description, + ) inverse_component.__class__ = intermediate_output_component.IntermediateOutput setattr(inverse_component, "round", component.round) update_output_bits(inverse_component, self, all_equivalent_bits, available_bits) else: - inverse_component = Component("NA", "NA", - Input(0, [[]], [[]]), - component.output_bit_size, ["NA"]) + inverse_component = Component("NA", "NA", Input(0, [[]], [[]]), component.output_bit_size, ["NA"]) return inverse_component + def update_available_bits_with_component_output_bits(component, available_bits, cipher): output_components = get_output_components(component, cipher) for i in range(component.output_bit_size): - bit = { - "component_id": component.id, - "position": i, - "type": "output" - } + bit = {"component_id": component.id, "position": i, "type": "output"} add_bit_to_bit_list(bit, available_bits) # add bits of the connected output components @@ -879,17 +1026,9 @@ def update_available_bits_with_component_output_bits(component, available_bits, for i in range(len(c.input_id_links)): if c.input_id_links[i] == component.id: for j in range(len(c.input_bit_positions[i])): - component_output_bit = { - "component_id": component.id, - "position": j, - "type": "output" - } + component_output_bit = {"component_id": component.id, "position": j, "type": "output"} if is_bit_contained_in(component_output_bit, available_bits): - c_input_bit = { - "component_id": c.id, - "position": accumulator + j, - "type": "input" - } + c_input_bit = {"component_id": c.id, "position": accumulator + j, "type": "input"} add_bit_to_bit_list(c_input_bit, available_bits) accumulator += len(c.input_bit_positions[i]) return @@ -897,11 +1036,7 @@ def update_available_bits_with_component_output_bits(component, available_bits, def update_available_bits_with_component_input_bits(component, available_bits): for i in range(component.input_bit_size): - bit = { - "component_id": component.id, - "position": i, - "type": "input" - } + bit = {"component_id": component.id, "position": i, "type": "input"} add_bit_to_bit_list(bit, available_bits) # add bits of the connected input components @@ -910,7 +1045,7 @@ def update_available_bits_with_component_input_bits(component, available_bits): bit1 = { "component_id": component.input_id_links[i], "position": component.input_bit_positions[i][j], - "type": "output" + "type": "output", } add_bit_to_bit_list(bit1, available_bits) return @@ -918,33 +1053,23 @@ def update_available_bits_with_component_input_bits(component, available_bits): def all_input_bits_available(component, available_bits): for i in range(component.input_bit_size): - bit = { - "component_id": component.id, - "position": i, - "type": "input" - } + bit = {"component_id": component.id, "position": i, "type": "input"} if not is_bit_contained_in(bit, available_bits): return False return True + def all_output_updated_bits_available(component, available_bits): for i in range(component.input_bit_size): - bit = { - "component_id": component.id, - "position": i, - "type": "output_updated" - } + bit = {"component_id": component.id, "position": i, "type": "output_updated"} if not is_bit_contained_in(bit, available_bits): return False return True + def all_output_bits_available(component, available_bits): for i in range(component.output_bit_size): - bit = { - "component_id": component.id, - "position": i, - "type": "output_updated" - } + bit = {"component_id": component.id, "position": i, "type": "output_updated"} if not is_bit_contained_in(bit, available_bits): return False return True @@ -979,23 +1104,26 @@ def is_output_bits_updated_equivalent_to_input_bits(output_bits_updated_list, in return False return True + def find_correct_order(id1, list1, id2, list2, all_equivalent_bits): list2_ordered = [] for i in list1: - bit = id1 + "_" + str(i) + "_output" + bit = f"{id1}_{i}_output" for j in list2: - bit_potentially_equivalent = id2 + "_" + str(j) + "_input" + bit_potentially_equivalent = f"{id2}_{j}_input" if bit_potentially_equivalent in all_equivalent_bits[bit]: list2_ordered.append(j) break return list2_ordered + def find_correct_order_for_inversion(list1, list2, component): list2_ordered = [] for i in list1: list2_ordered.append(list2[i % component.output_bit_size]) return list2_ordered + # def evaluated_component(component, available_bits, key_schedule_component_ids, all_equivalent_bits, self): # input_id_links = [] # input_bit_positions = [] @@ -1005,7 +1133,7 @@ def find_correct_order_for_inversion(list1, list2, component): # available_output_components = get_available_output_components(component_of_link, available_bits, self) # link_bit_names = [] # for i in range(component_of_link.output_bit_size): -# link_bit_name = link + "_" + str(i) + "_output" +# link_bit_name = f"{link}_{i}_output" # link_bit_names.append(link_bit_name) # for index_available_output_component, available_output_component in enumerate(available_output_components): # if (available_output_component.id not in component.input_id_links) and ( @@ -1076,14 +1204,14 @@ def find_correct_order_for_inversion(list1, list2, component): # # id = component.id # for i in range(evaluated_component.output_bit_size): -# output_bit_name_updated = id + "_" + str(i) + "_output_updated" +# output_bit_name_updated = f"{id}_{i}_output_updated" # bit = { # "component_id": id, # "position": i, # "type": "output_updated" # } # available_bits.append(bit) -# output_bit_name = id + "_" + str(i) + "_output" +# output_bit_name = f"{id}_{i}_output" # if output_bit_name not in all_equivalent_bits.keys(): # all_equivalent_bits[output_bit_name] = [] # all_equivalent_bits[output_bit_name].append(output_bit_name_updated) @@ -1096,6 +1224,7 @@ def find_correct_order_for_inversion(list1, list2, component): # # return evaluated_component + def evaluated_component(component, available_bits, key_schedule_component_ids, all_equivalent_bits, self): input_id_links = [] input_bit_positions = [] @@ -1105,21 +1234,24 @@ def evaluated_component(component, available_bits, key_schedule_component_ids, a starting_bit_position = 0 for i in range(len(component.input_id_links)): components_with_same_input_bits = get_all_components_with_the_same_input_id_link_and_input_bit_positions( - component.input_id_links[i], component.input_bit_positions[i], self) + component.input_id_links[i], component.input_bit_positions[i], self + ) components_with_same_input_bits.remove(component) # check if the original input component has all output bits available original_input_component = get_component_from_id(component.input_id_links[i], self) output_bits_updated_list = [] for j in component.input_bit_positions[i]: - output_bit_updated_name = original_input_component.id + "_" + str(j) + "_output_updated" + output_bit_updated_name = f"{original_input_component.id}_{j}_output_updated" output_bits_updated_list.append(output_bit_updated_name) input_bits_list = [] for k in range(starting_bit_position, starting_bit_position + len(component.input_bit_positions[i])): - input_bit_name = component.id + "_" + str(k) + "_input" + input_bit_name = f"{component.id}_{k}_input" input_bits_list.append(input_bit_name) starting_bit_position += len(component.input_bit_positions[i]) - flag = is_output_bits_updated_equivalent_to_input_bits(output_bits_updated_list, input_bits_list, all_equivalent_bits) + flag = is_output_bits_updated_equivalent_to_input_bits( + output_bits_updated_list, input_bits_list, all_equivalent_bits + ) if all_output_bits_available(original_input_component, available_bits) and flag: input_id_links.append(component.input_id_links[i]) input_bit_positions.append(component.input_bit_positions[i]) @@ -1127,40 +1259,63 @@ def evaluated_component(component, available_bits, key_schedule_component_ids, a # select component for which the connected components have all their inputs available link = component.input_id_links[i] original_input_bit_positions_of_link = component.input_bit_positions[i] - available_output_components = get_available_output_components(original_input_component, available_bits, self) + available_output_components = get_available_output_components( + original_input_component, available_bits, self + ) link_bit_names = [] for l in range(original_input_component.output_bit_size): - link_bit_name = link + "_" + str(l) + "_output" + link_bit_name = f"{link}_{l}_output" link_bit_names.append(link_bit_name) - for index_available_output_component, available_output_component in enumerate( - available_output_components): + for _, available_output_component in enumerate(available_output_components): if (available_output_component.id not in component.input_id_links) and ( - available_output_component.id != component.id): - index_id_list = [_ for _, x in enumerate(available_output_component.input_id_links) if x == link and set(original_input_bit_positions_of_link) <= set(available_output_component.input_bit_positions[_])] - index_id = index_id_list[0] if index_id_list else available_output_component.input_id_links.index(link) + available_output_component.id != component.id + ): + index_id_list = [ + _ + for _, x in enumerate(available_output_component.input_id_links) + if x == link + and set(original_input_bit_positions_of_link) + <= set(available_output_component.input_bit_positions[_]) + ] + index_id = ( + index_id_list[0] if index_id_list else available_output_component.input_id_links.index(link) + ) starting_bit = 0 for index_list, list_bit_positions in enumerate(available_output_component.input_bit_positions): if index_list == index_id: break starting_bit += len(list_bit_positions) - available_output_component_bit_name = available_output_component.id + "_" + str( - starting_bit) + "_output_updated" - if is_bit_adjacent_to_list_of_bits(available_output_component_bit_name, link_bit_names, - all_equivalent_bits): + available_output_component_bit_name = ( + f"{available_output_component.id}_{starting_bit}_output_updated" + ) + if is_bit_adjacent_to_list_of_bits( + available_output_component_bit_name, link_bit_names, all_equivalent_bits + ): # if all_input_bits_available(c, available_bits): input_id_links.append(available_output_component.id) # get input bit positions - accumulator = 0 # changed + accumulator = 0 # changed for j in range(len(available_output_component.input_id_links)): if j == index_id: - if set(original_input_bit_positions_of_link) < set(available_output_component.input_bit_positions[j]): - accumulator += original_input_bit_positions_of_link[0] - available_output_component.input_bit_positions[j][0] - l = [h for h in range(accumulator, accumulator + len(component.input_bit_positions[i]))] - l_ordered = find_correct_order(link, original_input_bit_positions_of_link, available_output_component.id, l, all_equivalent_bits) + if set(original_input_bit_positions_of_link) < set( + available_output_component.input_bit_positions[j] + ): + accumulator += ( + original_input_bit_positions_of_link[0] + - available_output_component.input_bit_positions[j][0] + ) + l = list(range(accumulator, accumulator + len(component.input_bit_positions[i]))) + l_ordered = find_correct_order( + link, + original_input_bit_positions_of_link, + available_output_component.id, + l, + all_equivalent_bits, + ) input_bit_positions.append(l_ordered) break else: - accumulator += len(available_output_component.input_bit_positions[j]) # changed + accumulator += len(available_output_component.input_bit_positions[j]) # changed else: input_id_links = [[]] input_bit_positions = [[]] @@ -1170,21 +1325,22 @@ def evaluated_component(component, available_bits, key_schedule_component_ids, a del input_id_links[index] del input_bit_positions[index] - evaluated_component = Component(component.id, component.type, Input(component.input_bit_size, input_id_links, input_bit_positions), - component.output_bit_size, component.description) + evaluated_component = Component( + component.id, + component.type, + Input(component.input_bit_size, input_id_links, input_bit_positions), + component.output_bit_size, + component.description, + ) evaluated_component.__class__ = component.__class__ setattr(evaluated_component, "round", getattr(component, "round")) id = component.id for i in range(evaluated_component.output_bit_size): - output_bit_name_updated = id + "_" + str(i) + "_output_updated" - bit = { - "component_id": id, - "position": i, - "type": "output_updated" - } + output_bit_name_updated = f"{id}_{i}_output_updated" + bit = {"component_id": id, "position": i, "type": "output_updated"} available_bits.append(bit) - output_bit_name = id + "_" + str(i) + "_output" + output_bit_name = f"{id}_{i}_output" if output_bit_name not in all_equivalent_bits.keys(): all_equivalent_bits[output_bit_name] = [] all_equivalent_bits[output_bit_name].append(output_bit_name_updated) @@ -1202,6 +1358,7 @@ def cipher_find_component(cipher, round_number, component_id): rounds = cipher._rounds.round_at(round_number)._components return next((item for item in rounds if item.id == component_id), None) + def delete_orphan_links(cipher, round_number): """ Delete orphans elements from input_id_link @@ -1215,10 +1372,11 @@ def delete_orphan_links(cipher, round_number): for input_id_link in component.input_id_links: if cipher_find_component(cipher, round_number, input_id_link) == None: idx = component.input_id_links.index(input_id_link) - component.input_id_links[idx] = '' + component.input_id_links[idx] = "" new_components.append(component) return new_components + def topological_sort(round_list): """ Perform topological sort on round components. @@ -1226,7 +1384,7 @@ def topological_sort(round_list): - ``round_list`` -- list of components """ pending = [(component.id, set(component.input_id_links)) for component in round_list] - emitted = [''] + emitted = [""] while pending: next_pending = [] next_emitted = [] @@ -1244,6 +1402,7 @@ def topological_sort(round_list): pending = next_pending emitted = next_emitted + def sort_cipher_graph(cipher): """ Sorts the cipher graph in a way that @@ -1270,10 +1429,13 @@ def sort_cipher_graph(cipher): return cipher + def remove_components_from_rounds(cipher, start_round, end_round, keep_key_schedule): - list_of_rounds = cipher.rounds_as_list[:start_round] + cipher.rounds_as_list[end_round + 1:] + list_of_rounds = cipher.rounds_as_list[:start_round] + cipher.rounds_as_list[end_round + 1 :] key_schedule_component_ids = get_key_schedule_component_ids(cipher) - key_schedule_components = [cipher.get_component_from_id(id) for id in key_schedule_component_ids if INPUT_KEY not in id] + key_schedule_components = [ + cipher.get_component_from_id(id) for id in key_schedule_component_ids if INPUT_KEY not in id + ] if not keep_key_schedule: for current_round in cipher.rounds_as_list: @@ -1284,37 +1446,47 @@ def remove_components_from_rounds(cipher, start_round, end_round, keep_key_sched intermediate_outputs = {} for current_round in list_of_rounds: for component in set(current_round.components) - set(key_schedule_components): - if component.type == INTERMEDIATE_OUTPUT and component.description == ['round_output']: + if component.type == INTERMEDIATE_OUTPUT and component.description == ["round_output"]: intermediate_outputs[current_round.id] = component cipher.rounds.remove_round_component(current_round.id, component) removed_component_ids.append(component.id) return removed_component_ids, intermediate_outputs + def get_relative_position(target_link, target_bit_positions, intermediate_output): if target_link == intermediate_output.id: return target_bit_positions intermediate_output_position_links = {} current_bit_position = 0 - for input_id_link, input_bit_positions in zip(intermediate_output.input_id_links, intermediate_output.input_bit_positions): + for input_id_link, input_bit_positions in zip( + intermediate_output.input_id_links, intermediate_output.input_bit_positions + ): for i in input_bit_positions: intermediate_output_position_links[(input_id_link, i)] = current_bit_position current_bit_position += 1 - return [intermediate_output_position_links[(target_link, bit)] for bit in target_bit_positions if (target_link, bit) in intermediate_output_position_links] + return [ + intermediate_output_position_links[(target_link, bit)] + for bit in target_bit_positions + if (target_link, bit) in intermediate_output_position_links + ] + def get_most_recent_intermediate_output(target_link, intermediate_outputs): for index in sorted(intermediate_outputs, reverse=True): if target_link in intermediate_outputs[index].input_id_links or target_link == intermediate_outputs[index].id: return intermediate_outputs[index] + def update_input_links_from_rounds(cipher_rounds, removed_components, intermediate_outputs): for round in cipher_rounds: for component in round.components: for i, link in enumerate(component.input_id_links): if link in removed_components: intermediate_output = get_most_recent_intermediate_output(link, intermediate_outputs) - component.input_id_links[i] = f'{intermediate_output.id}' - component.input_bit_positions[i] = get_relative_position(link, component.input_bit_positions[i], - intermediate_output) + component.input_id_links[i] = f"{intermediate_output.id}" + component.input_bit_positions[i] = get_relative_position( + link, component.input_bit_positions[i], intermediate_output + ) diff --git a/claasp/cipher_modules/models/cp/minizinc_utils/utils.py b/claasp/cipher_modules/models/cp/minizinc_utils/utils.py index 7d6b5e638..749d3b0d8 100644 --- a/claasp/cipher_modules/models/cp/minizinc_utils/utils.py +++ b/claasp/cipher_modules/models/cp/minizinc_utils/utils.py @@ -1,3 +1,5 @@ +import os + def filter_out_strings_containing_substring(strings_list, substring): return [string for string in strings_list if substring not in string] @@ -6,14 +8,15 @@ def filter_out_strings_containing_substring(strings_list, substring): def group_strings_by_pattern(list_of_data): results = [] data = list_of_data - data = filter_out_strings_containing_substring(data, 'array') + data = filter_out_strings_containing_substring(data, "array") prefixes = set([entry.split("_y")[0].split(": ")[1] for entry in data if "_y" in entry]) # For each prefix, collect matching strings for prefix in prefixes: - sublist = [entry.split(": ")[1][:-1] for entry in data if - entry.startswith(f"var bool: {prefix}") and "_y" in entry] + sublist = [ + entry.split(": ")[1][:-1] for entry in data if entry.startswith(f"var bool: {prefix}") and "_y" in entry + ] if sublist: results.append(sublist) - return results \ No newline at end of file + return results diff --git a/claasp/cipher_modules/models/cp/mzn_model.py b/claasp/cipher_modules/models/cp/mzn_model.py index b5e2da5e8..b2a4c76ca 100644 --- a/claasp/cipher_modules/models/cp/mzn_model.py +++ b/claasp/cipher_modules/models/cp/mzn_model.py @@ -1,52 +1,51 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - -import os import math import itertools import subprocess import time from copy import deepcopy +from datetime import timedelta from sage.crypto.sbox import SBox -from datetime import timedelta - from minizinc import Instance, Model, Solver, Status from claasp.cipher_modules.component_analysis_tests import branch_number from claasp.cipher_modules.models.cp.minizinc_utils import usefulfunctions -from claasp.cipher_modules.models.utils import write_model_to_file, convert_solver_solution_to_dictionary -from claasp.name_mappings import SBOX -from claasp.cipher_modules.models.cp.solvers import CP_SOLVERS_INTERNAL, CP_SOLVERS_EXTERNAL, MODEL_DEFAULT_PATH, SOLVER_DEFAULT +from claasp.cipher_modules.models.utils import convert_solver_solution_to_dictionary +from claasp.name_mappings import CIPHER, CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, SATISFIABLE, SBOX, UNSATISFIABLE +from claasp.cipher_modules.models.cp.solvers import ( + CP_SOLVERS_INTERNAL, + CP_SOLVERS_EXTERNAL, + SOLVER_DEFAULT, +) -solve_satisfy = 'solve satisfy;' -constraint_type_error = 'Constraint type not defined' +SOLVE_SATISFY = "solve satisfy;" +CONSTRAINT_TYPE_ERROR = "Constraint type not defined" class MznModel: - - def __init__(self, cipher, window_size_list=None, probability_weight_per_round=None, sat_or_milp='sat'): + def __init__(self, cipher, window_size_list=None, probability_weight_per_round=None, sat_or_milp="sat"): self._cipher = cipher self.initialise_model() - if sat_or_milp not in ['sat', 'milp']: + if sat_or_milp not in ("sat", "milp"): raise ValueError("Allowed value for sat_or_milp parameter is either sat or milp") self.sat_or_milp = sat_or_milp @@ -78,7 +77,6 @@ def __init__(self, cipher, window_size_list=None, probability_weight_per_round=N if window_size_list and len(window_size_list) != self._cipher.number_of_rounds: raise ValueError("window_size_list size must be equal to cipher_number_of_rounds") - def initialise_model(self): self._variables_list = [] self._model_constraints = [] @@ -98,7 +96,7 @@ def initialise_model(self): self.list_of_xor_components = [] self.list_of_xor_all_inputs = [] self.component_and_probability = {} - self._model_prefix = ['include "globals.mzn";', f'{usefulfunctions.MINIZINC_USEFUL_FUNCTIONS}'] + self._model_prefix = ['include "globals.mzn";', f"{usefulfunctions.MINIZINC_USEFUL_FUNCTIONS}"] def add_comment(self, comment): """ @@ -108,16 +106,26 @@ def add_comment(self, comment): - ``comment`` -- **string**; string with the comment to be added """ - self.mzn_comments.append("% " + comment) + self.mzn_comments.append(f"% {comment}") def add_constraint_from_str(self, str_constraint): self._model_constraints.append(str_constraint) def add_output_comment(self, comment): - self.mzn_output_directives.append(f'output [\"Comment: {comment}\", \"\\n\"];') - - def add_solutions_from_components_values(self, components_values, memory, model_type, solutions, solve_time, - solver_name, solver_output, total_weight, solve_external): + self.mzn_output_directives.append(f'output ["Comment: {comment}", "\\n"];') + + def add_solutions_from_components_values( + self, + components_values, + memory, + model_type, + solutions, + solve_time, + solver_name, + solver_output, + total_weight, + solve_external, + ): for i in range(len(total_weight)): solution = convert_solver_solution_to_dictionary( self._cipher, @@ -125,39 +133,44 @@ def add_solutions_from_components_values(self, components_values, memory, model_ solver_name, solve_time, memory, - components_values[f'solution{i + 1}'], - total_weight[i]) + components_values[f"solution{i + 1}"], + total_weight[i], + ) if solve_external: - if 'UNSATISFIABLE' in solver_output[0]: - solution['status'] = 'UNSATISFIABLE' + if UNSATISFIABLE in solver_output[0]: + solution["status"] = UNSATISFIABLE else: - solution['status'] = 'SATISFIABLE' + solution["status"] = SATISFIABLE else: if solver_output.status not in [Status.SATISFIED, Status.ALL_SOLUTIONS, Status.OPTIMAL_SOLUTION]: - solution['status'] = 'UNSATISFIABLE' + solution["status"] = UNSATISFIABLE else: - solution['status'] = 'SATISFIABLE' + solution["status"] = SATISFIABLE solutions.append(solution) - def add_solution_to_components_values(self, component_id, component_solution, components_values, j, output_to_parse, - solution_number, string): + def add_solution_to_components_values( + self, component_id, component_solution, components_values, j, output_to_parse, solution_number, string + ): + if f"solution{solution_number}" not in components_values: + components_values[f"solution{solution_number}"] = {} if component_id in self._cipher.inputs: - component_solution['weight'] = 0 - components_values[f'solution{solution_number}'][f'{component_id}'] = component_solution - elif f'{component_id}_i' in string: - component_solution['weight'] = float(output_to_parse[j + 2]) - components_values[f'solution{solution_number}'][f'{component_id}_i'] = component_solution - elif f'{component_id}_o' in string: - component_solution['weight'] = float(output_to_parse[j + 1]) - components_values[f'solution{solution_number}'][f'{component_id}_o'] = component_solution - elif f'{component_id} ' in string: - component_solution['weight'] = float(output_to_parse[j + 1]) - components_values[f'solution{solution_number}'][f'{component_id}'] = component_solution - - def add_solution_to_components_values_internal(self, component_solution, components_values, component_weight, - solution_number, component): - component_solution['weight'] = component_weight - components_values[f'solution{solution_number}'][f'{component}'] = component_solution + component_solution["weight"] = 0 + components_values[f"solution{solution_number}"][f"{component_id}"] = component_solution + elif f"{component_id}_i" in string: + component_solution["weight"] = float(output_to_parse[j + 2]) + components_values[f"solution{solution_number}"][f"{component_id}_i"] = component_solution + elif f"{component_id}_o" in string: + component_solution["weight"] = float(output_to_parse[j + 1]) + components_values[f"solution{solution_number}"][f"{component_id}_o"] = component_solution + elif f"{component_id} " in string: + component_solution["weight"] = float(output_to_parse[j + 1]) + components_values[f"solution{solution_number}"][f"{component_id}"] = component_solution + + def add_solution_to_components_values_internal( + self, component_solution, components_values, component_weight, solution_number, component + ): + component_solution["weight"] = component_weight + components_values[f"solution{solution_number}"][f"{component}"] = component_solution def build_mix_column_truncated_table(self, component): """ @@ -180,20 +193,22 @@ def build_mix_column_truncated_table(self, component): input_size = int(component.input_bit_size) output_size = int(component.output_bit_size) output_id_link = component.id - branch = branch_number(component, 'differential', 'word') + branch = branch_number(component, "differential", "word") total_size = (input_size + output_size) // self.word_size - table_items = '' + table_items = "" solutions = 0 - for i in range(2 ** total_size): - binary_i = f'{i:0{total_size}b}' + for i in range(2**total_size): + binary_i = f"{i:0{total_size}b}" bit_sum = sum(int(x) for x in binary_i) if bit_sum == 0 or bit_sum >= branch: table_items += binary_i solutions += 1 - table = ','.join(table_items) - mix_column_table = f'array[0..{solutions - 1}, 1..{total_size}] of int: ' \ - f'mix_column_truncated_table_{output_id_link} = ' \ - f'array2d(0..{solutions - 1}, 1..{total_size}, [{table}]);' + table = ",".join(table_items) + mix_column_table = ( + f"array[0..{solutions - 1}, 1..{total_size}] of int: " + f"mix_column_truncated_table_{output_id_link} = " + f"array2d(0..{solutions - 1}, 1..{total_size}, [{table}]);" + ) return mix_column_table @@ -218,20 +233,25 @@ def calculate_bit_values(self, bit_values, input_length): return new_bit_values - def calculate_input_bit_positions(self, word_index, input_name_1, input_name_2, new_input_bit_positions_1, - new_input_bit_positions_2): + def calculate_input_bit_positions( + self, word_index, input_name_1, input_name_2, new_input_bit_positions_1, new_input_bit_positions_2 + ): input_bit_positions = [[] for _ in range(3)] if input_name_1 != input_name_2: - input_bit_positions[0] = [int(new_input_bit_positions_1) * self.word_size + index - for index in range(self.word_size)] + input_bit_positions[0] = [ + int(new_input_bit_positions_1) * self.word_size + index for index in range(self.word_size) + ] input_bit_positions[1] = [word_index * self.word_size + index for index in range(self.word_size)] - input_bit_positions[2] = [int(new_input_bit_positions_2) * self.word_size + index - for index in range(self.word_size)] + input_bit_positions[2] = [ + int(new_input_bit_positions_2) * self.word_size + index for index in range(self.word_size) + ] else: - input_bit_positions[0] = [int(new_input_bit_positions_1) * self.word_size + index - for index in range(self.word_size)] - input_bit_positions[0] += [int(new_input_bit_positions_2) * self.word_size + index - for index in range(self.word_size)] + input_bit_positions[0] = [ + int(new_input_bit_positions_1) * self.word_size + index for index in range(self.word_size) + ] + input_bit_positions[0] += [ + int(new_input_bit_positions_2) * self.word_size + index for index in range(self.word_size) + ] input_bit_positions[1] = [word_index * self.word_size + index for index in range(self.word_size)] return input_bit_positions @@ -280,7 +300,7 @@ def find_possible_number_of_active_sboxes(self, weight): return numbers_of_active_sboxes - def fix_variables_value_constraints(self, fixed_variables=[], step='full_model'): + def fix_variables_value_constraints(self, fixed_variables=[], step="full_model"): r""" Return a list of CP constraints that fix the input variables to a specific value. @@ -315,26 +335,35 @@ def fix_variables_value_constraints(self, fixed_variables=[], step='full_model') """ cp_constraints = [] for component in fixed_variables: - component_id = component['component_id'] - bit_positions = component['bit_positions'] - bit_values = component['bit_values'] - if step == 'first_step': + component_id = component["component_id"] + bit_positions = component["bit_positions"] + bit_values = component["bit_values"] + if step == "first_step": if not self._cipher.is_spn(): - raise ValueError('Cipher is not SPN') + raise ValueError("Cipher is not SPN") input_length = len(bit_positions) // self.word_size bit_positions = self.calculate_bit_positions(bit_positions, input_length) bit_values = self.calculate_bit_values(bit_values, input_length) - if component['constraint_type'] == 'equal': - sign = '=' - logic_operator = ' /\\ ' - elif component['constraint_type'] == 'not_equal': - sign = '!=' - logic_operator = ' \\/ ' + if component["constraint_type"] == "equal": + sign = "=" + logic_operator = " /\\ " + elif component["constraint_type"] == "not_equal": + sign = "!=" + logic_operator = " \\/ " else: - raise ValueError(constraint_type_error) - values_constraints = [f'{component_id}[{position}] {sign} {bit_values[i]}' - for i, position in enumerate(bit_positions)] - new_constraint = 'constraint ' + f'{logic_operator}'.join(values_constraints) + ';' + raise ValueError(CONSTRAINT_TYPE_ERROR) + if bit_values[0] not in [0,1]: + variables_values = [] + for v in bit_values: + variables_values.extend([(v[0], i) for i in v[1]]) + values_constraints = [ + f"{component_id}[{position}] {sign} {variables_values[i][0]}[{variables_values[i][1]}]" for i, position in enumerate(bit_positions) + ] + else: + values_constraints = [ + f"{component_id}[{position}] {sign} {bit_values[i]}" for i, position in enumerate(bit_positions) + ] + new_constraint = "constraint " + f"{logic_operator}".join(values_constraints) + ";" cp_constraints.append(new_constraint) return cp_constraints @@ -375,27 +404,28 @@ def fix_variables_value_constraints_for_ARX(self, fixed_variables=[]): sage: minizinc.fix_variables_value_constraints_for_ARX(fixed_variables)[0] 'constraint plaintext_y0+plaintext_y1+plaintext_y2+plaintext_y3>0;' """ + def equal_operator(constraints_, fixed_variables_object_): component_name = fixed_variables_object_["component_id"] for i in range(len(fixed_variables_object_["bit_positions"])): bit_position = fixed_variables_object_["bit_positions"][i] bit_value = fixed_variables_object_["bit_values"][i] - constraints_.append(f'constraint {component_name}_y{bit_position} = {bit_value};') - if 'intermediate_output' in component_name or 'cipher_output' in component_name: - constraints_.append(f'constraint {component_name}_x{bit_position}' - f'=' - f'{bit_value};') + constraints_.append(f"constraint {component_name}_y{bit_position} = {bit_value};") + if INTERMEDIATE_OUTPUT in component_name or CIPHER_OUTPUT in component_name: + constraints_.append(f"constraint {component_name}_x{bit_position}={bit_value};") def sum_operator(constraints_, fixed_variables_object_): component_name = fixed_variables_object_["component_id"] bit_positions = [] for i in range(len(fixed_variables_object_["bit_positions"])): bit_position = fixed_variables_object_["bit_positions"][i] - bit_var_name_position = f'{component_name}_y{bit_position}' + bit_var_name_position = f"{component_name}_y{bit_position}" bit_positions.append(bit_var_name_position) - constraints_.append(f'constraint {"+".join(bit_positions)}' - f'{fixed_variables_object_["operator"]}' - f'{fixed_variables_object_["value"]};') + constraints_.append( + f"constraint {'+'.join(bit_positions)}" + f"{fixed_variables_object_['operator']}" + f"{fixed_variables_object_['value']};" + ) constraints = [] @@ -408,56 +438,56 @@ def sum_operator(constraints_, fixed_variables_object_): return constraints def format_component_value(self, component_id, string): - if f'{component_id}_i' in string: - value = string.replace(f'{component_id}_i', '') - elif f'{component_id}_o' in string: - value = string.replace(f'{component_id}_o', '') - elif f'inverse_{component_id}' in string: - value = string.replace(f'inverse_{component_id}', '') - elif f'{component_id}' in string: - value = string.replace(component_id, '') - value = value.replace('= [', '') - value = value.replace(']', '') - value = value.replace(',', '') - value = value.replace(' ', '') + if f"{component_id}_i" in string: + value = string.replace(f"{component_id}_i", "") + elif f"{component_id}_o" in string: + value = string.replace(f"{component_id}_o", "") + elif f"inverse_{component_id}" in string: + value = string.replace(f"inverse_{component_id}", "") + elif f"{component_id}" in string: + value = string.replace(component_id, "") + value = value.replace("= [", "") + value = value.replace("]", "") + value = value.replace(",", "") + value = value.replace(" ", "") return value - def get_command_for_solver_process(self, input_file_path, model_type, solver_name, num_of_processors, timelimit): - solvers = ['xor_differential_one_solution', - 'xor_linear_one_solution', - 'deterministic_truncated_xor_differential_one_solution', - 'impossible_xor_differential_one_solution', - 'differential_pair_one_solution', - 'evaluate_cipher'] - write_model_to_file(self._model_constraints, input_file_path) + def get_command_for_solver_process(self, model_type, solver_name, num_of_processors, timelimit): + solvers = ( + "deterministic_truncated_xor_differential_one_solution", + "differential_pair_one_solution", + "impossible_xor_differential_one_solution", + "xor_differential_one_solution", + "xor_linear_one_solution", + CIPHER, + ) found_name = False for i in range(len(CP_SOLVERS_EXTERNAL)): - if solver_name == CP_SOLVERS_EXTERNAL[i]['solver_name']: - command_options = deepcopy(CP_SOLVERS_EXTERNAL[i]) + if solver_name == CP_SOLVERS_EXTERNAL[i]["solver_name"]: + command_options = deepcopy(CP_SOLVERS_EXTERNAL[i]["keywords"]["command"]) found_name = True if not found_name: - raise(NameError(f'Solver {solver_name} not defined. Specify a valid solver name.')) - command_options['keywords']['command']['input_file'].append(input_file_path) + raise NameError(f"Solver {solver_name} not defined. Specify a valid solver name.") if model_type not in solvers: - command_options['keywords']['command']['options'].insert(0, '-a') + command_options["options"].insert(0, "-a") if num_of_processors is not None: - command_options['keywords']['command']['options'].insert(0, f'-p {num_of_processors}') + command_options["options"].insert(0, f"-p {num_of_processors}") if timelimit is not None: - command_options['keywords']['command']['options'].append('--time-limit') - command_options['keywords']['command']['options'].append(str(timelimit)) + command_options["options"].extend(["--time-limit", str(timelimit)]) command = [] - for key in command_options['keywords']['command']['format']: - command.extend(command_options['keywords']['command'][key]) - + for key in command_options["format"]: + command.extend(command_options[key]) + return command def get_mix_column_all_inputs(self, input_bit_positions_1, input_id_link_1, numb_of_inp_1): all_inputs = [] for i in range(numb_of_inp_1): for j in range(len(input_bit_positions_1[i]) // self.word_size): - all_inputs.append(f'{input_id_link_1[i]}' - f'[{input_bit_positions_1[i][j * self.word_size] // self.word_size}]') + all_inputs.append( + f"{input_id_link_1[i]}[{input_bit_positions_1[i][j * self.word_size] // self.word_size}]" + ) return all_inputs @@ -467,18 +497,19 @@ def get_total_weight(self, string_total_weight): elif string_total_weight is None: total_weight = None else: - total_weight = [str(int(w)/100.0) for w in string_total_weight] + total_weight = [str(int(w) / 100.0) for w in string_total_weight] return total_weight def output_probability_per_round(self): for mzn_probability_modadd_vars in self.probability_modadd_vars_per_round: mzn_probability_vars_per_round = "++".join(mzn_probability_modadd_vars) - self.mzn_output_directives.append(f'output ["\\n"++"Probability {mzn_probability_vars_per_round}:' - f' "++show(sum({mzn_probability_vars_per_round}))++"\\n"];') - - def parse_solver_information(self, output_to_parse, truncated=False, solve_external = True): + self.mzn_output_directives.append( + f'output ["\\n"++"Probability {mzn_probability_vars_per_round}:' + f' "++show(sum({mzn_probability_vars_per_round}))++"\\n"];' + ) + def parse_solver_information(self, output_to_parse, truncated=False, solve_external=True): memory = -1 time = -1 string_total_weight = [] @@ -486,24 +517,24 @@ def parse_solver_information(self, output_to_parse, truncated=False, solve_exter number_of_solutions = 1 if solve_external: for string in output_to_parse: - if 'time=' in string: + if "time=" in string: time_string = string time = float(time_string.replace("%%%mzn-stat: time=", "")) - elif 'solveTime=' in string: + elif "solveTime=" in string: time_string = string time = float(time_string.replace("%%%mzn-stat: solveTime=", "")) - elif 'trailMem=' in string: + elif "trailMem=" in string: memory_string = string memory = float(memory_string.replace("%%%mzn-stat: trailMem=", "")) - elif 'Trail weight' in string and not truncated: + elif "Trail weight" in string and not truncated: string_total_weight.append(float(string.replace("Trail weight = ", ""))) - components_values[f'solution{number_of_solutions}'] = {} + components_values[f"solution{number_of_solutions}"] = {} number_of_solutions += 1 - elif '----------' in string and truncated: + elif "----------" in string and truncated: string_total_weight.append("0") - components_values[f'solution{number_of_solutions}'] = {} + components_values[f"solution{number_of_solutions}"] = {} number_of_solutions += 1 - elif 'UNSATISFIABLE' in string: + elif UNSATISFIABLE in string: string_total_weight = None if number_of_solutions == 1: components_values = {} @@ -513,7 +544,9 @@ def parse_solver_information(self, output_to_parse, truncated=False, solve_exter return components_values, memory, time return components_values, memory, time, total_weight - def _parse_solver_output(self, output_to_parse, model_type, truncated = False, solve_external = False, solver_name = SOLVER_DEFAULT): + def _parse_solver_output( + self, output_to_parse, model_type, truncated=False, solve_external=False, solver_name=SOLVER_DEFAULT + ): """ Parse solver solution (if needed). @@ -532,11 +565,10 @@ def _parse_solver_output(self, output_to_parse, model_type, truncated = False, s sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'equal', range(32), integer_to_bit_list(0, 32, 'little'))) sage: cp.build_xor_differential_trail_model(-1, fixed_variables) - sage: write_model_to_file(cp._model_constraints,'doctesting_file.mzn') - sage: command = ['minizinc', '--solver-statistics', '--solver', 'Chuffed', 'doctesting_file.mzn'] + sage: command = ['minizinc', '--input-from-stdin', '--solver-statistics', '--solver', 'chuffed'] sage: import subprocess - sage: solver_process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - sage: os.remove('doctesting_file.mzn') + sage: model = '\n'.join(cp._model_constraints) + '\n' + sage: solver_process = subprocess.run(command, input=model, capture_output=True, text=True) sage: solver_output = solver_process.stdout.splitlines() sage: cp._parse_solver_output(solver_output, model_type = 'xor_differential_one_solution', solve_external = True) # random (0.018, @@ -544,74 +576,96 @@ def _parse_solver_output(self, output_to_parse, model_type, truncated = False, s 'cipher_output_3_12': {'value': '0', 'weight': 0}}}, ['0']) """ + def set_solution_values_internal(solution): components_values = {} - values = solution.__dict__['_output_item'].splitlines() + values = solution.__dict__["_output_item"].splitlines() total_weight = 0 for i in range(len(values)): curr_val = values[i] - if 'Trail weight' in curr_val: - total_weight = str(int(curr_val[curr_val.index('=') + 2:])/100.0) - elif '[' in curr_val: - component_id = curr_val[:curr_val.index('=') - 1] - value = ''.join(curr_val[curr_val.index('[') + 1:-1].split(', ')) + if "Trail weight" in curr_val: + total_weight = str(int(curr_val[curr_val.index("=") + 2 :]) / 100.0) + elif "[" in curr_val: + component_id = curr_val[: curr_val.index("=") - 1] + value = "".join(curr_val[curr_val.index("[") + 1 : -1].split(", ")) components_values[component_id] = {} self.set_component_solution_value(components_values[component_id], truncated, value) - if '=' not in values[i+1]: - components_values[component_id]['weight'] = float(values[i+1]) + if "=" not in values[i + 1]: + components_values[component_id]["weight"] = float(values[i + 1]) else: - components_values[component_id]['weight'] = 0 + components_values[component_id]["weight"] = 0 return components_values, total_weight - + if solve_external: if truncated: - components_values, memory, time = self.parse_solver_information(output_to_parse, truncated, solve_external) + components_values, memory, time = self.parse_solver_information( + output_to_parse, truncated, solve_external + ) else: - components_values, memory, time, total_weight = self.parse_solver_information(output_to_parse, truncated, solve_external) + components_values, memory, time, total_weight = self.parse_solver_information( + output_to_parse, truncated, solve_external + ) all_components = [*self._cipher.inputs, *self._cipher.get_all_components_ids()] for component_id in all_components: solution_number = 1 for j, string in enumerate(output_to_parse): - if f'{component_id} ' in string or f'{component_id}_i' in string or f'{component_id}_o' in string: + if f"{component_id} " in string or f"{component_id}_i" in string or f"{component_id}_o" in string: value = self.format_component_value(component_id, string) component_solution = {} self.set_component_solution_value(component_solution, truncated, value) - self.add_solution_to_components_values(component_id, component_solution, components_values, j, - output_to_parse, solution_number, string) - elif '----------' in string: + self.add_solution_to_components_values( + component_id, + component_solution, + components_values, + j, + output_to_parse, + solution_number, + string, + ) + elif "----------" in string: solution_number += 1 else: - if 'solveTime' in output_to_parse.statistics: - time = output_to_parse.statistics['solveTime'].total_seconds() + if "solveTime" in output_to_parse.statistics: + time = output_to_parse.statistics["solveTime"].total_seconds() else: - time = output_to_parse.statistics['time'].total_seconds() - if 'trailMem' in output_to_parse.statistics: - memory = output_to_parse.statistics['trailMem'] + time = output_to_parse.statistics["time"].total_seconds() + if "trailMem" in output_to_parse.statistics: + memory = output_to_parse.statistics["trailMem"] else: - memory = '-1' + memory = "-1" if output_to_parse.status not in [Status.SATISFIED, Status.ALL_SOLUTIONS, Status.OPTIMAL_SOLUTION]: - solutions = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name, time, memory, {}, '0') - solutions['status'] = 'UNSATISFIABLE' + solutions = convert_solver_solution_to_dictionary( + self._cipher, model_type, solver_name, time, memory, {}, "0" + ) + solutions["status"] = UNSATISFIABLE else: - if output_to_parse.statistics['nSolutions'] == 1 or type(output_to_parse.solution) != list: + if output_to_parse.statistics["nSolutions"] == 1 or (not isinstance(output_to_parse.solution, list)): components_values, total_weight = set_solution_values_internal(output_to_parse.solution) - solutions = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name, time, memory, components_values, total_weight) + solutions = convert_solver_solution_to_dictionary( + self._cipher, model_type, solver_name, time, memory, components_values, total_weight + ) if output_to_parse.status not in [Status.SATISFIED, Status.ALL_SOLUTIONS, Status.OPTIMAL_SOLUTION]: - solutions['status'] = 'UNSATISFIABLE' + solutions["status"] = UNSATISFIABLE else: - solutions['status'] = 'SATISFIABLE' + solutions["status"] = SATISFIABLE else: solutions = [] for solution in output_to_parse.solution: components_values, total_weight = set_solution_values_internal(solution) - solution = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name, time, memory, components_values, total_weight) - if output_to_parse.status not in [Status.SATISFIED, Status.ALL_SOLUTIONS, Status.OPTIMAL_SOLUTION]: - solution['status'] = 'UNSATISFIABLE' + solution = convert_solver_solution_to_dictionary( + self._cipher, model_type, solver_name, time, memory, components_values, total_weight + ) + if output_to_parse.status not in [ + Status.SATISFIED, + Status.ALL_SOLUTIONS, + Status.OPTIMAL_SOLUTION, + ]: + solution["status"] = UNSATISFIABLE else: - solution['status'] = 'SATISFIABLE' + solution["status"] = SATISFIABLE solutions.append(solution) return solutions - + if truncated: return time, memory, components_values return time, memory, components_values, total_weight @@ -619,16 +673,26 @@ def set_solution_values_internal(solution): def set_component_solution_value(self, component_solution, truncated, value): if not truncated: bin_value = int(value, 2) - hex_value = f'{bin_value:x}' - hex_value = ('0x' + '0' * (math.ceil(len(value) / 4) - len(hex_value))) + hex_value - component_solution['value'] = hex_value + hex_value = f"{bin_value:x}" + hex_value = ("0x" + "0" * (math.ceil(len(value) / 4) - len(hex_value))) + hex_value + component_solution["value"] = hex_value else: - component_solution['value'] = value - - def solve(self, model_type, solver_name=SOLVER_DEFAULT, solve_external=False, timeout_in_seconds_=None, - processes_=None, nr_solutions_=None, random_seed_=None, - all_solutions_=False, intermediate_solutions_=False, - free_search_=False, optimisation_level_=None): + component_solution["value"] = value + + def solve( + self, + model_type, + solver_name=SOLVER_DEFAULT, + solve_external=False, + timeout_in_seconds_=None, + processes_=None, + nr_solutions_=None, + random_seed_=None, + all_solutions_=False, + intermediate_solutions_=False, + free_search_=False, + optimisation_level_=None, + ): """ Return the solution of the model. @@ -644,11 +708,8 @@ def solve(self, model_type, solver_name=SOLVER_DEFAULT, solve_external=False, ti * 'deterministic_truncated_xor_differential' * 'deterministic_truncated_xor_differential_one_solution' * 'impossible_xor_differential' - - ``solver_name`` -- **string** (default: `None`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `None`); the name of the solver. + See also :meth:`MznModel.solver_names`. - ``num_of_processors`` -- **integer**; the number of processors to be used - ``timelimit`` -- **integer**; time limit to output a result @@ -667,24 +728,22 @@ def solve(self, model_type, solver_name=SOLVER_DEFAULT, solve_external=False, ti 'total_weight': '5.0'}] """ truncated = False - if model_type in ['deterministic_truncated_xor_differential', - 'deterministic_truncated_xor_differential_one_solution', - 'impossible_xor_differential', - 'impossible_xor_differential_one_solution', - 'impossible_xor_differential_attack']: + if model_type in ( + "deterministic_truncated_xor_differential_one_solution", + "deterministic_truncated_xor_differential", + "impossible_xor_differential_attack", + "impossible_xor_differential_one_solution", + "impossible_xor_differential", + ): truncated = True solutions = [] if solve_external: - cipher_name = self.cipher_id - input_file_path = f'{MODEL_DEFAULT_PATH}/{cipher_name}_Mzn_{model_type}_{solver_name}.mzn' - command = self.get_command_for_solver_process( - input_file_path, model_type, solver_name, processes_, timeout_in_seconds_ - ) + command = self.get_command_for_solver_process(model_type, solver_name, processes_, timeout_in_seconds_) + model = "\n".join(self._model_constraints) + "\n" start = time.time() - solver_process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + solver_process = subprocess.run(command, input=model, capture_output=True, text=True) end = time.time() solve_time = end - start - os.remove(input_file_path) if solver_process.returncode >= 0: solver_output = solver_process.stdout.splitlines() else: @@ -696,46 +755,83 @@ def solve(self, model_type, solver_name=SOLVER_DEFAULT, solve_external=False, ti instance = Instance(solver_name_mzn, bit_mzn_model) start = time.time() if processes_ != None and timeout_in_seconds_ != None: - solver_output = instance.solve(processes=processes_, timeout=timedelta(seconds=int(timeout_in_seconds_)), - nr_solutions=nr_solutions_, random_seed=random_seed_, all_solutions=all_solutions_, - intermediate_solutions=intermediate_solutions_, free_search=free_search_, - optimisation_level=optimisation_level_) + solver_output = instance.solve( + processes=processes_, + timeout=timedelta(seconds=int(timeout_in_seconds_)), + nr_solutions=nr_solutions_, + random_seed=random_seed_, + all_solutions=all_solutions_, + intermediate_solutions=intermediate_solutions_, + free_search=free_search_, + optimisation_level=optimisation_level_, + ) else: - solver_output = instance.solve(nr_solutions=nr_solutions_, random_seed=random_seed_, all_solutions=all_solutions_, - intermediate_solutions=intermediate_solutions_, free_search=free_search_, - optimisation_level=optimisation_level_) + solver_output = instance.solve( + nr_solutions=nr_solutions_, + random_seed=random_seed_, + all_solutions=all_solutions_, + intermediate_solutions=intermediate_solutions_, + free_search=free_search_, + optimisation_level=optimisation_level_, + ) end = time.time() solve_time = end - start - return self._parse_solver_output(solver_output, model_type, truncated = truncated, solve_external = solve_external, solver_name=solver_name) + return self._parse_solver_output( + solver_output, model_type, truncated=truncated, solve_external=solve_external, solver_name=solver_name + ) if truncated: - solver_time, memory, components_values = self._parse_solver_output(solver_output, model_type, truncated = True, solve_external = solve_external) + solver_time, memory, components_values = self._parse_solver_output( + solver_output, model_type, truncated=True, solve_external=solve_external + ) total_weight = 0 else: - solver_time, memory, components_values, total_weight = self._parse_solver_output(solver_output, model_type, solve_external = solve_external, solver_name=solver_name) + solver_time, memory, components_values, total_weight = self._parse_solver_output( + solver_output, model_type, solve_external=solve_external, solver_name=solver_name + ) if components_values == {}: - solution = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name, - solve_time, memory, - components_values, total_weight) - if '=====UNSATISFIABLE=====' in solver_output: - solution['status'] = 'UNSATISFIABLE' + solution = convert_solver_solution_to_dictionary( + self._cipher, model_type, solver_name, solve_time, memory, components_values, total_weight + ) + if "=====UNSATISFIABLE=====" in solver_output: + solution["status"] = UNSATISFIABLE else: - solution['status'] = 'SATISFIABLE' + solution["status"] = SATISFIABLE solutions.append(solution) else: - self.add_solutions_from_components_values(components_values, memory, model_type, solutions, solve_time, - solver_name, solver_output, total_weight, solve_external) - if model_type in ['xor_differential_one_solution', - 'xor_linear_one_solution', - 'deterministic_truncated_one_solution', - 'impossible_xor_differential_one_solution']: + self.add_solutions_from_components_values( + components_values, + memory, + model_type, + solutions, + solve_time, + solver_name, + solver_output, + total_weight, + solve_external, + ) + if model_type in ( + "deterministic_truncated_one_solution", + "impossible_xor_differential_one_solution", + "xor_differential_one_solution", + "xor_linear_one_solution", + CIPHER, + ): return solutions[0] else: return solutions - - def solve_for_ARX(self, solver_name=None, timeout_in_seconds_=30, - processes_=4, nr_solutions_=None, random_seed_=None, - all_solutions_=False, intermediate_solutions_=False, - free_search_=False, optimisation_level_=None): + + def solve_for_ARX( + self, + solver_name=None, + timeout_in_seconds_=30, + processes_=4, + nr_solutions_=None, + random_seed_=None, + all_solutions_=False, + intermediate_solutions_=False, + free_search_=False, + optimisation_level_=None, + ): """ Solve the model passed in `str_model_path` by using `MiniZinc` and `str_solver``. @@ -770,6 +866,7 @@ def solve_for_ARX(self, solver_name=None, timeout_in_seconds_=30, sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import MznXorDifferentialModelARXOptimized + sage: from claasp.cipher_modules.models.cp.solvers import CPSAT sage: speck = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) sage: minizinc = MznXorDifferentialModelARXOptimized(speck) sage: bit_positions = [i for i in range(speck.output_bit_size)] @@ -785,7 +882,7 @@ def solve_for_ARX(self, solver_name=None, timeout_in_seconds_=30, ....: 'operator': '=', ....: 'value': '0' }) sage: minizinc.build_xor_differential_trail_model(-1, fixed_variables) - sage: result = minizinc.solve_for_ARX('Xor') + sage: result = minizinc.solve_for_ARX(CPSAT) sage: result.statistics['nSolutions'] 1 """ @@ -797,33 +894,52 @@ def solve_for_ARX(self, solver_name=None, timeout_in_seconds_=30, bit_mzn_model.add_string(mzn_model_string) instance = Instance(solver_name_mzn, bit_mzn_model) if processes_ != None and timeout_in_seconds_ != None: - result = instance.solve(processes=processes_, timeout=timedelta(seconds=int(timeout_in_seconds_)), - nr_solutions=nr_solutions_, random_seed=random_seed_, all_solutions=all_solutions_, - intermediate_solutions=intermediate_solutions_, free_search=free_search_, - optimisation_level=optimisation_level_) + result = instance.solve( + processes=processes_, + timeout=timedelta(seconds=int(timeout_in_seconds_)), + nr_solutions=nr_solutions_, + random_seed=random_seed_, + all_solutions=all_solutions_, + intermediate_solutions=intermediate_solutions_, + free_search=free_search_, + optimisation_level=optimisation_level_, + ) else: - result = instance.solve(nr_solutions=nr_solutions_, random_seed=random_seed_, all_solutions=all_solutions_, - intermediate_solutions=intermediate_solutions_, free_search=free_search_, - optimisation_level=optimisation_level_) + result = instance.solve( + nr_solutions=nr_solutions_, + random_seed=random_seed_, + all_solutions=all_solutions_, + intermediate_solutions=intermediate_solutions_, + free_search=free_search_, + optimisation_level=optimisation_level_, + ) return result - def solver_names(self, verbose = False): + def solver_names(self, verbose: bool = False) -> None: + """ + Print the available MiniZinc solvers. + + INPUT: + + - ``verbose`` -- **bool**; beside the solver name, it will be printed the brand name. + + """ if not verbose: - print('Internal CP solvers:') - print('solver brand name | solver name') + print("Internal CP solvers:") + print("solver brand name | solver name") for i in range(len(CP_SOLVERS_INTERNAL)): - print(f'{CP_SOLVERS_INTERNAL[i]["solver_brand_name"]} | {CP_SOLVERS_INTERNAL[i]["solver_name"]}') - print('\n') - print('External CP solvers:') - print('solver brand name | solver name') + print(f"{CP_SOLVERS_INTERNAL[i]['solver_brand_name']} | {CP_SOLVERS_INTERNAL[i]['solver_name']}") + print("\n") + print("External CP solvers:") + print("solver brand name | solver name") for i in range(len(CP_SOLVERS_EXTERNAL)): - print(f'{CP_SOLVERS_EXTERNAL[i]["solver_brand_name"]} | {CP_SOLVERS_EXTERNAL[i]["solver_name"]}') + print(f"{CP_SOLVERS_EXTERNAL[i]['solver_brand_name']} | {CP_SOLVERS_EXTERNAL[i]['solver_name']}") else: - print('Internal CP solvers:') + print("Internal CP solvers:") print(CP_SOLVERS_INTERNAL) - print('\n') - print('External CP solvers:') + print("\n") + print("External CP solvers:") print(CP_SOLVERS_EXTERNAL) def weight_constraints(self, weight): @@ -847,7 +963,7 @@ def weight_constraints(self, weight): if weight == 0 or weight == -1: cp_declarations = [] else: - cp_declarations = [f'constraint weight = {100 * weight};'] + cp_declarations = [f"constraint weight = {100 * weight};"] return cp_declarations, cp_constraints @@ -860,17 +976,20 @@ def write_minizinc_model_to_file(self, file_path, prefix=""): - ``file_path`` -- **string**; the path of the file that will contain the model - ``prefix`` -- **str** (default: ``) """ - model_string = "\n".join(self.mzn_comments) + "\n".join(self._variables_list) + \ - "\n".join(self._model_constraints) + "\n".join(self.mzn_output_directives) + \ - "\n".join(self.mzn_carries_output_directives) + model_string = ( + "\n".join(self.mzn_comments) + + "\n".join(self._variables_list) + + "\n".join(self._model_constraints) + + "\n".join(self.mzn_output_directives) + + "\n".join(self.mzn_carries_output_directives) + ) if prefix == "": - filename = f'{file_path}/{self.cipher_id}_mzn_{self.sat_or_milp}.mzn' + filename = f"{file_path}/{self.cipher_id}_mzn_{self.sat_or_milp}.mzn" else: - filename = f'{file_path}/{prefix}_{self.cipher_id}_mzn_{self.sat_or_milp}.mzn' + filename = f"{file_path}/{prefix}_{self.cipher_id}_mzn_{self.sat_or_milp}.mzn" - f = open(filename, "w") - f.write(model_string) - f.close() + with open(filename, "w") as file: + file.write(model_string) @property def cipher(self): @@ -905,5 +1024,5 @@ def model_constraints(self): ValueError: No model generated """ if not self._model_constraints: - raise ValueError('No model generated') + raise ValueError("No model generated") return self._model_constraints diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized.py index e45679d7d..74d2ab7ab 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized.py @@ -19,13 +19,14 @@ from claasp.cipher_modules.graph_generator import split_cipher_graph_into_top_bottom from claasp.cipher_modules.models.cp.minizinc_utils.mzn_bct_predicates import get_bct_operations -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import \ - MznXorDifferentialModelARXOptimized +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import ( + MznXorDifferentialModelARXOptimized, +) from claasp.cipher_modules.models.cp.minizinc_utils.utils import group_strings_by_pattern class MznBoomerangModelARXOptimized(MznXorDifferentialModelARXOptimized): - def __init__(self, cipher, top_end_ids, bottom_start_ids, middle_ids, window_size_list=None, sat_or_milp='sat'): + def __init__(self, cipher, top_end_ids, bottom_start_ids, middle_ids, window_size_list=None, sat_or_milp="sat"): self.top_end_ids = top_end_ids self.bottom_start_ids = bottom_start_ids self.sboxes_ids = middle_ids @@ -38,8 +39,9 @@ def __init__(self, cipher, top_end_ids, bottom_start_ids, middle_ids, window_siz self.probability_vars = None self.filename = None super().__init__(cipher, window_size_list, None, sat_or_milp) - self.top_graph, self.bottom_graph = split_cipher_graph_into_top_bottom(cipher, self.top_end_ids, - self.bottom_start_ids) + self.top_graph, self.bottom_graph = split_cipher_graph_into_top_bottom( + cipher, self.top_end_ids, self.bottom_start_ids + ) self.create_top_and_bottom_ciphers_from_subgraphs() @staticmethod @@ -51,7 +53,9 @@ def remove_empty_rounds(cipher): @staticmethod def reduce_cipher(new_cipher, original_cipher, graph): for round_number in range(new_cipher.number_of_rounds): - MznBoomerangModelARXOptimized.remove_components_not_in_graph(new_cipher, original_cipher, round_number, graph) + MznBoomerangModelARXOptimized.remove_components_not_in_graph( + new_cipher, original_cipher, round_number, graph + ) @staticmethod def remove_components_not_in_graph(new_cipher, original_cipher, round_number, graph): @@ -69,7 +73,7 @@ def remove_component(new_cipher, component): @staticmethod def initialize_bottom_cipher(original_cipher): bottom_cipher = deepcopy(original_cipher) - bottom_cipher._id = f'{original_cipher.id}_bottom' + bottom_cipher._id = f"{original_cipher.id}_bottom" return bottom_cipher def setup_bottom_cipher_inputs(self, bottom_cipher, original_cipher): @@ -81,12 +85,9 @@ def setup_bottom_cipher_inputs(self, bottom_cipher, original_cipher): for middle_id in self.sboxes_ids: bottom_cipher._inputs.append(middle_id) - bottom_cipher._inputs_bit_size.append( - original_cipher.get_component_from_id(middle_id).output_bit_size - ) + bottom_cipher._inputs_bit_size.append(original_cipher.get_component_from_id(middle_id).output_bit_size) def update_bottom_cipher_inputs(self, bottom_cipher, original_cipher, initial_nodes, new_input_bit_positions): - for node_id in initial_nodes: if node_id in original_cipher.inputs: bottom_cipher._inputs.append(node_id) @@ -126,7 +127,7 @@ def create_bottom_cipher(self, original_cipher): def create_top_cipher(self, original_cipher): top_cipher = deepcopy(original_cipher) - top_cipher._id = f'{original_cipher.id}_top' + top_cipher._id = f"{original_cipher.id}_top" MznBoomerangModelARXOptimized.reduce_cipher(top_cipher, original_cipher, self.top_graph) MznBoomerangModelARXOptimized.remove_empty_rounds(top_cipher) return top_cipher @@ -139,32 +140,35 @@ def reset_round_ids(cipher): @staticmethod def objective_generator(mzn_top_cipher, mzn_bottom_cipher): objective_string = [] - modular_addition_concatenation = "++".join(mzn_top_cipher.probability_vars) + "++" + "++".join( - mzn_bottom_cipher.probability_vars) - objective_string.append(f'solve:: int_search({modular_addition_concatenation},' - f' smallest, indomain_min, complete)') - objective_string.append(f'minimize sum({modular_addition_concatenation});') - mzn_top_cipher.mzn_output_directives.append(f'output ["Total_Probability: "++show(sum(' - f'{modular_addition_concatenation}))];') + modular_addition_concatenation = ( + "++".join(mzn_top_cipher.probability_vars) + "++" + "++".join(mzn_bottom_cipher.probability_vars) + ) + objective_string.append( + f"solve:: int_search({modular_addition_concatenation}, smallest, indomain_min, complete)" + ) + objective_string.append(f"minimize sum({modular_addition_concatenation});") + mzn_top_cipher.mzn_output_directives.append( + f'output ["Total_Probability: "++show(sum({modular_addition_concatenation}))];' + ) return objective_string def create_boomerang_model(self, fixed_variables_for_top_cipher, fixed_variables_for_bottom_cipher): self.differential_model_top_cipher = MznXorDifferentialModelARXOptimized( - self.top_cipher, window_size_list=[0 for _ in range(self.top_cipher.number_of_rounds)], - sat_or_milp='sat', include_word_operations_mzn_file=False - ) - self.differential_model_top_cipher.build_xor_differential_trail_model( - -1, fixed_variables_for_top_cipher + self.top_cipher, + window_size_list=[0 for _ in range(self.top_cipher.number_of_rounds)], + sat_or_milp="sat", + include_word_operations_mzn_file=False, ) + self.differential_model_top_cipher.build_xor_differential_trail_model(-1, fixed_variables_for_top_cipher) self.differential_model_bottom_cipher = MznXorDifferentialModelARXOptimized( - self.bottom_cipher, window_size_list=[0 for _ in range(self.bottom_cipher.number_of_rounds)], - sat_or_milp='sat', include_word_operations_mzn_file=False - ) - self.differential_model_bottom_cipher.build_xor_differential_trail_model( - -1, fixed_variables_for_bottom_cipher + self.bottom_cipher, + window_size_list=[0 for _ in range(self.bottom_cipher.number_of_rounds)], + sat_or_milp="sat", + include_word_operations_mzn_file=False, ) + self.differential_model_bottom_cipher.build_xor_differential_trail_model(-1, fixed_variables_for_bottom_cipher) for sbox_component_id in self.sboxes_ids: sbox_component = self.original_cipher.get_component_from_id(sbox_component_id) @@ -172,22 +176,29 @@ def create_boomerang_model(self, fixed_variables_for_top_cipher, fixed_variables self.differential_model_bottom_cipher.add_constraint_from_str(bct_mzn_model) self.differential_model_bottom_cipher.extend_model_constraints( - MznBoomerangModelARXOptimized.objective_generator(self.differential_model_top_cipher, - self.differential_model_bottom_cipher) + MznBoomerangModelARXOptimized.objective_generator( + self.differential_model_top_cipher, self.differential_model_bottom_cipher + ) ) self.differential_model_bottom_cipher.extend_model_constraints( - self.differential_model_bottom_cipher.weight_constraints(max_weight=None, weight=None, operator=">=")) + self.differential_model_bottom_cipher.weight_constraints(max_weight=None, weight=None, operator=">=") + ) self.differential_model_top_cipher.extend_model_constraints( - self.differential_model_top_cipher.weight_constraints(max_weight=None, weight=None, operator=">=")) + self.differential_model_top_cipher.weight_constraints(max_weight=None, weight=None, operator=">=") + ) from claasp.cipher_modules.models.sat.utils.mzn_predicates import get_word_operations + self._model_constraints.extend([get_word_operations()]) self._model_constraints.extend([get_bct_operations()]) - self._variables_list.extend(self.differential_model_top_cipher.get_variables() + - self.differential_model_bottom_cipher.get_variables()) - self._model_constraints.extend(self.differential_model_top_cipher.get_model_constraints() + - self.differential_model_bottom_cipher.get_model_constraints()) + self._variables_list.extend( + self.differential_model_top_cipher.get_variables() + self.differential_model_bottom_cipher.get_variables() + ) + self._model_constraints.extend( + self.differential_model_top_cipher.get_model_constraints() + + self.differential_model_bottom_cipher.get_model_constraints() + ) top_cipher_probability_vars = self.differential_model_top_cipher.probability_vars bottom_cipher_probability_vars = self.differential_model_bottom_cipher.probability_vars @@ -195,22 +206,29 @@ def create_boomerang_model(self, fixed_variables_for_top_cipher, fixed_variables def write_minizinc_model_to_file(self, file_path, prefix=""): model_string_top = "\n".join(self.differential_model_top_cipher.mzn_comments) + "\n".join( - self.differential_model_top_cipher.mzn_output_directives) + self.differential_model_top_cipher.mzn_output_directives + ) model_string_bottom = "\n".join(self.differential_model_bottom_cipher.mzn_comments) + "\n".join( - self.differential_model_bottom_cipher.mzn_output_directives) + self.differential_model_bottom_cipher.mzn_output_directives + ) if prefix == "": - filename = f'{file_path}/{self.original_cipher.id}_mzn_{self.differential_model_top_cipher.sat_or_milp}.mzn' + filename = f"{file_path}/{self.original_cipher.id}_mzn_{self.differential_model_top_cipher.sat_or_milp}.mzn" self.filename = filename else: - filename = f'{file_path}/{prefix}_{self.original_cipher.id}_mzn_' - filename += f'{self.differential_model_top_cipher.sat_or_milp}.mzn' + filename = f"{file_path}/{prefix}_{self.original_cipher.id}_mzn_" + filename += f"{self.differential_model_top_cipher.sat_or_milp}.mzn" self.filename = filename f = open(filename, "w") f.write( - model_string_top + "\n" + model_string_bottom + "\n" + "\n".join(self._variables_list) + "\n" + "\n".join( - self._model_constraints) + model_string_top + + "\n" + + model_string_bottom + + "\n" + + "\n".join(self._variables_list) + + "\n" + + "\n".join(self._model_constraints) ) f.close() @@ -220,13 +238,13 @@ def parse_components_with_solution(self, result, solution): def get_hex_from_sublists(sublists, bool_dict): hex_values = {} for sublist in sublists: - bit_str = ''.join(['1' if bool_dict[val] else '0' for val in sublist]) + bit_str = "".join(["1" if bool_dict[val] else "0" for val in sublist]) component_id = sublist[0][:-3] weight = 0 - if component_id.startswith('modadd') and component_id not in self.sboxes_ids: - p_modadd_var = [s for s in bool_dict.keys() if s.startswith(f'p_{component_id}')] + if component_id.startswith("modadd") and component_id not in self.sboxes_ids: + p_modadd_var = [s for s in bool_dict.keys() if s.startswith(f"p_{component_id}")] weight = sum(bool_dict[p_modadd_var[0]]) - hex_values[component_id] = {'value': hex(int(bit_str, 2)), 'weight': weight, 'sign': 1} + hex_values[component_id] = {"value": hex(int(bit_str, 2)), "weight": weight, "sign": 1} return hex_values @@ -234,27 +252,27 @@ def get_hex_from_sublists(sublists, bool_dict): list_of_sublist_of_vars = group_strings_by_pattern(self._variables_list) dict_of_component_value = get_hex_from_sublists(list_of_sublist_of_vars, solution.__dict__) - return {'component_values': dict_of_component_value} + return {"component_values": dict_of_component_value} def bct_parse_result(self, result, solver_name, total_weight, model_type): - parsed_result = {'id': self.cipher_id, 'model_type': model_type, 'solver_name': solver_name} + parsed_result = {"id": self.cipher_id, "model_type": model_type, "solver_name": solver_name} if total_weight == "list_of_solutions": solutions = [] for solution in result.solution: - parsed_solution = {'total_weight': None, 'component_values': {}} + parsed_solution = {"total_weight": None, "component_values": {}} parsed_solution_non_linear = self.parse_components_with_solution(result, solution) solution_total_weight = 0 for _, item_value_and_weight in parsed_solution.items(): - solution_total_weight += item_value_and_weight['weight'] - parsed_solution['total_weight'] = solution_total_weight + solution_total_weight += item_value_and_weight["weight"] + parsed_solution["total_weight"] = solution_total_weight parsed_solution = {**parsed_solution_non_linear, **parsed_result} solutions.append(parsed_solution) return solutions else: - parsed_result['total_weight'] = total_weight - parsed_result['statistics'] = result.statistics + parsed_result["total_weight"] = total_weight + parsed_result["statistics"] = result.statistics parsed_result = {**self.parse_components_with_solution(result, result.solution), **parsed_result} - parsed_result['statistics']['flatTime'] = parsed_result['statistics']['flatTime'].total_seconds() - parsed_result['statistics']['time'] = parsed_result['statistics']['time'].total_seconds() + parsed_result["statistics"]["flatTime"] = parsed_result["statistics"]["flatTime"].total_seconds() + parsed_result["statistics"]["time"] = parsed_result["statistics"]["time"].total_seconds() return parsed_result diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model.py index e213a2d96..ce8486c14 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model.py @@ -1,29 +1,36 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** -from claasp.cipher_modules.models.cp.mzn_model import MznModel, solve_satisfy -from claasp.name_mappings import (CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, MIX_COLUMN, LINEAR_LAYER, WORD_OPERATION, - CONSTANT, SBOX) +from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY +from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CIPHER, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class MznCipherModel(MznModel): - def __init__(self, cipher): super().__init__(cipher) @@ -36,18 +43,6 @@ def build_cipher_model(self, fixed_variables=[], second=False): - ``fixed_variables`` -- **list** (default: `[]`); dictionaries containing name, bit_size, value (as integer) for the variables that need to be fixed to a certain value: - { - - 'component_id': 'plaintext', - - 'constraint_type': 'equal'/'not_equal' - - 'bit_positions': [0, 1, 2, 3], - - 'binary_value': '[0, 0, 0, 0]' - - } - EXAMPLES:: sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_cipher_model import MznCipherModel @@ -65,15 +60,16 @@ def build_cipher_model(self, fixed_variables=[], second=False): variables = [] self._variables_list = [] constraints = self.fix_variables_value_constraints(fixed_variables) - component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR'] + component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR") self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: if component.type != SBOX: variables, constraints = component.cp_constraints() @@ -82,17 +78,17 @@ def build_cipher_model(self, fixed_variables=[], second=False): self._model_constraints.extend(constraints) self._variables_list.extend(variables) - + self._model_constraints.extend(self.final_constraints()) - + if not second: self._model_constraints = self._model_prefix + self._variables_list + self._model_constraints - def evaluate_model(self, fixed_values=[], solver_name='Chuffed'): - self.build_cipher_model(fixed_variables = fixed_values) - - self.solve('evaluate_cipher', solver_name) - + def find_missing_bits(self, fixed_values=[], solver_name=SOLVER_DEFAULT, solver_external=True): + self.build_cipher_model(fixed_variables=fixed_values) + solution = self.solve(CIPHER, solver_name=solver_name, solve_external=solver_external) + + return solution def final_constraints(self): """ @@ -112,18 +108,17 @@ def final_constraints(self): ['solve satisfy;'] """ cipher_inputs = self._cipher.inputs - cp_constraints = [solve_satisfy] - new_constraint = 'output[' + cp_constraints = [SOLVE_SATISFY] + new_constraint = "output[" for element in cipher_inputs: - new_constraint = f'{new_constraint}\"{element} = \"++ show({element}) ++ \"\\n\" ++' + new_constraint = f'{new_constraint}"{element} = "++ show({element}) ++ "\\n" ++' for component_id in self._cipher.get_all_components_ids(): - new_constraint = new_constraint + f'\"{component_id} = \"++ ' \ - f'show({component_id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - new_constraint = new_constraint[:-2] + '];' + new_constraint = new_constraint + f'"{component_id} = "++ show({component_id})++ "\\n" ++ "0" ++ "\\n" ++' + new_constraint = new_constraint[:-2] + "];" cp_constraints.append(new_constraint) return cp_constraints - + def input_constraints(self): """ Return a list of CP constraints for the inputs of the cipher. @@ -144,12 +139,12 @@ def input_constraints(self): 'array[0..31] of var 0..1: cipher_output_3_12;'] """ self.sbox_mant = [] - cp_declarations = [f'array[0..{bit_size - 1}] of var 0..1: {input_};' - for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)] + cp_declarations = [ + f"array[0..{bit_size - 1}] of var 0..1: {input_};" + for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + ] for component in self._cipher.get_all_components(): if CONSTANT not in component.type: - output_id_link = component.id - output_size = int(component.output_bit_size) - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: {output_id_link};') + cp_declarations.append(f"array[0..{component.output_bit_size - 1}] of var 0..1: {component.id};") return cp_declarations diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model_arx_optimized.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model_arx_optimized.py index 5c13bca2a..c37e0bb1a 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model_arx_optimized.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model_arx_optimized.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -22,8 +21,7 @@ class MznCipherModelARXOptimized(MznModel): - - def __init__(self, cipher, window_size_list=None, probability_weight_per_round=None, sat_or_milp='sat'): + def __init__(self, cipher, window_size_list=None, probability_weight_per_round=None, sat_or_milp="sat"): super().__init__(cipher, window_size_list, probability_weight_per_round, sat_or_milp) def build_cipher_model(self, fixed_variables=[]): @@ -52,13 +50,14 @@ def build_cipher_model(self, fixed_variables=[]): constraints = self.fix_variables_value_constraints_for_ARX(fixed_variables) self._model_constraints = constraints component_types = [CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, WORD_OPERATION] - operation_types = ['ROTATE', 'SHIFT', 'XOR'] + operation_types = ["ROTATE", "SHIFT", "XOR"] for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.minizinc_constraints(self) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model.py index a0533f677..e0d163f31 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model.py @@ -1,80 +1,87 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** +from minizinc import Status -import os -import math -import itertools -import subprocess - -from minizinc import Instance, Model, Solver, Status - -from claasp.cipher_modules.models.cp.mzn_model import MznModel, solve_satisfy -from claasp.cipher_modules.models.utils import write_model_to_file, convert_solver_solution_to_dictionary, check_if_implemented_component -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, - WORD_OPERATION, DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL) -from claasp.cipher_modules.models.cp.solvers import MODEL_DEFAULT_PATH, SOLVER_DEFAULT +from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY +from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT +from claasp.cipher_modules.models.utils import ( + check_if_implemented_component, + convert_solver_solution_to_dictionary, +) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + SBOX, +) class MznDeterministicTruncatedXorDifferentialModel(MznModel): - def __init__(self, cipher): super().__init__(cipher) - def add_solutions_from_components_values(self, components_values, memory, model_type, solutions, solve_time, - solver_name, solver_output, total_weight, solve_external = False): + def add_solutions_from_components_values( + self, + components_values, + memory, + model_type, + solutions, + solve_time, + solver_name, + solver_output, + total_weight, + solve_external=False, + ): for nsol in components_values.keys(): solution = convert_solver_solution_to_dictionary( - self.cipher_id, - model_type, - solver_name, - solve_time, - memory, - components_values[nsol], - 0) + self.cipher_id, model_type, solver_name, solve_time, memory, components_values[nsol], 0 + ) if solve_external: - if 'UNSATISFIABLE' in solver_output[0]: - solution['status'] = 'UNSATISFIABLE' + if "UNSATISFIABLE" in solver_output[0]: + solution["status"] = "UNSATISFIABLE" else: - solution['status'] = 'SATISFIABLE' + solution["status"] = "SATISFIABLE" else: if solver_output.status not in [Status.SATISFIED, Status.ALL_SOLUTIONS, Status.OPTIMAL_SOLUTION]: - solution['status'] = 'UNSATISFIABLE' + solution["status"] = "UNSATISFIABLE" else: - solution['status'] = 'SATISFIABLE' + solution["status"] = "SATISFIABLE" solutions.append(solution) - def add_solution_to_components_values(self, component_id, component_solution, components_values, j, output_to_parse, - solution_number, string): + def add_solution_to_components_values( + self, component_id, component_solution, components_values, j, output_to_parse, solution_number, string + ): if component_id in self._cipher.inputs: - components_values[f'solution{solution_number}'][f'{component_id}'] = component_solution - elif f'{component_id}_i' in string: - components_values[f'solution{solution_number}'][f'{component_id}_i'] = component_solution - elif f'{component_id}_o' in string: - components_values[f'solution{solution_number}'][f'{component_id}_o'] = component_solution - elif f'{component_id} ' in string: - components_values[f'solution{solution_number}'][f'{component_id}'] = component_solution - - def add_solution_to_components_values_internal(self, component_solution, components_values, component_weight, - solution_number, component): - components_values[f'solution{solution_number}'][f'{component}'] = component_solution - - def build_deterministic_truncated_xor_differential_trail_model(self, fixed_variables=[], number_of_rounds=None, minimize=False, wordwise=False): + components_values[f"solution{solution_number}"][f"{component_id}"] = component_solution + elif f"{component_id}_i" in string: + components_values[f"solution{solution_number}"][f"{component_id}_i"] = component_solution + elif f"{component_id}_o" in string: + components_values[f"solution{solution_number}"][f"{component_id}_o"] = component_solution + elif f"{component_id} " in string: + components_values[f"solution{solution_number}"][f"{component_id}"] = component_solution + + def add_solution_to_components_values_internal( + self, component_solution, components_values, component_weight, solution_number, component + ): + components_values[f"solution{solution_number}"][f"{component}"] = component_solution + + def build_deterministic_truncated_xor_differential_trail_model( + self, fixed_variables=[], number_of_rounds=None, minimize=False, wordwise=False + ): """ Build the CP model for the search of deterministic truncated XOR differential trails. @@ -107,7 +114,7 @@ def build_deterministic_truncated_xor_differential_trail_model(self, fixed_varia variables, constraints = self.propagate_deterministically(component, wordwise) self._variables_list.extend(variables) deterministic_truncated_xor_differential.extend(constraints) - + if not wordwise: variables, constraints = self.input_deterministic_truncated_xor_differential_constraints() else: @@ -115,10 +122,14 @@ def build_deterministic_truncated_xor_differential_trail_model(self, fixed_varia self._model_prefix.extend(variables) self._variables_list.extend(constraints) if not wordwise: - deterministic_truncated_xor_differential.extend(self.final_deterministic_truncated_xor_differential_constraints(minimize)) + deterministic_truncated_xor_differential.extend( + self.final_deterministic_truncated_xor_differential_constraints(minimize) + ) else: - deterministic_truncated_xor_differential.extend(self.final_wordwise_deterministic_truncated_xor_differential_constraints(minimize)) - + deterministic_truncated_xor_differential.extend( + self.final_wordwise_deterministic_truncated_xor_differential_constraints(minimize) + ) + self._model_constraints = self._model_prefix + self._variables_list + deterministic_truncated_xor_differential def final_deterministic_truncated_xor_differential_constraints(self, minimize=False): @@ -141,23 +152,30 @@ def final_deterministic_truncated_xor_differential_constraints(self, minimize=Fa cipher_inputs = self._cipher.inputs cipher = self._cipher cp_constraints = [] - new_constraint = 'output[' + new_constraint = "output[" for element in cipher_inputs: - new_constraint = f'{new_constraint}\"{element} = \"++ show({element}) ++ \"\\n\" ++' + new_constraint = f'{new_constraint}"{element} = "++ show({element}) ++ "\\n" ++' for component_id in cipher.get_all_components_ids(): - new_constraint = new_constraint + \ - f'\"{component_id} = \"++ show({component_id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - if 'cipher_output' in component_id and minimize: - cp_constraints.append(f'solve maximize count({self._cipher.get_all_components_ids()[-1]}, 0);') - new_constraint = new_constraint[:-2] + '];' + new_constraint = new_constraint + f'"{component_id} = "++ show({component_id})++ "\\n" ++ "0" ++ "\\n" ++' + if "cipher_output" in component_id and minimize: + cp_constraints.append(f"solve maximize count({self._cipher.get_all_components_ids()[-1]}, 0);") + new_constraint = new_constraint[:-2] + "];" if cp_constraints == []: - cp_constraints.append(solve_satisfy) + cp_constraints.append(SOLVE_SATISFY) cp_constraints.append(new_constraint) return cp_constraints - def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail(self, number_of_rounds=None, - fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return the solution representing a differential trail with any weight. @@ -165,11 +183,8 @@ def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential - ``number_of_rounds`` -- **integer** (default: `None`); number of rounds - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -203,21 +218,37 @@ def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential 'xor_0_4': {'value': '2222222222222220', 'weight': 0}}, 'memory_megabytes': 0.01, 'model_type': 'deterministic_truncated_xor_differential_one_solution', - 'solver_name': 'Chuffed', + 'solver_name': 'chuffed', 'solving_time_seconds': 0.0, 'total_weight': '0.0'}] """ if number_of_rounds is None: number_of_rounds = self._cipher.number_of_rounds - self.build_deterministic_truncated_xor_differential_trail_model(fixed_values, number_of_rounds, minimize = True) + self.build_deterministic_truncated_xor_differential_trail_model(fixed_values, number_of_rounds, minimize=True) if solve_with_API: - return self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) - return self.solve('deterministic_truncated_xor_differential_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) - - def find_all_deterministic_truncated_xor_differential_trails(self, number_of_rounds=None, - fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "deterministic_truncated_xor_differential_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + + def find_all_deterministic_truncated_xor_differential_trails( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return the solution representing a differential trail with any weight. @@ -225,11 +256,8 @@ def find_all_deterministic_truncated_xor_differential_trails(self, number_of_rou - ``number_of_rounds`` -- **integer**; number of rounds - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `None`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `None`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -255,7 +283,7 @@ def find_all_deterministic_truncated_xor_differential_trails(self, number_of_rou ... 'memory_megabytes': 0.02, 'model_type': 'deterministic_truncated_xor_differential', - 'solver_name': 'Chuffed', + 'solver_name': 'chuffed', 'solving_time_seconds': 0.002, 'total_weight': '0.0'}] """ @@ -265,11 +293,31 @@ def find_all_deterministic_truncated_xor_differential_trails(self, number_of_rou self.build_deterministic_truncated_xor_differential_trail_model(fixed_values, number_of_rounds) if solve_with_API: - return self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True) - return self.solve('deterministic_truncated_xor_differential', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True, solve_external = solve_external) - - def find_one_deterministic_truncated_xor_differential_trail(self, number_of_rounds=None, - fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + return self.solve_for_ARX( + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + ) + return self.solve( + "deterministic_truncated_xor_differential", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + solve_external=solve_external, + ) + + def find_one_deterministic_truncated_xor_differential_trail( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return the solution representing a differential trail with any weight. @@ -277,11 +325,8 @@ def find_one_deterministic_truncated_xor_differential_trail(self, number_of_roun - ``number_of_rounds`` -- **integer** (default: `None`); number of rounds - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -315,7 +360,7 @@ def find_one_deterministic_truncated_xor_differential_trail(self, number_of_roun 'xor_0_4': {'value': '2222222222222220', 'weight': 0}}, 'memory_megabytes': 0.01, 'model_type': 'deterministic_truncated_xor_differential_one_solution', - 'solver_name': 'Chuffed', + 'solver_name': 'chuffed', 'solving_time_seconds': 0.0, 'total_weight': '0.0'}] """ @@ -325,8 +370,16 @@ def find_one_deterministic_truncated_xor_differential_trail(self, number_of_roun self.build_deterministic_truncated_xor_differential_trail_model(fixed_values, number_of_rounds) if solve_with_API: - return self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) - return self.solve('deterministic_truncated_xor_differential_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "deterministic_truncated_xor_differential_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) def input_deterministic_truncated_xor_differential_constraints(self): """ @@ -352,19 +405,21 @@ def input_deterministic_truncated_xor_differential_constraints(self): number_of_rounds = self._cipher.number_of_rounds cp_constraints = [] - cp_declarations = [f'array[0..{bit_size - 1}] of var 0..2: {input_};' - for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)] + cp_declarations = [ + f"array[0..{bit_size - 1}] of var 0..2: {input_};" + for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + ] cipher = self._cipher rounds = number_of_rounds for component in cipher.get_all_components(): output_id_link = component.id output_size = int(component.output_bit_size) if CIPHER_OUTPUT in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: {output_id_link};') - cp_constraints.append(f'constraint count({output_id_link},2) < {output_size};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: {output_id_link};") + cp_constraints.append(f"constraint count({output_id_link},2) < {output_size};") elif CONSTANT not in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: {output_id_link};') - cp_constraints.append('constraint count(plaintext,1) > 0;') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: {output_id_link};") + cp_constraints.append("constraint count(plaintext,1) > 0;") return cp_declarations, cp_constraints @@ -396,8 +451,8 @@ def output_constraints(self, component): cp_declarations = [] all_inputs = [] for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[i]};' for i in range(output_size)] + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) + cp_constraints = [f"constraint {output_id_link}[{i}] = {all_inputs[i]};" for i in range(output_size)] return cp_declarations, cp_constraints @@ -429,20 +484,21 @@ def output_inverse_constraints(self, component): cp_declarations = [] all_inputs = [] for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) - cp_constraints = [f'constraint {output_id_link}_inverse[{i}] = {all_inputs[i]};' - for i in range(output_size)] + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) + cp_constraints = [f"constraint {output_id_link}_inverse[{i}] = {all_inputs[i]};" for i in range(output_size)] return cp_declarations, cp_constraints def propagate_deterministically(self, component, wordwise=False, inverse=False): if not wordwise: if component.type == SBOX: - variables, constraints, sbox_mant = component.cp_deterministic_truncated_xor_differential_trail_constraints(self.sbox_mant, inverse) + variables, constraints, sbox_mant = ( + component.cp_deterministic_truncated_xor_differential_trail_constraints(self.sbox_mant, inverse) + ) self.sbox_mant = sbox_mant else: variables, constraints = component.cp_deterministic_truncated_xor_differential_trail_constraints() else: variables, constraints = component.cp_wordwise_deterministic_truncated_xor_differential_constraints(self) - + return variables, constraints diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized.py index 35ea098fb..66f93361b 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized.py @@ -1,30 +1,27 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** from claasp.cipher_modules.models.cp.mzn_model import MznModel -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, - CIPHER_OUTPUT, WORD_OPERATION) +from claasp.name_mappings import CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, WORD_OPERATION class MznDeterministicTruncatedXorDifferentialModelARXOptimized(MznModel): - - def __init__(self, cipher, window_size_list=None, probability_weight_per_round=None, sat_or_milp='sat'): + def __init__(self, cipher, window_size_list=None, probability_weight_per_round=None, sat_or_milp="sat"): super().__init__(cipher, window_size_list, probability_weight_per_round, sat_or_milp) def build_deterministic_truncated_xor_differential_trail_model(self, fixed_variables=[]): @@ -59,10 +56,11 @@ def build_deterministic_truncated_xor_differential_trail_model(self, fixed_varia operation_types = ["ROTATE", "SHIFT"] if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): - variables, constraints = \ - component.minizinc_deterministic_truncated_xor_differential_trail_constraints(self) + variables, constraints = component.minizinc_deterministic_truncated_xor_differential_trail_constraints( + self + ) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model.py index 40ab3275c..38dd3782b 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -32,25 +32,36 @@ import ast import itertools import math -import os import subprocess from sage.combinat.permutation import Permutation from sage.crypto.sbox import SBox -from claasp.cipher_modules.models.cp.mzn_model import solve_satisfy, MznModel -from claasp.cipher_modules.models.cp.mzn_models.mzn_impossible_xor_differential_model import \ - MznImpossibleXorDifferentialModel +from claasp.cipher_modules.models.cp.mzn_model import SOLVE_SATISFY, MznModel +from claasp.cipher_modules.models.cp.mzn_models.mzn_impossible_xor_differential_model import ( + MznImpossibleXorDifferentialModel, +) from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import update_and_or_ddt_valid_probabilities -from claasp.cipher_modules.models.utils import convert_solver_solution_to_dictionary, check_if_implemented_component, \ - get_bit_bindings -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, SBOX, WORD_OPERATION, - IMPOSSIBLE_XOR_DIFFERENTIAL, INPUT_PLAINTEXT, INPUT_KEY) - +from claasp.cipher_modules.models.utils import ( + check_if_implemented_component, + convert_solver_solution_to_dictionary, + get_bit_bindings, +) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + IMPOSSIBLE_XOR_DIFFERENTIAL, + INPUT_KEY, + INPUT_PLAINTEXT, + INTERMEDIATE_OUTPUT, + SATISFIABLE, + SBOX, + UNSATISFIABLE, + WORD_OPERATION, +) class MznHybridImpossibleXorDifferentialModel(MznImpossibleXorDifferentialModel): - def __init__(self, cipher): super().__init__(cipher) self.sbox_size = None @@ -58,8 +69,14 @@ def __init__(self, cipher): self.sbox_ddt_values = [] def build_hybrid_impossible_xor_differential_trail_model( - self, fixed_variables=[], number_of_rounds=None, - initial_round=1, middle_round=1, final_round=None, intermediate_components=True, probabilistic=True + self, + fixed_variables=[], + number_of_rounds=None, + initial_round=1, + middle_round=1, + final_round=None, + intermediate_components=True, + probabilistic=True, ): """ Build the CP model for the search of hybrid impossible XOR differential trails. @@ -112,12 +129,14 @@ def build_hybrid_impossible_xor_differential_trail_model( if probabilistic: direct_variables, direct_constraints = self.build_improbable_forward_model(forward_components) - inverse_variables, inverse_constraints = self.build_improbable_backward_model(backward_components, - clean=False) + inverse_variables, inverse_constraints = self.build_improbable_backward_model( + backward_components, clean=False + ) else: direct_variables, direct_constraints = self.build_impossible_forward_model(forward_components) - inverse_variables, inverse_constraints = self.build_impossible_backward_model(backward_components, - clean=False) + inverse_variables, inverse_constraints = self.build_impossible_backward_model( + backward_components, clean=False + ) self._variables_list.extend(direct_variables) deterministic_truncated_xor_differential.extend(direct_constraints) @@ -136,8 +155,12 @@ def build_hybrid_impossible_xor_differential_trail_model( deterministic_truncated_xor_differential.extend( self.final_impossible_constraints( - number_of_rounds, initial_round, middle_round, final_round, - intermediate_components, probabilistic=probabilistic + number_of_rounds, + initial_round, + middle_round, + final_round, + intermediate_components, + probabilistic=probabilistic, ) ) set_of_constraints = self._variables_list + deterministic_truncated_xor_differential @@ -145,21 +168,23 @@ def build_hybrid_impossible_xor_differential_trail_model( self._model_constraints = self._model_prefix + self.clean_constraints( set_of_constraints, initial_round, middle_round, final_round ) + def build_improbable_forward_model(self, forward_components, clean=False): direct_variables = [] direct_constraints = [] key_components, key_ids = self.extract_key_schedule() for component in forward_components: if check_if_implemented_component(component): - variables, constraints = self.propagate_deterministically(component, - key_schedule=(component.id in key_ids), - probabilistic=True) + variables, constraints = self.propagate_deterministically( + component, key_schedule=(component.id in key_ids), probabilistic=True + ) direct_variables.extend(variables) direct_constraints.extend(constraints) if clean: direct_variables, direct_constraints = self.clean_inverse_impossible_variables_constraints( - forward_components, direct_variables, direct_constraints) + forward_components, direct_variables, direct_constraints + ) return direct_variables, direct_constraints @@ -170,9 +195,9 @@ def build_improbable_backward_model(self, backward_components, clean=True): constant_components, constant_ids = self.extract_constants() for component in backward_components: if check_if_implemented_component(component): - variables, constraints = self.propagate_deterministically(component, - key_schedule=(component.id in key_ids), - inverse=True, probabilistic=True) + variables, constraints = self.propagate_deterministically( + component, key_schedule=(component.id in key_ids), inverse=True, probabilistic=True + ) inverse_variables.extend(variables) inverse_constraints.extend(constraints) @@ -183,8 +208,11 @@ def build_improbable_backward_model(self, backward_components, clean=True): input_component = self.get_component_from_id(id_link, self.inverse_cipher) if input_component not in backward_components and id_link not in key_ids + constant_ids: components_to_invert.append(input_component) - inverse_variables, inverse_constraints = self.clean_inverse_impossible_variables_constraints_with_extensions( - components_to_invert, inverse_variables, inverse_constraints) + inverse_variables, inverse_constraints = ( + self.clean_inverse_impossible_variables_constraints_with_extensions( + components_to_invert, inverse_variables, inverse_constraints + ) + ) return inverse_variables, inverse_constraints @@ -201,17 +229,17 @@ def _find_paths(self, graph, end_node, stop_at=INPUT_PLAINTEXT, path=None): path = [] path = [end_node] + path - end_node = end_node[:-1] + ('i',) + end_node = end_node[:-1] + ("i",) # if permutation - if end_node[0] != 'plaintext': + if end_node[0] != "plaintext": component = self.cipher.get_component_from_id(end_node[0]) - if component.type == 'linear_layer': + if component.type == "linear_layer": matrix = component.description try: perm = Permutation([i + 1 for i in self._extract_ones(matrix)]).inverse() P = [i - 1 for i in perm] - end_node = (end_node[0], str(P[int(end_node[-2])])) + ('i',) + end_node = (end_node[0], str(P[int(end_node[-2])])) + ("i",) except Exception: pass @@ -228,7 +256,7 @@ def _get_graph_for_round(self, cipher_round): bit_bindings, intermediate_bit_bindings = get_bit_bindings(cipher_round) for interm_binding in intermediate_bit_bindings.values(): for key, value_list in interm_binding.items(): - filtered_values = [val for val in value_list if val[2] == 'o'] + filtered_values = [val for val in value_list if val[2] == "o"] for val in filtered_values: related_values = [other_val for other_val in value_list if other_val != val] bit_bindings[val] = bit_bindings.get(val, []) + [key] + related_values @@ -238,7 +266,7 @@ def _get_graph_for_round(self, cipher_round): def _get_output_bits_connected_to_sboxes(self, intermediate_output, graph): path_indices = {} for bit in range(intermediate_output.output_bit_size): - path = self._find_paths(graph, (f'{intermediate_output.id}', f'{bit}', 'i'), stop_at=SBOX) + path = self._find_paths(graph, (f"{intermediate_output.id}", f"{bit}", "i"), stop_at=SBOX) if path[0][0] not in path_indices: path_indices[path[0][0]] = [int(path[-1][1])] else: @@ -247,7 +275,7 @@ def _get_output_bits_connected_to_sboxes(self, intermediate_output, graph): def _output_is_aligned_with_sboxes(self, path_indices): for bit_positions in path_indices.values(): - if len(bit_positions ) <= 1: + if len(bit_positions) <= 1: return True lst = sorted(bit_positions) @@ -257,36 +285,45 @@ def _output_is_aligned_with_sboxes(self, path_indices): return True def _generate_wordwise_incompatibility_constraint(self, component): - if self.sbox_size: current_round = self._cipher.get_round_from_component_id(component.id) - wordwise_incompatibility_constraint = '' + wordwise_incompatibility_constraint = "" single_round = self._cipher.remove_key_schedule().rounds.components_in_round(current_round) - round_intermediate_output = [c for c in single_round if c.description == ['round_output']][0] + round_intermediate_output = [c for c in single_round if c.description == ["round_output"]][0] graph = self._get_graph_for_round(self._cipher) path_indices = self._get_output_bits_connected_to_sboxes(round_intermediate_output, graph) if self._cipher.is_spn() or self._output_is_aligned_with_sboxes(path_indices): intermediate_output_bit_positions = path_indices.values() else: - intermediate_output_bit_positions = itertools.combinations(range(len(path_indices.keys()) * self.sbox_size), self.sbox_size) - - for bit_positions, suffix in itertools.product(intermediate_output_bit_positions, ['', 'inverse_']): - constraint = '(' - constraint += '/\\'.join( - [f'({component.id}[{i}]+inverse_{component.id}[{i}]={suffix + component.id}[{i}])' for i in - bit_positions]) - constraint += '/\\' + '/\\'.join([f'({suffix + component.id}[{i}] > 2)' for i in bit_positions]) - constraint += '/\\' + '/\\'.join( - [f'({suffix + component.id}[{i}] = {suffix + component.id}[{bit_positions[0]}])' for i in - bit_positions[1:]]) - wordwise_incompatibility_constraint += constraint + ') \\/' + intermediate_output_bit_positions = itertools.combinations( + range(len(path_indices.keys()) * self.sbox_size), self.sbox_size + ) + + for bit_positions, suffix in itertools.product(intermediate_output_bit_positions, ["", "inverse_"]): + constraint = "(" + constraint += "/\\".join( + [ + f"({component.id}[{i}]+inverse_{component.id}[{i}]={suffix + component.id}[{i}])" + for i in bit_positions + ] + ) + constraint += "/\\" + "/\\".join([f"({suffix + component.id}[{i}] > 2)" for i in bit_positions]) + constraint += "/\\" + "/\\".join( + [ + f"({suffix + component.id}[{i}] = {suffix + component.id}[{bit_positions[0]}])" + for i in bit_positions[1:] + ] + ) + wordwise_incompatibility_constraint += constraint + ") \\/" return wordwise_incompatibility_constraint[:-3] else: - return 'False' + return "False" - def final_impossible_constraints(self, number_of_rounds, initial_round, middle_round, final_round, intermediate_components, probabilistic=False): + def final_impossible_constraints( + self, number_of_rounds, initial_round, middle_round, final_round, intermediate_components, probabilistic=False + ): """ Constraints for output and incompatibility. @@ -317,18 +354,27 @@ def final_impossible_constraints(self, number_of_rounds, initial_round, middle_r sage: final = mzn.final_impossible_constraints(3, 1, 2, 3, False, False) #long# doctest: +SKIP """ - cipher_inputs = self._cipher.inputs if initial_round == 1 else ['key'] + [ - comp.id for comp in self._cipher.get_components_in_round(initial_round - 2) if 'output' in comp.id - ] + cipher_inputs = ( + self._cipher.inputs + if initial_round == 1 + else ["key"] + + [comp.id for comp in self._cipher.get_components_in_round(initial_round - 2) if "output" in comp.id] + ) cipher = self._cipher - cipher_outputs = self.inverse_cipher.inputs if final_round == self._cipher.number_of_rounds else ['key'] + [ - comp.id for comp in self.inverse_cipher.get_components_in_round(self._cipher.number_of_rounds - final_round) - if 'output' in comp.id - ] - cp_constraints = [solve_satisfy] - new_constraint = 'output[' - bitwise_incompatibility_constraint = '' - wordwise_incompatibility_constraint = '' + cipher_outputs = ( + self.inverse_cipher.inputs + if final_round == self._cipher.number_of_rounds + else ["key"] + + [ + comp.id + for comp in self.inverse_cipher.get_components_in_round(self._cipher.number_of_rounds - final_round) + if "output" in comp.id + ] + ) + cp_constraints = [SOLVE_SATISFY] + new_constraint = "output[" + bitwise_incompatibility_constraint = "" + wordwise_incompatibility_constraint = "" key_schedule_components, key_schedule_components_ids = self.extract_key_schedule() @@ -348,44 +394,79 @@ def final_impossible_constraints(self, number_of_rounds, initial_round, middle_r component_inputs = [] input_bit_size = 0 for id_link, bit_positions in zip(input_id_links, input_bit_positions): - component_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + component_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) input_bit_size += len(bit_positions) for i in range(input_bit_size): - bitwise_incompatibility_constraint += f'({component_inputs[i]}+inverse_{component_id}[{i}]=1) \\/ ' - wordwise_incompatibility_constraint += f'({self._generate_wordwise_incompatibility_constraint(component)}) \\/ ' + bitwise_incompatibility_constraint += ( + f"({component_inputs[i]}+inverse_{component_id}[{i}]=1) \\/ " + ) + wordwise_incompatibility_constraint += ( + f"({self._generate_wordwise_incompatibility_constraint(component)}) \\/ " + ) if not probabilistic or component.type in [CIPHER_OUTPUT, INTERMEDIATE_OUTPUT]: for id_link in input_id_links: - new_constraint = new_constraint + \ - f'\"{id_link} = \"++ show({id_link})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - new_constraint = new_constraint + \ - f'\"inverse_{component_id} = \"++ show(inverse_{component_id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + new_constraint = ( + new_constraint + f'"{id_link} = "++ show({id_link})++ "\\n" ++ "0" ++ "\\n" ++' + ) + new_constraint = ( + new_constraint + + f'"inverse_{component_id} = "++ show(inverse_{component_id})++ "\\n" ++ "0" ++ "\\n" ++' + ) else: for component in cipher.get_all_components(): - extra_condition = (component.id in key_schedule_components_ids and component.description == ['round_key_output']) if probabilistic else True - if 'output' in component.id: + extra_condition = ( + (component.id in key_schedule_components_ids and component.description == ["round_key_output"]) + if probabilistic + else True + ) + if "output" in component.id: if self.get_component_round(component.id) <= middle_round - 1 and extra_condition: - new_constraint = new_constraint + \ - f'\"{component.id} = \"++ show({component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + new_constraint = ( + new_constraint + f'"{component.id} = "++ show({component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) if self.get_component_round(component.id) >= middle_round - 1 and extra_condition: - new_constraint = new_constraint + \ - f'\"inverse_{component.id} = \"++ show(inverse_{component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - if self.get_component_round( - component.id) == middle_round - 1 and component.id not in key_schedule_components_ids and component.description == [ - 'round_output']: + new_constraint = ( + new_constraint + + f'"inverse_{component.id} = "++ show(inverse_{component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) + if ( + self.get_component_round(component.id) == middle_round - 1 + and component.id not in key_schedule_components_ids + and component.description == ["round_output"] + ): for i in range(component.output_bit_size): - bitwise_incompatibility_constraint += f'({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ ' + bitwise_incompatibility_constraint += ( + f"({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ " + ) wordwise_incompatibility_constraint += self._generate_wordwise_incompatibility_constraint( - component) + component + ) bitwise_incompatibility_constraint = bitwise_incompatibility_constraint[:-4] - new_constraint = new_constraint[:-2] + '];' - wordwise_incompatibility_constraint = wordwise_incompatibility_constraint.rstrip(' \\/ ') - cp_constraints.append(f"constraint ({bitwise_incompatibility_constraint}) \\/ ({wordwise_incompatibility_constraint});") + new_constraint = new_constraint[:-2] + "];" + wordwise_incompatibility_constraint = wordwise_incompatibility_constraint.rstrip(" \\/ ") + cp_constraints.append( + f"constraint ({bitwise_incompatibility_constraint}) \\/ ({wordwise_incompatibility_constraint});" + ) cp_constraints.append(new_constraint) return cp_constraints - def find_all_impossible_xor_differential_trails(self, number_of_rounds=None, fixed_values=[], solver_name=None, initial_round = 1, middle_round=2, final_round = None, intermediate_components = True, probabilistic=False, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = True): + def find_all_impossible_xor_differential_trails( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=None, + initial_round=1, + middle_round=2, + final_round=None, + intermediate_components=True, + probabilistic=False, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): """ Search for all impossible XOR differential trails of a cipher. @@ -412,7 +493,7 @@ def find_all_impossible_xor_differential_trails(self, number_of_rounds=None, fix sage: fixed_variables = [set_fixed_variables('key', 'equal', range(76), [0]*76)] sage: fixed_variables.append(set_fixed_variables('plaintext', 'equal', range(64), [0]*64)) sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_3_19', 'equal', range(64), [0]*64)) - sage: trails = mzn.find_all_impossible_xor_differential_trails(4, fixed_variables, 'Chuffed', 1, 2, 4, intermediate_components=False) + sage: trails = mzn.find_all_impossible_xor_differential_trails(4, fixed_variables, 'chuffed', 1, 2, 4, intermediate_components=False) sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_hybrid_impossible_xor_differential_model import MznHybridImpossibleXorDifferentialModel sage: from claasp.ciphers.block_ciphers.lblock_block_cipher import LBlockBlockCipher @@ -422,18 +503,57 @@ def find_all_impossible_xor_differential_trails(self, number_of_rounds=None, fix sage: fixed_variables = [set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(80), bit_values=[0] * 49 + [1] + [0]*30)] #long# doctest: +SKIP sage: fixed_variables.append(set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions= range(64), bit_values= [0] * 60 + [1,0,0,0])) #long# doctest: +SKIP sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_17_19', 'equal', range(64), [0]*64)) #long# doctest: +SKIP - sage: trails = mzn.find_all_impossible_xor_differential_trails(18, fixed_variables, 'Chuffed', 1, 9, 18, intermediate_components=False, probabilistic=True) #long# doctest: +SKIP + sage: trails = mzn.find_all_impossible_xor_differential_trails(18, fixed_variables, 'chuffed', 1, 9, 18, intermediate_components=False, probabilistic=True) #long# doctest: +SKIP sage: len(trails) #long# doctest: +SKIP 6 - + """ - self.build_hybrid_impossible_xor_differential_trail_model(fixed_values, number_of_rounds, initial_round, middle_round, final_round, intermediate_components, probabilistic) + self.build_hybrid_impossible_xor_differential_trail_model( + fixed_values, + number_of_rounds, + initial_round, + middle_round, + final_round, + intermediate_components, + probabilistic, + ) if solve_with_API: - return self.solve_for_ARX(solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors, all_solutions_=True) - return self.solve(IMPOSSIBLE_XOR_DIFFERENTIAL, solver_name=solver_name, number_of_rounds=number_of_rounds, initial_round=initial_round, middle_round=middle_round, final_round=final_round, timeout_in_seconds_=timelimit, processes_=num_of_processors, all_solutions_=True, solve_external=solve_external, probabilistic=probabilistic) + return self.solve_for_ARX( + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + ) + return self.solve( + IMPOSSIBLE_XOR_DIFFERENTIAL, + solver_name=solver_name, + number_of_rounds=number_of_rounds, + initial_round=initial_round, + middle_round=middle_round, + final_round=final_round, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + solve_external=solve_external, + probabilistic=probabilistic, + ) - def find_one_impossible_xor_differential_trail(self, number_of_rounds=None, fixed_values=[], solver_name=None, initial_round=1, middle_round=2, final_round=None, intermediate_components=True, probabilistic=False, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external=True): + def find_one_impossible_xor_differential_trail( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=None, + initial_round=1, + middle_round=2, + final_round=None, + intermediate_components=True, + probabilistic=False, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): """ Search for one impossible XOR differential trail of a cipher. @@ -460,7 +580,7 @@ def find_one_impossible_xor_differential_trail(self, number_of_rounds=None, fixe sage: fixed_variables = [set_fixed_variables('key', 'equal', range(80), [0]*10+[1]+[0]*69)] sage: fixed_variables.append(set_fixed_variables('plaintext', 'equal', range(64), [0]*64)) sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_3_19', 'equal', range(64), [0]*64)) - sage: trail = mzn.find_one_impossible_xor_differential_trail(4, fixed_variables, 'Chuffed', 1, 2, 4, intermediate_components=False) + sage: trail = mzn.find_one_impossible_xor_differential_trail(4, fixed_variables, 'chuffed', 1, 2, 4, intermediate_components=False) sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_hybrid_impossible_xor_differential_model import MznHybridImpossibleXorDifferentialModel sage: from claasp.ciphers.block_ciphers.lblock_block_cipher import LBlockBlockCipher @@ -470,19 +590,40 @@ def find_one_impossible_xor_differential_trail(self, number_of_rounds=None, fixe sage: fixed_variables = [set_fixed_variables('key', 'equal', range(80), integer_to_bit_list(0x800, 80, 'big'))] #long# doctest: +SKIP sage: fixed_variables.append(set_fixed_variables('plaintext', 'equal', range(64), [0]*64)) #long# doctest: +SKIP sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_15_19', 'equal', range(64), [0]*64)) #long# doctest: +SKIP - sage: trail = mzn.find_one_impossible_xor_differential_trail(16, fixed_variables, 'Chuffed', 1, 8, 16, intermediate_components=False) #long# doctest: +SKIP + sage: trail = mzn.find_one_impossible_xor_differential_trail(16, fixed_variables, 'chuffed', 1, 8, 16, intermediate_components=False) #long# doctest: +SKIP ... """ - self.build_hybrid_impossible_xor_differential_trail_model(fixed_values, number_of_rounds, initial_round, middle_round, final_round, intermediate_components, probabilistic) - + self.build_hybrid_impossible_xor_differential_trail_model( + fixed_values, + number_of_rounds, + initial_round, + middle_round, + final_round, + intermediate_components, + probabilistic, + ) + if solve_with_API: - return self.solve_for_ARX(solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors) - return self.solve('impossible_xor_differential_one_solution', solver_name=solver_name, number_of_rounds=number_of_rounds, initial_round=initial_round, middle_round=middle_round, final_round=final_round, timeout_in_seconds_=timelimit, processes_=num_of_processors, solve_external=solve_external, probabilistic=probabilistic) + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "impossible_xor_differential_one_solution", + solver_name=solver_name, + number_of_rounds=number_of_rounds, + initial_round=initial_round, + middle_round=middle_round, + final_round=final_round, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + probabilistic=probabilistic, + ) def _get_sbox_max(self): nb_sbox = len([c for c in self._cipher.get_all_components() if c.type == SBOX]) - return 100*self._cipher.number_of_rounds + nb_sbox*10 + return 100 * self._cipher.number_of_rounds + nb_sbox * 10 def input_constraints(self, number_of_rounds=None, middle_round=None, probabilistic=False): if number_of_rounds is None: @@ -493,7 +634,7 @@ def input_constraints(self, number_of_rounds=None, middle_round=None, probabilis cp_constraints = [] cp_declarations = [f"set of int: ext_domain = 0..2 union {{ i | i in 10..{sbox_max} where (i mod 10 = 0)}};"] cp_declarations += [ - f'array[0..{bit_size - 1}] of var ext_domain: {input_};' + f"array[0..{bit_size - 1}] of var ext_domain: {input_};" for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) ] @@ -505,10 +646,12 @@ def input_constraints(self, number_of_rounds=None, middle_round=None, probabilis for r in range(number_of_rounds - middle_round + 1): backward_components.extend(inverse_cipher.get_components_in_round(r)) - cp_declarations.extend([ - f'array[0..{bit_size - 1}] of var ext_domain: inverse_{input_};' - for input_, bit_size in zip(inverse_cipher.inputs, inverse_cipher.inputs_bit_size) - ]) + cp_declarations.extend( + [ + f"array[0..{bit_size - 1}] of var ext_domain: inverse_{input_};" + for input_, bit_size in zip(inverse_cipher.inputs, inverse_cipher.inputs_bit_size) + ] + ) prob_count = 0 valid_probabilities = {0} if probabilistic else set() @@ -523,46 +666,48 @@ def input_constraints(self, number_of_rounds=None, middle_round=None, probabilis if component.id in key_ids and SBOX in component.type and probabilistic: prob_count += 1 self.update_sbox_ddt_valid_probabilities(component, valid_probabilities) - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: {prefix}{output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..1: {prefix}{output_id_link};") elif component in key_components and WORD_OPERATION in component.type: if probabilistic: if "AND" in component.description[0] or component.description[0] == "OR": prob_count += component.description[1] * component.output_bit_size - update_and_or_ddt_valid_probabilities(and_already_added, component, cp_declarations, - valid_probabilities) + update_and_or_ddt_valid_probabilities( + and_already_added, component, cp_declarations, valid_probabilities + ) elif "MODADD" in component.description[0]: prob_count += component.description[1] - 1 valid_probabilities |= set(range(100 * output_size)[::100]) - cp_declarations.append(f'array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};") elif "output" in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};") elif CIPHER_OUTPUT in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};') - cp_constraints.append(f'constraint count({prefix}{output_id_link},2) < {output_size};') + cp_declarations.append(f"array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};") + cp_constraints.append(f"constraint count({prefix}{output_id_link},2) < {output_size};") if not is_forward and not probabilistic: - cp_constraints.append(f'constraint count({prefix}{output_id_link},1) > 0;') + cp_constraints.append(f"constraint count({prefix}{output_id_link},1) > 0;") elif CONSTANT not in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var ext_domain: {prefix}{output_id_link};") if INPUT_KEY in self._cipher.inputs: cp_constraints.append("constraint inverse_key = key;") for input_id, input_size in self._cipher.inputs_size_to_dict().items(): - cp_constraints.append(f'constraint forall (i in 0..{input_size-1})({input_id}[i] <= 2);') + cp_constraints.append(f"constraint forall (i in 0..{input_size - 1})({input_id}[i] <= 2);") cp_constraints.append(f"constraint count({' ++ '.join(self._cipher.inputs)}, 1) > 0;") - for component in key_components: if component.id in set(key_ids) & set(c.id for c in forward_components) & set( - c.id for c in backward_components): + c.id for c in backward_components + ): cp_declarations.append(f"constraint {component.id} = inverse_{component.id};") if probabilistic: - cp_declarations_weight = 'int: weight = -1;' + cp_declarations_weight = "int: weight = -1;" if prob_count > 0: self._probability = True last_round_key = [c for c in key_components if c.description == ["round_key_output"]][ - number_of_rounds - 1].id + number_of_rounds - 1 + ].id prob_variables = f"array[0..{prob_count - 1}] of var {valid_probabilities}: p;" cp_declarations.append(prob_variables) cp_declarations_weight = f"var int: weight = p[{'] + p['.join(map(str, [val for c, val in self.component_and_probability.items() if key_ids.index(c) < key_ids.index(last_round_key)]))}];" @@ -570,18 +715,22 @@ def input_constraints(self, number_of_rounds=None, middle_round=None, probabilis return cp_declarations, cp_constraints - def propagate_deterministically(self, component, key_schedule=False, wordwise=False, inverse=False, probabilistic=False): + def propagate_deterministically( + self, component, key_schedule=False, wordwise=False, inverse=False, probabilistic=False + ): if not wordwise: if component.type == SBOX: if key_schedule and probabilistic: - variables, constraints = component.cp_xor_differential_propagation_constraints( - self, inverse) + variables, constraints = component.cp_xor_differential_propagation_constraints(self, inverse) else: - variables, constraints, sbox_mant = component.cp_hybrid_deterministic_truncated_xor_differential_constraints( - self.sbox_mant, inverse, self.sboxes_component_number_list) + variables, constraints, sbox_mant = ( + component.cp_hybrid_deterministic_truncated_xor_differential_constraints( + self.sbox_mant, inverse, self.sboxes_component_number_list + ) + ) self.sbox_mant = sbox_mant self.sbox_size = component.output_bit_size - elif component.description[0] == 'XOR': + elif component.description[0] == "XOR": variables, constraints = component.cp_hybrid_deterministic_truncated_xor_differential_constraints() else: variables, constraints = component.cp_deterministic_truncated_xor_differential_trail_constraints() @@ -591,18 +740,20 @@ def propagate_deterministically(self, component, key_schedule=False, wordwise=Fa return variables, constraints def format_component_value(self, component_id, string): - if f'{component_id}_i' in string: - value = string.replace(f'{component_id}_i', '') - elif f'{component_id}_o' in string: - value = string.replace(f'{component_id}_o', '') - elif f'inverse_{component_id}' in string: - value = string.replace(f'inverse_{component_id}', '') - elif f'{component_id}' in string: - value = string.replace(component_id, '') - value = ['.' if x == 0 else str(x) if x < 2 else '?' if x == 2 else str(x % 7 + 2) for x in - ast.literal_eval(value[3:])] - - return ''.join(value) + if f"{component_id}_i" in string: + value = string.replace(f"{component_id}_i", "") + elif f"{component_id}_o" in string: + value = string.replace(f"{component_id}_o", "") + elif f"inverse_{component_id}" in string: + value = string.replace(f"inverse_{component_id}", "") + elif f"{component_id}" in string: + value = string.replace(component_id, "") + value = [ + "." if x == 0 else str(x) if x < 2 else "?" if x == 2 else str(x % 7 + 2) + for x in ast.literal_eval(value[3:]) + ] + + return "".join(value) def update_sbox_ddt_valid_probabilities(self, component, valid_probabilities): input_size = int(component.input_bit_size) @@ -618,12 +769,14 @@ def update_sbox_ddt_valid_probabilities(self, component, valid_probabilities): for i in range(sbox_ddt.nrows()): set_of_occurrences = set(sbox_ddt.rows()[i]) set_of_occurrences -= {0} - valid_probabilities.update({round(100 * math.log2(2 ** input_size / occurrence)) - for occurrence in set_of_occurrences}) + valid_probabilities.update( + {round(100 * math.log2(2**input_size / occurrence)) for occurrence in set_of_occurrences} + ) self.sbox_ddt_values.append((description, output_id_link)) - def _parse_solver_output(self, output_to_parse, number_of_rounds, initial_round, middle_round, final_round, probabilistic=False): - + def _parse_solver_output( + self, output_to_parse, number_of_rounds, initial_round, middle_round, final_round, probabilistic=False + ): if probabilistic: components_values, memory, time, total_weight = self.parse_solver_information(output_to_parse, False, True) else: @@ -634,65 +787,97 @@ def _parse_solver_output(self, output_to_parse, number_of_rounds, initial_round, for r in list(range(initial_round - 1, middle_round)) + list(range(final_round, number_of_rounds)): all_components.extend([component.id for component in [*self._cipher.get_components_in_round(r)]]) for r in list(range(initial_round - 1)) + list(range(middle_round - 1, final_round)): - all_components.extend(['inverse_' + component.id for component in [*self.inverse_cipher.get_components_in_round(number_of_rounds - r - 1)]]) - all_components.extend(['inverse_' + id_link for id_link in [*self.inverse_cipher.inputs]]) - all_components.extend(['inverse_' + id_link for id_link in [*self._cipher.inputs]]) + all_components.extend( + [ + "inverse_" + component.id + for component in [*self.inverse_cipher.get_components_in_round(number_of_rounds - r - 1)] + ] + ) + all_components.extend(["inverse_" + id_link for id_link in [*self.inverse_cipher.inputs]]) + all_components.extend(["inverse_" + id_link for id_link in [*self._cipher.inputs]]) for component_id in all_components: solution_number = 1 for j, string in enumerate(output_to_parse): - if f'{component_id}' in string and 'inverse_' not in component_id + string: + if f"{component_id}" in string and "inverse_" not in component_id + string: value = self.format_component_value(component_id, string) component_solution = {} - component_solution['value'] = value - self.add_solution_to_components_values(component_id, component_solution, components_values, j, - output_to_parse, solution_number, string) - elif f'{component_id}' in string and 'inverse_' in component_id: + component_solution["value"] = value + self.add_solution_to_components_values( + component_id, component_solution, components_values, j, output_to_parse, solution_number, string + ) + elif f"{component_id}" in string and "inverse_" in component_id: value = self.format_component_value(component_id, string) component_solution = {} - component_solution['value'] = value - self.add_solution_to_components_values(component_id, component_solution, components_values, j, - output_to_parse, solution_number, string) - elif '----------' in string: + component_solution["value"] = value + self.add_solution_to_components_values( + component_id, component_solution, components_values, j, output_to_parse, solution_number, string + ) + elif "----------" in string: solution_number += 1 return time, memory, components_values, total_weight - def solve(self, model_type, solver_name=None, number_of_rounds=None, initial_round=None, middle_round=None, final_round=None, processes_=None, timeout_in_seconds_=None, all_solutions_=False, solve_external = False, probabilistic=False): + def solve( + self, + model_type, + solver_name=None, + number_of_rounds=None, + initial_round=None, + middle_round=None, + final_round=None, + processes_=None, + timeout_in_seconds_=None, + all_solutions_=False, + solve_external=False, + probabilistic=False, + ): if number_of_rounds is None: number_of_rounds = self._cipher.number_of_rounds if final_round is None: final_round = self._cipher.number_of_rounds - cipher_name = self.cipher_id - input_file_path = f'{cipher_name}_Mzn_{model_type}_{solver_name}.mzn' - command = self.get_command_for_solver_process(input_file_path, model_type, solver_name, processes_, timeout_in_seconds_) - solver_process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") - os.remove(input_file_path) + command = self.get_command_for_solver_process(model_type, solver_name, processes_, timeout_in_seconds_) + model = "\n".join(self._model_constraints) + "\n" + solver_process = subprocess.run(command, input=model, capture_output=True, text=True) if solver_process.returncode >= 0: solutions = [] solver_output = solver_process.stdout.splitlines() - solve_time, memory, components_values, total_weight = self._parse_solver_output(solver_output, number_of_rounds, initial_round, middle_round, final_round, probabilistic) + solve_time, memory, components_values, total_weight = self._parse_solver_output( + solver_output, number_of_rounds, initial_round, middle_round, final_round, probabilistic + ) if probabilistic: weight_list = [int(float(x)) for x in total_weight] summed_weight = sum(2 ** (-x) for x in weight_list if x) cumulated_weight = math.log(summed_weight) / math.log(2) if summed_weight else 0 if components_values == {}: - solution = convert_solver_solution_to_dictionary(self.cipher_id, model_type, solver_name, - solve_time, memory, - components_values, total_weight) - if 'UNSATISFIABLE' in solver_output[0]: - solution['status'] = 'UNSATISFIABLE' + solution = convert_solver_solution_to_dictionary( + self.cipher_id, model_type, solver_name, solve_time, memory, components_values, total_weight + ) + if UNSATISFIABLE in solver_output[0]: + solution["status"] = UNSATISFIABLE else: - solution['status'] = 'SATISFIABLE' + solution["status"] = SATISFIABLE solutions.append(solution) else: - MznModel.add_solutions_from_components_values(self, components_values, memory, model_type, solutions, solve_time, - solver_name, solver_output, total_weight, solve_external) - if model_type in ['xor_differential_one_solution', - 'xor_linear_one_solution', - 'deterministic_truncated_one_solution', - 'impossible_xor_differential_one_solution']: + MznModel.add_solutions_from_components_values( + self, + components_values, + memory, + model_type, + solutions, + solve_time, + solver_name, + solver_output, + total_weight, + solve_external, + ) + if model_type in [ + "xor_differential_one_solution", + "xor_linear_one_solution", + "deterministic_truncated_one_solution", + "impossible_xor_differential_one_solution", + ]: return solutions[0] if probabilistic: - return {'total_weight': cumulated_weight, 'solutions': solutions} + return {"total_weight": cumulated_weight, "solutions": solutions} else: - return solutions \ No newline at end of file + return solutions diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model.py index ca9fa1602..07573fcef 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model.py @@ -1,41 +1,43 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - -import os -import math -import itertools import subprocess import time -from copy import deepcopy - -from claasp.cipher_modules.models.cp.mzn_model import solve_satisfy, constraint_type_error -from claasp.cipher_modules.models.utils import write_model_to_file, convert_solver_solution_to_dictionary, \ - check_if_implemented_component, set_fixed_variables -from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model import \ - MznDeterministicTruncatedXorDifferentialModel -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, - WORD_OPERATION, DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, IMPOSSIBLE_XOR_DIFFERENTIAL) -from claasp.cipher_modules.models.cp.solvers import CP_SOLVERS_EXTERNAL, SOLVER_DEFAULT +from claasp.cipher_modules.models.cp.mzn_model import SOLVE_SATISFY +from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model import ( + MznDeterministicTruncatedXorDifferentialModel, +) +from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT +from claasp.cipher_modules.models.utils import ( + check_if_implemented_component, + convert_solver_solution_to_dictionary, + set_fixed_variables, +) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + IMPOSSIBLE_XOR_DIFFERENTIAL, + SATISFIABLE, + UNSATISFIABLE, +) class MznImpossibleXorDifferentialModel(MznDeterministicTruncatedXorDifferentialModel): - def __init__(self, cipher): super().__init__(cipher) self.inverse_cipher = cipher.cipher_inverse() @@ -44,21 +46,22 @@ def __init__(self, cipher): self.key_involvements = self.get_state_key_bits_positions() self.inverse_key_involvements = self.get_inverse_state_key_bits_positions() - def add_solution_to_components_values(self, component_id, component_solution, components_values, j, output_to_parse, - solution_number, string): + def add_solution_to_components_values( + self, component_id, component_solution, components_values, j, output_to_parse, solution_number, string + ): inverse_cipher = self.inverse_cipher if component_id in self._cipher.inputs: - components_values[f'solution{solution_number}'][f'{component_id}'] = component_solution + components_values[f"solution{solution_number}"][f"{component_id}"] = component_solution elif component_id in self.inverse_cipher.inputs: - components_values[f'solution{solution_number}'][f'inverse_{component_id}'] = component_solution - elif f'{component_id}_i' in string: - components_values[f'solution{solution_number}'][f'{component_id}_i'] = component_solution - elif f'{component_id}_o' in string: - components_values[f'solution{solution_number}'][f'{component_id}_o'] = component_solution - elif f'inverse_{component_id} ' in string: - components_values[f'solution{solution_number}'][f'inverse_{component_id}'] = component_solution - elif f'{component_id} ' in string: - components_values[f'solution{solution_number}'][f'{component_id}'] = component_solution + components_values[f"solution{solution_number}"][f"inverse_{component_id}"] = component_solution + elif f"{component_id}_i" in string: + components_values[f"solution{solution_number}"][f"{component_id}_i"] = component_solution + elif f"{component_id}_o" in string: + components_values[f"solution{solution_number}"][f"{component_id}_o"] = component_solution + elif f"inverse_{component_id} " in string: + components_values[f"solution{solution_number}"][f"inverse_{component_id}"] = component_solution + elif f"{component_id} " in string: + components_values[f"solution{solution_number}"][f"{component_id}"] = component_solution def build_impossible_backward_model(self, backward_components, clean=True): inverse_variables = [] @@ -78,8 +81,11 @@ def build_impossible_backward_model(self, backward_components, clean=True): input_component = self.get_component_from_id(id_link, self.inverse_cipher) if input_component not in backward_components and id_link not in key_ids + constant_ids: components_to_invert.append(input_component) - inverse_variables, inverse_constraints = self.clean_inverse_impossible_variables_constraints_with_extensions( - components_to_invert, inverse_variables, inverse_constraints) + inverse_variables, inverse_constraints = ( + self.clean_inverse_impossible_variables_constraints_with_extensions( + components_to_invert, inverse_variables, inverse_constraints + ) + ) return inverse_variables, inverse_constraints @@ -94,13 +100,14 @@ def build_impossible_forward_model(self, forward_components, clean=False): if clean: direct_variables, direct_constraints = self.clean_inverse_impossible_variables_constraints( - forward_components, direct_variables, direct_constraints) + forward_components, direct_variables, direct_constraints + ) return direct_variables, direct_constraints - def build_impossible_xor_differential_trail_with_extensions_model(self, fixed_variables, number_of_rounds, - initial_round, middle_round, final_round, - intermediate_components): + def build_impossible_xor_differential_trail_with_extensions_model( + self, fixed_variables, number_of_rounds, initial_round, middle_round, final_round, intermediate_components=True + ): """ Build the CP model for the search of deterministic truncated XOR differential trails with extensions for key recovery. @@ -154,7 +161,8 @@ def build_impossible_xor_differential_trail_with_extensions_model(self, fixed_va for input_component in self.inverse_cipher.get_all_components(): if input_component.id == id_link: components_to_link.append( - [self.get_inverse_component_correspondance(input_component), id_link]) + [self.get_inverse_component_correspondance(input_component), id_link] + ) link_constraints = self.link_constraints_for_trail_with_extensions(components_to_link) key_schedule_variables, key_schedule_constraints = self.constraints_for_key_schedule() @@ -170,16 +178,19 @@ def build_impossible_xor_differential_trail_with_extensions_model(self, fixed_va self._variables_list.extend(key_schedule_variables) self._variables_list.extend(constants_variables) deterministic_truncated_xor_differential.extend(constraints) - variables, constraints = self.input_impossible_constraints_with_extensions(number_of_rounds, initial_round, - middle_round, final_round) + variables, constraints = self.input_impossible_constraints_with_extensions( + number_of_rounds, initial_round, middle_round, final_round + ) self._model_prefix.extend(variables) self._variables_list.extend(constraints) deterministic_truncated_xor_differential.extend(link_constraints) deterministic_truncated_xor_differential.extend(key_schedule_constraints) deterministic_truncated_xor_differential.extend(constants_constraints) deterministic_truncated_xor_differential.extend( - self.final_impossible_constraints_with_extensions(number_of_rounds, initial_round, middle_round, - final_round, intermediate_components)) + self.final_impossible_constraints_with_extensions( + number_of_rounds, initial_round, middle_round, final_round, intermediate_components + ) + ) set_of_constraints = self._variables_list + deterministic_truncated_xor_differential cleaned_constraints = [] @@ -189,9 +200,16 @@ def build_impossible_xor_differential_trail_with_extensions_model(self, fixed_va self._model_constraints = cleaned_constraints - def build_impossible_xor_differential_trail_model(self, fixed_variables=[], number_of_rounds=None, initial_round=1, - middle_round=None, final_round=None, - intermediate_components=True): + def build_impossible_xor_differential_trail_model( + self, + fixed_variables=[], + number_of_rounds=None, + initial_round=1, + middle_round=None, + final_round=None, + intermediate_components=True, + fully_automatic=False, + ): """ Build the CP model for the search of deterministic truncated XOR differential trails. @@ -215,10 +233,13 @@ def build_impossible_xor_differential_trail_model(self, fixed_variables=[], numb sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] sage: cp.build_impossible_xor_differential_trail_model(fixed_variables, 4, 1, 3, 4, False) """ + initial_round, middle_round, final_round, number_of_rounds = self.validate_input_rounds( + initial_round, middle_round, final_round, number_of_rounds + ) self.initialise_model() - if number_of_rounds is None: + if fully_automatic: number_of_rounds = self._cipher.number_of_rounds - if final_round is None: + initial_round = 1 final_round = self._cipher.number_of_rounds inverse_cipher = self.inverse_cipher @@ -227,20 +248,16 @@ def build_impossible_xor_differential_trail_model(self, fixed_variables=[], numb deterministic_truncated_xor_differential = constraints self.middle_round = middle_round - if middle_round is not None: + if fully_automatic: + forward_components = self._cipher.get_all_components() + backward_components = inverse_cipher.get_all_components() + else: forward_components = [] for r in range(middle_round): forward_components.extend(self._cipher.get_components_in_round(r)) backward_components = [] for r in range(number_of_rounds - middle_round + 1): backward_components.extend(inverse_cipher.get_components_in_round(r)) - else: - forward_components = [] - for r in range(final_round): - forward_components.extend(self._cipher.get_components_in_round(r)) - backward_components = [] - for r in range(final_round): - backward_components.extend(inverse_cipher.get_components_in_round(r)) direct_variables, direct_constraints = self.build_impossible_forward_model(forward_components) self._variables_list.extend(direct_variables) @@ -248,59 +265,67 @@ def build_impossible_xor_differential_trail_model(self, fixed_variables=[], numb inverse_variables, inverse_constraints = self.build_impossible_backward_model(backward_components, clean=False) inverse_variables, inverse_constraints = self.clean_inverse_impossible_variables_constraints( - backward_components, inverse_variables, inverse_constraints) + backward_components, inverse_variables, inverse_constraints + ) self._variables_list.extend(inverse_variables) deterministic_truncated_xor_differential.extend(inverse_constraints) - variables, constraints = self.input_impossible_constraints(number_of_rounds=number_of_rounds, - middle_round=middle_round) + variables, constraints = self.input_impossible_constraints( + number_of_rounds=number_of_rounds, middle_round=middle_round, fully_automatic=fully_automatic + ) self._model_prefix.extend(variables) self._variables_list.extend(constraints) deterministic_truncated_xor_differential.extend( - self.final_impossible_constraints(number_of_rounds, initial_round, middle_round, final_round, - intermediate_components)) + self.final_impossible_constraints( + number_of_rounds, initial_round, middle_round, final_round, intermediate_components, fully_automatic + ) + ) set_of_constraints = self._variables_list + deterministic_truncated_xor_differential - self._model_constraints = self._model_prefix + self.clean_constraints(set_of_constraints, initial_round, - middle_round, final_round) + self._model_constraints = self._model_prefix + self.clean_constraints( + set_of_constraints, initial_round, middle_round, final_round, fully_automatic + ) - def clean_constraints(self, set_of_constraints, initial_round, middle_round, final_round): + def clean_constraints(self, set_of_constraints, initial_round, middle_round, final_round, fully_automatic=False): number_of_rounds = self._cipher.number_of_rounds - input_component = 'plaintext' + input_component = "plaintext" model_constraints = [] - if middle_round is not None: + if fully_automatic: + initial_round = 1 + final_round = number_of_rounds + forward_components = self._cipher.get_all_components() + backward_components = self.inverse_cipher.get_all_components() + else: forward_components = [] for r in range(initial_round - 1, middle_round): forward_components.extend([component.id for component in self._cipher.get_components_in_round(r)]) backward_components = [] for r in range(number_of_rounds - final_round, number_of_rounds - middle_round + 1): backward_components.extend( - ['inverse_' + component.id for component in self.inverse_cipher.get_components_in_round(r)]) - else: - forward_components = [] - for r in range(initial_round - 1, final_round): - forward_components.extend([component.id for component in self._cipher.get_components_in_round(r)]) - backward_components = [] - for r in range(number_of_rounds - final_round, number_of_rounds - initial_round + 1): - backward_components.extend( - ['inverse_' + component.id for component in self.inverse_cipher.get_components_in_round(r)]) - key_components, key_ids = self.extract_key_schedule() - components_to_keep = forward_components + backward_components + key_ids + ['inverse_' + id_link for id_link in - key_ids] + ['array['] + [ - solve_satisfy] + ["inverse_" + component.id for component in self.inverse_cipher.get_components_in_round(r)] + ) + key_components, key_ids = self.extract_key_schedule() + components_to_keep = ( + forward_components + + backward_components + + key_ids + + ["inverse_" + id_link for id_link in key_ids] + + ["array["] + + [SOLVE_SATISFY] + ) if initial_round == 1 and final_round == self._cipher.number_of_rounds: for i in range(len(set_of_constraints) - 1): - if set_of_constraints[i] not in set_of_constraints[i + 1:]: + if set_of_constraints[i] not in set_of_constraints[i + 1 :]: model_constraints.append(set_of_constraints[i]) model_constraints.append(set_of_constraints[-1]) return model_constraints if initial_round == 1: components_to_keep.extend(self._cipher.inputs) if final_round == number_of_rounds: - components_to_keep.extend(['inverse_' + id_link for id_link in self.inverse_cipher.inputs]) + components_to_keep.extend(["inverse_" + id_link for id_link in self.inverse_cipher.inputs]) if initial_round > 1: for component in self._cipher.get_components_in_round(initial_round - 2): - if 'output' in component.id: + if "output" in component.id: components_to_keep.append(component.id) input_component = component for constraint in set_of_constraints: @@ -310,47 +335,52 @@ def clean_constraints(self, set_of_constraints, initial_round, middle_round, fin return model_constraints - def clean_inverse_impossible_variables_constraints(self, backward_components, inverse_variables, - inverse_constraints): + def clean_inverse_impossible_variables_constraints( + self, backward_components, inverse_variables, inverse_constraints + ): for component in backward_components: - inverse_variables, inverse_constraints = self.set_inverse_component_id_in_constraints(component, - inverse_variables, - inverse_constraints) - inverse_variables, inverse_constraints = self.clean_repetitions_in_constraints(inverse_variables, - inverse_constraints) + inverse_variables, inverse_constraints = self.set_inverse_component_id_in_constraints( + component, inverse_variables, inverse_constraints + ) + inverse_variables, inverse_constraints = self.clean_repetitions_in_constraints( + inverse_variables, inverse_constraints + ) return inverse_variables, inverse_constraints - def clean_inverse_impossible_variables_constraints_with_extensions(self, backward_components, inverse_variables, - inverse_constraints): + def clean_inverse_impossible_variables_constraints_with_extensions( + self, backward_components, inverse_variables, inverse_constraints + ): key_components, key_ids = self.extract_key_schedule() constant_components, constant_ids = self.extract_constants() for component in backward_components: if component.id not in key_ids + constant_ids: - inverse_variables, inverse_constraints = self.set_inverse_component_id_in_constraints(component, - inverse_variables, - inverse_constraints) - inverse_variables, inverse_constraints = self.clean_repetitions_in_constraints(inverse_variables, - inverse_constraints) + inverse_variables, inverse_constraints = self.set_inverse_component_id_in_constraints( + component, inverse_variables, inverse_constraints + ) + inverse_variables, inverse_constraints = self.clean_repetitions_in_constraints( + inverse_variables, inverse_constraints + ) return inverse_variables, inverse_constraints def clean_repetitions_in_constraints(self, inverse_variables, inverse_constraints): for c in range(len(inverse_constraints)): start = 0 - while 'cipher_output' in inverse_constraints[c][start:]: - new_start = inverse_constraints[c].index('cipher_output', start) - inverse_constraints[c] = inverse_constraints[c][:new_start] + 'inverse_' + inverse_constraints[c][ - new_start:] + while "cipher_output" in inverse_constraints[c][start:]: + new_start = inverse_constraints[c].index("cipher_output", start) + inverse_constraints[c] = ( + inverse_constraints[c][:new_start] + "inverse_" + inverse_constraints[c][new_start:] + ) start = new_start + 9 start = 0 - while 'inverse_inverse_' in inverse_constraints[c][start:]: - new_start = inverse_constraints[c].index('inverse_inverse_', start) - inverse_constraints[c] = inverse_constraints[c][:new_start] + inverse_constraints[c][new_start + 8:] + while "inverse_inverse_" in inverse_constraints[c][start:]: + new_start = inverse_constraints[c].index("inverse_inverse_", start) + inverse_constraints[c] = inverse_constraints[c][:new_start] + inverse_constraints[c][new_start + 8 :] start = new_start for v in range(len(inverse_variables)): start = 0 - while 'inverse_inverse_' in inverse_variables[v][start:]: - new_start = inverse_variables[v].index('inverse_inverse_', start) - inverse_variables[v] = inverse_variables[v][:new_start] + inverse_variables[v][new_start + 8:] + while "inverse_inverse_" in inverse_variables[v][start:]: + new_start = inverse_variables[v].index("inverse_inverse_", start) + inverse_variables[v] = inverse_variables[v][:new_start] + inverse_variables[v][new_start + 8 :] start = new_start return inverse_variables, inverse_constraints @@ -368,14 +398,14 @@ def extract_constants(self): constant_components_ids = [] constant_components = [] for component in cipher.get_all_components(): - if 'constant' in component.id: + if "constant" in component.id: constant_components_ids.append(component.id) constant_components.append(component) - elif '_' in component.id: + elif "_" in component.id: component_inputs = component.input_id_links ks = True for comp_input in component_inputs: - if 'constant' not in comp_input: + if "constant" not in comp_input: ks = False if ks: constant_components_ids.append(component.id) @@ -385,20 +415,20 @@ def extract_constants(self): def extract_key_schedule(self): cipher = self._cipher - key_schedule_components_ids = ['key'] + key_schedule_components_ids = ["key"] key_schedule_components = [] for component in cipher.get_all_components(): component_inputs = component.input_id_links ks = True for comp_input in component_inputs: - if 'constant' not in comp_input and comp_input not in key_schedule_components_ids: + if "constant" not in comp_input and comp_input not in key_schedule_components_ids: ks = False if ks: key_schedule_components_ids.append(component.id) key_schedule_components.append(component) master_key_bits = [] for id_link, bit_positions in zip(component_inputs, component.input_bit_positions): - if id_link == 'key': + if id_link == "key": master_key_bits.extend(bit_positions) else: if id_link in self.key_schedule_bits_distribution: @@ -407,8 +437,9 @@ def extract_key_schedule(self): return key_schedule_components, key_schedule_components_ids - def final_impossible_constraints_with_extensions(self, number_of_rounds, initial_round, middle_round, final_round, - intermediate_components): + def final_impossible_constraints_with_extensions( + self, number_of_rounds, initial_round, middle_round, final_round, intermediate_components + ): """ Constraints for output and incompatibility. @@ -436,23 +467,33 @@ def final_impossible_constraints_with_extensions(self, number_of_rounds, initial cipher_inputs = self._cipher.inputs cipher = self._cipher inverse_cipher = self.inverse_cipher - cp_constraints = [solve_satisfy] - new_constraint = 'output[' - incompatibility_constraint = 'constraint ' + cp_constraints = [SOLVE_SATISFY] + new_constraint = "output[" + incompatibility_constraint = "constraint " for element in cipher_inputs: - new_constraint = f'{new_constraint}\"{element} = \"++ show({element}) ++ \"\\n\" ++' + new_constraint = f'{new_constraint}"{element} = "++ show({element}) ++ "\\n" ++' for element in cipher_inputs: if element not in key_schedule_components_ids: - new_constraint = f'{new_constraint}\"inverse_{element} = \"++ show(inverse_{element}) ++ \"\\n\" ++' + new_constraint = f'{new_constraint}"inverse_{element} = "++ show(inverse_{element}) ++ "\\n" ++' for id_link in self._cipher.get_all_components_ids(): - if id_link not in key_schedule_components_ids and self.get_component_round(id_link) in list( - range(initial_round - 1, middle_round)) + list( - range(final_round, number_of_rounds)) and 'constant' not in id_link and 'output' in id_link: - new_constraint = new_constraint + f'\"{id_link} = \"++ show({id_link})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - if id_link not in key_schedule_components_ids and self.get_component_round(id_link) in list( - range(initial_round - 1)) + list( - range(middle_round - 1, final_round)) and 'constant' not in id_link and 'output' in id_link: - new_constraint = new_constraint + f'\"inverse_{id_link} = \"++ show(inverse_{id_link})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + if ( + id_link not in key_schedule_components_ids + and self.get_component_round(id_link) + in list(range(initial_round - 1, middle_round)) + list(range(final_round, number_of_rounds)) + and "constant" not in id_link + and "output" in id_link + ): + new_constraint = new_constraint + f'"{id_link} = "++ show({id_link})++ "\\n" ++ "0" ++ "\\n" ++' + if ( + id_link not in key_schedule_components_ids + and self.get_component_round(id_link) + in list(range(initial_round - 1)) + list(range(middle_round - 1, final_round)) + and "constant" not in id_link + and "output" in id_link + ): + new_constraint = ( + new_constraint + f'"inverse_{id_link} = "++ show(inverse_{id_link})++ "\\n" ++ "0" ++ "\\n" ++' + ) if intermediate_components: for component in cipher.get_components_in_round(middle_round - 1): if component.type != CONSTANT and component.id not in key_schedule_components_ids: @@ -462,29 +503,33 @@ def final_impossible_constraints_with_extensions(self, number_of_rounds, initial component_inputs = [] input_bit_size = 0 for id_link, bit_positions in zip(input_id_links, input_bit_positions): - component_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + component_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) input_bit_size += len(bit_positions) # new_constraint = new_constraint + \ # f'\"{id_link} = \"++ show({id_link})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' # new_constraint = new_constraint + \ # f'\"inverse_{component_id} = \"++ show(inverse_{component_id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' for i in range(input_bit_size): - incompatibility_constraint += f'({component_inputs[i]}+inverse_{component_id}[{i}]=1) \\/ ' + incompatibility_constraint += f"({component_inputs[i]}+inverse_{component_id}[{i}]=1) \\/ " else: for component in cipher.get_components_in_round(middle_round - 1): - if 'output' in component.id and component.id not in key_schedule_components_ids: - new_constraint = new_constraint + \ - f'\"{component.id} = \"++ show({component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - new_constraint = new_constraint + \ - f'\"inverse_{component.id} = \"++ show(inverse_{component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + if "output" in component.id and component.id not in key_schedule_components_ids: + new_constraint = ( + new_constraint + f'"{component.id} = "++ show({component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) + new_constraint = ( + new_constraint + + f'"inverse_{component.id} = "++ show(inverse_{component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) for i in range(component.output_bit_size): - incompatibility_constraint += f'({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ ' - cp_constraints.extend([incompatibility_constraint[:-4] + ';', new_constraint[:-2] + '];']) + incompatibility_constraint += f"({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ " + cp_constraints.extend([incompatibility_constraint[:-4] + ";", new_constraint[:-2] + "];"]) return cp_constraints - def final_impossible_constraints(self, number_of_rounds, initial_round, middle_round, final_round, - intermediate_components): + def final_impossible_constraints( + self, number_of_rounds, initial_round, middle_round, final_round, intermediate_components, fully_automatic + ): """ Constraints for output and incompatibility. @@ -503,87 +548,163 @@ def final_impossible_constraints(self, number_of_rounds, initial_round, middle_r sage: from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list sage: speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=5) sage: cp = MznImpossibleXorDifferentialModel(speck) - sage: cp.final_impossible_constraints(3, 2, 3, 4, False) + sage: cp.final_impossible_constraints(3, 2, 3, 4, False, False) ['solve satisfy;', ... 'output["key = "++ show(key) ++ "\\n" ++"intermediate_output_0_5 = "++ show(intermediate_output_0_5) ++ "\\n" ++"intermediate_output_0_6 = "++ show(intermediate_output_0_6) ++ "\\n" ++"inverse_intermediate_output_3_12 = "++ show(inverse_intermediate_output_3_12) ++ "\\n" ++ "0" ++ "\\n" ++"intermediate_output_0_6 = "++ show(intermediate_output_0_6)++ "\\n" ++ "0" ++ "\\n" ++"intermediate_output_1_12 = "++ show(intermediate_output_1_12)++ "\\n" ++ "0" ++ "\\n" ++"intermediate_output_2_12 = "++ show(intermediate_output_2_12)++ "\\n" ++ "0" ++ "\\n" ++"inverse_intermediate_output_2_12 = "++ show(inverse_intermediate_output_2_12)++ "\\n" ++ "0" ++ "\\n" ++"inverse_intermediate_output_3_12 = "++ show(inverse_intermediate_output_3_12)++ "\\n" ++ "0" ++ "\\n" ++"inverse_cipher_output_4_12 = "++ show(inverse_cipher_output_4_12)++ "\\n" ++ "0" ++ "\\n" ];'] """ + + def show_constraints_intermediate_components( + component_list, key_schedule_components_ids, current_constraint, current_incompatibility_constraint + ): + for component in component_list: + if component.type != CONSTANT and component.id not in key_schedule_components_ids: + component_id = component.id + input_id_links = component.input_id_links + input_bit_positions = component.input_bit_positions + component_inputs = [] + input_bit_size = 0 + for id_link, bit_positions in zip(input_id_links, input_bit_positions): + component_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) + input_bit_size += len(bit_positions) + current_constraint = ( + current_constraint + f'"{id_link} = "++ show({id_link})++ "\\n" ++ "0" ++ "\\n" ++' + ) + current_constraint = ( + current_constraint + + f'"inverse_{component_id} = "++ show(inverse_{component_id})++ "\\n" ++ "0" ++ "\\n" ++' + ) + for i in range(input_bit_size): + current_incompatibility_constraint += ( + f"({component_inputs[i]}+inverse_{component_id}[{i}]=1) \\/ " + ) + return current_constraint, current_incompatibility_constraint + + def show_constraints(key_schedule_components_ids, current_constraint, current_incompatibility_constraint): + for component in self._cipher.get_all_components(): + if "output" in component.id and component.id not in key_schedule_components_ids: + current_constraint = ( + current_constraint + f'"{component.id} = "++ show({component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) + current_constraint = ( + current_constraint + + f'"inverse_{component.id} = "++ show(inverse_{component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) + for i in range(component.output_bit_size): + current_incompatibility_constraint += ( + f"({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ " + ) + return current_constraint, current_incompatibility_constraint + + if number_of_rounds is None: + number_of_rounds = self._cipher.number_of_rounds + if fully_automatic: + number_of_rounds = self._cipher.number_of_rounds + initial_round = 1 + final_round = self._cipher.number_of_rounds if initial_round == 1: cipher_inputs = self._cipher.inputs else: - cipher_inputs = ['key'] + cipher_inputs = ["key"] for component in self._cipher.get_components_in_round(initial_round - 2): - if 'output' in component.id: + if "output" in component.id: cipher_inputs.append(component.id) cipher = self._cipher inverse_cipher = self.inverse_cipher if final_round == self._cipher.number_of_rounds: cipher_outputs = inverse_cipher.inputs else: - cipher_outputs = ['key'] + cipher_outputs = ["key"] for component in self.inverse_cipher.get_components_in_round(self._cipher.number_of_rounds - final_round): - if 'output' in component.id: + if "output" in component.id: cipher_outputs.append(component.id) - cp_constraints = [solve_satisfy] - new_constraint = 'output[' - incompatibility_constraint = 'constraint' + cp_constraints = [SOLVE_SATISFY] + new_constraint = "output[" + incompatibility_constraint = "constraint" key_schedule_components, key_schedule_components_ids = self.extract_key_schedule() - for element in cipher_inputs: - new_constraint = f'{new_constraint}\"{element} = \"++ show({element}) ++ \"\\n\" ++' - for element in cipher_outputs: - if element != 'key': - new_constraint = f'{new_constraint}\"inverse_{element} = \"++ show(inverse_{element}) ++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - if intermediate_components: - if middle_round is not None: - component_list = cipher.get_components_in_round(middle_round - 1) - else: + if fully_automatic: + for element in cipher_inputs: + new_constraint = f'{new_constraint}"{element} = "++ show({element}) ++ "\\n" ++' + for element in cipher_outputs: + if element != "key": + new_constraint = ( + f'{new_constraint}"inverse_{element} = "++ show(inverse_{element}) ++ "\\n" ++ "0" ++ "\\n" ++' + ) + if intermediate_components: component_list = cipher.get_all_components() - for component in component_list: - if component.type != CONSTANT and component.id not in key_schedule_components_ids: - component_id = component.id - input_id_links = component.input_id_links - input_bit_positions = component.input_bit_positions - component_inputs = [] - input_bit_size = 0 - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - component_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) - input_bit_size += len(bit_positions) - new_constraint = new_constraint + \ - f'\"{id_link} = \"++ show({id_link})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - new_constraint = new_constraint + \ - f'\"inverse_{component_id} = \"++ show(inverse_{component_id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - for i in range(input_bit_size): - incompatibility_constraint += f'({component_inputs[i]}+inverse_{component_id}[{i}]=1) \\/ ' + new_constraint, incompatibility_constraint = show_constraints_intermediate_components( + component_list, key_schedule_components_ids, new_constraint, incompatibility_constraint + ) + else: + new_constraint, incompatibility_constraint = show_constraints( + key_schedule_components_ids, new_constraint, incompatibility_constraint + ) else: - if middle_round is not None: + for element in cipher_inputs: + new_constraint = f'{new_constraint}"{element} = "++ show({element}) ++ "\\n" ++' + for element in cipher_outputs: + if element != "key": + new_constraint = ( + f'{new_constraint}"inverse_{element} = "++ show(inverse_{element}) ++ "\\n" ++ "0" ++ "\\n" ++' + ) + if intermediate_components: + component_list = cipher.get_components_in_round(middle_round - 1) + new_constraint, incompatibility_constraint = show_constraints_intermediate_components( + component_list, key_schedule_components_ids, new_constraint, incompatibility_constraint + ) + else: for component in cipher.get_all_components(): - if 'output' in component.id and component.id not in key_schedule_components_ids: + if "output" in component.id and component.id not in key_schedule_components_ids: if self.get_component_round(component.id) <= middle_round - 1: - new_constraint = new_constraint + \ - f'\"{component.id} = \"++ show({component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + new_constraint = ( + new_constraint + + f'"{component.id} = "++ show({component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) if self.get_component_round(component.id) >= middle_round - 1: - new_constraint = new_constraint + \ - f'\"inverse_{component.id} = \"++ show(inverse_{component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + new_constraint = ( + new_constraint + + f'"inverse_{component.id} = "++ show(inverse_{component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) if self.get_component_round(component.id) == middle_round - 1: for i in range(component.output_bit_size): - incompatibility_constraint += f'({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ ' - else: - for component in cipher.get_all_components(): - if 'output' in component.id and component.id not in key_schedule_components_ids: - new_constraint = new_constraint + \ - f'\"{component.id} = \"++ show({component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - new_constraint = new_constraint + \ - f'\"inverse_{component.id} = \"++ show(inverse_{component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - for i in range(component.output_bit_size): - incompatibility_constraint += f'({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ ' - cp_constraints.extend([incompatibility_constraint[:-4] + ';', new_constraint[:-2] + '];']) + incompatibility_constraint += ( + f"({component.id}[{i}]+inverse_{component.id}[{i}]=1) \\/ " + ) + cp_constraints.extend([incompatibility_constraint[:-4] + ";", new_constraint[:-2] + "];"]) return cp_constraints - def find_all_impossible_xor_differential_trails(self, number_of_rounds=None, fixed_values=[], solver_name='Chuffed', - initial_round=1, middle_round=None, final_round=None, - intermediate_components=True, num_of_processors=None, - timelimit=None, solve_with_API=False, solve_external=True): + def validate_input_rounds(self, initial_round, middle_round, final_round, number_of_rounds): + if initial_round < 1: + raise ValueError("Initial round must be at least 1.") + if final_round is None: + final_round = self._cipher.number_of_rounds + if middle_round is None: + middle_round = (final_round + initial_round) // 2 + if middle_round is not None and middle_round < initial_round: + raise ValueError("Middle round must be greater than or equal to initial round.") + if final_round is not None and final_round < (middle_round if middle_round is not None else initial_round): + raise ValueError("Final round must be greater than or equal to middle round.") + if number_of_rounds is None: + number_of_rounds = final_round - initial_round + 1 + if number_of_rounds != (final_round - initial_round + 1): + raise ValueError("Number of rounds is inconsistent with initial, middle, and final rounds.") + return initial_round, middle_round, final_round, number_of_rounds + + def find_all_impossible_xor_differential_trails( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + initial_round=1, + middle_round=None, + final_round=None, + intermediate_components=True, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): """ Search for all impossible XOR differential trails of a cipher. @@ -609,25 +730,50 @@ def find_all_impossible_xor_differential_trails(self, number_of_rounds=None, fix sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_3_12', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) - sage: trail = cp.find_all_impossible_xor_differential_trails(4, fixed_variables, 'Chuffed', 1, 3, 4, False) #doctest: +SKIP + sage: trail = cp.find_all_impossible_xor_differential_trails(4, fixed_variables, 'chuffed', 1, 3, 4, False) #doctest: +SKIP """ - self.build_impossible_xor_differential_trail_model(fixed_values, number_of_rounds, initial_round, middle_round, - final_round, intermediate_components) + initial_round, middle_round, final_round, number_of_rounds = self.validate_input_rounds( + initial_round, middle_round, final_round, number_of_rounds + ) + self.build_impossible_xor_differential_trail_model( + fixed_values, number_of_rounds, initial_round, middle_round, final_round, intermediate_components + ) if solve_with_API: - return self.solve_for_ARX(solver_name=solver_name, timeout_in_seconds_=timelimit, - processes_=num_of_processors, all_solutions_=True) - return self.solve(IMPOSSIBLE_XOR_DIFFERENTIAL, solver_name=solver_name, number_of_rounds=number_of_rounds, - initial_round=initial_round, middle_round=middle_round, final_round=final_round, - timeout_in_seconds_=timelimit, processes_=num_of_processors, all_solutions_=True, - solve_external=solve_external) - - def find_lowest_complexity_impossible_xor_differential_trail(self, number_of_rounds=None, fixed_values=[], - solver_name='Chuffed', initial_round=1, middle_round=None, - final_round=None, intermediate_components=True, - num_of_processors=None, timelimit=None, - solve_with_API=False, solve_external=True): + return self.solve_for_ARX( + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + ) + return self.solve( + IMPOSSIBLE_XOR_DIFFERENTIAL, + solver_name=solver_name, + number_of_rounds=number_of_rounds, + initial_round=initial_round, + middle_round=middle_round, + final_round=final_round, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + solve_external=solve_external, + ) + + def find_lowest_complexity_impossible_xor_differential_trail( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + initial_round=1, + middle_round=None, + final_round=None, + intermediate_components=True, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): """ Search for the impossible XOR differential trail of a cipher with the highest number of known bits in plaintext and ciphertext difference. @@ -653,26 +799,49 @@ def find_lowest_complexity_impossible_xor_differential_trail(self, number_of_rou sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_3_12', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) - sage: trail = cp.find_lowest_complexity_impossible_xor_differential_trail(4, fixed_variables, 'Chuffed', 1, 3, 4, intermediate_components = False) + sage: trail = cp.find_lowest_complexity_impossible_xor_differential_trail(4, fixed_variables, 'chuffed', 1, 3, 4, intermediate_components = False) """ - self.build_impossible_xor_differential_trail_model(fixed_values, number_of_rounds, initial_round, middle_round, - final_round, intermediate_components) - self._model_constraints.remove(f'solve satisfy;') + initial_round, middle_round, final_round, number_of_rounds = self.validate_input_rounds( + initial_round, middle_round, final_round, number_of_rounds + ) + self.build_impossible_xor_differential_trail_model( + fixed_values, number_of_rounds, initial_round, middle_round, final_round, intermediate_components + ) + self._model_constraints.remove("solve satisfy;") self._model_constraints.append( - f'solve minimize count(plaintext, 2) + count(inverse_{self._cipher.get_all_components_ids()[-1]}, 2);') + f"solve minimize count(plaintext, 2) + count(inverse_{self._cipher.get_all_components_ids()[-1]}, 2);" + ) if solve_with_API: - return self.solve_for_ARX(solver_name=solver_name, timeout_in_seconds_=timelimit, - processes_=num_of_processors) - return self.solve('impossible_xor_differential_one_solution', solver_name=solver_name, - number_of_rounds=number_of_rounds, initial_round=initial_round, middle_round=middle_round, - final_round=final_round, timeout_in_seconds_=timelimit, processes_=num_of_processors, - solve_external=solve_external) - - def find_one_impossible_xor_differential_cluster(self, number_of_rounds=None, fixed_values=[], solver_name='Chuffed', - initial_round=1, middle_round=None, final_round=None, - intermediate_components=True, num_of_processors=None, - timelimit=None, solve_with_API=False, solve_external=True): + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "impossible_xor_differential_one_solution", + solver_name=solver_name, + number_of_rounds=number_of_rounds, + initial_round=initial_round, + middle_round=middle_round, + final_round=final_round, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + + def find_one_impossible_xor_differential_cluster( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + initial_round=1, + middle_round=None, + final_round=None, + intermediate_components=True, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): """ Search for the impossible XOR differential trail of a cipher with the highest number of unknown bits in plaintext and ciphertext difference. @@ -698,27 +867,49 @@ def find_one_impossible_xor_differential_cluster(self, number_of_rounds=None, fi sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_3_12', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) - sage: trail = cp.find_one_impossible_xor_differential_cluster(4, fixed_variables, 'Chuffed', 1, 3, 4, intermediate_components = False) + sage: trail = cp.find_one_impossible_xor_differential_cluster(4, fixed_variables, 'chuffed', 1, 3, 4, intermediate_components = False) """ - self.build_impossible_xor_differential_trail_model(fixed_values, number_of_rounds, initial_round, middle_round, - final_round, intermediate_components) - self._model_constraints.remove(f'solve satisfy;') + initial_round, middle_round, final_round, number_of_rounds = self.validate_input_rounds( + initial_round, middle_round, final_round, number_of_rounds + ) + self.build_impossible_xor_differential_trail_model( + fixed_values, number_of_rounds, initial_round, middle_round, final_round, intermediate_components + ) + self._model_constraints.remove("solve satisfy;") self._model_constraints.append( - f'solve maximize count(plaintext, 2) + count(inverse_{self._cipher.get_all_components_ids()[-1]}, 2);') + f"solve maximize count(plaintext, 2) + count(inverse_{self._cipher.get_all_components_ids()[-1]}, 2);" + ) if solve_with_API: - return self.solve_for_ARX(solver_name=solver_name, timeout_in_seconds_=timelimit, - processes_=num_of_processors) - return self.solve('impossible_xor_differential_one_solution', solver_name=solver_name, - number_of_rounds=number_of_rounds, initial_round=initial_round, middle_round=middle_round, - final_round=final_round, timeout_in_seconds_=timelimit, processes_=num_of_processors, - solve_external=solve_external) - - def find_one_impossible_xor_differential_trail_with_extensions(self, number_of_rounds=None, fixed_values=[], - solver_name='Chuffed', initial_round=1, middle_round=None, - final_round=None, intermediate_components=True, - num_of_processors=None, timelimit=None, - solve_with_API=False, solve_external=True): + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "impossible_xor_differential_one_solution", + solver_name=solver_name, + number_of_rounds=number_of_rounds, + initial_round=initial_round, + middle_round=middle_round, + final_round=final_round, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + + def find_one_impossible_xor_differential_trail_with_extensions( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + initial_round=1, + middle_round=None, + final_round=None, + intermediate_components=True, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): """ Search for one impossible XOR differential trail of a cipher with forward and backward deterministic extensions for key recovery. @@ -744,24 +935,102 @@ def find_one_impossible_xor_differential_trail_with_extensions(self, number_of_r sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) sage: fixed_variables.append(set_fixed_variables('cipher_output_6_12', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) - sage: trail = cp.find_one_impossible_xor_differential_trail_with_extensions(7, fixed_variables, 'Chuffed', 2, 4, 6, intermediate_components = False) + sage: trail = cp.find_one_impossible_xor_differential_trail_with_extensions(7, fixed_variables, 'chuffed', 2, 4, 6, intermediate_components = False) + """ + self.build_impossible_xor_differential_trail_with_extensions_model( + fixed_values, number_of_rounds, initial_round, middle_round, final_round, intermediate_components + ) + + if solve_with_API: + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "impossible_xor_differential_one_solution", + solver_name=solver_name, + number_of_rounds=number_of_rounds, + initial_round=initial_round, + middle_round=middle_round, + final_round=final_round, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + + def find_one_impossible_xor_differential_trail( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + initial_round=1, + middle_round=None, + final_round=None, + intermediate_components=True, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): + """ + Search for one impossible XOR differential trail of a cipher. + + INPUT: + + - ``number_of_rounds`` -- **integer** (default: `None`); number of rounds + - ``fixed_values`` -- **list** (default: `[]`); dictionaries containing the variables to be fixed in standard + format + - ``initial_round`` -- **integer** (default: `1`); initial round of the impossible differential + - ``middle_round`` -- **integer** (default: `1`); incosistency round of the impossible differential + - ``final_round`` -- **integer** (default: `None`); final round of the impossible differential + - ``intermediate_components`` -- **Boolean** (default: `True`); check inconsistency on intermediate components of the inconsistency round or only on outputs + - ``num_of_processors`` -- **Integer** (default: `None`); number of processors used for MiniZinc search + - ``timelimit`` -- **Integer** (default: `None`); time limit of MiniZinc search + + EXAMPLES:: + + sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_impossible_xor_differential_model import MznImpossibleXorDifferentialModel + sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher + sage: from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list + sage: speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=4) + sage: cp = MznImpossibleXorDifferentialModel(speck) + sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] + sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) + sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_3_12', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) + sage: trail = cp.find_one_impossible_xor_differential_trail(4, fixed_variables, 'chuffed', 1, 3, 4, intermediate_components = False) """ - self.build_impossible_xor_differential_trail_with_extensions_model(fixed_values, number_of_rounds, - initial_round, middle_round, final_round, - intermediate_components) + initial_round, middle_round, final_round, number_of_rounds = self.validate_input_rounds( + initial_round, middle_round, final_round, number_of_rounds + ) + self.build_impossible_xor_differential_trail_model( + fixed_values, number_of_rounds, initial_round, middle_round, final_round, intermediate_components + ) if solve_with_API: - return self.solve_for_ARX(solver_name=solver_name, timeout_in_seconds_=timelimit, - processes_=num_of_processors) - return self.solve('impossible_xor_differential_one_solution', solver_name=solver_name, - number_of_rounds=number_of_rounds, initial_round=initial_round, middle_round=middle_round, - final_round=final_round, timeout_in_seconds_=timelimit, processes_=num_of_processors, - solve_external=solve_external) - - def find_one_impossible_xor_differential_trail(self, number_of_rounds=None, fixed_values=[], solver_name='Chuffed', - initial_round=1, middle_round=None, final_round=None, - intermediate_components=True, num_of_processors=None, timelimit=None, - solve_with_API=False, solve_external=True): + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "impossible_xor_differential_one_solution", + solver_name=solver_name, + number_of_rounds=number_of_rounds, + initial_round=initial_round, + middle_round=middle_round, + final_round=final_round, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + + def find_one_impossible_xor_differential_trail_with_fully_automatic_model( + self, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + intermediate_components=True, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=True, + ): """ Search for one impossible XOR differential trail of a cipher. @@ -787,20 +1056,26 @@ def find_one_impossible_xor_differential_trail(self, number_of_rounds=None, fixe sage: fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) sage: fixed_variables.append(set_fixed_variables('inverse_cipher_output_3_12', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))) - sage: trail = cp.find_one_impossible_xor_differential_trail(4, fixed_variables, 'Chuffed', 1, 3, 4, intermediate_components = False) + sage: trail = cp.find_one_impossible_xor_differential_trail_with_fully_automatic_model(fixed_variables, 'chuffed', intermediate_components = False) """ - self.build_impossible_xor_differential_trail_model(fixed_values, number_of_rounds, initial_round, middle_round, - final_round, intermediate_components) + self.build_impossible_xor_differential_trail_model( + fixed_variables=fixed_values, intermediate_components=intermediate_components, fully_automatic=True + ) if solve_with_API: - return self.solve_for_ARX(solver_name=solver_name, timeout_in_seconds_=timelimit, - processes_=num_of_processors) - return self.solve('impossible_xor_differential_one_solution', solver_name=solver_name, - number_of_rounds=number_of_rounds, initial_round=initial_round, middle_round=middle_round, - final_round=final_round, timeout_in_seconds_=timelimit, processes_=num_of_processors, - solve_external=solve_external) - - def fix_variables_value_constraints(self, fixed_variables=[], step='full_model'): + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "impossible_xor_differential_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + fully_automatic=True, + ) + + def fix_variables_value_constraints(self, fixed_variables=[], step="full_model"): r""" Return a list of CP constraints that fix the input variables to a specific value. @@ -835,14 +1110,24 @@ def fix_variables_value_constraints(self, fixed_variables=[], step='full_model') """ if fixed_variables == []: fixed_variables.append( - set_fixed_variables('plaintext', 'not_equal', list(range(self._cipher.output_bit_size)), - [0] * self._cipher.output_bit_size)) + set_fixed_variables( + "plaintext", + "not_equal", + list(range(self._cipher.output_bit_size)), + [0] * self._cipher.output_bit_size, + ) + ) fixed_variables.append( - set_fixed_variables('inverse_' + self._cipher.get_all_components_ids()[-1], 'not_equal', - list(range(self._cipher.output_bit_size)), [0] * self._cipher.output_bit_size)) + set_fixed_variables( + "inverse_" + self._cipher.get_all_components_ids()[-1], + "not_equal", + list(range(self._cipher.output_bit_size)), + [0] * self._cipher.output_bit_size, + ) + ) for cipher_input, bit_size in zip(self._cipher._inputs, self._cipher._inputs_bit_size): - if cipher_input == 'key': - fixed_variables.append(set_fixed_variables('key', 'equal', list(range(bit_size)), [0] * bit_size)) + if cipher_input == "key": + fixed_variables.append(set_fixed_variables("key", "equal", list(range(bit_size)), [0] * bit_size)) return super().fix_variables_value_constraints(fixed_variables, step) @@ -853,11 +1138,11 @@ def get_component_from_id(self, id_link, curr_cipher): return None def get_component_round(self, id_link): - if '_' in id_link: - last_us = - id_link[::-1].index('_') - 1 - start = - id_link[last_us - 1::-1].index('_') + last_us + if "_" in id_link: + last_us = -id_link[::-1].index("_") - 1 + start = -id_link[last_us - 1 :: -1].index("_") + last_us - return int(id_link[start:len(id_link) + last_us]) + return int(id_link[start : len(id_link) + last_us]) else: return 0 @@ -867,7 +1152,6 @@ def get_direct_component_correspondance(self, forward_component): return inverse_component def get_inverse_component_correspondance(self, backward_component): - for component in self._cipher.get_all_components(): if backward_component.id == component.id: direct_inputs = component.input_id_links @@ -877,7 +1161,7 @@ def get_inverse_component_correspondance(self, backward_component): inverse_outputs.append(component.id) correspondance = [dir_i for dir_i in direct_inputs if dir_i in inverse_outputs] if len(correspondance) > 1: - return 'Not invertible' + return "Not invertible" else: return correspondance[0] @@ -905,15 +1189,12 @@ def get_state_key_bits_positions(self): return key_bits - def input_impossible_constraints_with_extensions(self, number_of_rounds=None, initial_round=None, middle_round=None, - final_round=None): - - if number_of_rounds is None: - number_of_rounds = self._cipher.number_of_rounds - + def input_impossible_constraints_with_extensions(self, number_of_rounds, initial_round, middle_round, final_round): cp_constraints = [] - cp_declarations = [f'array[0..{bit_size - 1}] of var 0..2: {input_};' - for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)] + cp_declarations = [ + f"array[0..{bit_size - 1}] of var 0..2: {input_};" + for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + ] cipher = self._cipher inverse_cipher = self.inverse_cipher @@ -929,7 +1210,10 @@ def input_impossible_constraints_with_extensions(self, number_of_rounds=None, in for component in forward_components: for id_link in component.input_id_links: input_component = self.get_component_from_id(id_link, cipher) - if input_component not in key_ids + constant_ids + forward_components + components_to_direct and input_component != None: + if ( + input_component not in key_ids + constant_ids + forward_components + components_to_direct + and input_component != None + ): components_to_direct.append(input_component) forward_components.extend(components_to_direct) forward_components.extend(key_components) @@ -944,87 +1228,91 @@ def input_impossible_constraints_with_extensions(self, number_of_rounds=None, in for component in backward_components: for id_link in component.input_id_links: input_component = self.get_component_from_id(id_link, inverse_cipher) - if input_component not in key_ids + constant_ids + backward_components + components_to_invert and input_component != None: + if ( + input_component not in key_ids + constant_ids + backward_components + components_to_invert + and input_component != None + ): components_to_invert.append(input_component) backward_components.extend(components_to_invert) for component in forward_components: output_id_link = component.id output_size = int(component.output_bit_size) - if 'output' in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: {output_id_link};') - cp_constraints.append(f'constraint count({output_id_link},2) < {output_size};') + if "output" in component.type: + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: {output_id_link};") + cp_constraints.append(f"constraint count({output_id_link},2) < {output_size};") elif CONSTANT not in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: {output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: {output_id_link};") for component in backward_components: output_id_link = component.id output_size = int(component.output_bit_size) - if 'output' in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};') - cp_constraints.append(f'constraint count(inverse_{output_id_link},2) < {output_size};') - if self.get_component_round(component.id) == final_round - 1 or self.get_component_round( - component.id) == initial_round - 2: - cp_constraints.append(f'constraint count(inverse_{output_id_link},1) > 0;') + if "output" in component.type: + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};") + cp_constraints.append(f"constraint count(inverse_{output_id_link},2) < {output_size};") + if ( + self.get_component_round(component.id) == final_round - 1 + or self.get_component_round(component.id) == initial_round - 2 + ): + cp_constraints.append(f"constraint count(inverse_{output_id_link},1) > 0;") elif CONSTANT not in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};") - cp_constraints.append(f'constraint count(plaintext,2) < {self._cipher.output_bit_size};') + cp_constraints.append(f"constraint count(plaintext,2) < {self._cipher.output_bit_size};") for component in self._cipher.get_all_components(): if CIPHER_OUTPUT in component.type: - cp_constraints.append(f'constraint count({component.id},2) < {self._cipher.output_bit_size};') + cp_constraints.append(f"constraint count({component.id},2) < {self._cipher.output_bit_size};") return cp_declarations, cp_constraints - def input_impossible_constraints(self, number_of_rounds=None, middle_round=None): - - if number_of_rounds is None: - number_of_rounds = self._cipher.number_of_rounds - + def input_impossible_constraints(self, number_of_rounds, middle_round, fully_automatic=False): cp_constraints = [] - cp_declarations = [f'array[0..{bit_size - 1}] of var 0..2: {input_};' - for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)] + cp_declarations = [ + f"array[0..{bit_size - 1}] of var 0..2: {input_};" + for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + ] cipher = self._cipher inverse_cipher = self.inverse_cipher - if middle_round is not None: + if fully_automatic: + forward_components = cipher.get_all_components() + backward_components = inverse_cipher.get_all_components() + else: forward_components = [] for r in range(middle_round): forward_components.extend(self._cipher.get_components_in_round(r)) backward_components = [] for r in range(number_of_rounds - middle_round + 1): backward_components.extend(inverse_cipher.get_components_in_round(r)) - else: - forward_components = [] - for r in range(number_of_rounds): - forward_components.extend(self._cipher.get_components_in_round(r)) - backward_components = [] - for r in range(number_of_rounds): - backward_components.extend(inverse_cipher.get_components_in_round(r)) - cp_declarations.extend([f'array[0..{bit_size - 1}] of var 0..2: inverse_{input_};' for input_, bit_size in - zip(inverse_cipher.inputs, inverse_cipher.inputs_bit_size) if input_ != 'key']) + cp_declarations.extend( + [ + f"array[0..{bit_size - 1}] of var 0..2: inverse_{input_};" + for input_, bit_size in zip(inverse_cipher.inputs, inverse_cipher.inputs_bit_size) + if input_ != "key" + ] + ) for component in forward_components: output_id_link = component.id output_size = int(component.output_bit_size) - if 'output' in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: {output_id_link};') + if "output" in component.type: + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: {output_id_link};") elif CIPHER_OUTPUT in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: {output_id_link};') - cp_constraints.append(f'constraint count({output_id_link},2) < {output_size};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: {output_id_link};") + cp_constraints.append(f"constraint count({output_id_link},2) < {output_size};") elif CONSTANT not in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: {output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: {output_id_link};") for component in backward_components: output_id_link = component.id output_size = int(component.output_bit_size) - if 'output' in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};') + if "output" in component.type: + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};") elif CIPHER_OUTPUT in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};') - cp_constraints.append(f'constraint count(inverse_{output_id_link},2) < {output_size};') - cp_constraints.append(f'constraint count(inverse_{output_id_link},1) > 0;') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};") + cp_constraints.append(f"constraint count(inverse_{output_id_link},2) < {output_size};") + cp_constraints.append(f"constraint count(inverse_{output_id_link},1) > 0;") elif CONSTANT not in component.type: - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};') - cp_constraints.append('constraint count(plaintext,1) > 0;') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..2: inverse_{output_id_link};") + cp_constraints.append("constraint count(plaintext,1) > 0;") return cp_declarations, cp_constraints @@ -1038,43 +1326,49 @@ def is_cross_round_component(self, component, discarded_ids): def link_constraints_for_trail_with_extensions(self, components_to_link): linking_constraints = [] for pairs in components_to_link: - linking_constraints.append(f'constraint {pairs[0]} = inverse_{pairs[1]};') + linking_constraints.append(f"constraint {pairs[0]} = inverse_{pairs[1]};") return linking_constraints - def _parse_solver_output(self, output_to_parse, number_of_rounds, initial_round, middle_round, final_round): + def _parse_solver_output( + self, output_to_parse, number_of_rounds, initial_round, middle_round, final_round, fully_automatic + ): components_values, memory, time = self.parse_solver_information(output_to_parse, True, True) all_components = [*self._cipher.inputs] - if middle_round is not None: + if fully_automatic: + all_components.extend([component.id for component in self._cipher.get_all_components()]) + all_components.extend(["inverse_" + component.id for component in self.inverse_cipher.get_all_components()]) + all_components.extend(["inverse_" + component for component in self.inverse_cipher.inputs]) + else: for r in list(range(initial_round - 1, middle_round)) + list(range(final_round, number_of_rounds)): all_components.extend([component.id for component in [*self._cipher.get_components_in_round(r)]]) for r in list(range(initial_round - 1)) + list(range(middle_round - 1, final_round)): - all_components.extend(['inverse_' + component.id for component in - [*self.inverse_cipher.get_components_in_round(number_of_rounds - r - 1)]]) - else: - for r in list(range(initial_round - 1, number_of_rounds)): - all_components.extend([component.id for component in [*self._cipher.get_components_in_round(r)]]) - for r in list(range(final_round)): - all_components.extend(['inverse_' + component.id for component in - [*self.inverse_cipher.get_components_in_round(number_of_rounds - r - 1)]]) - all_components.extend(['inverse_' + id_link for id_link in [*self.inverse_cipher.inputs]]) - all_components.extend(['inverse_' + id_link for id_link in [*self._cipher.inputs]]) + all_components.extend( + [ + "inverse_" + component.id + for component in [*self.inverse_cipher.get_components_in_round(number_of_rounds - r - 1)] + ] + ) + all_components.extend(["inverse_" + id_link for id_link in [*self.inverse_cipher.inputs]]) + all_components.extend(["inverse_" + id_link for id_link in [*self._cipher.inputs]]) for component_id in all_components: solution_number = 1 for j, string in enumerate(output_to_parse): - if f'{component_id}' in string and 'inverse_' not in component_id + string: + if f"{component_id}" in string and "inverse_" not in component_id + string: value = self.format_component_value(component_id, string) component_solution = {} - component_solution['value'] = value - self.add_solution_to_components_values(component_id, component_solution, components_values, j, - output_to_parse, solution_number, string) - elif f'{component_id}' in string and 'inverse_' in component_id: + component_solution["value"] = value + self.add_solution_to_components_values( + component_id, component_solution, components_values, j, output_to_parse, solution_number, string + ) + elif f"{component_id}" in string and "inverse_" in component_id: value = self.format_component_value(component_id, string) component_solution = {} - component_solution['value'] = value - self.add_solution_to_components_values(component_id, component_solution, components_values, j, - output_to_parse, solution_number, string) - elif '----------' in string: + component_solution["value"] = value + self.add_solution_to_components_values( + component_id, component_solution, components_values, j, output_to_parse, solution_number, string + ) + elif "----------" in string: solution_number += 1 return time, memory, components_values @@ -1084,67 +1378,87 @@ def set_inverse_component_id_in_constraints(self, component, inverse_variables, start = 0 while component.id in inverse_variables[v][start:]: new_start = inverse_variables[v].index(component.id, start) - inverse_variables[v] = inverse_variables[v][:new_start] + 'inverse_' + inverse_variables[v][new_start:] + inverse_variables[v] = inverse_variables[v][:new_start] + "inverse_" + inverse_variables[v][new_start:] start = new_start + 9 for c in range(len(inverse_constraints)): start = 0 while component.id in inverse_constraints[c][start:]: new_start = inverse_constraints[c].index(component.id, start) - inverse_constraints[c] = inverse_constraints[c][:new_start] + 'inverse_' + inverse_constraints[c][ - new_start:] + inverse_constraints[c] = ( + inverse_constraints[c][:new_start] + "inverse_" + inverse_constraints[c][new_start:] + ) start = new_start + 9 return inverse_variables, inverse_constraints - def solve(self, model_type, solver_name=None, number_of_rounds=None, initial_round=1, middle_round=None, - final_round=None, processes_=None, timeout_in_seconds_=None, all_solutions_=False, solve_external=False): - if number_of_rounds is None: + def solve( + self, + model_type, + solver_name=SOLVER_DEFAULT, + number_of_rounds=None, + initial_round=1, + middle_round=None, + final_round=None, + processes_=None, + timeout_in_seconds_=None, + all_solutions_=False, + solve_external=False, + fully_automatic=False, + ): + if fully_automatic: number_of_rounds = self._cipher.number_of_rounds - if final_round is None: final_round = self._cipher.number_of_rounds - cipher_name = self.cipher_id - input_file_path = f'{cipher_name}_Mzn_{model_type}_{solver_name}.mzn' - command = self.get_command_for_solver_process(input_file_path, model_type, solver_name, processes_, - timeout_in_seconds_) + command = self.get_command_for_solver_process(model_type, solver_name, processes_, timeout_in_seconds_) + model = "\n".join(self._model_constraints) + "\n" start = time.time() - solver_process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + solver_process = subprocess.run(command, input=model, capture_output=True, text=True) end = time.time() solve_time = end - start - os.remove(input_file_path) if solver_process.returncode >= 0: solutions = [] solver_output = solver_process.stdout.splitlines() - if model_type in ['deterministic_truncated_xor_differential', - 'deterministic_truncated_xor_differential_one_solution', - 'impossible_xor_differential', - 'impossible_xor_differential_one_solution', - 'impossible_xor_differential_attack']: - solver_time, memory, components_values = self._parse_solver_output(solver_output, number_of_rounds, - initial_round, middle_round, - final_round) + if model_type in [ + "deterministic_truncated_xor_differential", + "deterministic_truncated_xor_differential_one_solution", + "impossible_xor_differential", + "impossible_xor_differential_one_solution", + "impossible_xor_differential_attack", + ]: + solver_time, memory, components_values = self._parse_solver_output( + solver_output, number_of_rounds, initial_round, middle_round, final_round, fully_automatic + ) total_weight = 0 else: - solver_time, memory, components_values, total_weight = self._parse_solver_output(solver_output, - number_of_rounds, - initial_round, - middle_round, - final_round) + solver_time, memory, components_values, total_weight = self._parse_solver_output( + solver_output, number_of_rounds, initial_round, middle_round, final_round, fully_automatic + ) if components_values == {}: - solution = convert_solver_solution_to_dictionary(self.cipher_id, model_type, solver_name, - solve_time, memory, - components_values, total_weight) - if 'UNSATISFIABLE' in solver_output[0]: - solution['status'] = 'UNSATISFIABLE' + solution = convert_solver_solution_to_dictionary( + self.cipher_id, model_type, solver_name, solve_time, memory, components_values, total_weight + ) + if UNSATISFIABLE in solver_output[0]: + solution["status"] = UNSATISFIABLE else: - solution['status'] = 'SATISFIABLE' + solution["status"] = SATISFIABLE solutions.append(solution) else: - self.add_solutions_from_components_values(components_values, memory, model_type, solutions, solve_time, - solver_name, solver_output, 0, solve_external) - if model_type in ['xor_differential_one_solution', - 'xor_linear_one_solution', - 'deterministic_truncated_one_solution', - 'impossible_xor_differential_one_solution']: + self.add_solutions_from_components_values( + components_values, + memory, + model_type, + solutions, + solve_time, + solver_name, + solver_output, + 0, + solve_external, + ) + if model_type in [ + "deterministic_truncated_one_solution", + "impossible_xor_differential_one_solution", + "xor_differential_one_solution", + "xor_linear_one_solution", + ]: return solutions[0] else: - return solutions \ No newline at end of file + return solutions diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model.py index b6824079e..398b19500 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model.py @@ -1,36 +1,32 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - -import os -import math -import itertools -import subprocess - -from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model import MznDeterministicTruncatedXorDifferentialModel, solve_satisfy -from claasp.cipher_modules.models.utils import write_model_to_file, convert_solver_solution_to_dictionary -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, - WORD_OPERATION, DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL) -from claasp.cipher_modules.models.cp.solvers import MODEL_DEFAULT_PATH, SOLVER_DEFAULT +from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model import ( + MznDeterministicTruncatedXorDifferentialModel, + SOLVE_SATISFY, +) +from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, +) class MznWordwiseDeterministicTruncatedXorDifferentialModel(MznDeterministicTruncatedXorDifferentialModel): - def __init__(self, cipher): super().__init__(cipher) @@ -55,64 +51,84 @@ def final_wordwise_deterministic_truncated_xor_differential_constraints(self, mi cipher_inputs = self._cipher.inputs cipher = self._cipher cp_constraints = [] - new_constraint = 'output[' + new_constraint = "output[" for element in cipher_inputs: - new_constraint = f'{new_constraint}\"{element}_active = \"++ show({element}_active) ++ \"\\n\" ++' + new_constraint = f'{new_constraint}"{element}_active = "++ show({element}_active) ++ "\\n" ++' for component_id in cipher.get_all_components_ids(): - new_constraint = new_constraint + \ - f'\"{component_id} = \"++ show({component_id}_active)++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - if 'cipher_output' in component_id and minimize: - cp_constraints.append(f'solve maximize count({self._cipher.get_all_components_ids()[-1]}_active, 0);') - new_constraint = new_constraint[:-2] + '];' + new_constraint = ( + new_constraint + f'"{component_id} = "++ show({component_id}_active)++ "\\n" ++ "0" ++ "\\n" ++' + ) + if "cipher_output" in component_id and minimize: + cp_constraints.append(f"solve maximize count({self._cipher.get_all_components_ids()[-1]}_active, 0);") + new_constraint = new_constraint[:-2] + "];" if cp_constraints == []: - cp_constraints.append(solve_satisfy) + cp_constraints.append(SOLVE_SATISFY) cp_constraints.append(new_constraint) return cp_constraints - def find_one_wordwise_deterministic_truncated_xor_differential_trail(self, number_of_rounds=None, - fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): - + def find_one_wordwise_deterministic_truncated_xor_differential_trail( + self, + number_of_rounds=None, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): if number_of_rounds is None: number_of_rounds = self._cipher.number_of_rounds self.build_deterministic_truncated_xor_differential_trail_model(fixed_values, number_of_rounds, wordwise=True) if solve_with_API: - return self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) - return self.solve('deterministic_truncated_xor_differential_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) + return self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) + return self.solve( + "deterministic_truncated_xor_differential_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) def input_wordwise_deterministic_truncated_xor_differential_constraints(self): - cp_constraints = [] cp_declarations = [] for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size): - cp_declarations.append(f'array[0..{bit_size // self.word_size - 1}] of var 0..3: {input_}_active;') + cp_declarations.append(f"array[0..{bit_size // self.word_size - 1}] of var 0..3: {input_}_active;") cp_declarations.append( - f'array[0..{bit_size // self.word_size - 1}] of var -2..{2 ** self.word_size - 1}: {input_}_value;') + f"array[0..{bit_size // self.word_size - 1}] of var -2..{2**self.word_size - 1}: {input_}_value;" + ) for i in range(bit_size // self.word_size): - cp_constraints.append(f'constraint if {input_}_active[{i}] == 0 then {input_}_value[{i}] = 0 elseif ' - f'{input_}_active[{i}] == 1 then {input_}_value[{i}] > 0 elseif ' - f'{input_}_active[{i}] == 2 then {input_}_value[{i}] =-1 else ' - f'{input_}_value[{i}] =-2 endif;') + cp_constraints.append( + f"constraint if {input_}_active[{i}] == 0 then {input_}_value[{i}] = 0 elseif " + f"{input_}_active[{i}] == 1 then {input_}_value[{i}] > 0 elseif " + f"{input_}_active[{i}] == 2 then {input_}_value[{i}] =-1 else " + f"{input_}_value[{i}] =-2 endif;" + ) for component in self._cipher.get_all_components(): if CONSTANT not in component.type: output_id_link = component.id output_size = int(component.output_bit_size) cp_declarations.append( - f'array[0..{output_size // self.word_size - 1}] of var 0..3: {output_id_link}_active;') + f"array[0..{output_size // self.word_size - 1}] of var 0..3: {output_id_link}_active;" + ) cp_declarations.append( - f'array[0..{output_size // self.word_size - 1}] of var -2..{2 ** self.word_size - 1}: ' - f'{output_id_link}_value;') + f"array[0..{output_size // self.word_size - 1}] of var -2..{2**self.word_size - 1}: " + f"{output_id_link}_value;" + ) for i in range(output_size // self.word_size): cp_constraints.append( - f'constraint if {output_id_link}_active[{i}] == 0 then {output_id_link}_value[{i}] = 0 elseif ' - f'{output_id_link}_active[{i}] == 1 then {output_id_link}_value[{i}] > 0 elseif ' - f'{output_id_link}_active[{i}] == 2 then {output_id_link}_value[{i}] =-1 else ' - f'{output_id_link}_value[{i}] =-2 endif;') + f"constraint if {output_id_link}_active[{i}] == 0 then {output_id_link}_value[{i}] = 0 elseif " + f"{output_id_link}_active[{i}] == 1 then {output_id_link}_value[{i}] > 0 elseif " + f"{output_id_link}_active[{i}] == 2 then {output_id_link}_value[{i}] =-1 else " + f"{output_id_link}_value[{i}] =-2 endif;" + ) if CIPHER_OUTPUT in component.type: - cp_constraints.append(f'constraint count({output_id_link}_active,2) < {output_size};') - cp_constraints.append('constraint count(plaintext_active,1) > 0;') + cp_constraints.append(f"constraint count({output_id_link}_active,2) < {output_size};") + cp_constraints.append("constraint count(plaintext_active,1) > 0;") return cp_declarations, cp_constraints - diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model.py index 8b98ca91f..8573be0ad 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,10 +20,18 @@ import time as tm from sage.crypto.sbox import SBox -from claasp.cipher_modules.models.cp.mzn_model import MznModel, solve_satisfy +from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY from claasp.cipher_modules.models.utils import get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, SBOX, MIX_COLUMN, WORD_OPERATION, - XOR_DIFFERENTIAL, LINEAR_LAYER) +from claasp.name_mappings import ( + CONSTANT, + INTERMEDIATE_OUTPUT, + CIPHER_OUTPUT, + SBOX, + MIX_COLUMN, + WORD_OPERATION, + XOR_DIFFERENTIAL, + LINEAR_LAYER, +) from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT @@ -52,9 +59,9 @@ def and_xor_differential_probability_ddt(numadd): count = 0 for j in range(n): k = i ^ j - binary_j = format(j, f'0{numadd}b') + binary_j = f"{j:0{numadd}b}" result_j = 1 - binary_k = format(k, f'0{numadd}b') + binary_k = f"{k:0{numadd}b}" result_k = 1 for index in range(numadd): result_j *= int(binary_j[index]) @@ -73,26 +80,28 @@ def update_and_or_ddt_valid_probabilities(and_already_added, component, cp_decla ddt_table = and_xor_differential_probability_ddt(numadd) dim_ddt = len([i for i in ddt_table if i]) ddt_entries = [] - ddt_values = '' + ddt_values = "" set_of_occurrences = set(ddt_table) set_of_occurrences -= {0} - valid_probabilities.update({round(100 * math.log2(2 ** numadd / occurrence)) - for occurrence in set_of_occurrences}) + valid_probabilities.update( + {round(100 * math.log2(2**numadd / occurrence)) for occurrence in set_of_occurrences} + ) for i in range(pow(2, numadd + 1)): if ddt_table[i] != 0: - binary_i = format(i, f'0{numadd + 1}b') - ddt_entries += [f'{binary_i[j]}' for j in range(numadd + 1)] + binary_i = f"{i:0{numadd + 1}b}" + ddt_entries += [f"{binary_i[j]}" for j in range(numadd + 1)] ddt_entries.append(str(round(100 * math.log2(pow(2, numadd) / ddt_table[i])))) - ddt_values = ','.join(ddt_entries) - and_declaration = f'array [1..{dim_ddt}, 1..{numadd + 2}] of int: ' \ - f'and{numadd}inputs_DDT = array2d(1..{dim_ddt}, 1..{numadd + 2}, ' \ - f'[{ddt_values}]);' + ddt_values = ",".join(ddt_entries) + and_declaration = ( + f"array [1..{dim_ddt}, 1..{numadd + 2}] of int: " + f"and{numadd}inputs_DDT = array2d(1..{dim_ddt}, 1..{numadd + 2}, " + f"[{ddt_values}]);" + ) cp_declarations.append(and_declaration) and_already_added.append(numadd) class MznXorDifferentialModel(MznModel): - def __init__(self, cipher): self._first_step = [] self._first_step_find_all_solutions = [] @@ -137,16 +146,17 @@ def build_xor_differential_trail_model_template(self, weight, fixed_variables, m if fixed_variables == []: fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) constraints = self.fix_variables_value_constraints(fixed_variables) - component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR'] + component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') - elif operation in ('MODADD', 'MODSUB') and milp_modadd: + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") + elif operation in ("MODADD", "MODSUB") and milp_modadd: variables, constraints = component.cp_xor_differential_propagation_constraints_arx_optimized(self) else: variables, constraints = component.cp_xor_differential_propagation_constraints(self) @@ -159,7 +169,7 @@ def build_xor_differential_trail_model_template(self, weight, fixed_variables, m self._variables_list.extend(variables) self._model_constraints.extend(constraints) - def final_xor_differential_constraints(self, weight, milp_modadd = False): + def final_xor_differential_constraints(self, weight, milp_modadd=False): """ Return a CP constraints list for the cipher outputs and solving indications for single or second step model. @@ -180,28 +190,42 @@ def final_xor_differential_constraints(self, weight, milp_modadd = False): cipher_inputs = self._cipher.inputs cp_constraints = [] if weight == -1 and self._probability: - cp_constraints.append('solve:: int_search(p, smallest, indomain_min, complete) minimize weight;') + cp_constraints.append("solve:: int_search(p, smallest, indomain_min, complete) minimize weight;") else: - cp_constraints.append(solve_satisfy) - new_constraint = 'output[' + cp_constraints.append(SOLVE_SATISFY) + new_constraint = "output[" for element in cipher_inputs: - new_constraint = new_constraint + f'\"{element} = \"++ show({element}) ++ \"\\n\" ++' + new_constraint = new_constraint + f'"{element} = "++ show({element}) ++ "\\n" ++' for component in self._cipher.get_all_components(): if SBOX in component.type: - new_constraint = new_constraint + \ - f'\"{component.id} = \"++ show({component.id})++ \"\\n\" ++ ' \ - f'show(p[{self.component_and_probability[component.id]}]/100) ++ \"\\n\" ++' + new_constraint = ( + new_constraint + f'"{component.id} = "++ show({component.id})++ "\\n" ++ ' + f'show(p[{self.component_and_probability[component.id]}]/100) ++ "\\n" ++' + ) elif WORD_OPERATION in component.type: - new_constraint = self.get_word_operation_xor_differential_constraints(component, new_constraint, milp_modadd) + new_constraint = self.get_word_operation_xor_differential_constraints( + component, new_constraint, milp_modadd + ) else: - new_constraint = new_constraint + f'\"{component.id} = \"++ ' \ - f'show({component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' - new_constraint = new_constraint + '\"Trail weight = \" ++ show(weight)];' + new_constraint = ( + new_constraint + f'"{component.id} = "++ show({component.id})++ "\\n" ++ "0" ++ "\\n" ++' + ) + new_constraint = new_constraint + '"Trail weight = " ++ show(weight)];' cp_constraints.append(new_constraint) return cp_constraints - def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, milp_modadd=False, solve_external = False): + def find_all_xor_differential_trails_with_fixed_weight( + self, + fixed_weight, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + milp_modadd=False, + solve_external=False, + ): """ Return a list of solutions containing all the differential trails having the ``fixed_weight`` weight. By default, the search is set in the single-key setting. @@ -210,11 +234,8 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed - ``fixed_weight`` -- **integer**; the weight to be fixed - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -223,7 +244,7 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: speck = SpeckBlockCipher(number_of_rounds=5) sage: cp = MznXorDifferentialModel(speck) - sage: trails = cp.find_all_xor_differential_trails_with_fixed_weight(9, solver_name='Chuffed', solve_external=True) + sage: trails = cp.find_all_xor_differential_trails_with_fixed_weight(9, solver_name='chuffed', solve_external=True) sage: len(trails) 2 @@ -243,17 +264,39 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed end = tm.time() build_time = end - start if solve_with_API: - solutions = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True) + solutions = self.solve_for_ARX( + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + ) else: - solutions = self.solve(XOR_DIFFERENTIAL, solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True, solve_external = solve_external) + solutions = self.solve( + XOR_DIFFERENTIAL, + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + solve_external=solve_external, + ) if solve_external: for solution in solutions: - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_all_xor_differential_trails_with_fixed_weight" + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_all_xor_differential_trails_with_fixed_weight" return solutions - def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_weight=64, fixed_values=[], - solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, milp_modadd=False, solve_external = False): + def find_all_xor_differential_trails_with_weight_at_most( + self, + min_weight, + max_weight=64, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + milp_modadd=False, + solve_external=False, + ): """ Return a list of solutions containing all the differential trails. By default, the search is set in the single-key setting. @@ -264,11 +307,8 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w - ``min_weight`` -- **integer**; the weight from which to start the search - ``max_weight`` -- **integer** (default: 64); the weight at which the search stops - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -277,7 +317,7 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: speck = SpeckBlockCipher(number_of_rounds=5) sage: cp = MznXorDifferentialModel(speck) - sage: trails = cp.find_all_xor_differential_trails_with_weight_at_most(9,10, solver_name='Chuffed', solve_external=True) + sage: trails = cp.find_all_xor_differential_trails_with_weight_at_most(9,10, solver_name='chuffed', solve_external=True) sage: len(trails) 28 @@ -295,35 +335,71 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w """ start = tm.time() self.build_xor_differential_trail_model(0, fixed_values, milp_modadd) - self._model_constraints.append(f'constraint weight >= {100 * min_weight} /\\ weight <= {100 * max_weight} ') + self._model_constraints.append(f"constraint weight >= {100 * min_weight} /\\ weight <= {100 * max_weight} ") end = tm.time() build_time = end - start if solve_with_API: - solutions = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True) + solutions = self.solve_for_ARX( + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + ) else: - solutions = self.solve(XOR_DIFFERENTIAL, solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True, solve_external = solve_external) + solutions = self.solve( + XOR_DIFFERENTIAL, + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + solve_external=solve_external, + ) for solution in solutions: - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_all_xor_differential_trails_with_weight_at_most" + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_all_xor_differential_trails_with_weight_at_most" return solutions - def find_differential_weight(self, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, milp_modadd=False, solve_external = False): + def find_differential_weight( + self, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + milp_modadd=False, + solve_external=False, + ): probability = 0 self.build_xor_differential_trail_model(-1, fixed_values, milp_modadd) if solve_with_API: - solutions = self.solve_for_ARX(solver_name = solver_name, all_solutions_ = True) + solutions = self.solve_for_ARX(solver_name=solver_name, all_solutions_=True) else: - solutions = self.solve(XOR_DIFFERENTIAL, solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) + solutions = self.solve( + XOR_DIFFERENTIAL, + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) if isinstance(solutions, list): for solution in solutions: - weight = solution['total_weight'] - probability += 1 / 2 ** weight + weight = solution["total_weight"] + probability += 1 / 2**weight return math.log2(1 / probability) else: - return solutions['total_weight'] - - def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, milp_modadd=False, solve_external = False): + return solutions["total_weight"] + + def find_lowest_weight_xor_differential_trail( + self, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + milp_modadd=False, + solve_external=False, + ): """ Return the solution representing a differential trail with the lowest probability weight. By default, the search is set in the single-key setting. @@ -336,11 +412,8 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name INPUT: - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -349,10 +422,10 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: speck = SpeckBlockCipher(number_of_rounds=5) sage: cp = MznXorDifferentialModel(speck) - sage: cp.find_lowest_weight_xor_differential_trail(solver_name='Chuffed', solve_external=True) # random + sage: cp.find_lowest_weight_xor_differential_trail(solver_name='chuffed', solve_external=True) # random {'cipher': speck_p32_k64_o32_r5, 'model_type': 'xor_differential_one_solution', - 'solver_name': 'Chuffed', + 'solver_name': 'chuffed', 'solving_time_seconds': 120.349, 'memory_megabytes': 0.28, 'components_values': {'plaintext': {'value': '28000010', 'weight': 0}, @@ -380,14 +453,31 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name end = tm.time() build_time = end - start if solve_with_API: - solution = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) + solution = self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) else: - solution = self.solve('xor_differential_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_lowest_weight_xor_differential_trail" + solution = self.solve( + "xor_differential_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_lowest_weight_xor_differential_trail" return solution - def find_one_xor_differential_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, milp_modadd=False, solve_external = False): + def find_one_xor_differential_trail( + self, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + milp_modadd=False, + solve_external=False, + ): """ Return the solution representing a differential trail with any weight. By default, the search is set in the single-key setting. @@ -395,11 +485,8 @@ def find_one_xor_differential_trail(self, fixed_values=[], solver_name=SOLVER_DE INPUT: - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -408,7 +495,7 @@ def find_one_xor_differential_trail(self, fixed_values=[], solver_name=SOLVER_DE sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: speck = SpeckBlockCipher(number_of_rounds=2) sage: cp = MznXorDifferentialModel(speck) - sage: cp.find_one_xor_differential_trail(solver_name='Chuffed', solve_external=True) # random + sage: cp.find_one_xor_differential_trail(solver_name='chuffed', solve_external=True) # random {'cipher_id': 'speck_p32_k64_o32_r2', 'model_type': 'xor_differential_one_solution', ... @@ -430,15 +517,32 @@ def find_one_xor_differential_trail(self, fixed_values=[], solver_name=SOLVER_DE end = tm.time() build_time = end - start if solve_with_API: - solution = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) + solution = self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) else: - solution = self.solve('xor_differential_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_one_xor_differential_trail" + solution = self.solve( + "xor_differential_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_one_xor_differential_trail" return solution - def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight=-1, fixed_values=[], - solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, milp_modadd=False, solve_external = False): + def find_one_xor_differential_trail_with_fixed_weight( + self, + fixed_weight=-1, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + milp_modadd=False, + solve_external=False, + ): """ Return the solution representing a differential trail with the weight of probability equal to ``fixed_weight``. By default, the search is set in the single-key setting. @@ -447,11 +551,8 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight=-1, fix - ``fixed_weight`` -- **integer**; the value to which the weight is fixed, if non-negative - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -460,7 +561,7 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight=-1, fix sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: speck = SpeckBlockCipher(number_of_rounds=3) sage: cp = MznXorDifferentialModel(speck) - sage: trail = cp.find_one_xor_differential_trail_with_fixed_weight(3, solver_name='Chuffed', solve_external=True) # random + sage: trail = cp.find_one_xor_differential_trail_with_fixed_weight(3, solver_name='chuffed', solve_external=True) # random sage: trail['total_weight'] '3.0' @@ -480,24 +581,31 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight=-1, fix end = tm.time() build_time = end - start if solve_with_API: - solution = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) + solution = self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) else: - solution = self.solve('xor_differential_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) + solution = self.solve( + "xor_differential_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) if solve_external: - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_one_xor_differential_trail_with_fixed_weight" + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_one_xor_differential_trail_with_fixed_weight" return solution - def get_word_operation_xor_differential_constraints(self, component, new_constraint, milp_modadd = False): - if 'AND' in component.description[0] or ('MODADD' in component.description[0] and not milp_modadd): - new_constraint = new_constraint + f'\"{component.id} = \"++ show({component.id})++ \"\\n\" ++ show(' + def get_word_operation_xor_differential_constraints(self, component, new_constraint, milp_modadd=False): + if "AND" in component.description[0] or ("MODADD" in component.description[0] and not milp_modadd): + new_constraint = new_constraint + f'"{component.id} = "++ show({component.id})++ "\\n" ++ show(' for i in range(len(self.component_and_probability[component.id])): - new_constraint = new_constraint + f'p[{self.component_and_probability[component.id][i]}]/100+' - new_constraint = new_constraint[:-1] + ') ++ \"\\n\" ++' + new_constraint = new_constraint + f"p[{self.component_and_probability[component.id][i]}]/100+" + new_constraint = new_constraint[:-1] + ') ++ "\\n" ++' else: - new_constraint = new_constraint + f'\"{component.id} = \"++ ' \ - f'show({component.id})++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + new_constraint = new_constraint + f'"{component.id} = "++ show({component.id})++ "\\n" ++ "0" ++ "\\n" ++' return new_constraint @@ -524,8 +632,10 @@ def input_xor_differential_constraints(self): 'var int: weight = sum(p);'], []) """ - self._cp_xor_differential_constraints = [f'array[0..{bit_size - 1}] of var 0..1: {input_};' - for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)] + self._cp_xor_differential_constraints = [ + f"array[0..{bit_size - 1}] of var 0..1: {input_};" + for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + ] self.sbox_mant = [] prob_count = 0 valid_probabilities = {0} @@ -533,25 +643,28 @@ def input_xor_differential_constraints(self): for component in self._cipher.get_all_components(): if CONSTANT not in component.type: output_id_link = component.id - self._cp_xor_differential_constraints.append(f'array[0..{int(component.output_bit_size) - 1}] of var 0..1: {output_id_link};') + self._cp_xor_differential_constraints.append( + f"array[0..{int(component.output_bit_size) - 1}] of var 0..1: {output_id_link};" + ) if SBOX in component.type: prob_count += 1 self.update_sbox_ddt_valid_probabilities(component, valid_probabilities) elif WORD_OPERATION in component.type: - if 'AND' in component.description[0] or component.description[0] == 'OR': + if "AND" in component.description[0] or component.description[0] == "OR": prob_count += component.description[1] * component.output_bit_size - update_and_or_ddt_valid_probabilities(and_already_added, component, self._cp_xor_differential_constraints, - valid_probabilities) - elif 'MODADD' in component.description[0]: + update_and_or_ddt_valid_probabilities( + and_already_added, component, self._cp_xor_differential_constraints, valid_probabilities + ) + elif "MODADD" in component.description[0]: prob_count += component.description[1] - 1 output_size = component.output_bit_size valid_probabilities |= set(range(100 * output_size)[::100]) - cp_declarations_weight = 'int: weight = 0;' + cp_declarations_weight = "int: weight = 0;" if prob_count > 0: self._probability = True - new_declaration = f'array[0..{prob_count - 1}] of var {valid_probabilities}: p;' + new_declaration = f"array[0..{prob_count - 1}] of var {valid_probabilities}: p;" self._cp_xor_differential_constraints.append(new_declaration) - cp_declarations_weight = 'var int: weight = sum(p);' + cp_declarations_weight = "var int: weight = sum(p);" self._cp_xor_differential_constraints.append(cp_declarations_weight) cp_constraints = [] @@ -571,6 +684,7 @@ def update_sbox_ddt_valid_probabilities(self, component, valid_probabilities): for i in range(sbox_ddt.nrows()): set_of_occurrences = set(sbox_ddt.rows()[i]) set_of_occurrences -= {0} - valid_probabilities.update({round(100 * math.log2(2 ** input_size / occurrence)) - for occurrence in set_of_occurrences}) + valid_probabilities.update( + {round(100 * math.log2(2**input_size / occurrence)) for occurrence in set_of_occurrences} + ) self.sbox_mant.append((description, output_id_link)) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized.py index d4bf7cae1..419cf5173 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -24,10 +23,13 @@ class MznXorDifferentialModelARXOptimized(MznModel): - def __init__( - self, cipher, window_size_list=None, probability_weight_per_round=None, sat_or_milp='sat', - include_word_operations_mzn_file=True + self, + cipher, + window_size_list=None, + probability_weight_per_round=None, + sat_or_milp="sat", + include_word_operations_mzn_file=True, ): self.include_word_operations_mzn_file = include_word_operations_mzn_file super().__init__(cipher, window_size_list, probability_weight_per_round, sat_or_milp) @@ -35,9 +37,9 @@ def __init__( @staticmethod def _create_minizinc_1d_array_from_list(mzn_list): mzn_list_size = len(mzn_list) - lst_temp = f'[{",".join(mzn_list)}]' + lst_temp = f"[{','.join(mzn_list)}]" - return f'array1d(0..{mzn_list_size}-1, {lst_temp})' + return f"array1d(0..{mzn_list_size}-1, {lst_temp})" @staticmethod def _get_total_weight(result): @@ -56,19 +58,19 @@ def _get_total_weight(result): @staticmethod def _parse_solution( - result, solution, list_of_vars, probability_vars, result_status, solution_dict, result_statistics=None + result, solution, list_of_vars, probability_vars, result_status, solution_dict, result_statistics=None ): def get_hex_string_from_bool_dict(data, bool_dict, probability_vars_weights_): temp_result = {} for sublist in data: reversed_list = sublist[::-1] bool_list = [bool_dict[item] for item in reversed_list] - int_value = sum([2 ** i if bit else 0 for i, bit in enumerate(bool_list)]) + int_value = sum([2**i if bit else 0 for i, bit in enumerate(bool_list)]) component_id = "_".join(sublist[0].split("_")[:-1]) weight = 0 - if component_id.startswith('modadd') or component_id.startswith('modsub'): - weight = probability_vars_weights_[f'p_{component_id}_0']['weight'] - temp_result[component_id] = {'value': hex(int_value)[2:], 'weight': weight, 'sign': 1} + if component_id.startswith("modadd") or component_id.startswith("modsub"): + weight = probability_vars_weights_[f"p_{component_id}_0"]["weight"] + temp_result[component_id] = {"value": hex(int_value)[2:], "weight": weight, "sign": 1} return temp_result @@ -78,18 +80,16 @@ def get_hex_string_from_bool_dict(data, bool_dict, probability_vars_weights_): probability_vars_weights = MznXorDifferentialModelARXOptimized.parse_probability_vars( result, solution, probability_vars ) - solution_total_weight = sum(item['weight'] for item in probability_vars_weights.values()) - parsed_solution['total_weight'] = solution_total_weight - parsed_solution['component_values'] = get_hex_string_from_bool_dict( + solution_total_weight = sum(item["weight"] for item in probability_vars_weights.values()) + parsed_solution["total_weight"] = solution_total_weight + parsed_solution["component_values"] = get_hex_string_from_bool_dict( list_of_vars, dict_of_solutions, probability_vars_weights ) return parsed_solution @staticmethod - def _parse_result( - result, solver_name, total_weight, model_type, _variables_list, cipher_id, probability_vars - ): + def _parse_result(result, solver_name, total_weight, model_type, _variables_list, cipher_id, probability_vars): def _entry_matches(entry, prefix): valid_starts = [f"var bool: {prefix}", f"var 0..1: {prefix}"] return any(entry.startswith(vs) for vs in valid_starts) @@ -104,35 +104,36 @@ def group_strings_by_pattern(data: list) -> list: return temp_result list_of_vars = group_strings_by_pattern(_variables_list) - common_parsed_data = { - 'id': cipher_id, - 'model_type': model_type, - 'solver_name': solver_name - } + common_parsed_data = {"id": cipher_id, "model_type": model_type, "solver_name": solver_name} if total_weight == "list_of_solutions": solutions = [] for solution in result.solution: - parsed_solution = {'total_weight': None, 'component_values': {}} + parsed_solution = {"total_weight": None, "component_values": {}} parsed_solution_temp = {} if result.status in [Status.SATISFIED, Status.ALL_SOLUTIONS, Status.OPTIMAL_SOLUTION]: parsed_solution_temp = MznXorDifferentialModelARXOptimized._parse_solution( result, solution, list_of_vars, probability_vars, result.status, solution.__dict__ ) - parsed_solution['status'] = str(result.status) + parsed_solution["status"] = str(result.status) parsed_solution = {**parsed_solution, **parsed_solution_temp} solutions.append({**parsed_solution, **common_parsed_data}) return solutions else: - parsed_solution = {'total_weight': None, 'component_values': {}} + parsed_solution = {"total_weight": None, "component_values": {}} parsed_solution_temp = {} if result.status in [Status.SATISFIED, Status.ALL_SOLUTIONS, Status.OPTIMAL_SOLUTION]: parsed_solution_temp = MznXorDifferentialModelARXOptimized._parse_solution( - result, result.solution, list_of_vars, probability_vars, result.status, - result.solution.__dict__, result.statistics + result, + result.solution, + list_of_vars, + probability_vars, + result.status, + result.solution.__dict__, + result.statistics, ) - parsed_solution['status'] = str(result.status) + parsed_solution["status"] = str(result.status) parsed_solution = {**parsed_solution, **parsed_solution_temp} return {**parsed_solution, **common_parsed_data} @@ -167,14 +168,15 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): self._variables_list = [] constraints = self.fix_variables_value_constraints_for_ARX(fixed_variables) component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, WORD_OPERATION] - operation_types = ['MODADD', 'MODSUB', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR'] + operation_types = ["MODADD", "MODSUB", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR"] self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.minizinc_xor_differential_propagation_constraints(self) @@ -201,6 +203,7 @@ def build_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixe sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import MznXorDifferentialModelARXOptimized + sage: from claasp.cipher_modules.models.cp.solvers import CPSAT sage: speck = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) sage: minizinc = MznXorDifferentialModelARXOptimized(speck) sage: bit_positions = [i for i in range(speck.output_bit_size)] @@ -216,7 +219,7 @@ def build_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixe ....: 'operator': '=', ....: 'value': '0' }) sage: minizinc.build_lowest_weight_xor_differential_trail_model(fixed_variables) - sage: result = minizinc.solve_for_ARX('Xor') + sage: result = minizinc.solve_for_ARX(CPSAT) sage: result.statistics['nSolutions'] > 1 True """ @@ -239,6 +242,7 @@ def build_lowest_weight_xor_differential_trail_model(self, fixed_variables, max_ sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import MznXorDifferentialModelARXOptimized sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher + sage: from claasp.cipher_modules.models.cp.solvers import CPSAT sage: speck = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) sage: minizinc = MznXorDifferentialModelARXOptimized(speck) sage: bit_positions = [i for i in range(speck.output_bit_size)] @@ -254,14 +258,13 @@ def build_lowest_weight_xor_differential_trail_model(self, fixed_variables, max_ ....: 'operator': '=', ....: 'value': '0' }) sage: minizinc.build_lowest_weight_xor_differential_trail_model(fixed_variables) - sage: result = minizinc.solve_for_ARX('Xor') + sage: result = minizinc.solve_for_ARX(CPSAT) sage: result.statistics['nSolutions'] > 1 True """ self.build_xor_differential_trail_model(-1, fixed_variables) self._model_constraints.extend(self.objective_generator()) - self._model_constraints.extend( - self.weight_constraints(max_weight=max_weight, weight=min_weight, operator=">=")) + self._model_constraints.extend(self.weight_constraints(max_weight=max_weight, weight=min_weight, operator=">=")) def build_lowest_xor_differential_trails_with_at_most_weight(self, fixed_weight, fixed_variables): """ @@ -276,6 +279,7 @@ def build_lowest_xor_differential_trails_with_at_most_weight(self, fixed_weight, sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import MznXorDifferentialModelARXOptimized + sage: from claasp.cipher_modules.models.cp.solvers import CPSAT sage: speck = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) sage: minizinc = MznXorDifferentialModelARXOptimized(speck) sage: bit_positions = [i for i in range(speck.output_bit_size)] @@ -293,7 +297,7 @@ def build_lowest_xor_differential_trails_with_at_most_weight(self, fixed_weight, sage: minizinc.build_lowest_xor_differential_trails_with_at_most_weight( ....: 100, fixed_variables ....: ) - sage: result = minizinc.solve_for_ARX('Xor') + sage: result = minizinc.solve_for_ARX(CPSAT) sage: result.statistics['nSolutions'] > 1 True """ @@ -327,7 +331,7 @@ def connect_rounds(self): continue ninputs = component.input_bit_size - input_vars = [f'{component.id}_{self.input_postfix}{i}' for i in range(ninputs)] + input_vars = [f"{component.id}_{self.input_postfix}{i}" for i in range(ninputs)] input_links = component.input_id_links input_positions = component.input_bit_positions prev_input_vars = [] @@ -335,7 +339,7 @@ def connect_rounds(self): for k in range(len(input_links)): prev_input_vars += [input_links[k] + "_" + self.output_postfix + str(i) for i in input_positions[k]] - connect_rounds_constraints += [f'constraint {x} = {y};' for (x, y) in zip(input_vars, prev_input_vars)] + connect_rounds_constraints += [f"constraint {x} = {y};" for (x, y) in zip(input_vars, prev_input_vars)] return connect_rounds_constraints @@ -358,6 +362,7 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import MznXorDifferentialModelARXOptimized sage: speck = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) sage: minizinc = MznXorDifferentialModelARXOptimized(speck) + sage: from claasp.cipher_modules.models.cp.solvers import CPSAT sage: bit_positions = [i for i in range(speck.output_bit_size)] sage: bit_positions_key = list(range(64)) sage: fixed_variables = [{ 'component_id': 'plaintext', @@ -371,7 +376,7 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed ....: 'operator': '=', ....: 'value': '0' }) sage: result = minizinc.find_all_xor_differential_trails_with_fixed_weight( - ....: 5, solver_name='Xor', fixed_values=fixed_variables + ....: 5, solver_name=CPSAT, fixed_values=fixed_variables ....: ) sage: print(result['total_weight']) None @@ -381,14 +386,20 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed result = self.solve_for_ARX(solver_name=solver_name, all_solutions_=True) total_weight = MznXorDifferentialModelARXOptimized._get_total_weight(result) parsed_result = MznXorDifferentialModelARXOptimized._parse_result( - result, solver_name, total_weight, 'xor_differential', self._variables_list, self.cipher_id, - self.probability_vars + result, + solver_name, + total_weight, + "xor_differential", + self._variables_list, + self.cipher_id, + self.probability_vars, ) return parsed_result - def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_weight=64, - fixed_values=[], solver_name=None): + def find_all_xor_differential_trails_with_weight_at_most( + self, min_weight, max_weight=64, fixed_values=[], solver_name=None + ): """ Return all XOR differential trails with weight greater than ``min_weight`` and lower/equal to ``max_weight``. @@ -408,6 +419,7 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import MznXorDifferentialModelARXOptimized sage: speck = SpeckBlockCipher(number_of_rounds=4, block_bit_size=32, key_bit_size=64) sage: minizinc = MznXorDifferentialModelARXOptimized(speck) + sage: from claasp.cipher_modules.models.cp.solvers import CPSAT sage: bit_positions = list(range(32)) sage: bit_positions_key = list(range(64)) sage: fixed_variables = [{ 'component_id': 'plaintext', @@ -421,38 +433,45 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w ....: 'operator': '=', ....: 'value': '0' }) sage: result = minizinc.find_all_xor_differential_trails_with_weight_at_most( - ....: 1, solver_name='Xor', fixed_values=fixed_variables + ....: 1, solver_name=CPSAT, fixed_values=fixed_variables ....: ) sage: result[0]['total_weight'] > 1 True """ self.build_xor_differential_trail_model(-1, fixed_values) - self._model_constraints.extend( - self.weight_constraints(min_weight, ">", max_weight)) + self._model_constraints.extend(self.weight_constraints(min_weight, ">", max_weight)) result = self.solve_for_ARX(solver_name=solver_name, all_solutions_=True) total_weight = MznXorDifferentialModelARXOptimized._get_total_weight(result) parsed_result = self._parse_result( - result, solver_name, total_weight, 'xor_differential', self._variables_list, self.cipher_id, - self.probability_vars + result, + solver_name, + total_weight, + "xor_differential", + self._variables_list, + self.cipher_id, + self.probability_vars, ) return parsed_result - def find_min_of_max_xor_differential_between_permutation_and_key_schedule( - self, fixed_values=[], solver_name=None - ): + def find_min_of_max_xor_differential_between_permutation_and_key_schedule(self, fixed_values=[], solver_name=None): self.constraint_permutation_and_key_schedule_separately_by_input_sizes() self.build_xor_differential_trail_model(-1, fixed_values) - self._model_constraints.extend(self.objective_generator(strategy='min_max_key_schedule_permutation')) + self._model_constraints.extend(self.objective_generator(strategy="min_max_key_schedule_permutation")) self._model_constraints.extend(self.weight_constraints()) result = self.solve_for_ARX(solver_name=solver_name) total_weight = self._get_total_weight(result) parsed_result = MznXorDifferentialModelARXOptimized._parse_result( - result, solver_name, total_weight, 'xor_differential', self._variables_list, self.cipher_id, - self.probability_vars + result, + solver_name, + total_weight, + "xor_differential", + self._variables_list, + self.cipher_id, + self.probability_vars, ) - parsed_result['objective_strategy'] = 'min_max_key_schedule_permutation' + parsed_result["objective_strategy"] = "min_max_key_schedule_permutation" return parsed_result @@ -472,6 +491,7 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import MznXorDifferentialModelARXOptimized sage: speck = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) sage: minizinc = MznXorDifferentialModelARXOptimized(speck) + sage: from claasp.cipher_modules.models.cp.solvers import CPSAT sage: bit_positions = list(range(32)) sage: bit_positions_key = list(range(64)) sage: fixed_variables = [{ 'component_id': 'plaintext', @@ -485,13 +505,13 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name ....: 'operator': '=', ....: 'value': '0' }) sage: result = minizinc.find_lowest_weight_xor_differential_trail( - ....: solver_name='Xor', fixed_values=fixed_variables + ....: solver_name=CPSAT, fixed_values=fixed_variables ....: ) sage: result["total_weight"] 9 sage: minizinc = MznXorDifferentialModelARXOptimized(speck, [0, 0, 0, 0, 0]) - sage: result = minizinc.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables) + sage: result = minizinc.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables) sage: result["total_weight"] 9 """ @@ -501,8 +521,13 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name result = self.solve_for_ARX(solver_name=solver_name) total_weight = MznXorDifferentialModelARXOptimized._get_total_weight(result) parsed_result = self._parse_result( - result, solver_name, total_weight, 'xor_differential', self._variables_list, self.cipher_id, - self.probability_vars + result, + solver_name, + total_weight, + "xor_differential", + self._variables_list, + self.cipher_id, + self.probability_vars, ) return parsed_result @@ -510,15 +535,21 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name def init_constraints(self): output_string_for_cipher_inputs = [] for i in range(len(self._cipher.inputs)): - var_names_inputs = [self._cipher.inputs[i] + "_" + self.output_postfix + str(j) - for j in range(self._cipher.inputs_bit_size[i])] - output_string_for_cipher_input = \ - "output [\"cipher_input:" + self._cipher.inputs[i] + ":\" ++ show(" + \ - MznXorDifferentialModelARXOptimized._create_minizinc_1d_array_from_list(var_names_inputs) + ")++\"\\n\"];\n" + var_names_inputs = [ + self._cipher.inputs[i] + "_" + self.output_postfix + str(j) + for j in range(self._cipher.inputs_bit_size[i]) + ] + output_string_for_cipher_input = ( + 'output ["cipher_input:' + + self._cipher.inputs[i] + + ':" ++ show(' + + MznXorDifferentialModelARXOptimized._create_minizinc_1d_array_from_list(var_names_inputs) + + ')++"\\n"];\n' + ) output_string_for_cipher_inputs.append(output_string_for_cipher_input) for ii in range(len(var_names_inputs)): - self._variables_list.extend([f'var {self.data_type}: {var_names_inputs[ii]};']) + self._variables_list.extend([f"var {self.data_type}: {var_names_inputs[ii]};"]) self._model_constraints.extend(self.connect_rounds()) if self.sat_or_milp == "sat": @@ -528,8 +559,9 @@ def init_constraints(self): if self.include_word_operations_mzn_file: self._model_constraints.extend([get_word_operations()]) - self._model_constraints.extend([ - f'output [ \"{self.cipher_id}, and window_size={self.window_size_list}\" ++ \"\\n\"];']) + self._model_constraints.extend( + [f'output [ "{self.cipher_id}, and window_size={self.window_size_list}" ++ "\\n"];'] + ) self._model_constraints.extend(output_string_for_cipher_inputs) def get_probability_vars_from_permutation(self): @@ -538,9 +570,9 @@ def get_probability_vars_from_permutation(self): permutation_components = cipher_permutation.get_all_components() probability_vars_from_permutation = [] for permutation_component in permutation_components: - if permutation_component.id.startswith('modadd') or permutation_component.id.startswith('modsub'): + if permutation_component.id.startswith("modadd") or permutation_component.id.startswith("modsub"): for probability_var in self.probability_vars: - if probability_var.startswith(f'p_{permutation_component.id}'): + if probability_var.startswith(f"p_{permutation_component.id}"): probability_vars_from_permutation.append(probability_var) return probability_vars_from_permutation @@ -562,9 +594,9 @@ def get_probability_vars_from_key_schedule(self): key_schedule_ids = set(all_components_ids) - set(permutation_component_ids) key_schedule_prob_var_ids = [] for key_schedule_id in key_schedule_ids: - if key_schedule_id.startswith('modadd') or key_schedule_id.startswith('modsub'): + if key_schedule_id.startswith("modadd") or key_schedule_id.startswith("modsub"): for probability_var in self.probability_vars: - if probability_var.startswith(f'p_{key_schedule_id}'): + if probability_var.startswith(f"p_{key_schedule_id}"): key_schedule_prob_var_ids.append(probability_var) return key_schedule_prob_var_ids @@ -574,24 +606,26 @@ def constraint_permutation_and_key_schedule_separately_by_input_sizes(self): permutation_probability_vars = list(set(self.get_probability_vars_from_permutation())) modadd_key_schedule_concatenation_vars = "++".join(key_schedule_probability_vars) modadd_permutation_probability_vars = "++".join(permutation_probability_vars) - key_index = self.cipher.inputs.index('key') - plaintext_index = self.cipher.inputs.index('plaintext') + key_index = self.cipher.inputs.index("key") + plaintext_index = self.cipher.inputs.index("plaintext") key_input_bit_size = self.cipher.inputs_bit_size[key_index] plaintext_input_bit_size = self.cipher.inputs_bit_size[plaintext_index] - self._model_constraints.append(f'sum({modadd_key_schedule_concatenation_vars}) <= {key_input_bit_size};') - self._model_constraints.append(f'sum({modadd_permutation_probability_vars}) <= {plaintext_input_bit_size};') + self._model_constraints.append(f"sum({modadd_key_schedule_concatenation_vars}) <= {key_input_bit_size};") + self._model_constraints.append(f"sum({modadd_permutation_probability_vars}) <= {plaintext_input_bit_size};") - def objective_generator(self, strategy='min_all_probabilities'): - if strategy == 'min_all_probabilities': + def objective_generator(self, strategy="min_all_probabilities"): + if strategy == "min_all_probabilities": objective_string = [] modular_addition_concatenation = "++".join(self.probability_vars) - objective_string.append(f'solve:: int_search({modular_addition_concatenation},' - f' smallest, indomain_min, complete)') - objective_string.append(f'minimize sum({modular_addition_concatenation});') - self.mzn_output_directives.append(f'output ["Total_Probability: "++show(sum(' - f'{modular_addition_concatenation}))];') - elif strategy == 'min_max_key_schedule_permutation': + objective_string.append( + f"solve:: int_search({modular_addition_concatenation}, smallest, indomain_min, complete)" + ) + objective_string.append(f"minimize sum({modular_addition_concatenation});") + self.mzn_output_directives.append( + f'output ["Total_Probability: "++show(sum({modular_addition_concatenation}))];' + ) + elif strategy == "min_max_key_schedule_permutation": objective_string = [] modular_addition_concatenation = "++".join(self.probability_vars) key_schedule_probability_vars = list(set(self.get_probability_vars_from_key_schedule())) @@ -599,10 +633,13 @@ def objective_generator(self, strategy='min_all_probabilities'): modadd_key_schedule_concatenation_vars = "++".join(key_schedule_probability_vars) modadd_permutation_probability_vars = "++".join(permutation_probability_vars) - objective_string.append(f'solve:: int_search({modular_addition_concatenation},' - f' smallest, indomain_min, complete)') + objective_string.append( + f"solve:: int_search({modular_addition_concatenation}, smallest, indomain_min, complete)" + ) - objective_string.append(f'minimize max(sum({modadd_key_schedule_concatenation_vars}), sum({modadd_permutation_probability_vars}));') + objective_string.append( + f"minimize max(sum({modadd_key_schedule_concatenation_vars}), sum({modadd_permutation_probability_vars}));" + ) else: raise NotImplementedError("Strategy {strategy} no implemented") @@ -612,13 +649,13 @@ def objective_generator(self, strategy='min_all_probabilities'): def parse_probability_vars(result, solution, probability_vars): parsed_result = {} if result.status not in [Status.UNKNOWN, Status.UNSATISFIABLE, Status.ERROR]: - for probability_var in probability_vars: lst_value = solution.__dict__[probability_var] parsed_result[probability_var] = { - 'value': str(hex(int("".join(str(0) if str(x) in ["false", "0"] else str(1) for x in lst_value), - 2))), - 'weight': sum(lst_value) + "value": str( + hex(int("".join(str(0) if str(x) in ["false", "0"] else str(1) for x in lst_value), 2)) + ), + "weight": sum(lst_value), } return parsed_result @@ -626,11 +663,13 @@ def parse_probability_vars(result, solution, probability_vars): def satisfy_generator(self): objective_string = [] modular_addition_concatenation = "++".join(self.probability_vars) - objective_string.append(f'solve:: int_search({modular_addition_concatenation},' - f' smallest, indomain_min, complete)') - objective_string.append(f'satisfy;') - self.mzn_output_directives.append(f'output ["Total_Probability: "++show(sum(' - f'{modular_addition_concatenation}))];') + objective_string.append( + f"solve:: int_search({modular_addition_concatenation}, smallest, indomain_min, complete)" + ) + objective_string.append("satisfy;") + self.mzn_output_directives.append( + f'output ["Total_Probability: "++show(sum({modular_addition_concatenation}))];' + ) return objective_string @@ -658,44 +697,48 @@ def weight_constraints(self, weight=None, operator="=", max_weight=None): modular_addition_concatenation = "++".join(self.probability_vars) if weight is not None: - objective_string.append(f'constraint sum({modular_addition_concatenation}) {operator} {weight};') + objective_string.append(f"constraint sum({modular_addition_concatenation}) {operator} {weight};") if max_weight is not None: - objective_string.append(f'constraint sum({modular_addition_concatenation}) < {max_weight};') + objective_string.append(f"constraint sum({modular_addition_concatenation}) < {max_weight};") if self.probability_weight_per_round: for index, mzn_probability_modadd_vars in enumerate(self.probability_modadd_vars_per_round): weights_per_round = self.probability_weight_per_round[index] - min_weight_per_round = weights_per_round['min_bound'] - max_weight_per_round = weights_per_round['max_bound'] + min_weight_per_round = weights_per_round["min_bound"] + max_weight_per_round = weights_per_round["max_bound"] mzn_probability_vars_per_round = "++".join(mzn_probability_modadd_vars) - objective_string.append(f'constraint sum({mzn_probability_vars_per_round}) <= {max_weight_per_round};') - objective_string.append(f'constraint sum({mzn_probability_vars_per_round}) >= {min_weight_per_round};') + objective_string.append(f"constraint sum({mzn_probability_vars_per_round}) <= {max_weight_per_round};") + objective_string.append(f"constraint sum({mzn_probability_vars_per_round}) >= {min_weight_per_round};") - self.mzn_output_directives.append(f'output ["\\n"++"Probability: "++show(sum(' - f'{modular_addition_concatenation}))++"\\n"];') + self.mzn_output_directives.append( + f'output ["\\n"++"Probability: "++show(sum({modular_addition_concatenation}))++"\\n"];' + ) return objective_string def set_max_number_of_nonlinear_carries(self, max_number_of_nonlinear_carries): carries_vars = self.carries_vars concatenated_str = "array[1.." - sizes_sum = sum(var['mzn_carry_array_size'] for var in carries_vars) + sizes_sum = sum(var["mzn_carry_array_size"] for var in carries_vars) concatenated_str += str(sizes_sum) + "] of var bool: concatenated_carries = " - concatenated_str += " ++ ".join(var['mzn_carry_array_name'] for var in carries_vars) + ";\n" - aux_x_definition_str = f'array[1..{sizes_sum}] of var bool: x_carries;\n' - cluster_constraint = (f'constraint forall(i in 1..{sizes_sum}) (' - f'x_carries[i]<->(concatenated_carries[i] /\\ (i == 1 \\/ not concatenated_carries[i-1]))' - f');\n') + concatenated_str += " ++ ".join(var["mzn_carry_array_name"] for var in carries_vars) + ";\n" + aux_x_definition_str = f"array[1..{sizes_sum}] of var bool: x_carries;\n" + cluster_constraint = ( + f"constraint forall(i in 1..{sizes_sum}) (" + f"x_carries[i]<->(concatenated_carries[i] /\\ (i == 1 \\/ not concatenated_carries[i-1]))" + f");\n" + ) self._variables_list.append(concatenated_str) self._variables_list.append(aux_x_definition_str) self._model_constraints.append(cluster_constraint) - self._model_constraints.append(f'constraint sum(i in 1..{sizes_sum})' - f'(bool2int(x_carries[i])) <= {max_number_of_nonlinear_carries};\n') + self._model_constraints.append( + f"constraint sum(i in 1..{sizes_sum})(bool2int(x_carries[i])) <= {max_number_of_nonlinear_carries};\n" + ) def set_max_number_of_carries_on_arx_cipher(self, max_number_of_carries): - concatenated_str = " ++ ".join(var['mzn_carry_array_name'] for var in self.carries_vars) - self._model_constraints.append(f'constraint sum({concatenated_str}) <= {max_number_of_carries};\n') + concatenated_str = " ++ ".join(var["mzn_carry_array_name"] for var in self.carries_vars) + self._model_constraints.append(f"constraint sum({concatenated_str}) <= {max_number_of_carries};\n") def extend_variables(self, variables): self._variables_list.extend(variables) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model.py index d0deb51ab..fcf9ec382 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,8 +20,8 @@ from claasp.input import Input from claasp.component import Component -from claasp.cipher_modules.models.cp.mzn_model import MznModel, solve_satisfy -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, SBOX, MIX_COLUMN, WORD_OPERATION) +from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY +from claasp.name_mappings import CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, SBOX, MIX_COLUMN, WORD_OPERATION def build_xor_truncated_table(numadd): @@ -42,19 +41,20 @@ def build_xor_truncated_table(numadd): sage: build_xor_truncated_table(3) 'array[0..4, 1..3] of int: xor_truncated_table_3 = array2d(0..4, 1..3, [0,0,0,0,1,1,1,0,1,1,1,0,1,1,1]);' """ - size = 2 ** numadd - binary_list = (f'{i:0{numadd}b}' for i in range(size)) - table_items = [','.join(i) for i in binary_list if i.count('1') != 1] - table = ','.join(table_items) - xor_table = f'array[0..{size - numadd - 1}, 1..{numadd}] of int: ' \ - f'xor_truncated_table_{numadd} = array2d(0..{size - numadd - 1}, 1..{numadd}, ' \ - f'[{table}]);' + size = 2**numadd + binary_list = (f"{i:0{numadd}b}" for i in range(size)) + table_items = [",".join(i) for i in binary_list if i.count("1") != 1] + table = ",".join(table_items) + xor_table = ( + f"array[0..{size - numadd - 1}, 1..{numadd}] of int: " + f"xor_truncated_table_{numadd} = array2d(0..{size - numadd - 1}, 1..{numadd}, " + f"[{table}]);" + ) return xor_table class MznXorDifferentialNumberOfActiveSboxesModel(MznModel): - def __init__(self, cipher): self._first_step = [] self._first_step_find_all_solutions = [] @@ -96,7 +96,9 @@ def add_additional_xor_constraints(self, nmax, repetition): if self.list_of_xor_components == temp_list_of_xor_components: break - def build_xor_differential_trail_first_step_model(self, weight=-1, fixed_variables=[], nmax=2, repetition=1, possible_sboxes=0): + def build_xor_differential_trail_first_step_model( + self, weight=-1, fixed_variables=[], nmax=2, repetition=1, possible_sboxes=0 + ): """ Build the CP Model for the second step of the search of XOR differential trail of an SPN cipher. @@ -132,18 +134,19 @@ def build_xor_differential_trail_first_step_model(self, weight=-1, fixed_variabl self._variables_list = [] self.c = 0 self.table_of_solutions_length = 0 - constraints = self.fix_variables_value_constraints(fixed_variables, 'first_step') + constraints = self.fix_variables_value_constraints(fixed_variables, "first_step") self._first_step = constraints self._variables_list.extend(self.input_xor_differential_first_step_constraints(possible_sboxes)) for component in self._cipher.get_all_components(): component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, SBOX, MIX_COLUMN, WORD_OPERATION] operation = component.description[0] - operation_types = ['ROTATE', 'SHIFT', 'XOR', 'NOT'] - if component.type not in component_types or \ - (component.type == WORD_OPERATION and operation not in operation_types): - print(f'{component.id} not yet implemented') - elif component.type == WORD_OPERATION and operation == 'XOR': + operation_types = ["ROTATE", "SHIFT", "XOR", "NOT"] + if component.type not in component_types or ( + component.type == WORD_OPERATION and operation not in operation_types + ): + print(f"{component.id} not yet implemented") + elif component.type == WORD_OPERATION and operation == "XOR": variables, constraints = component.cp_transform_xor_components_for_first_step(self) else: variables, constraints = component.cp_xor_differential_propagation_first_step_constraints(self) @@ -156,8 +159,7 @@ def build_xor_differential_trail_first_step_model(self, weight=-1, fixed_variabl self._variables_list.extend(variables) self._first_step.append(constraints) self._first_step.extend(self.final_xor_differential_first_step_constraints(weight)) - self._first_step = \ - self._model_prefix + self._variables_list + self._first_step + self._first_step = self._model_prefix + self._variables_list + self._first_step def create_xor_component(self, component1, component2, nmax): """ @@ -201,9 +203,10 @@ def create_xor_component(self, component1, component2, nmax): for input_bit in input_bit_positions: input_len += len(input_bit) component_input = Input(input_len, input_id_link, input_bit_positions) - xor_component = Component("", "word_operation", component_input, input_len, ['XOR', new_numb_of_inp]) - xor_components_dictionaries = [component.as_python_dictionary() - for component in self.list_of_xor_components] + xor_component = Component("", "word_operation", component_input, input_len, ["XOR", new_numb_of_inp]) + xor_components_dictionaries = [ + component.as_python_dictionary() for component in self.list_of_xor_components + ] if xor_component.as_python_dictionary() not in xor_components_dictionaries: self.list_of_xor_components.append(xor_component) @@ -229,17 +232,21 @@ def final_xor_differential_first_step_constraints(self, weight=-1): 'solve minimize number_of_active_sBoxes;', 'output[show(number_of_active_sBoxes) ++ "\\n" ++ " table_of_solution_length = "++ show(table_of_solutions_length)];'] """ - inputs = '+'.join([f'{input_[0]}' for input_ in self.input_sbox]) - cp_constraints = [f'constraint number_of_active_sBoxes = {inputs};', - f'int: table_of_solutions_length = {self.table_of_solutions_length};'] + inputs = "+".join([f"{input_[0]}" for input_ in self.input_sbox]) + cp_constraints = [ + f"constraint number_of_active_sBoxes = {inputs};", + f"int: table_of_solutions_length = {self.table_of_solutions_length};", + ] if weight == -1: - cp_constraints.append('solve minimize number_of_active_sBoxes;') + cp_constraints.append("solve minimize number_of_active_sBoxes;") else: - cp_constraints.append(solve_satisfy) - new_constraint = 'output[show(number_of_active_sBoxes) ++ \"\\n\" ++ \" ' \ - 'table_of_solution_length = \"++ show(table_of_solutions_length)' \ - + ''.join([f' ++ \"\\n\" ++ \" {input_[0]} = \"++ show({input_[0]})' - for input_ in self.input_sbox]) + '];' + cp_constraints.append(SOLVE_SATISFY) + new_constraint = ( + 'output[show(number_of_active_sBoxes) ++ "\\n" ++ " ' + 'table_of_solution_length = "++ show(table_of_solutions_length)' + + "".join([f' ++ "\\n" ++ " {input_[0]} = "++ show({input_[0]})' for input_ in self.input_sbox]) + + "];" + ) cp_constraints.append(new_constraint) return cp_constraints @@ -249,19 +256,22 @@ def get_new_xor_input_links_and_positions(self, all_inputs, new_numb_of_inp): input_bit_positions = [[] for _ in range(new_numb_of_inp)] input_index = 0 for i in range(new_numb_of_inp): - divide = all_inputs[i].partition('[') + divide = all_inputs[i].partition("[") new_input_name = divide[0] new_input_bit_positions = divide[2][:-1] if new_input_name not in input_id_link: input_id_link.append(new_input_name) - input_bit_positions[input_index] += [int(new_input_bit_positions) * self.word_size + j - for j in range(self.word_size)] + input_bit_positions[input_index] += [ + int(new_input_bit_positions) * self.word_size + j for j in range(self.word_size) + ] input_index = input_index + 1 else: for j, present_input in enumerate(input_id_link): if present_input == new_input_name: - input_bit_positions[j] += [int(new_input_bit_positions) * self.word_size + word_size - for word_size in range(self.word_size)] + input_bit_positions[j] += [ + int(new_input_bit_positions) * self.word_size + word_size + for word_size in range(self.word_size) + ] input_bit_positions = [x for x in input_bit_positions if x != []] return input_id_link, input_bit_positions @@ -273,11 +283,19 @@ def get_xor_all_inputs(self, component1, component2): input_bit_positions_2 = component2.input_bit_positions old_all_inputs = [] for id_link, bit_positions in zip(input_id_links_1, input_bit_positions_1): - old_all_inputs.extend([f'{id_link}[{bit_positions[j * self.word_size] // self.word_size}]' - for j in range(len(bit_positions) // self.word_size)]) + old_all_inputs.extend( + [ + f"{id_link}[{bit_positions[j * self.word_size] // self.word_size}]" + for j in range(len(bit_positions) // self.word_size) + ] + ) for id_link, bit_positions in zip(input_id_links_2, input_bit_positions_2): - old_all_inputs.extend([f'{id_link}[{bit_positions[j * self.word_size] // self.word_size}]' - for j in range(len(bit_positions) // self.word_size)]) + old_all_inputs.extend( + [ + f"{id_link}[{bit_positions[j * self.word_size] // self.word_size}]" + for j in range(len(bit_positions) // self.word_size) + ] + ) all_inputs = [] for old_input in old_all_inputs: if old_input not in all_inputs: @@ -308,21 +326,27 @@ def input_xor_differential_first_step_constraints(self, possible_sboxes=0): 'array[0..15] of var 0..1: plaintext;'] """ if possible_sboxes != 0: - number_of_active_sBoxes_declaration = 'var {' + number_of_active_sboxes_declaration = "var {" for sboxes_n in possible_sboxes: - number_of_active_sBoxes_declaration += str(sboxes_n) - number_of_active_sBoxes_declaration += ', ' - number_of_active_sBoxes_declaration = number_of_active_sBoxes_declaration[:-2] + '}: number_of_active_sBoxes;' - cp_declarations = [number_of_active_sBoxes_declaration] + number_of_active_sboxes_declaration += str(sboxes_n) + number_of_active_sboxes_declaration += ", " + number_of_active_sboxes_declaration = ( + number_of_active_sboxes_declaration[:-2] + "}: number_of_active_sBoxes;" + ) + cp_declarations = [number_of_active_sboxes_declaration] else: active_sboxes_count = 0 for component in self._cipher.get_all_components(): if SBOX in component.type: input_bit_positions = component.input_bit_positions active_sboxes_count += len(input_bit_positions) - cp_declarations = [f'var 1..{active_sboxes_count}: number_of_active_sBoxes;'] - cp_declarations.extend([f'array[0..{bit_size // self.word_size - 1}] of var 0..1: {input_};' - for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)]) + cp_declarations = [f"var 1..{active_sboxes_count}: number_of_active_sBoxes;"] + cp_declarations.extend( + [ + f"array[0..{bit_size // self.word_size - 1}] of var 0..1: {input_};" + for input_, bit_size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + ] + ) return cp_declarations @@ -352,12 +376,18 @@ def xor_xor_differential_first_step_constraints(self, component): numadd = description[1] all_inputs = [] for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{bit_positions[j * self.word_size] // self.word_size}]' - for j in range(len(bit_positions) // self.word_size)]) + all_inputs.extend( + [ + f"{id_link}[{bit_positions[j * self.word_size] // self.word_size}]" + for j in range(len(bit_positions) // self.word_size) + ] + ) input_len = len(all_inputs) // numadd - cp_constraints = 'constraint table(' \ - + '++'.join([f'[{all_inputs[input_len * j]}]' for j in range(numadd)]) \ - + f', xor_truncated_table_{numadd});' + cp_constraints = ( + "constraint table(" + + "++".join([f"[{all_inputs[input_len * j]}]" for j in range(numadd)]) + + f", xor_truncated_table_{numadd});" + ) xor_table = build_xor_truncated_table(numadd) cp_declarations = [] if xor_table not in self._variables_list: diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model.py index 6632abfd1..e673c0f66 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model.py @@ -1,49 +1,52 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import os -import math import subprocess import time as tm from copy import deepcopy -from sage.crypto.sbox import SBox - - -from claasp.name_mappings import XOR_DIFFERENTIAL, CONSTANT, SBOX, WORD_OPERATION -from claasp.cipher_modules.models.cp.mzn_model import solve_satisfy -from claasp.cipher_modules.models.utils import write_model_to_file, convert_solver_solution_to_dictionary -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import ( - MznXorDifferentialModel, update_and_or_ddt_valid_probabilities) +from claasp.cipher_modules.models.cp.mzn_model import SOLVE_SATISFY +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import MznXorDifferentialModel from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_number_of_active_sboxes_model import ( - MznXorDifferentialNumberOfActiveSboxesModel) -from claasp.cipher_modules.models.cp.solvers import CP_SOLVERS_EXTERNAL, CP_SOLVERS_INTERNAL, MODEL_DEFAULT_PATH, SOLVER_DEFAULT - - -class MznXorDifferentialFixingNumberOfActiveSboxesModel(MznXorDifferentialModel, - MznXorDifferentialNumberOfActiveSboxesModel): - + MznXorDifferentialNumberOfActiveSboxesModel, +) +from claasp.cipher_modules.models.cp.solvers import ( + CP_SOLVERS_EXTERNAL, + SOLVER_DEFAULT, +) +from claasp.cipher_modules.models.utils import convert_solver_solution_to_dictionary +from claasp.name_mappings import UNSATISFIABLE, XOR_DIFFERENTIAL + + +class MznXorDifferentialFixingNumberOfActiveSboxesModel( + MznXorDifferentialModel, MznXorDifferentialNumberOfActiveSboxesModel +): def __init__(self, cipher): self._table_items = [] super().__init__(cipher) + def initialise_model(self): + self._table_items = [] + self._first_step = [] + self._first_step_find_all_solutions = [] + super().initialise_model() + def build_xor_differential_trail_second_step_model(self, weight=-1, fixed_variables=[]): """ Build the CP Model for the second step of the search of XOR differential trail of an SPN cipher. @@ -77,7 +80,18 @@ def build_xor_differential_trail_second_step_model(self, weight=-1, fixed_variab self._model_constraints.extend(self.final_xor_differential_constraints(weight)) self._model_constraints = self._model_prefix + self._variables_list + self._model_constraints - def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], first_step_solver_name=SOLVER_DEFAULT, second_step_solver_name=SOLVER_DEFAULT, nmax=2, repetition=1, num_of_processors=None, timelimit=None, solve_with_API=False): + def find_all_xor_differential_trails_with_fixed_weight( + self, + fixed_weight, + fixed_values=[], + first_step_solver_name=SOLVER_DEFAULT, + second_step_solver_name=SOLVER_DEFAULT, + nmax=2, + repetition=1, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + ): """ Return a list of solutions containing all the differential trails having the ``fixed_weight`` weight of correlation. @@ -85,14 +99,9 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed - ``fixed_weight`` -- **integer**; the weight to be fixed - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``first_step_solver_name`` -- **string** (default: `Chuffed`); the name of the solver for the number of active sboxes search - - ``second_step_solver_name`` -- **string** (default: `Chuffed`); the name of the solver for the differential trails search. Available values for both the solver names are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` - * ``'Xor'`` - * ``'Choco-solver'`` + - ``first_step_solver_name`` -- **string** (default: `chuffed`); the name of the solver for the number of active sboxes search + - ``second_step_solver_name`` -- **string** (default: `chuffed`); the name of the solver for the differential trails search. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -106,14 +115,34 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed ....: integer_to_bit_list(0, 128, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(128), ....: integer_to_bit_list(0, 128, 'little'))) - sage: trails = cp.find_all_xor_differential_trails_with_fixed_weight(224, fixed_variables, 'Chuffed', 'Chuffed') # long # doctest: +SKIP + sage: trails = cp.find_all_xor_differential_trails_with_fixed_weight(224, fixed_variables, 'chuffed', 'chuffed') # long # doctest: +SKIP ... sage: len(trails) # long # doctest: +SKIP 8 """ - return self.solve_full_two_steps_xor_differential_model('xor_differential_all_solutions', fixed_weight, fixed_values, first_step_solver_name, second_step_solver_name, nmax, repetition, num_of_processors, timelimit) - - def find_lowest_weight_xor_differential_trail(self, fixed_values=[], first_step_solver_name=SOLVER_DEFAULT, second_step_solver_name=SOLVER_DEFAULT, nmax=2, repetition=1, num_of_processors=None, timelimit=None, solve_with_API=False): + return self.solve_full_two_steps_xor_differential_model( + "xor_differential_all_solutions", + fixed_weight, + fixed_values, + first_step_solver_name, + second_step_solver_name, + nmax, + repetition, + num_of_processors, + timelimit, + ) + + def find_lowest_weight_xor_differential_trail( + self, + fixed_values=[], + first_step_solver_name=SOLVER_DEFAULT, + second_step_solver_name=SOLVER_DEFAULT, + nmax=2, + repetition=1, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + ): """ Return the solution representing a differential trail with the lowest weight. @@ -125,11 +154,8 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], first_step_ INPUT: - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -143,29 +169,46 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], first_step_ ....: integer_to_bit_list(0, 128, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(128), ....: integer_to_bit_list(0, 128, 'little'))) - sage: cp.find_lowest_weight_xor_differential_trail(fixed_variables, 'Chuffed', 'Chuffed') # random + sage: cp.find_lowest_weight_xor_differential_trail(fixed_variables, 'chuffed', 'chuffed') # random 5 {'cipher': 'aes_block_cipher_k128_p128_o128_r2', 'model_type': 'xor_differential', - 'solver_name': 'Chuffed', + 'solver_name': 'chuffed', 'components_values': {'key': {'value': '00000000000000000000000000000000', 'weight': 0}, ... 'total_weight': '30.0'} """ - return self.solve_full_two_steps_xor_differential_model('xor_differential_one_solution', -1, fixed_values, first_step_solver_name, second_step_solver_name, nmax, repetition, num_of_processors, timelimit) - - def find_one_xor_differential_trail(self, fixed_values=[], first_step_solver_name=SOLVER_DEFAULT, second_step_solver_name=SOLVER_DEFAULT, nmax=2, repetition=1, num_of_processors=None, timelimit=None, solve_with_API=False): + return self.solve_full_two_steps_xor_differential_model( + "xor_differential_one_solution", + -1, + fixed_values, + first_step_solver_name, + second_step_solver_name, + nmax, + repetition, + num_of_processors, + timelimit, + ) + + def find_one_xor_differential_trail( + self, + fixed_values=[], + first_step_solver_name=SOLVER_DEFAULT, + second_step_solver_name=SOLVER_DEFAULT, + nmax=2, + repetition=1, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + ): """ Return the solution representing a differential trail with any weight. INPUT: - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -179,16 +222,37 @@ def find_one_xor_differential_trail(self, fixed_values=[], first_step_solver_nam ....: integer_to_bit_list(0, 128, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(128), ....: integer_to_bit_list(0, 128, 'little'))) - sage: cp.find_one_xor_differential_trail(fixed_variables, 'Chuffed', 'Chuffed') # random + sage: cp.find_one_xor_differential_trail(fixed_variables, 'chuffed', 'chuffed') # random {'cipher': 'aes_block_cipher_k128_p128_o128_r2', 'model_type': 'xor_differential', ... 'cipher_output_1_32':{'value': 'ffffffffffffffffffffffffffffffff', 'weight': 0.0}}, 'total_weight': '224.0'} """ - return self.solve_full_two_steps_xor_differential_model('xor_differential_one_solution', 0, fixed_values, first_step_solver_name, second_step_solver_name, nmax, repetition, num_of_processors, timelimit) - - def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight=-1, fixed_values=[], first_step_solver_name=SOLVER_DEFAULT, second_step_solver_name=SOLVER_DEFAULT, nmax=2, repetition=1, num_of_processors=None, timelimit=None, solve_with_API=False): + return self.solve_full_two_steps_xor_differential_model( + "xor_differential_one_solution", + 0, + fixed_values, + first_step_solver_name, + second_step_solver_name, + nmax, + repetition, + num_of_processors, + timelimit, + ) + + def find_one_xor_differential_trail_with_fixed_weight( + self, + fixed_weight=-1, + fixed_values=[], + first_step_solver_name=SOLVER_DEFAULT, + second_step_solver_name=SOLVER_DEFAULT, + nmax=2, + repetition=1, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + ): """ Return the solution representing a differential trail with any weight. @@ -196,11 +260,8 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight=-1, fix - ``fixed_weight`` -- **integer**; the value to which the weight is fixed, if non-negative - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -214,17 +275,27 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight=-1, fix ....: integer_to_bit_list(0, 128, 'little'))] sage: fixed_variables.append(set_fixed_variables('plaintext', 'not_equal', range(128), ....: integer_to_bit_list(0, 128, 'little'))) - sage: cp.find_one_xor_differential_trail_with_fixed_weight(224, fixed_variables, 'Chuffed', 'Chuffed') # random + sage: cp.find_one_xor_differential_trail_with_fixed_weight(224, fixed_variables, 'chuffed', 'chuffed') # random {'cipher': 'aes_block_cipher_k128_p128_o128_r2', 'model_type': 'xor_differential', - 'solver_name': 'Chuffed', + 'solver_name': 'chuffed', ... 'total_weight': '224.0', 'building_time_seconds': 19.993147134780884} """ - return self.solve_full_two_steps_xor_differential_model('xor_differential_one_solution', fixed_weight, fixed_values, first_step_solver_name, second_step_solver_name, nmax, repetition, num_of_processors, timelimit) - - def generate_table_of_solutions(self, solution, solver_name): + return self.solve_full_two_steps_xor_differential_model( + "xor_differential_one_solution", + fixed_weight, + fixed_values, + first_step_solver_name, + second_step_solver_name, + nmax, + repetition, + num_of_processors, + timelimit, + ) + + def generate_table_of_solutions(self, solution): """ Return a table with the solutions from the first step in the two steps model for xor differential trail search. @@ -235,46 +306,59 @@ def generate_table_of_solutions(self, solution, solver_name): EXAMPLES:: sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model import ( - ....: MznXorDifferentialFixingNumberOfActiveSboxesModel) + ....: MznXorDifferentialFixingNumberOfActiveSboxesModel, + ....: ) sage: from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher - sage: from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list + sage: from claasp.cipher_modules.models.utils import set_fixed_variables sage: aes = AESBlockCipher(number_of_rounds=2) sage: cp = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) - sage: fixed_variables = [set_fixed_variables('key', 'not_equal', list(range(128)), - ....: integer_to_bit_list(0, 128, 'little'))] - sage: cp.build_xor_differential_trail_first_step_model(-1,fixed_variables) - sage: first_step_solution, solve_time = cp.solve_model('xor_differential_first_step', 'Chuffed') - sage: cp.generate_table_of_solutions(first_step_solution, 'Chuffed') + sage: fixed_variables = [set_fixed_variables('key', 'not_equal', list(range(128)), (0,)*128)] + sage: cp.build_xor_differential_trail_first_step_model(-1, fixed_variables) + sage: first_step_solution, solve_time = cp.solve_model('xor_differential_first_step', 'chuffed') + sage: table_of_solutions = cp.generate_table_of_solutions(first_step_solution) + sage: table_of_solutions[:26] + 'array [0..0, 1..40] of int' """ cipher_name = self.cipher_id - separator = '----------' + separator = "----------" count_separator = solution.count(separator) - table_of_solutions_length = '' + table_of_solutions_length = "" for line in solution: - if 'table_of_solution_length' in line: - line = line.replace(' table_of_solution_length = ', '') - table_of_solutions_length = line.rstrip('\n') - table = f'array [0..{count_separator - 1}, 1..{table_of_solutions_length}] of int: ' \ - f'{cipher_name}_table_of_solutions = ' \ - f'array2d(0..{count_separator - 1}, 1..{table_of_solutions_length}, [' + if "table_of_solution_length" in line: + line = line.replace(" table_of_solution_length = ", "") + table_of_solutions_length = line.rstrip("\n") + table_of_solutions = ( + f"array [0..{count_separator - 1}, 1..{table_of_solutions_length}] of int: " + f"{cipher_name}_table_of_solutions = " + f"array2d(0..{count_separator - 1}, 1..{table_of_solutions_length}, [" + ) for line in solution: for item in self.input_sbox: if item[0] in line: - value = line.replace(item[0], '') - value = value.replace(' = ', '') - table = table + value.replace('\n', '') + ',' - table = table[:-1] + ']);' - with open(f'{cipher_name}_table_of_solutions_{solver_name}.mzn', 'w') as table_of_solutions_file: - table_of_solutions_file.write(table) - - def get_solutions_dictionaries_with_build_time(self, build_time, components_values, memory, solver_name, time, - total_weight): - solutions = [convert_solver_solution_to_dictionary(self.cipher_id, XOR_DIFFERENTIAL, solver_name, time, - memory, components_values[f'solution{i + 1}'], - total_weight[i]) - for i in range(len(total_weight))] + value = line.replace(item[0], "") + value = value.replace(" = ", "") + table_of_solutions = table_of_solutions + value.replace("\n", "") + "," + table_of_solutions = table_of_solutions[:-1] + "]);" + + return table_of_solutions + + def get_solutions_dictionaries_with_build_time( + self, build_time, components_values, memory, solver_name, time, total_weight + ): + solutions = [ + convert_solver_solution_to_dictionary( + self.cipher_id, + XOR_DIFFERENTIAL, + solver_name, + time, + memory, + components_values[f"solution{i + 1}"], + total_weight[i], + ) + for i in range(len(total_weight)) + ] for solution in solutions: - solution['building_time_seconds'] = build_time + solution["building_time_seconds"] = build_time if len(solutions) == 1: solutions = solutions[0] return solutions @@ -302,13 +386,23 @@ def input_xor_differential_constraints(self): """ cp_declarations, cp_constraints = super().input_xor_differential_constraints() - table = '++'.join(self._table_items) - cp_constraints = f'constraint table({table}, {self.cipher_id}_table_of_solutions);' + table = "++".join(self._table_items) + cp_constraints = f"constraint table({table}, {self.cipher_id}_table_of_solutions);" return cp_declarations, cp_constraints - def solve_full_two_steps_xor_differential_model(self, model_type='xor_differential_one_solution', weight=-1, fixed_variables=[], - first_step_solver_name=SOLVER_DEFAULT, second_step_solver_name=SOLVER_DEFAULT, nmax=2, repetition=1, num_of_processors=None, timelimit=None): + def solve_full_two_steps_xor_differential_model( + self, + model_type="xor_differential_one_solution", + weight=-1, + fixed_variables=[], + first_step_solver_name=SOLVER_DEFAULT, + second_step_solver_name=SOLVER_DEFAULT, + nmax=2, + repetition=1, + num_of_processors=None, + timelimit=None, + ): """ Return the solution of the model for an SPN cipher. @@ -334,37 +428,36 @@ def solve_full_two_steps_xor_differential_model(self, model_type='xor_differenti sage: cp = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) sage: fixed_variables = [set_fixed_variables('key', 'not_equal', list(range(128)), ....: integer_to_bit_list(0, 128, 'little'))] - sage: cp.solve_full_two_steps_xor_differential_model('xor_differential_one_solution', -1, fixed_variables, 'Chuffed', 'Chuffed') # random + sage: cp.solve_full_two_steps_xor_differential_model('xor_differential_one_solution', -1, fixed_variables, 'chuffed', 'chuffed') # random 1 {'cipher': 'aes_block_cipher_k128_p128_o128_r2', ... 'total_weight': '6.0', 'building_time': 3.7489726543426514} """ + self.initialise_model() possible_sboxes = 0 if weight > 0: possible_sboxes = self.find_possible_number_of_active_sboxes(weight) if not possible_sboxes: - raise ValueError('There are no trails with the fixed weight!') + raise ValueError("There are no trails with the fixed weight!") - cipher_name = self.cipher_id start = tm.time() self.build_xor_differential_trail_first_step_model(weight, fixed_variables, nmax, repetition, possible_sboxes) end = tm.time() build_time = end - start - first_step_solution, solve_time = self.solve_model('xor_differential_first_step', first_step_solver_name, num_of_processors, timelimit) + first_step_solution, solve_time = self.solve_model( + "xor_differential_first_step", first_step_solver_name, num_of_processors, timelimit + ) start = tm.time() self.build_xor_differential_trail_second_step_model(weight, fixed_variables) end = tm.time() build_time += end - start - input_file_name = f'{MODEL_DEFAULT_PATH}/{cipher_name}_mzn_xor_differential_{first_step_solver_name}.mzn' - solution_file_name = f'{MODEL_DEFAULT_PATH}/{cipher_name}_table_of_solutions_{first_step_solver_name}.mzn' - write_model_to_file(self._model_constraints, input_file_name) for i in range(len(CP_SOLVERS_EXTERNAL)): - if second_step_solver_name == CP_SOLVERS_EXTERNAL[i]['solver_name']: - command_options = deepcopy(CP_SOLVERS_EXTERNAL[i]) - + if second_step_solver_name == CP_SOLVERS_EXTERNAL[i]["solver_name"]: + command_options = deepcopy(CP_SOLVERS_EXTERNAL[i]["keywords"]["command"]) + for attempt in range(10000): if weight == -1: start = tm.time() @@ -372,48 +465,41 @@ def solve_full_two_steps_xor_differential_model(self, model_type='xor_differenti end = tm.time() build_time += end - start first_step_all_solutions, solve_first_step_time = self.solve_model( - 'xor_differential_first_step_find_all_solutions', first_step_solver_name) + "xor_differential_first_step_find_all_solutions", first_step_solver_name + ) solve_time += solve_first_step_time - self.generate_table_of_solutions(first_step_all_solutions, first_step_solver_name) - - command_options['keywords']['command']['input_file'].append(input_file_name) - command_options['keywords']['command']['output_file'].append(solution_file_name) - command_options['keywords']['command']['options'].insert(0, '-a') - elif model_type == 'xor_differential_all_solutions': - self.generate_table_of_solutions(first_step_solution, first_step_solver_name) - - command_options['keywords']['command']['input_file'].append(input_file_name) - command_options['keywords']['command']['output_file'].append(solution_file_name) - command_options['keywords']['command']['options'].insert(0, '-a') + table_of_solutions = self.generate_table_of_solutions(first_step_all_solutions) + command_options["options"].insert(0, "-a") + elif model_type == "xor_differential_all_solutions": + table_of_solutions = self.generate_table_of_solutions(first_step_solution) + command_options["options"].insert(0, "-a") else: - self.generate_table_of_solutions(first_step_solution, first_step_solver_name) - - command_options['keywords']['command']['input_file'].append(input_file_name) - command_options['keywords']['command']['output_file'].append(solution_file_name) + table_of_solutions = self.generate_table_of_solutions(first_step_solution) if num_of_processors is not None: - command_options['keywords']['command']['options'].insert(0, f'-p {num_of_processors}') + command_options["options"].insert(0, f"-p {num_of_processors}") if timelimit is not None: - command_options['keywords']['command']['options'].append('--time-limit') - command_options['keywords']['command']['options'].append(str(timelimit)) + command_options["options"].append(["--time-limit", str(timelimit)]) command = [] - for key in command_options['keywords']['command']['format']: - command.extend(command_options['keywords']['command'][key]) + for key in command_options["format"]: + command.extend(command_options[key]) - solver_process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + model = table_of_solutions + "\n" + "\n".join(self._model_constraints) + "\n" + solver_process = subprocess.run(command, input=model, capture_output=True, text=True) if solver_process.returncode < 0: - raise ValueError('something went wrong with solver subprocess... sorry!') + raise ValueError("something went wrong with solver subprocess... sorry!") solver_output = solver_process.stdout.splitlines() - if any('UNSATISFIABLE' in line for line in solver_output) and weight not in (-1, 0): - os.remove(input_file_name) - os.remove(solution_file_name) - return 'Unsatisfiable' - time, memory, components_values, total_weight = self._parse_solver_output(solver_output, model_type, False, True, second_step_solver_name) - solutions = self.get_solutions_dictionaries_with_build_time(build_time, components_values, memory, - second_step_solver_name, time, total_weight) - os.remove(input_file_name) - os.remove(solution_file_name) + if any(UNSATISFIABLE in line for line in solver_output) and weight not in (-1, 0): + return UNSATISFIABLE + + time, memory, components_values, total_weight = self._parse_solver_output( + solver_output, model_type, False, True, second_step_solver_name + ) + + solutions = self.get_solutions_dictionaries_with_build_time( + build_time, components_values, memory, second_step_solver_name, time, total_weight + ) return solutions @@ -427,11 +513,8 @@ def solve_model(self, model_type, solver_name=SOLVER_DEFAULT, num_of_processors= * 'xor_differential_first_step' * 'xor_differential_first_step_find_all_solutions' - - ``solver_name`` -- **string** (default: `None`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `None`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -444,7 +527,7 @@ def solve_model(self, model_type, solver_name=SOLVER_DEFAULT, num_of_processors= sage: fixed_variables = [set_fixed_variables('key', 'not_equal', list(range(128)), ....: integer_to_bit_list(0, 128, 'little'))] sage: cp.build_xor_differential_trail_first_step_model(-1, fixed_variables) - sage: cp.solve_model('xor_differential_first_step', 'Chuffed') # random + sage: cp.solve_model('xor_differential_first_step', 'chuffed') # random ['1', ' table_of_solution_length = 40', ' xor_0_0[0] = 0', @@ -455,43 +538,38 @@ def solve_model(self, model_type, solver_name=SOLVER_DEFAULT, num_of_processors= 0.19837307929992676)] """ start = tm.time() - cipher_name = self.cipher_id - input_file_name = f'{MODEL_DEFAULT_PATH}/{cipher_name}_Mzn_{model_type}_{solver_name}.mzn' for i in range(len(CP_SOLVERS_EXTERNAL)): - if solver_name == CP_SOLVERS_EXTERNAL[i]['solver_name']: - command_options = deepcopy(CP_SOLVERS_EXTERNAL[i]) - command_options['keywords']['command']['input_file'].append(input_file_name) - - if model_type == 'xor_differential_first_step_find_all_solutions': - write_model_to_file(self._first_step_find_all_solutions, input_file_name) - command_options['keywords']['command']['options'].insert(0, '-a') + if solver_name == CP_SOLVERS_EXTERNAL[i]["solver_name"]: + command_options = deepcopy(CP_SOLVERS_EXTERNAL[i]["keywords"]["command"]) + + if model_type == "xor_differential_first_step_find_all_solutions": + model = "\n".join(self._first_step_find_all_solutions) + "\n" + command_options["options"].insert(0, "-a") else: - if model_type == 'xor_differential_first_step': - write_model_to_file(self._first_step, input_file_name) + if model_type == "xor_differential_first_step": + model = "\n".join(self._first_step) + "\n" else: - write_model_to_file(self._model_constraints, input_file_name) + model = "\n".join(self._model_constraints) + "\n" if num_of_processors is not None: - command_options['keywords']['command']['options'].insert(0, f'-p {num_of_processors}') + command_options["options"].insert(0, f"-p {num_of_processors}") if timelimit is not None: - command_options['keywords']['command']['options'].append('--time-limit') - command_options['keywords']['command']['options'].append(str(timelimit)) - + command_options["options"].append(["--time-limit", str(timelimit)]) + command = [] - for key in command_options['keywords']['command']['format']: - command.extend(command_options['keywords']['command'][key]) - command.remove('--solver-statistics') - solver_process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8') - os.remove(input_file_name) + for key in command_options["format"]: + command.extend(command_options[key]) + command.remove("--solver-statistics") + solver_process = subprocess.run(command, input=model, capture_output=True, text=True) solution = [] temp = [] for c in solver_process.stdout: - if c == '\n': - solution.append(''.join(temp)) + if c == "\n": + solution.append("".join(temp)) temp = [] else: temp.append(c) if temp: - solution.append(''.join(temp)) + solution.append("".join(temp)) end = tm.time() return solution, end - start @@ -514,37 +592,40 @@ def transform_first_step_model(self, attempt, active_sboxes, weight=-1): EXAMPLES:: sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model import ( - ....: MznXorDifferentialFixingNumberOfActiveSboxesModel) + ....: MznXorDifferentialFixingNumberOfActiveSboxesModel, + ....: ) sage: from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher sage: from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list sage: aes = AESBlockCipher(number_of_rounds=2) sage: cp = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) - sage: fixed_variables = [set_fixed_variables('key', 'not_equal', range(128), - ....: integer_to_bit_list(0, 128, 'little'))] + sage: fixed_variables = [set_fixed_variables('key', 'not_equal', range(128), integer_to_bit_list(0, 128, 'little'))] sage: cp.build_xor_differential_trail_first_step_model(-1, fixed_variables) - sage: first_step_solution, solve_time = cp.solve_model('xor_differential_first_step','Chuffed') + sage: first_step_solution, solve_time = cp.solve_model('xor_differential_first_step','chuffed') sage: cp.transform_first_step_model(0, first_step_solution[0]) - 1 + sage: first_step_solution[0] + '1' """ - print(active_sboxes) self._first_step_find_all_solutions = [] for line in self._first_step: - if ': number_of_active_sBoxes;' in line: + if ": number_of_active_sBoxes;" in line: if weight != -1: possible_sboxes = self.find_possible_number_of_active_sboxes(weight) - self._first_step_find_all_solutions += \ - [f'var {str(possible_sboxes)}:number_of_active_sBoxes;'] + self._first_step_find_all_solutions += [f"var {str(possible_sboxes)}:number_of_active_sBoxes;"] else: - self._first_step_find_all_solutions += \ - [f'var int:number_of_active_sBoxes = {int(active_sboxes) + attempt};'] - elif 'solve minimize' in line: - self._first_step_find_all_solutions += [solve_satisfy] - new_constraint = 'output[show(number_of_active_sBoxes) ++ \"\\n\" ++ \" table_of_solution_length = ' \ - '\"++ show(table_of_solutions_length) ++ \"\\n\" ++' + self._first_step_find_all_solutions += [ + f"var int:number_of_active_sBoxes = {int(active_sboxes) + attempt};" + ] + elif "solve minimize" in line: + self._first_step_find_all_solutions += [SOLVE_SATISFY] + new_constraint = ( + 'output[show(number_of_active_sBoxes) ++ "\\n" ++ " table_of_solution_length = ' + '"++ show(table_of_solutions_length) ++ "\\n" ++' + ) for i in range(len(self.input_sbox)): - new_constraint = f'{new_constraint}\" {self.input_sbox[i][0]} = ' \ - f'\"++ show({self.input_sbox[i][0]})++ \"\\n\" ++' - self._first_step_find_all_solutions += [new_constraint[:-2] + '];\n'] + new_constraint = ( + f'{new_constraint}" {self.input_sbox[i][0]} = "++ show({self.input_sbox[i][0]})++ "\\n" ++' + ) + self._first_step_find_all_solutions += [new_constraint[:-2] + "];\n"] break else: self._first_step_find_all_solutions += [line] @@ -555,12 +636,12 @@ def update_sbox_ddt_valid_probabilities(self, component, valid_probabilities): super().update_sbox_ddt_valid_probabilities(component, valid_probabilities) input_id_link = component.input_id_links[0] input_bit_positions = component.input_bit_positions[0] - all_inputs = [f'{input_id_link}[{position}]' for position in input_bit_positions] + all_inputs = [f"{input_id_link}[{position}]" for position in input_bit_positions] for i in range(input_size // self.word_size): - ineq_left_side = '+'.join([f'{all_inputs[i * self.word_size + j]}' - for j in range(self.word_size)]) - new_declaration = f'constraint ({ineq_left_side} > 0) = word_{output_id_link}[{i}];' + ineq_left_side = "+".join([f"{all_inputs[i * self.word_size + j]}" for j in range(self.word_size)]) + new_declaration = f"constraint ({ineq_left_side} > 0) = word_{output_id_link}[{i}];" self._cp_xor_differential_constraints.append(new_declaration) self._cp_xor_differential_constraints.append( - f'array[0..{input_size // self.word_size - 1}] of var 0..1: word_{output_id_link};') - self._table_items.append(f'[word_{output_id_link}[s] | s in 0..{input_size // self.word_size - 1}]') + f"array[0..{input_size // self.word_size - 1}] of var 0..1: word_{output_id_link};" + ) + self._table_items.append(f"[word_{output_id_link}[s] | s in 0..{input_size // self.word_size - 1}]") diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model.py index e5aa4b16a..5f4f874b5 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -22,21 +21,27 @@ from sage.crypto.sbox import SBox -from claasp.cipher_modules.models.cp.mzn_model import MznModel, solve_satisfy, constraint_type_error -from claasp.cipher_modules.models.utils import get_bit_bindings, \ - get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import INTERMEDIATE_OUTPUT, XOR_LINEAR, CONSTANT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, \ - MIX_COLUMN, WORD_OPERATION, INPUT_KEY +from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY, CONSTRAINT_TYPE_ERROR +from claasp.cipher_modules.models.utils import get_bit_bindings, get_single_key_scenario_format_for_fixed_values +from claasp.name_mappings import ( + INTERMEDIATE_OUTPUT, + XOR_LINEAR, + CONSTANT, + CIPHER_OUTPUT, + LINEAR_LAYER, + SBOX, + MIX_COLUMN, + WORD_OPERATION, + INPUT_KEY, +) from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT class MznXorLinearModel(MznModel): - def __init__(self, cipher): super().__init__(cipher) - format_func = lambda record: f'{record[0]}_{record[2]}[{record[1]}]' - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings( - cipher, format_func) + format_func = lambda record: f"{record[0]}_{record[2]}[{record[1]}]" + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, format_func) def and_xor_linear_probability_lat(self, numadd): """ @@ -58,13 +63,13 @@ def and_xor_linear_probability_lat(self, numadd): lat = [] for full_mask in range(2 ** (numadd + 1)): num_of_matches = 0 - for values in range(2 ** numadd): + for values in range(2**numadd): full_values = values << 1 bit_of_values = (values >> i & 1 for i in range(numadd)) full_values ^= 0 not in bit_of_values equation = full_values & full_mask addenda = (equation >> i & 1 for i in range(numadd + 1)) - num_of_matches += (sum(addenda) % 2 == 0) + num_of_matches += sum(addenda) % 2 == 0 lat.append(num_of_matches - (2 ** (numadd - 1))) return lat @@ -96,11 +101,11 @@ def branch_xor_linear_constraints(self): for output_bit_id, input_bit_ids in self.bit_bindings.items(): # no fork if len(input_bit_ids) == 1: - cp_constraints.append(f'constraint {output_bit_id} = {input_bit_ids[0]};') + cp_constraints.append(f"constraint {output_bit_id} = {input_bit_ids[0]};") # fork else: - operation = f'({" + ".join(input_bit_ids)}) mod 2;' - cp_constraints.append(f'constraint {output_bit_id} = {operation}') + operation = f"({' + '.join(input_bit_ids)}) mod 2;" + cp_constraints.append(f"constraint {output_bit_id} = {operation}") return cp_constraints @@ -137,21 +142,29 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): cipher_without_key_schedule = self._cipher.remove_key_schedule() self._cipher = cipher_without_key_schedule self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings( - self._cipher, lambda record: f'{record[0]}_{record[2]}[{record[1]}]') + self._cipher, lambda record: f"{record[0]}_{record[2]}[{record[1]}]" + ) if fixed_variables == []: fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) constraints = self.fix_variables_value_xor_linear_constraints(fixed_variables) self._model_constraints = constraints for component in self._cipher.get_all_components(): - component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, - SBOX, MIX_COLUMN, WORD_OPERATION] + component_types = [ + CONSTANT, + INTERMEDIATE_OUTPUT, + CIPHER_OUTPUT, + LINEAR_LAYER, + SBOX, + MIX_COLUMN, + WORD_OPERATION, + ] operation = component.description[0] operation_types = ["AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB"] if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): variables, constraints = component.cp_xor_linear_mask_propagation_constraints(self) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -192,31 +205,46 @@ def final_xor_linear_constraints(self, weight): ['solve:: int_search(p, smallest, indomain_min, complete) minimize sum(p);'] """ cipher_inputs = self._cipher.inputs - cp_constraints = ['solve:: int_search(p, smallest, indomain_min, complete) minimize sum(p);' - if weight == -1 else solve_satisfy] - new_constraint = 'output[' + cp_constraints = [ + "solve:: int_search(p, smallest, indomain_min, complete) minimize sum(p);" + if weight == -1 + else SOLVE_SATISFY + ] + new_constraint = "output[" for i, element in enumerate(cipher_inputs): - new_constraint += f'\"{element} = \"++ show({element}_o) ++ \"\\n\" ++' + new_constraint += f'"{element} = "++ show({element}_o) ++ "\\n" ++' for component in self._cipher.get_all_components(): if SBOX in component.type: - new_constraint += f'\"{component.id}_i = \"++ show({component.id}_i)++ \"\\n\" ++ ' \ - f'\"{component.id}_o = \"++ show({component.id}_o)++ \"\\n\" ++ ' \ - f'show(p[{self.component_and_probability[component.id]}]) ++ \"\\n\" ++' + new_constraint += ( + f'"{component.id}_i = "++ show({component.id}_i)++ "\\n" ++ ' + f'"{component.id}_o = "++ show({component.id}_o)++ "\\n" ++ ' + f'show(p[{self.component_and_probability[component.id]}]) ++ "\\n" ++' + ) elif CIPHER_OUTPUT in component.type: - new_constraint += f'\"{component.id}_o = \"++ ' \ - f'show({component.id}_i)++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + new_constraint += f'"{component.id}_o = "++ show({component.id}_i)++ "\\n" ++ "0" ++ "\\n" ++' elif WORD_OPERATION in component.type: new_constraint = self.get_word_operation_final_xor_linear_constraints(component, new_constraint) else: - new_constraint += f'\"{component.id}_i = \"++ show({component.id}_o)++ \"\\n\" ++ ' \ - f'\"{component.id}_o = \"++ show({component.id}_o)++ \"\\n\" ++ \"0\" ++ \"\\n\" ++' + new_constraint += ( + f'"{component.id}_i = "++ show({component.id}_o)++ "\\n" ++ ' + f'"{component.id}_o = "++ show({component.id}_o)++ "\\n" ++ "0" ++ "\\n" ++' + ) - new_constraint += '\"Trail weight = \" ++ show(weight)];' + new_constraint += '"Trail weight = " ++ show(weight)];' cp_constraints.append(new_constraint) return cp_constraints - def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + def find_all_xor_linear_trails_with_fixed_weight( + self, + fixed_weight, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return a list of solutions containing all the linear trails having the ``fixed_weight`` weight of correlation. By default, the search removes the key schedule, if any. @@ -225,11 +253,8 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value - ``fixed_weight`` -- **integer**; the weight to be fixed - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -257,16 +282,37 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value end = tm.time() build_time = end - start if solve_with_API: - solutions = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True) + solutions = self.solve_for_ARX( + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + ) else: - solutions = self.solve(XOR_LINEAR, solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True, solve_external = solve_external) + solutions = self.solve( + XOR_LINEAR, + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + solve_external=solve_external, + ) for solution in solutions: - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_all_xor_linear_trails_with_fixed_weight" + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_all_xor_linear_trails_with_fixed_weight" return solutions - def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight=64, - fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + def find_all_xor_linear_trails_with_weight_at_most( + self, + min_weight, + max_weight=64, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return a list of solutions containing all the linear trails having the weight of correlation lying in the interval ``[min_weight, max_weight]``. By default, the search removes the key schedule, if any. @@ -276,11 +322,8 @@ def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight= - ``min_weight`` -- **integer**; the weight from which to start the search - ``max_weight`` -- **integer** (default: `64`); the weight at which the search stops - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -305,20 +348,40 @@ def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight= """ start = tm.time() self.build_xor_linear_trail_model(0, fixed_values) - self._model_constraints.append(f'constraint weight >= {100 * min_weight} /\\ weight <= {100 * max_weight} ') + self._model_constraints.append(f"constraint weight >= {100 * min_weight} /\\ weight <= {100 * max_weight} ") end = tm.time() build_time = end - start if solve_with_API: - solutions = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True) + solutions = self.solve_for_ARX( + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + ) else: - solutions = self.solve(XOR_LINEAR, solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, all_solutions_ = True, solve_external = solve_external) + solutions = self.solve( + XOR_LINEAR, + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + all_solutions_=True, + solve_external=solve_external, + ) for solution in solutions: - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_all_xor_linear_trails_with_weight_at_most" + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_all_xor_linear_trails_with_weight_at_most" return solutions - def find_lowest_weight_xor_linear_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + def find_lowest_weight_xor_linear_trail( + self, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return the solution representing a linear trail with the lowest weight of correlation. By default, the search removes the key schedule, if any. @@ -331,11 +394,8 @@ def find_lowest_weight_xor_linear_trail(self, fixed_values=[], solver_name=SOLVE INPUT: - ``fixed_values`` -- **list** (default: `[]`); they can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -363,15 +423,31 @@ def find_lowest_weight_xor_linear_trail(self, fixed_values=[], solver_name=SOLVE end = tm.time() build_time = end - start if solve_with_API: - solution = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) + solution = self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) else: - solution = self.solve('xor_linear_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_lowest_weight_xor_linear_trail" + solution = self.solve( + "xor_linear_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_lowest_weight_xor_linear_trail" return solution - def find_one_xor_linear_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + def find_one_xor_linear_trail( + self, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return the solution representing a linear trail with any weight of correlation. By default, the search removes the key schedule, if any. @@ -379,11 +455,8 @@ def find_one_xor_linear_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, INPUT: - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: @@ -407,15 +480,32 @@ def find_one_xor_linear_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, end = tm.time() build_time = end - start if solve_with_API: - solution = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) + solution = self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) else: - solution = self.solve('xor_linear_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_one_xor_linear_trail" + solution = self.solve( + "xor_linear_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_one_xor_linear_trail" return solution - def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight=-1, fixed_values=[], solver_name=SOLVER_DEFAULT, num_of_processors=None, timelimit=None, solve_with_API=False, solve_external = False): + def find_one_xor_linear_trail_with_fixed_weight( + self, + fixed_weight=-1, + fixed_values=[], + solver_name=SOLVER_DEFAULT, + num_of_processors=None, + timelimit=None, + solve_with_API=False, + solve_external=False, + ): """ Return the solution representing a linear trail with the weight of correlation equal to ``fixed_weight``. By default, the search removes the key schedule, if any. @@ -424,12 +514,8 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight=-1, fixed_val - ``fixed_weight`` -- **integer**; the value to which the weight is fixed, if non-negative - ``fixed_values`` -- **list** (default: `[]`); can be created using ``set_fixed_variables`` method - - ``solver_name`` -- **string** (default: `Chuffed`); the name of the solver. Available values are: - - * ``'Chuffed'`` - * ``'Gecode'`` - * ``'COIN-BC'`` - + - ``solver_name`` -- **string** (default: `chuffed`); the name of the solver. + See also :meth:`MznModel.solver_names`. EXAMPLES:: sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel @@ -456,11 +542,19 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight=-1, fixed_val end = tm.time() build_time = end - start if solve_with_API: - solution = self.solve_for_ARX(solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors) + solution = self.solve_for_ARX( + solver_name=solver_name, timeout_in_seconds_=timelimit, processes_=num_of_processors + ) else: - solution = self.solve('xor_linear_one_solution', solver_name = solver_name, timeout_in_seconds_ = timelimit, processes_ = num_of_processors, solve_external = solve_external) - solution['building_time_seconds'] = build_time - solution['test_name'] = "find_one_xor_linear_trail_with_fixed_weight" + solution = self.solve( + "xor_linear_one_solution", + solver_name=solver_name, + timeout_in_seconds_=timelimit, + processes_=num_of_processors, + solve_external=solve_external, + ) + solution["building_time_seconds"] = build_time + solution["test_name"] = "find_one_xor_linear_trail_with_fixed_weight" return solution @@ -498,50 +592,58 @@ def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]): """ cp_constraints = [] for component in fixed_variables: - component_id = component['component_id'] - bit_positions = component['bit_positions'] - bit_values = component['bit_values'] - if component['constraint_type'] == 'equal': - conditions = ' /\\ '.join(f'{component_id}_o[{value}] = {bit_values[index]}' - for index, value in enumerate(bit_positions)) - elif component['constraint_type'] == 'not_equal': - conditions = ' \\/ '.join(f'{component_id}_o[{value}] != {bit_values[index]}' - for index, value in enumerate(bit_positions)) - constraint = f'constraint {conditions};' + component_id = component["component_id"] + bit_positions = component["bit_positions"] + bit_values = component["bit_values"] + if component["constraint_type"] == "equal": + conditions = " /\\ ".join( + f"{component_id}_o[{value}] = {bit_values[index]}" for index, value in enumerate(bit_positions) + ) + elif component["constraint_type"] == "not_equal": + conditions = " \\/ ".join( + f"{component_id}_o[{value}] != {bit_values[index]}" for index, value in enumerate(bit_positions) + ) + constraint = f"constraint {conditions};" cp_constraints.append(constraint) return cp_constraints def get_lat_values(self, lat_table, numadd): lat_entries = [] - lat_values = '' + lat_values = "" for i in range(pow(2, numadd + 1)): if lat_table[i] != 0: - binary_i = format(i, f'0{numadd + 1}b') - lat_entries += [f'{binary_i[j]}' for j in range(numadd + 1)] + binary_i = format(i, f"0{numadd + 1}b") + lat_entries += [f"{binary_i[j]}" for j in range(numadd + 1)] lat_entries.append(str(round(100 * math.log2(pow(2, numadd - 1) / abs(lat_table[i]))))) - lat_values = ','.join(lat_entries) + lat_values = ",".join(lat_entries) return lat_values def get_word_operation_final_xor_linear_constraints(self, component, new_constraint): - if 'AND' in component.description[0]: - new_constraint += f'\"{component.id}_i = \"++ ' \ - f'show({component.id}_i)++ \"\\n\" ++ \"{component.id}_o = ' \ - f'\"++ show({component.id}_o)++ \"\\n\" ++ show(' + if "AND" in component.description[0]: + new_constraint += ( + f'"{component.id}_i = "++ ' + f'show({component.id}_i)++ "\\n" ++ "{component.id}_o = ' + f'"++ show({component.id}_o)++ "\\n" ++ show(' + ) for i in range(len(self.component_and_probability[component.id])): - new_constraint += f'p[{self.component_and_probability[component.id][i]}]+' - new_constraint = new_constraint[:-1] + ') ++ \"\\n\" ++' - elif 'MODADD' in component.description[0]: - new_constraint += f'\"{component.id}_i = \"++ show({component.id}_i)++ \"\\n\" ++ ' \ - f'\"{component.id}_o = \"++ show({component.id}_o)++ \"\\n\" ++ show(' + new_constraint += f"p[{self.component_and_probability[component.id][i]}]+" + new_constraint = new_constraint[:-1] + ') ++ "\\n" ++' + elif "MODADD" in component.description[0]: + new_constraint += ( + f'"{component.id}_i = "++ show({component.id}_i)++ "\\n" ++ ' + f'"{component.id}_o = "++ show({component.id}_o)++ "\\n" ++ show(' + ) for i in range(len(self.component_and_probability[component.id])): - new_constraint += f'p[{self.component_and_probability[component.id][i]}]+' - new_constraint = new_constraint[:-1] + ') ++ \"\\n\" ++' + new_constraint += f"p[{self.component_and_probability[component.id][i]}]+" + new_constraint = new_constraint[:-1] + ') ++ "\\n" ++' else: - new_constraint += f'\"{component.id}_i = \"++ show({component.id}_i)++ \"\\n\" ' \ - f'++\"{component.id}_o = \"++ show({component.id}_o)++ \"\\n\" ' \ - f'++ \"0\" ++ \"\\n\" ++' + new_constraint += ( + f'"{component.id}_i = "++ show({component.id}_i)++ "\\n" ' + f'++"{component.id}_o = "++ show({component.id}_o)++ "\\n" ' + f'++ "0" ++ "\\n" ++' + ) return new_constraint @@ -571,8 +673,10 @@ def input_xor_linear_constraints(self): and_already_added = [] cipher_inputs = self._cipher.inputs cipher_inputs_bit_size = self._cipher.inputs_bit_size - cp_declarations = [f'array[0..{cipher_inputs_bit_size[i] - 1}] of var 0..1: {element}_o;' - for i, element in enumerate(cipher_inputs)] + cp_declarations = [ + f"array[0..{cipher_inputs_bit_size[i] - 1}] of var 0..1: {element}_o;" + for i, element in enumerate(cipher_inputs) + ] prob_count = 0 xor_count = 0 valid_probabilities = {0} @@ -581,20 +685,21 @@ def input_xor_linear_constraints(self): prob_count = prob_count + 1 self.update_sbox_lat_valid_probabilities(component, valid_probabilities) elif WORD_OPERATION in component.type: - if 'AND' in component.description[0] or component.description[0] == 'OR': + if "AND" in component.description[0] or component.description[0] == "OR": prob_count += component.description[1] * component.output_bit_size - self.update_and_or_lat_valid_probabilities(and_already_added, component, cp_declarations, - valid_probabilities) - elif 'MODADD' in component.description[0]: + self.update_and_or_lat_valid_probabilities( + and_already_added, component, cp_declarations, valid_probabilities + ) + elif "MODADD" in component.description[0]: prob_count = prob_count + component.description[1] - 1 output_size = component.output_bit_size valid_probabilities.update({i + 100 for i in range(100 * output_size)[::100]}) - elif 'XOR' in component.description[0]: - if any('constant' in input_links for input_links in component.input_id_links): + elif "XOR" in component.description[0]: + if any("constant" in input_links for input_links in component.input_id_links): xor_count = xor_count + 1 - cp_declarations.append(f'array[0..{prob_count - 1}] of var {valid_probabilities}: p;') - data_type = 'int' - cp_declarations.append(f'var {data_type}: weight = sum(p);') + cp_declarations.append(f"array[0..{prob_count - 1}] of var {valid_probabilities}: p;") + data_type = "int" + cp_declarations.append(f"var {data_type}: weight = sum(p);") return cp_declarations, cp_constraints @@ -608,9 +713,11 @@ def update_and_or_lat_valid_probabilities(self, and_already_added, component, cp for occurrence in set_of_occurrences: valid_probabilities.add(round(100 * math.log2(abs(pow(2, numadd - 1) / occurrence)))) lat_values = self.get_lat_values(lat_table, numadd) - and_declaration = f'array [1..{dim_lat}, 1..{numadd + 2}] of int: ' \ - f'and{numadd}inputs_LAT = array2d(1..{dim_lat}, 1..{numadd + 2}, ' \ - f'[{lat_values}]);' + and_declaration = ( + f"array [1..{dim_lat}, 1..{numadd + 2}] of int: " + f"and{numadd}inputs_LAT = array2d(1..{dim_lat}, 1..{numadd + 2}, " + f"[{lat_values}]);" + ) cp_declarations.append(and_declaration) and_already_added.append(numadd) @@ -629,8 +736,11 @@ def update_sbox_lat_valid_probabilities(self, component, valid_probabilities): set_of_occurrences = set(sbox_lat.rows()[i]) set_of_occurrences -= {0} valid_probabilities.update( - {round(100 * math.log2(abs(pow(2, input_size - 1) / occurence))) for occurence in - set_of_occurrences}) + { + round(100 * math.log2(abs(pow(2, input_size - 1) / occurence))) + for occurence in set_of_occurrences + } + ) self.sbox_mant.append((description, output_id_link)) def weight_xor_linear_constraints(self, weight): diff --git a/claasp/cipher_modules/models/cp/solvers.py b/claasp/cipher_modules/models/cp/solvers.py index ae2ad8258..cbcb35115 100644 --- a/claasp/cipher_modules/models/cp/solvers.py +++ b/claasp/cipher_modules/models/cp/solvers.py @@ -1,108 +1,105 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** -import os +# solvers definition +CHOCO = "choco" +CHUFFED = "chuffed" +COIN_BC = "coin-bc" +CPLEX = "cplex" +FINDMUS = "findmus" +GLOBALIZER = "globalizer" +GUROBI = "gurobi" +SCIP = "scip" +CPSAT = "cp-sat" +XPRESS = "xpress" + + +SOLVER_DEFAULT = CHUFFED -SOLVER_DEFAULT = 'chuffed' -MODEL_DEFAULT_PATH = os.getcwd() -CP_SOLVERS_INTERNAL = [{'solver_brand_name': 'Choco', 'solver_name': 'choco'}, - {'solver_brand_name': 'Chuffed', 'solver_name': 'chuffed'}, - {'solver_brand_name': 'COIN-BC', 'solver_name': 'coin-bc'}, - {'solver_brand_name': 'IBM ILOG CPLEX', 'solver_name': 'cplex'}, - {'solver_brand_name': 'MiniZinc findMUS', 'solver_name': 'findmus'}, - {'solver_brand_name': 'Gecode', 'solver_name': 'gecode'}, - {'solver_brand_name': 'MiniZinc Globalizer', 'solver_name': 'globalizer'}, - {'solver_brand_name': 'Gurobi Optimizer', 'solver_name': 'gurobi'}, - {'solver_brand_name': 'SCIP', 'solver_name': 'scip'}, - {'solver_brand_name': 'OR Tools', 'solver_name': 'Xor'}, - {'solver_brand_name': 'FICO Xpress', 'solver_name': 'xpress'},] +CP_SOLVERS_INTERNAL = [ + {"solver_brand_name": "Choco", "solver_name": CHOCO}, + {"solver_brand_name": "Chuffed", "solver_name": CHUFFED}, + {"solver_brand_name": "COIN-BC", "solver_name": COIN_BC}, + {"solver_brand_name": "IBM ILOG CPLEX", "solver_name": CPLEX}, + {"solver_brand_name": "MiniZinc findMUS", "solver_name": FINDMUS}, + {"solver_brand_name": "MiniZinc Globalizer", "solver_name": GLOBALIZER}, + {"solver_brand_name": "Gurobi Optimizer", "solver_name": GUROBI}, + {"solver_brand_name": "SCIP", "solver_name": SCIP}, + {"solver_brand_name": "OR Tools", "solver_name": CPSAT}, + {"solver_brand_name": "FICO Xpress", "solver_name": XPRESS}, +] + CP_SOLVERS_EXTERNAL = [ { - 'solver_brand_name': 'Chuffed', - 'solver_name': 'Chuffed', # keyword to call the solver - 'keywords': { - 'command': { - 'executable': ['minizinc'], - 'options': ['--solver-statistics'], - 'input_file': [], - 'output_file': [], - 'solver': ['--solver', 'Chuffed'], - 'format': ['executable', 'options', 'solver', 'input_file', 'output_file'], + "solver_brand_name": "Chuffed", + "solver_name": CHUFFED, # keyword to call the solver + "keywords": { + "command": { + "executable": ["minizinc"], + "options": ["--input-from-stdin", "--solver-statistics"], + "input_file": [], + "output_file": [], + "solver": ["--solver", CHUFFED], + "format": ["executable", "options", "solver"], }, }, }, { - 'solver_brand_name': 'Gecode', - 'solver_name': 'gecode', # keyword to call the solver - 'keywords': { - 'command': { - 'executable': ['minizinc'], - 'options': ['--solver-statistics'], - 'input_file': [], - 'output_file': [], - 'solver': ['--solver', 'Gecode'], - 'format': ['executable', 'options', 'solver', 'input_file', 'output_file'], + "solver_brand_name": "OR Tools", + "solver_name": CPSAT, # keyword to call the solver + "keywords": { + "command": { + "executable": ["minizinc"], + "options": ["--input-from-stdin", "--solver-statistics"], + "input_file": [], + "output_file": [], + "solver": ["--solver", CPSAT], + "format": ["executable", "options", "solver"], }, - }, - }, - { - 'solver_brand_name': 'OR Tools', - 'solver_name': 'Xor', # keyword to call the solver - 'keywords': { - 'command': { - 'executable': ['minizinc'], - 'options': ['--solver-statistics'], - 'input_file': [], - 'output_file': [], - 'solver': ['--solver', 'Xor'], - 'format': ['executable', 'options', 'solver', 'input_file', 'output_file'], - }, - }, + }, }, { - 'solver_brand_name': 'COIN-BC', - 'solver_name': 'coin-bc', # keyword to call the solver - 'keywords': { - 'command': { - 'executable': ['minizinc'], - 'options': ['--solver-statistics'], - 'input_file': [], - 'output_file': [], - 'solver': ['--solver', 'COIN-BC'], - 'format': ['executable', 'options', 'solver', 'input_file', 'output_file'], + "solver_brand_name": "COIN-BC", + "solver_name": COIN_BC, # keyword to call the solver + "keywords": { + "command": { + "executable": ["minizinc"], + "options": ["--input-from-stdin", "--solver-statistics"], + "input_file": [], + "output_file": [], + "solver": ["--solver", COIN_BC], + "format": ["executable", "options", "solver"], }, - }, + }, }, { - 'solver_brand_name': 'Choco', - 'solver_name': 'choco', # keyword to call the solver - 'keywords': { - 'command': { - 'executable': ['minizinc'], - 'options': ['--solver-statistics'], - 'input_file': [], - 'output_file': [], - 'solver': ['--solver', 'choco'], - 'format': ['executable', 'options', 'solver', 'input_file', 'output_file'], + "solver_brand_name": "Choco", + "solver_name": CHOCO, # keyword to call the solver + "keywords": { + "command": { + "executable": ["--input-from-stdin", "minizinc"], + "options": ["--solver-statistics"], + "input_file": [], + "output_file": [], + "solver": ["--solver", CHOCO], + "format": ["executable", "options", "solver"], }, - }, + }, }, ] - diff --git a/claasp/cipher_modules/models/milp/__init__.py b/claasp/cipher_modules/models/milp/__init__.py index 795e4ce72..deab14490 100644 --- a/claasp/cipher_modules/models/milp/__init__.py +++ b/claasp/cipher_modules/models/milp/__init__.py @@ -2,5 +2,5 @@ MILP_AUXILIARY_FILE_PATH = os.getcwd() -if os.access(os.path.join(os.path.dirname(__file__), 'utils'), os.W_OK): - MILP_AUXILIARY_FILE_PATH = os.path.join(os.path.dirname(__file__), 'utils') +if os.access(os.path.join(os.path.dirname(__file__), "utils"), os.W_OK): + MILP_AUXILIARY_FILE_PATH = os.path.join(os.path.dirname(__file__), "utils") diff --git a/claasp/cipher_modules/models/milp/milp_model.py b/claasp/cipher_modules/models/milp/milp_model.py index 5b2b97268..15b0b2695 100644 --- a/claasp/cipher_modules/models/milp/milp_model.py +++ b/claasp/cipher_modules/models/milp/milp_model.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -36,6 +35,7 @@ The default choice is GLPK. """ + import os import subprocess import time @@ -43,11 +43,21 @@ from sage.numerical.mip import MixedIntegerLinearProgram, MIPSolverException -from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT, MODEL_DEFAULT_PATH, MILP_SOLVERS_EXTERNAL, \ - MILP_SOLVERS_INTERNAL +from claasp.cipher_modules.models.milp.solvers import ( + MILP_SOLVERS_EXTERNAL, + MILP_SOLVERS_INTERNAL, + MODEL_DEFAULT_PATH, + SOLVER_DEFAULT, +) from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_DEFAULT_WEIGHT_PRECISION -from claasp.cipher_modules.models.milp.utils.utils import _get_data, _parse_external_solver_output, _write_model_to_lp_file +from claasp.cipher_modules.models.milp.utils.utils import ( + _get_data, + _parse_external_solver_output, + _write_model_to_lp_file, +) from claasp.cipher_modules.models.utils import convert_solver_solution_to_dictionary +from claasp.name_mappings import SATISFIABLE, UNSATISFIABLE + def get_independent_input_output_variables(component): """ @@ -187,22 +197,47 @@ def fix_variables_value_constraints(self, fixed_variables=[]): 1 <= x_4 + x_6 + x_8 + x_10] """ x = self._binary_variable + n_trails = self._number_of_trails_found constraints = [] - for fixed_variable in fixed_variables: - component_id = fixed_variable["component_id"] - if fixed_variable["constraint_type"] == "equal": - for index, bit_position in enumerate(fixed_variable["bit_positions"]): - constraints.append(x[f"{component_id}_{bit_position}"] == fixed_variable["bit_values"][index]) + + for fv in fixed_variables: + bit_vals = fv["bit_values"] + comp_id = fv["component_id"] + bit_pos = fv["bit_positions"] + ctype = fv["constraint_type"] + + if bit_vals[0] not in [0, 1]: + var_vals = [] + for v in bit_vals: + var_vals.extend([(v_id, i) for v_id, bits in [v] for i in bits]) + + if ctype == "equal": + for i, pos in enumerate(bit_pos): + constraints.append(x[f"{comp_id}_{pos}"] == x[f"{var_vals[i][0]}_{var_vals[i][1]}"]) + else: + for i, pos in enumerate(bit_pos): + neq = f"{comp_id}{pos}_not_equal_{n_trails}" + lhs = x[neq] + a = x[f"{comp_id}_{pos}"] + b = x[f"{var_vals[i][0]}_{var_vals[i][1]}"] + + constraints.append(lhs <= a + b) + constraints.append(lhs >= a - b) + constraints.append(lhs >= b - a) + constraints.append(lhs + a + b <= 2) + + constraints.append(sum(x[f"{comp_id}{p}_not_equal_{n_trails}"] for p in bit_pos) >= 1) else: - for index, bit_position in enumerate(fixed_variable["bit_positions"]): - if fixed_variable["bit_values"][index]: - constraints.append(x[f"{component_id}{bit_position}_not_equal_{self._number_of_trails_found}"] - == 1 - x[f"{component_id}_{bit_position}"]) - else: - constraints.append(x[f"{component_id}{bit_position}_not_equal_{self._number_of_trails_found}"] - == x[f"{component_id}_{bit_position}"]) - constraints.append(sum(x[f"{component_id}{i}_not_equal_{self._number_of_trails_found}"] - for i in fixed_variable["bit_positions"]) >= 1) + if ctype == "equal": + for i, pos in enumerate(bit_pos): + constraints.append(x[f"{comp_id}_{pos}"] == bit_vals[i]) + else: + for i, pos in enumerate(bit_pos): + neq = x[f"{comp_id}{pos}_not_equal_{n_trails}"] + a = x[f"{comp_id}_{pos}"] + bit_val = bit_vals[i] + constraints.append(neq == (1 - a if bit_val else a)) + constraints.append(sum(x[f"{comp_id}{p}_not_equal_{n_trails}"] for p in bit_pos) >= 1) return constraints @@ -233,10 +268,10 @@ def weight_constraints(self, weight, weight_precision=MILP_DEFAULT_WEIGHT_PRECIS constraints = [] if weight >= 0: - constraints.append(p["probability"] == (10 ** weight_precision) * weight) + constraints.append(p["probability"] == (10**weight_precision) * weight) variables = [("p[probability]", p["probability"])] elif weight != -1: - self._model.set_max(p["probability"], - (10 ** weight_precision) * weight) + self._model.set_max(p["probability"], -(10**weight_precision) * weight) variables = [("p[probability]", p["probability"])] return variables, constraints @@ -267,49 +302,50 @@ def init_model_in_sage_milp_class(self, solver_name=SOLVER_DEFAULT): self._non_linear_component_id = [] def _solve_with_external_solver(self, model_type, model_path, solver_name=SOLVER_DEFAULT): - - solver_specs = [specs for specs in MILP_SOLVERS_EXTERNAL if specs["solver_name"] == solver_name.upper()][0] - solution_file_path = f'{MODEL_DEFAULT_PATH}/{model_path[:-3]}.sol' + solution_file_path = f"{MODEL_DEFAULT_PATH}/{model_path[:-3]}.sol" command = "" - for key in solver_specs['keywords']['command']['format']: - parameter = solver_specs['keywords']['command'][key] + for key in solver_specs["keywords"]["command"]["format"]: + parameter = solver_specs["keywords"]["command"][key] if key == "input_file": parameter += " " + model_path elif key == "output_file": - parameter = parameter + solution_file_path if parameter.endswith('=') else parameter + " " + solution_file_path + parameter = ( + parameter + solution_file_path if parameter.endswith("=") else parameter + " " + solution_file_path + ) elif key == "options": parameter = " ".join(parameter) command += " " + parameter tracemalloc.start() solver_process = subprocess.run(command, capture_output=True, shell=True, text=True) - milp_memory = tracemalloc.get_traced_memory()[1] / 10 ** 6 + milp_memory = tracemalloc.get_traced_memory()[1] / 10**6 tracemalloc.stop() if solver_process.stderr: raise MIPSolverException("Make sure that the solver is correctly installed.") - if 'memory' in solver_specs: - milp_memory = _get_data(solver_specs['keywords']['memory'], str(solver_process)) + if "memory" in solver_specs: + milp_memory = _get_data(solver_specs["keywords"]["memory"], str(solver_process)) - return _parse_external_solver_output(self, solver_specs, model_type, solution_file_path, solver_process.stdout) + (milp_memory,) + return _parse_external_solver_output( + self, solver_specs, model_type, solution_file_path, solver_process.stdout + ) + (milp_memory,) def _solve_with_internal_solver(self): - mip = self._model - status = 'UNSATISFIABLE' + status = UNSATISFIABLE self._verbose_print("Solving model in progress ...") time_start = time.time() tracemalloc.start() try: mip.solve() - status = 'SATISFIABLE' + status = SATISFIABLE except MIPSolverException as milp_exception: print(milp_exception) finally: - milp_memory = tracemalloc.get_traced_memory()[1] / 10 ** 6 + milp_memory = tracemalloc.get_traced_memory()[1] / 10**6 tracemalloc.stop() time_end = time.time() milp_time = time_end - time_start @@ -341,13 +377,16 @@ def solve(self, model_type, solver_name=SOLVER_DEFAULT, external_solver_name=Non if external_solver_name or (solver_name.upper().endswith("_EXT")): solver_choice = external_solver_name or solver_name if solver_choice.upper() not in [specs["solver_name"] for specs in MILP_SOLVERS_EXTERNAL]: - raise ValueError(f"Invalid solver name: {solver_choice}.\n" - f"Please select a solver in the following list: {[specs['solver_name'] for specs in MILP_SOLVERS_EXTERNAL]}.") + raise ValueError( + f"Invalid solver name: {solver_choice}.\n" + f"Please select a solver in the following list: {[specs['solver_name'] for specs in MILP_SOLVERS_EXTERNAL]}." + ) solver_name_in_solution = solver_choice model_path = _write_model_to_lp_file(self, model_type) - solution_file_path, status, objective_value, components_values, milp_time, milp_memory = self._solve_with_external_solver( - model_type, model_path, solver_choice) + solution_file_path, status, objective_value, components_values, milp_time, milp_memory = ( + self._solve_with_external_solver(model_type, model_path, solver_choice) + ) os.remove(model_path) os.remove(f"{solution_file_path}") else: @@ -355,22 +394,29 @@ def solve(self, model_type, solver_name=SOLVER_DEFAULT, external_solver_name=Non components_values = None solver_name_in_solution = solver_name status, milp_time, milp_memory = self._solve_with_internal_solver() - if status == 'SATISFIABLE': + if status == SATISFIABLE: objective_value, components_values = self._parse_solver_output() - solution = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name_in_solution, milp_time, - milp_memory, components_values, objective_value) - solution['status'] = status + solution = convert_solver_solution_to_dictionary( + self._cipher, + model_type, + solver_name_in_solution, + milp_time, + milp_memory, + components_values, + objective_value, + ) + solution["status"] = status return solution def solver_names(self, verbose=False): solver_names = [] - keys = ['solver_brand_name', 'solver_name'] + keys = ["solver_brand_name", "solver_name"] for solver in MILP_SOLVERS_INTERNAL: solver_names.append({key: solver[key] for key in keys}) if verbose: - keys = ['solver_brand_name', 'solver_name', 'keywords'] + keys = ["solver_brand_name", "solver_name", "keywords"] for solver in MILP_SOLVERS_EXTERNAL: solver_names.append({key: solver[key] for key in keys}) @@ -421,7 +467,7 @@ def model_constraints(self): ValueError: No model generated """ if not self._model_constraints: - raise ValueError('No model generated') + raise ValueError("No model generated") return self._model_constraints @property diff --git a/claasp/cipher_modules/models/milp/milp_models/Gurobi/__init__.py b/claasp/cipher_modules/models/milp/milp_models/Gurobi/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/claasp/cipher_modules/models/milp/milp_models/Gurobi/monomial_prediction.py b/claasp/cipher_modules/models/milp/milp_models/Gurobi/monomial_prediction.py new file mode 100644 index 000000000..473388577 --- /dev/null +++ b/claasp/cipher_modules/models/milp/milp_models/Gurobi/monomial_prediction.py @@ -0,0 +1,1969 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + +import time +from sage.crypto.sbox import SBox +from collections import Counter +from sage.rings.polynomial.pbori.pbori import BooleanPolynomialRing +from claasp.cipher_modules.graph_generator import create_networkx_graph_from_input_ids, _get_predecessors_subgraph +from claasp.cipher_modules.component_analysis_tests import binary_matrix_of_linear_component +from gurobipy import Model, GRB, Env +import os +import secrets +from sage.all import GF + +verbosity = False + + +class MilpMonomialPredictionModel(): + """ + + Given a number of rounds of a chosen cipher and a chosen output bit, this module produces a model that can either: + - find the ANF of this chosen output bit, + - find an upper bound of this ANF, + - find the exact degree of this ANF (slower), + - find the superpoly of this ANF given a chosen cube. + + This module can only be used if the user possesses a Gurobi license. + + """ + + def __init__(self, cipher): + self._cipher = cipher + self._variables = None + self._model = None + self._occurences = None + self._used_variables = [] + self._variables_as_list = [] + self._unused_variables = [] + self._used_predecessors_sorted = None + self._constants = {} + + def build_gurobi_model(self): + if os.getenv('GUROBI_COMPUTE_SERVER') is not None: + env = Env(empty=True) + env.setParam('ComputeServer', os.getenv('GUROBI_COMPUTE_SERVER')) + env.start() + model = Model(env=env) + else: + model = Model() + model.Params.LogToConsole = 0 + self._model = model + + def get_all_variables_as_list(self): + for component_id in list(self._variables.keys())[:-1]: + for bit_position in self._variables[component_id].keys(): + self._variables_as_list.append(self._variables[component_id][bit_position]["original"].VarName) + copies = self._variables[component_id][bit_position]["copies"] + for copy in copies: + self._variables_as_list.append(copy.VarName) + + def get_unused_variables(self): + self.get_all_variables_as_list() + for variable in self._variables_as_list: + if variable not in self._used_variables: + self._unused_variables.append(variable) + + def set_unused_variables_to_zero(self): + self.get_unused_variables() + for name in self._unused_variables: + var = self._model.getVarByName(name) + self._model.addConstr(var == 0) + + def set_as_used_variables(self, variables): + self._model.update() + for v in variables: + try: + if v.VarName not in self._used_variables: + self._used_variables.append(v.VarName) + if "copy" in v.VarName.split("_"): + i = v.VarName.split("_").index("copy") + tmp1 = v.VarName.split("_")[(i + 2):] + tmp2 = "_".join(tmp1) + self._used_variables.append(tmp2) + self._unused_variables = [x for x in self._unused_variables if x != v.VarName] + except: + continue + + def create_all_copies(self): + for name in list(self._variables.keys())[:-1]: + for bit_position in self._variables[name].keys(): + copies = self._variables[name][bit_position]["copies"] + original_var = self._variables[name][bit_position]["original"] + + if copies != []: + for i in range(len(copies)): + self._model.addConstr(original_var >= copies[i]) + self._model.addConstr(sum(copies[i] for i in range(len(copies))) >= original_var) + self._model.update() + + def get_anfs_from_sbox(self, component): + anfs = [] + B = BooleanPolynomialRing(component.output_bit_size, 'x') + C = BooleanPolynomialRing(component.output_bit_size, 'x') + var_names = [f"x{i}" for i in range(component.output_bit_size)] + d = {} + for i in range(component.output_bit_size): + d[B(var_names[i])] = C(var_names[component.output_bit_size - i - 1]) + + sbox = SBox(component.description) + for i in range(component.input_bit_size): + anf = sbox.component_function(1 << i).algebraic_normal_form() + anf = anf.subs(d) # x0 was msb, now it is the lsb + anfs.append(anf) + anfs.reverse() + return anfs + + def get_monomial_occurences(self, component): + B = BooleanPolynomialRing(component.input_bit_size, 'x') + anfs = self.get_anfs_from_sbox(component) + + anfs = [B(anfs[i]) for i in range(component.input_bit_size)] + monomials = [] + for index, anf in enumerate(anfs): + if index in list(self._occurences[component.id].keys()): + monomials += anf.monomials() + monomials_degree_based = {} + sbox = SBox(component.description) + for deg in range(sbox.max_degree() + 1): + monomials_degree_based[deg] = dict( + Counter([monomial for monomial in monomials if monomial.degree() == deg])) + if deg >= 2: + for monomial in monomials_degree_based[deg].keys(): + deg1_monomials = monomial.variables() + for deg1_monomial in deg1_monomials: + if deg1_monomial not in monomials_degree_based[1].keys(): + monomials_degree_based[1][deg1_monomial] = 0 + monomials_degree_based[1][deg1_monomial] += monomials_degree_based[deg][monomial] + + sorted_monomials_degree_based = {1: {}} + for xi in B.variable_names(): + if B(xi) not in monomials_degree_based[1].keys(): + sorted_monomials_degree_based[1][B(xi)] = 0 + else: + sorted_monomials_degree_based[1][B(xi)] = monomials_degree_based[1][B(xi)] + for deg in range(sbox.max_degree() + 1): + if deg != 1: + sorted_monomials_degree_based[deg] = monomials_degree_based[deg] + + return sorted_monomials_degree_based + + def create_gurobi_vars_sbox(self, component, input_vars_concat): + monomial_occurences = self.get_monomial_occurences(component) + B = BooleanPolynomialRing(component.input_bit_size, 'x') + x = B.variable_names() + + copy_xi = {} + for index, xi in enumerate(monomial_occurences[1].keys()): + nb_occurence_xi = monomial_occurences[1][B(xi)] + if nb_occurence_xi != 0: + copy_xi[B(xi)] = self._model.addVars(list(range(nb_occurence_xi)), vtype=GRB.BINARY, + name="copy_" + input_vars_concat[index].VarName + "_as_" + str(xi)) + self._model.update() + self.set_as_used_variables(list(copy_xi[B(xi)].values())) + self.set_as_used_variables([input_vars_concat[index]]) + for i in range(nb_occurence_xi): + self._model.addConstr(input_vars_concat[index] >= copy_xi[B(xi)][i]) + self._model.addConstr( + sum(copy_xi[B(xi)][i] for i in range(nb_occurence_xi)) >= input_vars_concat[index]) + + copy_monomials_deg = {} + for deg in list(monomial_occurences.keys()): + if deg >= 2: + nb_monomials = sum(monomial_occurences[deg].values()) + copy_monomials_deg[deg] = self._model.addVars(list(range(nb_monomials)), vtype=GRB.BINARY) + self._model.update() + + copy_monomials_deg[1] = copy_xi + degrees = list(copy_monomials_deg.keys()) + for deg in degrees: + if deg >= 2: + copy_monomials_deg[deg]["current"] = 0 + elif deg == 1: + monomials = list(copy_monomials_deg[1].keys()) + for monomial in monomials: + copy_monomials_deg[deg][monomial]["current"] = 0 + self._model.update() + return copy_monomials_deg + + def add_sbox_constraints(self, component): + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + self._model.update() + + B = BooleanPolynomialRing(component.input_bit_size, 'x') + x = B.variable_names() + anfs = self.get_anfs_from_sbox(component) + anfs = [B(anfs[i]) for i in range(component.input_bit_size)] + + copy_monomials_deg = self.create_gurobi_vars_sbox(component, input_vars_concat) + + for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): + constr = 0 + equality = True + monomials = anfs[bit_pos].monomials() + for monomial in monomials: + deg = monomial.degree() + if deg == 1: + current = copy_monomials_deg[deg][monomial]["current"] + constr += copy_monomials_deg[deg][monomial][current] + copy_monomials_deg[deg][monomial]["current"] += 1 + elif deg >= 2: + current = copy_monomials_deg[deg]["current"] + for deg1_monomial in monomial.variables(): + current_deg1 = copy_monomials_deg[1][deg1_monomial]["current"] + self._model.addConstr( + copy_monomials_deg[deg][current] == copy_monomials_deg[1][deg1_monomial][current_deg1]) + self.set_as_used_variables([copy_monomials_deg[deg][current]]) + copy_monomials_deg[1][deg1_monomial]["current"] += 1 + constr += copy_monomials_deg[deg][current] + copy_monomials_deg[deg]["current"] += 1 + elif deg == 0: + equality = False + if equality: + self._model.addConstr(output_vars[index] == constr) + else: + self._model.addConstr(output_vars[index] >= constr) + self._model.update() + + def create_copies_for_linear_layer(self, binary_matrix, input_vars_concat): + copies = {} + for index, var in enumerate(input_vars_concat): + column = [row[index] for row in binary_matrix] + number_of_1s = list(column).count(1) + if number_of_1s > 1: + current = 1 + else: + current = 0 + copies[index] = {} + copies[index][0] = var + copies[index]["current"] = current + self.set_as_used_variables([var]) + new_vars = self._model.addVars(list(range(number_of_1s)), vtype=GRB.BINARY, + name="copy_" + var.VarName) + self._model.update() + for i in range(number_of_1s): + self._model.addConstr(var >= new_vars[i]) + self._model.addConstr( + sum(new_vars[i] for i in range(number_of_1s)) >= var) + self._model.update() + for i in range(1, number_of_1s + 1): + copies[index][i] = new_vars[i - 1] + return copies + + def add_linear_layer_constraints(self, component): + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + + if component.type == "linear_layer": + binary_matrix = component.description + binary_matrix = list(zip(*binary_matrix)) + else: + binary_matrix = binary_matrix_of_linear_component(component) + + copies = self.create_copies_for_linear_layer(binary_matrix, input_vars_concat) + for index_row, row in enumerate(binary_matrix): + constr = 0 + for index_bit, bit in enumerate(row): + if bit: + current = copies[index_bit]["current"] + constr += copies[index_bit][current] + copies[index_bit]["current"] += 1 + self.set_as_used_variables([copies[index_bit][current]]) + self._model.addConstr(output_vars[index_row] == constr) + self._model.update() + + def add_rotate_constraints(self, component): + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + self._model.update() + + rotate_offset = component.description[1] + for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): + self._model.addConstr( + output_vars[index] == input_vars_concat[(bit_pos - rotate_offset) % component.output_bit_size]) + self.set_as_used_variables([input_vars_concat[(bit_pos - rotate_offset) % component.output_bit_size]]) + self._model.update() + + def add_shift_constraints(self, component): + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + self._model.update() + + shift_offset = component.description[1] + + for index, bit_pos in enumerate(self._occurences[component.id].keys()): + target = bit_pos - shift_offset + + if target < 0 or target >= component.output_bit_size: + self._model.addConstr(output_vars[index] == 0) + else: + self._model.addConstr(output_vars[index] == input_vars_concat[target]) + self.set_as_used_variables([input_vars_concat[target]]) + + self._model.update() + + def add_xor_constraints(self, component): + output_vars = self.get_output_vars(component) + output_size = component.output_bit_size + + var_inputs_per_bit = [[] for _ in range(output_size)] + const_bits_per_bit = [[] for _ in range(output_size)] + + current_output_index = 0 + for input_idx, input_name in enumerate(component.input_id_links): + bit_positions = component.input_bit_positions[input_idx] + + for local_idx, pos in enumerate(bit_positions): + output_index = current_output_index % output_size + + if input_name.startswith("constant"): + const_comp = self._cipher.get_component_from_id(input_name) + value = (int(const_comp.description[0], 16) >> + (const_comp.output_bit_size - 1 - pos)) & 1 + const_bits_per_bit[output_index].append(value) + else: + copy_index = len(self._variables[input_name][pos]["copies"]) + copy_var = self._model.addVar( + vtype=GRB.BINARY, + name=f"copy_{copy_index}_{input_name}[{pos}]" + ) + self._variables[input_name][pos]["copies"].append(copy_var) + var_inputs_per_bit[output_index].append(copy_var) + current_output_index += 1 + + self._model.update() + for bit_idx in range(output_size): + vars_sum = sum(var_inputs_per_bit[bit_idx]) + for v in var_inputs_per_bit[bit_idx]: + self.set_as_used_variables([v]) + const_val = sum(const_bits_per_bit[bit_idx]) % 2 + if const_val == 0: + self._model.addConstr(output_vars[bit_idx] == vars_sum) + else: + self._model.addConstr(output_vars[bit_idx] >= vars_sum) + self._model.update() + + def get_output_vars(self, component): + output_vars = [] + tmp = list(self._occurences[component.id].keys()) + tmp.sort() + for i in tmp: + output_vars.append(self._model.getVarByName(f"{component.id}[{i}]")) + self._model.update() + return output_vars + + def get_input_vars(self, component): + input_vars_concat = [] + for index, input_name in enumerate(component.input_id_links): + for pos in component.input_bit_positions[index]: + copy_index = len(self._variables[input_name][pos]["copies"]) + copy = self._model.addVar(vtype=GRB.BINARY, name=f"copy_{copy_index}_{input_name}[{pos}]") + self._variables[input_name][pos]["copies"].append(copy) + input_vars_concat.append(copy) + self._model.update() + return input_vars_concat + + def add_modadd_constraints(self, component): + """ + Constraints are taken from https://eprint.iacr.org/2024/1335.pdf + """ + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + self._model.update() + + total = len(input_vars_concat) + if total % 2 != 0: + raise ValueError("add_modadd_constraints: input length not even") + n = total // 2 + a_bits = input_vars_concat[:n] + b_bits = input_vars_concat[n:2 * n] + z_bits = output_vars + + # Rerverse endianess : index 0 corresponds to LSB now + a_bits = list(reversed(a_bits)) + b_bits = list(reversed(b_bits)) + z_bits = list(reversed(z_bits)) + + # Create carry-out variables for bits 0..n-1 + carry_vars = [None] * n + for i in range(n - 1): + carry_vars[i] = self._model.addVar(vtype=GRB.BINARY, + name=f"modadd_carry_{component.id}_{i}") + # top carry fixed to 0 + carry_vars[n - 1] = self._model.addVar(vtype=GRB.BINARY, lb=0, ub=0, + name=f"modadd_carry_{component.id}_{n - 1}_zero") + self._model.update() + + for i in range(n): + ai = a_bits[i] + bi = b_bits[i] + zi = z_bits[i] + + # carry-in for bit i + if i == 0: + c_in = None # no carry into LSB + else: + c_in = carry_vars[i - 1] + + s_i = self._model.addVar(vtype=GRB.INTEGER, lb=0, ub=3, + name=f"modadd_sum_{component.id}_{i}") + if c_in is not None: + self._model.addConstr(s_i == ai + bi + c_in) + else: + self._model.addConstr(s_i == ai + bi) + + t_i = carry_vars[i] + self._model.addConstr(zi + 2 * t_i == s_i) + + self.set_as_used_variables([ai, bi, zi, t_i, s_i]) + + self._model.update() + + def add_and_constraints(self, component): + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + self._model.update() + + block_size = int(len(input_vars_concat) // component.description[1]) + for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): + self._model.addConstr(output_vars[index] == input_vars_concat[index]) + self._model.addConstr(output_vars[index] == input_vars_concat[index + block_size]) + self.set_as_used_variables([input_vars_concat[index], input_vars_concat[index + block_size]]) + self._model.update() + + def add_fsr_constraints(self, component): + output_bit_size = component.output_bit_size + + output_vars = {} + tmp = list(self._occurences[component.id].keys()) + tmp.sort() + for i in tmp: + output_vars[i] = self._model.getVarByName(f"{component.id}[{i}]") + + input_vars_concat = self.get_input_vars(component) + self._model.update() + + interm_input_vars = self._model.addVars(list(range(output_bit_size)), vtype=GRB.BINARY, name=f"interm_input") + for i in range(output_bit_size): + self._model.addConstr(interm_input_vars[i] == input_vars_concat[i]) + self.set_as_used_variables([input_vars_concat[i]]) + + if len(component.description) == 2: + number_of_initialization_clocks = 1 + else: + number_of_initialization_clocks = component.description[-1] + + registers = component.description[0] + registers_lengths = [registers[i][0] for i in range(len(registers))] + registers_lengths_accumulated = [0] + for value in registers_lengths: + registers_lengths_accumulated.append(registers_lengths_accumulated[-1] + value) + + s = {} + s[0] = list(interm_input_vars.values()) + + for clock in range(number_of_initialization_clocks): + tmp = s[clock][:] + self._model.update() + + new_bits = [] + for register in registers: + polynomial = 0 + monomials_indexes = register[1] + for indexes in monomials_indexes: + if len(indexes) > 1: + a = self._model.addVar(vtype=GRB.BINARY) + self._model.update() + y = self._model.addVars(indexes, vtype=GRB.BINARY) + for index in indexes: + self._model.addConstr(y[index] <= tmp[index]) + self._model.addConstr(a <= tmp[index]) + self._model.addConstr(y[index] + a >= tmp[index]) + tmp[index] = y[index] + monomial = a + else: + index = indexes[0] + if index not in registers_lengths_accumulated: + y = self._model.addVar(vtype=GRB.BINARY) + z = self._model.addVar(vtype=GRB.BINARY) + self._model.addConstr(y <= tmp[index]) + self._model.addConstr(z <= tmp[index]) + self._model.addConstr(y + z >= tmp[index]) + monomial = z + tmp[index] = y + else: + monomial = tmp[index] + polynomial += monomial + polynomial_var = self._model.addVar(vtype=GRB.BINARY, name=f"product_{register[0]}_clock_{clock}") + self._model.update() + self._model.addConstr(polynomial_var == polynomial) + new_bits.append(polynomial_var) + self._model.update() + + new_bits = new_bits[-1:] + new_bits[:-1] + for index, length in enumerate(registers_lengths_accumulated[:-1]): + tmp[length] = new_bits[index] + + self._model.update() + s[clock + 1] = [] + for index in range(output_bit_size): + s[clock + 1].append(tmp[(index + 1) % output_bit_size]) + + interm_output_vars = self._model.addVars(list(range(output_bit_size)), vtype=GRB.BINARY, + name=f"interm_{component.id}_output") + self._model.update() + self._variables[f"interm_{component.id}_output"] = {} + for index, var in enumerate(interm_output_vars.values()): + self._variables[f"interm_{component.id}_output"][index] = {"original": var, "copies": []} + + for position in range(component.output_bit_size): + self._model.addConstr(interm_output_vars[position] == s[number_of_initialization_clocks][position]) + + self._model.update() + for position in list(self._occurences[component.id].keys()): + self._model.addConstr(output_vars[position] == interm_output_vars[position]) + self.set_as_used_variables([interm_output_vars[position]]) + + self._model.update() + + def add_not_constraints(self, component): + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + self._model.update() + + for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): + self._model.addConstr(output_vars[index] >= input_vars_concat[index]) + self.set_as_used_variables([input_vars_concat[index]]) + self._model.update() + + def add_constant_constraints(self, component): + self._constants[component.id] = {} + output_vars = self.get_output_vars(component) + + if component.description[0].startswith("0b"): + const = int(component.description[0], 2) + elif component.description[0].startswith("0x"): + const = int(component.description[0], 16) + else: + raise ValueError("Unknown format: must start with 0b or 0x") + + for i, bit_pos in enumerate(list(self._occurences[component.id].keys())): + if (const >> (component.output_bit_size - 1 - i)) & 1 == 0: + self._model.addConstr(output_vars[i] == 0) + self._constants[component.id][i] = 0 + else: + self._constants[component.id][i] = 1 + self._model.update() + + def add_or_constraints(self, component): + """ + The OR operation is modeled as: + y = OR(x1, x2, ..., xn) + Then: + - y >= xi for each input xi + - y <= sum(xi) + """ + output_vars = self.get_output_vars(component) + output_size = component.output_bit_size + + var_inputs_per_bit = [[] for _ in range(output_size)] + + for input_idx, input_name in enumerate(component.input_id_links): + bit_positions = component.input_bit_positions[input_idx] + + for local_idx, pos in enumerate(bit_positions): + output_index = pos % output_size + + copy_index = len(self._variables[input_name][pos]["copies"]) + copy_var = self._model.addVar( + vtype=GRB.BINARY, + name=f"copy_{copy_index}_{input_name}[{pos}]" + ) + self._variables[input_name][pos]["copies"].append(copy_var) + var_inputs_per_bit[output_index].append(copy_var) + + self._model.update() + + for bit_idx in range(output_size): + input_vars = var_inputs_per_bit[bit_idx] + output_var = output_vars[bit_idx] + + if not input_vars: + continue + + for v in input_vars: + self._model.addConstr(output_var >= v) + self._model.addConstr(output_var <= sum(input_vars)) + self.set_as_used_variables(input_vars) + self._model.update() + + def add_intermediate_output_constraints(self, component): + output_vars = self.get_output_vars(component) + input_vars_concat = self.get_input_vars(component) + self._model.update() + + for index, bit_pos in enumerate(list(self._occurences[component.id].keys())): + self._model.addConstr(output_vars[index] == input_vars_concat[bit_pos]) + self.set_as_used_variables([input_vars_concat[bit_pos]]) + self._model.update() + + def get_cipher_output_component_id(self): + for component in self._cipher.get_all_components(): + if component.type == "cipher_output": + return component.id + + def add_constraints(self, predecessors, input_id_link_needed, block_needed): + self.build_gurobi_model() + self.create_gurobi_vars_from_all_components(predecessors, input_id_link_needed, block_needed) + + used_predecessors_sorted = self.order_predecessors(list(self._occurences.keys())) + self._used_predecessors_sorted = used_predecessors_sorted + for component_id in used_predecessors_sorted: + if component_id not in self._cipher.inputs: + component = self._cipher.get_component_from_id(component_id) + print(f"---> {component.id}") if verbosity else None + if component.type == "sbox": + self.add_sbox_constraints(component) + elif component.type == "fsr": + self.add_fsr_constraints(component) + elif component.type == "constant": + self.add_constant_constraints(component) + elif component.type in ["linear_layer", "mix_column"]: + self.add_linear_layer_constraints(component) + elif component.type in ["cipher_output", "intermediate_output"]: + self.add_intermediate_output_constraints(component) + elif component.type == "word_operation": + if component.description[0] == "XOR": + self.add_xor_constraints(component) + elif component.description[0] == "ROTATE": + self.add_rotate_constraints(component) + elif component.description[0] == "SHIFT": + self.add_shift_constraints(component) + elif component.description[0] == "AND": + self.add_and_constraints(component) + elif component.description[0] == "NOT": + self.add_not_constraints(component) + elif component.description[0] == "OR": + self.add_or_constraints(component) + elif component.description[0] == "MODADD": + self.add_modadd_constraints(component) + else: + raise NotImplementedError(f"Component {component.description[0]} is not yet implemented") + else: + raise NotImplementedError(f"Component {component.description[0]} is not yet implemented") + + return self._model + + def get_where_component_is_used(self, predecessors, input_id_link_needed, block_needed): + occurences = {} + ids = self._cipher.inputs + predecessors + for name in ids: + for component_id in predecessors: + component = self._cipher.get_component_from_id(component_id) + if name in component.input_id_links: + indexes = [i for i, j in enumerate(component.input_id_links) if j == name] + if name not in occurences.keys(): + occurences[name] = [] + for index in indexes: + occurences[name].append(component.input_bit_positions[index]) + if input_id_link_needed in self._cipher.inputs: + occurences[input_id_link_needed] = [block_needed] + else: + component = self._cipher.get_component_from_id(input_id_link_needed) + occurences[input_id_link_needed] = [[i for i in range(component.output_bit_size)]] + + cipher_id = self.get_cipher_output_component_id() + if input_id_link_needed == cipher_id: + component = self._cipher.get_component_from_id(cipher_id) + occurences[cipher_id] = [[i for i in range(component.output_bit_size)]] + + occurences_final = {} + for component_id in occurences.keys(): + occurences_final[component_id] = self.find_copy_indexes(occurences[component_id]) + + self._occurences = occurences_final + return occurences_final + + def find_copy_indexes(self, input_bit_positions): + l = {} + for input_bit_position in input_bit_positions: + for pos in input_bit_position: + if pos not in l.keys(): + l[pos] = 0 + l[pos] += 1 + return l + + def order_predecessors(self, used_predecessors): + for component_id in self._cipher.inputs: + if component_id in list(self._occurences.keys()): + used_predecessors.remove(component_id) + tmp = {} + final = {} + for r in range(self._cipher.number_of_rounds): + tmp[r] = {} + for component_id in used_predecessors: + if int(component_id.split("_")[-2]) == r: + tmp[r][component_id] = int(component_id.split("_")[-1]) + final[r] = {k: v for k, v in sorted(tmp[r].items(), key=lambda item: item[1])} + + used_predecessors_sorted = [] + for r in range(self._cipher.number_of_rounds): + used_predecessors_sorted += list(final[r].keys()) + + l = [] + for component_id in self._cipher.inputs: + if component_id in list(self._occurences.keys()): + l.append(component_id) + used_predecessors_sorted = l + used_predecessors_sorted + return used_predecessors_sorted + + def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_needed, block_needed): + occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed) + all_vars = {} + used_predecessors_sorted = self.order_predecessors(list(occurences.keys())) + cipher_id = self.get_cipher_output_component_id() + for component_id in used_predecessors_sorted: + all_vars[component_id] = {} + if component_id != cipher_id: + for pos in list(occurences[component_id].keys()): + all_vars[component_id][pos] = {} + all_vars[component_id][pos]["original"] = self._model.addVar(vtype=GRB.BINARY, + name=component_id + f"[{pos}]") + all_vars[component_id][pos]["copies"] = [] + else: + component = self._cipher.get_component_from_id(cipher_id) + for pos in range(component.output_bit_size): + all_vars[component_id][pos] = {} + all_vars[component_id][pos]["original"] = self._model.addVar(vtype=GRB.BINARY, + name=component_id + f"[{pos}]") + all_vars[component_id][pos]["copies"] = [] + + self._model.update() + self._variables = all_vars + + def find_index_second_input(self): + occurences = self._occurences + return len(list(occurences[self._cipher.inputs[0]].keys())) + + def build_generic_model_for_specific_output_bit(self, output_bit_index, fixed_degree=None, + which_var_degree=None, + chosen_cipher_output=None): + start = time.time() + + if chosen_cipher_output != None: + input_id_link_needed = chosen_cipher_output + else: + input_id_link_needed = self.get_cipher_output_component_id() + component = self._cipher.get_component_from_id(input_id_link_needed) + block_needed = list(range(component.output_bit_size)) + output_bit_index_previous_comp = output_bit_index + + G = create_networkx_graph_from_input_ids(self._cipher) + predecessors = list(_get_predecessors_subgraph(G, [input_id_link_needed])) + for input_id in self._cipher.inputs + ['']: + if input_id in predecessors: + predecessors.remove(input_id) + + self.add_constraints(predecessors, input_id_link_needed, block_needed) + + var_from_block_needed = [] + for i in block_needed: + var_from_block_needed.append(self._variables[input_id_link_needed][i]["original"]) + + output_vars = self._model.addVars(list(range(len(block_needed))), vtype=GRB.BINARY, name="output") + self._variables["output"] = output_vars + output_vars = list(output_vars.values()) + self._model.update() + + for i in range(len(block_needed)): + self._model.addConstr(output_vars[i] == var_from_block_needed[i]) + self.set_as_used_variables([output_vars[i], var_from_block_needed[i]]) + + ks = self._model.addVar() + self._model.addConstr(ks == sum(output_vars[i] for i in range(len(block_needed)))) + self._model.addConstr(ks == 1) + self._model.addConstr(output_vars[output_bit_index_previous_comp] == 1) + + if fixed_degree is not None: + if which_var_degree is not None: + var_input_name = next( + (inp for inp in self._cipher.inputs if inp.startswith(which_var_degree)), + None + ) + if var_input_name is None: + raise ValueError(f"No input found matching prefix '{which_var_degree}'") + else: + var_input_name = self._cipher.inputs[0] + + input_index = self._cipher.inputs.index(var_input_name) + input_size = self._cipher.inputs_bit_size[input_index] + + vars_to_constrain = [] + for i in range(input_size): + v = self._model.getVarByName(f"{var_input_name}[{i}]") + if v is not None: + vars_to_constrain.append(v) + + self._model.addConstr(sum(vars_to_constrain) == fixed_degree, + name=f"degree_{var_input_name}_{fixed_degree}") + + self.set_unused_variables_to_zero() + self.create_all_copies() + self._model.update() + end = time.time() + building_time = end - start + if verbosity: + print(f"########## building_time : {building_time}") + self._model.update() + + def _prefix_for_input(self, name: str) -> str: + return name[:1].lower() + + def get_solutions(self): + start = time.time() + solCount = self._model.SolCount + inputs = [] + for prio, inp_name in enumerate(self._cipher.inputs): + if inp_name not in self._variables: + continue + prefix = self._prefix_for_input(inp_name) + for idx, d in self._variables[inp_name].items(): + inputs.append((prio, prefix, idx, d["original"])) + inputs.sort(key=lambda t: (t[0], t[1], t[2])) + + mono_set = set() + for sn in range(solCount): + self._model.setParam(GRB.Param.SolutionNumber, sn) + toks = [] + for _, prefix, idx, var in inputs: + if var.Xn > 0.5: + toks.append(f"{prefix}{idx}") + mono = "1" if not toks else "".join(toks) + if mono in mono_set: + mono_set.remove(mono) + else: + mono_set.add(mono) + end = time.time() + printing_time = end - start + if verbosity: + print('Number of solutions (might cancel each other) found: ' + str(solCount)) + print(f"########## printing_time : {printing_time}") + print(f'Number of monomials found: {len(mono_set)}') + monomials_list = sorted(mono_set) + return self.anf_list_to_boolean_poly(monomials_list) + + def optimize_model(self): + start = time.time() + self._model.optimize() + end = time.time() + solving_time = end - start + if verbosity: + print(self._model) + print(f"########## solving_time : {solving_time}") + + def anf_list_to_boolean_poly(self, anf_list): + variables = [] + for index, input_name in enumerate(self._cipher.inputs): + bit_size = self._cipher.inputs_bit_size[index] + variables.extend([f"{input_name[0]}{i}" for i in range(bit_size)]) + + B = BooleanPolynomialRing(names=variables) + var_map = {str(v): B(str(v)) for v in variables} + + poly = B(0) + for term in anf_list: + if term == "1": + term_poly = B(1) + else: + i = 0 + factors = [] + while i < len(term): + var = term[i] + i += 1 + digits = '' + while i < len(term) and term[i].isdigit(): + digits += term[i] + i += 1 + factors.append(var_map[f"{var}{digits}"]) + term_poly = factors[0] + for f in factors[1:]: + term_poly *= f + poly += term_poly + return poly + + def get_boolean_polynomial_ring(self): + variables = [] + for index, input_name in enumerate(self._cipher.inputs): + bit_size = self._cipher.inputs_bit_size[index] + variables.extend([f"{input_name[0]}{i}" for i in range(bit_size)]) + R = BooleanPolynomialRing(names=variables) + return R + + def var_list_to_input_positions(self, var_list): + """ + Convert flat variable names (e.g., ``['p1', 'k8']``) into structured + input references tied to the cipher's input components. + + Each variable name's first letter (e.g., ``'p'``, ``'k'``, ``'i'``) + is mapped to its corresponding input (e.g., ``'plaintext'``, ``'key'``, + ``'initialisation_vector'``), and its numeric suffix is treated as the bit index. + For example, ``['p1', 'k8']`` → ``[('plaintext', 1), ('key', 8)]``. + """ + input_map = {} + for index, input_name in enumerate(self._cipher.inputs): + bit_size = self._cipher.inputs_bit_size[index] + prefix = input_name[0] # e.g., 'p' for plaintext, 'k' for key + input_map[prefix] = (input_name, bit_size) + + results = [] + for var in var_list: + prefix = var[0] + index = int(var[1:]) + input_name, bit_size = input_map[prefix] + + if index >= bit_size: + raise ValueError(f"Index {index} out of range for input '{input_name}' (size {bit_size})") + results.append((input_name, index)) + return results + + def re_init(self): + self._variables = None + self._model = None + self._occurences = None + self._used_variables = [] + self._variables_as_list = [] + self._unused_variables = [] + self._used_predecessors_sorted = None + self._constants = {} + + def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, which_var_degree=None, + chosen_cipher_output=None): + """ + Build and solve the MILP model to compute the Algebraic Normal Form (ANF) + of a specific output bit of the cipher using the Monomial Prediction (MP) approach. + + By default, the model enumerates all possible monomials contributing to the selected output bit. + Optionally, a degree constraint can be applied to restrict the search to monomials of a fixed degree. + + INPUT: + + - ``output_bit_index`` -- **integer**; index of the ciphertext bit whose ANF is to be computed. + - ``fixed_degree`` -- **integer** (default: ``None``); if not ``None``, only monomials + whose degree equals this value are returned. + - ``which_var_degree`` -- **string** (default: ``None``); prefix or full name of the input + variable on which the degree constraint (``fixed_degree``) is applied. + Typical values include: + * ``"p"`` or ``"plaintext"`` for plaintext variables + * ``"k"`` or ``"key"`` for key variables + * ``"i"`` for initialization vector variables + If ``None``, defaults to the first input listed in ``self._cipher.inputs``. + - ``chosen_cipher_output`` -- **string** (default: ``None``); specify a cipher component + ID if you want to compute the ANF for an intermediate output instead of the final cipher output. + + EXAMPLES:: + + # Example 1: Compute the ANF of the first ciphertext bit in SIMON (round 1) + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: cipher = SimonBlockCipher(number_of_rounds=1) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: milp.find_anf_of_specific_output_bit(0) # doctest: +SKIP + sage: R = milp.get_boolean_polynomial_ring() # doctest: +SKIP + sage: anf == R("p1*p8 + p2 + p16 + k48") # doctest: +SKIP + ... + + # Example 2: Restrict the analysis to degree-2 monomials on plaintext variables + sage: anf = milp.find_anf_of_specific_output_bit(0, fixed_degree=2, which_var_degree="p") # doctest: +SKIP + sage: anf == R("p1*p8") # doctest: +SKIP + ... + + # Example 3: Restrict the analysis to degree-1 monomials on key variables + sage: milp.find_anf_of_specific_output_bit(0, fixed_degree=1, which_var_degree="k") # doctest: +SKIP + sage: anf == R("k48") # doctest: +SKIP + ... + """ + + self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, which_var_degree, + chosen_cipher_output) + self._model.setParam("PoolSolutions", 200000000) + self._model.setParam(GRB.Param.PoolSearchMode, 2) + + self.optimize_model() + anf = self.get_solutions() + self._log_experiment( + "anf", + { + "output_bit_index": output_bit_index, + "fixed_degree": fixed_degree, + "which_var_degree": which_var_degree, + "chosen_cipher_output": chosen_cipher_output, + }, + anf + ) + + return anf + + def check_anf_correctness(self, output_bit_index, num_tests=10, endian="msb"): + """ + Verify the correctness of the computed Algebraic Normal Form (ANF) + for a specific cipher output bit by random testing. + + This method compares the value of an output bit obtained from the + cipher evaluation and from its ANF evaluation, across several + random input assignments. + + INPUT: + + - ``output_bit_index`` -- **integer**; index (0-based) of the output bit to test. + The indexing direction depends on the ``endian`` parameter. + - ``num_tests`` -- **integer** (default: ``10``); number of random input assignments + to test. + - ``endian`` -- **string** (default: ``"msb"``); defines how bit positions are indexed + and extracted: + * ``"msb"`` : bit index 0 corresponds to the most significant bit (default) + * ``"lsb"`` : bit index 0 corresponds to the least significant bit + + OUTPUT: + + - **bool**; returns ``True`` if the ANF output matches the cipher output + for all tested input assignments, ``False`` otherwise. + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: cipher = SimonBlockCipher(number_of_rounds=2) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: milp.check_anf_correctness(0, endian="msb") # doctest: +SKIP + ... + """ + + # 1) Generate random test vectors for all cipher inputs + test_vectors = [] + for _ in range(num_tests): + assignment = {} + for inp, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size): + assignment[inp] = secrets.randbits(size) + test_vectors.append(assignment) + + # 2) Compute the ANF for the specified output bit + anf_poly = self.find_anf_of_specific_output_bit(output_bit_index) + print("ANF:", anf_poly) if verbosity else None + + B = self.get_boolean_polynomial_ring() + + # 3) Helper: evaluate the ANF polynomial for a given input assignment + def evaluate_poly(assignments): + var_values = {} + for inp, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size): + val = assignments[inp] + for i in range(size): + if endian == "msb": + # MSB-first: inp0 = MSB, inp{n-1} = LSB + bit = (val >> (size - 1 - i)) & 1 + elif endian == "lsb": + # LSB-first: inp0 = LSB, inp{n-1} = MSB + bit = (val >> i) & 1 + else: + raise ValueError("Invalid endian value. Use 'msb' or 'lsb'.") + var_values[f"{inp[0]}{i}"] = bit + return int(GF(2)(anf_poly(**var_values))) + + # 4) Evaluate and compare ANF vs cipher outputs + output_size = self._cipher.output_bit_size + for trial, assign in enumerate(test_vectors): + print(f"trial = {trial}") if verbosity else None + cipher_output = self._cipher.evaluate( + [assign[inp] for inp in self._cipher.inputs] + ) + if endian == "msb": + real_index = output_size - 1 - output_bit_index + else: + real_index = output_bit_index + + expected_bit = (cipher_output >> real_index) & 1 + computed_bit = evaluate_poly(assign) + + if expected_bit != computed_bit: + return False + return True + + def find_superpoly_of_specific_output_bit(self, output_bit_index, cube, chosen_cipher_output=None): + """ + Compute the superpoly of a specific cipher output bit under a given cube. + + INPUT: + + - ``cube`` -- **list of strings**; variable names forming the cube. + Each variable follows the convention: + * ``"i"`` prefix for IV bits + * ``"p"`` prefix for plaintext bits + Example: ``["i9", "i19", "i29", "i39", "i49", "i59", "i69", "i79"]``. + - ``output_bit_index`` -- **integer**; index (0-based, counting from the most significant bit) + of the cipher output bit for which the superpoly is computed. + - ``chosen_cipher_output`` -- **string** (default: ``None``); specify a cipher component + ID if the computation targets an intermediate output instead of the final cipher output. + + OUTPUT: + + - **BooleanPolynomial**; the resulting superpoly polynomial in the Boolean ring. + + EXAMPLES:: + + sage: from claasp.ciphers.stream_ciphers.trivium_stream_cipher import TriviumStreamCipher + sage: cipher = TriviumStreamCipher(keystream_bit_len=1, number_of_initialization_clocks=590) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: cube = ["i9", "i19", "i29", "i39", "i49", "i59", "i69", "i79"] + sage: superpoly = milp.find_superpoly_of_specific_output_bit(output_bit_index=0, cube) # doctest: +SKIP + sage: R = milp.get_boolean_polynomial_ring() # doctest: +SKIP + sage: superpoly == R("k20*i60*i61 + k20*i60*i74 + k20*i60 + k20*i73 + i8*i60*i61 + i8*i60*i74 + i8*i60 + i8*i73 + i60*i61*i71 + i60*i61*i72*i73 + i60*i71*i74 + i60*i71 + i60*i72*i73*i74 + i60*i72*i73 + i71*i73 + i72*i73") # doctest: +SKIP + ... + """ + + fixed_degree = None + which_var_degree = None + self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, which_var_degree, + chosen_cipher_output) + self._model.setParam("PoolSolutions", 200000000) + self._model.setParam(GRB.Param.PoolSearchMode, 2) + + # Convert compact cube names like "i9" -> ("initialisation_vector", 9) + cube_verbose = self.var_list_to_input_positions(cube) + + for term in cube_verbose: + var_term = self._model.getVarByName(f"{term[0]}[{term[1]}]") + self._model.update() + self._model.addConstr(var_term == 1) + + self._model.update() + self.optimize_model() + poly = self.get_solutions() + + assignments = {v: 1 for v in cube} + poly_sub = poly.subs(assignments) + + self._log_experiment( + "superpoly", + { + "output_bit_index": output_bit_index, + "chosen_cipher_output": chosen_cipher_output, + "cube": cube, + }, + poly_sub + ) + + return poly_sub + + def find_exact_degree_of_superpoly_of_specific_output_bit( + self, output_bit_index, cube, chosen_cipher_output=None): + """ + Compute the exact algebraic degree of the superpoly + corresponding to a specific cipher output bit under a given cube. + + INPUT: + + - ``cube`` -- list[str]; variable names forming the cube (e.g. ["i9", "i19", ...]) + - ``output_bit_index`` -- int; index (0-based, MSB-first) of the cipher output bit. + - ``chosen_cipher_output`` -- str | None; specify a cipher component ID if + targeting an intermediate output. + + OUTPUT: + - integer; exact algebraic degree of the superpoly (with respect to key variables). + """ + + # === 1. Build generic model for the chosen output bit + fixed_degree = None + which_var_degree = None + self.build_generic_model_for_specific_output_bit( + output_bit_index, fixed_degree, which_var_degree, chosen_cipher_output) + + m = self._model + m.Params.OutputFlag = 0 + m.setParam(GRB.Param.PoolSearchMode, 2) + m.setParam(GRB.Param.PoolSolutions, 200000000) + m.setParam(GRB.Param.PoolGap, 0.0) + + cube_verbose = self.var_list_to_input_positions(cube) + for term in cube_verbose: + var_term = m.getVarByName(f"{term[0]}[{term[1]}]") + if var_term is not None: + m.addConstr(var_term == 1) + m.update() + + key_input_index = None + for i, inp in enumerate(self._cipher.inputs): + if inp.startswith("k"): + key_input_index = i + break + if key_input_index is None: + raise ValueError("No key input found in cipher definition.") + + key_size = self._cipher.inputs_bit_size[key_input_index] + key_vars = [ + m.getVarByName(f"key[{i}]") for i in range(key_size) + if m.getVarByName(f"key[{i}]") is not None + ] + + m.setObjective(sum(key_vars), GRB.MAXIMIZE) + m.update() + m.optimize() + + degree_drop = False + if m.Status not in [GRB.OPTIMAL, GRB.SUBOPTIMAL]: + print(f"[INFO] Model is infeasible") if verbosity else None + exact_degree = -1 + else: + d = int(round(m.ObjVal)) + monomial_parity = {} + for s in range(m.SolCount): + m.Params.SolutionNumber = s + active_indices = tuple(i for i, v in enumerate(key_vars) if v.Xn > 0.5) + if len(active_indices) == d: + monomial_parity[active_indices] = monomial_parity.get(active_indices, 0) ^ 1 + + if any(val == 1 for val in monomial_parity.values()): + exact_degree = d + else: + degree_drop = True + exact_degree = d - 1 + + self._log_experiment( + "exact degree superpoly", + { + "output_bit_index": output_bit_index, + "chosen_cipher_output": chosen_cipher_output, + "cube": cube, + "degree_drop": degree_drop + }, + exact_degree + ) + + return exact_degree + + def find_exact_degree_of_superpoly_of_all_output_bits(self, cube, chosen_cipher_output=None): + """ + Compute the exact algebraic degree of the superpoly + for all output bits of the cipher under a given cube. + + INPUT: + + - ``cube`` -- list[str]; variable names forming the cube (e.g. ["i9", "i19", ...]) + - ``chosen_cipher_output`` -- str | None; specify a cipher component ID if + the computation targets an intermediate output instead of the final cipher output. + + OUTPUT: + - list[int]; exact algebraic degrees of the superpoly for each output bit. + """ + + global verbosity + old_verbosity = verbosity + verbosity = False + + degrees = [] + output_size = self._cipher.output_bit_size + + for i in range(output_size): + self.re_init() + deg_exact = self.find_exact_degree_of_superpoly_of_specific_output_bit( + i, cube, chosen_cipher_output) + degrees.append(deg_exact) + + verbosity = old_verbosity + self._log_experiment( + "all output bits exact degree superpoly", + { + "chosen_cipher_output": chosen_cipher_output, + "cube": cube + }, + degrees + ) + + return degrees + + def find_upper_bound_degree_of_specific_output_bit(self, output_bit_index, which_var_degree=None, + chosen_cipher_output=None): + """ + Compute an upper bound on the algebraic degree of a specific cipher output bit + with respect to a chosen input variable (e.g., key, IV, or plaintext). + + INPUT: + + - ``output_bit_index`` -- **integer**; index (0-based, counting from the most significant bit) + of the cipher output bit to analyze. + - ``which_var_degree`` -- **string** (default: ``None``); prefix identifying which + input the algebraic degree should be computed over: + * ``"k"`` → degree with respect to key bits + * ``"p"`` → degree with respect to plaintext bits + * ``"i"`` → degree with respect to IV bits + If ``None`` (default), the first input listed in ``self._cipher.inputs`` is used. + - ``chosen_cipher_output`` -- **string** (default: ``None``); specify a cipher component + ID if the computation targets an intermediate output instead of the final cipher output. + + OUTPUT: + + - **integer**; upper bound on the algebraic degree of the selected output bit + with respect to the chosen input variable group. + + EXAMPLES:: + + sage: from claasp.ciphers.stream_ciphers.trivium_stream_cipher import TriviumStreamCipher + sage: cipher = TriviumStreamCipher(keystream_bit_len=1, number_of_initialization_clocks=508) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: milp.find_upper_bound_degree_of_specific_output_bit(0, which_var_degree="i") # doctest: +SKIP + ... + """ + fixed_degree = None + self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, which_var_degree, + chosen_cipher_output) + + self._model.setParam(GRB.Param.PoolSearchMode, 0) # single optimal solution (fastest) + self._model.setParam("MIPGap", 0) + self._model.Params.OutputFlag = 0 + + if which_var_degree is None: + target_inputs = [(self._cipher.inputs[0], self._cipher.inputs_bit_size[0])] + else: + target_inputs = [ + (inp, size) + for inp, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + if inp.startswith(which_var_degree) + ] + + vars_target = [] + for inp, size in target_inputs: + for i in range(size): + var = self._model.getVarByName(f"{inp}[{i}]") + if var is not None: + vars_target.append(var) + + self._model.setObjective(sum(vars_target), GRB.MAXIMIZE) + self._model.update() + self.optimize_model() + + if self._model.Status not in [GRB.OPTIMAL, GRB.SUBOPTIMAL]: + print(f"[INFO] Model is infeasible") if verbosity else None + degree_upper_bound = -1 + else: + degree_upper_bound = int(round(self._model.ObjVal)) + + self._log_experiment( + "upper bound degree", + { + "output_bit_index": output_bit_index, + "chosen_cipher_output": chosen_cipher_output, + "which_var_degree": which_var_degree, + }, + degree_upper_bound + ) + + return degree_upper_bound + + def find_upper_bound_degree_of_all_output_bits(self, which_var_degree=None, chosen_cipher_output=None): + """ + Compute the upper bound on the algebraic degree for all cipher output bits. + + INPUT: + + - ``which_var_degree`` -- **string** (default: ``None``); prefix indicating which + variable group the degree should be computed over: + * ``"k"`` → key bits + * ``"p"`` → plaintext bits + * ``"i"`` → IV bits + If ``None`` (default), the degree is computed with respect to the first input + listed in ``self._cipher.inputs``. + - ``chosen_cipher_output`` -- **string** (default: ``None``); specify a cipher + component ID if the computation targets an intermediate output (e.g., after a + given round) instead of the final cipher output. + + OUTPUT: + + - **list of integers**; upper bounds on the algebraic degrees of all cipher output bits. + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: cipher = SimonBlockCipher(number_of_rounds=4) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: milp.find_upper_bound_degree_of_all_output_bits(which_var_degree="p") # doctest: +SKIP + ... + """ + global verbosity + old_verbosity = verbosity + verbosity = False + degrees = [] + for i in range(self._cipher.output_bit_size): + self.re_init() + degree = self.find_upper_bound_degree_of_specific_output_bit( + i, which_var_degree=which_var_degree, chosen_cipher_output=chosen_cipher_output + ) + degrees.append(degree) + verbosity = old_verbosity + + self._log_experiment( + "all output bits upper bound degree", + { + "chosen_cipher_output": chosen_cipher_output, + "which_var_degree": which_var_degree, + }, + degrees + ) + + return degrees + + def find_exact_degree_of_specific_output_bit(self, output_bit_index, which_var_degree=None, + chosen_cipher_output=None): + """ + Compute the exact algebraic degree of a specific cipher output bit + with respect to a chosen input variable group (e.g., key, IV, or plaintext). + + Unlike the upper-bound computation, this method enumerates all optimal MILP + solutions corresponding to maximal-degree monomials and checks their parity + (mod 2). The exact algebraic degree is the highest degree for which the number + of monomials with that degree is odd. + + INPUT: + + - ``output_bit_index`` -- **integer**; index (0-based, counting from the most significant bit) + of the cipher output bit to analyze. + - ``which_var_degree`` -- **string** (default: ``None``); prefix identifying which + input group the algebraic degree should be computed over: + * ``"k"`` → degree with respect to key bits + * ``"p"`` → degree with respect to plaintext bits + * ``"i"`` → degree with respect to IV bits + If ``None`` (default), the first input listed in ``self._cipher.inputs`` is used. + - ``chosen_cipher_output`` -- **string** (default: ``None``); specify a cipher component + ID if the computation targets an intermediate output instead of the final cipher output. + + OUTPUT: + + - **integer**; exact algebraic degree of the selected output bit with respect to + the chosen input variable group. + + EXAMPLES:: + + sage: from claasp.ciphers.stream_ciphers.trivium_stream_cipher import TriviumStreamCipher + sage: cipher = TriviumStreamCipher(keystream_bit_len=1, number_of_initialization_clocks=508) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: milp.find_exact_degree_of_specific_output_bit(0, which_var_degree="i") # doctest: +SKIP + ... + """ + + fixed_degree = None + self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, which_var_degree, + chosen_cipher_output) + + m = self._model + m.Params.OutputFlag = 0 + m.setParam(GRB.Param.PoolSearchMode, 2) # enumerate all optimal solutions + m.setParam(GRB.Param.PoolSolutions, 200000000) # large enough for enumeration + m.setParam(GRB.Param.PoolGap, 0.0) # ensure only optimal solutions are put in the Pool + + if which_var_degree is None: + target_inputs = [(self._cipher.inputs[0], self._cipher.inputs_bit_size[0])] + else: + target_inputs = [ + (inp, size) + for inp, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + if inp.startswith(which_var_degree) + ] + + vars_target = [] + for inp, size in target_inputs: + for i in range(size): + var = m.getVarByName(f"{inp}[{i}]") + if var is not None: + vars_target.append(var) + + m.setObjective(sum(vars_target), GRB.MAXIMIZE) + m.update() + m.optimize() + + degree_drop = False + if self._model.Status not in [GRB.OPTIMAL, GRB.SUBOPTIMAL]: + print(f"[INFO] Model is infeasible") if verbosity else None + exact_degree = -1 + else: + d = int(round(m.ObjVal)) + # Gather all distinct monomials of degree d and compute parity + monomial_parity = {} + for s in range(m.SolCount): + m.Params.SolutionNumber = s + active_indices = tuple(i for i, v in enumerate(vars_target) if v.Xn > 0.5) + if len(active_indices) == d: + monomial_parity[active_indices] = monomial_parity.get(active_indices, 0) ^ 1 + + if any(val == 1 for val in monomial_parity.values()): + exact_degree = d + else: + degree_drop = True + exact_degree = d - 1 + + self._log_experiment( + "exact degree", + { + "output_bit_index": output_bit_index, + "chosen_cipher_output": chosen_cipher_output, + "which_var_degree": which_var_degree, + "degree_drop": degree_drop + }, + exact_degree + ) + + return exact_degree + + def find_exact_degree_of_all_output_bits(self, which_var_degree=None, chosen_cipher_output=None): + """ + Compute the exact algebraic degree for all cipher output bits. + + INPUT: + + - ``which_var_degree`` -- **string** (default: ``None``); prefix indicating which + variable group the algebraic degree should be computed over: + * ``"k"`` → key bits + * ``"p"`` → plaintext bits + * ``"i"`` → IV bits + If ``None`` (default), the degree is computed with respect to the first input + listed in ``self._cipher.inputs``. + - ``chosen_cipher_output`` -- **string** (default: ``None``); specify a cipher + component ID if the computation targets an intermediate output instead of the final cipher output. + + OUTPUT: + + - **list of integers**; exact algebraic degrees of all cipher output bits. + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: cipher = SimonBlockCipher(number_of_rounds=4) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: milp.find_exact_degree_of_all_output_bits(which_var_degree="p") # doctest: +SKIP + ... + """ + global verbosity + old_verbosity = verbosity + verbosity = False + degrees = [] + for i in range(self._cipher.output_bit_size): + self.re_init() + degree = self.find_exact_degree_of_specific_output_bit( + i, which_var_degree=which_var_degree, chosen_cipher_output=chosen_cipher_output + ) + degrees.append(degree) + verbosity = old_verbosity + + self._log_experiment( + "all output bits exact degree", + { + "chosen_cipher_output": chosen_cipher_output, + "which_var_degree": which_var_degree, + }, + degrees + ) + + return degrees + + def find_upper_bound_degree_of_cube_monomial_of_specific_output_bit( + self, + output_bit_index, + cube, + chosen_cipher_output=None, + ): + r""" + Compute an upper bound degree of the given cube monomial relatively to the given cipher output bit. + + INPUT: + + - ``output_bit_index`` -- **integer** + Index (0-based, counting from the most significant bit). + + - ``cube`` -- **list of strings** + List of cube variable names (e.g. ``["p1", "p3", "p8"]``) representing the cube variables fixed to 1. + + - ``chosen_cipher_output`` -- **string** (default: ``None``) + Optional component ID if the computation targets an intermediate output + instead of the final cipher output. + + OUTPUT: + + - **integer** + Upper bound degree of the given cube monomial. Maximum value is the number of variables involved in the cube. + Returns ``-1`` if the model is infeasible. + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: cipher = SimonBlockCipher(number_of_rounds=13) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: cube = [f"p{i}" for i in range(1, 32)] + sage: d = milp.find_upper_bound_degree_of_cube_monomial_of_specific_output_bit(16, cube) # doctest: +SKIP + ... + """ + self.build_generic_model_for_specific_output_bit( + output_bit_index, fixed_degree=None, which_var_degree=None, chosen_cipher_output=chosen_cipher_output + ) + m = self._model + m.Params.OutputFlag = 0 + m.setParam(GRB.Param.PoolSearchMode, 0) + m.setParam("MIPGap", 0) + + # Fix cube bits to 1 + cube_verbose = self.var_list_to_input_positions(cube) + for inp_name, idx in cube_verbose: + v = m.getVarByName(f"{inp_name}[{idx}]") + m.addConstr(v == 1) + + # Objective: maximize number of cube variables influencing the degree + cube_vars = [m.getVarByName(f"{inp_name}[{idx}]") for (inp_name, idx) in cube_verbose] + m.setObjective(sum(cube_vars), GRB.MAXIMIZE) + m.update() + m.optimize() + + if m.Status not in [GRB.OPTIMAL, GRB.SUBOPTIMAL]: + if verbosity: + print(f"[INFO] Model infeasible for output bit {output_bit_index}") + return -1 + + degree_upper_bound = int(round(m.ObjVal)) + + self._log_experiment( + "upper bound degree of cube monomial", + { + "output_bit_index": output_bit_index, + "chosen_cipher_output": chosen_cipher_output, + "cube": cube + }, + degree_upper_bound + ) + + return degree_upper_bound + + + def find_keycoeff_of_cube_monomial_of_specific_output_bit( + self, + output_bit_index, + cube, + chosen_cipher_output=None, + ): + r""" + Compute the coefficient of the given cube monomial over key varables only, of a given cipher output bit. + + INPUT: + + - ``output_bit_index`` -- **integer** + Index (0-based, counting from the most significant bit) + + - ``cube`` -- **list of strings** + List of cube variable names (e.g. ``["p1", "p3", "p8"]``) representing the cube variables fixed to 1. + + - ``chosen_cipher_output`` -- **string** (default: ``None``) + Optional component ID if the computation targets an intermediate output + instead of the final cipher output. + + OUTPUT: + + - **Sage BooleanPolynomial** + Boolean polynomial over key variables corresponding to the coefficient + of the given cube. + Returns ``0`` if the model is infeasible or no valid solutions are found. + + EXAMPLES:: + + sage: from claasp.ciphers.stream_ciphers.trivium_stream_cipher import TriviumStreamCipher + sage: cipher = TriviumStreamCipher(keystream_bit_len=1, number_of_initialization_clocks=200) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: cube = ["i53"] + sage: coeff = milp.find_keycoeff_of_cube_monomial_of_specific_output_bit(0, cube) # doctest: +SKIP + ... + + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: cipher = SimonBlockCipher(number_of_rounds=13) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import MilpMonomialPredictionModel + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: cube = [f"p{i}" for i in range(1, 32)] + sage: coeff = milp.find_keycoeff_of_cube_monomial_of_specific_output_bit(15, cube) # doctest: +SKIP + ... + """ + self.build_generic_model_for_specific_output_bit( + output_bit_index, fixed_degree=None, which_var_degree=None, chosen_cipher_output=chosen_cipher_output + ) + m = self._model + m.Params.OutputFlag = 0 + m.setParam(GRB.Param.PoolSearchMode, 2) + m.setParam(GRB.Param.PoolSolutions, 200000000) + m.setParam(GRB.Param.PoolGap, 0.0) + + cube_verbose = self.var_list_to_input_positions(cube) + cube_set = {(a, b) for (a, b) in cube_verbose} + + # Fix cube bits to 1 + for (inp_name, idx) in cube_verbose: + v = m.getVarByName(f"{inp_name}[{idx}]") + m.addConstr(v == 1) + + cube_vars = [m.getVarByName(f"{a}[{b}]") for (a, b) in cube_verbose] + m.addConstr(sum(cube_vars) == len(cube)) + + # Fix all other non-key input bits to 0 + for (inp, sz) in zip(self._cipher.inputs, self._cipher.inputs_bit_size): + pref = inp[0] + if pref in {"p", "i"}: + for i in range(sz): + if (inp, i) in cube_set: + continue + v = m.getVarByName(f"{inp}[{i}]") + if v is not None: + m.addConstr(v == 0) + + m.setObjective(0.0, GRB.MAXIMIZE) + m.update() + m.optimize() + + if m.Status not in [GRB.OPTIMAL, GRB.SUBOPTIMAL] or m.SolCount == 0: + if verbosity: + print(f"[INFO] Model infeasible or no valid solutions for output bit {output_bit_index}") + return self.get_boolean_polynomial_ring()(0) + + poly_full = self.get_solutions() + + # Substitute cube bits to 1 + subs_map = {f"{inp_name[0]}{idx}": 1 for (inp_name, idx) in cube_verbose} + key_coef_poly = poly_full.subs(subs_map) + + self._log_experiment( + "key coefficient of cube monomial", + { + "output_bit_index": output_bit_index, + "chosen_cipher_output": chosen_cipher_output, + "cube": cube + }, + key_coef_poly + ) + + + return key_coef_poly + + def _log_experiment( + self, + experiment_name: str, + details: dict, + content: str, + ): + """ + Internal helper to log experiment results to a timestamped text file. + Used whenever verbosity is enabled to avoid code duplication. + + Args: + experiment_name (str): Short label for the experiment (e.g., "ANF", "superpoly"). + details (dict): Key-value pairs to include in the header (e.g., output_bit_index, cube, etc.). + content (str): The main content to log (e.g., a polynomial, integer, or list). + folder (str): Target folder name for logs (default: "monomial_prediction_experiments"). + """ + if not verbosity: + return + + folder = "monomial_prediction_experiments" + os.makedirs(folder, exist_ok=True) + filename = os.path.join(folder, f"{self._cipher._id}.txt") + try: + with open(filename, "a", encoding="utf-8") as f: + f.write("\n" + "=" * 80 + "\n") + f.write(f"Experiment: {experiment_name}\n") + for k, v in details.items(): + f.write(f"{k}: {v}\n") + f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write("-" * 10 + "\n") + f.write(str(content)) + f.write("\n\n") + print(f"[INFO] {experiment_name} successfully saved to '{filename}'") + except Exception as e: + print(f"[WARNING] Failed to save {experiment_name} to file: {e}") + + + +################################ +######## END OF CLASS ########## +################################ + + +def _valuation_from_assign(cipher, full_assign_bits, allowed_prefixes=None): + val = {} + for name, size in zip(cipher.inputs, cipher.inputs_bit_size): + pref = name[0] + if allowed_prefixes is not None and pref not in allowed_prefixes: + continue + w = full_assign_bits[name] + for i in range(size): # i is MSB index + bit = (w >> (size - 1 - i)) & 1 + val[f'{pref}{i}'] = bit + return val + + +def _parse_cube_positions(cipher, cube_tokens): + pref_map = {name[0]: (name, size) + for name, size in zip(cipher.inputs, cipher.inputs_bit_size)} + out = [] + for tok in cube_tokens: + pref, msb_pos = tok[0], int(tok[1:]) + name, size = pref_map[pref] + if msb_pos < 0 or msb_pos >= size: + raise ValueError(f"{tok} out of range for input {name} (size {size})") + out.append((name, msb_pos)) + return out + + +def _eval_boolean_poly(poly, valuation): + vars_in_poly = [str(v) for v in poly.variables()] + if not vars_in_poly: + return int(GF(2)(poly)) + vals = {name: valuation.get(name, 0) for name in vars_in_poly} + return int(GF(2)(poly(**vals))) + + +def check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(cipher, output_bit_index, cube, poly, + public_assign_bits=None, trials=16): + """ + Check the correctness of a computed cube monomial coefficient or superpoly + for a specific cipher output bit, by evaluating the cipher multiple times + with random key and public variable assignments. + + The method performs the full cube summation over the selected cube bits and + compares the resulting bit parity with the evaluation of the provided Boolean + polynomial (`poly`) under the same key assignment. A match across all trials + confirms the correctness of the derived superpoly or key coefficient. + + INPUT: + + - ``cipher`` -- **Cipher object** + The CLAASP cipher instance implementing the `evaluate()` method. + + - ``output_bit_index`` -- **integer** + Index (0-based, counting from the most significant bit). + + - ``cube`` -- **list of strings** + List of cube variable names (e.g. ``["p1", "p3"]`` or ``["i9", "i19", ...]``) + that define the cube monomial being analyzed. + + - ``poly`` -- **Sage BooleanPolynomial** + The candidate Boolean polynomial representing either the key coefficient + or the superpoly predicted by the MILP model. + + - ``public_assign_bits`` -- **dict** (default: ``None``) + Optional mapping specifying fixed assignments for public variables (e.g. plaintext or IV). + If omitted, all non-cube public variables default to zero. + Example: ``{"plaintext": 0xfda120472589641}`` or ``{"initialization_vector": (1 << 80) - 1}``. + + - ``trials`` -- **integer** (default: ``16``) + Number of random key assignments to test. Each trial independently verifies the + cube summation equivalence under new random key values. + + OUTPUT: + + - **boolean** + ``True`` if the cube-summation result matches the polynomial evaluation for all trials, + otherwise ``False``. + + Example:: + sage: from claasp.ciphers.stream_ciphers.trivium_stream_cipher import TriviumStreamCipher + sage: cipher = TriviumStreamCipher(keystream_bit_len=1, number_of_initialization_clocks= 200) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import * + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: cube = ["i53"] + sage: coef_poly = milp.find_keycoeff_of_cube_monomial_of_specific_output_bit(0, cube) # doctest: +SKIP + sage: check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(cipher, 0, cube, coef_poly) # doctest: +SKIP + ... + + sage: from claasp.ciphers.stream_ciphers.trivium_stream_cipher import TriviumStreamCipher + sage: cipher = TriviumStreamCipher(keystream_bit_len=1, number_of_initialization_clocks= 590) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import * + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: cube = ['i9', 'i19', 'i29', 'i39', 'i49', 'i59', 'i69', 'i79'] + sage: superpoly = milp.find_superpoly_of_specific_output_bit(0, cube) # doctest: +SKIP + sage: check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(cipher, 0, cube, superpoly) # doctest: +SKIP + ... + + # by defult non-cube public variables assign to zero but that can be assigned + # to any arbitrary constant values, for example see below + # From the following dictionary 'pub' all non-cube public vars will be set to constant 1. + + sage: pub = {"initialization_vector": (1 << 80) - 1} # Every non cube vars set to 1. + sage: check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(cipher, 0, cube, superpoly, public_assign_bits= pub) # doctest: +SKIP + ... + + # A short example + sage: from claasp.ciphers.block_ciphers.present_block_cipher import PresentBlockCipher + sage: cipher = PresentBlockCipher(number_of_rounds=1) + sage: from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import * + sage: milp = MilpMonomialPredictionModel(cipher) # doctest: +SKIP + sage: cube = ['p2', 'p3'] + sage: superpoly = milp.find_superpoly_of_specific_output_bit(0, cube) # doctest: +SKIP + sage: check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(cipher, 0, cube, superpoly) # doctest: +SKIP + ... + sage: pub = {"plaintext": 0xfda120472589641} # Set to 1 or 0 the plaintext vars according to the given pattern. + sage: check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(cipher, 0, cube, superpoly, public_assign_bits= pub) # doctest: +SKIP + ... + + """ + cube_pos = _parse_cube_positions(cipher, cube) + m = len(cube_pos) + size_map = dict(zip(cipher.inputs, cipher.inputs_bit_size)) + needed_prefixes = {str(v)[0] for v in poly.variables()} + + for _ in range(trials): + # random keys, fixed/zero publics + assign = {} + for name, size in zip(cipher.inputs, cipher.inputs_bit_size): + if name.startswith('k'): + assign[name] = secrets.randbits(size) + else: + if needed_prefixes <= {'k'}: + assign[name] = 0 + else: + if public_assign_bits and name in public_assign_bits: + assign[name] = int(public_assign_bits[name]) + else: + assign[name] = 0 + + acc = 0 + for a in range(1 << m): + cur = dict(assign) + for j, (inp_name, msb_pos) in enumerate(cube_pos): + size = size_map[inp_name] + lsb_idx = size - 1 - msb_pos # MSB is first + mask = 1 << lsb_idx + if (a >> j) & 1: + cur[inp_name] |= mask + else: + cur[inp_name] &= ~mask + + output = cipher.evaluate([cur[name] for name in cipher.inputs]) + out_lsb_idx = cipher.output_bit_size - 1 - output_bit_index + acc ^= (output >> out_lsb_idx) & 1 + + vals = _valuation_from_assign(cipher, assign, allowed_prefixes=needed_prefixes) + rhs = _eval_boolean_poly(poly, vals) + + if acc != rhs: + return False + return True diff --git a/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model.py index b423268e7..ad6f9f517 100644 --- a/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model.py @@ -1,34 +1,45 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** import time -from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_BITWISE_DETERMINISTIC_TRUNCATED, \ - MILP_BACKWARD_SUFFIX, MILP_BUILDING_MESSAGE, MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE -from claasp.cipher_modules.models.milp.utils.milp_truncated_utils import \ - fix_variables_value_deterministic_truncated_xor_differential_constraints + from claasp.cipher_modules.models.milp.milp_model import MilpModel +from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_BITWISE_DETERMINISTIC_TRUNCATED, + MILP_BUILDING_MESSAGE, + MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE, +) +from claasp.cipher_modules.models.milp.utils.milp_truncated_utils import ( + fix_variables_value_deterministic_truncated_xor_differential_constraints, +) from claasp.cipher_modules.models.utils import set_component_solution -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, - WORD_OPERATION, LINEAR_LAYER, SBOX, MIX_COLUMN) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class MilpBitwiseDeterministicTruncatedXorDifferentialModel(MilpModel): - def __init__(self, cipher, n_window_heuristic=None, verbose=False): super().__init__(cipher, n_window_heuristic, verbose) self._trunc_binvar = None @@ -86,7 +97,6 @@ def add_constraints_to_build_in_sage_milp_class(self, fixed_variables=[]): x = self._binary_variable p = self._integer_variable - components = self._cipher.get_all_components() last_component = components[-1] @@ -98,14 +108,19 @@ def add_constraints_to_build_in_sage_milp_class(self, fixed_variables=[]): input_id_tuples, output_id_tuples = last_component._get_input_output_variables_tuples() input_ids, output_ids = last_component._get_input_output_variables() - linking_constraints = self.link_binary_tuples_to_integer_variables(input_id_tuples + output_id_tuples, - input_ids + output_ids) + linking_constraints = self.link_binary_tuples_to_integer_variables( + input_id_tuples + output_id_tuples, input_ids + output_ids + ) for constraint in linking_constraints: mip.add_constraint(constraint) - mip.add_constraint(p[MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE] == sum(x[output_msb] for output_msb in [id[0] for id in output_id_tuples])) - - - def build_bitwise_deterministic_truncated_xor_differential_trail_model(self, fixed_variables=[], component_list=None): + mip.add_constraint( + p[MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE] + == sum(x[output_msb] for output_msb in [id[0] for id in output_id_tuples]) + ) + + def build_bitwise_deterministic_truncated_xor_differential_trail_model( + self, fixed_variables=[], component_list=None + ): """ Build the model for the search of bitwise deterministic truncated XOR differential trails. @@ -132,26 +147,42 @@ def build_bitwise_deterministic_truncated_xor_differential_trail_model(self, fix self._variables_list = [] variables = [] constraints = self.fix_variables_value_bitwise_deterministic_truncated_xor_differential_constraints( - fixed_variables) + fixed_variables + ) self._model_constraints = constraints component_list = component_list or self._cipher.get_all_components() for component in component_list: - component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, - WORD_OPERATION] + component_types = ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, + ) operation = component.description[0] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR'] + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): - if operation in ['XOR','MODADD'] or component.type == LINEAR_LAYER: - variables, constraints = component.milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(self) + if operation in ("XOR", "MODADD") or component.type == LINEAR_LAYER: + variables, constraints = ( + component.milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(self) + ) elif component.type == SBOX: - variables, constraints = component.milp_undisturbed_bits_bitwise_deterministic_truncated_xor_differential_constraints(self) + variables, constraints = ( + component.milp_undisturbed_bits_bitwise_deterministic_truncated_xor_differential_constraints( + self + ) + ) else: - variables, constraints = component.milp_bitwise_deterministic_truncated_xor_differential_constraints(self) + variables, constraints = ( + component.milp_bitwise_deterministic_truncated_xor_differential_constraints(self) + ) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -203,7 +234,9 @@ def fix_variables_value_bitwise_deterministic_truncated_xor_differential_constra """ - return fix_variables_value_deterministic_truncated_xor_differential_constraints(self, self.trunc_binvar, fixed_variables) + return fix_variables_value_deterministic_truncated_xor_differential_constraints( + self, self.trunc_binvar, fixed_variables + ) def link_binary_tuples_to_integer_variables(self, id_tuples, ids): """ @@ -243,7 +276,6 @@ def link_binary_tuples_to_integer_variables(self, id_tuples, ids): """ - x = self.binary_variable x_class = self.trunc_binvar @@ -252,12 +284,13 @@ def link_binary_tuples_to_integer_variables(self, id_tuples, ids): variables = [x_class[i] for i in ids] for index, var in enumerate(variables): - constraints.append( - var == sum([2 ** i * var_bit for i, var_bit in enumerate(variables_tuples[index][::-1])])) + constraints.append(var == sum([2**i * var_bit for i, var_bit in enumerate(variables_tuples[index][::-1])])) return constraints - def find_one_bitwise_deterministic_truncated_xor_differential_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_bitwise_deterministic_truncated_xor_differential_trail( + self, fixed_values=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Returns one deterministic truncated XOR differential trail. @@ -312,11 +345,13 @@ def find_one_bitwise_deterministic_truncated_xor_differential_trail(self, fixed_ end = time.time() building_time = end - start solution = self.solve(MILP_BITWISE_DETERMINISTIC_TRUNCATED, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution - def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail(self, fixed_values=[], solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail( + self, fixed_values=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Return the solution representing a differential trail with the lowest number of unknown variables. @@ -353,7 +388,7 @@ def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential end = time.time() building_time = end - start solution = self.solve(MILP_BITWISE_DETERMINISTIC_TRUNCATED, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution @@ -368,6 +403,7 @@ def _get_component_values(self, objective_variables, components_variables): dict_tmp = self._get_component_value_weight(component_id, components_variables) components_values[component_id] = dict_tmp return components_values + def _parse_solver_output(self): mip = self._model components_variables = mip.get_values(self._trunc_binvar) @@ -378,7 +414,6 @@ def _parse_solver_output(self): return objective_value, components_values def _get_component_value_weight(self, component_id, components_variables): - if component_id in self._cipher.inputs: output_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(component_id)] else: @@ -407,5 +442,3 @@ def _get_final_output(self, component_id, components_variables, suffix_dict): final_output.append(set_component_solution(diff_str)) return final_output - - diff --git a/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model.py b/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model.py index 806c72759..7b7935f50 100644 --- a/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model.py +++ b/claasp/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -18,22 +18,27 @@ import time from claasp.cipher_modules.inverse_cipher import get_key_schedule_component_ids +from claasp.cipher_modules.models.milp.milp_models.milp_bitwise_deterministic_truncated_xor_differential_model import ( + MilpBitwiseDeterministicTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT -from claasp.cipher_modules.models.milp.milp_models.milp_bitwise_deterministic_truncated_xor_differential_model import \ - MilpBitwiseDeterministicTruncatedXorDifferentialModel -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_BITWISE_IMPOSSIBLE, \ - MILP_BITWISE_IMPOSSIBLE_AUTO, MILP_BACKWARD_SUFFIX, MILP_BUILDING_MESSAGE -from claasp.name_mappings import CIPHER_OUTPUT, INPUT_KEY from claasp.cipher_modules.models.milp.utils import utils as milp_utils, milp_truncated_utils +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_BACKWARD_SUFFIX, + MILP_BITWISE_IMPOSSIBLE_AUTO, + MILP_BITWISE_IMPOSSIBLE, + MILP_BUILDING_MESSAGE, +) +from claasp.name_mappings import CIPHER_OUTPUT, INPUT_KEY class MilpBitwiseImpossibleXorDifferentialModel(MilpBitwiseDeterministicTruncatedXorDifferentialModel): - def __init__(self, cipher, n_window_heuristic=None, verbose=False): super().__init__(cipher, n_window_heuristic, verbose) self._forward_cipher = None self._backward_cipher = None self._incompatible_components = None + def build_bitwise_impossible_xor_differential_trail_model(self, fixed_variables=[]): """ Build the model for the search of bitwise impossible XOR differential trails. @@ -100,13 +105,15 @@ def add_constraints_to_build_in_sage_milp_class(self, middle_round=None, fixed_v assert middle_round < self._cipher.number_of_rounds self._forward_cipher = self._cipher.get_partial_cipher(0, middle_round - 1, keep_key_schedule=True) - backward_cipher = self._cipher.cipher_partial_inverse(middle_round, self._cipher.number_of_rounds - 1, - keep_key_schedule=False) - self._backward_cipher = backward_cipher.add_suffix_to_components(MILP_BACKWARD_SUFFIX, - [backward_cipher.get_all_components_ids()[-1]]) + backward_cipher = self._cipher.cipher_partial_inverse( + middle_round, self._cipher.number_of_rounds - 1, keep_key_schedule=False + ) + self._backward_cipher = backward_cipher.add_suffix_to_components( + MILP_BACKWARD_SUFFIX, [backward_cipher.get_all_components_ids()[-1]] + ) self.build_bitwise_impossible_xor_differential_trail_model(fixed_variables) - for index, constraint in enumerate(self._model_constraints): + for constraint in self._model_constraints: mip.add_constraint(constraint) # finding incompatibility @@ -116,25 +123,31 @@ def add_constraints_to_build_in_sage_milp_class(self, middle_round=None, fixed_v _, output_ids = forward_output._get_input_output_variables() forward_vars = [x_class[id] for id in output_ids] - backward_vars = [x_class["_".join(id.split("_")[:-1] + ["backward"] + [id.split("_")[-1]])] for id in - output_ids] + backward_vars = [ + x_class["_".join(id.split("_")[:-1] + ["backward"] + [id.split("_")[-1]])] for id in output_ids + ] inconsistent_vars = [x[f"{forward_output.id}_inconsistent_{_}"] for _ in range(output_bit_size)] constraints.extend([sum(inconsistent_vars) == 1]) for inconsistent_index in range(output_bit_size): incompatibility_constraint = [forward_vars[inconsistent_index] + backward_vars[inconsistent_index] == 1] constraints.extend( - milp_utils.milp_if_then(inconsistent_vars[inconsistent_index], incompatibility_constraint, - self._model.get_max(x_class) * 2)) + milp_utils.milp_if_then( + inconsistent_vars[inconsistent_index], incompatibility_constraint, self._model.get_max(x_class) * 2 + ) + ) for constraint in constraints: mip.add_constraint(constraint) _, forward_output_id_tuples = forward_output._get_input_output_variables_tuples() - mip.add_constraint(p["number_of_unknown_patterns"] == sum( - x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuples])) - - def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components(self, component_id_list=None, - fixed_variables=[]): + mip.add_constraint( + p["number_of_unknown_patterns"] + == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuples]) + ) + + def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components( + self, component_id_list=None, fixed_variables=[] + ): """ Take the constraints contained in self._model_constraints and add them to the build-in sage class. @@ -168,7 +181,9 @@ def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_compone if component_id_list == None: return self.add_constraints_to_build_in_sage_milp_class(fixed_variables=fixed_variables) - assert set(component_id_list) <= set(self._cipher.get_all_components_ids()) - set(get_key_schedule_component_ids(self._cipher)) + assert set(component_id_list) <= set(self._cipher.get_all_components_ids()) - set( + get_key_schedule_component_ids(self._cipher) + ) middle_round_numbers = [self._cipher.get_round_from_component_id(id) for id in component_id_list] @@ -176,24 +191,40 @@ def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_compone middle_round_number = middle_round_numbers[0] - if len(component_id_list) == 1 and self._cipher.get_component_from_id(component_id_list[0]).description == ['round_output']: + if len(component_id_list) == 1 and self._cipher.get_component_from_id(component_id_list[0]).description == [ + "round_output" + ]: return self.add_constraints_to_build_in_sage_milp_class(middle_round_number + 1, fixed_variables) self._forward_cipher = self._cipher.get_partial_cipher(0, middle_round_number, keep_key_schedule=True) - backward_cipher = self._cipher.cipher_partial_inverse(middle_round_number, self._cipher.number_of_rounds - 1, - keep_key_schedule=False) + backward_cipher = self._cipher.cipher_partial_inverse( + middle_round_number, self._cipher.number_of_rounds - 1, keep_key_schedule=False + ) self._incompatible_components = component_id_list - backward_last_round_components = set(backward_cipher._rounds.round_at(self._cipher.number_of_rounds - 1 - middle_round_number).get_components_ids() + [backward_cipher.get_all_components_ids()[-1]]) - input_id_links_of_chosen_components = [_ for c in [backward_cipher.get_component_from_id(id) for id in component_id_list] for _ in c.input_id_links] - round_input_id_links_of_chosen_components = [backward_cipher.get_round_from_component_id(id) for id in input_id_links_of_chosen_components] - links_round = [_ for r in round_input_id_links_of_chosen_components for _ in backward_cipher._rounds.round_at(r).get_components_ids()] - self._backward_cipher = backward_cipher.add_suffix_to_components(MILP_BACKWARD_SUFFIX, backward_last_round_components | set(links_round)) + backward_last_round_components = set( + backward_cipher._rounds.round_at( + self._cipher.number_of_rounds - 1 - middle_round_number + ).get_components_ids() + + [backward_cipher.get_all_components_ids()[-1]] + ) + input_id_links_of_chosen_components = [ + _ for c in [backward_cipher.get_component_from_id(id) for id in component_id_list] for _ in c.input_id_links + ] + round_input_id_links_of_chosen_components = [ + backward_cipher.get_round_from_component_id(id) for id in input_id_links_of_chosen_components + ] + links_round = [ + _ + for r in round_input_id_links_of_chosen_components + for _ in backward_cipher._rounds.round_at(r).get_components_ids() + ] + self._backward_cipher = backward_cipher.add_suffix_to_components( + MILP_BACKWARD_SUFFIX, backward_last_round_components | set(links_round) + ) self.build_bitwise_impossible_xor_differential_trail_model(fixed_variables) - - # finding incompatibility incompatibility_constraints = [] @@ -209,25 +240,39 @@ def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_compone inconsistent_vars = [x[f"{forward_component.id}_inconsistent_{_}"] for _ in range(output_bit_size)] # for multiple input components such as the XOR, ensures compatibility occurs on the correct branch - for index, input_id in enumerate(["_".join(i.split("_")[:-1]) if MILP_BACKWARD_SUFFIX in i else i for i in backward_component.input_id_links]): + for index, input_id in enumerate( + [ + "_".join(i.split("_")[:-1]) if MILP_BACKWARD_SUFFIX in i else i + for i in backward_component.input_id_links + ] + ): if INPUT_KEY not in input_id and self._cipher.get_component_from_id(input_id).input_id_links == [id]: - backward_vars = [x_class[f'{input_id}_{pos}'] for pos in backward_component.input_bit_positions[index]] + backward_vars = [ + x_class[f"{input_id}_{pos}"] for pos in backward_component.input_bit_positions[index] + ] incompatibility_constraints.extend([sum(inconsistent_vars) == 1]) for inconsistent_index in range(output_bit_size): incompatibility_constraint = [forward_vars[inconsistent_index] + backward_vars[inconsistent_index] == 1] incompatibility_constraints.extend( - milp_utils.milp_if_then(inconsistent_vars[inconsistent_index], incompatibility_constraint, - self._model.get_max(x_class) * 2)) + milp_utils.milp_if_then( + inconsistent_vars[inconsistent_index], + incompatibility_constraint, + self._model.get_max(x_class) * 2, + ) + ) _, forward_output_id_tuples = forward_component._get_input_output_variables_tuples() - optimization_constraint = [p["number_of_unknown_patterns"] == sum( - x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuples])] + optimization_constraint = [ + p["number_of_unknown_patterns"] + == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuples]) + ] for constraint in self._model_constraints + incompatibility_constraints + optimization_constraint: mip.add_constraint(constraint) - def add_constraints_to_build_fully_automatic_model_in_sage_milp_class(self, fixed_variables=[], include_all_components=False): - + def add_constraints_to_build_fully_automatic_model_in_sage_milp_class( + self, fixed_variables=[], include_all_components=False + ): """ Take the constraints contained in self._model_constraints and add them to the build-in sage class. @@ -264,22 +309,27 @@ def add_constraints_to_build_fully_automatic_model_in_sage_milp_class(self, fixe self._backward_cipher = self._cipher.cipher_inverse().add_suffix_to_components(MILP_BACKWARD_SUFFIX) self.build_bitwise_impossible_xor_differential_trail_model(fixed_variables) - for index, constraint in enumerate(self._model_constraints): + for constraint in self._model_constraints: mip.add_constraint(constraint) # finding incompatibility - constraints = milp_truncated_utils.generate_all_incompatibility_constraints_for_fully_automatic_model(self, MILP_BITWISE_IMPOSSIBLE_AUTO, x, x_class, include_all_components) + constraints = milp_truncated_utils.generate_all_incompatibility_constraints_for_fully_automatic_model( + self, MILP_BITWISE_IMPOSSIBLE_AUTO, x, x_class, include_all_components + ) for constraint in constraints: mip.add_constraint(constraint) forward_output = [c for c in self._forward_cipher.get_all_components() if c.type == CIPHER_OUTPUT][0] _, forward_output_id_tuples = forward_output._get_input_output_variables_tuples() - mip.add_constraint(p["number_of_unknown_patterns"] == sum( - x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuples])) - - def find_one_bitwise_impossible_xor_differential_trail(self, middle_round, fixed_values=[], - solver_name=SOLVER_DEFAULT, external_solver_name=None): + mip.add_constraint( + p["number_of_unknown_patterns"] + == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuples]) + ) + + def find_one_bitwise_impossible_xor_differential_trail( + self, middle_round, fixed_values=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Returns one bitwise impossible XOR differential trail. @@ -325,12 +375,13 @@ def find_one_bitwise_impossible_xor_differential_trail(self, middle_round, fixed end = time.time() building_time = end - start solution = self.solve(MILP_BITWISE_IMPOSSIBLE, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution - def find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components(self, component_id_list, fixed_values=[], - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components( + self, component_id_list, fixed_values=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Returns one bitwise impossible XOR differential trail. @@ -375,16 +426,19 @@ def find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_ self._verbose_print(f"Solver used : {solver_name} (Choose Gurobi for Better performance)") mip = self._model mip.set_objective(None) - self.add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components(component_id_list, fixed_values) + self.add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components( + component_id_list, fixed_values + ) end = time.time() building_time = end - start solution = self.solve(MILP_BITWISE_IMPOSSIBLE, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution - def find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model(self, fixed_values=[], include_all_components=False, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model( + self, fixed_values=[], include_all_components=False, solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Returns one bitwise impossible XOR differential trail. @@ -428,24 +482,27 @@ def find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_mode self._verbose_print(f"Solver used : {solver_name} (Choose Gurobi for Better performance)") mip = self._model mip.set_objective(None) - self.add_constraints_to_build_fully_automatic_model_in_sage_milp_class(fixed_variables=fixed_values, include_all_components=include_all_components) + self.add_constraints_to_build_fully_automatic_model_in_sage_milp_class( + fixed_variables=fixed_values, include_all_components=include_all_components + ) end = time.time() building_time = end - start solution = self.solve(MILP_BITWISE_IMPOSSIBLE_AUTO, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution def _get_component_values(self, objective_variables, components_variables): - return milp_utils._get_component_values_for_impossible_models(self, objective_variables, components_variables) + return milp_utils._get_component_values_for_impossible_models(self, objective_variables, components_variables) def _parse_solver_output(self): mip = self._model components_variables = mip.get_values(self._trunc_binvar) if self._forward_cipher == self._cipher: objective_variables = mip.get_values(self._binary_variable) - inconsistent_component_var = \ - [i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i][0] + inconsistent_component_var = [ + i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i + ][0] objective_value = "_".join(inconsistent_component_var.split("_")[:-3]) else: objective_variables = mip.get_values(self._integer_variable) @@ -454,7 +511,6 @@ def _parse_solver_output(self): return objective_value, components_values def _get_component_value_weight(self, component_id, components_variables): - if component_id in self._cipher.inputs: output_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(component_id)] elif self._forward_cipher != self._cipher and component_id.endswith(MILP_BACKWARD_SUFFIX): @@ -474,4 +530,4 @@ def _get_component_value_weight(self, component_id, components_variables): if len(final_output) == 1: final_output = final_output[0] - return final_output \ No newline at end of file + return final_output diff --git a/claasp/cipher_modules/models/milp/milp_models/milp_cipher_model.py b/claasp/cipher_modules/models/milp/milp_models/milp_cipher_model.py index f6a7e29dc..8d777e4ed 100644 --- a/claasp/cipher_modules/models/milp/milp_models/milp_cipher_model.py +++ b/claasp/cipher_modules/models/milp/milp_models/milp_cipher_model.py @@ -1,28 +1,26 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** from claasp.cipher_modules.models.milp.milp_model import MilpModel -from claasp.name_mappings import (INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, WORD_OPERATION, MIX_COLUMN) +from claasp.name_mappings import CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, WORD_OPERATION class MilpCipherModel(MilpModel): - def __init__(self, cipher, n_window_heuristic=None): super().__init__(cipher, n_window_heuristic) @@ -54,15 +52,16 @@ def build_cipher_model(self, fixed_variables=[]): variables = [] self._variables_list = [] constraints = self.fix_variables_value_constraints(fixed_variables) - component_types = [CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, WORD_OPERATION] - operation_types = ['NOT', 'ROTATE', 'SHIFT', 'XOR'] + component_types = (CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, WORD_OPERATION) + operation_types = ("NOT", "ROTATE", "SHIFT", "XOR") self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.milp_constraints(self) diff --git a/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model.py index 823fb5bcc..cde1c64e3 100644 --- a/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model.py @@ -1,42 +1,53 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** import time -from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT + +from numpy import array_split + from claasp.cipher_modules.models.milp.milp_model import MilpModel -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_WORDWISE_DETERMINISTIC_TRUNCATED, \ - MILP_BUILDING_MESSAGE, MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE -from claasp.cipher_modules.models.milp.utils.utils import espresso_pos_to_constraints, \ - _get_variables_values_as_string -from claasp.cipher_modules.models.milp.utils.milp_truncated_utils import \ - fix_variables_value_deterministic_truncated_xor_differential_constraints -from claasp.cipher_modules.models.utils import set_component_solution -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, - WORD_OPERATION, LINEAR_LAYER, SBOX, MIX_COLUMN) +from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits import ( update_dictionary_that_contains_wordwise_truncated_input_inequalities, - output_dictionary_that_contains_wordwise_truncated_input_inequalities + output_dictionary_that_contains_wordwise_truncated_input_inequalities, +) +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_WORDWISE_DETERMINISTIC_TRUNCATED, + MILP_BUILDING_MESSAGE, + MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE, +) +from claasp.cipher_modules.models.milp.utils.milp_truncated_utils import ( + fix_variables_value_deterministic_truncated_xor_differential_constraints, ) +from claasp.cipher_modules.models.milp.utils.utils import espresso_pos_to_constraints, _get_variables_values_as_string +from claasp.cipher_modules.models.utils import set_component_solution from claasp.editor import get_output_bit_size_from_id -from numpy import array_split +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class MilpWordwiseDeterministicTruncatedXorDifferentialModel(MilpModel): - def __init__(self, cipher, n_window_heuristic=None, verbose=False): super().__init__(cipher, n_window_heuristic, verbose) self._trunc_wordvar = None @@ -111,11 +122,14 @@ def add_constraints_to_build_in_sage_milp_class(self, fixed_bits=[], fixed_words # objective is the number of unknown patterns i.e. tuples of the form (1, x) _, output_ids = last_component._get_wordwise_input_output_linked_class_tuples(self) - mip.add_constraint(p[MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE] == sum( - x[output_msb] for output_msb in [id[0] for id in output_ids])) - - def build_wordwise_deterministic_truncated_xor_differential_trail_model(self, fixed_bits=[], fixed_words=[], - cipher_list=None): + mip.add_constraint( + p[MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE] + == sum(x[output_msb] for output_msb in [id[0] for id in output_ids]) + ) + + def build_wordwise_deterministic_truncated_xor_differential_trail_model( + self, fixed_bits=[], fixed_words=[], cipher_list=None + ): """ Build the model for the search of wordwise deterministic truncated XOR differential trails. @@ -146,32 +160,43 @@ def build_wordwise_deterministic_truncated_xor_differential_trail_model(self, fi """ self._variables_list = [] cipher_list = cipher_list or [self._cipher] - component_list = [c for cipher_component in [cipher.get_all_components() for cipher in cipher_list] for c in - cipher_component] + component_list = [ + c for cipher_component in [cipher.get_all_components() for cipher in cipher_list] for c in cipher_component + ] variables, constraints = self.input_wordwise_deterministic_truncated_xor_differential_constraints( - component_list) + component_list + ) constraints += self.fix_variables_value_wordwise_deterministic_truncated_xor_differential_constraints( - fixed_bits, fixed_words, cipher_list) + fixed_bits, fixed_words, cipher_list + ) self._model_constraints = constraints for component in component_list: - component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, - WORD_OPERATION] + component_types = ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, + ) operation = component.description[0] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR'] + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") if component.type in component_types or operation in operation_types: variables, constraints = component.milp_wordwise_deterministic_truncated_xor_differential_constraints( - self) + self + ) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) - def fix_variables_value_wordwise_deterministic_truncated_xor_differential_constraints(self, fixed_bits=[], - fixed_words=[], - cipher_list=None): + def fix_variables_value_wordwise_deterministic_truncated_xor_differential_constraints( + self, fixed_bits=[], fixed_words=[], cipher_list=None + ): """ Returns a list of constraints that fix the input variables to a specific value. @@ -224,14 +249,21 @@ def fix_variables_value_wordwise_deterministic_truncated_xor_differential_constr if fixed_variable["constraint_type"] == "equal": output_bit_size = get_output_bit_size_from_id(cipher_list, fixed_variable["component_id"]) for i, current_word_bits in enumerate( - array_split(range(output_bit_size), output_bit_size // self._word_size)): + array_split(range(output_bit_size), output_bit_size // self._word_size) + ): if set(current_word_bits) <= set(fixed_variable["bit_positions"]): - if sum([fixed_variable["bit_values"][fixed_variable["bit_positions"].index(_)] for _ in - current_word_bits]) == 0: - constraints.append(x[f'{fixed_variable["component_id"]}_word_{i}_class'] == 0) - - return constraints + fix_variables_value_deterministic_truncated_xor_differential_constraints(self, x, - fixed_words) + if ( + sum( + fixed_variable["bit_values"][fixed_variable["bit_positions"].index(_)] + for _ in current_word_bits + ) + == 0 + ): + constraints.append(x[f"{fixed_variable['component_id']}_word_{i}_class"] == 0) + + return constraints + fix_variables_value_deterministic_truncated_xor_differential_constraints( + self, x, fixed_words + ) def input_wordwise_deterministic_truncated_xor_differential_constraints(self, component_list=None): """ @@ -288,8 +320,10 @@ def input_wordwise_deterministic_truncated_xor_differential_constraints(self, co minimized_constraints = espresso_pos_to_constraints(inequalities, word_vars) constraints.extend(minimized_constraints) - variables.extend([(f"x_class[{var}]", x_class[var]) for var in all_int_vars] + \ - [(f"x[{var}]", x[var]) for var in all_vars]) + variables.extend( + [(f"x_class[{var}]", x_class[var]) for var in all_int_vars] + + [(f"x[{var}]", x[var]) for var in all_vars] + ) # link class tuple (c0, c1) to the integer value of the class (0, 1, 2, 3) input_tuples, output_tuples = component._get_wordwise_input_output_linked_class_tuples(self) @@ -297,13 +331,14 @@ def input_wordwise_deterministic_truncated_xor_differential_constraints(self, co for index, var in enumerate(all_int_vars): constraints.append( - x_class[var] == sum([2 ** i * var_bit for i, var_bit in enumerate(variables_tuples[index][::-1])])) + x_class[var] == sum([2**i * var_bit for i, var_bit in enumerate(variables_tuples[index][::-1])]) + ) return variables, constraints - def find_one_wordwise_deterministic_truncated_xor_differential_trail(self, fixed_bits=[], fixed_words=[], - solver_name=SOLVER_DEFAULT, - external_solver_name=None): + def find_one_wordwise_deterministic_truncated_xor_differential_trail( + self, fixed_bits=[], fixed_words=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Returns one deterministic truncated XOR differential trail. @@ -339,14 +374,13 @@ def find_one_wordwise_deterministic_truncated_xor_differential_trail(self, fixed end = time.time() building_time = end - start solution = self.solve(MILP_WORDWISE_DETERMINISTIC_TRUNCATED, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution - def find_lowest_varied_patterns_wordwise_deterministic_truncated_xor_differential_trail(self, fixed_bits=[], - fixed_words=[], - solver_name=SOLVER_DEFAULT, - external_solver_name=None): + def find_lowest_varied_patterns_wordwise_deterministic_truncated_xor_differential_trail( + self, fixed_bits=[], fixed_words=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Return the solution representing a differential trail with the lowest number of unknown variables. @@ -384,7 +418,7 @@ def find_lowest_varied_patterns_wordwise_deterministic_truncated_xor_differentia end = time.time() building_time = end - start solution = self.solve(MILP_WORDWISE_DETERMINISTIC_TRUNCATED, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution @@ -414,7 +448,6 @@ def _parse_solver_output(self): return objective_value, components_values def _get_component_value_weight(self, component_id, components_variables): - wordsize = self._word_size if component_id in self._cipher.inputs: output_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(component_id)] // wordsize @@ -431,8 +464,9 @@ def _get_component_value_weight(self, component_id, components_variables): def _get_final_output(self, component_id, components_variables, suffix_dict): final_output = [] for suffix in suffix_dict.keys(): - diff_str = _get_variables_values_as_string(component_id + "_word", components_variables, suffix, - suffix_dict[suffix]) + diff_str = _get_variables_values_as_string( + f"{component_id}_word", components_variables, suffix, suffix_dict[suffix] + ) final_output.append(set_component_solution(diff_str)) return final_output diff --git a/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model.py b/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model.py index 7488f281f..877e8d996 100644 --- a/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model.py +++ b/claasp/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -18,16 +18,21 @@ import time from claasp.cipher_modules.inverse_cipher import get_key_schedule_component_ids +from claasp.cipher_modules.models.milp.milp_models.milp_wordwise_deterministic_truncated_xor_differential_model import ( + MilpWordwiseDeterministicTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT -from claasp.cipher_modules.models.milp.milp_models.milp_wordwise_deterministic_truncated_xor_differential_model import MilpWordwiseDeterministicTruncatedXorDifferentialModel -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_WORDWISE_IMPOSSIBLE_AUTO, \ - MILP_WORDWISE_IMPOSSIBLE, MILP_BACKWARD_SUFFIX, MILP_BUILDING_MESSAGE -from claasp.name_mappings import CIPHER_OUTPUT, INPUT_KEY from claasp.cipher_modules.models.milp.utils import utils as milp_utils, milp_truncated_utils +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_BACKWARD_SUFFIX, + MILP_BUILDING_MESSAGE, + MILP_WORDWISE_IMPOSSIBLE_AUTO, + MILP_WORDWISE_IMPOSSIBLE, +) +from claasp.name_mappings import CIPHER_OUTPUT, INPUT_KEY class MilpWordwiseImpossibleXorDifferentialModel(MilpWordwiseDeterministicTruncatedXorDifferentialModel): - def __init__(self, cipher, n_window_heuristic=None, verbose=False): super().__init__(cipher, n_window_heuristic, verbose) self._forward_cipher = None @@ -63,7 +68,9 @@ def build_wordwise_impossible_xor_differential_trail_model(self, fixed_bits=[], ... """ cipher_list = [self._forward_cipher, self._backward_cipher] - return self.build_wordwise_deterministic_truncated_xor_differential_trail_model(fixed_bits, fixed_words, cipher_list) + return self.build_wordwise_deterministic_truncated_xor_differential_trail_model( + fixed_bits, fixed_words, cipher_list + ) def add_constraints_to_build_in_sage_milp_class(self, middle_round=None, fixed_bits=[], fixed_words=[]): """ @@ -104,8 +111,12 @@ def add_constraints_to_build_in_sage_milp_class(self, middle_round=None, fixed_b assert middle_round < self._cipher.number_of_rounds self._forward_cipher = self._cipher.get_partial_cipher(0, middle_round - 1, keep_key_schedule=True) - backward_cipher = self._cipher.cipher_partial_inverse(middle_round, self._cipher.number_of_rounds - 1, keep_key_schedule=False) - self._backward_cipher = backward_cipher.add_suffix_to_components(MILP_BACKWARD_SUFFIX, [backward_cipher.get_all_components_ids()[-1]]) + backward_cipher = self._cipher.cipher_partial_inverse( + middle_round, self._cipher.number_of_rounds - 1, keep_key_schedule=False + ) + self._backward_cipher = backward_cipher.add_suffix_to_components( + MILP_BACKWARD_SUFFIX, [backward_cipher.get_all_components_ids()[-1]] + ) self.build_wordwise_impossible_xor_differential_trail_model(fixed_bits, fixed_words) for index, constraint in enumerate(self._model_constraints): @@ -126,15 +137,25 @@ def add_constraints_to_build_in_sage_milp_class(self, middle_round=None, fixed_b constraints.extend([sum(inconsistent_vars) == 1]) for inconsistent_index in range(output_size): incompatibility_constraints = [forward_vars[inconsistent_index] + backward_vars[inconsistent_index] <= 2] - dummy = x[f'dummy_incompatibility_{x[forward_vars[inconsistent_index]]}_or_{x[backward_vars[inconsistent_index]]}_is_0'] - incompatibility_constraints += [forward_vars[inconsistent_index] <= self._model.get_max(x_class) * (1 - dummy)] + dummy = x[ + f"dummy_incompatibility_{x[forward_vars[inconsistent_index]]}_or_{x[backward_vars[inconsistent_index]]}_is_0" + ] + incompatibility_constraints += [ + forward_vars[inconsistent_index] <= self._model.get_max(x_class) * (1 - dummy) + ] incompatibility_constraints += [backward_vars[inconsistent_index] <= self._model.get_max(x_class) * dummy] - constraints.extend(milp_utils.milp_if_then(inconsistent_vars[inconsistent_index], incompatibility_constraints, self._model.get_max(x_class) * 2)) + constraints.extend( + milp_utils.milp_if_then( + inconsistent_vars[inconsistent_index], incompatibility_constraints, self._model.get_max(x_class) * 2 + ) + ) # output is fixed cipher_output = [c for c in self._cipher.get_all_components() if c.type == CIPHER_OUTPUT][0] _, cipher_output_ids = cipher_output._get_wordwise_input_output_linked_class(self) - constraints.extend([x_class[id] <= 2 for id in cipher_output_ids] + [sum([x_class[id] for id in cipher_output_ids]) >= 1]) + constraints.extend( + [x_class[id] <= 2 for id in cipher_output_ids] + [sum([x_class[id] for id in cipher_output_ids]) >= 1] + ) for constraint in constraints: mip.add_constraint(constraint) @@ -142,10 +163,13 @@ def add_constraints_to_build_in_sage_milp_class(self, middle_round=None, fixed_b # unknown patterns are tuples of the form (1,x) (i.e pattern = 2 or 3) _, forward_output_id_tuple = forward_output._get_wordwise_input_output_linked_class_tuples(self) mip.add_constraint( - p["number_of_unknown_patterns"] == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuple])) - + p["number_of_unknown_patterns"] + == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuple]) + ) - def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components(self, component_id_list=None, fixed_bits=[], fixed_words=[]): + def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components( + self, component_id_list=None, fixed_bits=[], fixed_words=[] + ): """ Take the constraints contained in self._model_constraints and add them to the build-in sage class. @@ -181,7 +205,9 @@ def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_compone if component_id_list == None: return self.add_constraints_to_build_in_sage_milp_class(fixed_bits=fixed_bits, fixed_words=fixed_words) - assert set(component_id_list) <= set(self._cipher.get_all_components_ids()) - set(get_key_schedule_component_ids(self._cipher)) + assert set(component_id_list) <= set(self._cipher.get_all_components_ids()) - set( + get_key_schedule_component_ids(self._cipher) + ) middle_round_numbers = [self._cipher.get_round_from_component_id(id) for id in component_id_list] @@ -190,25 +216,36 @@ def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_compone middle_round_number = middle_round_numbers[0] if len(component_id_list) == 1 and self._cipher.get_component_from_id(component_id_list[0]).description == [ - 'round_output']: + "round_output" + ]: return self.add_constraints_to_build_in_sage_milp_class(middle_round_number + 1, fixed_bits, fixed_words) self._forward_cipher = self._cipher.get_partial_cipher(0, middle_round_number, keep_key_schedule=True) - backward_cipher = self._cipher.cipher_partial_inverse(middle_round_number, self._cipher.number_of_rounds - 1, keep_key_schedule=False) + backward_cipher = self._cipher.cipher_partial_inverse( + middle_round_number, self._cipher.number_of_rounds - 1, keep_key_schedule=False + ) self._incompatible_components = component_id_list - backward_last_round_components = set(backward_cipher._rounds.round_at(self._cipher.number_of_rounds - 1 - middle_round_number).get_components_ids() + [backward_cipher.get_all_components_ids()[-1]]) - input_id_links_of_chosen_components = [_ for c in - [backward_cipher.get_component_from_id(id) for id in component_id_list] - for _ in c.input_id_links] - round_input_id_links_of_chosen_components = [backward_cipher.get_round_from_component_id(id) for id in - input_id_links_of_chosen_components] - links_round = [_ for r in round_input_id_links_of_chosen_components for _ in - backward_cipher._rounds.round_at(r).get_components_ids()] - self._backward_cipher = backward_cipher.add_suffix_to_components(MILP_BACKWARD_SUFFIX, - backward_last_round_components | set( - links_round)) - + backward_last_round_components = set( + backward_cipher._rounds.round_at( + self._cipher.number_of_rounds - 1 - middle_round_number + ).get_components_ids() + + [backward_cipher.get_all_components_ids()[-1]] + ) + input_id_links_of_chosen_components = [ + _ for c in [backward_cipher.get_component_from_id(id) for id in component_id_list] for _ in c.input_id_links + ] + round_input_id_links_of_chosen_components = [ + backward_cipher.get_round_from_component_id(id) for id in input_id_links_of_chosen_components + ] + links_round = [ + _ + for r in round_input_id_links_of_chosen_components + for _ in backward_cipher._rounds.round_at(r).get_components_ids() + ] + self._backward_cipher = backward_cipher.add_suffix_to_components( + MILP_BACKWARD_SUFFIX, backward_last_round_components | set(links_round) + ) self.build_wordwise_impossible_xor_differential_trail_model(fixed_bits, fixed_words) @@ -227,29 +264,48 @@ def add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_compone inconsistent_vars = [x[f"{forward_component.id}_inconsistent_{_}"] for _ in range(output_size)] # for multiple input components such as the XOR, ensures compatibility occurs on the correct branch - for index, input_id in enumerate(["_".join(i.split("_")[:-1]) if MILP_BACKWARD_SUFFIX in i else i for i in - backward_component.input_id_links]): + for index, input_id in enumerate( + [ + "_".join(i.split("_")[:-1]) if MILP_BACKWARD_SUFFIX in i else i + for i in backward_component.input_id_links + ] + ): if INPUT_KEY not in input_id and self._cipher.get_component_from_id(input_id).input_id_links == [id]: - backward_vars = [x_class[f'{input_id}_{pos}'] for pos in - backward_component.input_bit_positions[index]] + backward_vars = [ + x_class[f"{input_id}_{pos}"] for pos in backward_component.input_bit_positions[index] + ] incompatibility_constraints.extend([sum(inconsistent_vars) == 1]) for inconsistent_index in range(output_size): incompatibility_constraint = [forward_vars[inconsistent_index] + backward_vars[inconsistent_index] <= 2] - incompatibility_constraints.extend(milp_utils.milp_if_then(inconsistent_vars[inconsistent_index], incompatibility_constraint, self._model.get_max(x_class) * 2)) + incompatibility_constraints.extend( + milp_utils.milp_if_then( + inconsistent_vars[inconsistent_index], + incompatibility_constraint, + self._model.get_max(x_class) * 2, + ) + ) # output is fixed cipher_output = [c for c in self._cipher.get_all_components() if c.type == CIPHER_OUTPUT][0] _, cipher_output_ids = cipher_output._get_wordwise_input_output_linked_class(self) - incompatibility_constraints.extend([x_class[id] <= 1 for id in cipher_output_ids] + [sum([x_class[id] for id in cipher_output_ids]) >= 1]) + incompatibility_constraints.extend( + [x_class[id] <= 1 for id in cipher_output_ids] + [sum([x_class[id] for id in cipher_output_ids]) >= 1] + ) # unknown patterns are tuples of the form (1,x) (i.e pattern = 2 or 3) _, forward_output_id_tuple = forward_component._get_wordwise_input_output_linked_class_tuples(self) - optimization_constraint = [p["number_of_unknown_patterns"] == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuple])] + optimization_constraint = [ + p["number_of_unknown_patterns"] + == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuple]) + ] for constraint in self._model_constraints + incompatibility_constraints + optimization_constraint: mip.add_constraint(constraint) - def add_constraints_to_build_fully_automatic_model_in_sage_milp_class(self, fixed_bits=[], fixed_words=[], include_all_components=False): + + def add_constraints_to_build_fully_automatic_model_in_sage_milp_class( + self, fixed_bits=[], fixed_words=[], include_all_components=False + ): """ Take the constraints contained in self._model_constraints and add them to the build-in sage class. @@ -291,11 +347,15 @@ def add_constraints_to_build_fully_automatic_model_in_sage_milp_class(self, fixe mip.add_constraint(constraint) # finding incompatibility - constraints = milp_truncated_utils.generate_all_incompatibility_constraints_for_fully_automatic_model(self, MILP_WORDWISE_IMPOSSIBLE_AUTO, x, x_class, include_all_components) + constraints = milp_truncated_utils.generate_all_incompatibility_constraints_for_fully_automatic_model( + self, MILP_WORDWISE_IMPOSSIBLE_AUTO, x, x_class, include_all_components + ) # decryption input is fixed and non-zero constraints.extend( - [x_class[id] <= 1 for id in self._backward_cipher.inputs] + [sum([x_class[id] for id in self._backward_cipher.inputs]) >= 1]) + [x_class[id] <= 1 for id in self._backward_cipher.inputs] + + [sum([x_class[id] for id in self._backward_cipher.inputs]) >= 1] + ) for constraint in constraints: mip.add_constraint(constraint) @@ -304,10 +364,13 @@ def add_constraints_to_build_fully_automatic_model_in_sage_milp_class(self, fixe forward_output = [c for c in self._forward_cipher.get_all_components() if c.type == CIPHER_OUTPUT][0] _, forward_output_id_tuple = forward_output._get_wordwise_input_output_linked_class_tuples(self) mip.add_constraint( - p["number_of_unknown_patterns"] == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuple])) + p["number_of_unknown_patterns"] + == sum(x[output_msb] for output_msb in [id[0] for id in forward_output_id_tuple]) + ) - - def find_one_wordwise_impossible_xor_differential_trail(self, middle_round=None, fixed_bits=[], fixed_words=[], solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_wordwise_impossible_xor_differential_trail( + self, middle_round=None, fixed_bits=[], fixed_words=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Returns one wordwise impossible XOR differential trail. @@ -338,11 +401,13 @@ def find_one_wordwise_impossible_xor_differential_trail(self, middle_round=None, end = time.time() building_time = end - start solution = self.solve(MILP_WORDWISE_IMPOSSIBLE, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution - def find_one_wordwise_impossible_xor_differential_trail_with_chosen_components(self, component_id_list, fixed_bits=[], fixed_words=[], solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_wordwise_impossible_xor_differential_trail_with_chosen_components( + self, component_id_list, fixed_bits=[], fixed_words=[], solver_name=SOLVER_DEFAULT, external_solver_name=None + ): """ Returns one wordwise impossible XOR differential trail. @@ -369,16 +434,24 @@ def find_one_wordwise_impossible_xor_differential_trail_with_chosen_components(s self._verbose_print(f"Solver used : {solver_name} (Choose Gurobi for Better performance)") mip = self._model mip.set_objective(None) - self.add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components(component_id_list, - fixed_bits, fixed_words) + self.add_constraints_to_build_in_sage_milp_class_with_chosen_incompatible_components( + component_id_list, fixed_bits, fixed_words + ) end = time.time() building_time = end - start solution = self.solve(MILP_WORDWISE_IMPOSSIBLE, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution - def find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model(self, fixed_bits=[], fixed_words=[], include_all_components=False, solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model( + self, + fixed_bits=[], + fixed_words=[], + include_all_components=False, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Returns one wordwise impossible XOR differential trail. @@ -406,24 +479,27 @@ def find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_mod self._verbose_print(f"Solver used : {solver_name} (Choose Gurobi for Better performance)") mip = self._model mip.set_objective(None) - self.add_constraints_to_build_fully_automatic_model_in_sage_milp_class(fixed_bits, fixed_words, include_all_components=include_all_components) + self.add_constraints_to_build_fully_automatic_model_in_sage_milp_class( + fixed_bits, fixed_words, include_all_components=include_all_components + ) end = time.time() building_time = end - start solution = self.solve(MILP_WORDWISE_IMPOSSIBLE_AUTO, solver_name, external_solver_name) - solution['building_time'] = building_time + solution["building_time"] = building_time return solution def _get_component_values(self, objective_variables, components_variables): - return milp_utils._get_component_values_for_impossible_models(self, objective_variables, components_variables) + return milp_utils._get_component_values_for_impossible_models(self, objective_variables, components_variables) def _parse_solver_output(self): mip = self._model if self._forward_cipher == self._cipher: components_variables = mip.get_values(self._trunc_wordvar) objective_variables = mip.get_values(self._binary_variable) - inconsistent_component_var = \ - [i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i][0] + inconsistent_component_var = [ + i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i + ][0] objective_value = "_".join(inconsistent_component_var.split("_")[:-3]) else: components_variables = mip.get_values(self._trunc_wordvar) @@ -434,19 +510,17 @@ def _parse_solver_output(self): return objective_value, components_values def _get_component_value_weight(self, component_id, components_variables): - wordsize = self._word_size if component_id in self._cipher.inputs: output_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(component_id)] // wordsize - elif self._forward_cipher != self._cipher and component_id.endswith( - MILP_BACKWARD_SUFFIX): + elif self._forward_cipher != self._cipher and component_id.endswith(MILP_BACKWARD_SUFFIX): component = self._backward_cipher.get_component_from_id(component_id) output_size = component.output_bit_size // wordsize - elif self._forward_cipher == self._cipher and component_id.endswith( - MILP_BACKWARD_SUFFIX): + elif self._forward_cipher == self._cipher and component_id.endswith(MILP_BACKWARD_SUFFIX): if component_id in self._backward_cipher.inputs: - output_size = self._backward_cipher.inputs_bit_size[ - self._backward_cipher.inputs.index(component_id)] // wordsize + output_size = ( + self._backward_cipher.inputs_bit_size[self._backward_cipher.inputs.index(component_id)] // wordsize + ) else: component = self._backward_cipher.get_component_from_id(component_id) output_size = component.output_bit_size // wordsize diff --git a/claasp/cipher_modules/models/milp/milp_models/milp_xor_differential_model.py b/claasp/cipher_modules/models/milp/milp_models/milp_xor_differential_model.py index 4cb3d0a9c..56466533c 100644 --- a/claasp/cipher_modules/models/milp/milp_models/milp_xor_differential_model.py +++ b/claasp/cipher_modules/models/milp/milp_models/milp_xor_differential_model.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,27 +21,47 @@ import numpy as np from bitstring import BitArray -from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT from claasp.cipher_modules.models.milp.milp_model import MilpModel -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_XOR_DIFFERENTIAL, MILP_PROBABILITY_SUFFIX, \ - MILP_BUILDING_MESSAGE, MILP_XOR_DIFFERENTIAL_OBJECTIVE, MILP_DEFAULT_WEIGHT_PRECISION -from claasp.cipher_modules.models.milp.utils.utils import _string_to_hex, _get_variables_values_as_string, \ - _filter_fixed_variables, _set_weight_precision -from claasp.cipher_modules.models.utils import integer_to_bit_list, set_component_solution, \ - get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, - WORD_OPERATION, LINEAR_LAYER, SBOX, MIX_COLUMN, INPUT_KEY) +from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_BUILDING_MESSAGE, + MILP_DEFAULT_WEIGHT_PRECISION, + MILP_PROBABILITY_SUFFIX, + MILP_XOR_DIFFERENTIAL_OBJECTIVE, + MILP_XOR_DIFFERENTIAL, +) +from claasp.cipher_modules.models.milp.utils.utils import ( + _filter_fixed_variables, + _get_variables_values_as_string, + _set_weight_precision, + _string_to_hex, +) +from claasp.cipher_modules.models.utils import ( + get_single_key_scenario_format_for_fixed_values, + integer_to_bit_list, + set_component_solution, +) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INPUT_KEY, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class MilpXorDifferentialModel(MilpModel): - def __init__(self, cipher, n_window_heuristic=None, verbose=False): super().__init__(cipher, n_window_heuristic, verbose) self._weight_precision = MILP_DEFAULT_WEIGHT_PRECISION self._has_non_integer_weight = False - def add_constraints_to_build_in_sage_milp_class(self, weight=-1, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - fixed_variables=[]): + def add_constraints_to_build_in_sage_milp_class( + self, weight=-1, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, fixed_variables=[] + ): """ Take the constraints contained in self._model_constraints and add them to the build-in sage class. @@ -76,10 +96,14 @@ def add_constraints_to_build_in_sage_milp_class(self, weight=-1, weight_precisio self.build_xor_differential_trail_model(weight, fixed_variables) mip = self._model p = self._integer_variable - for index, constraint in enumerate(self._model_constraints): + for constraint in self._model_constraints: mip.add_constraint(constraint) - mip.add_constraint(p[MILP_XOR_DIFFERENTIAL_OBJECTIVE] == sum( - p[self._non_linear_component_id[i] + "_probability"] for i in range(len(self._non_linear_component_id)))) + mip.add_constraint( + p[MILP_XOR_DIFFERENTIAL_OBJECTIVE] + == sum( + p[self._non_linear_component_id[i] + "_probability"] for i in range(len(self._non_linear_component_id)) + ) + ) def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): """ @@ -110,15 +134,16 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): if fixed_variables == []: fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) constraints = self.fix_variables_value_constraints(fixed_variables) - component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR'] + component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.milp_xor_differential_propagation_constraints(self) @@ -130,9 +155,14 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): self._variables_list.extend(variables) self._model_constraints.extend(constraints) - def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], - weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_all_xor_differential_trails_with_fixed_weight( + self, + fixed_weight, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return all the XOR differential trails with weight equal to ``fixed_weight`` as a list in standard format. By default, the search is set in the single-key setting. @@ -204,12 +234,12 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed looking_for_other_solutions = 1 while looking_for_other_solutions: try: - f = open(os.devnull, 'w') + f = open(os.devnull, "w") sys.stdout = f solution = self.solve(MILP_XOR_DIFFERENTIAL, solver_name, external_solver_name) sys.stdout = sys.__stdout__ - solution['building_time'] = building_time - solution['test_name'] = "find_all_xor_differential_trails_with_fixed_weight" + solution["building_time"] = building_time + solution["test_name"] = "find_all_xor_differential_trails_with_fixed_weight" self._number_of_trails_found += 1 self._verbose_print(f"trails found : {self._number_of_trails_found}") list_trails.append(solution) @@ -229,7 +259,7 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed self._number_of_trails_found = 0 - return [trail for trail in list_trails if trail['status'] == 'SATISFIABLE'] + return [trail for trail in list_trails if trail["status"] == "SATISFIABLE"] def exclude_variables_value_constraints(self, fixed_variables=[]): """ @@ -280,26 +310,20 @@ def exclude_variables_value_constraints(self, fixed_variables=[]): if fixed_variable["constraint_type"] == "not_equal": for index, bit_position in enumerate(fixed_variable["bit_positions"]): if fixed_variable["bit_values"][index]: - constraints.append(x[component_id + - str(bit_position) + - "_not_equal_" + - str(self._number_of_trails_found)] == 1 - - x[component_id + - '_' + - str(bit_position)]) + constraints.append( + x[f"{component_id}{bit_position}_not_equal_{self._number_of_trails_found}"] + == 1 - x[f"{component_id}_{bit_position}"] + ) else: - constraints.append(x[component_id + - str(bit_position) + - "_not_equal_" + - str(self._number_of_trails_found)] == x[component_id + - '_' + - str(bit_position)]) + constraints.append( + x[f"{component_id}{bit_position}_not_equal_{self._number_of_trails_found}"] + == x[f"{component_id}_{bit_position}"] + ) var_sum = 0 for fixed_variable in fixed_variables: for i in fixed_variable["bit_positions"]: - var_sum += x[fixed_variable["component_id"] + - str(i) + "_not_equal_" + str(self._number_of_trails_found)] + var_sum += x[f"{fixed_variable['component_id']}{i}_not_equal_{self._number_of_trails_found}"] constraints.append(var_sum >= 1) return constraints @@ -324,18 +348,27 @@ def is_single_key(self, fixed_values=[]): """ cipher_inputs = self._cipher.inputs cipher_inputs_bit_size = self._cipher.inputs_bit_size - for fixed_input in [value for value in fixed_values if value['component_id'] in cipher_inputs]: - input_size = cipher_inputs_bit_size[cipher_inputs.index(fixed_input['component_id'])] - if fixed_input['component_id'] == 'key' and fixed_input['constraint_type'] == 'equal' \ - and list(fixed_input['bit_positions']) == list(range(input_size)) \ - and all(v == 0 for v in fixed_input['bit_values']): + for fixed_input in [value for value in fixed_values if value["component_id"] in cipher_inputs]: + input_size = cipher_inputs_bit_size[cipher_inputs.index(fixed_input["component_id"])] + if ( + fixed_input["component_id"] == "key" + and fixed_input["constraint_type"] == "equal" + and list(fixed_input["bit_positions"]) == list(range(input_size)) + and all(v == 0 for v in fixed_input["bit_values"]) + ): return True return False - def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_weight, - fixed_values=[], weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_all_xor_differential_trails_with_weight_at_most( + self, + min_weight, + max_weight, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return all XOR differential trails with weight greater than ``min_weight`` and lower/equal to ``max_weight``. By default, the search is set in the single-key setting. @@ -409,12 +442,12 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w number_new_constraints = len(weight_constraints) while looking_for_other_solutions: try: - f = open(os.devnull, 'w') + f = open(os.devnull, "w") sys.stdout = f solution = self.solve(MILP_XOR_DIFFERENTIAL, solver_name, external_solver_name) sys.stdout = sys.__stdout__ - solution['building_time'] = building_time - solution['test_name'] = "find_all_xor_differential_trails_with_weight_at_most" + solution["building_time"] = building_time + solution["test_name"] = "find_all_xor_differential_trails_with_weight_at_most" self._number_of_trails_found += 1 self._verbose_print(f"trails found : {self._number_of_trails_found}") list_trails.append(solution) @@ -432,10 +465,15 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w mip.remove_constraints(range(number_constraints - number_new_constraints, number_constraints)) self._number_of_trails_found = 0 - return [trail for trail in list_trails if trail['status'] == 'SATISFIABLE'] + return [trail for trail in list_trails if trail["status"] == "SATISFIABLE"] - def find_lowest_weight_xor_differential_trail(self, fixed_values=[], weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=False): + def find_lowest_weight_xor_differential_trail( + self, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=False, + ): """ Return a XOR differential trail with the lowest weight in standard format, i.e. the solver solution. By default, the search is set in the single-key setting. @@ -487,12 +525,18 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], weight_prec end = time.time() building_time = end - start solution = self.solve(MILP_XOR_DIFFERENTIAL, solver_name, external_solver_name) - solution['building_time'] = building_time - solution['test_name'] = "find_lowest_weight_xor_differential_trail" + solution["building_time"] = building_time + solution["test_name"] = "find_lowest_weight_xor_differential_trail" return solution - def find_one_xor_differential_trail(self, fixed_values=[], weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_xor_differential_trail( + self, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return a XOR differential trail, not necessarily the one with the lowest weight. By default, the search is set in the single-key setting. @@ -536,13 +580,19 @@ def find_one_xor_differential_trail(self, fixed_values=[], weight_precision=MILP end = time.time() building_time = end - start solution = self.solve(MILP_XOR_DIFFERENTIAL, solver_name, external_solver_name) - solution['building_time'] = building_time - solution['test_name'] = "find_one_xor_differential_trail" + solution["building_time"] = building_time + solution["test_name"] = "find_one_xor_differential_trail" return solution - def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight, fixed_values=[], weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_xor_differential_trail_with_fixed_weight( + self, + fixed_weight, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return one XOR differential trail with weight equal to ``fixed_weight`` as a list in standard format. By default, the search is set in the single-key setting. @@ -594,8 +644,8 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight, fixed_ end = time.time() building_time = end - start solution = self.solve(MILP_XOR_DIFFERENTIAL, solver_name, external_solver_name) - solution['building_time'] = building_time - solution['test_name'] = "find_one_xor_differential_trail_with_fixed_weight" + solution["building_time"] = building_time + solution["test_name"] = "find_one_xor_differential_trail_with_fixed_weight" return solution @@ -603,23 +653,29 @@ def _get_fixed_variables_from_solution(self, fixed_values, inputs_ids, solution) fixed_variables = [] for input in inputs_ids: input_bit_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(input)] - fixed_variable = {"component_id": input, - "bit_positions": list(range(input_bit_size)), - "constraint_type": "not_equal", - "bit_values": (integer_to_bit_list( - BitArray(solution["components_values"][input]["value"]).int, - input_bit_size, 'big'))} + fixed_variable = { + "component_id": input, + "bit_positions": list(range(input_bit_size)), + "constraint_type": "not_equal", + "bit_values": ( + integer_to_bit_list( + BitArray(solution["components_values"][input]["value"]).int, input_bit_size, "big" + ) + ), + } _filter_fixed_variables(fixed_values, fixed_variable, input) fixed_variables.append(fixed_variable) for component in self._cipher.get_all_components(): output_bit_size = component.output_bit_size - fixed_variable = {"component_id": component.id, - "bit_positions": list(range(output_bit_size)), - "constraint_type": "not_equal", - "bit_values": integer_to_bit_list( - BitArray(solution["components_values"][component.id]["value"]).int, - output_bit_size, 'big')} + fixed_variable = { + "component_id": component.id, + "bit_positions": list(range(output_bit_size)), + "constraint_type": "not_equal", + "bit_values": integer_to_bit_list( + BitArray(solution["components_values"][component.id]["value"]).int, output_bit_size, "big" + ), + } _filter_fixed_variables(fixed_values, fixed_variable, component.id) fixed_variables.append(fixed_variable) @@ -629,22 +685,20 @@ def _get_component_values(self, objective_variables, components_variables): components_values = {} list_component_ids = self._cipher.inputs + self._cipher.get_all_components_ids() for component_id in list_component_ids: - dict_tmp = self._get_component_value_weight(component_id, - objective_variables, components_variables) + dict_tmp = self._get_component_value_weight(component_id, objective_variables, components_variables) components_values[component_id] = dict_tmp return components_values def _parse_solver_output(self): mip = self._model objective_variables = mip.get_values(self._integer_variable) - objective_value = objective_variables[MILP_XOR_DIFFERENTIAL_OBJECTIVE] / float(10 ** self._weight_precision) + objective_value = objective_variables[MILP_XOR_DIFFERENTIAL_OBJECTIVE] / float(10**self._weight_precision) components_variables = mip.get_values(self._binary_variable) components_values = self._get_component_values(objective_variables, components_variables) return objective_value, components_values def _get_component_value_weight(self, component_id, probability_variables, components_variables): - if component_id in self._cipher.inputs: output_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(component_id)] else: @@ -657,15 +711,16 @@ def _get_component_value_weight(self, component_id, probability_variables, compo return final_output - def _get_final_output(self, component_id, components_variables, probability_variables, - suffix_dict): + def _get_final_output(self, component_id, components_variables, probability_variables, suffix_dict): final_output = [] for suffix in suffix_dict.keys(): diff_str = _get_variables_values_as_string(component_id, components_variables, suffix, suffix_dict[suffix]) difference = _string_to_hex(diff_str) weight = 0 if component_id + MILP_PROBABILITY_SUFFIX in probability_variables: - weight = probability_variables[component_id + MILP_PROBABILITY_SUFFIX] / float(10 ** self._weight_precision) + weight = probability_variables[component_id + MILP_PROBABILITY_SUFFIX] / float( + 10**self._weight_precision + ) final_output.append(set_component_solution(value=difference, weight=weight)) return final_output diff --git a/claasp/cipher_modules/models/milp/milp_models/milp_xor_linear_model.py b/claasp/cipher_modules/models/milp/milp_models/milp_xor_linear_model.py index 4f35dd325..de825de4d 100644 --- a/claasp/cipher_modules/models/milp/milp_models/milp_xor_linear_model.py +++ b/claasp/cipher_modules/models/milp/milp_models/milp_xor_linear_model.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -23,30 +23,54 @@ import numpy as np from bitstring import BitArray -from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import \ - update_dictionary_that_contains_xor_inequalities_between_n_input_bits, \ - output_dictionary_that_contains_xor_inequalities from claasp.cipher_modules.models.milp.milp_model import MilpModel -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_XOR_LINEAR, MILP_PROBABILITY_SUFFIX, \ - MILP_BUILDING_MESSAGE, MILP_XOR_LINEAR_OBJECTIVE, MILP_DEFAULT_WEIGHT_PRECISION -from claasp.cipher_modules.models.milp.utils.utils import _get_variables_values_as_string, _string_to_hex, \ - _filter_fixed_variables, _set_weight_precision -from claasp.cipher_modules.models.utils import get_bit_bindings, set_fixed_variables, integer_to_bit_list, \ - set_component_solution, get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import (INTERMEDIATE_OUTPUT, CONSTANT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, - WORD_OPERATION, INPUT_KEY) +from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import ( + output_dictionary_that_contains_xor_inequalities, + update_dictionary_that_contains_xor_inequalities_between_n_input_bits, +) +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_BUILDING_MESSAGE, + MILP_DEFAULT_WEIGHT_PRECISION, + MILP_PROBABILITY_SUFFIX, + MILP_XOR_LINEAR_OBJECTIVE, + MILP_XOR_LINEAR, +) +from claasp.cipher_modules.models.milp.utils.utils import ( + _filter_fixed_variables, + _get_variables_values_as_string, + _set_weight_precision, + _string_to_hex, +) +from claasp.cipher_modules.models.utils import ( + get_bit_bindings, + get_single_key_scenario_format_for_fixed_values, + integer_to_bit_list, + set_component_solution, +) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INPUT_KEY, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SATISFIABLE, + SBOX, + WORD_OPERATION, +) class MilpXorLinearModel(MilpModel): def __init__(self, cipher, n_window_heuristic=None, verbose=False): super().__init__(cipher, n_window_heuristic, verbose) - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, "_".join) self._weight_precision = MILP_DEFAULT_WEIGHT_PRECISION self._has_non_integer_weight = False - def add_constraints_to_build_in_sage_milp_class(self, weight=-1, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - fixed_variables=[]): + def add_constraints_to_build_in_sage_milp_class( + self, weight=-1, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, fixed_variables=[] + ): """ Take the constraints contained in self._model_constraints and add them to the build-in sage class. @@ -83,8 +107,10 @@ def add_constraints_to_build_in_sage_milp_class(self, weight=-1, weight_precisio p = self._integer_variable for constraint in self._model_constraints: mip.add_constraint(constraint) - mip.add_constraint(p[MILP_XOR_LINEAR_OBJECTIVE] == sum( - p[self._non_linear_component_id[i] + "_probability"] for i in range(len(self._non_linear_component_id)))) + mip.add_constraint( + p[MILP_XOR_LINEAR_OBJECTIVE] + == sum(p[f"{component_id}_probability"] for component_id in self._non_linear_component_id) + ) def branch_xor_linear_constraints(self): """ @@ -137,8 +163,9 @@ def branch_xor_linear_constraints(self): constraints.append(constraint >= 2 * x[f"{output_var}_dummy"]) # more than a 3-way fork as in SIMON else: - self.update_xor_linear_constraints_for_more_than_two_bits(constraints, input_vars, number_of_inputs, - output_var, x) + self.update_xor_linear_constraints_for_more_than_two_bits( + constraints, input_vars, number_of_inputs, output_var, x + ) return constraints @@ -166,21 +193,28 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): variables = [] if INPUT_KEY not in [variable["component_id"] for variable in fixed_variables]: self._cipher = self._cipher.remove_key_schedule() - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self.cipher, '_'.join) + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self.cipher, "_".join) if fixed_variables == []: fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) constraints = self.fix_variables_value_xor_linear_constraints(fixed_variables) self._model_constraints = constraints for component in self._cipher.get_all_components(): - component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, - SBOX, MIX_COLUMN, WORD_OPERATION] + component_types = [ + CONSTANT, + INTERMEDIATE_OUTPUT, + CIPHER_OUTPUT, + LINEAR_LAYER, + SBOX, + MIX_COLUMN, + WORD_OPERATION, + ] operation = component.description[0] - operation_types = ["AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB"] + operation_types = ("AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB") if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): variables, constraints = component.milp_xor_linear_mask_propagation_constraints(self) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -242,25 +276,32 @@ def exclude_variables_value_xor_linear_constraints(self, fixed_variables=[]): if fixed_variable["constraint_type"] == "not_equal": for index, bit_position in enumerate(fixed_variable["bit_positions"]): if fixed_variable["bit_values"][index]: - constraints.append(x[component_id + str(bit_position) + '_o' + "_not_equal_" + str( - self._number_of_trails_found)] == 1 - x[component_id + '_' + str(bit_position) + '_o']) + constraints.append( + x[f"{component_id}{bit_position}_o_not_equal_{self._number_of_trails_found}"] + == 1 - x[f"{component_id}_{bit_position}_o"] + ) else: - constraints.append(x[component_id + str(bit_position) + '_o' + "_not_equal_" + - str(self._number_of_trails_found)] == - x[component_id + '_' + str(bit_position) + '_o']) + constraints.append( + x[f"{component_id}{bit_position}_o_not_equal_{self._number_of_trails_found}"] + == x[f"{component_id}_{bit_position}_o"] + ) var_sum = 0 for fixed_variable in fixed_variables: for i in fixed_variable["bit_positions"]: - var_sum += x[ - fixed_variable["component_id"] + str(i) + '_o' + "_not_equal_" + str(self._number_of_trails_found)] + var_sum += x[f"{fixed_variable['component_id']}{i}_o_not_equal_{self._number_of_trails_found}"] constraints.append(var_sum >= 1) return constraints - def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], - weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_all_xor_linear_trails_with_fixed_weight( + self, + fixed_weight, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return all the XOR linear trails with weight equal to ``fixed_weight`` as a solutions list in standard format. By default, the search removes the key schedule, if any. @@ -325,12 +366,12 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value looking_for_other_solutions = 1 while looking_for_other_solutions: try: - f = open(os.devnull, 'w') + f = open(os.devnull, "w") sys.stdout = f solution = self.solve(MILP_XOR_LINEAR, solver_name, external_solver_name) sys.stdout = sys.__stdout__ - solution['building_time'] = building_time - solution['test_name'] = "find_all_xor_linear_trails_with_fixed_weight" + solution["building_time"] = building_time + solution["test_name"] = "find_all_xor_linear_trails_with_fixed_weight" self._number_of_trails_found += 1 self._verbose_print(f"trails found : {self._number_of_trails_found}") list_trails.append(solution) @@ -351,11 +392,17 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value self._number_of_trails_found = 0 - return [trail for trail in list_trails if trail['status'] == 'SATISFIABLE'] - - def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight, fixed_values=[], - weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + return [trail for trail in list_trails if trail["status"] == SATISFIABLE] + + def find_all_xor_linear_trails_with_weight_at_most( + self, + min_weight, + max_weight, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return all XOR linear trails with weight greater than ``min_weight`` and lower than or equal to ``max_weight``. By default, the search removes the key schedule, if any. @@ -424,12 +471,12 @@ def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight, number_new_constraints = len(weight_constraints) while looking_for_other_solutions: try: - f = open(os.devnull, 'w') + f = open(os.devnull, "w") sys.stdout = f solution = self.solve(MILP_XOR_LINEAR, solver_name, external_solver_name) sys.stdout = sys.__stdout__ - solution['building_time'] = building_time - solution['test_name'] = "find_all_xor_linear_trails_with_weight_at_most" + solution["building_time"] = building_time + solution["test_name"] = "find_all_xor_linear_trails_with_weight_at_most" self._number_of_trails_found += 1 self._verbose_print(f"trails found : {self._number_of_trails_found}") list_trails.append(solution) @@ -447,10 +494,15 @@ def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight, mip.remove_constraints(range(number_constraints - number_new_constraints, number_constraints)) self._number_of_trails_found = 0 - return [trail for trail in list_trails if trail['status'] == 'SATISFIABLE'] + return [trail for trail in list_trails if trail["status"] == SATISFIABLE] - def find_lowest_weight_xor_linear_trail(self, fixed_values=[], weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_lowest_weight_xor_linear_trail( + self, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return a XOR linear trail with the lowest weight in standard format, i.e. the solver solution. By default, the search removes the key schedule, if any. @@ -518,13 +570,18 @@ def find_lowest_weight_xor_linear_trail(self, fixed_values=[], weight_precision= end = time.time() building_time = end - start solution = self.solve(MILP_XOR_LINEAR, solver_name, external_solver_name) - solution['building_time'] = building_time - solution['test_name'] = "find_lowest_weight_xor_linear_trail" + solution["building_time"] = building_time + solution["test_name"] = "find_lowest_weight_xor_linear_trail" return solution - def find_one_xor_linear_trail(self, fixed_values=[], weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_xor_linear_trail( + self, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return a XOR linear trail, not necessarily the one with the lowest weight. By default, the search removes the key schedule, if any. @@ -568,14 +625,19 @@ def find_one_xor_linear_trail(self, fixed_values=[], weight_precision=MILP_DEFAU end = time.time() building_time = end - start solution = self.solve(MILP_XOR_LINEAR, solver_name, external_solver_name) - solution['building_time'] = building_time - solution['test_name'] = "find_lowest_weight_xor_linear_trail" + solution["building_time"] = building_time + solution["test_name"] = "find_lowest_weight_xor_linear_trail" return solution - def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values=[], - weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, - solver_name=SOLVER_DEFAULT, external_solver_name=None): + def find_one_xor_linear_trail_with_fixed_weight( + self, + fixed_weight, + fixed_values=[], + weight_precision=MILP_DEFAULT_WEIGHT_PRECISION, + solver_name=SOLVER_DEFAULT, + external_solver_name=None, + ): """ Return one XOR linear trail with weight equal to ``fixed_weight`` as a list in standard format. By default, the search removes the key schedule, if any. @@ -627,8 +689,8 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values end = time.time() building_time = end - start solution = self.solve(MILP_XOR_LINEAR, solver_name, external_solver_name) - solution['building_time'] = building_time - solution['test_name'] = "find_one_xor_linear_trail_with_fixed_weight" + solution["building_time"] = building_time + solution["test_name"] = "find_one_xor_linear_trail_with_fixed_weight" return solution @@ -677,22 +739,26 @@ def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]): component_id = fixed_variable["component_id"] if fixed_variable["constraint_type"] == "equal": for index, bit_position in enumerate(fixed_variable["bit_positions"]): - constraints.append(x[component_id + '_' + str(bit_position) + '_o'] - == fixed_variable["bit_values"][index]) + constraints.append(x[f"{component_id}_{bit_position}_o"] == fixed_variable["bit_values"][index]) else: for index, bit_position in enumerate(fixed_variable["bit_positions"]): if fixed_variable["bit_values"][index]: constraints.append( - x[component_id + str(bit_position) + '_o' + "_not_equal_" + - str(self._number_of_trails_found)] == - 1 - x[component_id + '_' + str(bit_position) + '_o']) + x[f"{component_id}{bit_position}_o_not_equal_{self._number_of_trails_found}"] + == 1 - x[f"{component_id}_{bit_position}_o"] + ) else: constraints.append( - x[component_id + str(bit_position) + '_o' + "_not_equal_" + - str(self._number_of_trails_found)] == x[component_id + '_' + str(bit_position) + '_o']) - constraints.append(sum( - x[component_id + str(i) + '_o' + "_not_equal_" + str(self._number_of_trails_found)] for i in - fixed_variable["bit_positions"]) >= 1) + x[f"{component_id}{bit_position}_o_not_equal_{self._number_of_trails_found}"] + == x[f"{component_id}_{bit_position}_o"] + ) + constraints.append( + sum( + x[f"{component_id}{i}_o_not_equal_{self._number_of_trails_found}"] + for i in fixed_variable["bit_positions"] + ) + >= 1 + ) return constraints @@ -700,29 +766,34 @@ def _get_fixed_variables_from_solution(self, fixed_values, inputs_ids, solution) fixed_variables = [] for input in inputs_ids: input_bit_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(input)] - fixed_variable = {"component_id": input, - "bit_positions": list(range(input_bit_size)), - "constraint_type": "not_equal", - "bit_values": integer_to_bit_list( - BitArray(solution["components_values"][input]["value"]).int, - input_bit_size, 'big')} + fixed_variable = { + "component_id": input, + "bit_positions": list(range(input_bit_size)), + "constraint_type": "not_equal", + "bit_values": integer_to_bit_list( + BitArray(solution["components_values"][input]["value"]).int, input_bit_size, "big" + ), + } _filter_fixed_variables(fixed_values, fixed_variable, input) fixed_variables.append(fixed_variable) for component in self._cipher.get_all_components(): output_bit_size = component.output_bit_size - fixed_variable = {"component_id": component.id, - "bit_positions": list(range(output_bit_size)), - "constraint_type": "not_equal", - "bit_values": integer_to_bit_list( - BitArray(solution["components_values"][component.id + "_o"]["value"]).int, - output_bit_size, 'big')} + fixed_variable = { + "component_id": component.id, + "bit_positions": list(range(output_bit_size)), + "constraint_type": "not_equal", + "bit_values": integer_to_bit_list( + BitArray(solution["components_values"][f"{component.id}_o"]["value"]).int, output_bit_size, "big" + ), + } _filter_fixed_variables(fixed_values, fixed_variable, component.id) fixed_variables.append(fixed_variable) return fixed_variables - def update_xor_linear_constraints_for_more_than_two_bits(self, constraints, input_vars, - number_of_inputs, output_var, x): + def update_xor_linear_constraints_for_more_than_two_bits( + self, constraints, input_vars, number_of_inputs, output_var, x + ): update_dictionary_that_contains_xor_inequalities_between_n_input_bits(number_of_inputs) dict_inequalities = output_dictionary_that_contains_xor_inequalities() inequalities = dict_inequalities[number_of_inputs] @@ -774,28 +845,26 @@ def _get_component_values(self, objective_variables, components_variables): components_values = {} list_component_ids = self._cipher.inputs + self._cipher.get_all_components_ids() for component_id in list_component_ids: - dict_tmp = self._get_component_value_weight(component_id, - objective_variables, components_variables) + dict_tmp = self._get_component_value_weight(component_id, objective_variables, components_variables) if component_id in self._cipher.inputs: components_values[component_id] = dict_tmp[1] - elif 'cipher_output' not in component_id: - components_values[component_id + '_i'] = dict_tmp[0] - components_values[component_id + '_o'] = dict_tmp[1] + elif "cipher_output" not in component_id: + components_values[f"{component_id}_i"] = dict_tmp[0] + components_values[f"{component_id}_o"] = dict_tmp[1] else: - components_values[component_id + '_o'] = dict_tmp[1] + components_values[f"{component_id}_o"] = dict_tmp[1] return components_values def _parse_solver_output(self): mip = self._model objective_variables = mip.get_values(self._integer_variable) - objective_value = objective_variables[MILP_XOR_LINEAR_OBJECTIVE] / float(10 ** self._weight_precision) + objective_value = objective_variables[MILP_XOR_LINEAR_OBJECTIVE] / float(10**self._weight_precision) components_variables = mip.get_values(self._binary_variable) components_values = self._get_component_values(objective_variables, components_variables) return objective_value, components_values def _get_component_value_weight(self, component_id, probability_variables, components_variables): - if component_id in self._cipher.inputs: output_size = self._cipher.inputs_bit_size[self._cipher.inputs.index(component_id)] input_size = output_size @@ -810,18 +879,17 @@ def _get_component_value_weight(self, component_id, probability_variables, compo return final_output - def _get_final_output(self, component_id, components_variables, probability_variables, - suffix_dict): + def _get_final_output(self, component_id, components_variables, probability_variables, suffix_dict): final_output = [] for suffix in suffix_dict.keys(): mask_str = _get_variables_values_as_string(component_id, components_variables, suffix, suffix_dict[suffix]) mask = _string_to_hex(mask_str) bias = 0 if component_id + MILP_PROBABILITY_SUFFIX in probability_variables: - bias = probability_variables[component_id + MILP_PROBABILITY_SUFFIX] / float(10 ** self._weight_precision) + bias = probability_variables[component_id + MILP_PROBABILITY_SUFFIX] / float(10**self._weight_precision) final_output.append(set_component_solution(mask, bias, sign=1)) return final_output @property def weight_precision(self): - return self._weight_precision \ No newline at end of file + return self._weight_precision diff --git a/claasp/cipher_modules/models/milp/solvers.py b/claasp/cipher_modules/models/milp/solvers.py index e8304330a..5e2fa6ea3 100644 --- a/claasp/cipher_modules/models/milp/solvers.py +++ b/claasp/cipher_modules/models/milp/solvers.py @@ -1,33 +1,37 @@ -import os - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** +import os SOLVER_DEFAULT = "GLPK" MODEL_DEFAULT_PATH = os.getcwd() - MILP_SOLVERS_INTERNAL = [ {"solver_brand_name": "GLPK (GNU Linear Programming Kit) (using Sage backend)", "solver_name": "GLPK"}, - {"solver_brand_name": "GLPK (GNU Linear Programming Kit) with simplex method based on exact arithmetic (using Sage backend)", "solver_name": "GLPK/exact"}, + { + "solver_brand_name": "GLPK (GNU Linear Programming Kit) with simplex method based on exact arithmetic (using Sage backend)", + "solver_name": "GLPK/exact", + }, {"solver_brand_name": "COIN-BC (COIN Branch and Cut) (using Sage backend)", "solver_name": "Coin"}, - {"solver_brand_name": "CVXOPT (Python Software for Convex Optimization) (using Sage backend)", "solver_name": "CVXOPT"}, + { + "solver_brand_name": "CVXOPT (Python Software for Convex Optimization) (using Sage backend)", + "solver_name": "CVXOPT", + }, {"solver_brand_name": "Gurobi Optimizer (using Sage backend)", "solver_name": "Gurobi"}, {"solver_brand_name": "PPL (Parma Polyhedra Library) (using Sage backend)", "solver_name": "PPL"}, {"solver_brand_name": "InteractiveLP (using Sage backend)", "solver_name": "InteractiveLP"}, diff --git a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_and_operation_2_input_bits.py b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_and_operation_2_input_bits.py index 5af3f6af1..6d3ae646c 100644 --- a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_and_operation_2_input_bits.py +++ b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_and_operation_2_input_bits.py @@ -1,23 +1,23 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** """The target of this module is to generate MILP inequalities for a AND operation between 2 input bits.""" + from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT @@ -37,13 +37,7 @@ def and_inequalities(): def and_LAT(): - valid_points = [ - [0, 0, 0], - [0, 0, 1], - [0, 1, 1], - [1, 0, 1], - [1, 1, 1] - ] + valid_points = [[0, 0, 0], [0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1]] chosen_ineqs = cutting_off_greedy(valid_points) return chosen_ineqs @@ -78,15 +72,11 @@ def cutting_off_greedy(valid_points): poly = convex_hull(valid_points) poly_points = poly.integral_points() remaining_ineqs = list(poly.inequalities()) - impossible = [vector(poly.base_ring(), v) - for v in VectorSpace(GF(2), poly.ambient_dim()) - if v not in poly_points] + impossible = [vector(poly.base_ring(), v) for v in VectorSpace(GF(2), poly.ambient_dim()) if v not in poly_points] while impossible != []: - if len(remaining_ineqs) == 0: - raise ValueError("no more inequalities to choose, but still " - "%d impossible points left" % len(impossible)) + raise ValueError("no more inequalities to choose, but still %d impossible points left" % len(impossible)) # find inequality in remaining_ineqs that cuts off the most # impossible points and add this to the chosen_ineqs @@ -100,10 +90,7 @@ def cutting_off_greedy(valid_points): remaining_ineqs.remove(chosen_ineqs[-1]) # remove all cut off impossible points - impossible = [v - for v in impossible - if chosen_ineqs[-1].contains(v) - ] + impossible = [v for v in impossible if chosen_ineqs[-1].contains(v)] return chosen_ineqs @@ -132,16 +119,10 @@ def cutting_off_milp(valid_points, number_of_ineqs=None): poly = convex_hull(valid_points) ineqs = list(poly.inequalities()) poly_points = poly.integral_points() - impossible = [vector(poly.base_ring(), v) - for v in VectorSpace(GF(2), poly.ambient_dim()) - if v not in poly_points] + impossible = [vector(poly.base_ring(), v) for v in VectorSpace(GF(2), poly.ambient_dim()) if v not in poly_points] # precompute which inequality removes which impossible point - precomputation = matrix( - [[int(not (ineq.contains(p))) - for p in impossible] - for ineq in ineqs] - ) + precomputation = matrix([[int(not (ineq.contains(p))) for p in impossible] for ineq in ineqs]) milp = MixedIntegerLinearProgram(maximization=False, solver=SOLVER_DEFAULT) var_ineqs = milp.new_variable(binary=True, name="ineqs") @@ -150,26 +131,17 @@ def cutting_off_milp(valid_points, number_of_ineqs=None): milp.set_objective(sum([var_ineqs[i] for i in range(len(ineqs))])) # or the given number else: - milp.add_constraint(sum( - [var_ineqs[i] - for i in range(len(ineqs))] - ) == number_of_ineqs) + milp.add_constraint(sum([var_ineqs[i] for i in range(len(ineqs))]) == number_of_ineqs) nrows, ncols = precomputation.dimensions() for c in range(ncols): - lhs = sum([var_ineqs[r] - for r in range(nrows) - if precomputation[r][c] == 1]) + lhs = sum([var_ineqs[r] for r in range(nrows) if precomputation[r][c] == 1]) # milp.add_constraint(lhs >= 1) - if (not isinstance(lhs, int)): + if not isinstance(lhs, int): milp.add_constraint(lhs >= 1) milp.solve() - remaining_ineqs = [ - ineq - for ineq, (var, val) in zip(ineqs, milp.get_values(var_ineqs).items()) - if val == 1 - ] + remaining_ineqs = [ineq for ineq, (var, val) in zip(ineqs, milp.get_values(var_ineqs).items()) if val == 1] return remaining_ineqs diff --git a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_large_sboxes.py b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_large_sboxes.py index 462c1276c..f10492fca 100644 --- a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_large_sboxes.py +++ b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_large_sboxes.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -22,24 +21,19 @@ The logic minimizer espresso is required for this module. It is already installed in the docker. """ + import pickle, os import subprocess from claasp.cipher_modules.models.milp import MILP_AUXILIARY_FILE_PATH -from sage.rings.integer_ring import ZZ large_sbox_file_name = "dictionary_that_contains_inequalities_for_large_sboxes.obj" large_sbox_xor_linear_file_name = "dictionary_that_contains_inequalities_for_large_sboxes_xor_linear.obj" large_sboxes_inequalities_file_path = os.path.join(MILP_AUXILIARY_FILE_PATH, large_sbox_file_name) -large_sboxes_xor_linear_inequalities_file_path = os.path.join(MILP_AUXILIARY_FILE_PATH, - large_sbox_xor_linear_file_name) +large_sboxes_xor_linear_inequalities_file_path = os.path.join(MILP_AUXILIARY_FILE_PATH, large_sbox_xor_linear_file_name) def generate_espresso_input(input_size, output_size, value, valid_transformations_matrix): - # little_endian - def to_bits(x, size): - return ZZ(x).digits(base=2, padto=size)[::-1] - espresso_input = [f"# there are {input_size + output_size} input variables\n"] espresso_input.append(f".i {input_size + output_size}") espresso_input.append("# there is only 1 output result\n") @@ -49,7 +43,7 @@ def to_bits(x, size): n, m = input_size, output_size for i in range(0, 1 << n): for o in range(0, 1 << m): - io = "".join([str(i) for i in to_bits(i, input_size) + to_bits(o, output_size)]) + io = f"{i:0{input_size}b}{o:0{output_size}b}" if i + o > 0 and valid_transformations_matrix[i][o] == value: espresso_input.append(f"{io} 1\n") else: @@ -58,10 +52,10 @@ def to_bits(x, size): espresso_input.append("# end of the PLA data\n") espresso_input.append(".e") - return ''.join(espresso_input) + return "".join(espresso_input) -def generate_product_of_sum_from_espresso(sbox, analysis="differential"): +def generate_product_of_sum_from_espresso(sbox, analysis="differential"): dict_espresso_outputs = {} if analysis == "differential": valid_transformations_matrix = sbox.difference_distribution_table() @@ -75,9 +69,12 @@ def generate_product_of_sum_from_espresso(sbox, analysis="differential"): raise TypeError("analysis (%s) has to be one of ['differential', 'linear']" % (analysis,)) for value in values_in_matrix: - espresso_input = generate_espresso_input(sbox.input_size(), sbox.output_size(), value, valid_transformations_matrix) - espresso_process = subprocess.run(['espresso', '-epos', '-okiss'], input=espresso_input, - capture_output=True, text=True) + espresso_input = generate_espresso_input( + sbox.input_size(), sbox.output_size(), value, valid_transformations_matrix + ) + espresso_process = subprocess.run( + ["espresso", "-epos", "-okiss"], input=espresso_input, capture_output=True, text=True + ) espresso_output = espresso_process.stdout.splitlines() dict_espresso_outputs[value] = [line[:-2] for line in espresso_output[4:]] @@ -96,20 +93,27 @@ def get_dictionary_that_contains_inequalities_for_large_sboxes(analysis="differe - then Espresso is used to compute the minimum product-of-sum representation of each pb-DDT, seen as a boolean function """ - file_path = large_sboxes_inequalities_file_path if analysis == "differential" else large_sboxes_xor_linear_inequalities_file_path + file_path = ( + large_sboxes_inequalities_file_path + if analysis == "differential" + else large_sboxes_xor_linear_inequalities_file_path + ) - read_file = open(file_path, 'rb') + read_file = open(file_path, "rb") inequalities = pickle.load(read_file) read_file.close() return inequalities def update_dictionary_that_contains_inequalities_for_large_sboxes(sbox, analysis="differential"): - - file_path = large_sboxes_inequalities_file_path if analysis == "differential" else large_sboxes_xor_linear_inequalities_file_path + file_path = ( + large_sboxes_inequalities_file_path + if analysis == "differential" + else large_sboxes_xor_linear_inequalities_file_path + ) try: - read_file = open(file_path, 'rb') + read_file = open(file_path, "rb") dictio = pickle.load(read_file) read_file.close() except OSError: @@ -120,13 +124,17 @@ def update_dictionary_that_contains_inequalities_for_large_sboxes(sbox, analysis dict_product_of_sum = generate_product_of_sum_from_espresso(sbox, analysis) dictio[str(sbox)] = dict_product_of_sum - write_file = open(file_path, 'wb') + write_file = open(file_path, "wb") pickle.dump(dictio, write_file) write_file.close() def delete_dictionary_that_contains_inequalities_for_large_sboxes(analysis="differential"): - file_path = large_sboxes_inequalities_file_path if analysis == "differential" else large_sboxes_xor_linear_inequalities_file_path - write_file = open(file_path, 'wb') + file_path = ( + large_sboxes_inequalities_file_path + if analysis == "differential" + else large_sboxes_xor_linear_inequalities_file_path + ) + write_file = open(file_path, "wb") pickle.dump({}, write_file) write_file.close() diff --git a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrices.py b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrices.py index e8dbc7163..e83598229 100644 --- a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrices.py +++ b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrices.py @@ -22,6 +22,7 @@ using model 5 from https://tosc.iacr.org/index.php/ToSC/article/view/8702/8294 """ + from itertools import product from math import ceil, log import pickle, os @@ -32,10 +33,9 @@ wordwise_truncated_mds_file_path = os.path.join(MILP_AUXILIARY_FILE_PATH, wordwise_truncated_mds_file_name) - -def generate_valid_points_for_truncated_mds_matrix(dimensions=(4,4), max_pattern_value=3): +def generate_valid_points_for_truncated_mds_matrix(dimensions=(4, 4), max_pattern_value=3): """ - Model 5 from https://tosc.iacr.org/index.php/ToSC/article/view/8702/8294 + Model 5 from https://tosc.iacr.org/index.php/ToSC/article/view/8702/8294 """ nrows, ncols = dimensions @@ -53,19 +53,19 @@ def generate_valid_points_for_truncated_mds_matrix(dimensions=(4,4), max_pattern else: delta_output = [3 for _ in range(nrows)] - tmp = ''.join(format(delta[i], '0' + str(bit_len) + 'b') for i in range(ncols)) + \ - ''.join(format(delta_output[i], '0' + str(bit_len) + 'b') for i in range(nrows)) + tmp = "".join(format(delta[i], "0" + str(bit_len) + "b") for i in range(ncols)) + "".join( + format(delta_output[i], "0" + str(bit_len) + "b") for i in range(nrows) + ) valid_points.append(tmp) else: raise NotImplementedError - return valid_points -def update_dictionary_that_contains_wordwise_truncated_mds_inequalities(wordsize=8, dimensions=(4,4)): +def update_dictionary_that_contains_wordwise_truncated_mds_inequalities(wordsize=8, dimensions=(4, 4)): try: - read_file = open(wordwise_truncated_mds_file_path, 'rb') + read_file = open(wordwise_truncated_mds_file_path, "rb") dictio = pickle.load(read_file) read_file.close() except OSError: @@ -76,11 +76,12 @@ def update_dictionary_that_contains_wordwise_truncated_mds_inequalities(wordsize if dimensions not in dictio[wordsize].keys(): print( - f"Adding wordwise mds inequalities for {dimensions[0]} x {dimensions[1]} matrices for words of {wordsize} bits in pre-saved dictionary") + f"Adding wordwise mds inequalities for {dimensions[0]} x {dimensions[1]} matrices for words of {wordsize} bits in pre-saved dictionary" + ) valid_points = generate_valid_points_for_truncated_mds_matrix(dimensions) inequalities = milp_utils.generate_product_of_sum_from_espresso(valid_points) dictio[wordsize][dimensions] = inequalities - write_file = open(wordwise_truncated_mds_file_path, 'wb') + write_file = open(wordwise_truncated_mds_file_path, "wb") pickle.dump(dictio, write_file) write_file.close() diff --git a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits.py b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits.py index fe9f80f84..9c33a1d0e 100644 --- a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits.py +++ b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits.py @@ -21,6 +21,7 @@ The target of this module is to generate MILP inequalities for a wordwise truncated XOR operation between n input words. """ + import itertools from math import ceil, log import pickle, os @@ -31,21 +32,24 @@ input_patterns_file_name = "dictionary_containing_truncated_input_pattern_inequalities.obj" xor_n_inputs_file_name = "dictionary_containing_truncated_xor_inequalities_between_n_input_bits.obj" -wordwise_truncated_input_pattern_inequalities_file_path = os.path.join(MILP_AUXILIARY_FILE_PATH, input_patterns_file_name) -wordwise_truncated_xor_inequalities_between_n_input_bits_file_path = os.path.join(MILP_AUXILIARY_FILE_PATH, xor_n_inputs_file_name) - +wordwise_truncated_input_pattern_inequalities_file_path = os.path.join( + MILP_AUXILIARY_FILE_PATH, input_patterns_file_name +) +wordwise_truncated_xor_inequalities_between_n_input_bits_file_path = os.path.join( + MILP_AUXILIARY_FILE_PATH, xor_n_inputs_file_name +) def generate_valid_points_input_words(wordsize=4, max_pattern_value=3): """ - Model 1 from https://tosc.iacr.org/index.php/ToSC/article/view/8702/8294 + Model 1 from https://tosc.iacr.org/index.php/ToSC/article/view/8702/8294 - delta | zeta - ------------ - 0 | Z (0) - 1 | N (> 0) - 2 | N* - 3 | U + delta | zeta + ------------ + 0 | Z (0) + 1 | N (> 0) + 2 | N* + 3 | U """ @@ -54,14 +58,10 @@ def generate_valid_points_input_words(wordsize=4, max_pattern_value=3): valid_points = [] if max_pattern_value == 3: - - list_of_possible_inputs = [(0, 0)] + \ - [(1, i) for i in range(1, 1 << wordsize)] + \ - [(2, 0)] + [(3, 0)] + list_of_possible_inputs = [(0, 0)] + [(1, i) for i in range(1, 1 << wordsize)] + [(2, 0)] + [(3, 0)] for delta, zeta in list_of_possible_inputs: - tmp = ''.join(format(delta, '0' + str(bit_len) + 'b') + - format(zeta, '0' + str(wordsize) + 'b')) + tmp = "".join(format(delta, "0" + str(bit_len) + "b") + format(zeta, "0" + str(wordsize) + "b")) valid_points.append(tmp) else: raise NotImplementedError @@ -71,7 +71,7 @@ def generate_valid_points_input_words(wordsize=4, max_pattern_value=3): def update_dictionary_that_contains_wordwise_truncated_input_inequalities(wordsize): try: - read_file = open(wordwise_truncated_input_pattern_inequalities_file_path, 'rb') + read_file = open(wordwise_truncated_input_pattern_inequalities_file_path, "rb") dictio = pickle.load(read_file) read_file.close() except OSError: @@ -82,7 +82,7 @@ def update_dictionary_that_contains_wordwise_truncated_input_inequalities(wordsi valid_points = generate_valid_points_input_words(wordsize) inequalities = generate_product_of_sum_from_espresso(valid_points) dictio[wordsize] = inequalities - write_file = open(wordwise_truncated_input_pattern_inequalities_file_path, 'wb') + write_file = open(wordwise_truncated_input_pattern_inequalities_file_path, "wb") pickle.dump(dictio, write_file) write_file.close() @@ -96,7 +96,6 @@ def delete_dictionary_that_contains_wordwise_truncated_input_inequalities(): def get_valid_points_for_wordwise_xor(delta_in_1, zeta_in_1, delta_in_2, zeta_in_2): - zeta_out = 0 if delta_in_1 + delta_in_2 > 2: delta_out = 3 @@ -117,43 +116,42 @@ def get_valid_points_for_wordwise_xor(delta_in_1, zeta_in_1, delta_in_2, zeta_in return delta_out, zeta_out + def generate_valid_points_for_xor_between_n_input_words(wordsize=4, number_of_words=2): """ - Model 2 from https://tosc.iacr.org/index.php/ToSC/article/view/8702/8294 - - For the wordwise truncated xor between two inputs, the file is: - - # there are 6 input variables - .i 6# there is only 1 output result - .o 1 - # the following is the truth table - 000000 1 - 000101 1 - 001010 1 - 001111 1 - 010001 1 - 010100 1 - 010101 1 - 011011 1 - 011111 1 - 100010 1 - 100111 1 - 101011 1 - 101111 1 - 110011 1 - 110111 1 - 111011 1 - 111111 1 - # end of the PLA data - .e + Model 2 from https://tosc.iacr.org/index.php/ToSC/article/view/8702/8294 + + For the wordwise truncated xor between two inputs, the file is: + + # there are 6 input variables + .i 6# there is only 1 output result + .o 1 + # the following is the truth table + 000000 1 + 000101 1 + 001010 1 + 001111 1 + 010001 1 + 010100 1 + 010101 1 + 011011 1 + 011111 1 + 100010 1 + 100111 1 + 101011 1 + 101111 1 + 110011 1 + 110111 1 + 111011 1 + 111111 1 + # end of the PLA data + .e """ bit_len = 2 valid_points = [] - list_of_possible_inputs = [(0, 0)] + \ - [(1, i) for i in range(1, 1 << wordsize)] + \ - [(2, -1)] + [(3, -2)] + list_of_possible_inputs = [(0, 0)] + [(1, i) for i in range(1, 1 << wordsize)] + [(2, -1)] + [(3, -2)] for input in itertools.product(list_of_possible_inputs, repeat=number_of_words): delta = [input[_][0] for _ in range(number_of_words)] @@ -165,10 +163,9 @@ def generate_valid_points_for_xor_between_n_input_words(wordsize=4, number_of_wo tmp_zeta[0] = zeta[0] for summand in range(number_of_words - 2): - tmp_delta[summand + 1], tmp_zeta[summand + 1] = get_valid_points_for_wordwise_xor(tmp_delta[summand], - tmp_zeta[summand], - delta[summand + 1], - zeta[summand + 1]) + tmp_delta[summand + 1], tmp_zeta[summand + 1] = get_valid_points_for_wordwise_xor( + tmp_delta[summand], tmp_zeta[summand], delta[summand + 1], zeta[summand + 1] + ) delta_output, zeta_output = get_valid_points_for_wordwise_xor(tmp_delta[-1], tmp_zeta[-1], delta[-1], zeta[-1]) zeta_output = max(0, zeta_output) @@ -178,11 +175,15 @@ def generate_valid_points_for_xor_between_n_input_words(wordsize=4, number_of_wo if reduce(lambda a, b: a ^ b, only_fixed_patterns) == 0: delta_output = 2 - tmp = ''.join(format(delta[i], '0' + str(bit_len) + 'b') + - format(zeta[i] if (delta[i] == 1) else 0, '0' + str(wordsize) + 'b') for i in - range(number_of_words)) + \ - format(delta_output, '0' + str(bit_len) + 'b') + \ - format(zeta_output, '0' + str(wordsize) + 'b') + tmp = ( + "".join( + format(delta[i], "0" + str(bit_len) + "b") + + format(zeta[i] if (delta[i] == 1) else 0, "0" + str(wordsize) + "b") + for i in range(number_of_words) + ) + + format(delta_output, "0" + str(bit_len) + "b") + + format(zeta_output, "0" + str(wordsize) + "b") + ) valid_points.append(tmp) @@ -191,7 +192,7 @@ def generate_valid_points_for_xor_between_n_input_words(wordsize=4, number_of_wo def update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs(wordsize, number_of_inputs): try: - read_file = open(wordwise_truncated_xor_inequalities_between_n_input_bits_file_path, 'rb') + read_file = open(wordwise_truncated_xor_inequalities_between_n_input_bits_file_path, "rb") dictio = pickle.load(read_file) read_file.close() except OSError: @@ -202,11 +203,12 @@ def update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_ if number_of_inputs not in dictio[wordsize].keys(): print( - f"Adding wordwise xor inequalities between {number_of_inputs} inputs of size {wordsize} in pre-saved dictionary") + f"Adding wordwise xor inequalities between {number_of_inputs} inputs of size {wordsize} in pre-saved dictionary" + ) valid_points = generate_valid_points_for_xor_between_n_input_words(wordsize, number_of_inputs) inequalities = milp_utils.generate_product_of_sum_from_espresso(valid_points) dictio[wordsize][number_of_inputs] = inequalities - write_file = open(wordwise_truncated_xor_inequalities_between_n_input_bits_file_path, 'wb') + write_file = open(wordwise_truncated_xor_inequalities_between_n_input_bits_file_path, "wb") pickle.dump(dictio, write_file) write_file.close() @@ -223,10 +225,14 @@ def update_dictionary_that_contains_xor_inequalities_for_specific_wordwise_matri number_of_1_in_each_cols.append(number_of_1) number_of_1_in_each_cols = list(set(number_of_1_in_each_cols)) for number_of_input_bits in number_of_1_in_each_cols: - update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs(wordsize, number_of_input_bits) + update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs( + wordsize, number_of_input_bits + ) + def output_dictionary_that_contains_wordwise_truncated_xor_inequalities(): return milp_utils.output_espresso_dictionary(wordwise_truncated_xor_inequalities_between_n_input_bits_file_path) + def delete_dictionary_that_contains_wordwise_truncated_xor_inequalities(): - return milp_utils.delete_espresso_dictionary(wordwise_truncated_xor_inequalities_between_n_input_bits_file_path) \ No newline at end of file + return milp_utils.delete_espresso_dictionary(wordwise_truncated_xor_inequalities_between_n_input_bits_file_path) diff --git a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_xor_with_n_input_bits.py b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_xor_with_n_input_bits.py index c883bfb23..4809587d2 100644 --- a/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_xor_with_n_input_bits.py +++ b/claasp/cipher_modules/models/milp/utils/generate_inequalities_for_xor_with_n_input_bits.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -20,6 +19,7 @@ """ The target of this module is to generate MILP inequalities for a XOR operation between n input bits. """ + import pickle, os from claasp.cipher_modules.models.milp import MILP_AUXILIARY_FILE_PATH @@ -56,9 +56,7 @@ def update_dictionary_that_contains_xor_inequalities_between_n_input_bits(number if number_of_input_bits not in dictio.keys(): print(f"Adding xor inequalities between {number_of_input_bits} input bits in pre-saved dictionary") dictio[number_of_input_bits] = generate_impossible_points_for_xor_between_n_input_bits(number_of_input_bits) - write_file = open( - xor_inequalities_between_n_input_bits_file_path, - 'wb') + write_file = open(xor_inequalities_between_n_input_bits_file_path, "wb") pickle.dump(dictio, write_file) write_file.close() @@ -80,7 +78,7 @@ def update_dictionary_that_contains_xor_inequalities_for_specific_matrix(mat): def output_dictionary_that_contains_xor_inequalities(): try: - read_file = open(xor_inequalities_between_n_input_bits_file_path, 'rb') + read_file = open(xor_inequalities_between_n_input_bits_file_path, "rb") dictio = pickle.load(read_file) read_file.close() except (OSError, EOFError): @@ -89,6 +87,6 @@ def output_dictionary_that_contains_xor_inequalities(): def delete_dictionary_that_contains_xor_inequalities(): - write_file = open(xor_inequalities_between_n_input_bits_file_path, 'wb') + write_file = open(xor_inequalities_between_n_input_bits_file_path, "wb") pickle.dump({}, write_file) - write_file.close() \ No newline at end of file + write_file.close() diff --git a/claasp/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search.py b/claasp/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search.py index 8cfc23401..8e1553961 100644 --- a/claasp/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search.py +++ b/claasp/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - """ The target of this module is to generate MILP inequalities for small sboxes (4 bits) by using the convex hull method. @@ -26,21 +24,21 @@ The module generate_inequalities_for_large_sboxes.py take care of both cases, small and large sboxes. Hence, this module can be removed, but we decide to keep it for comparison purpose. """ -import pickle, os -from claasp.cipher_modules.models.milp import MILP_AUXILIARY_FILE_PATH -from sage.rings.integer_ring import ZZ +import os +import pickle + +from claasp.cipher_modules.models.milp import MILP_AUXILIARY_FILE_PATH from claasp.cipher_modules.models.milp.solvers import SOLVER_DEFAULT small_sbox_file_name = "dictionary_that_contains_inequalities_for_small_sboxes.obj" small_sbox_xor_linear_file_name = "dictionary_that_contains_inequalities_for_small_sboxes_xor_linear.obj" inequalities_for_small_sboxes_path = os.path.join(MILP_AUXILIARY_FILE_PATH, small_sbox_file_name) -inequalities_for_small_sboxes_xor_linear_path = os.path.join(MILP_AUXILIARY_FILE_PATH, - small_sbox_xor_linear_file_name) +inequalities_for_small_sboxes_xor_linear_path = os.path.join(MILP_AUXILIARY_FILE_PATH, small_sbox_xor_linear_file_name) -def sbox_inequalities(sbox, analysis="differential", algorithm="milp", big_endian=False): +def sbox_inequalities(sbox, analysis="differential", algorithm="milp"): """ Compute inequalities for modeling the given S-box. @@ -50,7 +48,6 @@ def sbox_inequalities(sbox, analysis="differential", algorithm="milp", big_endia - ``analysis`` -- **string** (default: `differential`); choosing between 'differential' and 'linear' cryptanalysis - ``algorithm`` -- **string** (default: `greedy`); choosing the algorithm for computing the S-box model, one of ['none', 'greedy', 'milp'] - - ``big_endian`` -- **boolean** (default: `False`); representation of transitions vectors EXAMPLES:: @@ -61,7 +58,7 @@ def sbox_inequalities(sbox, analysis="differential", algorithm="milp", big_endia sage: sbox_ineqs[2][1] An inequality (0, 0, 0, 1, 1, 0, 1, 0) x - 1 >= 0 """ - ch = convex_hull(sbox, analysis, big_endian) + ch = convex_hull(sbox, analysis) if algorithm == "greedy": return cutting_off_greedy(ch) elif algorithm == "milp": @@ -72,7 +69,7 @@ def sbox_inequalities(sbox, analysis="differential", algorithm="milp", big_endia raise ValueError("algorithm (%s) has to be one of ['greedy', 'milp']" % (algorithm,)) -def convex_hull(sbox, analysis="differential", big_endian=False): +def convex_hull(sbox, analysis="differential"): """ Compute the convex hull of the differential or linear behaviour of the given S-box. @@ -80,7 +77,6 @@ def convex_hull(sbox, analysis="differential", big_endian=False): - ``sbox`` -- **SBox object**; the S-box for which the convex hull should be computed - ``analysis`` -- **string** (default: `differential`); choosing between differential and linear behaviour - - ``big_endian`` -- **boolean** (default: `False`); representation of transitions vectors """ from sage.geometry.polyhedron.constructor import Polyhedron @@ -100,9 +96,8 @@ def convex_hull(sbox, analysis="differential", big_endian=False): dict_points[value] = [] for i in range(0, 1 << n): for o in range(0, 1 << m): - if i+o > 0 and valid_transformations_matrix[i][o] != 0: - dict_points[valid_transformations_matrix[i][o]].append( - to_bits(n, i, big_endian) + to_bits(n, o, big_endian)) + if i + o > 0 and valid_transformations_matrix[i][o] != 0: + dict_points[valid_transformations_matrix[i][o]].append(list(map(int, f"{i:0{n}b}{o:0{n}b}"))) for value in values_in_matrix: if dict_points[value]: dict_polyhedron[value] = Polyhedron(vertices=dict_points[value]) @@ -110,12 +105,6 @@ def convex_hull(sbox, analysis="differential", big_endian=False): return dict_polyhedron -def to_bits(n, x, big_endian=False): - if big_endian: - return ZZ(x).digits(base=2, padto=n) - return ZZ(x).digits(base=2, padto=n)[::-1] - - def cutting_off_greedy(dict_polyhedron): """ Compute a set of inequalities that is cutting-off equivalent to the H-representation of the given convex hull. @@ -133,13 +122,16 @@ def cutting_off_greedy(dict_polyhedron): chosen_ineqs = [] poly_points = dict_polyhedron[proba].integral_points() remaining_ineqs = list(dict_polyhedron[proba].inequalities()) - impossible = [vector(dict_polyhedron[proba].base_ring(), v) - for v in VectorSpace(GF(2), dict_polyhedron[proba].ambient_dim()) - if v not in poly_points] + impossible = [ + vector(dict_polyhedron[proba].base_ring(), v) + for v in VectorSpace(GF(2), dict_polyhedron[proba].ambient_dim()) + if v not in poly_points + ] while impossible != []: if len(remaining_ineqs) == 0: - raise ValueError("no more inequalities to choose, but still " - "%d impossible points left" % len(impossible)) + raise ValueError( + "no more inequalities to choose, but still %d impossible points left" % len(impossible) + ) # find inequality in remaining_ineqs that cuts off the most # impossible points and add this to the chosen_ineqs @@ -153,10 +145,7 @@ def cutting_off_greedy(dict_polyhedron): remaining_ineqs.remove(chosen_ineqs[-1]) # remove all cut off impossible points - impossible = [v - for v in impossible - if chosen_ineqs[-1].contains(v) - ] + impossible = [v for v in impossible if chosen_ineqs[-1].contains(v)] dict_chosen_inequalities[proba] = chosen_ineqs return dict_chosen_inequalities @@ -188,16 +177,14 @@ def cutting_off_milp(dict_polyhedron, number_of_ineqs=None): for proba in dict_polyhedron.keys(): ineqs = list(dict_polyhedron[proba].inequalities()) poly_points = dict_polyhedron[proba].integral_points() - impossible = [vector(dict_polyhedron[proba].base_ring(), v) - for v in VectorSpace(GF(2), dict_polyhedron[proba].ambient_dim()) - if v not in poly_points] + impossible = [ + vector(dict_polyhedron[proba].base_ring(), v) + for v in VectorSpace(GF(2), dict_polyhedron[proba].ambient_dim()) + if v not in poly_points + ] # precompute which inequality removes which impossible point - precomputation = matrix( - [[int(not (ineq.contains(p))) - for p in impossible] - for ineq in ineqs] - ) + precomputation = matrix([[int(not (ineq.contains(p))) for p in impossible] for ineq in ineqs]) milp = MixedIntegerLinearProgram(maximization=False, solver=SOLVER_DEFAULT) var_ineqs = milp.new_variable(binary=True, name="ineqs") @@ -206,26 +193,17 @@ def cutting_off_milp(dict_polyhedron, number_of_ineqs=None): milp.set_objective(sum([var_ineqs[i] for i in range(len(ineqs))])) # or the given number else: - milp.add_constraint(sum( - [var_ineqs[i] - for i in range(len(ineqs))] - ) == number_of_ineqs) + milp.add_constraint(sum([var_ineqs[i] for i in range(len(ineqs))]) == number_of_ineqs) nrows, ncols = precomputation.dimensions() for c in range(ncols): - lhs = sum([var_ineqs[r] - for r in range(nrows) - if precomputation[r][c] == 1]) - if (not isinstance(lhs, int)): + lhs = sum([var_ineqs[r] for r in range(nrows) if precomputation[r][c] == 1]) + if not isinstance(lhs, int): milp.add_constraint(lhs >= 1) milp.solve() - remaining_ineqs = [ - ineq - for ineq, (var, val) in zip(ineqs, milp.get_values(var_ineqs).items()) - if val == 1 - ] + remaining_ineqs = [ineq for ineq, (var, val) in zip(ineqs, milp.get_values(var_ineqs).items()) if val == 1] dict_chosen_inequalities[proba] = remaining_ineqs return dict_chosen_inequalities @@ -239,8 +217,12 @@ def get_dictionary_that_contains_inequalities_for_small_sboxes(analysis="differe - ``analysis`` - **string** (default: `differential`); """ - file_path = inequalities_for_small_sboxes_path if analysis == "differential" else inequalities_for_small_sboxes_xor_linear_path - read_file = open(file_path, 'rb') + file_path = ( + inequalities_for_small_sboxes_path + if analysis == "differential" + else inequalities_for_small_sboxes_xor_linear_path + ) + read_file = open(file_path, "rb") dictio = pickle.load(read_file) read_file.close() @@ -248,9 +230,13 @@ def get_dictionary_that_contains_inequalities_for_small_sboxes(analysis="differe def update_dictionary_that_contains_inequalities_for_small_sboxes(sbox, analysis="differential"): - file_path = inequalities_for_small_sboxes_path if analysis == "differential" else inequalities_for_small_sboxes_xor_linear_path + file_path = ( + inequalities_for_small_sboxes_path + if analysis == "differential" + else inequalities_for_small_sboxes_xor_linear_path + ) try: - read_file = open(file_path, 'rb') + read_file = open(file_path, "rb") dictio = pickle.load(read_file) read_file.close() except OSError: @@ -260,13 +246,17 @@ def update_dictionary_that_contains_inequalities_for_small_sboxes(sbox, analysis print("Adding sbox inequalities in pre-saved dictionary") dict_inequalities = sbox_inequalities(sbox, analysis) dictio[str(sbox)] = dict_inequalities - write_file = open(file_path, 'wb') + write_file = open(file_path, "wb") pickle.dump(dictio, write_file) write_file.close() def delete_dictionary_that_contains_inequalities_for_small_sboxes(analysis="differential"): - file_path = inequalities_for_small_sboxes_path if analysis == "differential" else inequalities_for_small_sboxes_xor_linear_path - write_file = open(file_path, 'wb') + file_path = ( + inequalities_for_small_sboxes_path + if analysis == "differential" + else inequalities_for_small_sboxes_xor_linear_path + ) + write_file = open(file_path, "wb") pickle.dump({}, write_file) write_file.close() diff --git a/claasp/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes.py b/claasp/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes.py index 3a1e8b77d..5c21cda19 100644 --- a/claasp/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes.py +++ b/claasp/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -25,53 +24,62 @@ The logic minimizer espresso is required for this module. It is already installed in the docker. """ -import pickle, os -from claasp.cipher_modules.models.milp import MILP_AUXILIARY_FILE_PATH -from claasp.cipher_modules.models.milp.utils.utils import generate_espresso_input, delete_espresso_dictionary, \ - output_espresso_dictionary, generate_product_of_sum_from_espresso - -from sage.rings.integer_ring import ZZ -undisturbed_bit_sboxes_inequalities_file_name = "dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits.obj" -undisturbed_bit_sboxes_inequalities_file_path = os.path.join(MILP_AUXILIARY_FILE_PATH, undisturbed_bit_sboxes_inequalities_file_name) +import os +import pickle +from claasp.cipher_modules.models.milp import MILP_AUXILIARY_FILE_PATH +from claasp.cipher_modules.models.milp.utils.utils import ( + delete_espresso_dictionary, + generate_product_of_sum_from_espresso, + output_espresso_dictionary, +) -def _to_bits(x, input_size): - return ZZ(x).digits(base=2, padto=input_size)[::-1] +undisturbed_bit_sboxes_inequalities_file_name = ( + "dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits.obj" +) +undisturbed_bit_sboxes_inequalities_file_path = os.path.join( + MILP_AUXILIARY_FILE_PATH, undisturbed_bit_sboxes_inequalities_file_name +) def _encode_transition(delta_in, delta_out, verbose): - encoded_in = [_ for j in delta_in for _ in _to_bits(j, 2)] - encoded_out = [_ for j in delta_out for _ in _to_bits(j, 2)] + encoded_in = [_ for j in delta_in for _ in map(int, f"{j:02b}")] + encoded_out = [_ for j in delta_out for _ in map(int, f"{j:02b}")] if verbose: _print_transition(delta_in, delta_out, True) return "".join(str(_) for _ in encoded_in + encoded_out) def _print_transition(delta_in, delta_out, print_undisturbed_only=False): - input_str = ''.join(['1' if _ == 1 else '0' if _ == 0 else '?' for _ in delta_in]) - output_str = ''.join(['1' if _ == 1 else '0' if _ == 0 else '?' for _ in delta_out]) + input_str = "".join(["1" if _ == 1 else "0" if _ == 0 else "?" for _ in delta_in]) + output_str = "".join(["1" if _ == 1 else "0" if _ == 0 else "?" for _ in delta_out]) if print_undisturbed_only: - if output_str != ''.join(['?' for _ in delta_out]): + if output_str != "".join(["?" for _ in delta_out]): print(f" {input_str} -> {output_str}") else: print(f" {input_str} -> {output_str}") -def get_transitions_for_single_output_bit(sbox, valid_points, verbose=False): - ddt_with_undisturbed_bits_transitions = [_encode_transition(input, output, verbose) for input, output in valid_points] + +def get_transitions_for_single_output_bit(sbox, valid_points, verbose=False): + ddt_with_undisturbed_bits_transitions = [ + _encode_transition(input, output, verbose) for input, output in valid_points + ] n = sbox.input_size() valid_points = {} for position in range(n): valid_points[position] = {} for encoding_bit in range(2): - valid_points[position][encoding_bit] = [transition[:2 * n] + transition[2 * (n + position) + encoding_bit] for transition in ddt_with_undisturbed_bits_transitions] + valid_points[position][encoding_bit] = [ + transition[: 2 * n] + transition[2 * (n + position) + encoding_bit] + for transition in ddt_with_undisturbed_bits_transitions + ] return valid_points def generate_dict_product_of_sum_from_espresso(sbox, valid_points): - dict_espresso_outputs = {} valid_transitions = get_transitions_for_single_output_bit(sbox, valid_points) @@ -87,12 +95,12 @@ def generate_dict_product_of_sum_from_espresso(sbox, valid_points): def get_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits(): return output_espresso_dictionary(undisturbed_bit_sboxes_inequalities_file_path) -def update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits(sbox, valid_points): +def update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits(sbox, valid_points): file_path = undisturbed_bit_sboxes_inequalities_file_path try: - read_file = open(file_path, 'rb') + read_file = open(file_path, "rb") dictio = pickle.load(read_file) read_file.close() except OSError: @@ -103,7 +111,7 @@ def update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bit dict_product_of_sum = generate_dict_product_of_sum_from_espresso(sbox, valid_points) dictio[str(sbox)] = dict_product_of_sum - write_file = open(file_path, 'wb') + write_file = open(file_path, "wb") pickle.dump(dictio, write_file) write_file.close() diff --git a/claasp/cipher_modules/models/milp/utils/milp_name_mappings.py b/claasp/cipher_modules/models/milp/utils/milp_name_mappings.py index b175dd44f..2efc1769d 100644 --- a/claasp/cipher_modules/models/milp/utils/milp_name_mappings.py +++ b/claasp/cipher_modules/models/milp/utils/milp_name_mappings.py @@ -1,18 +1,18 @@ # Model names -MILP_XOR_DIFFERENTIAL = "xor_differential" -MILP_XOR_LINEAR = "xor_linear" MILP_BITWISE_DETERMINISTIC_TRUNCATED = "bitwise_deterministic_truncated_xor_differential" -MILP_WORDWISE_DETERMINISTIC_TRUNCATED ="wordwise_deterministic_truncated_xor_differential" -MILP_WORDWISE_IMPOSSIBLE = "wordwise_impossible_xor_differential" -MILP_WORDWISE_IMPOSSIBLE_AUTO = "wordwise_impossible_xor_differential_fully_automated" MILP_BITWISE_IMPOSSIBLE = "bitwise_impossible_xor_differential" MILP_BITWISE_IMPOSSIBLE_AUTO = "bitwise_impossible_xor_differential_fully_automated" +MILP_WORDWISE_DETERMINISTIC_TRUNCATED = "wordwise_deterministic_truncated_xor_differential" +MILP_WORDWISE_IMPOSSIBLE = "wordwise_impossible_xor_differential" +MILP_WORDWISE_IMPOSSIBLE_AUTO = "wordwise_impossible_xor_differential_fully_automated" +MILP_XOR_DIFFERENTIAL = "xor_differential" +MILP_XOR_LINEAR = "xor_linear" # Model utils MILP_BACKWARD_SUFFIX = "_backward" -MILP_PROBABILITY_SUFFIX = "_probability" MILP_BUILDING_MESSAGE = "Building model in progress ..." +MILP_DEFAULT_WEIGHT_PRECISION = 2 +MILP_PROBABILITY_SUFFIX = "_probability" +MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE = "number_of_unknown_patterns" MILP_XOR_DIFFERENTIAL_OBJECTIVE = "probability" MILP_XOR_LINEAR_OBJECTIVE = "probability" -MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE = "number_of_unknown_patterns" -MILP_DEFAULT_WEIGHT_PRECISION = 2 \ No newline at end of file diff --git a/claasp/cipher_modules/models/milp/utils/milp_truncated_utils.py b/claasp/cipher_modules/models/milp/utils/milp_truncated_utils.py index 0a0b6e83d..823f65f08 100644 --- a/claasp/cipher_modules/models/milp/utils/milp_truncated_utils.py +++ b/claasp/cipher_modules/models/milp/utils/milp_truncated_utils.py @@ -1,12 +1,16 @@ from claasp.cipher_modules.inverse_cipher import get_key_schedule_component_ids -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_BITWISE_IMPOSSIBLE_AUTO, \ - MILP_WORDWISE_IMPOSSIBLE_AUTO, MILP_BACKWARD_SUFFIX +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_BACKWARD_SUFFIX, + MILP_BITWISE_IMPOSSIBLE_AUTO, + MILP_WORDWISE_IMPOSSIBLE_AUTO, +) from claasp.cipher_modules.models.milp.utils.utils import milp_if_then from claasp.name_mappings import CIPHER_OUTPUT, INPUT_KEY -def generate_incompatiblity_constraints_for_component(model, model_type, x, x_class, backward_component, - include_all_components): +def generate_incompatiblity_constraints_for_component( + model, model_type, x, x_class, backward_component, include_all_components +): incompatiblity_constraints = [] if model_type == MILP_BITWISE_IMPOSSIBLE_AUTO: @@ -22,7 +26,9 @@ def generate_incompatiblity_constraints_for_component(model, model_type, x, x_cl # for multiple input components such as the XOR, ensures compatibility occurs on the correct branch inputs_to_be_kept = [] for index, input_id in enumerate(["_".join(i.split("_")[:-1]) for i in set(backward_component.input_id_links)]): - if INPUT_KEY not in input_id and [link + MILP_BACKWARD_SUFFIX for link in model._cipher.get_component_from_id(input_id).input_id_links] == [backward_component.id]: + if INPUT_KEY not in input_id and [ + link + MILP_BACKWARD_SUFFIX for link in model._cipher.get_component_from_id(input_id).input_id_links + ] == [backward_component.id]: inputs_to_be_kept.extend([_ for _ in input_ids if input_id in _]) backward_vars = [x_class[id] for id in (inputs_to_be_kept or input_ids) if INPUT_KEY not in id] else: @@ -36,36 +42,40 @@ def generate_incompatiblity_constraints_for_component(model, model_type, x, x_cl else: incompatibility_constraint = [forward_vars[inconsistent_index] + backward_vars[inconsistent_index] <= 2] incompatiblity_constraints.extend( - milp_if_then(inconsistent_vars[inconsistent_index], incompatibility_constraint, - model._model.get_max(x_class) * 2)) + milp_if_then( + inconsistent_vars[inconsistent_index], incompatibility_constraint, model._model.get_max(x_class) * 2 + ) + ) return incompatiblity_constraints, inconsistent_vars -def generate_all_incompatibility_constraints_for_fully_automatic_model(model, model_type, x, x_class, - include_all_components): +def generate_all_incompatibility_constraints_for_fully_automatic_model( + model, model_type, x, x_class, include_all_components +): assert model_type in [MILP_BITWISE_IMPOSSIBLE_AUTO, MILP_WORDWISE_IMPOSSIBLE_AUTO] constraints = [] forward_output = [c for c in model._forward_cipher.get_all_components() if c.type == CIPHER_OUTPUT][0] all_inconsistent_vars = [] - backward_components = [c for c in model._backward_cipher.get_all_components() if - c.description == ['round_output'] and set(c.input_id_links) != { - forward_output.id + MILP_BACKWARD_SUFFIX}] + backward_components = [ + c + for c in model._backward_cipher.get_all_components() + if c.description == ["round_output"] and set(c.input_id_links) != {forward_output.id + MILP_BACKWARD_SUFFIX} + ] key_flow = set(get_key_schedule_component_ids(model._cipher)) - {INPUT_KEY} - backward_key_flow = [f'{id}{MILP_BACKWARD_SUFFIX}' for id in key_flow] + backward_key_flow = [f"{id}{MILP_BACKWARD_SUFFIX}" for id in key_flow] if include_all_components: backward_components = set(model._backward_cipher.get_all_components()) - set( - model._backward_cipher.get_component_from_id(key_flow_id) for key_flow_id in backward_key_flow) + model._backward_cipher.get_component_from_id(key_flow_id) for key_flow_id in backward_key_flow + ) for backward_component in backward_components: - incompatibility_constraints, inconsistent_vars = generate_incompatiblity_constraints_for_component(model, - model_type, - x, x_class, - backward_component, - include_all_components) + incompatibility_constraints, inconsistent_vars = generate_incompatiblity_constraints_for_component( + model, model_type, x, x_class, backward_component, include_all_components + ) all_inconsistent_vars += inconsistent_vars constraints.extend(incompatibility_constraints) @@ -74,10 +84,11 @@ def generate_all_incompatibility_constraints_for_fully_automatic_model(model, mo return constraints -def fix_variables_value_deterministic_truncated_xor_differential_constraints(milp_model, model_variables, - fixed_variables=[]): +def fix_variables_value_deterministic_truncated_xor_differential_constraints( + milp_model, model_variables, fixed_variables=[] +): constraints = [] - if 'Wordwise' in milp_model.__class__.__name__: + if "Wordwise" in milp_model.__class__.__name__: prefix = "_word" suffix = "_class" else: @@ -87,11 +98,12 @@ def fix_variables_value_deterministic_truncated_xor_differential_constraints(mil for fixed_variable in fixed_variables: if fixed_variable["constraint_type"] == "equal": for index, bit_position in enumerate(fixed_variable["bit_positions"]): - component_bit = f'{fixed_variable["component_id"]}{prefix}_{bit_position}{suffix}' + component_bit = f"{fixed_variable['component_id']}{prefix}_{bit_position}{suffix}" constraints.append(model_variables[component_bit] == fixed_variable["bit_values"][index]) else: constraints.extend( - _generate_value_exclusion_constraints(milp_model, model_variables, fixed_variable, prefix, suffix)) + _generate_value_exclusion_constraints(milp_model, model_variables, fixed_variable, prefix, suffix) + ) return constraints @@ -99,8 +111,13 @@ def fix_variables_value_deterministic_truncated_xor_differential_constraints(mil def _generate_value_exclusion_constraints(milp_model, model_variables, fixed_variable, prefix, suffix): constraints = [] if sum(fixed_variable["bit_values"]) == 0: - constraints.append(sum(model_variables[f'{fixed_variable["component_id"]}{prefix}_{i}{suffix}'] for i in - fixed_variable["bit_positions"]) >= 1) + constraints.append( + sum( + model_variables[f"{fixed_variable['component_id']}{prefix}_{i}{suffix}"] + for i in fixed_variable["bit_positions"] + ) + >= 1 + ) else: M = milp_model._model.get_max(model_variables) + 1 d = milp_model._binary_variable @@ -108,12 +125,12 @@ def _generate_value_exclusion_constraints(milp_model, model_variables, fixed_var for index, bit_position in enumerate(fixed_variable["bit_positions"]): # eq = 1 iff bit_position == diff_index - eq = d[f'{fixed_variable["component_id"]}{prefix}_{bit_position}{suffix}_is_diff_index'] + eq = d[f"{fixed_variable['component_id']}{prefix}_{bit_position}{suffix}_is_diff_index"] one_among_n += eq # x[diff_index] < fixed_variable[diff_index] or fixed_variable[diff_index] < x[diff_index] - dummy = d[f'{fixed_variable["component_id"]}{prefix}_{bit_position}{suffix}_is_diff_index'] - a = model_variables[f'{fixed_variable["component_id"]}{prefix}_{bit_position}{suffix}'] + dummy = d[f"{fixed_variable['component_id']}{prefix}_{bit_position}{suffix}_is_diff_index"] + a = model_variables[f"{fixed_variable['component_id']}{prefix}_{bit_position}{suffix}"] b = fixed_variable["bit_values"][index] constraints.extend([a <= b - 1 + M * (2 - dummy - eq), a >= b + 1 - M * (dummy + 1 - eq)]) diff --git a/claasp/cipher_modules/models/milp/utils/mzn_predicates.py b/claasp/cipher_modules/models/milp/utils/mzn_predicates.py index d748e204c..c2fc17f9d 100644 --- a/claasp/cipher_modules/models/milp/utils/mzn_predicates.py +++ b/claasp/cipher_modules/models/milp/utils/mzn_predicates.py @@ -1,24 +1,22 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** def get_word_operations(): - functions_with_window_size = """ % Left rotation of X by val positions function array[int] of var 0..1: LRot(array[int] of var 0..1: X, int: val)= diff --git a/claasp/cipher_modules/models/milp/utils/utils.py b/claasp/cipher_modules/models/milp/utils/utils.py index dd288c054..b2fb650bb 100644 --- a/claasp/cipher_modules/models/milp/utils/utils.py +++ b/claasp/cipher_modules/models/milp/utils/utils.py @@ -14,26 +14,34 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** + import datetime import pickle import re -from subprocess import run +from subprocess import run from bitstring import BitArray -from sage.arith.misc import is_power_of_two - -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_large_sboxes import \ - get_dictionary_that_contains_inequalities_for_large_sboxes -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import ( - output_dictionary_that_contains_xor_inequalities, - update_dictionary_that_contains_xor_inequalities_between_n_input_bits) from sage.numerical.mip import MIPSolverException -from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_BITWISE_DETERMINISTIC_TRUNCATED, \ - MILP_WORDWISE_DETERMINISTIC_TRUNCATED, MILP_BACKWARD_SUFFIX, MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE, \ - MILP_XOR_DIFFERENTIAL_OBJECTIVE, MILP_BITWISE_IMPOSSIBLE, MILP_WORDWISE_IMPOSSIBLE, MILP_BITWISE_IMPOSSIBLE_AUTO, \ - MILP_WORDWISE_IMPOSSIBLE_AUTO +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_large_sboxes import ( + get_dictionary_that_contains_inequalities_for_large_sboxes, +) +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import ( + output_dictionary_that_contains_xor_inequalities, + update_dictionary_that_contains_xor_inequalities_between_n_input_bits, +) +from claasp.cipher_modules.models.milp.utils.milp_name_mappings import ( + MILP_BACKWARD_SUFFIX, + MILP_BITWISE_DETERMINISTIC_TRUNCATED, + MILP_BITWISE_IMPOSSIBLE_AUTO, + MILP_BITWISE_IMPOSSIBLE, + MILP_TRUNCATED_XOR_DIFFERENTIAL_OBJECTIVE, + MILP_WORDWISE_DETERMINISTIC_TRUNCATED, + MILP_WORDWISE_IMPOSSIBLE_AUTO, + MILP_WORDWISE_IMPOSSIBLE, + MILP_XOR_DIFFERENTIAL_OBJECTIVE, +) from claasp.name_mappings import SBOX @@ -59,23 +67,23 @@ def _get_data(data_keywords, lines): def _get_variables_value(internal_variables, read_file): variables_value = {} for key in internal_variables.keys(): - index = int(re.search(r'\d+', str(internal_variables[key])).group()) + 1 - match = re.search(r'[xyz]_%s[\s]+[\*]?[\s]*([0-9]*[.]?[0-9]+)' % index, read_file) + index = int(re.search(r"\d+", str(internal_variables[key])).group()) + 1 + match = re.search(r"[xyz]_%s[\s]+[\*]?[\s]*([0-9]*[.]?[0-9]+)" % index, read_file) variables_value[key] = float(match.group(1)) if match else 0.0 return variables_value def _parse_external_solver_output(model, solver_specs, model_type, solution_file_path, solver_process): - solve_time = _get_data(solver_specs['keywords']['time'], solver_process) + solve_time = _get_data(solver_specs["keywords"]["time"], solver_process) - status = 'UNSATISFIABLE' + status = "UNSATISFIABLE" objective_value = None components_values = None - if re.findall(solver_specs['keywords']['unsat_condition'], solver_process) == []: - status = 'SATISFIABLE' + if re.findall(solver_specs["keywords"]["unsat_condition"], solver_process) == []: + status = "SATISFIABLE" - with open(solution_file_path, 'r') as lp_file: + with open(solution_file_path, "r") as lp_file: read_file = lp_file.read() if model_type in [MILP_BITWISE_DETERMINISTIC_TRUNCATED, MILP_BITWISE_IMPOSSIBLE]: @@ -85,8 +93,9 @@ def _parse_external_solver_output(model, solver_specs, model_type, solution_file elif model_type == MILP_BITWISE_IMPOSSIBLE_AUTO: components_variables = _get_variables_value(model.trunc_binvar, read_file) objective_variables = _get_variables_value(model.binary_variable, read_file) - inconsistent_component_var = \ - [i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i][0] + inconsistent_component_var = [ + i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i + ][0] objective_value = "_".join(inconsistent_component_var.split("_")[:-3]) elif model_type in [MILP_WORDWISE_DETERMINISTIC_TRUNCATED, MILP_WORDWISE_IMPOSSIBLE]: components_variables = _get_variables_value(model.trunc_wordvar, read_file) @@ -95,13 +104,14 @@ def _parse_external_solver_output(model, solver_specs, model_type, solution_file elif model_type == MILP_WORDWISE_IMPOSSIBLE_AUTO: components_variables = _get_variables_value(model.trunc_wordvar, read_file) objective_variables = _get_variables_value(model.binary_variable, read_file) - inconsistent_component_var = \ - [i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i][0] + inconsistent_component_var = [ + i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i + ][0] objective_value = "_".join(inconsistent_component_var.split("_")[:-3]) else: components_variables = _get_variables_value(model.binary_variable, read_file) objective_variables = _get_variables_value(model.integer_variable, read_file) - objective_value = objective_variables[MILP_XOR_DIFFERENTIAL_OBJECTIVE] / float(10 ** model.weight_precision) + objective_value = objective_variables[MILP_XOR_DIFFERENTIAL_OBJECTIVE] / float(10**model.weight_precision) components_values = model._get_component_values(objective_variables, components_variables) @@ -112,7 +122,6 @@ def _parse_external_solver_output(model, solver_specs, model_type, solution_file def generate_espresso_input(valid_points): - input_size = len(valid_points[0]) espresso_input = [f"# there are {input_size} input variables\n"] @@ -127,7 +136,7 @@ def generate_espresso_input(valid_points): espresso_input.append("# end of the PLA data\n") espresso_input.append(".e") - return ''.join(espresso_input) + return "".join(espresso_input) def generate_product_of_sum_from_espresso(valid_points): @@ -146,22 +155,21 @@ def generate_product_of_sum_from_espresso(valid_points): """ espresso_input = generate_espresso_input(valid_points) - espresso_process = run(['espresso', '-epos', '-okiss'], input=espresso_input, - capture_output=True, text=True) + espresso_process = run(["espresso", "-epos", "-okiss"], input=espresso_input, capture_output=True, text=True) espresso_output = espresso_process.stdout.splitlines() return [line[:-2] for line in espresso_output[4:]] def output_espresso_dictionary(file_path): - read_file = open(file_path, 'rb') + read_file = open(file_path, "rb") dictio = pickle.load(read_file) read_file.close() return dictio def delete_espresso_dictionary(file_path): - write_file = open(file_path, 'wb') + write_file = open(file_path, "wb") pickle.dump({}, write_file) write_file.close() @@ -197,8 +205,7 @@ def milp_less(model, a, b, big_m): """ d = model.binary_variable a_less_b = d[str(a) + "_less_" + str(b) + "_dummy"] - constraints = [a <= b - 1 + big_m * (1 - a_less_b), - a >= b - big_m * a_less_b] + constraints = [a <= b - 1 + big_m * (1 - a_less_b), a >= b - big_m * a_less_b] return a_less_b, constraints @@ -239,9 +246,7 @@ def milp_and(model, a, b): d = model.binary_variable a_and_b = d[str(a) + "_and_" + str(b) + "_dummy"] - constraint = [a + b - 1 <= a_and_b, - a_and_b <= a, - a_and_b <= b] + constraint = [a + b - 1 <= a_and_b, a_and_b <= a, a_and_b <= b] return a_and_b, constraint @@ -255,9 +260,7 @@ def milp_or(model, a, b): d = model.binary_variable a_or_b = d[str(a) + "_or_" + str(b) + "_dummy"] - constraint = [a + b >= a_or_b, - a_or_b >= a, - a_or_b >= b] + constraint = [a + b >= a_or_b, a_or_b >= a, a_or_b >= b] return a_or_b, constraint @@ -294,9 +297,9 @@ def milp_generalized_and(model, var_list): """ d = model.binary_variable - generalized_and_varname = '' + generalized_and_varname = "" for i in range(len(var_list)): - generalized_and_varname += str(var_list[i]) + '{}'.format('_and_' if i < len(var_list) - 1 else '_dummy') + generalized_and_varname += str(var_list[i]) + "{}".format("_and_" if i < len(var_list) - 1 else "_dummy") generalized_and = d[generalized_and_varname] constraint = [sum(var_list) - len(var_list) + 1 <= generalized_and] @@ -410,10 +413,7 @@ def milp_xor(a, b, c): sage: a x_0 """ - constraints = [a + b >= c, - a + c >= b, - b + c >= a, - a + b + c <= 2] + constraints = [a + b >= c, a + c >= b, b + c >= a, a + b + c <= 2] return constraints @@ -446,7 +446,6 @@ def milp_generalized_xor(input_var_list, output_bit): dict_inequalities = output_dictionary_that_contains_xor_inequalities() inequalities = dict_inequalities[number_of_inputs] - for ineq in inequalities: constraint = 0 for var in range(number_of_inputs): @@ -505,6 +504,8 @@ def milp_else(var_if, else_constraints, big_m): constraints.append(rhs <= lhs + big_m * var_if) return constraints + + def milp_if_then_else(var_if, then_constraints, else_constraints, big_m): """ Returns a list of variables and a list of constraints to model an if-then-else statement. @@ -526,7 +527,7 @@ def milp_if_elif_else(model, var_if_list, then_constraints_list, else_constraint https://stackoverflow.com/questions/41009196/if-then-elseif-then-in-mixed-integer-linear-programming """ - assert (len(var_if_list) == len(then_constraints_list)) + assert len(var_if_list) == len(then_constraints_list) constraints = [] num_cond = len(var_if_list) @@ -535,11 +536,11 @@ def milp_if_elif_else(model, var_if_list, then_constraints_list, else_constraint else: d = model.binary_variable - decision_varname = '' + decision_varname = "" for i in range(num_cond): - decision_varname += str(var_if_list[i]) + '{}'.format('_and_' if i < num_cond - 1 else '_dummy') + decision_varname += str(var_if_list[i]) + "{}".format("_and_" if i < num_cond - 1 else "_dummy") - decision_var = [d[decision_varname + '_' + str(i)] for i in range(num_cond)] + decision_var = [d[decision_varname + "_" + str(i)] for i in range(num_cond)] for i in range(num_cond): decision_constraints = 0 @@ -547,7 +548,7 @@ def milp_if_elif_else(model, var_if_list, then_constraints_list, else_constraint decision_constraints += 1 - var_if_list[j] decision_constraints += var_if_list[i] constraints.append(decision_constraints <= decision_var[i] + num_cond - 1) - constraints.append(1. / num_cond * decision_constraints >= decision_var[i]) + constraints.append(1.0 / num_cond * decision_constraints >= decision_var[i]) constraints.extend(milp_if_then(decision_var[i], then_constraints_list[i], big_m)) @@ -555,6 +556,7 @@ def milp_if_elif_else(model, var_if_list, then_constraints_list, else_constraint return constraints + def espresso_pos_to_constraints(espresso_inequalities, variables): constraints = [] for ineq in espresso_inequalities: @@ -567,6 +569,7 @@ def espresso_pos_to_constraints(espresso_inequalities, variables): constraints.append(constraint >= 1) return constraints + def milp_xor_truncated(model, input_1, input_2, output): """ Returns a list of variables and a list of constraints for the XOR for two input bits @@ -604,13 +607,24 @@ def milp_xor_truncated(model, input_1, input_2, output): """ x = model.binary_variable - espresso_inequalities = ['-1-000', '-0-100', '----11', '0-0-1-', '-0-0-1', - '-1-1-1', '11----', '--1-0-', '1---0-', '--11--'] + espresso_inequalities = [ + "-1-000", + "-0-100", + "----11", + "0-0-1-", + "-0-0-1", + "-1-1-1", + "11----", + "--1-0-", + "1---0-", + "--11--", + ] all_vars = [x[i] for i in input_1 + input_2 + output] return espresso_pos_to_constraints(espresso_inequalities, all_vars) + def milp_xor_truncated_wordwise(model, input_1, input_2, output): """ Returns a list of variables and a list of constraints for the XOR for two input bytes @@ -647,53 +661,99 @@ def milp_xor_truncated_wordwise(model, input_1, input_2, output): x = model.binary_variable - espresso_inequalities = ['0-00000000-0---------1--------', '-0--------0-00000000-1--------', - '-1----------00000000-0--------', '--00000000-1---------0--------', - '---------------------01-------', '--------------------0100000000', - '---------------------0-1------', '--------------------1-1-------', - '---------------------0--1-----', '--------------------1--1------', - '---------------------0---1----', '--------------------1---1-----', - '---------------------0----1---', '--------------------1----1----', - '---------------------0-----1--', '--1---------0-------0-0-------', - '--0---------1-------0-0-------', '---------------------0------1-', - '---1---------0------0--0------', '---0---------1------0--0------', - '----1---------0-----0---0-----', '----0---------1-----0---0-----', - '--------------------1-----1---', '-----1---------0----0----0----', - '-----0---------1----0----0----', '------1---------0---0-----0---', - '------0---------1---0-----0---', '-------1---------0--0------0--', - '-------0---------1--0------0--', '--------1---------0-0-------0-', - '--------0---------1-0-------0-', '---------1---------00--------0', - '---------0---------10--------0', '---------------------0-------1', - '--------------------1------1--', '--------------------1-------1-', - '--------------------1--------1', '0100000000--------------------', - '----------0100000000----------', '---------0---------0---------1', - '---------1---------1---------1', '1---------1----------0--------', - '0---------0---------1---------', '-------0---------0---------1--', - '------0---------0---------1---', '-----0---------0---------1----', - '----0---------0---------1-----', '---0---------0---------1------', - '--0---------0---------1-------', '--------0---------0---------1-', - '--------1---------1---------1-', '--1---------1---------1-------', - '------1---------1---------1---', '-----1---------1---------1----', - '----1---------1---------1-----', '---1---------1---------1------', - '-------1---------1---------1--', '----------1---------0---------', - '1-------------------0---------', '-----------0------1-----------', - '----------1-------1-----------', '-----------01-----------------', - '----------1-1-----------------', '-----------0----1-------------', - '----------1-----1-------------', '-----------0---1--------------', - '----------1----1--------------', '-----------0--1---------------', - '----------1---1---------------', '-----------0-1----------------', - '----------1--1----------------', '-0------1---------------------', - '-0-----1----------------------', '-0----1-----------------------', - '-0---1------------------------', '-0--1-------------------------', - '-0-1--------------------------', '-01---------------------------', - '-----------0-----1------------', '----------1------1------------', - '1-------1---------------------', '1------1----------------------', - '1-----1-----------------------', '1----1------------------------', - '1---1-------------------------', '1--1--------------------------', - '1-1---------------------------', '-----------0-------1----------', - '----------1--------1----------', '-0-------1--------------------', - '1--------1--------------------'] - + espresso_inequalities = [ + "0-00000000-0---------1--------", + "-0--------0-00000000-1--------", + "-1----------00000000-0--------", + "--00000000-1---------0--------", + "---------------------01-------", + "--------------------0100000000", + "---------------------0-1------", + "--------------------1-1-------", + "---------------------0--1-----", + "--------------------1--1------", + "---------------------0---1----", + "--------------------1---1-----", + "---------------------0----1---", + "--------------------1----1----", + "---------------------0-----1--", + "--1---------0-------0-0-------", + "--0---------1-------0-0-------", + "---------------------0------1-", + "---1---------0------0--0------", + "---0---------1------0--0------", + "----1---------0-----0---0-----", + "----0---------1-----0---0-----", + "--------------------1-----1---", + "-----1---------0----0----0----", + "-----0---------1----0----0----", + "------1---------0---0-----0---", + "------0---------1---0-----0---", + "-------1---------0--0------0--", + "-------0---------1--0------0--", + "--------1---------0-0-------0-", + "--------0---------1-0-------0-", + "---------1---------00--------0", + "---------0---------10--------0", + "---------------------0-------1", + "--------------------1------1--", + "--------------------1-------1-", + "--------------------1--------1", + "0100000000--------------------", + "----------0100000000----------", + "---------0---------0---------1", + "---------1---------1---------1", + "1---------1----------0--------", + "0---------0---------1---------", + "-------0---------0---------1--", + "------0---------0---------1---", + "-----0---------0---------1----", + "----0---------0---------1-----", + "---0---------0---------1------", + "--0---------0---------1-------", + "--------0---------0---------1-", + "--------1---------1---------1-", + "--1---------1---------1-------", + "------1---------1---------1---", + "-----1---------1---------1----", + "----1---------1---------1-----", + "---1---------1---------1------", + "-------1---------1---------1--", + "----------1---------0---------", + "1-------------------0---------", + "-----------0------1-----------", + "----------1-------1-----------", + "-----------01-----------------", + "----------1-1-----------------", + "-----------0----1-------------", + "----------1-----1-------------", + "-----------0---1--------------", + "----------1----1--------------", + "-----------0--1---------------", + "----------1---1---------------", + "-----------0-1----------------", + "----------1--1----------------", + "-0------1---------------------", + "-0-----1----------------------", + "-0----1-----------------------", + "-0---1------------------------", + "-0--1-------------------------", + "-0-1--------------------------", + "-01---------------------------", + "-----------0-----1------------", + "----------1------1------------", + "1-------1---------------------", + "1------1----------------------", + "1-----1-----------------------", + "1----1------------------------", + "1---1-------------------------", + "1--1--------------------------", + "1-1---------------------------", + "-----------0-------1----------", + "----------1--------1----------", + "-0-------1--------------------", + "1--------1--------------------", + ] all_vars = [x[i] for i in input_1 + input_2 + output] return espresso_pos_to_constraints(espresso_inequalities, all_vars) @@ -703,14 +763,17 @@ def milp_xor_truncated_wordwise(model, input_1, input_2, output): def _get_component_values_for_impossible_models(model, objective_variables, components_variables): components_values = {} if model._forward_cipher == model._cipher: - inconsistent_component_var = \ - [i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i][0] + inconsistent_component_var = [ + i for i in objective_variables.keys() if objective_variables[i] > 0 and "inconsistent" in i + ][0] inconsistent_component_id = "_".join(inconsistent_component_var.split("_")[:-3]) full_cipher_components = model._cipher.get_all_components_ids() backward_components = model._backward_cipher.get_all_components_ids() + model._backward_cipher.inputs index = full_cipher_components.index(inconsistent_component_id) - updated_cipher_components = full_cipher_components[:index + 1] + [ - c + MILP_BACKWARD_SUFFIX if c + MILP_BACKWARD_SUFFIX in backward_components else c for c in full_cipher_components[index:]] + updated_cipher_components = full_cipher_components[: index + 1] + [ + c + MILP_BACKWARD_SUFFIX if c + MILP_BACKWARD_SUFFIX in backward_components else c + for c in full_cipher_components[index:] + ] list_component_ids = model._forward_cipher.inputs + updated_cipher_components elif model._incompatible_components != None: full_cipher_components = model._cipher.get_all_components_ids() @@ -718,14 +781,20 @@ def _get_component_values_for_impossible_models(model, objective_variables, comp indices = [] for id in model._incompatible_components: - backward_incompatible_component = model._backward_cipher.get_component_from_id(id + f"{MILP_BACKWARD_SUFFIX}") - input_ids, _ = backward_incompatible_component._get_input_output_variables() - renamed_input_ids = set(["_".join(id.split("_")[:-2]) if MILP_BACKWARD_SUFFIX in id else "_".join(id.split("_")[:-1]) for id in input_ids]) + backward_incompatible_component = model._backward_cipher.get_component_from_id( + id + f"{MILP_BACKWARD_SUFFIX}" + ) + input_ids, _ = backward_incompatible_component._get_input_output_variables() + renamed_input_ids = { + "_".join(id.split("_")[:-2]) if MILP_BACKWARD_SUFFIX in id else "_".join(id.split("_")[:-1]) + for id in input_ids + } indices += sorted(indices + [full_cipher_components.index(c) for c in renamed_input_ids]) - updated_cipher_components = full_cipher_components[:indices[0]] + [ - c + MILP_BACKWARD_SUFFIX if c + MILP_BACKWARD_SUFFIX in backward_components else c for c in - full_cipher_components[indices[0]:]] + updated_cipher_components = full_cipher_components[: indices[0]] + [ + c + MILP_BACKWARD_SUFFIX if c + MILP_BACKWARD_SUFFIX in backward_components else c + for c in full_cipher_components[indices[0] :] + ] list_component_ids = model._forward_cipher.inputs + updated_cipher_components else: full_cipher_components = model._cipher.get_all_components_ids() @@ -752,7 +821,8 @@ def _get_variables_values_as_string(component_id, components_variables, suffix, diff_str += "*" return diff_str -def _string_to_hex( string): + +def _string_to_hex(string): string = "0b" + string try: value = BitArray(string) @@ -764,6 +834,7 @@ def _string_to_hex( string): value = string return value + def _filter_fixed_variables(fixed_values, fixed_variable, id): fixed_values_to_keep = [variable for variable in fixed_values if variable["constraint_type"] == "equal"] if id in [value["component_id"] for value in fixed_values_to_keep]: @@ -772,14 +843,15 @@ def _filter_fixed_variables(fixed_values, fixed_variable, id): bit_index = fixed_variable["bit_positions"].index(bit) del fixed_variable["bit_values"][bit_index] del fixed_variable["bit_positions"][bit_index] - + + def _set_weight_precision(model, analysis_type): if any(SBOX in item for item in model.non_linear_component_id): dict_product_of_sum = get_dictionary_that_contains_inequalities_for_large_sboxes(analysis=analysis_type) for id in model.non_linear_component_id: sb = tuple(model._cipher.get_component_from_id(id).description) for proba in dict_product_of_sum[str(sb)].keys(): - if not is_power_of_two(proba): + if (proba & (proba - 1)) != 0: # proba not power of two model._has_non_integer_weight = True break else: @@ -787,7 +859,7 @@ def _set_weight_precision(model, analysis_type): break if model._has_non_integer_weight: - step = 1 / float(10 ** model.weight_precision) + step = 1 / float(10**model.weight_precision) else: step = 1 - return step \ No newline at end of file + return step diff --git a/claasp/cipher_modules/models/sat/cms_models/cms_bitwise_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/sat/cms_models/cms_bitwise_deterministic_truncated_xor_differential_model.py index c18fb10b5..122dbf695 100644 --- a/claasp/cipher_modules/models/sat/cms_models/cms_bitwise_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/sat/cms_models/cms_bitwise_deterministic_truncated_xor_differential_model.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - """CryptoMiniSat model of Cipher. .. _cms-deterministic-truncated-standard: @@ -44,17 +42,18 @@ `_. """ - -from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_deterministic_truncated_xor_differential_model import \ - SatBitwiseDeterministicTruncatedXorDifferentialModel +from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_deterministic_truncated_xor_differential_model import ( + SatBitwiseDeterministicTruncatedXorDifferentialModel, +) class CmsSatDeterministicTruncatedXorDifferentialModel(SatBitwiseDeterministicTruncatedXorDifferentialModel): - - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) - print("\n*** WARNING ***\n" - "At the best of the authors knowldege, deterministic truncated XOR differential model " - "cannot take any advantage of CryptoMiniSat. Therefore, the implementation is the same " - "of the SAT one.") + print( + "\n*** WARNING ***\n" + "At the best of the authors knowldege, deterministic truncated XOR differential model " + "cannot take any advantage of CryptoMiniSat. Therefore, the implementation is the same " + "of the SAT one." + ) diff --git a/claasp/cipher_modules/models/sat/cms_models/cms_cipher_model.py b/claasp/cipher_modules/models/sat/cms_models/cms_cipher_model.py index 8bf7a8d24..084b63f6a 100644 --- a/claasp/cipher_modules/models/sat/cms_models/cms_cipher_model.py +++ b/claasp/cipher_modules/models/sat/cms_models/cms_cipher_model.py @@ -1,28 +1,26 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - """CryptoMiniSat model of Cipher. .. _cms-cipher-standard: CMS cipher model of a cipher ------------------------------------- +---------------------------- The target of this class is to override the methods of the superclass :py:class:`Sat Cipher Model ` to take the advantage given by @@ -42,16 +40,23 @@ For any further information, visit `CryptoMiniSat - XOR clauses `_. """ + from claasp.cipher_modules.models.sat.sat_model import SatModel from claasp.cipher_modules.models.sat.utils import utils from claasp.cipher_modules.models.sat.sat_models.sat_cipher_model import SatCipherModel -from claasp.name_mappings import (CONSTANT, SBOX, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, - LINEAR_LAYER, MIX_COLUMN, WORD_OPERATION) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class CmsSatCipherModel(SatCipherModel): - - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) def _add_clauses_to_solver(self, numerical_cnf, solver): @@ -84,16 +89,17 @@ def build_cipher_model(self, fixed_variables=[]): """ variables = [] constraints = SatModel.fix_variables_value_constraints(fixed_variables) - component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR'] + component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR") self._model_constraints = constraints self._variables_list = [] for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.cms_constraints() diff --git a/claasp/cipher_modules/models/sat/cms_models/cms_xor_differential_model.py b/claasp/cipher_modules/models/sat/cms_models/cms_xor_differential_model.py index 43fd9239e..2a001dc0e 100644 --- a/claasp/cipher_modules/models/sat/cms_models/cms_xor_differential_model.py +++ b/claasp/cipher_modules/models/sat/cms_models/cms_xor_differential_model.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - """CryptoMiniSat model of Cipher. .. _cms-differential-standard: @@ -42,16 +40,23 @@ For any further information, visit `CryptoMiniSat - XOR clauses `_. """ + from claasp.cipher_modules.models.sat.sat_model import SatModel from claasp.cipher_modules.models.sat.utils import utils from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel -from claasp.name_mappings import WORD_OPERATION, CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, \ - LINEAR_LAYER, SBOX, MIX_COLUMN +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class CmsSatXorDifferentialModel(SatXorDifferentialModel): - - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) def _add_clauses_to_solver(self, numerical_cnf, solver): @@ -88,17 +93,18 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): variables = [] self._variables_list = [] constraints = SatModel.fix_variables_value_constraints(fixed_variables) - component_types = [CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR'] + component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): + WORD_OPERATION == component.type and operation not in operation_types + ): variables, constraints = component.cms_xor_differential_propagation_constraints(self) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) diff --git a/claasp/cipher_modules/models/sat/cms_models/cms_xor_linear_model.py b/claasp/cipher_modules/models/sat/cms_models/cms_xor_linear_model.py index db9232c91..c168a5bd3 100644 --- a/claasp/cipher_modules/models/sat/cms_models/cms_xor_linear_model.py +++ b/claasp/cipher_modules/models/sat/cms_models/cms_xor_linear_model.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - """CryptoMiniSat model of Cipher. .. _cms-linear-standard: @@ -42,17 +40,17 @@ For any further information, visit `CryptoMiniSat - XOR clauses `_. """ + +from claasp.cipher_modules.models.sat.sat_models.sat_xor_linear_model import SatXorLinearModel from claasp.cipher_modules.models.sat.utils import utils from claasp.cipher_modules.models.utils import get_bit_bindings -from claasp.cipher_modules.models.sat.sat_models.sat_xor_linear_model import SatXorLinearModel from claasp.name_mappings import CONSTANT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION class CmsSatXorLinearModel(SatXorLinearModel): - - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, "_".join) def _add_clauses_to_solver(self, numerical_cnf, solver): """ @@ -91,8 +89,8 @@ def branch_xor_linear_constraints(self): """ constraints = [] for output_bit, input_bits in self.bit_bindings.items(): - operands = [f'x -{output_bit}'] + input_bits - constraints.append(' '.join(operands)) + operands = [f"x -{output_bit}"] + input_bits + constraints.append(" ".join(operands)) return constraints @@ -120,13 +118,13 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): self._model_constraints = constraints for component in self._cipher.get_all_components(): - component_types = [CONSTANT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION] + component_types = (CONSTANT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION) operation = component.description[0] - operation_types = ["AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB"] + operation_types = ("AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB") if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): variables, constraints = component.cms_xor_linear_mask_propagation_constraints(self) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) diff --git a/claasp/cipher_modules/models/sat/sat_model.py b/claasp/cipher_modules/models/sat/sat_model.py index 9fc72136d..704bcffaf 100644 --- a/claasp/cipher_modules/models/sat/sat_model.py +++ b/claasp/cipher_modules/models/sat/sat_model.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -53,6 +52,7 @@ (MSB) is indexed by 0. Be careful whenever inspecting the code or, as well, a CNF. """ + import copy import math import time @@ -65,12 +65,19 @@ from claasp.cipher_modules.models.sat.utils import utils from claasp.cipher_modules.models.utils import set_component_solution, convert_solver_solution_to_dictionary from claasp.editor import remove_permutations, remove_rotations -from claasp.name_mappings import SBOX, CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, \ - WORD_OPERATION +from claasp.name_mappings import ( + SBOX, + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + WORD_OPERATION, +) class SatModel: - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): """ Initialise the sat model. @@ -88,7 +95,7 @@ def __init__(self, cipher, counter='sequential', compact=False): internal_cipher = remove_rotations(internal_cipher) # set the counter to fix the weight - if counter == 'sequential': + if counter == "sequential": self._counter = self._sequential_counter else: self._counter = self._parallel_counter @@ -115,10 +122,10 @@ def _get_cipher_inputs_components_solutions(self, out_suffix, variable2value): value = 0 for i in range(bit_size): value <<= 1 - if f'{cipher_input}_{i}{out_suffix}' in variable2value: - value ^= variable2value[f'{cipher_input}_{i}{out_suffix}'] + if f"{cipher_input}_{i}{out_suffix}" in variable2value: + value ^= variable2value[f"{cipher_input}_{i}{out_suffix}"] hex_digits = bit_size // 4 + (bit_size % 4 != 0) - hex_value = f'{value:#0{hex_digits+2}x}' + hex_value = f"{value:#0{hex_digits + 2}x}" component_solution = set_component_solution(hex_value) components_solutions[cipher_input] = component_solution @@ -130,12 +137,12 @@ def _get_cipher_inputs_components_solutions_double_ids(self, variable2value): values = [] for i in range(bit_size): value = 0 - if f'{cipher_input}_{i}_0' in variable2value: - value ^= variable2value[f'{cipher_input}_{i}_0'] << 1 - if f'{cipher_input}_{i}_1' in variable2value: - value ^= variable2value[f'{cipher_input}_{i}_1'] - values.append(f'{value}') - component_solution = set_component_solution(''.join(values).replace('2', '?').replace('3', '?')) + if f"{cipher_input}_{i}_0" in variable2value: + value ^= variable2value[f"{cipher_input}_{i}_0"] << 1 + if f"{cipher_input}_{i}_1" in variable2value: + value ^= variable2value[f"{cipher_input}_{i}_1"] + values.append(f"{value}") + component_solution = set_component_solution("".join(values).replace("2", "?").replace("3", "?")) components_solutions[cipher_input] = component_solution return components_solutions @@ -145,10 +152,10 @@ def _get_component_hex_value(self, component, out_suffix, variable2value): value = 0 for i in range(output_bit_size): value <<= 1 - if f'{component.id}_{i}{out_suffix}' in variable2value: - value ^= variable2value[f'{component.id}_{i}{out_suffix}'] + if f"{component.id}_{i}{out_suffix}" in variable2value: + value ^= variable2value[f"{component.id}_{i}{out_suffix}"] hex_digits = output_bit_size // 4 + (output_bit_size % 4 != 0) - hex_value = f'{value:#0{hex_digits+2}x}' + hex_value = f"{value:#0{hex_digits + 2}x}" return hex_value @@ -157,21 +164,19 @@ def _get_component_value_double_ids(self, component, variable2value): values = [] for i in range(output_bit_size): variable_value = 0 - if f'{component.id}_{i}_0' in variable2value: - variable_value ^= variable2value[f'{component.id}_{i}_0'] << 1 - if f'{component.id}_{i}_1' in variable2value: - variable_value ^= variable2value[f'{component.id}_{i}_1'] - values.append(f'{variable_value}') - value = ''.join(values).replace('2', '?').replace('3', '?') + if f"{component.id}_{i}_0" in variable2value: + variable_value ^= variable2value[f"{component.id}_{i}_0"] << 1 + if f"{component.id}_{i}_1" in variable2value: + variable_value ^= variable2value[f"{component.id}_{i}_1"] + values.append(f"{variable_value}") + value = "".join(values).replace("2", "?").replace("3", "?") return value - def _get_solver_solution_parsed(self, variable2number, values): - variable2value = {} - for i, variable in enumerate(variable2number): - variable2value[variable] = 0 if values[i][0] == '-' else 1 + def _get_solver_solution_parsed(self, variables, values): + variable_to_value = {variable: 0 if values[i][0] == "-" else 1 for i, variable in enumerate(variables)} - return variable2value + return variable_to_value def _parallel_counter(self, hw_list, weight): """ @@ -185,54 +190,77 @@ def _parallel_counter(self, hw_list, weight): variables = [] constraints = [] num_of_orders = math.ceil(math.log2(len(hw_list))) - dummy_list = [f'dummy_hw_{i}' for i in range(len(hw_list), 2 ** num_of_orders)] + dummy_list = [f"dummy_hw_{i}" for i in range(len(hw_list), 2**num_of_orders)] variables.extend(dummy_list) hw_list.extend(dummy_list) - constraints.extend(f'-{d}' for d in dummy_list) - for i in range(0, 2 ** num_of_orders, 2): - variables.append(f'r_{num_of_orders - 1}_{i // 2}_0') - variables.append(f'r_{num_of_orders - 1}_{i // 2}_1') - constraints.extend(utils.cnf_and(f'r_{num_of_orders - 1}_{i // 2}_0', - (f'{hw_list[i]}', f'{hw_list[i + 1]}'))) - constraints.extend(utils.cnf_xor(f'r_{num_of_orders - 1}_{i // 2}_1', - [f'{hw_list[i]}', f'{hw_list[i + 1]}'])) + constraints.extend(f"-{d}" for d in dummy_list) + for i in range(0, 2**num_of_orders, 2): + variables.append(f"r_{num_of_orders - 1}_{i // 2}_0") + variables.append(f"r_{num_of_orders - 1}_{i // 2}_1") + constraints.extend( + utils.cnf_and(f"r_{num_of_orders - 1}_{i // 2}_0", (f"{hw_list[i]}", f"{hw_list[i + 1]}")) + ) + constraints.extend( + utils.cnf_xor(f"r_{num_of_orders - 1}_{i // 2}_1", [f"{hw_list[i]}", f"{hw_list[i + 1]}"]) + ) # recursively summing couple words series = num_of_orders - 2 for i in range(2, num_of_orders + 1): - for j in range(0, 2 ** num_of_orders, 2 ** i): + for j in range(0, 2**num_of_orders, 2**i): # carries computed as usual (remember the library convention: MSB indexed by 0) for k in range(0, i - 1): - variables.append(f'c_{series}_{j // (2 ** i)}_{k}') - constraints.extend(utils.cnf_carry(f'c_{series}_{j // (2 ** i)}_{k}', - f'r_{series + 1}_{j // (2 ** (i - 1))}_{k}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k}', - f'c_{series}_{j // (2 ** i)}_{k + 1}')) + variables.append(f"c_{series}_{j // (2**i)}_{k}") + constraints.extend( + utils.cnf_carry( + f"c_{series}_{j // (2**i)}_{k}", + f"r_{series + 1}_{j // (2 ** (i - 1))}_{k}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k}", + f"c_{series}_{j // (2**i)}_{k + 1}", + ) + ) # the carry for the tens is the first not null - variables.append(f'c_{series}_{j // (2 ** i)}_{i - 1}') - constraints.extend(utils.cnf_and(f'c_{series}_{j // (2 ** i)}_{i - 1}', - [f'r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}'])) + variables.append(f"c_{series}_{j // (2**i)}_{i - 1}") + constraints.extend( + utils.cnf_and( + f"c_{series}_{j // (2**i)}_{i - 1}", + [ + f"r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}", + ], + ) + ) # first bit of the result (MSB) is simply the carry of the previous MSBs - variables.append(f'r_{series}_{j // (2 ** i)}_0') - constraints.extend(utils.cnf_equivalent([f'r_{series}_{j // (2 ** i)}_0', - f'c_{series}_{j // (2 ** i)}_0'])) + variables.append(f"r_{series}_{j // (2**i)}_0") + constraints.extend(utils.cnf_equivalent([f"r_{series}_{j // (2**i)}_0", f"c_{series}_{j // (2**i)}_0"])) # remaining bits of the result except the last one are as usual for k in range(1, i): - variables.append(f'r_{series}_{j // (2 ** i)}_{k}') - constraints.extend(utils.cnf_xor(f'r_{series}_{j // (2 ** i)}_{k}', - [f'r_{series + 1}_{j // (2 ** (i - 1))}_{k - 1}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k - 1}', - f'c_{series}_{j // (2 ** i)}_{k}'])) + variables.append(f"r_{series}_{j // (2**i)}_{k}") + constraints.extend( + utils.cnf_xor( + f"r_{series}_{j // (2**i)}_{k}", + [ + f"r_{series + 1}_{j // (2 ** (i - 1))}_{k - 1}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k - 1}", + f"c_{series}_{j // (2**i)}_{k}", + ], + ) + ) # last bit of the result (LSB) - variables.append(f'r_{series}_{j // (2 ** i)}_{i}') - constraints.extend(utils.cnf_xor(f'r_{series}_{j // (2 ** i)}_{i}', - [f'r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}'])) + variables.append(f"r_{series}_{j // (2**i)}_{i}") + constraints.extend( + utils.cnf_xor( + f"r_{series}_{j // (2**i)}_{i}", + [ + f"r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}", + ], + ) + ) series -= 1 # bit length of hamming weight, needed to fix weight when building the model bit_length_of_hw = num_of_orders + 1 - minus_signs = ['-' * (int(bit) ^ 1) for bit in f'{weight:0{bit_length_of_hw}b}'] - constraints.extend([f'{minus_signs[i]}r_0_0_{i}' for i in range(bit_length_of_hw)]) + minus_signs = ["-" * (int(bit) ^ 1) for bit in f"{weight:0{bit_length_of_hw}b}"] + constraints.extend([f"{minus_signs[i]}r_0_0_{i}" for i in range(bit_length_of_hw)]) return variables, constraints @@ -240,104 +268,110 @@ def _sequential_counter_algorithm(self, hw_list, weight, dummy_id, greater_or_eq n = len(hw_list) if greater_or_equal: weight = n - weight - minus = '' + minus = "" else: - minus = '-' - dummy_variables = [[f'{dummy_id}_{i}_{j}' for j in range(weight)] for i in range(n - 1)] - constraints = [f'{minus}{hw_list[0]} {dummy_variables[0][0]}'] - constraints.extend([f'-{dummy_variables[0][j]}' for j in range(1, weight)]) + minus = "-" + dummy_variables = [[f"{dummy_id}_{i}_{j}" for j in range(weight)] for i in range(n - 1)] + constraints = [f"{minus}{hw_list[0]} {dummy_variables[0][0]}"] + constraints.extend([f"-{dummy_variables[0][j]}" for j in range(1, weight)]) for i in range(1, n - 1): - constraints.append(f'{minus}{hw_list[i]} {dummy_variables[i][0]}') - constraints.append(f'-{dummy_variables[i - 1][0]} {dummy_variables[i][0]}') - constraints.extend([f'{minus}{hw_list[i]} -{dummy_variables[i - 1][j - 1]} {dummy_variables[i][j]}' - for j in range(1, weight)]) - constraints.extend([f'-{dummy_variables[i - 1][j]} {dummy_variables[i][j]}' - for j in range(1, weight)]) - constraints.append(f'{minus}{hw_list[i]} -{dummy_variables[i - 1][weight - 1]}') - constraints.append(f'{minus}{hw_list[n - 1]} -{dummy_variables[n - 2][weight - 1]}') + constraints.append(f"{minus}{hw_list[i]} {dummy_variables[i][0]}") + constraints.append(f"-{dummy_variables[i - 1][0]} {dummy_variables[i][0]}") + constraints.extend( + [ + f"{minus}{hw_list[i]} -{dummy_variables[i - 1][j - 1]} {dummy_variables[i][j]}" + for j in range(1, weight) + ] + ) + constraints.extend([f"-{dummy_variables[i - 1][j]} {dummy_variables[i][j]}" for j in range(1, weight)]) + constraints.append(f"{minus}{hw_list[i]} -{dummy_variables[i - 1][weight - 1]}") + constraints.append(f"{minus}{hw_list[n - 1]} -{dummy_variables[n - 2][weight - 1]}") dummy_variables = [d for dummy_list in dummy_variables for d in dummy_list] return dummy_variables, constraints - def _sequential_counter(self, hw_list, weight, dummy_id='dummy_hw_0'): + def _sequential_counter(self, hw_list, weight, dummy_id="dummy_hw_0"): return self._sequential_counter_algorithm(hw_list, weight, dummy_id) def _sequential_counter_greater_or_equal(self, weight, dummy_id): - hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith('hw_')] - variables, constraints = self._sequential_counter_algorithm(hw_list, weight, dummy_id, - greater_or_equal=True) + hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith("hw_")] + variables, constraints = self._sequential_counter_algorithm(hw_list, weight, dummy_id, greater_or_equal=True) self._variables_list.extend(variables) self._model_constraints.extend(constraints) def _solve_with_external_sat_solver(self, model_type, solver_name, options, host=None, env_vars_string=""): - solver_specs = [specs for specs in solvers.SAT_SOLVERS_EXTERNAL - if specs['solver_name'] == solver_name.upper()][0] - if host and (not solver_specs['keywords']['is_dimacs_compliant']): - raise ValueError('{solver_name} not supported.') + solver_specs = [specs for specs in solvers.SAT_SOLVERS_EXTERNAL if specs["solver_name"] == solver_name.upper()][ + 0 + ] + if host and (not solver_specs["keywords"]["is_dimacs_compliant"]): + raise ValueError(f"{solver_name} not supported.") # creating the dimacs - variable2number, numerical_cnf = utils.create_numerical_cnf(self._model_constraints) - dimacs = utils.numerical_cnf_to_dimacs(len(variable2number), numerical_cnf) + variables, numerical_cnf = utils.create_numerical_cnf(self._model_constraints) + dimacs = utils.numerical_cnf_to_dimacs(variables, numerical_cnf) # running the SAT solver - file_id = f'{uuid.uuid4()}' + file_id = f"{uuid.uuid4()}" if host is not None: - status, sat_time, sat_memory, values = utils.run_sat_solver(solver_specs, options, - dimacs, host, env_vars_string) + status, sat_time, sat_memory, values = utils.run_sat_solver( + solver_specs, options, dimacs, host, env_vars_string + ) else: - if solver_specs['keywords']['is_dimacs_compliant']: - status, sat_time, sat_memory, values = utils.run_sat_solver(solver_specs, options, - dimacs) - elif solver_specs['solver_name'] == 'MINISAT_EXT': - input_file = f'{self.cipher_id}_{file_id}_sat_input.cnf' - output_file = f'{self.cipher_id}_{file_id}_sat_output.cnf' - status, sat_time, sat_memory, values = utils.run_minisat(solver_specs, options, dimacs, - input_file, output_file) - elif solver_specs['solver_name'] == 'PARKISSAT_EXT': - input_file = f'{self.cipher_id}_{file_id}_sat_input.cnf' + if solver_specs["keywords"]["is_dimacs_compliant"]: + status, sat_time, sat_memory, values = utils.run_sat_solver(solver_specs, options, dimacs) + elif solver_specs["solver_name"] == solvers.MINISAT_EXT: + input_file = f"{self.cipher_id}_{file_id}_sat_input.cnf" + output_file = f"{self.cipher_id}_{file_id}_sat_output.cnf" + status, sat_time, sat_memory, values = utils.run_minisat( + solver_specs, options, dimacs, input_file, output_file + ) + elif solver_specs["solver_name"] == solvers.PARKISSAT_EXT: + input_file = f"{self.cipher_id}_{file_id}_sat_input.cnf" status, sat_time, sat_memory, values = utils.run_parkissat(solver_specs, options, dimacs, input_file) - elif solver_specs['solver_name'] == 'YICES_SAT_EXT': - input_file = f'{self.cipher_id}_{file_id}_sat_input.cnf' + elif solver_specs["solver_name"] == solvers.YICES_SAT_EXT: + input_file = f"{self.cipher_id}_{file_id}_sat_input.cnf" status, sat_time, sat_memory, values = utils.run_yices(solver_specs, options, dimacs, input_file) # parsing the solution - if status == 'SATISFIABLE': - variable2value = self._get_solver_solution_parsed(variable2number, values) - component2fields, total_weight = self._parse_solver_output(variable2value) - + if status == "SATISFIABLE": + variable_to_value = self._get_solver_solution_parsed(variables, values) + component_to_fields, total_weight = self._parse_solver_output(variable_to_value) else: - component2fields, total_weight = {}, None + component_to_fields, total_weight = {}, None + if total_weight is not None: total_weight = float(total_weight) - solution = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name, sat_time, - sat_memory, component2fields, total_weight) - solution['status'] = status + solution = convert_solver_solution_to_dictionary( + self._cipher, model_type, solver_name, sat_time, sat_memory, component_to_fields, total_weight + ) + solution["status"] = status return solution def _solve_with_sage_sat_solver(self, model_type, solver_name): - variable2number, numerical_cnf = utils.create_numerical_cnf(self._model_constraints) + variable_to_number, numerical_cnf = utils.create_numerical_cnf(self._model_constraints) solver = SAT(solver=solver_name) self._add_clauses_to_solver(numerical_cnf, solver) start_time = time.time() tracemalloc.start() values = solver() - sat_memory = tracemalloc.get_traced_memory()[1] / 10 ** 6 + sat_memory = tracemalloc.get_traced_memory()[1] / 10**6 tracemalloc.stop() sat_time = time.time() - start_time if values: - values = [f'{v-1}' for v in values[1:]] - variable2value = self._get_solver_solution_parsed(variable2number, values) - component2fields, total_weight = self._parse_solver_output(variable2value) - status = 'SATISFIABLE' + values = [f"{v - 1}" for v in values[1:]] + variable_to_value = self._get_solver_solution_parsed(variable_to_number, values) + component_to_fields, total_weight = self._parse_solver_output(variable_to_value) + status = "SATISFIABLE" else: - component2fields, total_weight = {}, None - status = 'UNSATISFIABLE' + component_to_fields, total_weight = {}, None + status = "UNSATISFIABLE" if total_weight is not None: total_weight = float(total_weight) - solution = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name, sat_time, - sat_memory, component2fields, total_weight) - solution['status'] = status + solution = convert_solver_solution_to_dictionary( + self._cipher, model_type, solver_name, sat_time, sat_memory, component_to_fields, total_weight + ) + solution["status"] = status return solution @@ -379,28 +413,58 @@ def fix_variables_value_constraints(fixed_variables=[]): '-ciphertext_0 -ciphertext_1 -ciphertext_2 ciphertext_3'] """ constraints = [] + for variable in fixed_variables: - component_id = variable['component_id'] - is_equal = (variable['constraint_type'] == 'equal') - bit_positions = variable['bit_positions'] - bit_values = variable['bit_values'] - variables_ids = [] - for position, value in zip(bit_positions, bit_values): - is_negative = '-' * (value ^ is_equal) - variables_ids.append(f'{is_negative}{component_id}_{position}') - if is_equal: - constraints.extend(variables_ids) + if variable["bit_values"][0] not in [0,1]: + variables_values = [] + for v in variable["bit_values"]: + variables_values.extend([(v[0], i) for i in v[1]]) + + component_id = variable["component_id"] + is_equal = variable["constraint_type"] == "equal" + bit_positions = variable["bit_positions"] + bit_values = variable["bit_values"] + variables_ids = [] + if is_equal: + for position, value in zip(bit_positions, variables_values): + constraints.extend(utils.cnf_equivalent( + [f"{component_id}_{position}", f"{value[0]}_{value[1]}"] + )) + else: + for position, value in zip(bit_positions, variables_values): + constraints.extend(utils.cnf_xor( + f"{component_id}_fix_{position}", [f"{component_id}_{position}", f"{value[0]}_{value[1]}"] + )) + constraints.append(" ".join(f"{component_id}_fix_{position}" for position in bit_positions)) + else: - constraints.append(' '.join(variables_ids)) + component_id = variable["component_id"] + is_equal = variable["constraint_type"] == "equal" + bit_positions = variable["bit_positions"] + bit_values = variable["bit_values"] + variables_ids = [] + for position, value in zip(bit_positions, bit_values): + is_negative = "-" * (value ^ is_equal) + variables_ids.append(f"{is_negative}{component_id}_{position}") + if is_equal: + constraints.extend(variables_ids) + else: + constraints.append(" ".join(variables_ids)) return constraints def calculate_component_weight(self, component, out_suffix, output_values_dict): weight = 0 - if ('MODSUB' in component.description or 'MODADD' in component.description or 'AND' in component.description - or 'OR' in component.description or SBOX in component.type): - weight = sum([output_values_dict[f'hw_{component.id}_{i}{out_suffix}'] - for i in range(component.output_bit_size)]) + if ( + "MODSUB" in component.description + or "MODADD" in component.description + or "AND" in component.description + or "OR" in component.description + or SBOX in component.type + ): + weight = sum( + [output_values_dict[f"hw_{component.id}_{i}{out_suffix}"] for i in range(component.output_bit_size)] + ) return weight def solve(self, model_type, solver_name=solvers.SOLVER_DEFAULT, options=None): @@ -446,11 +510,11 @@ def solve(self, model_type, solver_name=solvers.SOLVER_DEFAULT, options=None): """ if options is None: options = [] - if solver_name.endswith('_EXT'): + if solver_name.endswith("_EXT"): solution = self._solve_with_external_sat_solver(model_type, solver_name, options) else: if options: - raise ValueError('Options not allowed for SageMath solvers.') + raise ValueError("Options not allowed for SageMath solvers.") solution = self._solve_with_sage_sat_solver(model_type, solver_name) return solution @@ -479,29 +543,32 @@ def weight_constraints(self, weight): '-hw_modadd_2_7_14 -dummy_hw_0_77_6', '-hw_modadd_2_7_15 -dummy_hw_0_78_6']) """ - hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith('hw_')] + hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith("hw_")] if weight == 0: - return [], [f'-{variable}' for variable in hw_list] + return [], [f"-{variable}" for variable in hw_list] return self._counter(hw_list, weight) def build_generic_sat_model_from_dictionary(self, component_and_model_types): self._variables_list = [] self._model_constraints = [] - component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR'] + component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR") for component_and_model_type in component_and_model_types: component = component_and_model_type["component_object"] model_type = component_and_model_type["model_type"] operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: sat_xor_differential_propagation_constraints = getattr(component, model_type) - if model_type in ['sat_bitwise_deterministic_truncated_xor_differential_constraints', - 'sat_semi_deterministic_truncated_xor_differential_constraints']: + if model_type in ( + "sat_bitwise_deterministic_truncated_xor_differential_constraints", + "sat_semi_deterministic_truncated_xor_differential_constraints", + ): variables, constraints = sat_xor_differential_propagation_constraints() else: variables, constraints = sat_xor_differential_propagation_constraints(self) @@ -535,7 +602,7 @@ def model_constraints(self): ValueError: No model generated """ if not self._model_constraints: - raise ValueError('No model generated') + raise ValueError("No model generated") return self._model_constraints @property diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model.py index 4aa6c9526..2da3ecb9e 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model.py @@ -1,38 +1,46 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import time from claasp.cipher_modules.models.sat import solvers -from claasp.cipher_modules.models.sat.sat_models.sat_truncated_xor_differential_model import \ - SatTruncatedXorDifferentialModel +from claasp.cipher_modules.models.sat.sat_models.sat_truncated_xor_differential_model import ( + SatTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.utils import set_component_solution -from claasp.name_mappings import (CIPHER_OUTPUT, CONSTANT, DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, - INTERMEDIATE_OUTPUT, INPUT_PLAINTEXT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, + INTERMEDIATE_OUTPUT, + INPUT_PLAINTEXT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class SatBitwiseDeterministicTruncatedXorDifferentialModel(SatTruncatedXorDifferentialModel): - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) def build_bitwise_deterministic_truncated_xor_differential_trail_model(self, number_of_unknown_variables=None, - fixed_variables=[]): + fixed_variables=[], component_list=None): """ Build the model for the search of deterministic truncated XOR DIFFERENTIAL trails. @@ -57,18 +65,21 @@ def build_bitwise_deterministic_truncated_xor_differential_trail_model(self, num ... """ variables = [] - constraints = SatBitwiseDeterministicTruncatedXorDifferentialModel.fix_variables_value_constraints(fixed_variables) + constraints = SatBitwiseDeterministicTruncatedXorDifferentialModel.fix_variables_value_constraints( + fixed_variables + ) self._variables_list = [] self._model_constraints = constraints component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) - operation_types = ('AND', 'MODADD', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR') + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") - for component in self._cipher.get_all_components(): + component_list = component_list or self._cipher.get_all_components() + for component in component_list: operation = component.description[0] if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): variables, constraints = component.sat_bitwise_deterministic_truncated_xor_differential_constraints() else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -78,8 +89,9 @@ def build_bitwise_deterministic_truncated_xor_differential_trail_model(self, num self._variables_list.extend(variables) self._model_constraints.extend(constraints) - def find_one_bitwise_deterministic_truncated_xor_differential_trail(self, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_one_bitwise_deterministic_truncated_xor_differential_trail( + self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Returns one deterministic truncated XOR differential trail. @@ -126,12 +138,14 @@ def find_one_bitwise_deterministic_truncated_xor_differential_trail(self, fixed_ start_building_time = time.time() self.build_bitwise_deterministic_truncated_xor_differential_trail_model(fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time return solution - def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail( + self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return the solution representing a differential trail with the lowest number of unknown variables. @@ -159,24 +173,26 @@ def find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential current_unknowns_count = 1 start_building_time = time.time() self.build_bitwise_deterministic_truncated_xor_differential_trail_model( - number_of_unknown_variables=current_unknowns_count, fixed_variables=fixed_values) + number_of_unknown_variables=current_unknowns_count, fixed_variables=fixed_values + ) end_building_time = time.time() - solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time = solution['solving_time_seconds'] - max_memory = solution['memory_megabytes'] - while solution['status'] != 'SATISFIABLE': + solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time = solution["solving_time_seconds"] + max_memory = solution["memory_megabytes"] + while solution["status"] != "SATISFIABLE": current_unknowns_count += 1 start_building_time = time.time() self.build_bitwise_deterministic_truncated_xor_differential_trail_model( - number_of_unknown_variables=current_unknowns_count, fixed_variables=fixed_values) + number_of_unknown_variables=current_unknowns_count, fixed_variables=fixed_values + ) end_building_time = time.time() - solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time += solution['solving_time_seconds'] - max_memory = max((max_memory, solution['memory_megabytes'])) - solution['solving_time_seconds'] = total_time - solution['memory_megabytes'] = max_memory + solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time += solution["solving_time_seconds"] + max_memory = max((max_memory, solution["memory_megabytes"])) + solution["solving_time_seconds"] = total_time + solution["memory_megabytes"] = max_memory return solution @@ -206,10 +222,13 @@ def weight_constraints(self, number_of_unknown_variables): '-cipher_output_2_12_31_0 -dummy_hw_0_62_3']) """ cipher_output_id = self._cipher.get_all_components_ids()[-1] - set_to_be_minimized = [f"{INPUT_PLAINTEXT}_{i}_0" - for i in range(self._cipher.inputs_bit_size[self._cipher.inputs.index(INPUT_PLAINTEXT)])] - set_to_be_minimized.extend([bit_id for bit_id in self._variables_list - if bit_id.startswith(cipher_output_id) and bit_id.endswith("_0")]) + set_to_be_minimized = [ + f"{INPUT_PLAINTEXT}_{i}_0" + for i in range(self._cipher.inputs_bit_size[self._cipher.inputs.index(INPUT_PLAINTEXT)]) + ] + set_to_be_minimized.extend( + [bit_id for bit_id in self._variables_list if bit_id.startswith(cipher_output_id) and bit_id.endswith("_0")] + ) return self._counter(set_to_be_minimized, number_of_unknown_variables) @@ -218,6 +237,6 @@ def _parse_solver_output(self, variable2value): for component in self._cipher.get_all_components(): value = self._get_component_value_double_ids(component, variable2value) component_solution = set_component_solution(value) - components_solutions[f'{component.id}'] = component_solution + components_solutions[f"{component.id}"] = component_solution return components_solutions, None diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_bitwise_impossible_xor_differential_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_bitwise_impossible_xor_differential_model.py new file mode 100644 index 000000000..05ff64f4e --- /dev/null +++ b/claasp/cipher_modules/models/sat/sat_models/sat_bitwise_impossible_xor_differential_model.py @@ -0,0 +1,457 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + +import time + +from claasp.cipher_modules.inverse_cipher import get_key_schedule_component_ids +from claasp.cipher_modules.models.sat import solvers +from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_deterministic_truncated_xor_differential_model import ( + SatBitwiseDeterministicTruncatedXorDifferentialModel, +) +from claasp.cipher_modules.models.sat.utils import utils +from claasp.cipher_modules.models.utils import set_component_solution +from claasp.name_mappings import CIPHER_OUTPUT, IMPOSSIBLE_XOR_DIFFERENTIAL, INPUT_KEY + + +class SatBitwiseImpossibleXorDifferentialModel(SatBitwiseDeterministicTruncatedXorDifferentialModel): + def __init__(self, cipher, compact=False): + super().__init__(cipher, compact) + self._forward_cipher = None + self._backward_cipher = None + self._middle_round = None + self._incompatible_components = None + + def build_bitwise_impossible_xor_differential_trail_model(self, fixed_variables=[]): + """ + Build the model for the search of bitwise impossible XOR differential trails. + + INPUTS: + + - ``fixed_variables`` -- **list** (default: `[]`); dictionaries containing the variables to be fixed in + standard format + + .. SEEALSO:: + + :py:meth:`~cipher_modules.models.utils.set_fixed_variables` + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: speck = SpeckBlockCipher(number_of_rounds=2) + sage: sat = SatBitwiseImpossibleXorDifferentialModel(speck) + sage: sat._forward_cipher = speck.get_partial_cipher(0, 1, keep_key_schedule=True) + sage: backward_cipher = sat._cipher.cipher_partial_inverse(1, 1, keep_key_schedule=False) + sage: sat._backward_cipher = backward_cipher.add_suffix_to_components("_backward", [backward_cipher.get_all_components_ids()[-1]]) + sage: sat.build_bitwise_impossible_xor_differential_trail_model() + ... + """ + component_list = self._forward_cipher.get_all_components() + self._backward_cipher.get_all_components() + return self.build_bitwise_deterministic_truncated_xor_differential_trail_model( + fixed_variables=fixed_variables, component_list=component_list + ) + + def find_one_bitwise_impossible_xor_differential_trail( + self, middle_round, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): + """ + Returns one bitwise impossible XOR differential trail. + + INPUTS: + + - ``middle_round`` -- **integer**; the round number for which the incompatibility occurs + - ``fixed_values`` -- *list of dict*, the variables to be fixed in + standard format (see :py:meth:`~GenericModel.set_fixed_variables`) + - ``solver_name`` -- *str*, the solver to call + + EXAMPLES:: + + # table 9 from https://eprint.iacr.org/2014/761.pdf + sage: from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: sat = SatBitwiseImpossibleXorDifferentialModel(simon) + sage: plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), bit_values=[0]*31 + [1]) + sage: key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0]*64) + sage: ciphertext = set_fixed_variables(component_id='cipher_output_10_13', constraint_type='equal', bit_positions=range(32), bit_values=[0]*6 + [2,0,2] + [0]*23) + sage: trail = sat.find_one_bitwise_impossible_xor_differential_trail(6, fixed_values=[plaintext, key, ciphertext]) + + # table 10 from https://eprint.iacr.org/2014/761.pdf + sage: from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: simon = SimonBlockCipher(block_bit_size=48, key_bit_size=72, number_of_rounds=12) + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: sat = SatBitwiseImpossibleXorDifferentialModel(simon) + sage: plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(48), bit_values=[0]*47 + [1]) + sage: key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(72), bit_values=[0]*72) + sage: ciphertext = set_fixed_variables(component_id='cipher_output_11_12', constraint_type='equal', bit_positions=range(48), bit_values=[1]+[0]*16 + [2,0,0,0,2,2,2] + [0]*24) + sage: trail = sat.find_one_bitwise_impossible_xor_differential_trail(7, fixed_values=[plaintext, key, ciphertext]) + + # https://eprint.iacr.org/2016/490.pdf + # requires to comment the constraints ' '.join(incompatibility_ids) as we are considering half rounds not full rounds + sage: from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables + sage: from claasp.ciphers.permutations.ascon_sbox_sigma_permutation import AsconSboxSigmaPermutation + sage: ascon = AsconSboxSigmaPermutation(number_of_rounds=5) + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: sat = SatBitwiseImpossibleXorDifferentialModel(ascon) + sage: plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(320), bit_values=[1] + [0]*191 + [1] + [0]*63 + [1] + [0]*63 ) + sage: P1 = set_fixed_variables(component_id='intermediate_output_0_71', constraint_type='equal', bit_positions=range(320), bit_values= [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + sage: P2 = set_fixed_variables(component_id='intermediate_output_1_71', constraint_type='equal', bit_positions=range(320), bit_values= [2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 2, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 0, 0, 2, 0, 2, 2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0]) + sage: P3 = set_fixed_variables(component_id='intermediate_output_2_71', constraint_type='equal', bit_positions=range(320), bit_values= [2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2]) + sage: P5 = set_fixed_variables(component_id='cipher_output_4_71', constraint_type='equal', bit_positions=range(320), bit_values= [0]*192 + [1] + [0]* 127) + sage: trail = sat.find_one_bitwise_impossible_xor_differential_trail(4, fixed_values=[plaintext, P1, P2, P3, P5]) + + """ + start = time.time() + if middle_round is None: + middle_round = self._cipher.number_of_rounds // 2 + assert middle_round < self._cipher.number_of_rounds + self._middle_round = middle_round + + self._forward_cipher = self._cipher.get_partial_cipher(0, middle_round - 1, keep_key_schedule=True) + backward_cipher = self._cipher.cipher_partial_inverse( + middle_round, self._cipher.number_of_rounds - 1, keep_key_schedule=False + ) + self._backward_cipher = backward_cipher.add_suffix_to_components( + "_backward", [backward_cipher.get_all_components_ids()[-1]] + ) + + self.build_bitwise_impossible_xor_differential_trail_model(fixed_variables=fixed_values) + + forward_output = [c for c in self._forward_cipher.get_all_components() if c.type == CIPHER_OUTPUT][0] + out_size, forward_out_ids_0, forward_out_ids_1 = forward_output._generate_output_double_ids() + backward_out_ids_0 = [ + "_".join(id_.split("_")[:-2] + ["backward"] + id_.split("_")[-2:]) for id_ in forward_out_ids_0 + ] + backward_out_ids_1 = [ + "_".join(id_.split("_")[:-2] + ["backward"] + id_.split("_")[-2:]) for id_ in forward_out_ids_1 + ] + end = time.time() + building_time = end - start + + incompatibility_ids = [f"incompatibility_{forward_output.id}_{i}" for i in range(out_size)] + + for i in range(out_size): + self._model_constraints.extend( + utils.incompatibility( + incompatibility_ids[i], + (forward_out_ids_0[i], forward_out_ids_1[i]), + (backward_out_ids_0[i], backward_out_ids_1[i]), + ) + ) + self._model_constraints.append(" ".join(incompatibility_ids)) + + solution = self.solve(IMPOSSIBLE_XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time"] = building_time + + return solution + + def find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components( + self, component_id_list, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): + """ + Returns one bitwise impossible XOR differential trail. + + INPUTS: + + - ``solver_name`` -- *str*, the solver to call + - ``component_id_list`` -- **str**; the list of component ids for which the incompatibility occurs + - ``fixed_values`` -- *list of dict*, the variables to be fixed in + standard format (see :py:meth:`~GenericModel.set_fixed_variables`) + + EXAMPLES:: + + sage: from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: sat = SatBitwiseImpossibleXorDifferentialModel(simon) + sage: plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), bit_values=[0]*31 + [1]) + sage: key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0]*64) + sage: ciphertext = set_fixed_variables(component_id='cipher_output_10_13', constraint_type='equal', bit_positions=range(32), bit_values=[0]*6 + [2,0,2] + [0]*23) + sage: trail = sat.find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components(component_id_list=['intermediate_output_5_12'], fixed_values=[plaintext, key, ciphertext],solver_name='cryptominisat') + + sage: from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables + sage: from claasp.ciphers.permutations.ascon_sbox_sigma_permutation import AsconSboxSigmaPermutation + sage: ascon = AsconSboxSigmaPermutation(number_of_rounds=5) + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: sat = SatBitwiseImpossibleXorDifferentialModel(ascon) + sage: plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(320), bit_values=[1] + [0]*191 + [1] + [0]*63 + [1] + [0]*63 ) + sage: P1 = set_fixed_variables(component_id='intermediate_output_0_71', constraint_type='equal', bit_positions=range(320), bit_values= [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + sage: P2 = set_fixed_variables(component_id='intermediate_output_1_71', constraint_type='equal', bit_positions=range(320), bit_values= [2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 2, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 0, 0, 2, 0, 2, 2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0]) + sage: P3 = set_fixed_variables(component_id='intermediate_output_2_71', constraint_type='equal', bit_positions=range(320), bit_values= [2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2]) + sage: P5 = set_fixed_variables(component_id='cipher_output_4_71', constraint_type='equal', bit_positions=range(320), bit_values= [0]*192 + [1] + [0]* 127) + sage: trail = sat.find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components(component_id_list=["sbox_3_56"], fixed_values=[plaintext, P1, P2, P3, P5], solver_name='cryptominisat') #doctest: +SKIP + """ + start = time.time() + + if component_id_list is None: + return self.find_one_bitwise_impossible_xor_differential_trail( + middle_round=None, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ) + assert set(component_id_list) <= set(self._cipher.get_all_components_ids()) - set( + get_key_schedule_component_ids(self._cipher) + ) + + rounds = [self._cipher.get_round_from_component_id(cid) for cid in component_id_list] + assert len(set(rounds)) == 1, "All chosen components must be in the same round" + middle = rounds[0] + + if len(component_id_list) == 1: + comp = self._cipher.get_component_from_id(component_id_list[0]) + if comp.description == ["round_output"]: + return self.find_one_bitwise_impossible_xor_differential_trail(middle + 1, fixed_values, solver_name) + + assert middle < self._cipher.number_of_rounds - 1 + self._middle_round = middle + self._forward_cipher = self._cipher.get_partial_cipher(0, middle, keep_key_schedule=True) + backward_cipher = self._cipher.cipher_partial_inverse( + middle, self._cipher.number_of_rounds - 1, keep_key_schedule=False + ) + + self._incompatible_components = component_id_list + + suffix = "_backward" + self._backward_cipher = backward_cipher.add_suffix_to_components( + suffix, backward_cipher.get_all_components_ids() + ) + + self.build_bitwise_impossible_xor_differential_trail_model(fixed_variables=fixed_values) + + incompat_ids = [] + for cid in component_id_list: + fwd_comp = self._forward_cipher.get_component_from_id(cid) + out_size, fwd_out_ids_0, fwd_out_ids_1 = fwd_comp._generate_output_double_ids() + + backward_cid = cid + "_backward" + bwd_comp = self._backward_cipher.get_component_from_id(backward_cid) + bwd_in_ids_0, bwd_in_ids_1 = bwd_comp._generate_input_double_ids() + + for i in range(out_size): + inv_id = f"incompatibility_{cid}_{i}" + incompat_ids.append(inv_id) + self._model_constraints.extend( + utils.incompatibility( + inv_id, (fwd_out_ids_0[i], fwd_out_ids_1[i]), (bwd_in_ids_0[i], bwd_in_ids_1[i]) + ) + ) + + self._model_constraints.append(" ".join(incompat_ids)) + + solution = self.solve(IMPOSSIBLE_XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time"] = time.time() - start + return solution + + def find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model( + self, fixed_values=[], include_all_components=False, solver_name=solvers.SOLVER_DEFAULT, options=None + ): + """ + Returns one bitwise impossible XOR differential trail. + + INPUTS: + + - ``solver_name`` -- *str*, the solver to call + - ``fixed_values`` -- *list of dict*, the variables to be fixed in + standard format (see :py:meth:`~GenericModel.set_fixed_variables`) + - ``include_all_components`` -- **boolean** (default: `False`); when set to `True`, every component output can be + a source of incompatibility; otherwise, only round outputs are considered + + EXAMPLES:: + + sage: from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables + sage: from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher + sage: simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: sat = SatBitwiseImpossibleXorDifferentialModel(simon) + sage: plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), bit_values=[0]*31 + [1]) + sage: key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0]*64) + sage: key_backward = set_fixed_variables(component_id='key_backward', constraint_type='equal', bit_positions=range(64), bit_values=[0]*64) + sage: ciphertext_backward = set_fixed_variables(component_id='cipher_output_10_13_backward', constraint_type='equal', bit_positions=range(32), bit_values=[0]*6 + [2,0,2] + [0]*23) + sage: trail = sat.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model(fixed_values=[plaintext, key, key_backward, ciphertext_backward]) + + + sage: from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables + sage: from claasp.ciphers.permutations.ascon_sbox_sigma_permutation import AsconSboxSigmaPermutation + sage: ascon = AsconSboxSigmaPermutation(number_of_rounds=5) + sage: from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import SatBitwiseImpossibleXorDifferentialModel + sage: sat = SatBitwiseImpossibleXorDifferentialModel(ascon) + sage: P = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(320), bit_values= [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) + sage: trail = sat.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model(fixed_values=[P]) + """ + + start = time.time() + self._forward_cipher = self._cipher + self._backward_cipher = self._cipher.cipher_inverse().add_suffix_to_components("_backward") + + self.build_bitwise_impossible_xor_differential_trail_model(fixed_variables=fixed_values) + + backward_components = [] + forward_output = [c for c in self._forward_cipher.get_all_components() if c.type == CIPHER_OUTPUT][0] + forward_output_id = forward_output.id + "_backward" + + for comp in self._backward_cipher.get_all_components(): + if comp.description == ["round_output"]: + if set(comp.input_id_links) == {forward_output_id}: + continue + backward_components.append(comp) + + if include_all_components: + key_flow = set(get_key_schedule_component_ids(self._cipher)) + backward_key_ids = {f"{k_id}_backward" for k_id in key_flow} + backward_components = [ + c for c in self._backward_cipher.get_all_components() if c.id not in backward_key_ids + ] + + incompat_ids = [] + + for comp in backward_components: + comp_id = comp.id + try: + fwd_comp = self._forward_cipher.get_component_from_id(comp_id.replace("_backward", "")) + except ValueError: + # Skip this backward component because we can't map it to a forward component (es: plaintext_backward). + continue + + _, fwd_out_ids_0, fwd_out_ids_1 = fwd_comp._generate_output_double_ids() + forward_pairs = list(zip(fwd_out_ids_0, fwd_out_ids_1)) + + if include_all_components: + bwd_in_ids_0, bwd_in_ids_1 = comp._generate_input_double_ids() + + inputs_to_be_kept = [] + unique_input_bases = ["_".join(i.split("_")[:-1]) for i in set(comp.input_id_links)] + for input_base in unique_input_bases: + if INPUT_KEY not in input_base: + try: + input_comp = self._cipher.get_component_from_id(input_base) + except ValueError: + continue + linked_backward_ids = [link + "_backward" for link in input_comp.input_id_links] + if linked_backward_ids == [comp_id]: + inputs_to_be_kept.extend([_ for _ in bwd_in_ids_0 + bwd_in_ids_1 if input_base in _]) + + if inputs_to_be_kept: + bwd_ids_filtered_0 = [id_ for id_ in bwd_in_ids_0 if id_ in inputs_to_be_kept] + bwd_ids_filtered_1 = [id_ for id_ in bwd_in_ids_1 if id_ in inputs_to_be_kept] + else: + bwd_ids_filtered_0 = bwd_in_ids_0 + bwd_ids_filtered_1 = bwd_in_ids_1 + + backward_pairs = list(zip(bwd_ids_filtered_0, bwd_ids_filtered_1)) + + else: + bwd_out_ids_0 = [ + "_".join(id_.split("_")[:-2] + ["backward"] + id_.split("_")[-2:]) for id_ in fwd_out_ids_0 + ] + bwd_out_ids_1 = [ + "_".join(id_.split("_")[:-2] + ["backward"] + id_.split("_")[-2:]) for id_ in fwd_out_ids_1 + ] + backward_pairs = list(zip(bwd_out_ids_0, bwd_out_ids_1)) + + for i, (fwd_pair, bwd_pair) in enumerate(zip(forward_pairs, backward_pairs)): + inv_id = f"incompatibility_{fwd_comp.id}_{i}" + incompat_ids.append(inv_id) + self._model_constraints.extend(utils.incompatibility(inv_id, fwd_pair, bwd_pair)) + if incompat_ids: + self._model_constraints.append(" ".join(incompat_ids)) + + solution = self.solve(IMPOSSIBLE_XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time"] = time.time() - start + + return solution + + def _parse_solver_output(self, variable2value): + active_incompatibilities = [ + var for var, val in variable2value.items() if var.startswith("incompatibility_") and val == 1 + ] + + incompatible_components = set() + for var in active_incompatibilities: + parts = var.split("_") + comp_id = "_".join(parts[1:-1]) + incompatible_components.add(comp_id) + + components_solutions = self._get_cipher_inputs_components_solutions_double_ids(variable2value) + + incompatible_rounds = {} + for comp_id in incompatible_components: + round_num = self._cipher.get_round_from_component_id(comp_id) + incompatible_rounds.setdefault(round_num, set()).add(comp_id) + + start_backward = False + if self._incompatible_components is not None: + all_backward_input_ids = set() + for i_component in self._incompatible_components: + backward_component = self._backward_cipher.get_component_from_id(i_component + "_backward") + input_ids = backward_component.input_id_links + all_backward_input_ids.update(input_ids) + + for component in self._cipher.get_all_components(): + comp_id = component.id + comp_round = self._cipher.get_round_from_component_id(comp_id) + if self._forward_cipher == self._cipher: + if comp_round in incompatible_rounds and comp_id in incompatible_rounds[comp_round]: + fwd = self._get_component_value_from_cipher(component, variable2value, "forward") + bwd = self._get_component_value_from_cipher(component, variable2value, "backward") + components_solutions[comp_id] = set_component_solution(fwd) + components_solutions[comp_id + "_backward"] = set_component_solution(bwd) + start_backward = True + elif start_backward: + bwd = self._get_component_value_from_cipher(component, variable2value, "backward") + components_solutions[comp_id + "_backward"] = set_component_solution(bwd) + else: + value = self._get_component_value_double_ids(component, variable2value) + components_solutions[comp_id] = set_component_solution(value) + elif self._incompatible_components != None: + if comp_id + "_backward" in all_backward_input_ids: + bwd = self._get_component_value_from_cipher(component, variable2value, "backward") + components_solutions[comp_id + "_backward"] = set_component_solution(bwd) + else: + value = self._get_component_value_double_ids(component, variable2value) + components_solutions[comp_id] = set_component_solution(value) + else: + if comp_round in incompatible_rounds and comp_id in incompatible_rounds[comp_round]: + fwd = self._get_component_value_from_cipher(component, variable2value, "forward") + bwd = self._get_component_value_from_cipher(component, variable2value, "backward") + components_solutions[comp_id] = set_component_solution(fwd) + components_solutions[comp_id + "_backward"] = set_component_solution(bwd) + else: + value = self._get_component_value_double_ids(component, variable2value) + components_solutions[comp_id] = set_component_solution(value) + + return components_solutions, None + + def _get_component_value_from_cipher(self, component, variable2value, cipher_type): + if cipher_type == "forward": + forward_component = self._forward_cipher.get_component_from_id(component.id) + return self._get_component_value_double_ids(forward_component, variable2value) + + if cipher_type == "backward": + backward_id = f"{component.id}_backward" + values = [] + for i in range(component.output_bit_size): + variable_value = 0 + if f"{backward_id}_{i}_0" in variable2value: + variable_value ^= variable2value[f"{backward_id}_{i}_0"] << 1 + if f"{backward_id}_{i}_1" in variable2value: + variable_value ^= variable2value[f"{backward_id}_{i}_1"] + values.append(f"{variable_value}") + backward_component_value = "".join(values).replace("2", "?").replace("3", "?") + return backward_component_value + + return self._get_component_value_double_ids(component, variable2value) diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_cipher_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_cipher_model.py index d55f011e2..b04f22dd4 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_cipher_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_cipher_model.py @@ -1,33 +1,39 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import time from claasp.cipher_modules.models.sat import solvers from claasp.cipher_modules.models.sat.sat_model import SatModel from claasp.cipher_modules.models.utils import set_component_solution -from claasp.name_mappings import (CIPHER, WORD_OPERATION, CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, - MIX_COLUMN, SBOX) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CIPHER, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class SatCipherModel(SatModel): - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) def build_cipher_model(self, fixed_variables=[]): @@ -54,14 +60,15 @@ def build_cipher_model(self, fixed_variables=[]): constraints = SatModel.fix_variables_value_constraints(fixed_variables) self._variables_list = [] self._model_constraints = constraints - component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR'] + component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR") for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.sat_constraints() @@ -73,19 +80,20 @@ def build_generic_sat_model_from_dictionary(self, fixed_variables, component_and constraints = SatModel.fix_variables_value_constraints(fixed_variables) self._variables_list = [] self._model_constraints = constraints - component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR'] + component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR") for component_and_model_type in component_and_model_types: component = component_and_model_type["component_object"] model_type = component_and_model_type["model_type"] operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: sat_xor_differential_propagation_constraints = getattr(component, model_type) - if model_type == 'sat_bitwise_deterministic_truncated_xor_differential_constraints': + if model_type == "sat_bitwise_deterministic_truncated_xor_differential_constraints": variables, constraints = sat_xor_differential_propagation_constraints() else: variables, constraints = sat_xor_differential_propagation_constraints(self) @@ -93,7 +101,7 @@ def build_generic_sat_model_from_dictionary(self, fixed_variables, component_and self._model_constraints.extend(constraints) self._variables_list.extend(variables) - def find_missing_bits(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + def find_missing_bits(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None): """ Return the solution representing a generic flow of the cipher from plaintext and key to ciphertext. @@ -113,32 +121,27 @@ def find_missing_bits(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT) sage: from claasp.cipher_modules.models.sat.sat_models.sat_cipher_model import SatCipherModel sage: sat = SatCipherModel(speck) sage: from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list + sage: ciphertext_id = speck.get_all_components_ids()[-1] sage: ciphertext = set_fixed_variables( - ....: component_id=speck.get_all_components_ids()[-1], - ....: constraint_type='equal', - ....: bit_positions=range(32), - ....: bit_values=integer_to_bit_list(endianness='big', list_length=32, int_value=0xaffec7ed)) - sage: sat.find_missing_bits(fixed_values=[ciphertext]) # random - {'cipher_id': 'speck_p32_k64_o32_r22', - 'model_type': 'cipher', - 'solver_name': 'CRYPTOMINISAT_EXT', - ... - 'intermediate_output_21_11': {'value': '1411'}, - 'cipher_output_21_12': {'value': 'affec7ed'}}, - 'total_weight': None, - 'status': 'SATISFIABLE', - 'building_time_seconds': 0.019376516342163086} + ....: component_id=ciphertext_id, + ....: constraint_type="equal", + ....: bit_positions=range(32), + ....: bit_values=integer_to_bit_list(endianness="big", list_length=32, int_value=0xaffec7ed) + ....: ) + sage: trail = sat.find_missing_bits(fixed_values=[ciphertext]) + sage: trail["components_values"][ciphertext_id]["value"] + '0xaffec7ed' """ start_building_time = time.time() self.build_cipher_model(fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(CIPHER, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution = self.solve(CIPHER, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time return solution def _parse_solver_output(self, variable2value): - out_suffix = '' + out_suffix = "" components_solutions = self._get_cipher_inputs_components_solutions(out_suffix, variable2value) for component in self._cipher.get_all_components(): hex_value = self._get_component_hex_value(component, out_suffix, variable2value) diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_differential_linear_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_differential_linear_model.py index c53975e04..6fa6d5a10 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_differential_linear_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_differential_linear_model.py @@ -1,19 +1,39 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + import time from claasp.cipher_modules.models.sat import solvers from claasp.cipher_modules.models.sat.sat_model import SatModel from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_deterministic_truncated_xor_differential_model import ( - SatBitwiseDeterministicTruncatedXorDifferentialModel + SatBitwiseDeterministicTruncatedXorDifferentialModel, +) +from claasp.cipher_modules.models.sat.sat_models.sat_semi_deterministic_truncated_xor_differential_model import ( + SatSemiDeterministicTruncatedXorDifferentialModel, ) -from claasp.cipher_modules.models.sat.sat_models.sat_semi_deterministic_truncated_xor_differential_model import \ - SatSemiDeterministicTruncatedXorDifferentialModel from claasp.cipher_modules.models.sat.sat_models.sat_xor_linear_model import SatXorLinearModel from claasp.cipher_modules.models.sat.utils import utils as sat_utils, constants -from claasp.cipher_modules.models.sat.utils.utils import _generate_component_model_types, \ - _update_component_model_types_for_truncated_components, _update_component_model_types_for_linear_components +from claasp.cipher_modules.models.sat.utils.utils import ( + _generate_component_model_types, + _update_component_model_types_for_truncated_components, + _update_component_model_types_for_linear_components, +) from claasp.cipher_modules.models.utils import set_component_solution, get_bit_bindings -from claasp.ciphers.block_ciphers.threefish_block_cipher import INPUT_TWEAK -from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, INPUT_TWEAK class SatDifferentialLinearModel(SatModel): @@ -23,10 +43,10 @@ class SatDifferentialLinearModel(SatModel): """ def __init__( - self, - cipher, - list_of_components, - middle_part_model="sat_bitwise_deterministic_truncated_xor_differential_constraints" + self, + cipher, + list_of_components, + middle_part_model="sat_bitwise_deterministic_truncated_xor_differential_constraints", ): """ Initializes the model with cipher and components. @@ -52,31 +72,33 @@ def __init__( _update_component_model_types_for_linear_components(component_model_types, bottom_part_component_ids) self.dict_of_components = component_model_types - self.regular_components = self._get_components_by_type('sat_xor_differential_propagation_constraints') + self.regular_components = self._get_components_by_type("sat_xor_differential_propagation_constraints") - model_types = set(component['model_type'] for component in self.dict_of_components) + model_types = {component["model_type"] for component in self.dict_of_components} truncated_model_types = [ - item for item in model_types if - item != 'sat_xor_differential_propagation_constraints' and item != 'sat_xor_linear_mask_propagation_constraints' + item + for item in model_types + if item != "sat_xor_differential_propagation_constraints" + and item != "sat_xor_linear_mask_propagation_constraints" ] allow_truncated_models_types = [ - 'sat_semi_deterministic_truncated_xor_differential_constraints', - 'sat_bitwise_deterministic_truncated_xor_differential_constraints' + "sat_semi_deterministic_truncated_xor_differential_constraints", + "sat_bitwise_deterministic_truncated_xor_differential_constraints", ] - if len(truncated_model_types + allow_truncated_models_types) == 0 or len( - truncated_model_types + allow_truncated_models_types) == 2: - + if ( + len(truncated_model_types + allow_truncated_models_types) == 0 + or len(truncated_model_types + allow_truncated_models_types) == 2 + ): raise ValueError(f"Model types should be one of {allow_truncated_models_types}") self.truncated_model_type = truncated_model_types[0] self.truncated_components = self._get_components_by_type(self.truncated_model_type) - self.linear_components = self._get_components_by_type( - 'sat_xor_linear_mask_propagation_constraints') - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) + self.linear_components = self._get_components_by_type("sat_xor_linear_mask_propagation_constraints") + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, "_".join) super().__init__(cipher, "sequential", False) def _get_components_by_type(self, model_type): @@ -89,7 +111,7 @@ def _get_components_by_type(self, model_type): RETURN: - **list**; A list of components of the specified type. """ - return [component for component in self.dict_of_components if component['model_type'] == model_type] + return [component for component in self.dict_of_components if component["model_type"] == model_type] def _get_regular_xor_differential_components_in_border(self): """ @@ -98,11 +120,11 @@ def _get_regular_xor_differential_components_in_border(self): RETURN: - **list**; A list of regular components at the border. """ - regular_component_ids = {item['component_id'] for item in self.regular_components} + regular_component_ids = {item["component_id"] for item in self.regular_components} border_components = [] for truncated_component in self.truncated_components: - component_obj = self.cipher.get_component_from_id(truncated_component['component_id']) + component_obj = self.cipher.get_component_from_id(truncated_component["component_id"]) for input_id in component_obj.input_id_links: if input_id in regular_component_ids: border_components.append(input_id) @@ -116,10 +138,10 @@ def _get_truncated_xor_differential_components_in_border(self): RETURN: - **list**; A list of truncated components at the border. """ - truncated_component_ids = {item['component_id'] for item in self.truncated_components} + truncated_component_ids = {item["component_id"] for item in self.truncated_components} border_components = [] for linear_component in self.linear_components: - component_obj = self.cipher.get_component_from_id(linear_component['component_id']) + component_obj = self.cipher.get_component_from_id(linear_component["component_id"]) for input_id in component_obj.input_id_links: if input_id in truncated_component_ids: border_components.append(input_id) @@ -133,7 +155,7 @@ def _get_connecting_constraints(self): def get_component_output_bit_size(component_identifier): component_output_bit_size = 0 - if component_identifier not in [INPUT_KEY, INPUT_PLAINTEXT, INPUT_TWEAK]: + if component_identifier not in (INPUT_KEY, INPUT_PLAINTEXT, INPUT_TWEAK): component = self.cipher.get_component_from_id(component_identifier) component_output_bit_size = component.output_bit_size else: @@ -153,32 +175,32 @@ def is_any_string_in_list_substring_of_string(string, string_list): output_bit_size = get_component_output_bit_size(component_id) for idx in range(output_bit_size): constraints = sat_utils.get_cnf_bitwise_truncate_constraints( - f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1' + f"{component_id}_{idx}", f"{component_id}_{idx}_0", f"{component_id}_{idx}_1" ) self._model_constraints.extend(constraints) - self._variables_list.extend([ - f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1' - ]) + self._variables_list.extend( + [f"{component_id}_{idx}", f"{component_id}_{idx}_0", f"{component_id}_{idx}_1"] + ) border_components = self._get_truncated_xor_differential_components_in_border() - linear_component_ids = [item['component_id'] for item in self.linear_components] + linear_component_ids = [item["component_id"] for item in self.linear_components] for component_id in border_components: component = self.cipher.get_component_from_id(component_id) for idx in range(component.output_bit_size): - truncated_component = f'{component_id}_{idx}_o' + truncated_component = f"{component_id}_{idx}_o" component_successors = self.bit_bindings[truncated_component] for component_successor in component_successors: length_component_successor = len(component_successor) - component_successor_id = component_successor[:length_component_successor-2] + component_successor_id = component_successor[: length_component_successor - 2] if is_any_string_in_list_substring_of_string(component_successor_id, linear_component_ids): constraints = sat_utils.get_cnf_truncated_linear_constraints( - component_successor, f'{component_id}_{idx}_0' + component_successor, f"{component_id}_{idx}_0" ) self._model_constraints.extend(constraints) - self._variables_list.extend([component_successor, f'{component_id}_{idx}_0']) + self._variables_list.extend([component_successor, f"{component_id}_{idx}_0"]) def _build_weight_constraints(self, weight): """ @@ -191,7 +213,7 @@ def _build_weight_constraints(self, weight): - **tuple**; A tuple containing a list of variables and a list of constraints. """ - hw_variables = [var_id for var_id in self._variables_list if var_id.startswith('hw_')] + hw_variables = [var_id for var_id in self._variables_list if var_id.startswith("hw_")] linear_component_ids = [linear_component["component_id"] for linear_component in self.linear_components] hw_linear_variables = [] @@ -201,7 +223,7 @@ def _build_weight_constraints(self, weight): hw_linear_variables.append(hw_variable) hw_variables.extend(hw_linear_variables) if weight == 0: - return [], [f'-{var}' for var in hw_variables] + return [], [f"-{var}" for var in hw_variables] return self._counter(hw_variables, weight) @@ -225,10 +247,10 @@ def _build_unknown_variable_constraints(self, num_unknowns): return self._sequential_counter(minimize_vars, num_unknowns, "dummy_id_unknown") def build_xor_differential_linear_model( - self, - weight=-1, - num_unknown_vars=None, - unknown_window_size_configuration=None, + self, + weight=-1, + num_unknown_vars=None, + unknown_window_size_configuration=None, ): """ Constructs a model to search for differential-linear trails. @@ -280,7 +302,8 @@ def build_xor_differential_linear_model( SatSemiDeterministicTruncatedXorDifferentialModel.unknown_window_size_configuration_constraints( unknown_window_size_configuration, variables_list=self._variables_list, - cardinality_constraint_method=self._counter) + cardinality_constraint_method=self._counter, + ) ) self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -289,7 +312,8 @@ def build_xor_differential_linear_model( @staticmethod def fix_variables_value_constraints( - fixed_variables, regular_components=None, truncated_components=None, linear_components=None): + fixed_variables, regular_components=None, truncated_components=None, linear_components=None + ): """ Imposes fixed value constraints on variables within differential, truncated, and linear components. @@ -309,7 +333,7 @@ def fix_variables_value_constraints( for var in fixed_variables: component_id = var["component_id"] - if component_id in [comp["component_id"] for comp in regular_components] and 2 in var['bit_values']: + if component_id in [comp["component_id"] for comp in regular_components] and 2 in var["bit_values"]: raise ValueError("The fixed value in a regular XOR differential component cannot be 2") if component_id in [comp["component_id"] for comp in truncated_components]: @@ -323,7 +347,8 @@ def fix_variables_value_constraints( regular_constraints = SatModel.fix_variables_value_constraints(regular_vars) truncated_constraints = SatBitwiseDeterministicTruncatedXorDifferentialModel.fix_variables_value_constraints( - truncated_vars) + truncated_vars + ) linear_constraints = SatXorLinearModel.fix_variables_value_xor_linear_constraints(linear_vars) return regular_constraints + truncated_constraints + linear_constraints @@ -340,22 +365,22 @@ def _parse_solver_output(self, variable2value): RETURN: - **tuple**; a tuple containing the dictionary of component solutions and the total weight. """ - components_solutions = self._get_cipher_inputs_components_solutions('', variable2value) + components_solutions = self._get_cipher_inputs_components_solutions("", variable2value) total_weight_diff = 0 total_weight_lin = 0 for component in self._cipher.get_all_components(): - if component.id in [d['component_id'] for d in self.regular_components]: - hex_value = self._get_component_hex_value(component, '', variable2value) - weight = self.calculate_component_weight(component, '', variable2value) + if component.id in [d["component_id"] for d in self.regular_components]: + hex_value = self._get_component_hex_value(component, "", variable2value) + weight = self.calculate_component_weight(component, "", variable2value) components_solutions[component.id] = set_component_solution(hex_value, weight) total_weight_diff += weight - elif component.id in [d['component_id'] for d in self.truncated_components]: + elif component.id in [d["component_id"] for d in self.truncated_components]: value = self._get_component_value_double_ids(component, variable2value) components_solutions[component.id] = set_component_solution(value, weight=0) - elif component.id in [d['component_id'] for d in self.linear_components]: + elif component.id in [d["component_id"] for d in self.linear_components]: hex_value = self._get_component_hex_value(component, constants.OUTPUT_BIT_ID_SUFFIX, variable2value) weight = self.calculate_component_weight(component, constants.OUTPUT_BIT_ID_SUFFIX, variable2value) total_weight_lin += weight @@ -364,12 +389,13 @@ def _parse_solver_output(self, variable2value): return components_solutions, total_weight_diff + 2 * total_weight_lin def find_one_differential_linear_trail_with_fixed_weight( - self, - weight, - num_unknown_vars=None, - fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT, - unknown_window_size_configuration=None + self, + weight, + num_unknown_vars=None, + fixed_values=[], + solver_name=solvers.SOLVER_DEFAULT, + unknown_window_size_configuration=None, + options=None, ): """ Finds one XOR differential-linear trail with a fixed weight. The weight must be the sum of the probability weight @@ -436,21 +462,18 @@ def find_one_differential_linear_trail_with_fixed_weight( self.build_xor_differential_linear_model(weight, num_unknown_vars, unknown_window_size_configuration) constraints = self.fix_variables_value_constraints( - fixed_values, - self.regular_components, - self.truncated_components, - self.linear_components + fixed_values, self.regular_components, self.truncated_components, self.linear_components ) self.model_constraints.extend(constraints) - solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name) - solution['building_time_seconds'] = time.time() - start_time - solution['test_name'] = "find_one_differential_linear_trail" + solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name, options=options) + solution["building_time_seconds"] = time.time() - start_time + solution["test_name"] = "find_one_differential_linear_trail" return solution def find_lowest_weight_xor_differential_linear_trail( - self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, num_unknown_vars=1 + self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, num_unknown_vars=1, options=None ): """ Finds the differential-linear trail with the lowest weight. @@ -467,28 +490,25 @@ def find_lowest_weight_xor_differential_linear_trail( start_building_time = time.time() self.build_xor_differential_linear_model(current_weight, num_unknown_vars) constraints = self.fix_variables_value_constraints( - fixed_values, - self.regular_components, - self.truncated_components, - self.linear_components + fixed_values, self.regular_components, self.truncated_components, self.linear_components ) self.model_constraints.extend(constraints) end_building_time = time.time() - solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time = solution['solving_time_seconds'] - max_memory = solution['memory_megabytes'] - while solution['total_weight'] is None: + solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time = solution["solving_time_seconds"] + max_memory = solution["memory_megabytes"] + while solution["total_weight"] is None: current_weight += 1 self.build_xor_differential_linear_model(current_weight, num_unknown_vars) self.model_constraints.extend(constraints) - solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name) - total_time += solution['solving_time_seconds'] - max_memory = max(max_memory, solution['memory_megabytes']) + solution = self.solve("XOR_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name, options=options) + total_time += solution["solving_time_seconds"] + max_memory = max(max_memory, solution["memory_megabytes"]) - solution['solving_time_seconds'] = total_time - solution['memory_megabytes'] = max_memory - solution['test_name'] = "find_lowest_weight_differential_linear_trail" + solution["solving_time_seconds"] = total_time + solution["memory_megabytes"] = max_memory + solution["test_name"] = "find_lowest_weight_differential_linear_trail" return solution diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_model.py index e0bcde6b7..4e8fb18fd 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_model.py @@ -1,11 +1,30 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + import time from claasp.cipher_modules.models.sat import solvers from claasp.cipher_modules.models.sat.sat_model import SatModel -from claasp.cipher_modules.models.sat.sat_models.sat_semi_deterministic_truncated_xor_differential_model import \ - SatSemiDeterministicTruncatedXorDifferentialModel -from claasp.cipher_modules.models.sat.sat_models.sat_truncated_xor_differential_model import \ - SatTruncatedXorDifferentialModel +from claasp.cipher_modules.models.sat.sat_models.sat_semi_deterministic_truncated_xor_differential_model import ( + SatSemiDeterministicTruncatedXorDifferentialModel, +) +from claasp.cipher_modules.models.sat.sat_models.sat_truncated_xor_differential_model import ( + SatTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel from claasp.cipher_modules.models.sat.utils import utils as sat_utils from claasp.cipher_modules.models.utils import set_component_solution @@ -25,24 +44,26 @@ def __init__(self, cipher, dict_of_components): - ``dict_of_components`` -- **dict**; Dictionary mapping component IDs to their respective models and types. """ self.dict_of_components = dict_of_components - model_types = set(component['model_type'] for component in self.dict_of_components) + model_types = {component["model_type"] for component in self.dict_of_components} - truncated_model_types = {item for item in model_types if item != 'sat_xor_differential_propagation_constraints'} + truncated_model_types = {item for item in model_types if item != "sat_xor_differential_propagation_constraints"} allow_truncated_models_types = { - 'sat_semi_deterministic_truncated_xor_differential_constraints', - 'sat_bitwise_deterministic_truncated_xor_differential_constraints' + "sat_semi_deterministic_truncated_xor_differential_constraints", + "sat_bitwise_deterministic_truncated_xor_differential_constraints", } - if len(truncated_model_types & allow_truncated_models_types) == 0 or len( - truncated_model_types & allow_truncated_models_types) == 2: + if ( + len(truncated_model_types & allow_truncated_models_types) == 0 + or len(truncated_model_types & allow_truncated_models_types) == 2 + ): raise ValueError(f"Model types should be one of {allow_truncated_models_types}") truncated_model_type = list(truncated_model_types)[0] self.truncated_model_type = truncated_model_type - self.regular_components = self._get_components_by_type('sat_xor_differential_propagation_constraints') + self.regular_components = self._get_components_by_type("sat_xor_differential_propagation_constraints") self.truncated_components = self._get_components_by_type(truncated_model_type) - super().__init__(cipher, "sequential", False) + super().__init__(cipher=cipher, counter="sequential", compact=False) def _get_components_by_type(self, model_type): """ @@ -54,7 +75,7 @@ def _get_components_by_type(self, model_type): RETURN: - **list**; A list of components of the specified type. """ - return [component for component in self.dict_of_components if component['model_type'] == model_type] + return [component for component in self.dict_of_components if component["model_type"] == model_type] def _get_regular_xor_differential_components_in_border(self): """ @@ -63,11 +84,11 @@ def _get_regular_xor_differential_components_in_border(self): RETURN: - **list**; A list of regular components at the border. """ - regular_component_ids = {item['component_id'] for item in self.regular_components} + regular_component_ids = {item["component_id"] for item in self.regular_components} border_components = [] for truncated_component in self.truncated_components: - component_obj = self.cipher.get_component_from_id(truncated_component['component_id']) + component_obj = self.cipher.get_component_from_id(truncated_component["component_id"]) for input_id in component_obj.input_id_links: if input_id in regular_component_ids: border_components.append(input_id) @@ -84,12 +105,12 @@ def _get_connecting_constraints(self): component = self.cipher.get_component_from_id(component_id) for idx in range(component.output_bit_size): constraints = sat_utils.get_cnf_bitwise_truncate_constraints( - f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1' + f"{component_id}_{idx}", f"{component_id}_{idx}_0", f"{component_id}_{idx}_1" ) self._model_constraints.extend(constraints) - self._variables_list.extend([ - f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1' - ]) + self._variables_list.extend( + [f"{component_id}_{idx}", f"{component_id}_{idx}_0", f"{component_id}_{idx}_1"] + ) def _build_weight_constraints(self, top_part_weight, truncated_part_weight_configuration=None): """ @@ -104,13 +125,16 @@ def _build_weight_constraints(self, top_part_weight, truncated_part_weight_confi variables = [] constraints = [] hw_variables = [ - var_id for var_id in self._variables_list if - var_id.startswith('hw_') and not var_id.startswith('hw_p') and not var_id.startswith( - 'hw_q') and not var_id.startswith('hw_r') + var_id + for var_id in self._variables_list + if var_id.startswith("hw_") + and not var_id.startswith("hw_p") + and not var_id.startswith("hw_q") + and not var_id.startswith("hw_r") ] if top_part_weight == 0: - return [], [f'-{var}' for var in hw_variables] + return [], [f"-{var}" for var in hw_variables] top_part_weight_variables, top_part_weight_constraints = self._counter(hw_variables, top_part_weight) variables.extend(top_part_weight_variables) @@ -119,11 +143,11 @@ def _build_weight_constraints(self, top_part_weight, truncated_part_weight_confi return variables, constraints def build_xor_probabilistic_truncated_differential_model( - self, - weight=-1, - number_of_unknowns_per_component=None, - unknown_window_size_configuration=None, - fixed_variables=[] + self, + weight=-1, + number_of_unknowns_per_component=None, + unknown_window_size_configuration=None, + fixed_variables=[], ): """ Constructs a model to search for probabilistic truncated XOR differential trails. @@ -173,10 +197,12 @@ def build_xor_probabilistic_truncated_differential_model( self._model_constraints.extend(constraints) if unknown_window_size_configuration is not None: - variables, constraints = SatSemiDeterministicTruncatedXorDifferentialModel.unknown_window_size_configuration_constraints( - unknown_window_size_configuration, - variables_list=self._variables_list, - cardinality_constraint_method=self._counter + variables, constraints = ( + SatSemiDeterministicTruncatedXorDifferentialModel.unknown_window_size_configuration_constraints( + unknown_window_size_configuration, + variables_list=self._variables_list, + cardinality_constraint_method=self._counter, + ) ) self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -201,7 +227,7 @@ def fix_variables_value_constraints(fixed_variables, regular_components=None, tr for var in fixed_variables: component_id = var["component_id"] - if component_id in [comp["component_id"] for comp in regular_components] and 2 in var['bit_values']: + if component_id in [comp["component_id"] for comp in regular_components] and 2 in var["bit_values"]: raise ValueError("The fixed value in a regular XOR differential component cannot be 2") if component_id in [comp["component_id"] for comp in truncated_components]: @@ -210,8 +236,7 @@ def fix_variables_value_constraints(fixed_variables, regular_components=None, tr regular_vars.append(var) regular_constraints = SatModel.fix_variables_value_constraints(regular_vars) - truncated_constraints = SatTruncatedXorDifferentialModel.fix_variables_value_constraints( - truncated_vars) + truncated_constraints = SatTruncatedXorDifferentialModel.fix_variables_value_constraints(truncated_vars) return regular_constraints + truncated_constraints @@ -225,38 +250,40 @@ def _parse_solver_output(self, variable2value): RETURN: - **tuple**; a tuple containing the dictionary of component solutions and the total weight. """ - out_suffix = '' + out_suffix = "" components_solutions = self._get_cipher_inputs_components_solutions(out_suffix, variable2value) total_weight = 0 total_weight_truncated = 0 for component in self._cipher.get_all_components(): - if component.id in [d['component_id'] for d in self.regular_components]: - hex_value = self._get_component_hex_value(component, '', variable2value) - weight = self.calculate_component_weight(component, '', variable2value) + if component.id in [d["component_id"] for d in self.regular_components]: + hex_value = self._get_component_hex_value(component, "", variable2value) + weight = self.calculate_component_weight(component, "", variable2value) components_solutions[component.id] = set_component_solution(hex_value, weight) total_weight += weight - elif component.id in [d['component_id'] for d in self.truncated_components]: - + elif component.id in [d["component_id"] for d in self.truncated_components]: value = self._get_component_value_double_ids(component, variable2value) weight = 0 - if component.description[ - 0] == 'MODADD' and self.truncated_model_type == 'sat_semi_deterministic_truncated_xor_differential_constraints': - weight = SatSemiDeterministicTruncatedXorDifferentialModel._calculate_component_weight(component, - variable2value, - self._variables_list) + if ( + component.description[0] == "MODADD" + and self.truncated_model_type == "sat_semi_deterministic_truncated_xor_differential_constraints" + ): + weight = SatSemiDeterministicTruncatedXorDifferentialModel._calculate_component_weight( + component, variable2value, self._variables_list + ) total_weight += weight total_weight_truncated += weight components_solutions[component.id] = set_component_solution(value, weight) return components_solutions, total_weight def find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weight( - self, - weight, - number_of_unknowns_per_component=None, - fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT, - unknown_window_size_configuration=None + self, + weight, + number_of_unknowns_per_component=None, + fixed_values=[], + solver_name=solvers.SOLVER_DEFAULT, + unknown_window_size_configuration=None, + options=None, ): """ Finds one XOR probabilistic truncated differential trail with a fixed weight. @@ -277,21 +304,22 @@ def find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weight( self.build_xor_probabilistic_truncated_differential_model( weight, number_of_unknowns_per_component=number_of_unknowns_per_component, - unknown_window_size_configuration=unknown_window_size_configuration + unknown_window_size_configuration=unknown_window_size_configuration, ) constraints = self.fix_variables_value_constraints( fixed_values, self.regular_components, self.truncated_components ) self.model_constraints.extend(constraints) - solution = self.solve("XOR_REGULAR_DETERMINISTIC_DIFFERENTIAL", solver_name=solver_name) - solution['building_time_seconds'] = time.time() - start_time - solution['test_name'] = "find_one_regular_truncated_xor_differential_trail" + solution = self.solve("XOR_REGULAR_DETERMINISTIC_DIFFERENTIAL", solver_name=solver_name, options=options) + solution["building_time_seconds"] = time.time() - start_time + solution["test_name"] = "find_one_regular_truncated_xor_differential_trail" return solution - def find_lowest_weight_xor_probabilistic_truncated_differential_trail(self, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_lowest_weight_xor_probabilistic_truncated_differential_trail( + self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Finds the XOR probabilistic truncated differential trail with the lowest weight. @@ -305,30 +333,31 @@ def find_lowest_weight_xor_probabilistic_truncated_differential_trail(self, fixe current_weight = 0 start_building_time = time.time() self.build_xor_probabilistic_truncated_differential_model(current_weight) - constraints = self.fix_variables_value_constraints(fixed_values, self.regular_components, - self.truncated_components) + constraints = self.fix_variables_value_constraints( + fixed_values, self.regular_components, self.truncated_components + ) end_building_time = time.time() self.model_constraints.extend(constraints) - solution = self.solve("XOR_REGULAR_DETERMINISTIC_DIFFERENTIAL", solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time = solution['solving_time_seconds'] - max_memory = solution['memory_megabytes'] + solution = self.solve("XOR_REGULAR_DETERMINISTIC_DIFFERENTIAL", solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time = solution["solving_time_seconds"] + max_memory = solution["memory_megabytes"] - while solution['total_weight'] is None: + while solution["total_weight"] is None: current_weight += 1 start_building_time = time.time() self.build_xor_probabilistic_truncated_differential_model(current_weight) self.model_constraints.extend(constraints) - solution = self.solve("XOR_REGULAR_DETERMINISTIC_DIFFERENTIAL", solver_name=solver_name) + solution = self.solve("XOR_REGULAR_DETERMINISTIC_DIFFERENTIAL", solver_name=solver_name, options=options) end_building_time = time.time() - solution['building_time_seconds'] = end_building_time - start_building_time - total_time += solution['solving_time_seconds'] - max_memory = max(max_memory, solution['memory_megabytes']) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time += solution["solving_time_seconds"] + max_memory = max(max_memory, solution["memory_megabytes"]) - solution['solving_time_seconds'] = total_time - solution['memory_megabytes'] = max_memory - solution['test_name'] = "find_lowest_weight_xor_probabilistic_truncated_differential_trail" + solution["solving_time_seconds"] = total_time + solution["memory_megabytes"] = max_memory + solution["test_name"] = "find_lowest_weight_xor_probabilistic_truncated_differential_trail" return solution diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_semi_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_semi_deterministic_truncated_xor_differential_model.py index 6372ca040..b1422a8fd 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_semi_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_semi_deterministic_truncated_xor_differential_model.py @@ -1,29 +1,37 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import time from claasp.cipher_modules.models.sat import solvers -from claasp.cipher_modules.models.sat.sat_models.sat_truncated_xor_differential_model import \ - SatTruncatedXorDifferentialModel +from claasp.cipher_modules.models.sat.sat_models.sat_truncated_xor_differential_model import ( + SatTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.utils import set_component_solution -from claasp.name_mappings import (CIPHER_OUTPUT, CONSTANT, DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, - INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) def group_triples(var_names): @@ -38,7 +46,7 @@ def group_triples(var_names): grouped = {} for name in var_names: if "hw_p_modadd" in name or "hw_q_modadd" in name or "hw_r_modadd" in name: - parts = name.split('_') + parts = name.split("_") probability_component = parts[1] round_index = parts[3] component_index = parts[4] @@ -46,24 +54,24 @@ def group_triples(var_names): key = (round_index, component_index, position_index) if key not in grouped: - grouped[key] = {'p': None, 'q': None, 'r': None} + grouped[key] = {"p": None, "q": None, "r": None} grouped[key][probability_component] = name triples_dict = {} for k, bit_map in grouped.items(): - p_name = bit_map['p'] - q_name = bit_map['q'] - r_name = bit_map['r'] + p_name = bit_map["p"] + q_name = bit_map["q"] + r_name = bit_map["r"] triples_dict[k] = (p_name, q_name, r_name) return triples_dict class SatSemiDeterministicTruncatedXorDifferentialModel(SatTruncatedXorDifferentialModel): - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) def build_semi_deterministic_truncated_xor_differential_trail_model( - self, number_of_unknowns_per_component=None, unknown_window_size_configuration=None, fixed_variables=[] + self, number_of_unknowns_per_component=None, unknown_window_size_configuration=None, fixed_variables=[] ): """ Build the model for the search of deterministic truncated XOR DIFFERENTIAL trails. @@ -93,14 +101,14 @@ def build_semi_deterministic_truncated_xor_differential_trail_model( self._variables_list = [] self._model_constraints = constraints component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) - operation_types = ('AND', 'MODADD', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR') + operation_types = ("AND", "MODADD", "NOT", "OR", "ROTATE", "SHIFT", "XOR") for component in self._cipher.get_all_components(): operation = component.description[0] if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): variables, constraints = component.sat_semi_deterministic_truncated_xor_differential_constraints() else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -109,20 +117,23 @@ def build_semi_deterministic_truncated_xor_differential_trail_model( self._build_unknown_variable_constraints(number_of_unknowns_per_component) if unknown_window_size_configuration is not None: - variables, constraints = SatSemiDeterministicTruncatedXorDifferentialModel.unknown_window_size_configuration_constraints( - unknown_window_size_configuration, - variables_list=self._variables_list, - cardinality_constraint_method=self._counter + variables, constraints = ( + SatSemiDeterministicTruncatedXorDifferentialModel.unknown_window_size_configuration_constraints( + unknown_window_size_configuration, + variables_list=self._variables_list, + cardinality_constraint_method=self._counter, + ) ) self._variables_list.extend(variables) self._model_constraints.extend(constraints) def find_one_semi_deterministic_truncated_xor_differential_trail( - self, - fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT, - unknown_window_size_configuration=None, - number_of_unknowns_per_component=None + self, + fixed_values=[], + solver_name=solvers.SOLVER_DEFAULT, + unknown_window_size_configuration=None, + number_of_unknowns_per_component=None, + options=None, ): """ Returns one deterministic truncated XOR differential trail. @@ -139,18 +150,18 @@ def find_one_semi_deterministic_truncated_xor_differential_trail( self.build_semi_deterministic_truncated_xor_differential_trail_model( fixed_variables=fixed_values, unknown_window_size_configuration=unknown_window_size_configuration, - number_of_unknowns_per_component=number_of_unknowns_per_component + number_of_unknowns_per_component=number_of_unknowns_per_component, ) end_building_time = time.time() - solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution = self.solve(DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time return solution @staticmethod def unknown_window_size_configuration_constraints( - unknown_window_size_configuration, variables_list=None, cardinality_constraint_method=None + unknown_window_size_configuration, variables_list=None, cardinality_constraint_method=None ): """ Return lists of variables and constraints that fix the number of unknown @@ -166,11 +177,11 @@ def unknown_window_size_configuration_constraints( new_variables_list = [] new_constraints_list = [] - max_number_of_seq_window_size_0 = unknown_window_size_configuration['max_number_of_sequences_window_size_0'] - max_number_of_seq_window_size_1 = unknown_window_size_configuration['max_number_of_sequences_window_size_1'] - max_number_of_seq_window_size_2 = unknown_window_size_configuration['max_number_of_sequences_window_size_2'] + max_number_of_seq_window_size_0 = unknown_window_size_configuration["max_number_of_sequences_window_size_0"] + max_number_of_seq_window_size_1 = unknown_window_size_configuration["max_number_of_sequences_window_size_1"] + max_number_of_seq_window_size_2 = unknown_window_size_configuration["max_number_of_sequences_window_size_2"] - hw_variables = [var_id for var_id in variables_list if var_id.startswith('hw_')] + hw_variables = [var_id for var_id in variables_list if var_id.startswith("hw_")] def x_iff_abc_cnf(a: str, b: str, c: str, x: str) -> list: """ @@ -192,7 +203,7 @@ def negate(var): f"{negate(x)} {a}", f"{negate(x)} {b}", f"{negate(x)} {c}", - f"{negate(a)} {negate(b)} {negate(c)} {x}" + f"{negate(a)} {negate(b)} {negate(c)} {x}", ] return clauses @@ -203,18 +214,14 @@ def negate(var): for tuple_key, tuple_value in triples_dict.items(): window_1_var = "hw_window_1_" + "_".join(tuple_key) window_1_vars.append(window_1_var) - constraints = x_iff_abc_cnf( - tuple_value[0], "-" + tuple_value[1], tuple_value[2], window_1_var - ) + constraints = x_iff_abc_cnf(tuple_value[0], "-" + tuple_value[1], tuple_value[2], window_1_var) new_variables_list.extend([window_1_var]) new_constraints_list.extend(constraints) window_2_var = "hw_window_2_" + "_".join(tuple_key) window_2_vars.append(window_2_var) - constraints = x_iff_abc_cnf( - tuple_value[0], "-" + tuple_value[1], "-" + tuple_value[2], window_2_var - ) + constraints = x_iff_abc_cnf(tuple_value[0], "-" + tuple_value[1], "-" + tuple_value[2], window_2_var) new_variables_list.extend([window_2_var]) new_constraints_list.extend(constraints) cardinality_variables_window_1, cardinality_constraints_window_1 = cardinality_constraint_method( @@ -302,10 +309,14 @@ def get_probability_expressions(input_dict): return counts weight = 0 - if ('MODSUB' in component.description or 'MODADD' in component.description or 'AND' in component.description - or 'OR' in component.description or SBOX in component.type): - - hw_variables = [var_id for var_id in variables_list if var_id.startswith('hw_')] + if ( + "MODSUB" in component.description + or "MODADD" in component.description + or "AND" in component.description + or "OR" in component.description + or SBOX in component.type + ): + hw_variables = [var_id for var_id in variables_list if var_id.startswith("hw_")] hw_variables = [var_id for var_id in hw_variables if component.id in var_id] triples_dict = group_triples(hw_variables) @@ -328,6 +339,6 @@ def _parse_solver_output(self, variable2value): total_weight += weight component_solution = set_component_solution(value, weight) - components_solutions[f'{component.id}'] = component_solution + components_solutions[f"{component.id}"] = component_solution return components_solutions, total_weight diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model.py index 5fe1f3cb5..e70945183 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model.py @@ -1,3 +1,20 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + import time from copy import deepcopy @@ -5,8 +22,10 @@ from claasp.cipher_modules.models.sat.sat_model import SatModel from claasp.cipher_modules.models.sat.sat_models.sat_xor_linear_model import SatXorLinearModel from claasp.cipher_modules.models.sat.utils import utils as sat_utils, constants -from claasp.cipher_modules.models.sat.utils.utils import _generate_component_model_types, \ - _update_component_model_types_for_linear_components +from claasp.cipher_modules.models.sat.utils.utils import ( + _generate_component_model_types, + _update_component_model_types_for_linear_components, +) from claasp.cipher_modules.models.utils import set_component_solution, get_bit_bindings @@ -49,7 +68,7 @@ def __init__(self, cipher, dict_of_components): _update_component_model_types_for_linear_components(component_model_types, bottom_part_components) self.dict_of_components = component_model_types - self.regular_components = self._get_components_by_type('sat_xor_differential_propagation_constraints') + self.regular_components = self._get_components_by_type("sat_xor_differential_propagation_constraints") new_regular_components = [] regular_components = deepcopy(self.regular_components) for regular_component_dict in regular_components: @@ -58,26 +77,28 @@ def __init__(self, cipher, dict_of_components): regular_component = regular_component_dict["component_object"] round_number = cipher.get_round_from_component_id(regular_component_id) regular_component_copy = deepcopy(regular_component) - regular_component_copy._id = 'cipher1_' + regular_component._id + regular_component_copy._id = "cipher1_" + regular_component._id new_input_id_links = [ - f'cipher1_{input_id_link}' if input_id_link not in cipher.inputs else input_id_link + f"cipher1_{input_id_link}" if input_id_link not in cipher.inputs else input_id_link for input_id_link in regular_component_copy.input_id_links ] regular_component_copy.set_input_id_links(new_input_id_links) cipher._rounds.rounds[round_number]._components.extend([regular_component_copy]) - new_regular_components.append({ - 'component_id': regular_component_copy.id, - 'component_object': regular_component_copy, - "model_type": "sat_xor_differential_propagation_constraints" - }) + new_regular_components.append( + { + "component_id": regular_component_copy.id, + "component_object": regular_component_copy, + "model_type": "sat_xor_differential_propagation_constraints", + } + ) self.regular_components.extend(new_regular_components) self.new_regular_components = new_regular_components - self.linear_components = self._get_components_by_type('sat_xor_linear_mask_propagation_constraints') - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) - super().__init__(cipher, "sequential", False) + self.linear_components = self._get_components_by_type("sat_xor_linear_mask_propagation_constraints") + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, "_".join) + super().__init__(cipher=cipher, counter="sequential", compact=False) def _get_components_by_type(self, model_type): """ @@ -89,7 +110,7 @@ def _get_components_by_type(self, model_type): RETURN: - **list**; A list of components of the specified type. """ - return [component for component in self.dict_of_components if component['model_type'] == model_type] + return [component for component in self.dict_of_components if component["model_type"] == model_type] def _get_regular_xor_differential_components_in_border(self): """ @@ -98,11 +119,11 @@ def _get_regular_xor_differential_components_in_border(self): RETURN: - **list**; A list of regular components at the border. """ - regular_component_ids = {item['component_id'] for item in self.regular_components} + regular_component_ids = {item["component_id"] for item in self.regular_components} border_components = [] for linear_component in self.linear_components: - component_obj = self.cipher.get_component_from_id(linear_component['component_id']) + component_obj = self.cipher.get_component_from_id(linear_component["component_id"]) for input_id in component_obj.input_id_links: if input_id in regular_component_ids: border_components.append(input_id) @@ -118,29 +139,29 @@ def is_any_string_in_list_substring_of_string(string, string_list): return any(s in string for s in string_list) border_components = self._get_regular_xor_differential_components_in_border() - linear_component_ids = [item['component_id'] for item in self.linear_components] + linear_component_ids = [item["component_id"] for item in self.linear_components] for component_id in border_components: component = self.cipher.get_component_from_id(component_id) for idx in range(component.output_bit_size): - linear_component = f'{component_id}_{idx}_o' + linear_component = f"{component_id}_{idx}_o" component_successors = self.bit_bindings[linear_component] for component_successor in component_successors: length_component_successor = len(component_successor) - component_successor_id = component_successor[:length_component_successor - 2] + component_successor_id = component_successor[: length_component_successor - 2] if is_any_string_in_list_substring_of_string(component_successor_id, linear_component_ids): # TODO: update method name get_cnf_truncated_linear_constraints for something more general constraints = sat_utils.get_cnf_truncated_linear_constraints( - component_successor, f'{component_id}_{idx}' + component_successor, f"{component_id}_{idx}" ) self._model_constraints.extend(constraints) constraints = sat_utils.get_cnf_truncated_linear_constraints( - component_successor, f'cipher1_{component_id}_{idx}' + component_successor, f"cipher1_{component_id}_{idx}" ) self._model_constraints.extend(constraints) - self._variables_list.extend([component_successor, f'{component_id}_{idx}']) + self._variables_list.extend([component_successor, f"{component_id}_{idx}"]) def _build_weight_constraints(self, weight): """ @@ -153,7 +174,7 @@ def _build_weight_constraints(self, weight): - **tuple**; A tuple containing a list of variables and a list of constraints. """ - hw_variables = [var_id for var_id in self._variables_list if var_id.startswith('hw_')] + hw_variables = [var_id for var_id in self._variables_list if var_id.startswith("hw_")] linear_component_ids = [linear_component["component_id"] for linear_component in self.linear_components] hw_linear_variables = [] @@ -163,7 +184,7 @@ def _build_weight_constraints(self, weight): hw_linear_variables.append(hw_variable) hw_variables.extend(hw_linear_variables) if weight == 0: - return [], [f'-{var}' for var in hw_variables] + return [], [f"-{var}" for var in hw_variables] return self._counter(hw_variables, weight) @@ -196,10 +217,10 @@ def build_shared_difference_paired_input_differential_model(self, weight=-1): self._model_constraints.extend(constraints) high_order_differential_constraints = [] for component in self._cipher.get_all_components(): - if (component.id.startswith('cipher1_') and "modadd" in component.id): + if component.id.startswith("cipher1_") and "modadd" in component.id: component_copy_id = component.id.split("cipher1_")[1] for i in range(component.output_bit_size): - new_constraint_cnf = [f'-cipher1_{component_copy_id}_{i} -{component_copy_id}_{i}'] + new_constraint_cnf = [f"-cipher1_{component_copy_id}_{i} -{component_copy_id}_{i}"] high_order_differential_constraints.extend(new_constraint_cnf) self._model_constraints.extend(high_order_differential_constraints) @@ -211,8 +232,7 @@ def build_shared_difference_paired_input_differential_model(self, weight=-1): self._get_connecting_constraints() @staticmethod - def fix_variables_value_constraints( - fixed_variables, regular_components=None, linear_components=None): + def fix_variables_value_constraints(fixed_variables, regular_components=None, linear_components=None): """ Imposes fixed value constraints on variables within differential and linear components. @@ -228,7 +248,7 @@ def fix_variables_value_constraints( for var in fixed_variables: component_id = var["component_id"] - if component_id in [comp["component_id"] for comp in regular_components] and 2 in var['bit_values']: + if component_id in [comp["component_id"] for comp in regular_components] and 2 in var["bit_values"]: raise ValueError("The fixed value in a regular XOR differential component cannot be 2") elif component_id in [comp["component_id"] for comp in linear_components]: linear_vars.append(var) @@ -255,19 +275,19 @@ def _parse_solver_output(self, variable2value): - **tuple**; a tuple containing the dictionary of component solutions and the total weight. """ - components_solutions = self._get_cipher_inputs_components_solutions('', variable2value) + components_solutions = self._get_cipher_inputs_components_solutions("", variable2value) total_weight_diff = 0 total_weight_lin = 0 for component in self._cipher.get_all_components(): - if component.id in [d['component_id'] for d in self.regular_components]: - hex_value = self._get_component_hex_value(component, '', variable2value) + if component.id in [d["component_id"] for d in self.regular_components]: + hex_value = self._get_component_hex_value(component, "", variable2value) - weight = self.calculate_component_weight(component, '', variable2value) + weight = self.calculate_component_weight(component, "", variable2value) components_solutions[component.id] = set_component_solution(hex_value, weight) total_weight_diff += weight - elif component.id in [d['component_id'] for d in self.linear_components]: + elif component.id in [d["component_id"] for d in self.linear_components]: hex_value = self._get_component_hex_value(component, constants.OUTPUT_BIT_ID_SUFFIX, variable2value) weight = self.calculate_component_weight(component, constants.OUTPUT_BIT_ID_SUFFIX, variable2value) total_weight_lin += weight @@ -276,7 +296,8 @@ def _parse_solver_output(self, variable2value): return components_solutions, total_weight_diff + 2 * total_weight_lin def find_one_shared_difference_paired_input_differential_linear_trail_with_fixed_weight( - self, weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + self, weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Finds a high-order differential-linear trail with fixed weight using paired inputs and a shared difference. @@ -359,15 +380,15 @@ def find_one_shared_difference_paired_input_differential_linear_trail_with_fixed self.build_shared_difference_paired_input_differential_model(weight) constraints = self.fix_variables_value_constraints( - fixed_values, - self.regular_components, - self.linear_components + fixed_values, self.regular_components, self.linear_components ) self.model_constraints.extend(constraints) - solution = self.solve("SHARED_DIFFERENCE_PAIRED_INPUT_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name) - solution['building_time_seconds'] = time.time() - start_time - solution['test_name'] = "find_one_shared_difference_paired_input_differential_linear_model_trail" + solution = self.solve( + "SHARED_DIFFERENCE_PAIRED_INPUT_DIFFERENTIAL_LINEAR_MODEL", solver_name=solver_name, options=options + ) + solution["building_time_seconds"] = time.time() - start_time + solution["test_name"] = "find_one_shared_difference_paired_input_differential_linear_model_trail" return solution diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model.py index b437106bb..52d45bf63 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model.py @@ -15,7 +15,6 @@ # along with this program. If not, see . # **************************************************************************** - import time from copy import deepcopy @@ -28,9 +27,9 @@ def add_prefix_id_to_components(cipher, prefix): all_components = cipher.rounds.get_all_components() for component in all_components: - component.set_id(f'{prefix}_{component.id}') + component.set_id(f"{prefix}_{component.id}") new_input_id_links = [ - f'{prefix}_{input_id_link}' if input_id_link not in cipher.inputs else input_id_link + f"{prefix}_{input_id_link}" if input_id_link not in cipher.inputs else input_id_link for input_id_link in component.input_id_links ] @@ -66,7 +65,7 @@ def __init__(self, cipher): """ cipher1 = cipher cipher2 = deepcopy(cipher) - add_prefix_id_to_components(cipher1, 'cipher1') + add_prefix_id_to_components(cipher1, "cipher1") for round_number in range(cipher.number_of_rounds): round_components2 = cipher2.get_components_in_round(round_number) cipher1._rounds.rounds[round_number]._components.extend(round_components2) @@ -105,16 +104,17 @@ def build_shared_difference_paired_input_differential_model(self, weight=-1, fix self._variables_list = self.differential_model._variables_list new_constraints = [] for component in self._cipher.get_all_components(): - if ((component.id.startswith('cipher1_') and "modadd" in component.id) or - (component.id.startswith('cipher1_') and "modsub" in component.id)): + if (component.id.startswith("cipher1_") and "modadd" in component.id) or ( + component.id.startswith("cipher1_") and "modsub" in component.id + ): component_copy_id = component.id.split("cipher1_")[1] for i in range(component.output_bit_size): - new_constraints.append(f'-cipher1_{component_copy_id}_{i} -{component_copy_id}_{i}') + new_constraints.append(f"-cipher1_{component_copy_id}_{i} -{component_copy_id}_{i}") self._model_constraints.extend(new_constraints) self.differential_model._model_constraints.extend(new_constraints) def find_one_shared_difference_paired_input_differential_trail_with_fixed_weight( - self, weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + self, weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None ): """ Return a single solution representing a high-order XOR differential trail for paired inputs (`x`, `y`) @@ -182,22 +182,22 @@ def find_one_shared_difference_paired_input_differential_trail_with_fixed_weight start_time = time.time() self.build_shared_difference_paired_input_differential_model(weight, fixed_variables=fixed_values) solution = self.differential_model.solve( - "SHARED_DIFFERENCE_PAIRED_INPUT_DIFFERENTIAL_MODEL", solver_name=solver_name + "SHARED_DIFFERENCE_PAIRED_INPUT_DIFFERENTIAL_MODEL", solver_name=solver_name, options=options ) - solution['building_time_seconds'] = time.time() - start_time - solution['test_name'] = "find_one_shared_difference_paired_input_differential_model_trail" + solution["building_time_seconds"] = time.time() - start_time + solution["test_name"] = "find_one_shared_difference_paired_input_differential_model_trail" return solution def _parse_solver_output(self, variable2value): - out_suffix = '' + out_suffix = "" components_solutions = self._get_cipher_inputs_components_solutions(out_suffix, variable2value) total_weight = 0 for component in self._cipher.get_all_components(): hex_value = self._get_component_hex_value(component, out_suffix, variable2value) weight = self.calculate_component_weight(component, out_suffix, variable2value) component_solution = set_component_solution(hex_value, weight) - components_solutions[f'{component.id}{out_suffix}'] = component_solution + components_solutions[f"{component.id}{out_suffix}"] = component_solution total_weight += weight return components_solutions, total_weight diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_truncated_xor_differential_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_truncated_xor_differential_model.py index 8e4e89fef..fa20f55f5 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_truncated_xor_differential_model.py @@ -1,8 +1,25 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + from claasp.cipher_modules.models.sat.sat_model import SatModel class SatTruncatedXorDifferentialModel(SatModel): - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) @staticmethod @@ -55,39 +72,39 @@ def fix_variables_value_constraints(fixed_variables=[]): """ constraints = [] for variable in fixed_variables: - component_id = variable['component_id'] - is_equal = (variable['constraint_type'] == 'equal') - bit_positions = variable['bit_positions'] - bit_values = variable['bit_values'] + component_id = variable["component_id"] + is_equal = variable["constraint_type"] == "equal" + bit_positions = variable["bit_positions"] + bit_values = variable["bit_values"] variables_ids = [] all_values_are_2 = all(v == 2 for v in bit_values) for position, value in zip(bit_positions, bit_values): - false_sign = '-' * is_equal - true_sign = '-' * (not is_equal) + false_sign = "-" * is_equal + true_sign = "-" * (not is_equal) if value == 0: - variables_ids.append(f'{false_sign}{component_id}_{position}_0') - variables_ids.append(f'{false_sign}{component_id}_{position}_1') + variables_ids.append(f"{false_sign}{component_id}_{position}_0") + variables_ids.append(f"{false_sign}{component_id}_{position}_1") elif value == 1: - variables_ids.append(f'{false_sign}{component_id}_{position}_0') - variables_ids.append(f'{true_sign}{component_id}_{position}_1') + variables_ids.append(f"{false_sign}{component_id}_{position}_0") + variables_ids.append(f"{true_sign}{component_id}_{position}_1") elif value == 2: if not is_equal: # Forbid (1,0) and ensure mutual exclusion of (1,1) - constraints.append(f'-{component_id}_{position}_0 {component_id}_{position}_1') - constraints.append(f'-{component_id}_{position}_0 -{component_id}_{position}_1') + constraints.append(f"-{component_id}_{position}_0 {component_id}_{position}_1") + constraints.append(f"-{component_id}_{position}_0 -{component_id}_{position}_1") else: - variables_ids.append(f'{true_sign}{component_id}_{position}_0') + variables_ids.append(f"{true_sign}{component_id}_{position}_0") if is_equal: constraints.extend(variables_ids) else: if all_values_are_2: # Require at least one (0,1) tuple - clause = ' '.join([f'{component_id}_{position}_1' for position in bit_positions]) + clause = " ".join([f"{component_id}_{position}_1" for position in bit_positions]) constraints.append(clause) else: - joined_clause = ' '.join(variables_ids) + joined_clause = " ".join(variables_ids) if joined_clause: constraints.append(joined_clause) @@ -99,18 +116,21 @@ def _build_unknown_variable_constraints(self, number_of_unknowns_per_component): for component_id in list(number_of_unknowns_per_component.keys()): if component_id in self._cipher.get_all_components_ids(): set_to_be_minimized = [] - set_to_be_minimized.extend([bit_id for bit_id in self._variables_list - if bit_id.startswith(component_id) and bit_id.endswith("_0")]) + set_to_be_minimized.extend( + [ + bit_id + for bit_id in self._variables_list + if bit_id.startswith(component_id) and bit_id.endswith("_0") + ] + ) number_of_unknowns_per_component = number_of_unknowns_per_component[component_id] unknown_variables, unknown_constraints = self._sequential_counter_algorithm( - set_to_be_minimized, - number_of_unknowns_per_component, - f'unknown_vars_for_{component_id}' + set_to_be_minimized, number_of_unknowns_per_component, f"unknown_vars_for_{component_id}" ) variables.extend(unknown_variables) constraints.extend(unknown_constraints) else: - raise ValueError(f'Component {component_id} not found in number_of_unknowns_per_component') + raise ValueError(f"Component {component_id} not found in number_of_unknowns_per_component") return variables, constraints diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_xor_differential_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_xor_differential_model.py index 70ae149c4..e60dbd1c4 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_xor_differential_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_xor_differential_model.py @@ -15,7 +15,6 @@ # along with this program. If not, see . # **************************************************************************** - import time from copy import deepcopy @@ -23,12 +22,20 @@ from claasp.cipher_modules.models.sat.sat_model import SatModel from claasp.cipher_modules.models.sat.sat_models.sat_cipher_model import SatCipherModel from claasp.cipher_modules.models.utils import set_component_solution, get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, - MIX_COLUMN, SBOX, WORD_OPERATION, XOR_DIFFERENTIAL) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, + XOR_DIFFERENTIAL, +) class SatXorDifferentialModel(SatModel): - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): self._window_size_by_component_id_values = None self._window_size_by_round_values = None self._window_size_full_window_vars = None @@ -66,13 +73,14 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): constraints = SatModel.fix_variables_value_constraints(fixed_variables) self._model_constraints = constraints component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION) - operation_types = ('AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR') + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.sat_xor_differential_propagation_constraints(self) @@ -89,46 +97,46 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): if self._window_size_number_of_full_window == 0: self._variables_list.extend([]) - self._model_constraints.extend([f'-{variable}' for variable in self._window_size_full_window_vars]) + self._model_constraints.extend([f"-{variable}" for variable in self._window_size_full_window_vars]) return - if self._window_size_full_window_operator == 'at_least': + if self._window_size_full_window_operator == "at_least": all_ones_dummy_variables, all_ones_constraints = self._sequential_counter_algorithm( self._window_size_full_window_vars, self._window_size_number_of_full_window - 1, - 'dummy_all_ones_at_least', - greater_or_equal=True + "dummy_all_ones_at_least", + greater_or_equal=True, ) - elif self._window_size_full_window_operator == 'at_most': + elif self._window_size_full_window_operator == "at_most": all_ones_dummy_variables, all_ones_constraints = self._sequential_counter_algorithm( self._window_size_full_window_vars, self._window_size_number_of_full_window, - 'dummy_all_ones_at_most', - greater_or_equal=False + "dummy_all_ones_at_most", + greater_or_equal=False, ) - elif self._window_size_full_window_operator == 'exactly': + elif self._window_size_full_window_operator == "exactly": all_ones_dummy_variables1, all_ones_constraints1 = self._sequential_counter_algorithm( self._window_size_full_window_vars, self._window_size_number_of_full_window, - 'dummy_all_ones_at_least', - greater_or_equal=True + "dummy_all_ones_at_least", + greater_or_equal=True, ) all_ones_dummy_variables2, all_ones_constraints2 = self._sequential_counter_algorithm( self._window_size_full_window_vars, self._window_size_number_of_full_window, - 'dummy_all_ones_at_most', - greater_or_equal=False + "dummy_all_ones_at_most", + greater_or_equal=False, ) all_ones_dummy_variables = all_ones_dummy_variables1 + all_ones_dummy_variables2 all_ones_constraints = all_ones_constraints1 + all_ones_constraints2 else: - raise ValueError(f'Unknown operator {self._window_size_full_window_operator}') + raise ValueError(f"Unknown operator {self._window_size_full_window_operator}") self._variables_list.extend(all_ones_dummy_variables) self._model_constraints.extend(all_ones_constraints) def build_xor_differential_trail_and_checker_model_at_intermediate_output_level( - self, weight=-1, fixed_variables=[] + self, weight=-1, fixed_variables=[] ): """ Build the model for the search of XOR DIFFERENTIAL trails and the model to check that there is at least one pair @@ -160,8 +168,9 @@ def build_xor_differential_trail_and_checker_model_at_intermediate_output_level( self._variables_list.extend(sat._variables_list) self._model_constraints.extend(sat._model_constraints) - def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_differential_trails_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return a list of solutions containing all the XOR differential trails having the ``fixed_weight`` weight. By default, the search is set in the single-key setting. @@ -205,34 +214,36 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed start_building_time = time.time() self.build_xor_differential_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() - solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time solutions_list = [] - while solution['total_weight'] is not None: + while solution["total_weight"] is not None: solutions_list.append(solution) literals = [] for input_, bit_len in zip(self._cipher.inputs, self._cipher.inputs_bit_size): - value_to_avoid = int(solution['components_values'][input_]['value'], base=16) - minus = ['-' * (value_to_avoid >> i & 1) for i in reversed(range(bit_len))] - literals.extend([f'{minus[i]}{input_}_{i}' for i in range(bit_len)]) + value_to_avoid = int(solution["components_values"][input_]["value"], base=16) + minus = ["-" * (value_to_avoid >> i & 1) for i in reversed(range(bit_len))] + literals.extend([f"{minus[i]}{input_}_{i}" for i in range(bit_len)]) for component in self._cipher.get_all_components(): bit_len = component.output_bit_size - if component.type == SBOX or \ - (component.type == WORD_OPERATION and - component.description[0] in ('AND', 'MODADD', 'MODSUB', 'OR', 'SHIFT_BY_VARIABLE_AMOUNT')): - value_to_avoid = int(solution['components_values'][component.id]['value'], base=16) - minus = ['-' * (value_to_avoid >> i & 1) for i in reversed(range(bit_len))] - literals.extend([f'{minus[i]}{component.id}_{i}' for i in range(bit_len)]) - self._model_constraints.append(' '.join(literals)) - solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_all_xor_differential_trails_with_fixed_weight" + if component.type == SBOX or ( + component.type == WORD_OPERATION + and component.description[0] in ("AND", "MODADD", "MODSUB", "OR", "SHIFT_BY_VARIABLE_AMOUNT") + ): + value_to_avoid = int(solution["components_values"][component.id]["value"], base=16) + minus = ["-" * (value_to_avoid >> i & 1) for i in reversed(range(bit_len))] + literals.extend([f"{minus[i]}{component.id}_{i}" for i in range(bit_len)]) + self._model_constraints.append(" ".join(literals)) + solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_all_xor_differential_trails_with_fixed_weight" return solutions_list - def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_differential_trails_with_weight_at_most( + self, min_weight, max_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return a list of solutions. By default, the search is set in the single-key setting. @@ -279,17 +290,19 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w """ solutions_list = [] for weight in range(min_weight, max_weight + 1): - solutions = self.find_all_xor_differential_trails_with_fixed_weight(weight, - fixed_values=fixed_values, - solver_name=solver_name) + solutions = self.find_all_xor_differential_trails_with_fixed_weight( + weight, fixed_values=fixed_values, solver_name=solver_name, options=options + ) for solution in solutions: - solution['test_name'] = "find_all_xor_differential_trails_with_weight_at_most" + solution["test_name"] = "find_all_xor_differential_trails_with_weight_at_most" solutions_list.extend(solutions) return solutions_list - def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + def find_lowest_weight_xor_differential_trail( + self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return the solution representing a trail with the lowest weight. By default, the search is set in the single-key setting. @@ -338,26 +351,26 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name start_building_time = time.time() self.build_xor_differential_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time = solution['solving_time_seconds'] - max_memory = solution['memory_megabytes'] - while solution['total_weight'] is None: + solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time = solution["solving_time_seconds"] + max_memory = solution["memory_megabytes"] + while solution["total_weight"] is None: current_weight += 1 start_building_time = time.time() self.build_xor_differential_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time += solution['solving_time_seconds'] - max_memory = max((max_memory, solution['memory_megabytes'])) - solution['solving_time_seconds'] = total_time - solution['memory_megabytes'] = max_memory - solution['test_name'] = "find_lowest_weight_xor_differential_trail" + solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time += solution["solving_time_seconds"] + max_memory = max((max_memory, solution["memory_megabytes"])) + solution["solving_time_seconds"] = total_time + solution["memory_megabytes"] = max_memory + solution["test_name"] = "find_lowest_weight_xor_differential_trail" return solution - def find_one_xor_differential_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + def find_one_xor_differential_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None): """ Return the solution representing a XOR differential trail. By default, the search is set in the single-key setting. @@ -400,14 +413,15 @@ def find_one_xor_differential_trail(self, fixed_values=[], solver_name=solvers.S start_building_time = time.time() self.build_xor_differential_trail_model(fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_differential_trail" + solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_differential_trail" return solution - def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_one_xor_differential_trail_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return the solution representing a XOR differential trail whose probability is ``2 ** fixed_weight``. By default, the search is set in the single-key setting. @@ -452,32 +466,32 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight, fixed_ start_building_time = time.time() self.build_xor_differential_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() - solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_differential_trail_with_fixed_weight" + solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_differential_trail_with_fixed_weight" return solution def _parse_solver_output(self, variable2value): - out_suffix = '' + out_suffix = "" components_solutions = self._get_cipher_inputs_components_solutions(out_suffix, variable2value) total_weight = 0 for component in self._cipher.get_all_components(): hex_value = self._get_component_hex_value(component, out_suffix, variable2value) weight = self.calculate_component_weight(component, out_suffix, variable2value) component_solution = set_component_solution(hex_value, weight) - components_solutions[f'{component.id}{out_suffix}'] = component_solution + components_solutions[f"{component.id}{out_suffix}"] = component_solution total_weight += weight return components_solutions, total_weight def set_window_size_heuristic_by_round( - self, window_size_by_round_values, number_of_full_windows=None, full_window_operator='at_least' + self, window_size_by_round_values, number_of_full_windows=None, full_window_operator="at_least" ): if not self._cipher.is_arx(): - raise Exception('Cipher is not ARX. Window Size Heuristic is only supported for ARX ciphers.') + raise TypeError("Cipher is not ARX. Window Size Heuristic is only supported for ARX ciphers.") self._window_size_by_round_values = window_size_by_round_values if number_of_full_windows is not None: self._window_size_full_window_vars = [] @@ -485,23 +499,24 @@ def set_window_size_heuristic_by_round( self._window_size_full_window_operator = full_window_operator def set_window_size_heuristic_by_component_id( - self, window_size_by_component_id_values, number_of_full_windows=None, full_window_operator='at_least' + self, window_size_by_component_id_values, number_of_full_windows=None, full_window_operator="at_least" ): if not self._cipher.is_arx(): - raise Exception('Cipher is not ARX. Window Size Heuristic is only supported for ARX ciphers.') + raise TypeError("Cipher is not ARX. Window Size Heuristic is only supported for ARX ciphers.") self._window_size_by_component_id_values = window_size_by_component_id_values if number_of_full_windows is not None: self._window_size_full_window_vars = [] self._window_size_number_of_full_window = number_of_full_windows self._window_size_full_window_operator = full_window_operator - def set_window_size_weight_pr_vars(self, window_size_weight_pr_vars): - self._window_size_weight_pr_vars = window_size_weight_pr_vars - @property def window_size_weight_pr_vars(self): return self._window_size_weight_pr_vars + @window_size_weight_pr_vars.setter + def window_size_weight_pr_vars(self, window_size_weight_pr_vars): + self._window_size_weight_pr_vars = window_size_weight_pr_vars + @property def window_size_number_of_full_window(self): return self._window_size_number_of_full_window @@ -514,6 +529,10 @@ def window_size_full_window_vars(self): def window_size_by_round_values(self): return self._window_size_by_round_values + @window_size_by_round_values.setter + def window_size_by_round_values(self, window_size_by_round_values): + self._window_size_by_round_values = window_size_by_round_values + @property def window_size_by_component_id_values(self): return self._window_size_by_component_id_values diff --git a/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py b/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py index 8d366a763..f67ea82a6 100644 --- a/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py +++ b/claasp/cipher_modules/models/sat/sat_models/sat_xor_linear_model.py @@ -1,38 +1,48 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import time from claasp.cipher_modules.models.sat import solvers -from claasp.cipher_modules.models.sat.utils import constants, utils from claasp.cipher_modules.models.sat.sat_model import SatModel +from claasp.cipher_modules.models.sat.utils import constants, utils from claasp.cipher_modules.models.sat.utils.constants import OUTPUT_BIT_ID_SUFFIX, INPUT_BIT_ID_SUFFIX -from claasp.cipher_modules.models.utils import get_bit_bindings, set_component_solution, \ - get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, - MIX_COLUMN, SBOX, WORD_OPERATION, XOR_LINEAR, INPUT_KEY) +from claasp.cipher_modules.models.utils import ( + get_bit_bindings, + set_component_solution, + get_single_key_scenario_format_for_fixed_values, +) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INPUT_KEY, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, + XOR_LINEAR, +) class SatXorLinearModel(SatModel): - def __init__(self, cipher, counter='sequential', compact=False): + def __init__(self, cipher, counter="sequential", compact=False): super().__init__(cipher, counter, compact) - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, "_".join) @staticmethod def branch_xor_linear_constraints(bindings): @@ -89,7 +99,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): variables = [] if INPUT_KEY not in [variable["component_id"] for variable in fixed_variables]: self._cipher = self._cipher.remove_key_schedule() - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self._cipher, '_'.join) + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self._cipher, "_".join) if fixed_variables == []: fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) constraints = SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables) @@ -102,7 +112,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): variables, constraints = component.sat_xor_linear_mask_propagation_constraints(self) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -115,8 +125,9 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): self._variables_list.extend(variables) self._model_constraints.extend(constraints) - def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_linear_trails_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return a list of solutions containing all the XOR linear trails having weight equal to ``fixed_weight``. By default, the search removes the key schedule, if any. @@ -156,19 +167,19 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value start_building_time = time.time() self.build_xor_linear_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() - solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution = self.solve(XOR_LINEAR, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time solutions_list = [] - while solution['total_weight'] is not None: + while solution["total_weight"] is not None: solutions_list.append(solution) literals = [] - for component in solution['components_values']: - value_as_hex_string = solution['components_values'][component]['value'] + for component in solution["components_values"]: + value_as_hex_string = solution["components_values"][component]["value"] value_to_avoid = int(value_as_hex_string, base=16) bit_len = (len(value_as_hex_string) - 2) * 4 - minus = ['-' * (value_to_avoid >> i & 1) for i in reversed(range(bit_len))] + minus = ["-" * (value_to_avoid >> i & 1) for i in reversed(range(bit_len))] if CONSTANT in component and component.endswith(INPUT_BIT_ID_SUFFIX): continue elif component.endswith(INPUT_BIT_ID_SUFFIX) or component.endswith(OUTPUT_BIT_ID_SUFFIX): @@ -177,15 +188,16 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value else: component_id = component suffix = OUTPUT_BIT_ID_SUFFIX - literals.extend([f'{minus[i]}{component_id}_{i}{suffix}' for i in range(bit_len)]) - self._model_constraints.append(' '.join(literals)) - solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_all_xor_linear_trails_with_fixed_weight" + literals.extend([f"{minus[i]}{component_id}_{i}{suffix}" for i in range(bit_len)]) + self._model_constraints.append(" ".join(literals)) + solution = self.solve(XOR_LINEAR, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_all_xor_linear_trails_with_fixed_weight" return solutions_list - def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_linear_trails_with_weight_at_most( + self, min_weight, max_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return a list of solutions. By default, the search removes the key schedule, if any. @@ -227,16 +239,16 @@ def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight, """ solutions_list = [] for weight in range(min_weight, max_weight + 1): - solutions = self.find_all_xor_linear_trails_with_fixed_weight(weight, - fixed_values=fixed_values, - solver_name=solver_name) + solutions = self.find_all_xor_linear_trails_with_fixed_weight( + weight, fixed_values=fixed_values, solver_name=solver_name, options=options + ) for solution in solutions: - solution['test_name'] = "find_all_xor_linear_trails_with_weight_at_most" + solution["test_name"] = "find_all_xor_linear_trails_with_weight_at_most" solutions_list.extend(solutions) return solutions_list - def find_lowest_weight_xor_linear_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + def find_lowest_weight_xor_linear_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None): """ Return the solution representing a XOR LINEAR trail with the lowest possible weight. By default, the search removes the key schedule, if any. @@ -281,26 +293,26 @@ def find_lowest_weight_xor_linear_trail(self, fixed_values=[], solver_name=solve start_building_time = time.time() self.build_xor_linear_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time = solution['solving_time_seconds'] - max_memory = solution['memory_megabytes'] - while solution['total_weight'] is None: + solution = self.solve(XOR_LINEAR, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time = solution["solving_time_seconds"] + max_memory = solution["memory_megabytes"] + while solution["total_weight"] is None: current_weight += 1 start_building_time = time.time() self.build_xor_linear_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time += solution['solving_time_seconds'] - max_memory = max((max_memory, solution['memory_megabytes'])) - solution['solving_time_seconds'] = total_time - solution['memory_megabytes'] = max_memory - solution['test_name'] = "find_lowest_weight_xor_linear_trail" + solution = self.solve(XOR_LINEAR, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + total_time += solution["solving_time_seconds"] + max_memory = max((max_memory, solution["memory_megabytes"])) + solution["solving_time_seconds"] = total_time + solution["memory_megabytes"] = max_memory + solution["test_name"] = "find_lowest_weight_xor_linear_trail" return solution - def find_one_xor_linear_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + def find_one_xor_linear_trail(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None): """ Return the solution representing a XOR linear trail. By default, the search removes the key schedule, if any. @@ -345,14 +357,15 @@ def find_one_xor_linear_trail(self, fixed_values=[], solver_name=solvers.SOLVER_ start_building_time = time.time() self.build_xor_linear_trail_model(fixed_variables=fixed_values) end_building_time = time.time() - solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_linear_trail" + solution = self.solve(XOR_LINEAR, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_linear_trail" return solution - def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_one_xor_linear_trail_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT, options=None + ): """ Return the solution representing a XOR linear trail whose weight is ``fixed_weight``. By default, the search removes the key schedule, if any. @@ -392,11 +405,11 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values start_building_time = time.time() self.build_xor_linear_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() - solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_linear_trail_with_fixed_weight" + solution = self.solve(XOR_LINEAR, solver_name=solver_name, options=options) + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_linear_trail_with_fixed_weight" return solution @@ -440,18 +453,18 @@ def fix_variables_value_xor_linear_constraints(fixed_variables=[]): constraints = [] out_suffix = constants.OUTPUT_BIT_ID_SUFFIX for variable in fixed_variables: - component_id = variable['component_id'] - is_equal = (variable['constraint_type'] == 'equal') - bit_positions = variable['bit_positions'] - bit_values = variable['bit_values'] + component_id = variable["component_id"] + is_equal = variable["constraint_type"] == "equal" + bit_positions = variable["bit_positions"] + bit_values = variable["bit_values"] variables_ids = [] for position, value in zip(bit_positions, bit_values): - is_negative = '-' * (value ^ is_equal) - variables_ids.append(f'{is_negative}{component_id}_{position}{out_suffix}') + is_negative = "-" * (value ^ is_equal) + variables_ids.append(f"{is_negative}{component_id}_{position}{out_suffix}") if is_equal: constraints.extend(variables_ids) else: - constraints.append(' '.join(variables_ids)) + constraints.append(" ".join(variables_ids)) return constraints @@ -467,11 +480,11 @@ def _parse_solver_output(self, variable2value): hex_solution = self._get_component_hex_value(component, out_suffix, variable2value) weight = self.calculate_component_weight(component, out_suffix, variable2value) component_solution = set_component_solution(hex_solution, weight) - components_solutions[f'{component.id}{out_suffix}'] = component_solution + components_solutions[f"{component.id}{out_suffix}"] = component_solution total_weight += weight input_hex_value = self._get_component_hex_value(component, in_suffix, variable2value) component_solution = set_component_solution(input_hex_value, 0) - components_solutions[f'{component.id}{in_suffix}'] = component_solution + components_solutions[f"{component.id}{in_suffix}"] = component_solution return components_solutions, total_weight diff --git a/claasp/cipher_modules/models/sat/solvers.py b/claasp/cipher_modules/models/sat/solvers.py index d2124e148..a7a76ad29 100644 --- a/claasp/cipher_modules/models/sat/solvers.py +++ b/claasp/cipher_modules/models/sat/solvers.py @@ -34,26 +34,42 @@ needed. """ +# internal solvers definition +CRYPTOMINISAT = "cryptominisat" +PICOSAT = "picosat" +GLUCOSE = "glucose" +GLUCOSE_SYRUP = "glucose-syrup" +# external solvers definition +CADICAL_EXT = "CADICAL_EXT" +CRYPTOMINISAT_EXT = "CRYPTOMINISAT_EXT" +GLUCOSE_EXT = "GLUCOSE_EXT" +GLUCOSE_SYRUP_EXT = "GLUCOSE_SYRUP_EXT" +MATHSAT_EXT = "MATHSAT_EXT" +MINISAT_EXT = "MINISAT_EXT" +KISSAT_EXT = "KISSAT_EXT" +PARKISSAT_EXT = "PARKISSAT_EXT" +YICES_SAT_EXT = "YICES_SAT_EXT" -SOLVER_DEFAULT = "CRYPTOMINISAT_EXT" + +SOLVER_DEFAULT = CRYPTOMINISAT_EXT SAT_SOLVERS_INTERNAL = [ { "solver_brand_name": "CryptoMiniSat SAT solver (using Sage backend)", - "solver_name": "cryptominisat", + "solver_name": CRYPTOMINISAT, }, { "solver_brand_name": "PicoSAT (using Sage backend)", - "solver_name": "picosat", + "solver_name": PICOSAT, }, { "solver_brand_name": "Glucose SAT solver (using Sage backend)", - "solver_name": "glucose", + "solver_name": GLUCOSE, }, { "solver_brand_name": "Glucose (Syrup) SAT solver (using Sage backend)", - "solver_name": "glucose-syrup", + "solver_name": GLUCOSE_SYRUP, }, ] @@ -61,7 +77,7 @@ SAT_SOLVERS_EXTERNAL = [ { "solver_brand_name": "CaDiCal Simplified Satisfiability Solver", - "solver_name": "CADICAL_EXT", + "solver_name": CADICAL_EXT, "keywords": { "command": { "executable": "cadical", @@ -80,7 +96,7 @@ }, { "solver_brand_name": "CryptoMiniSat SAT solver", - "solver_name": "CRYPTOMINISAT_EXT", + "solver_name": CRYPTOMINISAT_EXT, "keywords": { "command": { "executable": "cryptominisat5", @@ -99,7 +115,7 @@ }, { "solver_brand_name": "Glucose SAT solver", - "solver_name": "GLUCOSE_EXT", + "solver_name": GLUCOSE_EXT, "keywords": { "command": { "executable": "glucose", @@ -118,7 +134,7 @@ }, { "solver_brand_name": "Glucose (Syrup) SAT solver", - "solver_name": "GLUCOSE_SYRUP_EXT", + "solver_name": GLUCOSE_SYRUP_EXT, "keywords": { "command": { "executable": "glucose-syrup", @@ -137,7 +153,7 @@ }, { "solver_brand_name": "The Kissat SAT solver", - "solver_name": "KISSAT_EXT", + "solver_name": KISSAT_EXT, "keywords": { "command": { "executable": "kissat", @@ -156,7 +172,7 @@ }, { "solver_brand_name": "ParKissat-RS", - "solver_name": "PARKISSAT_EXT", + "solver_name": PARKISSAT_EXT, "keywords": { "command": { "executable": "parkissat", @@ -175,7 +191,7 @@ }, { "solver_brand_name": "MathSAT", - "solver_name": "MATHSAT_EXT", + "solver_name": MATHSAT_EXT, "keywords": { "command": { "executable": "mathsat", @@ -194,7 +210,7 @@ }, { "solver_brand_name": "MiniSat", - "solver_name": "MINISAT_EXT", + "solver_name": MINISAT_EXT, "keywords": { "command": { "executable": "minisat", @@ -213,7 +229,7 @@ }, { "solver_brand_name": "Yices2", - "solver_name": "YICES_SAT_EXT", + "solver_name": YICES_SAT_EXT, "keywords": { "command": { "executable": "yices-sat", diff --git a/claasp/cipher_modules/models/sat/utils/constants.py b/claasp/cipher_modules/models/sat/utils/constants.py index 956ecbde6..2696a6482 100644 --- a/claasp/cipher_modules/models/sat/utils/constants.py +++ b/claasp/cipher_modules/models/sat/utils/constants.py @@ -1,2 +1,2 @@ -INPUT_BIT_ID_SUFFIX = '_i' -OUTPUT_BIT_ID_SUFFIX = '_o' +INPUT_BIT_ID_SUFFIX = "_i" +OUTPUT_BIT_ID_SUFFIX = "_o" diff --git a/claasp/cipher_modules/models/sat/utils/mzn_predicates.py b/claasp/cipher_modules/models/sat/utils/mzn_predicates.py index 3fa9f55e3..d01257c6e 100644 --- a/claasp/cipher_modules/models/sat/utils/mzn_predicates.py +++ b/claasp/cipher_modules/models/sat/utils/mzn_predicates.py @@ -1,24 +1,22 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** def get_word_operations(): - functions_with_window_size = """ % Left rotation of X by val positions function array[int] of var bool: LRot(array[int] of var bool: X, int: val)= diff --git a/claasp/cipher_modules/models/sat/utils/n_window_heuristic_helper.py b/claasp/cipher_modules/models/sat/utils/n_window_heuristic_helper.py index 0c3d4fb94..d068ab5c6 100644 --- a/claasp/cipher_modules/models/sat/utils/n_window_heuristic_helper.py +++ b/claasp/cipher_modules/models/sat/utils/n_window_heuristic_helper.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -27,55 +26,55 @@ def save_list(data, filename): """Save a list to a file using pickle.""" try: - with open(filename, 'wb') as file: + with open(filename, "wb") as file: pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) print(f"List successfully saved to {filename}") except Exception as e: print(f"Error saving list: {e}") + def load_list(filename): """Load a list from a file using pickle.""" try: - with open(filename, 'rb') as file: + with open(filename, "rb") as file: return pickle.load(file) except Exception as e: print(f"Error loading list: {e}") return None + def generating_n_window_clauses(window_size_plus_one): def compute_ex(i): return Xor(first_diff_addend_vars[i], second_diff_addend_vars[i], output_diff_vars[i]) - - - filename = f"{window_size_plus_one-1}-window_size_list_of_clauses.pkl" + filename = f"{window_size_plus_one - 1}-window_size_list_of_clauses.pkl" if os.path.exists(filename): return load_list(filename) - # Define your variables - first_diff_addend_vars = symbols('a[:{}]'.format(window_size_plus_one)) - second_diff_addend_vars = symbols('b[:{}]'.format(window_size_plus_one)) - output_diff_vars = symbols('c[:{}]'.format(window_size_plus_one)) - temp_var = symbols('aux') + first_diff_addend_vars = symbols("a[:{}]".format(window_size_plus_one)) + second_diff_addend_vars = symbols("b[:{}]".format(window_size_plus_one)) + output_diff_vars = symbols("c[:{}]".format(window_size_plus_one)) + temp_var = symbols("aux") if window_size_plus_one == 1: ex = Not( Xor( first_diff_addend_vars[window_size_plus_one - 1], second_diff_addend_vars[window_size_plus_one - 1], - output_diff_vars[window_size_plus_one - 1] + output_diff_vars[window_size_plus_one - 1], ) ) else: results = Parallel(n_jobs=-1)(delayed(compute_ex)(i) for i in range(window_size_plus_one - 1)) ex2 = Equivalent(And(*results), temp_var) ex1 = And( - temp_var, Xor( + temp_var, + Xor( first_diff_addend_vars[window_size_plus_one - 1], second_diff_addend_vars[window_size_plus_one - 1], - output_diff_vars[window_size_plus_one - 1] - ) + output_diff_vars[window_size_plus_one - 1], + ), ) ex = And(Not(ex1), ex2) @@ -88,19 +87,19 @@ def compute_ex(i): def convert_clauses(clauses): import re - clean_clauses = re.sub(r'[{}()\s]', '', clauses) + clean_clauses = re.sub(r"[{}()\s]", "", clauses) - clause_list = clean_clauses.split('&') + clause_list = clean_clauses.split("&") formatted_clauses = [] for clause in clause_list: - literals = clause.split('|') + literals = clause.split("|") pos_vars = [] neg_vars = [] for literal in literals: - if literal.startswith('~'): + if literal.startswith("~"): neg_vars.append(literal[1:]) else: pos_vars.append(literal) @@ -108,19 +107,20 @@ def convert_clauses(clauses): pos_vars.sort() neg_vars.sort() - formatted_clause = f"f'" - formatted_clause += f" ".join(f"{{{var}}}" for var in pos_vars) - formatted_clause += f" " if pos_vars and neg_vars else "" + formatted_clause = "f'" + formatted_clause += " ".join(f"{{{var}}}" for var in pos_vars) + formatted_clause += " " if pos_vars and neg_vars else "" formatted_clause += " ".join(f"-{{{var}}}" for var in neg_vars) - formatted_clause += f"'" + formatted_clause += "'" formatted_clauses.append(formatted_clause) return formatted_clauses + def generate_window_size_clauses(first_input_difference, second_input_difference, output_difference, aux_var): """ - Returns a set of clauses representing a simplified CNF (Conjunctive Normal Form) expression + Returns a set of clauses representing a simplified CNF (Conjunctive Normal Form) expression for the n-window size heuristic applied to a + b = c. Specifically, these clauses ensure that no more than n variables are true (i.e., there are no sequences of n+1 ones in the carry differences of a + b = c). These clauses were obtained after simplifying the formula below (in sympy notation): @@ -149,12 +149,7 @@ def generate_window_size_clauses(first_input_difference, second_input_difference """ window_size_plus_one = len(first_input_difference) - context = { - 'a': first_input_difference, - 'b': second_input_difference, - 'c': output_difference, - 'aux': aux_var - } + context = {"a": first_input_difference, "b": second_input_difference, "c": output_difference, "aux": aux_var} new_clauses = [] string_generated_clauses = generating_n_window_clauses(window_size_plus_one) diff --git a/claasp/cipher_modules/models/sat/utils/utils.py b/claasp/cipher_modules/models/sat/utils/utils.py index e01307ca2..af61cd31d 100644 --- a/claasp/cipher_modules/models/sat/utils/utils.py +++ b/claasp/cipher_modules/models/sat/utils/utils.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -46,10 +45,14 @@ method for SAT solvers in :py:class:`Sat Model `. """ + import itertools import os import re import subprocess +import time + +from claasp.cipher_modules.models.sat import solvers # ----------------- # @@ -64,50 +67,48 @@ def cms_add_clauses_to_solver(numerical_cnf, solver): It needs to be overwritten in this class because it must handle the XOR clauses. """ for clause in numerical_cnf: - if clause.startswith('x '): - rhs = bool(True ^ (clause.count('-') % 2)) - literals = clause.replace('-', '').split()[1:] + if clause.startswith("x "): + rhs = bool(1 ^ (clause.count("-") % 2)) + literals = clause.replace("-", "").split()[1:] solver.add_xor_clause([int(literal) for literal in literals], rhs) else: solver.add_clause([int(literal) for literal in clause.split()]) -def create_numerical_cnf(cnf): +def create_numerical_cnf(cnf: list[str]) -> tuple[list[str], list[str]]: # creating dictionary (variable -> string, numeric_id -> int) - family_of_variables = ' '.join(cnf).replace('-', '') - if family_of_variables.startswith('x '): - family_of_variables = family_of_variables[2:] - family_of_variables = family_of_variables.replace(' x ', ' ') - variables = sorted(set(family_of_variables.split())) - variable2number = {variable: i + 1 for (i, variable) in enumerate(variables)} + variables = " ".join(cnf).replace("-", "").replace("x ", " ") + variables = sorted(set(variables.split())) + variable_to_number = {variable: i + 1 for i, variable in enumerate(variables)} # creating numerical CNF numerical_cnf = [] for clause in cnf: literals = clause.split() numerical_literals = [] - if literals[0] == 'x': + if literals[0] == "x": literals = literals[1:] - numerical_literals = ['x'] - lits_are_neg = (literal[0] == '-' for literal in literals) - numerical_literals.extend(tuple(f'{"-" * lit_is_neg}{variable2number[literal[lit_is_neg:]]}' - for lit_is_neg, literal in zip(lits_are_neg, literals))) - numerical_clause = ' '.join(numerical_literals) - numerical_cnf.append(numerical_clause) + numerical_literals = ["x"] + signs = (literal[0] == "-" for literal in literals) + numerical_literals.extend( + [f"{'-' * sign}{variable_to_number[literal[sign:]]}" for sign, literal in zip(signs, literals)] + ) + numerical_cnf.append(" ".join(numerical_literals)) - return variable2number, numerical_cnf + return variables, numerical_cnf -def numerical_cnf_to_dimacs(number_of_variables, numerical_cnf): - dimacs = f'p cnf {number_of_variables} {len(numerical_cnf)}\n' - dimacs_clauses = tuple(f'{numerical_clause} 0\n' for numerical_clause in numerical_cnf) +def numerical_cnf_to_dimacs(variables: list[str], numerical_cnf: list[str]) -> str: + dimacs = [f"p cnf {len(variables)} {len(numerical_cnf)}"] + dimacs.extend(f"{numerical_clause} 0" for numerical_clause in numerical_cnf) + dimacs = "\n".join(dimacs) - return dimacs + ''.join(dimacs_clauses) + return dimacs def cnf_n_window_heuristic_on_w_vars(hw_bit_ids): - cnf_constraint_lst = [f'-{hw_bit}' for hw_bit in hw_bit_ids] + cnf_constraint_lst = [f"-{hw_bit}" for hw_bit in hw_bit_ids] - return [' '.join(cnf_constraint_lst)] + return [" ".join(cnf_constraint_lst)] # ----------------------------------------------------------------- # @@ -133,7 +134,7 @@ def cnf_equivalent(variables): """ variables_shifted = [variables[-1]] + variables[:-1] - return [f'{variables[i]} -{variables_shifted[i]}' for i in range(len(variables))] + return [f"{variables[i]} -{variables_shifted[i]}" for i in range(len(variables))] def cnf_inequality(left_var, right_var): @@ -151,7 +152,7 @@ def cnf_inequality(left_var, right_var): sage: cnf_inequality('a', 'b') ('a b', '-a -b') """ - return (f'{left_var} {right_var}', f'-{left_var} -{right_var}') + return (f"{left_var} {right_var}", f"-{left_var} -{right_var}") def cnf_and(result, variables): @@ -171,8 +172,8 @@ def cnf_and(result, variables): sage: cnf_and('r', ['a', 'b', 'c']) ['-r a', '-r b', '-r c', 'r -a -b -c'] """ - cnf = [f'-{result} {variable}' for variable in variables] - cnf.append(f'{result} -{" -".join(variables)}') + cnf = [f"-{result} {variable}" for variable in variables] + cnf.append(f"{result} -{' -'.join(variables)}") return cnf @@ -202,8 +203,8 @@ def cnf_or(result, variables): sage: cnf_or('r', ['a', 'b', 'c']) ['r -a', 'r -b', 'r -c', '-r a b c'] """ - model = [f'{result} -{variable}' for variable in variables] - model.append(f'-{result} {" ".join(variables)}') + model = [f"{result} -{variable}" for variable in variables] + model.append(f"-{result} {' '.join(variables)}") return model @@ -246,8 +247,8 @@ def cnf_xor(result, variables): for i in range(1, num_of_operands + 1, 2): subsets = tuple(itertools.combinations(range(num_of_operands), i)) for s in subsets: - literals = ['-' * (j in s) + f'{operands[j]}' for j in range(num_of_operands)] - model.append(' '.join(literals)) + literals = ["-" * (j in s) + f"{operands[j]}" for j in range(num_of_operands)] + model.append(" ".join(literals)) return model @@ -310,12 +311,14 @@ def cnf_carry(carry, x, y, previous_carry): 'y_3 c_2 -c_3', '-y_3 -c_2 c_3') """ - return (f'{x} {y} -{carry}', - f'-{x} -{y} {carry}', - f'{x} {previous_carry} -{carry}', - f'-{x} -{previous_carry} {carry}', - f'{y} {previous_carry} -{carry}', - f'-{y} -{previous_carry} {carry}') + return ( + f"{x} {y} -{carry}", + f"-{x} -{y} {carry}", + f"{x} {previous_carry} -{carry}", + f"-{x} -{previous_carry} {carry}", + f"{y} {previous_carry} -{carry}", + f"-{y} -{previous_carry} {carry}", + ) def cnf_carry_comp2(carry, x, previous_carry): @@ -337,9 +340,7 @@ def cnf_carry_comp2(carry, x, previous_carry): sage: cnf_carry_comp2('c_3', 'x_2', 'c_2') ('-c_3 c_2', '-c_3 -x_2', 'c_3 -c_2 x_2') """ - return (f'-{carry} {previous_carry}', - f'-{carry} -{x}', - f'{carry} -{previous_carry} {x}') + return (f"-{carry} {previous_carry}", f"-{carry} -{x}", f"{carry} -{previous_carry} {x}") def cnf_result_comp2(result, x, carry): @@ -360,10 +361,7 @@ def cnf_result_comp2(result, x, carry): sage: cnf_result_comp2('r_3', 'x_3', 'c_3') ('c_3 -r_3 -x_3', '-c_3 r_3 -x_3', '-c_3 -r_3 x_3', 'c_3 r_3 x_3') """ - return (f'{carry} -{result} -{x}', - f'-{carry} {result} -{x}', - f'-{carry} -{result} {x}', - f'{carry} {result} {x}') + return (f"{carry} -{result} -{x}", f"-{carry} {result} -{x}", f"-{carry} -{result} {x}", f"{carry} {result} {x}") def cnf_vshift_id(out_id, in_id, in_shifted, shift_id): @@ -385,10 +383,12 @@ def cnf_vshift_id(out_id, in_id, in_shifted, shift_id): sage: cnf_vshift_id('s_3', 'i_3', 'i_4', 'k_7') ('-s_3 i_3 k_7', 's_3 -i_3 k_7', '-s_3 i_4 -k_7', 's_3 -i_4 -k_7') """ - return (f'-{out_id} {in_id} {shift_id}', - f'{out_id} -{in_id} {shift_id}', - f'-{out_id} {in_shifted} -{shift_id}', - f'{out_id} -{in_shifted} -{shift_id}') + return ( + f"-{out_id} {in_id} {shift_id}", + f"{out_id} -{in_id} {shift_id}", + f"-{out_id} {in_shifted} -{shift_id}", + f"{out_id} -{in_shifted} -{shift_id}", + ) def cnf_vshift_false(out_id, in_id, shift_id): @@ -410,9 +410,7 @@ def cnf_vshift_false(out_id, in_id, shift_id): sage: cnf_vshift_false('s_1', 'i_1', 'k_7') ('-s_1 i_1', '-s_1 -k_7', 's_1 -i_1 k_7') """ - return (f'-{out_id} {in_id}', - f'-{out_id} -{shift_id}', - f'{out_id} -{in_id} {shift_id}') + return (f"-{out_id} {in_id}", f"-{out_id} -{shift_id}", f"{out_id} -{in_id} {shift_id}") def cnf_hw_lipmaa(hw, alpha, beta, gamma): @@ -439,11 +437,13 @@ def cnf_hw_lipmaa(hw, alpha, beta, gamma): 'alpha_7 beta_7 gamma_7 -hw_6', '-alpha_7 -beta_7 -gamma_7 -hw_6') """ - return (f'{alpha} -{gamma} {hw}', - f'{beta} -{alpha} {hw}', - f'{gamma} -{beta} {hw}', - f'{alpha} {beta} {gamma} -{hw}', - f'-{alpha} -{beta} -{gamma} -{hw}') + return ( + f"{alpha} -{gamma} {hw}", + f"{beta} -{alpha} {hw}", + f"{gamma} -{beta} {hw}", + f"{alpha} {beta} {gamma} -{hw}", + f"-{alpha} -{beta} -{gamma} -{hw}", + ) def cnf_lipmaa(hw, dummy, beta_1, alpha, beta, gamma): @@ -477,16 +477,18 @@ def cnf_lipmaa(hw, dummy, beta_1, alpha, beta, gamma): '-alpha_10 -beta_10 dummy_10 -gamma_10', '-alpha_10 -beta_10 -dummy_10 gamma_10') """ - return (f'{beta_1} -{dummy} {hw}', - f'-{beta_1} {dummy} {hw}', - f'{alpha} {beta} {dummy} -{gamma}', - f'{alpha} {beta} -{dummy} {gamma}', - f'{alpha} -{beta} {dummy} {gamma}', - f'-{alpha} {beta} {dummy} {gamma}', - f'{alpha} -{beta} -{dummy} -{gamma}', - f'-{alpha} {beta} -{dummy} -{gamma}', - f'-{alpha} -{beta} {dummy} -{gamma}', - f'-{alpha} -{beta} -{dummy} {gamma}') + return ( + f"{beta_1} -{dummy} {hw}", + f"-{beta_1} {dummy} {hw}", + f"{alpha} {beta} {dummy} -{gamma}", + f"{alpha} {beta} -{dummy} {gamma}", + f"{alpha} -{beta} {dummy} {gamma}", + f"-{alpha} {beta} {dummy} {gamma}", + f"{alpha} -{beta} -{dummy} -{gamma}", + f"-{alpha} {beta} -{dummy} -{gamma}", + f"-{alpha} -{beta} {dummy} -{gamma}", + f"-{alpha} -{beta} -{dummy} {gamma}", + ) def cnf_modadd_inequality(z, u, v): @@ -509,8 +511,7 @@ def cnf_modadd_inequality(z, u, v): sage: cnf_modadd_inequality('z', 'u', 'v') ('z u -v', 'z -u v') """ - return (f'{z} {u} -{v}', - f'{z} -{u} {v}') + return (f"{z} {u} -{v}", f"{z} -{u} {v}") def cnf_and_differential(diff_in_0, diff_in_1, diff_out, hw): @@ -530,10 +531,7 @@ def cnf_and_differential(diff_in_0, diff_in_1, diff_out, hw): sage: cnf_and_differential('and_0', 'and_1', 'and_out', 'hw') ('-and_out hw', 'and_0 and_1 -hw', '-and_0 hw', '-and_1 hw') """ - return (f'-{diff_out} {hw}', - f'{diff_in_0} {diff_in_1} -{hw}', - f'-{diff_in_0} {hw}', - f'-{diff_in_1} {hw}') + return (f"-{diff_out} {hw}", f"{diff_in_0} {diff_in_1} -{hw}", f"-{diff_in_0} {hw}", f"-{diff_in_1} {hw}") def cnf_and_linear(mask_in_0, mask_in_1, mask_out, hw): @@ -555,10 +553,7 @@ def cnf_and_linear(mask_in_0, mask_in_1, mask_out, hw): sage: cnf_and_linear('and_0', 'and_1', 'and_out', 'hw') ('-and_0 hw', '-and_1 hw', '-and_out hw', 'and_out -hw') """ - return (f'-{mask_in_0} {hw}', - f'-{mask_in_1} {hw}', - f'-{mask_out} {hw}', - f'{mask_out} -{hw}') + return (f"-{mask_in_0} {hw}", f"-{mask_in_1} {hw}", f"-{mask_out} {hw}", f"{mask_out} -{hw}") def cnf_xor_truncated(result, variable_0, variable_1): @@ -607,13 +602,15 @@ def cnf_xor_truncated(result, variable_0, variable_1): 'b1 r0 r1 -a1', 'r0 -a1 -b1 -r1'] """ - return [f'{result[0]} -{variable_0[0]}', - f'{result[0]} -{variable_1[0]}', - f'{variable_0[0]} {variable_1[0]} -{result[0]}', - f'{variable_0[1]} {variable_1[1]} {result[0]} -{result[1]}', - f'{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}', - f'{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}', - f'{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}'] + return [ + f"{result[0]} -{variable_0[0]}", + f"{result[0]} -{variable_1[0]}", + f"{variable_0[0]} {variable_1[0]} -{result[0]}", + f"{variable_0[1]} {variable_1[1]} {result[0]} -{result[1]}", + f"{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}", + f"{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}", + f"{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}", + ] def cnf_xor_truncated_seq(results, variables): @@ -651,87 +648,130 @@ def cnf_xor_truncated_seq(results, variables): def get_cnf_bitwise_truncate_constraints(a, a_0, a_1): - return [ - f'-{a_0}', f'{a_1} -{a}', f'{a} -{a_1}' - ] + return [f"-{a_0}", f"{a_1} -{a}", f"{a} -{a_1}"] def get_cnf_truncated_linear_constraints(a, a_0): - return [ - f'-{a} -{a_0}' - ] + return [f"-{a} -{a_0}"] def modadd_truncated_lsb(result, variable_0, variable_1, next_carry): - return [f'{next_carry[0]} -{next_carry[1]}', - f'{next_carry[0]} -{variable_1[1]}', - f'{next_carry[0]} -{result[0]}', - f'{next_carry[0]} -{result[1]}', - f'{result[0]} -{variable_0[0]}', - f'{result[0]} -{variable_1[0]}', - f'{variable_0[0]} {variable_1[0]} -{result[0]}', - f'{variable_0[1]} {variable_1[1]} {result[0]} -{next_carry[0]}', - f'{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}', - f'{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}', - f'{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}'] + return [ + f"{next_carry[0]} -{next_carry[1]}", + f"{next_carry[0]} -{variable_1[1]}", + f"{next_carry[0]} -{result[0]}", + f"{next_carry[0]} -{result[1]}", + f"{result[0]} -{variable_0[0]}", + f"{result[0]} -{variable_1[0]}", + f"{variable_0[0]} {variable_1[0]} -{result[0]}", + f"{variable_0[1]} {variable_1[1]} {result[0]} -{next_carry[0]}", + f"{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}", + f"{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}", + f"{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}", + ] def modadd_truncated(result, variable_0, variable_1, carry, next_carry): - return [f'{next_carry[0]} -{next_carry[1]}', - f'{next_carry[0]} -{variable_1[1]}', - f'{next_carry[0]} -{result[0]}', - f'{next_carry[0]} -{result[1]}', - f'{result[0]} -{carry[0]}', - f'{result[0]} -{carry[1]}', - f'{result[0]} -{variable_0[0]}', - f'{result[0]} -{variable_1[0]}', - f'{variable_0[1]} {variable_1[1]} {result[0]} -{next_carry[0]}', - f'{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}', - f'{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}', - f'{carry[0]} {carry[1]} {variable_0[0]} {variable_1[0]} -{result[0]}', - f'{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}'] + return [ + f"{next_carry[0]} -{next_carry[1]}", + f"{next_carry[0]} -{variable_1[1]}", + f"{next_carry[0]} -{result[0]}", + f"{next_carry[0]} -{result[1]}", + f"{result[0]} -{carry[0]}", + f"{result[0]} -{carry[1]}", + f"{result[0]} -{variable_0[0]}", + f"{result[0]} -{variable_1[0]}", + f"{variable_0[1]} {variable_1[1]} {result[0]} -{next_carry[0]}", + f"{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}", + f"{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}", + f"{carry[0]} {carry[1]} {variable_0[0]} {variable_1[0]} -{result[0]}", + f"{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}", + ] def modadd_truncated_msb(result, variable_0, variable_1, carry): - return [f'{result[0]} -{carry[0]}', - f'{result[0]} -{carry[1]}', - f'{result[0]} -{variable_0[0]}', - f'{result[0]} -{variable_1[0]}', - f'{variable_0[1]} {variable_1[1]} {result[0]} -{result[1]}', - f'{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}', - f'{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}', - f'{carry[0]} {carry[1]} {variable_0[0]} {variable_1[0]} -{result[0]}', - f'{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}'] + return [ + f"{result[0]} -{carry[0]}", + f"{result[0]} -{carry[1]}", + f"{result[0]} -{variable_0[0]}", + f"{result[0]} -{variable_1[0]}", + f"{variable_0[1]} {variable_1[1]} {result[0]} -{result[1]}", + f"{variable_0[1]} {result[0]} {result[1]} -{variable_1[1]}", + f"{variable_1[1]} {result[0]} {result[1]} -{variable_0[1]}", + f"{carry[0]} {carry[1]} {variable_0[0]} {variable_1[0]} -{result[0]}", + f"{result[0]} -{variable_0[1]} -{variable_1[1]} -{result[1]}", + ] + + +def incompatibility(incompatibility_var, forward_var, backward_var): + """ + Return a list of strings representing CNF clauses encoding incompatibility + constraints between forward and backward variables with respect to an + incompatibility variable. + + INPUT: + + - ``incompatibility_var`` -- **string**; variable representing the incompatibility condition + - ``forward_var`` -- **tuple of strings**; pair of variables related to the forward direction + - ``backward_var`` -- **tuple of strings**; pair of variables related to the backward direction + + OUTPUT: + + - **list of strings**; each string is a CNF clause encoding part of the incompatibility constraint + + EXAMPLES:: + + sage: from claasp.cipher_modules.models.sat.utils.utils import incompatibility + sage: incompatibility_var = 'i' + sage: forward_var = ('f0', 'f1') + sage: backward_var = ('b0', 'b1') + sage: incompatibility(incompatibility_var, forward_var, backward_var) + ['-f0 -i', + '-b0 -i', + 'f1 b1 -i', + '-f1 -b1 -i', + 'f0 f1 b0 i -b1', + 'f0 b0 b1 i -f1'] + """ + return [f'-{forward_var[0]} -{incompatibility_var}', + f'-{backward_var[0]} -{incompatibility_var}', + f'{forward_var[1]} {backward_var[1]} -{incompatibility_var}', + f'-{forward_var[1]} -{backward_var[1]} -{incompatibility_var}', + f'{forward_var[0]} {forward_var[1]} {backward_var[0]} {incompatibility_var} -{backward_var[1]}', + f'{forward_var[0]} {backward_var[0]} {backward_var[1]} {incompatibility_var} -{forward_var[1]}'] # ---------------------------- # # - Running SAT solver - # # ---------------------------- # + def _get_data(data_keywords, lines): data_line = [line for line in lines if data_keywords in line][0] - data = float(re.findall(r'[0-9]+\.?[0-9]*', data_line)[0]) + data = float(re.findall(r"[0-9]+\.?[0-9]*", data_line)[0]) return data def run_sat_solver(solver_specs, options, dimacs_input, host=None, env_vars_string=""): """Call the SAT solver specified in `solver_specs`, using input and output pipes.""" - solver_name = solver_specs['solver_name'] - command = [solver_specs['keywords']['command']['executable']] + solver_specs['keywords']['command']['options'] + options + solver_name = solver_specs["solver_name"] + command = ( + [solver_specs["keywords"]["command"]["executable"]] + solver_specs["keywords"]["command"]["options"] + options + ) if host: - command = ['ssh', f'{host}'] + [env_vars_string] + command + command = ["ssh", f"{host}"] + [env_vars_string] + command solver_process = subprocess.run(command, input=dimacs_input, capture_output=True, text=True) solver_output = solver_process.stdout.splitlines() - status = [line for line in solver_output if line.startswith('s')][0].split()[1] + status = [line for line in solver_output if line.startswith("s")][0].split()[1] values = [] - if status == 'SATISFIABLE': + if status == "SATISFIABLE": for line in solver_output: - if line.startswith('v'): + if line.startswith("v"): values.extend(line.split()[1:]) values = values[:-1] - if solver_name == 'kissat': - data_keywords = solver_specs['keywords']['time'] + if solver_name == solvers.KISSAT_EXT: + data_keywords = solver_specs["keywords"]["time"] lines = solver_output data_line = [line for line in lines if data_keywords in line][0] seconds_str_index = data_line.find("seconds") - 2 @@ -739,107 +779,114 @@ def run_sat_solver(solver_specs, options, dimacs_input, host=None, env_vars_stri while data_line[seconds_str_index] != " ": output_str += data_line[seconds_str_index] seconds_str_index -= 1 - time = float(output_str[::-1]) + solver_time = float(output_str[::-1]) else: - time = _get_data(solver_specs['keywords']['time'], solver_output) - memory = float('inf') - memory_keywords = solver_specs['keywords']['memory'] + solver_time = _get_data(solver_specs["keywords"]["time"], solver_output) + solver_memory = float("inf") + memory_keywords = solver_specs["keywords"]["memory"] if memory_keywords: - if not (solver_name == 'glucose-syrup' and status != 'SATISFIABLE'): - memory = _get_data(memory_keywords, solver_output) - if solver_name == 'kissat': - memory = memory / 10**6 - if solver_name == 'cryptominisat': - memory = memory / 10**3 + if not (solver_name == solvers.GLUCOSE_SYRUP_EXT and status != "SATISFIABLE"): + solver_memory = _get_data(memory_keywords, solver_output) + if solver_name == solvers.KISSAT_EXT: + solver_memory = solver_memory / 10**6 + if solver_name == solvers.CRYPTOMINISAT_EXT: + solver_memory = solver_memory / 10**3 - return status, time, memory, values + return status, solver_time, solver_memory, values def run_minisat(solver_specs, options, dimacs_input, input_file_name, output_file_name): """Call the MiniSat solver specified in `solver_specs`, using input and output files.""" - with open(input_file_name, 'wt') as input_file: + with open(input_file_name, "wt") as input_file: input_file.write(dimacs_input) - command = [solver_specs['keywords']['command']['executable']] + solver_specs['keywords']['command']['options'] + options + command = ( + [solver_specs["keywords"]["command"]["executable"]] + solver_specs["keywords"]["command"]["options"] + options + ) command.append(input_file_name) command.append(output_file_name) solver_process = subprocess.run(command, capture_output=True, text=True) solver_output = solver_process.stdout.splitlines() - time = _get_data(solver_specs['keywords']['time'], solver_output) - memory = _get_data(solver_specs['keywords']['memory'], solver_output) + solver_time = _get_data(solver_specs["keywords"]["time"], solver_output) + solver_memory = _get_data(solver_specs["keywords"]["memory"], solver_output) status = solver_output[-1] values = [] - if status == 'SATISFIABLE': - with open(output_file_name, 'rt') as output_file: + if status == "SATISFIABLE": + with open(output_file_name, "rt") as output_file: values = output_file.read().splitlines()[1].split()[:-1] os.remove(input_file_name) os.remove(output_file_name) - return status, time, memory, values + return status, solver_time, solver_memory, values def run_parkissat(solver_specs, options, dimacs_input, input_file_name): """Call the Parkissat solver specified in `solver_specs`, using input and output files.""" - with open(input_file_name, 'wt') as input_file: + with open(input_file_name, "wt") as input_file: input_file.write(dimacs_input) - import time - command = [solver_specs['keywords']['command']['executable']] + solver_specs['keywords']['command']['options'] + options + command = ( + [solver_specs["keywords"]["command"]["executable"]] + solver_specs["keywords"]["command"]["options"] + options + ) command.append(input_file_name) start = time.time() solver_process = subprocess.run(command, capture_output=True, text=True) end = time.time() solver_output = solver_process.stdout.splitlines() - time = end - start - memory = 0 + solver_time = end - start + solver_memory = 0 status = solver_output[0].split()[1] values = "" - if status == 'SATISFIABLE': + if status == "SATISFIABLE": solver_output = solver_output[1:] - solver_output = list(map(lambda s: s.replace('v ', ''), solver_output)) + solver_output = [s.replace("v ", "") for s in solver_output] values = [] for element in solver_output: substrings = element.split() values.extend(substrings) os.remove(input_file_name) - return status, time, memory, values + return status, solver_time, solver_memory, values def run_yices(solver_specs, options, dimacs_input, input_file_name): """Call the Yices SAT solver specified in `solver_specs`, using input file.""" - with open(input_file_name, 'wt') as input_file: + with open(input_file_name, "wt") as input_file: input_file.write(dimacs_input) - command = [solver_specs['keywords']['command']['executable']] + solver_specs['keywords']['command']['options'] + options + command = ( + [solver_specs["keywords"]["command"]["executable"]] + solver_specs["keywords"]["command"]["options"] + options + ) command.append(input_file_name) solver_process = subprocess.run(command, capture_output=True, text=True) solver_stats = solver_process.stderr.splitlines() solver_output = solver_process.stdout.splitlines() - time = _get_data(solver_specs['keywords']['time'], solver_stats) - memory = _get_data(solver_specs['keywords']['memory'], solver_stats) - status = 'SATISFIABLE' if solver_output[0] == 'sat' else 'UNSATISFIABLE' + solver_time = _get_data(solver_specs["keywords"]["time"], solver_stats) + solver_memory = _get_data(solver_specs["keywords"]["memory"], solver_stats) + status = "SATISFIABLE" if solver_output[0] == "sat" else "UNSATISFIABLE" values = [] - if status == 'SATISFIABLE': + if status == "SATISFIABLE": values = solver_output[1].split()[:-1] os.remove(input_file_name) - return status, time, memory, values + return status, solver_time, solver_memory, values def _generate_component_model_types(speck_cipher): """Generates the component model types for a given Speck cipher.""" component_model_types = [] for component in speck_cipher.get_all_components(): - component_model_types.append({ - "component_id": component.id, - "component_object": component, - "model_type": "sat_xor_differential_propagation_constraints" - }) + component_model_types.append( + { + "component_id": component.id, + "component_object": component, + "model_type": "sat_xor_differential_propagation_constraints", + } + ) return component_model_types def _update_component_model_types_for_truncated_components( - component_model_types, - truncated_components, - truncated_model_type="sat_bitwise_deterministic_truncated_xor_differential_constraints" + component_model_types, + truncated_components, + truncated_model_type="sat_bitwise_deterministic_truncated_xor_differential_constraints", ): """Updates the component model types for truncated components.""" for component_model_type in component_model_types: @@ -855,381 +902,418 @@ def _update_component_model_types_for_linear_components(component_model_types, l def get_semi_deterministic_cnf_window_0( - A_t0, A_t1, A_v0, A_v1, - B_t0, B_t1, B_v0, B_v1, - C_t0, C_t1, C_v0, C_v1, - p0, q0, r0 + A_t0, A_t1, A_v0, A_v1, B_t0, B_t1, B_v0, B_v1, C_t0, C_t1, C_v0, C_v1, p0, q0, r0 ): return [ - f'{C_v1} {B_t1} {C_t0} {B_v0} {A_v0} {A_t0} {A_t1} {B_v1} -{C_v0} {A_v1} {B_t0}', - f'{C_v1} {B_t1} {C_t0} {B_v0} -{A_v0} {A_t0} {A_t1} {B_v1} {C_v0} {A_v1} {B_t0}', - f'{C_v1} {B_t1} {C_t0} -{B_v0} {A_v0} {A_t0} {A_t1} {B_v1} {C_v0} {A_v1} {B_t0}', - f'{C_v1} {B_t1} {C_t0} -{B_v0} -{A_v0} {A_t0} {A_t1} {B_v1} -{C_v0} {A_v1} {B_t0}', - f'{C_v1} {B_t1} -{C_t0} {A_t1} {B_v1} {A_v1} {C_t1}', - f'{C_v1} {B_t1} -{p0} {A_t1} {B_v1} {A_v1}', - f'{C_v1} {B_t1} {B_v0} {A_v0} {A_t1} {B_v1} -{C_v0} {A_v1} {C_t1}', - f'{C_v1} {B_t1} {B_v0} -{A_v0} {A_t1} {B_v1} {C_v0} {A_v1} {C_t1}', - f'{C_v1} {B_t1} -{B_v0} {A_v0} {A_t1} {B_v1} {C_v0} {A_v1} {C_t1}', - f'{C_v1} {B_t1} -{B_v0} -{A_v0} {A_t1} {B_v1} -{C_v0} {A_v1} {C_t1}', - f'{C_v1} {B_t1} -{r0} {A_t1} {B_v1} {A_v1}', - f'{C_v1} {B_t1} -{A_t0} {A_t1} {B_v1} {A_v1} {C_t1}', - f'{C_v1} {B_t1} {A_t1} {B_v1} {A_v1} -{B_t0} {C_t1}', - f'{C_v1} {C_t0} {p0} {B_v0} {A_v0} {A_t0} -{C_v0} {B_t0}', - f'{C_v1} {C_t0} {p0} {B_v0} -{A_v0} {A_t0} {C_v0} {B_t0}', - f'{C_v1} {C_t0} {p0} -{B_v0} {A_v0} {A_t0} {C_v0} {B_t0}', - f'{C_v1} {C_t0} {p0} -{B_v0} -{A_v0} {A_t0} -{C_v0} {B_t0}', - f'{C_v1} {C_t0} {p0} {A_t0} -{B_v1} {B_t0}', - f'{C_v1} {C_t0} {p0} {A_t0} -{A_v1} {B_t0}', - f'{C_v1} {C_t0} {B_v0} {r0} {A_v0} {A_t0} -{C_v0} {B_t0}', - f'{C_v1} {C_t0} {B_v0} {r0} -{A_v0} {A_t0} {C_v0} {B_t0}', - f'{C_v1} {C_t0} -{B_v0} {r0} {A_v0} {A_t0} {C_v0} {B_t0}', - f'{C_v1} {C_t0} -{B_v0} {r0} -{A_v0} {A_t0} -{C_v0} {B_t0}', - f'{C_v1} {C_t0} {r0} {A_t0} -{B_v1} {B_t0}', - f'{C_v1} {C_t0} {r0} {A_t0} -{A_v1} {B_t0}', - f'-{C_v1} {B_t1} -{C_t0} {A_t1} -{B_v1} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} -{p0} {A_t1} -{B_v1} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} {B_v0} {A_v0} {A_t1} -{B_v1} {C_v0} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} {B_v0} -{A_v0} {A_t1} -{B_v1} -{C_v0} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} -{B_v0} {A_v0} {A_t1} -{B_v1} -{C_v0} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} -{B_v0} -{A_v0} {A_t1} -{B_v1} {C_v0} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} -{r0} {A_t1} -{B_v1} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} -{A_t0} {A_t1} -{B_v1} -{A_v1} {C_t1}', - f'-{C_v1} {B_t1} {A_t1} -{B_v1} -{A_v1} -{B_t0} {C_t1}', - f'-{C_v1} {C_t0} {p0} {B_v0} {A_v0} {A_t0} {C_v0} {B_t0}', - f'-{C_v1} {C_t0} {p0} {B_v0} -{A_v0} {A_t0} -{C_v0} {B_t0}', - f'-{C_v1} {C_t0} {p0} -{B_v0} {A_v0} {A_t0} -{C_v0} {B_t0}', - f'-{C_v1} {C_t0} {p0} -{B_v0} -{A_v0} {A_t0} {C_v0} {B_t0}', - f'-{C_v1} {C_t0} {p0} {A_t0} {B_v1} {B_t0}', - f'-{C_v1} {C_t0} {p0} {A_t0} {A_v1} {B_t0}', - f'-{C_v1} {C_t0} {B_v0} {r0} {A_v0} {A_t0} {C_v0} {B_t0}', - f'-{C_v1} {C_t0} {B_v0} {r0} -{A_v0} {A_t0} -{C_v0} {B_t0}', - f'-{C_v1} {C_t0} -{B_v0} {r0} {A_v0} {A_t0} -{C_v0} {B_t0}', - f'-{C_v1} {C_t0} -{B_v0} {r0} -{A_v0} {A_t0} {C_v0} {B_t0}', - f'-{C_v1} {C_t0} {r0} {A_t0} {B_v1} {B_t0}', - f'-{C_v1} {C_t0} {r0} {A_t0} {A_v1} {B_t0}', - f'{B_t1} {C_t0} {A_t0} {A_t1} {B_v1} {A_v1} {B_t0} -{C_t1}', - f'{B_t1} -{p0} {A_t1} {B_v1} {A_v1} -{C_t1}', - f'{B_t1} -{r0} {A_t1} {B_v1} {A_v1} -{C_t1}', - f'-{B_t1} {C_t0} {p0} {A_t0} {B_t0}', - f'-{B_t1} {C_t0} {r0} {A_t0} {B_t0}', - f'{C_t0} {p0} {B_v0} {A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}', - f'{C_t0} {p0} {B_v0} {A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}', - f'{C_t0} {p0} {B_v0} {A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}', - f'{C_t0} {p0} {B_v0} {A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}', - f'{C_t0} {p0} {B_v0} -{A_v0} {A_t0} {B_v1} {C_v0} {B_t0}', - f'{C_t0} {p0} {B_v0} -{A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}', - f'{C_t0} {p0} {B_v0} -{A_v0} {A_t0} {C_v0} {A_v1} {B_t0}', - f'{C_t0} {p0} {B_v0} -{A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}', - f'{C_t0} {p0} -{B_v0} {A_v0} {A_t0} {B_v1} {C_v0} {B_t0}', - f'{C_t0} {p0} -{B_v0} {A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}', - f'{C_t0} {p0} -{B_v0} {A_v0} {A_t0} {C_v0} {A_v1} {B_t0}', - f'{C_t0} {p0} -{B_v0} {A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}', - f'{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}', - f'{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}', - f'{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}', - f'{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}', - f'{C_t0} {p0} {A_t0} -{A_t1} {B_t0}', - f'{C_t0} {p0} {A_t0} {B_v1} -{A_v1} {B_t0}', - f'{C_t0} {p0} {A_t0} -{B_v1} {A_v1} {B_t0}', - f'{C_t0} {p0} {A_t0} {B_t0} -{C_t1}', - f'{C_t0} {B_v0} {r0} {A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}', - f'{C_t0} {B_v0} {r0} {A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}', - f'{C_t0} {B_v0} {r0} {A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}', - f'{C_t0} {B_v0} {r0} {A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}', - f'{C_t0} {B_v0} {r0} -{A_v0} {A_t0} {B_v1} {C_v0} {B_t0}', - f'{C_t0} {B_v0} {r0} -{A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}', - f'{C_t0} {B_v0} {r0} -{A_v0} {A_t0} {C_v0} {A_v1} {B_t0}', - f'{C_t0} {B_v0} {r0} -{A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}', - f'{C_t0} -{B_v0} {r0} {A_v0} {A_t0} {B_v1} {C_v0} {B_t0}', - f'{C_t0} -{B_v0} {r0} {A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}', - f'{C_t0} -{B_v0} {r0} {A_v0} {A_t0} {C_v0} {A_v1} {B_t0}', - f'{C_t0} -{B_v0} {r0} {A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}', - f'{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}', - f'{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}', - f'{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}', - f'{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}', - f'{C_t0} {r0} {A_t0} -{A_t1} {B_t0}', - f'{C_t0} {r0} {A_t0} {B_v1} -{A_v1} {B_t0}', - f'{C_t0} {r0} {A_t0} -{B_v1} {A_v1} {B_t0}', - f'{C_t0} {r0} {A_t0} {B_t0} -{C_t1}', - f'-{C_t0} -{p0}', - f'-{C_t0} -{r0}', - f'-{q0}', - f'{p0} -{r0}', - f'-{p0} {r0}', - f'-{p0} -{A_t0}', - f'-{p0} -{B_t0}', - f'-{r0} -{A_t0}', - f'-{r0} -{B_t0}' + f"{C_v1} {B_t1} {C_t0} {B_v0} {A_v0} {A_t0} {A_t1} {B_v1} -{C_v0} {A_v1} {B_t0}", + f"{C_v1} {B_t1} {C_t0} {B_v0} -{A_v0} {A_t0} {A_t1} {B_v1} {C_v0} {A_v1} {B_t0}", + f"{C_v1} {B_t1} {C_t0} -{B_v0} {A_v0} {A_t0} {A_t1} {B_v1} {C_v0} {A_v1} {B_t0}", + f"{C_v1} {B_t1} {C_t0} -{B_v0} -{A_v0} {A_t0} {A_t1} {B_v1} -{C_v0} {A_v1} {B_t0}", + f"{C_v1} {B_t1} -{C_t0} {A_t1} {B_v1} {A_v1} {C_t1}", + f"{C_v1} {B_t1} -{p0} {A_t1} {B_v1} {A_v1}", + f"{C_v1} {B_t1} {B_v0} {A_v0} {A_t1} {B_v1} -{C_v0} {A_v1} {C_t1}", + f"{C_v1} {B_t1} {B_v0} -{A_v0} {A_t1} {B_v1} {C_v0} {A_v1} {C_t1}", + f"{C_v1} {B_t1} -{B_v0} {A_v0} {A_t1} {B_v1} {C_v0} {A_v1} {C_t1}", + f"{C_v1} {B_t1} -{B_v0} -{A_v0} {A_t1} {B_v1} -{C_v0} {A_v1} {C_t1}", + f"{C_v1} {B_t1} -{r0} {A_t1} {B_v1} {A_v1}", + f"{C_v1} {B_t1} -{A_t0} {A_t1} {B_v1} {A_v1} {C_t1}", + f"{C_v1} {B_t1} {A_t1} {B_v1} {A_v1} -{B_t0} {C_t1}", + f"{C_v1} {C_t0} {p0} {B_v0} {A_v0} {A_t0} -{C_v0} {B_t0}", + f"{C_v1} {C_t0} {p0} {B_v0} -{A_v0} {A_t0} {C_v0} {B_t0}", + f"{C_v1} {C_t0} {p0} -{B_v0} {A_v0} {A_t0} {C_v0} {B_t0}", + f"{C_v1} {C_t0} {p0} -{B_v0} -{A_v0} {A_t0} -{C_v0} {B_t0}", + f"{C_v1} {C_t0} {p0} {A_t0} -{B_v1} {B_t0}", + f"{C_v1} {C_t0} {p0} {A_t0} -{A_v1} {B_t0}", + f"{C_v1} {C_t0} {B_v0} {r0} {A_v0} {A_t0} -{C_v0} {B_t0}", + f"{C_v1} {C_t0} {B_v0} {r0} -{A_v0} {A_t0} {C_v0} {B_t0}", + f"{C_v1} {C_t0} -{B_v0} {r0} {A_v0} {A_t0} {C_v0} {B_t0}", + f"{C_v1} {C_t0} -{B_v0} {r0} -{A_v0} {A_t0} -{C_v0} {B_t0}", + f"{C_v1} {C_t0} {r0} {A_t0} -{B_v1} {B_t0}", + f"{C_v1} {C_t0} {r0} {A_t0} -{A_v1} {B_t0}", + f"-{C_v1} {B_t1} -{C_t0} {A_t1} -{B_v1} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} -{p0} {A_t1} -{B_v1} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} {B_v0} {A_v0} {A_t1} -{B_v1} {C_v0} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} {B_v0} -{A_v0} {A_t1} -{B_v1} -{C_v0} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} -{B_v0} {A_v0} {A_t1} -{B_v1} -{C_v0} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} -{B_v0} -{A_v0} {A_t1} -{B_v1} {C_v0} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} -{r0} {A_t1} -{B_v1} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} -{A_t0} {A_t1} -{B_v1} -{A_v1} {C_t1}", + f"-{C_v1} {B_t1} {A_t1} -{B_v1} -{A_v1} -{B_t0} {C_t1}", + f"-{C_v1} {C_t0} {p0} {B_v0} {A_v0} {A_t0} {C_v0} {B_t0}", + f"-{C_v1} {C_t0} {p0} {B_v0} -{A_v0} {A_t0} -{C_v0} {B_t0}", + f"-{C_v1} {C_t0} {p0} -{B_v0} {A_v0} {A_t0} -{C_v0} {B_t0}", + f"-{C_v1} {C_t0} {p0} -{B_v0} -{A_v0} {A_t0} {C_v0} {B_t0}", + f"-{C_v1} {C_t0} {p0} {A_t0} {B_v1} {B_t0}", + f"-{C_v1} {C_t0} {p0} {A_t0} {A_v1} {B_t0}", + f"-{C_v1} {C_t0} {B_v0} {r0} {A_v0} {A_t0} {C_v0} {B_t0}", + f"-{C_v1} {C_t0} {B_v0} {r0} -{A_v0} {A_t0} -{C_v0} {B_t0}", + f"-{C_v1} {C_t0} -{B_v0} {r0} {A_v0} {A_t0} -{C_v0} {B_t0}", + f"-{C_v1} {C_t0} -{B_v0} {r0} -{A_v0} {A_t0} {C_v0} {B_t0}", + f"-{C_v1} {C_t0} {r0} {A_t0} {B_v1} {B_t0}", + f"-{C_v1} {C_t0} {r0} {A_t0} {A_v1} {B_t0}", + f"{B_t1} {C_t0} {A_t0} {A_t1} {B_v1} {A_v1} {B_t0} -{C_t1}", + f"{B_t1} -{p0} {A_t1} {B_v1} {A_v1} -{C_t1}", + f"{B_t1} -{r0} {A_t1} {B_v1} {A_v1} -{C_t1}", + f"-{B_t1} {C_t0} {p0} {A_t0} {B_t0}", + f"-{B_t1} {C_t0} {r0} {A_t0} {B_t0}", + f"{C_t0} {p0} {B_v0} {A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}", + f"{C_t0} {p0} {B_v0} {A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}", + f"{C_t0} {p0} {B_v0} {A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}", + f"{C_t0} {p0} {B_v0} {A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}", + f"{C_t0} {p0} {B_v0} -{A_v0} {A_t0} {B_v1} {C_v0} {B_t0}", + f"{C_t0} {p0} {B_v0} -{A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}", + f"{C_t0} {p0} {B_v0} -{A_v0} {A_t0} {C_v0} {A_v1} {B_t0}", + f"{C_t0} {p0} {B_v0} -{A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}", + f"{C_t0} {p0} -{B_v0} {A_v0} {A_t0} {B_v1} {C_v0} {B_t0}", + f"{C_t0} {p0} -{B_v0} {A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}", + f"{C_t0} {p0} -{B_v0} {A_v0} {A_t0} {C_v0} {A_v1} {B_t0}", + f"{C_t0} {p0} -{B_v0} {A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}", + f"{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}", + f"{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}", + f"{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}", + f"{C_t0} {p0} -{B_v0} -{A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}", + f"{C_t0} {p0} {A_t0} -{A_t1} {B_t0}", + f"{C_t0} {p0} {A_t0} {B_v1} -{A_v1} {B_t0}", + f"{C_t0} {p0} {A_t0} -{B_v1} {A_v1} {B_t0}", + f"{C_t0} {p0} {A_t0} {B_t0} -{C_t1}", + f"{C_t0} {B_v0} {r0} {A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}", + f"{C_t0} {B_v0} {r0} {A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}", + f"{C_t0} {B_v0} {r0} {A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}", + f"{C_t0} {B_v0} {r0} {A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}", + f"{C_t0} {B_v0} {r0} -{A_v0} {A_t0} {B_v1} {C_v0} {B_t0}", + f"{C_t0} {B_v0} {r0} -{A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}", + f"{C_t0} {B_v0} {r0} -{A_v0} {A_t0} {C_v0} {A_v1} {B_t0}", + f"{C_t0} {B_v0} {r0} -{A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}", + f"{C_t0} -{B_v0} {r0} {A_v0} {A_t0} {B_v1} {C_v0} {B_t0}", + f"{C_t0} -{B_v0} {r0} {A_v0} {A_t0} -{B_v1} -{C_v0} {B_t0}", + f"{C_t0} -{B_v0} {r0} {A_v0} {A_t0} {C_v0} {A_v1} {B_t0}", + f"{C_t0} -{B_v0} {r0} {A_v0} {A_t0} -{C_v0} -{A_v1} {B_t0}", + f"{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} {B_v1} -{C_v0} {B_t0}", + f"{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} -{B_v1} {C_v0} {B_t0}", + f"{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} {C_v0} -{A_v1} {B_t0}", + f"{C_t0} -{B_v0} {r0} -{A_v0} {A_t0} -{C_v0} {A_v1} {B_t0}", + f"{C_t0} {r0} {A_t0} -{A_t1} {B_t0}", + f"{C_t0} {r0} {A_t0} {B_v1} -{A_v1} {B_t0}", + f"{C_t0} {r0} {A_t0} -{B_v1} {A_v1} {B_t0}", + f"{C_t0} {r0} {A_t0} {B_t0} -{C_t1}", + f"-{C_t0} -{p0}", + f"-{C_t0} -{r0}", + f"-{q0}", + f"{p0} -{r0}", + f"-{p0} {r0}", + f"-{p0} -{A_t0}", + f"-{p0} -{B_t0}", + f"-{r0} -{A_t0}", + f"-{r0} -{B_t0}", ] def get_cnf_semi_deterministic_window_1( - A_t0, A_t1, A_t2, A_v0, A_v1, A_v2, - B_t0, B_t1, B_t2, B_v0, B_v1, B_v2, - C_t0, C_t1, C_t2, C_v0, C_v1, - p0, q0, r0 + A_t0, A_t1, A_t2, A_v0, A_v1, A_v2, B_t0, B_t1, B_t2, B_v0, B_v1, B_v2, C_t0, C_t1, C_t2, C_v0, C_v1, p0, q0, r0 ): return [ - f'{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} {B_v0} -{C_v0}', - f'{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} -{B_v0} {C_v0}', - f'{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} {B_v0} {C_v0}', - f'{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} -{B_v0} -{C_v0}', - f'{C_v1} {A_t1} {A_v1} -{A_t0} {C_t1} {B_t1} {B_v1}', - f'{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{C_t0} {B_v1}', - f'{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{B_t0} {B_v1}', - f'{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} {A_v0} {B_v1} {B_v0} -{C_v0}', - f'{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} {A_v0} {B_v1} -{B_v0} {C_v0}', - f'{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{A_v0} {B_v1} {B_v0} {C_v0}', - f'{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{A_v0} {B_v1} -{B_v0} -{C_v0}', - f'{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{p0} {B_v1}', - f'{C_v1} {A_t1} {A_v1} -{r0} {B_t1} {B_v1}', - f'{C_v1} {A_t1} {A_v1} {A_v2} {B_t1} -{C_t2} {A_t2} {B_t2} {B_v2} -{p0} {B_v1}', - f'{C_v1} {A_t1} {A_v1} {B_t1} {A_v0} -{p0} {B_v1} {B_v0} -{C_v0}', - f'{C_v1} {A_t1} {A_v1} {B_t1} {A_v0} -{p0} {B_v1} -{B_v0} {C_v0}', - f'{C_v1} {A_t1} {A_v1} {B_t1} -{A_v0} -{p0} {B_v1} {B_v0} {C_v0}', - f'{C_v1} {A_t1} {A_v1} {B_t1} -{A_v0} -{p0} {B_v1} -{B_v0} -{C_v0}', - f'{C_v1} -{A_v1} {A_t0} {r0} {C_t0} {B_t0}', - f'{C_v1} -{A_v1} {A_t0} {C_t0} {B_t0} {p0}', - f'{C_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v0} -{C_v0}', - f'{C_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} {C_v0}', - f'{C_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} {C_v0}', - f'{C_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} -{C_v0}', - f'{C_v1} {A_t0} {r0} {C_t0} {B_t0} -{B_v1}', - f'{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} -{C_v0}', - f'{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} {C_v0}', - f'{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} {C_v0}', - f'{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} -{C_v0}', - f'{C_v1} {A_t0} {C_t0} {B_t0} {p0} -{B_v1}', - f'-{C_v1} {A_t1} -{A_v1} -{A_t0} {C_t1} {B_t1} -{B_v1}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} -{r0} {B_t1} -{B_v1}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{C_t0} -{B_v1}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{B_t0} -{B_v1}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} {A_v0} -{B_v1} {B_v0} {C_v0}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} {A_v0} -{B_v1} -{B_v0} -{C_v0}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{A_v0} -{B_v1} {B_v0} -{C_v0}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{A_v0} -{B_v1} -{B_v0} {C_v0}', - f'-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{p0} -{B_v1}', - f'-{C_v1} {A_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0}', - f'-{C_v1} {A_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2}', - f'-{C_v1} {A_v1} {A_t0} {C_t0} {B_t0} {p0}', - f'-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} {A_v0} {B_v0} {C_v0}', - f'-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} -{C_v0}', - f'-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} -{C_v0}', - f'-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} {C_v0}', - f'-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} {B_v1}', - f'-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} {A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v0} {C_v0}', - f'-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} {A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} -{B_v0} -{C_v0}', - f'-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v0} -{C_v0}', - f'-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} -{B_v0} {C_v0}', - f'-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v1}', - f'-{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} {C_v0}', - f'-{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} -{C_v0}', - f'-{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} -{C_v0}', - f'-{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} {C_v0}', - f'-{C_v1} {A_t0} {C_t0} {B_t0} {p0} {B_v1}', - f'{A_t1} {A_v1} {A_t0} -{C_t1} {A_v2} {B_t1} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v1}', - f'{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} {B_v0} -{C_v0}', - f'{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} -{B_v0} {C_v0}', - f'{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} {B_v0} {C_v0}', - f'{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} -{B_v0} -{C_v0}', - f'{A_t1} {A_v1} -{C_t1} -{r0} {B_t1} {B_v1}', - f'{A_t1} {A_v1} -{C_t1} {A_v2} {B_t1} -{C_t2} {A_t2} {B_t2} {B_v2} -{p0} {B_v1}', - f'{A_t1} {A_v1} -{C_t1} {B_t1} {A_v0} -{p0} {B_v1} {B_v0} -{C_v0}', - f'{A_t1} {A_v1} -{C_t1} {B_t1} {A_v0} -{p0} {B_v1} -{B_v0} {C_v0}', - f'{A_t1} {A_v1} -{C_t1} {B_t1} -{A_v0} -{p0} {B_v1} {B_v0} {C_v0}', - f'{A_t1} {A_v1} -{C_t1} {B_t1} -{A_v0} -{p0} {B_v1} -{B_v0} -{C_v0}', - f'-{A_t1} {A_t0} {r0} {C_t0} {B_t0}', - f'-{A_t1} {A_t0} {C_t0} {B_t0} {p0}', - f'-{A_t1} {r0} -{p0}', - f'{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v0} -{C_v0}', - f'{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} {C_v0}', - f'{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} {C_v0}', - f'{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} -{C_v0}', - f'{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{B_v1}', - f'{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} -{C_v0}', - f'{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} {C_v0}', - f'{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} {C_v0}', - f'{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} -{C_v0}', - f'{A_v1} {A_t0} {C_t0} {B_t0} {p0} -{B_v1}', - f'-{A_v1} {A_t0} -{C_t1} {r0} {C_t0} {B_t0}', - f'-{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v0} {C_v0}', - f'-{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} -{C_v0}', - f'-{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} -{C_v0}', - f'-{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} {C_v0}', - f'-{A_v1} {A_t0} {r0} {C_t0} {B_t0} {B_v1}', - f'-{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} {C_v0}', - f'-{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} -{C_v0}', - f'-{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} -{C_v0}', - f'-{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} {C_v0}', - f'-{A_v1} {A_t0} {C_t0} {B_t0} {p0} {B_v1}', - f'-{A_v1} {r0} -{p0}', - f'{A_t0} -{C_t1} {r0} {A_v2} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2}', - f'{A_t0} -{C_t1} {r0} {C_t0} {B_t0} {A_v0} {B_v0} -{C_v0}', - f'{A_t0} -{C_t1} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} {C_v0}', - f'{A_t0} -{C_t1} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} {C_v0}', - f'{A_t0} -{C_t1} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} -{C_v0}', - f'{A_t0} -{C_t1} {r0} {C_t0} {B_t0} -{B_v1}', - f'{A_t0} -{C_t1} {C_t0} {B_t0} {p0}', - f'{A_t0} {r0} -{B_t1} {C_t0} {B_t0}', - f'{A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v1} {B_v0} -{C_v0}', - f'{A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v1} -{B_v0} {C_v0}', - f'{A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v1} {B_v0} {C_v0}', - f'{A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v1} -{B_v0} -{C_v0}', - f'{A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v1} {B_v0} {C_v0}', - f'{A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v1} -{B_v0} -{C_v0}', - f'{A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v1} {B_v0} -{C_v0}', - f'{A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v1} -{B_v0} {C_v0}', - f'{A_t0} -{B_t1} {C_t0} {B_t0} {p0}', - f'{A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v1} {B_v0} -{C_v0}', - f'{A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v1} -{B_v0} {C_v0}', - f'{A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v1} {B_v0} {C_v0}', - f'{A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v1} -{B_v0} -{C_v0}', - f'{A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v1} {B_v0} {C_v0}', - f'{A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v1} -{B_v0} -{C_v0}', - f'{A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v1} {B_v0} -{C_v0}', - f'{A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v1} -{B_v0} {C_v0}', - f'-{A_t0} -{r0}', - f'-{A_t0} -{p0}', - f'{C_t1} {r0} -{p0}', - f'{r0} {A_v2} -{C_t2} {A_t2} {B_t2} {B_v2} -{p0}', - f'{r0} -{B_t1} -{p0}', - f'{r0} {A_v0} -{p0} {B_v0} -{C_v0}', - f'{r0} {A_v0} -{p0} -{B_v0} {C_v0}', - f'{r0} -{A_v0} -{p0} {B_v0} {C_v0}', - f'{r0} -{A_v0} -{p0} -{B_v0} -{C_v0}', - f'{r0} -{p0} -{B_v1}', - f'-{r0} -{C_t0}', - f'-{r0} -{B_t0}', - f'-{r0} {p0}', - f'-{C_t0} -{p0}', - f'-{B_t0} -{p0}', - f'-{q0}' + f"{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} {B_v0} -{C_v0}", + f"{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} -{B_v0} {C_v0}", + f"{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} {B_v0} {C_v0}", + f"{C_v1} {A_t1} {A_v1} {A_t0} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} -{B_v0} -{C_v0}", + f"{C_v1} {A_t1} {A_v1} -{A_t0} {C_t1} {B_t1} {B_v1}", + f"{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{C_t0} {B_v1}", + f"{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{B_t0} {B_v1}", + f"{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} {A_v0} {B_v1} {B_v0} -{C_v0}", + f"{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} {A_v0} {B_v1} -{B_v0} {C_v0}", + f"{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{A_v0} {B_v1} {B_v0} {C_v0}", + f"{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{A_v0} {B_v1} -{B_v0} -{C_v0}", + f"{C_v1} {A_t1} {A_v1} {C_t1} {B_t1} -{p0} {B_v1}", + f"{C_v1} {A_t1} {A_v1} -{r0} {B_t1} {B_v1}", + f"{C_v1} {A_t1} {A_v1} {A_v2} {B_t1} -{C_t2} {A_t2} {B_t2} {B_v2} -{p0} {B_v1}", + f"{C_v1} {A_t1} {A_v1} {B_t1} {A_v0} -{p0} {B_v1} {B_v0} -{C_v0}", + f"{C_v1} {A_t1} {A_v1} {B_t1} {A_v0} -{p0} {B_v1} -{B_v0} {C_v0}", + f"{C_v1} {A_t1} {A_v1} {B_t1} -{A_v0} -{p0} {B_v1} {B_v0} {C_v0}", + f"{C_v1} {A_t1} {A_v1} {B_t1} -{A_v0} -{p0} {B_v1} -{B_v0} -{C_v0}", + f"{C_v1} -{A_v1} {A_t0} {r0} {C_t0} {B_t0}", + f"{C_v1} -{A_v1} {A_t0} {C_t0} {B_t0} {p0}", + f"{C_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v0} -{C_v0}", + f"{C_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} {C_v0}", + f"{C_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} {C_v0}", + f"{C_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} -{C_v0}", + f"{C_v1} {A_t0} {r0} {C_t0} {B_t0} -{B_v1}", + f"{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} -{C_v0}", + f"{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} {C_v0}", + f"{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} {C_v0}", + f"{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} -{C_v0}", + f"{C_v1} {A_t0} {C_t0} {B_t0} {p0} -{B_v1}", + f"-{C_v1} {A_t1} -{A_v1} -{A_t0} {C_t1} {B_t1} -{B_v1}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} -{r0} {B_t1} -{B_v1}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{C_t0} -{B_v1}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{B_t0} -{B_v1}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} {A_v0} -{B_v1} {B_v0} {C_v0}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} {A_v0} -{B_v1} -{B_v0} -{C_v0}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{A_v0} -{B_v1} {B_v0} -{C_v0}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{A_v0} -{B_v1} -{B_v0} {C_v0}", + f"-{C_v1} {A_t1} -{A_v1} {C_t1} {B_t1} -{p0} -{B_v1}", + f"-{C_v1} {A_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0}", + f"-{C_v1} {A_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2}", + f"-{C_v1} {A_v1} {A_t0} {C_t0} {B_t0} {p0}", + f"-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} {A_v0} {B_v0} {C_v0}", + f"-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} -{C_v0}", + f"-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} -{C_v0}", + f"-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} {C_v0}", + f"-{C_v1} {A_t0} {C_t1} {r0} {C_t0} {B_t0} {B_v1}", + f"-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} {A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v0} {C_v0}", + f"-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} {A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} -{B_v0} -{C_v0}", + f"-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v0} -{C_v0}", + f"-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{A_v0} -{C_t2} {A_t2} {B_t2} {B_v2} -{B_v0} {C_v0}", + f"-{C_v1} {A_t0} {r0} {A_v2} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v1}", + f"-{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} {C_v0}", + f"-{C_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} -{C_v0}", + f"-{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} -{C_v0}", + f"-{C_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} {C_v0}", + f"-{C_v1} {A_t0} {C_t0} {B_t0} {p0} {B_v1}", + f"{A_t1} {A_v1} {A_t0} -{C_t1} {A_v2} {B_t1} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2} {B_v1}", + f"{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} {B_v0} -{C_v0}", + f"{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} {A_v0} {B_v1} -{B_v0} {C_v0}", + f"{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} {B_v0} {C_v0}", + f"{A_t1} {A_v1} {A_t0} -{C_t1} {B_t1} {C_t0} {B_t0} -{A_v0} {B_v1} -{B_v0} -{C_v0}", + f"{A_t1} {A_v1} -{C_t1} -{r0} {B_t1} {B_v1}", + f"{A_t1} {A_v1} -{C_t1} {A_v2} {B_t1} -{C_t2} {A_t2} {B_t2} {B_v2} -{p0} {B_v1}", + f"{A_t1} {A_v1} -{C_t1} {B_t1} {A_v0} -{p0} {B_v1} {B_v0} -{C_v0}", + f"{A_t1} {A_v1} -{C_t1} {B_t1} {A_v0} -{p0} {B_v1} -{B_v0} {C_v0}", + f"{A_t1} {A_v1} -{C_t1} {B_t1} -{A_v0} -{p0} {B_v1} {B_v0} {C_v0}", + f"{A_t1} {A_v1} -{C_t1} {B_t1} -{A_v0} -{p0} {B_v1} -{B_v0} -{C_v0}", + f"-{A_t1} {A_t0} {r0} {C_t0} {B_t0}", + f"-{A_t1} {A_t0} {C_t0} {B_t0} {p0}", + f"-{A_t1} {r0} -{p0}", + f"{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v0} -{C_v0}", + f"{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} {C_v0}", + f"{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} {C_v0}", + f"{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} -{C_v0}", + f"{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{B_v1}", + f"{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} -{C_v0}", + f"{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} {C_v0}", + f"{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} {C_v0}", + f"{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} -{C_v0}", + f"{A_v1} {A_t0} {C_t0} {B_t0} {p0} -{B_v1}", + f"-{A_v1} {A_t0} -{C_t1} {r0} {C_t0} {B_t0}", + f"-{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v0} {C_v0}", + f"-{A_v1} {A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} -{C_v0}", + f"-{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} -{C_v0}", + f"-{A_v1} {A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} {C_v0}", + f"-{A_v1} {A_t0} {r0} {C_t0} {B_t0} {B_v1}", + f"-{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v0} {C_v0}", + f"-{A_v1} {A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v0} -{C_v0}", + f"-{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v0} -{C_v0}", + f"-{A_v1} {A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v0} {C_v0}", + f"-{A_v1} {A_t0} {C_t0} {B_t0} {p0} {B_v1}", + f"-{A_v1} {r0} -{p0}", + f"{A_t0} -{C_t1} {r0} {A_v2} {C_t0} {B_t0} -{C_t2} {A_t2} {B_t2} {B_v2}", + f"{A_t0} -{C_t1} {r0} {C_t0} {B_t0} {A_v0} {B_v0} -{C_v0}", + f"{A_t0} -{C_t1} {r0} {C_t0} {B_t0} {A_v0} -{B_v0} {C_v0}", + f"{A_t0} -{C_t1} {r0} {C_t0} {B_t0} -{A_v0} {B_v0} {C_v0}", + f"{A_t0} -{C_t1} {r0} {C_t0} {B_t0} -{A_v0} -{B_v0} -{C_v0}", + f"{A_t0} -{C_t1} {r0} {C_t0} {B_t0} -{B_v1}", + f"{A_t0} -{C_t1} {C_t0} {B_t0} {p0}", + f"{A_t0} {r0} -{B_t1} {C_t0} {B_t0}", + f"{A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v1} {B_v0} -{C_v0}", + f"{A_t0} {r0} {C_t0} {B_t0} {A_v0} {B_v1} -{B_v0} {C_v0}", + f"{A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v1} {B_v0} {C_v0}", + f"{A_t0} {r0} {C_t0} {B_t0} {A_v0} -{B_v1} -{B_v0} -{C_v0}", + f"{A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v1} {B_v0} {C_v0}", + f"{A_t0} {r0} {C_t0} {B_t0} -{A_v0} {B_v1} -{B_v0} -{C_v0}", + f"{A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v1} {B_v0} -{C_v0}", + f"{A_t0} {r0} {C_t0} {B_t0} -{A_v0} -{B_v1} -{B_v0} {C_v0}", + f"{A_t0} -{B_t1} {C_t0} {B_t0} {p0}", + f"{A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v1} {B_v0} -{C_v0}", + f"{A_t0} {C_t0} {B_t0} {A_v0} {p0} {B_v1} -{B_v0} {C_v0}", + f"{A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v1} {B_v0} {C_v0}", + f"{A_t0} {C_t0} {B_t0} {A_v0} {p0} -{B_v1} -{B_v0} -{C_v0}", + f"{A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v1} {B_v0} {C_v0}", + f"{A_t0} {C_t0} {B_t0} -{A_v0} {p0} {B_v1} -{B_v0} -{C_v0}", + f"{A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v1} {B_v0} -{C_v0}", + f"{A_t0} {C_t0} {B_t0} -{A_v0} {p0} -{B_v1} -{B_v0} {C_v0}", + f"-{A_t0} -{r0}", + f"-{A_t0} -{p0}", + f"{C_t1} {r0} -{p0}", + f"{r0} {A_v2} -{C_t2} {A_t2} {B_t2} {B_v2} -{p0}", + f"{r0} -{B_t1} -{p0}", + f"{r0} {A_v0} -{p0} {B_v0} -{C_v0}", + f"{r0} {A_v0} -{p0} -{B_v0} {C_v0}", + f"{r0} -{A_v0} -{p0} {B_v0} {C_v0}", + f"{r0} -{A_v0} -{p0} -{B_v0} -{C_v0}", + f"{r0} -{p0} -{B_v1}", + f"-{r0} -{C_t0}", + f"-{r0} -{B_t0}", + f"-{r0} {p0}", + f"-{C_t0} -{p0}", + f"-{B_t0} -{p0}", + f"-{q0}", ] def get_cnf_semi_deterministic_window_2( - A_t0, A_t1, A_t2, A_t3, - A_v0, A_v1, A_v2, A_v3, - B_t0, B_t1, B_t2, B_t3, - B_v0, B_v1, B_v2, B_v3, - C_t0, C_t1, C_t2, C_t3, - C_v0, C_v1, - p0, q0, r0 + A_t0, + A_t1, + A_t2, + A_t3, + A_v0, + A_v1, + A_v2, + A_v3, + B_t0, + B_t1, + B_t2, + B_t3, + B_v0, + B_v1, + B_v2, + B_v3, + C_t0, + C_t1, + C_t2, + C_t3, + C_v0, + C_v1, + p0, + q0, + r0, ): return [ - f'{A_t3} {A_v3} {B_t3} {B_v3} -{C_t3} -{q0}', - f'-{A_t1} -{q0}', - f'{A_t2} {A_v2} {B_t2} {B_v2} -{C_t2} -{p0} {r0}', - f'{A_t1} {A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}', - f'{A_t1} -{A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}', - f'{A_t1} -{A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}', - f'{A_t1} {A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}', - f'-{A_t2} -{q0}', - f'-{A_v1} -{q0}', - f'-{A_v2} -{q0}', - f'-{B_t1} -{q0}', - f'{A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1} -{r0}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} {C_v0} {C_v1}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} {C_v0} {C_v1}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}', - f'-{B_t2} -{q0}', - f'{A_t1} {A_v1} {B_t1} {B_v1} -{C_t1} {q0} -{r0}', - f'-{A_v1} -{p0} {r0}', - f'{A_t1} {A_v1} {B_t1} {B_v1} {C_v1} {q0} -{r0}', - f'{A_t0} {B_t0} {B_v1} {C_t0} -{C_v1} {p0} {r0}', - f'-{B_v1} -{q0}', - f'-{B_v2} -{q0}', - f'{A_t0} -{A_v1} {B_t0} {C_t0} {C_v1} {r0}', - f'{C_t1} -{p0} {r0}', - f'{A_t0} {A_v1} {B_t0} -{B_v1} {C_t0} {r0}', - f'{C_t2} -{q0}', - f'{A_t0} -{A_t1} {B_t0} {C_t0} {r0}', - f'-{A_t0} -{r0}', - f'-{B_t0} -{r0}', - f'{A_t0} {B_t0} -{B_t1} {C_t0} {r0}', - f'-{C_t0} -{r0}', - f'-{p0} -{q0}', - f'{A_t0} {B_t0} {C_t0} -{C_t1} {p0} {q0}', - f'{C_t1} {p0} -{r0}', - f'-{q0} {r0}', - f'{A_t1} -{A_v1} {B_t1} -{B_v1} -{C_t0} {C_t1} -{C_v1}', - f'{A_t1} -{A_v1} -{B_t0} {B_t1} -{B_v1} {C_t1} -{C_v1}', - f'-{A_t0} {A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1}', - f'{A_t1} {A_v1} {B_t1} {B_v1} -{C_t0} {C_t1} {C_v1}', - f'{A_t1} {A_v1} -{B_t0} {B_t1} {B_v1} {C_t1} {C_v1}', - f'-{A_t0} {A_t1} {A_v1} {B_t1} {B_v1} {C_t1} {C_v1}', - f'-{C_t0} -{p0}', - f'-{B_t0} -{p0}', - f'-{A_t0} -{p0}', + f"{A_t3} {A_v3} {B_t3} {B_v3} -{C_t3} -{q0}", + f"-{A_t1} -{q0}", + f"{A_t2} {A_v2} {B_t2} {B_v2} -{C_t2} -{p0} {r0}", + f"{A_t1} {A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}", + f"{A_t1} -{A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}", + f"{A_t1} -{A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}", + f"{A_t1} {A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}", + f"-{A_t2} -{q0}", + f"-{A_v1} -{q0}", + f"-{A_v2} -{q0}", + f"-{B_t1} -{q0}", + f"{A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1} -{r0}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} {C_v0} {C_v1}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} {C_v0} {C_v1}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}", + f"-{B_t2} -{q0}", + f"{A_t1} {A_v1} {B_t1} {B_v1} -{C_t1} {q0} -{r0}", + f"-{A_v1} -{p0} {r0}", + f"{A_t1} {A_v1} {B_t1} {B_v1} {C_v1} {q0} -{r0}", + f"{A_t0} {B_t0} {B_v1} {C_t0} -{C_v1} {p0} {r0}", + f"-{B_v1} -{q0}", + f"-{B_v2} -{q0}", + f"{A_t0} -{A_v1} {B_t0} {C_t0} {C_v1} {r0}", + f"{C_t1} -{p0} {r0}", + f"{A_t0} {A_v1} {B_t0} -{B_v1} {C_t0} {r0}", + f"{C_t2} -{q0}", + f"{A_t0} -{A_t1} {B_t0} {C_t0} {r0}", + f"-{A_t0} -{r0}", + f"-{B_t0} -{r0}", + f"{A_t0} {B_t0} -{B_t1} {C_t0} {r0}", + f"-{C_t0} -{r0}", + f"-{p0} -{q0}", + f"{A_t0} {B_t0} {C_t0} -{C_t1} {p0} {q0}", + f"{C_t1} {p0} -{r0}", + f"-{q0} {r0}", + f"{A_t1} -{A_v1} {B_t1} -{B_v1} -{C_t0} {C_t1} -{C_v1}", + f"{A_t1} -{A_v1} -{B_t0} {B_t1} -{B_v1} {C_t1} -{C_v1}", + f"-{A_t0} {A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1}", + f"{A_t1} {A_v1} {B_t1} {B_v1} -{C_t0} {C_t1} {C_v1}", + f"{A_t1} {A_v1} -{B_t0} {B_t1} {B_v1} {C_t1} {C_v1}", + f"-{A_t0} {A_t1} {A_v1} {B_t1} {B_v1} {C_t1} {C_v1}", + f"-{C_t0} -{p0}", + f"-{B_t0} -{p0}", + f"-{A_t0} -{p0}", ] def get_cnf_semi_deterministic_window_3( - A_t0, A_t1, A_t2, A_t3, A_t4, - A_v0, A_v1, A_v2, A_v3, A_v4, - B_t0, B_t1, B_t2, B_t3, B_t4, - B_v0, B_v1, B_v2, B_v3, B_v4, - C_t0, C_t1, C_t2, C_t3, C_t4, - C_v0, C_v1, p0, q0, r0): + A_t0, + A_t1, + A_t2, + A_t3, + A_t4, + A_v0, + A_v1, + A_v2, + A_v3, + A_v4, + B_t0, + B_t1, + B_t2, + B_t3, + B_t4, + B_v0, + B_v1, + B_v2, + B_v3, + B_v4, + C_t0, + C_t1, + C_t2, + C_t3, + C_t4, + C_v0, + C_v1, + p0, + q0, + r0, +): return [ - f'{A_t4} {A_v4} {B_t4} {B_v4} -{C_t4} -{q0} {r0}', - f'{A_t3} {A_v3} {B_t3} {B_v3} -{C_t3} -{q0} -{r0}', - f'-{A_t3} -{q0} {r0}', - f'-{A_v3} -{q0} {r0}', - f'-{B_t3} -{q0} {r0}', - f'-{B_v3} -{q0} {r0}', - f'{C_t3} -{q0} {r0}', - f'{A_t1} {A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}', - f'{A_t1} -{A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}', - f'{A_t1} -{A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}', - f'{A_t1} {A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}', - f'-{A_v1} -{q0}', - f'-{B_v1} -{q0}', - f'{A_t0} {A_t2} {A_v2} {B_t0} {B_t2} {B_v2} {C_t0} -{C_t1} -{C_t2} {q0} {r0}', - f'{A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1} -{r0}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}', - f'{A_t0} -{A_t2} {B_t0} {C_t0} -{C_t1} {p0}', - f'{A_t1} {A_v1} {B_t1} {B_v1} -{C_t1} {q0} -{r0}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} {C_v0} {C_v1}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} {C_v0} {C_v1}', - f'{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}', - f'{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}', - f'{A_t1} {A_v1} {B_t1} {B_v1} {C_v1} {q0} -{r0}', - f'{A_t0} -{A_v2} {B_t0} {C_t0} -{C_t1} {p0}', - f'{A_t0} {B_t0} {B_v1} {C_t0} -{C_v1} {p0} {q0}', - f'{A_t0} {B_t0} -{B_t2} {C_t0} -{C_t1} {p0}', - f'{C_t1} -{p0} {r0}', - f'{A_t0} {B_t0} -{B_v2} {C_t0} -{C_t1} {p0}', - f'{A_t0} -{A_v1} {B_t0} {C_t0} -{C_t1} {r0}', - f'{A_t0} -{A_v1} {B_t0} {C_t0} {C_v1} {r0}', - f'{A_t0} {A_v1} {B_t0} -{B_v1} {C_t0} {r0}', - f'-{p0} -{q0}', - f'{A_t0} {B_t0} {C_t0} -{C_t1} {C_t2} {p0}', - f'-{A_t1} {p0} -{r0}', - f'-{B_t1} {p0} -{r0}', - f'{p0} {q0} -{r0}', - f'{A_t0} -{A_t1} {B_t0} {C_t0} {r0}', - f'{A_t0} {B_t0} -{B_t1} {C_t0} {r0}', - f'{A_t1} -{A_v1} {B_t1} -{B_v1} -{C_t0} {C_t1} -{C_v1}', - f'{A_t1} -{A_v1} -{B_t0} {B_t1} -{B_v1} {C_t1} -{C_v1}', - f'-{A_t0} {A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1}', - f'{A_t1} {A_v1} {B_t1} {B_v1} -{C_t0} {C_t1} {C_v1}', - f'{A_t1} {A_v1} -{B_t0} {B_t1} {B_v1} {C_t1} {C_v1}', - f'-{A_t0} {A_t1} {A_v1} {B_t1} {B_v1} {C_t1} {C_v1}', - f'-{C_t0} -{p0}', - f'-{B_t0} -{p0}', - f'-{A_t0} -{p0}', - f'-{C_t0} -{q0}', - f'-{B_t0} -{q0}', - f'-{A_t0} -{q0}', - f'{C_t1} -{q0}', + f"{A_t4} {A_v4} {B_t4} {B_v4} -{C_t4} -{q0} {r0}", + f"{A_t3} {A_v3} {B_t3} {B_v3} -{C_t3} -{q0} -{r0}", + f"-{A_t3} -{q0} {r0}", + f"-{A_v3} -{q0} {r0}", + f"-{B_t3} -{q0} {r0}", + f"-{B_v3} -{q0} {r0}", + f"{C_t3} -{q0} {r0}", + f"{A_t1} {A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}", + f"{A_t1} -{A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} {C_v0} -{C_v1}", + f"{A_t1} -{A_v0} -{A_v1} {B_t1} {B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}", + f"{A_t1} {A_v0} -{A_v1} {B_t1} -{B_v0} -{B_v1} {C_t1} -{C_v0} -{C_v1}", + f"-{A_v1} -{q0}", + f"-{B_v1} -{q0}", + f"{A_t0} {A_t2} {A_v2} {B_t0} {B_t2} {B_v2} {C_t0} -{C_t1} -{C_t2} {q0} {r0}", + f"{A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1} -{r0}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} {C_v0}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_t1} -{C_v0}", + f"{A_t0} -{A_t2} {B_t0} {C_t0} -{C_t1} {p0}", + f"{A_t1} {A_v1} {B_t1} {B_v1} -{C_t1} {q0} -{r0}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} {C_v0} {C_v1}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} {C_v0} {C_v1}", + f"{A_t0} {A_t1} {A_v0} {A_v1} {B_t0} {B_t1} {B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}", + f"{A_t0} {A_t1} -{A_v0} {A_v1} {B_t0} {B_t1} -{B_v0} {B_v1} {C_t0} -{C_v0} {C_v1}", + f"{A_t1} {A_v1} {B_t1} {B_v1} {C_v1} {q0} -{r0}", + f"{A_t0} -{A_v2} {B_t0} {C_t0} -{C_t1} {p0}", + f"{A_t0} {B_t0} {B_v1} {C_t0} -{C_v1} {p0} {q0}", + f"{A_t0} {B_t0} -{B_t2} {C_t0} -{C_t1} {p0}", + f"{C_t1} -{p0} {r0}", + f"{A_t0} {B_t0} -{B_v2} {C_t0} -{C_t1} {p0}", + f"{A_t0} -{A_v1} {B_t0} {C_t0} -{C_t1} {r0}", + f"{A_t0} -{A_v1} {B_t0} {C_t0} {C_v1} {r0}", + f"{A_t0} {A_v1} {B_t0} -{B_v1} {C_t0} {r0}", + f"-{p0} -{q0}", + f"{A_t0} {B_t0} {C_t0} -{C_t1} {C_t2} {p0}", + f"-{A_t1} {p0} -{r0}", + f"-{B_t1} {p0} -{r0}", + f"{p0} {q0} -{r0}", + f"{A_t0} -{A_t1} {B_t0} {C_t0} {r0}", + f"{A_t0} {B_t0} -{B_t1} {C_t0} {r0}", + f"{A_t1} -{A_v1} {B_t1} -{B_v1} -{C_t0} {C_t1} -{C_v1}", + f"{A_t1} -{A_v1} -{B_t0} {B_t1} -{B_v1} {C_t1} -{C_v1}", + f"-{A_t0} {A_t1} -{A_v1} {B_t1} -{B_v1} {C_t1} -{C_v1}", + f"{A_t1} {A_v1} {B_t1} {B_v1} -{C_t0} {C_t1} {C_v1}", + f"{A_t1} {A_v1} -{B_t0} {B_t1} {B_v1} {C_t1} {C_v1}", + f"-{A_t0} {A_t1} {A_v1} {B_t1} {B_v1} {C_t1} {C_v1}", + f"-{C_t0} -{p0}", + f"-{B_t0} -{p0}", + f"-{A_t0} -{p0}", + f"-{C_t0} -{q0}", + f"-{B_t0} -{q0}", + f"-{A_t0} -{q0}", + f"{C_t1} -{q0}", ] diff --git a/claasp/cipher_modules/models/smt/smt_model.py b/claasp/cipher_modules/models/smt/smt_model.py index 719897052..7e4857c31 100644 --- a/claasp/cipher_modules/models/smt/smt_model.py +++ b/claasp/cipher_modules/models/smt/smt_model.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - """ SMT model of Cipher. @@ -37,23 +35,25 @@ class :py:class:`Sat Model `). SMT-LIB is t :py:mod:`claasp.cipher_modules.models.smt.solvers.py` and to the section :ref:`Available SMT solvers`. """ + import math import re import subprocess -from claasp.name_mappings import (SBOX, CIPHER, XOR_LINEAR) -from claasp.cipher_modules.models.smt.utils import constants, utils from claasp.cipher_modules.models.smt import solvers +from claasp.cipher_modules.models.smt.solvers import MATHSAT_EXT, YICES_EXT, Z3_EXT +from claasp.cipher_modules.models.smt.utils import constants, utils from claasp.cipher_modules.models.utils import convert_solver_solution_to_dictionary, set_component_solution +from claasp.name_mappings import SATISFIABLE, SBOX, UNSATISFIABLE def mathsat_parser(output_to_parse): tmp_dict = {} for line in output_to_parse[1:]: - if line.strip().startswith('(define-fun'): - solution = line.strip()[1:-1].split(' ') + if line.strip().startswith("(define-fun"): + solution = line.strip()[1:-1].split(" ") var_name = solution[1] - var_value = 1 if solution[-1] == 'true' else 0 + var_value = 1 if solution[-1] == "true" else 0 tmp_dict[var_name] = var_value return tmp_dict @@ -62,10 +62,10 @@ def mathsat_parser(output_to_parse): def yices_parser(output_to_parse): tmp_dict = {} for line in output_to_parse[1:]: - if line.startswith('(='): - solution = line[1:-1].split(' ') + if line.startswith("(="): + solution = line[1:-1].split(" ") var_name = solution[1] - var_value = 1 if solution[-1] == 'true' else 0 + var_value = 1 if solution[-1] == "true" else 0 tmp_dict[var_name] = var_value return tmp_dict @@ -74,17 +74,17 @@ def yices_parser(output_to_parse): def z3_parser(output_to_parse): tmp_dict = {} for index in range(2, len(output_to_parse), 2): - if output_to_parse[index] == ')': + if output_to_parse[index] == ")": break var_name = output_to_parse[index].split()[1] - var_value = 1 if output_to_parse[index + 1].strip()[:-1] == 'true' else 0 + var_value = 1 if output_to_parse[index + 1].strip()[:-1] == "true" else 0 tmp_dict[var_name] = var_value return tmp_dict class SmtModel: - def __init__(self, cipher, counter='sequential'): + def __init__(self, cipher, counter="sequential"): self._cipher = cipher self._variables_list = [] self._model_constraints = [] @@ -94,38 +94,13 @@ def __init__(self, cipher, counter='sequential'): self._sboxes_lat_templates = {} # set the counter to fix the weight - if counter == 'sequential': + if counter == "sequential": self._counter = self._sequential_counter else: self._counter = self._parallel_counter def _declarations_builder(self): - self._declarations = [f'(declare-const {variable} Bool)' - for variable in self._variables_list] - - def _generate_component_input_ids(self, component): - input_id_link = component.id - in_suffix = constants.INPUT_BIT_ID_SUFFIX - input_bit_size = component.input_bit_size - input_bit_ids = [f'{input_id_link}_{i}{in_suffix}' for i in range(input_bit_size)] - - return input_bit_size, input_bit_ids - - def _generate_input_ids(self, component, suffix=''): - input_id_link = component.input_id_links - input_bit_positions = component.input_bit_positions - input_bit_ids = [] - for link, positions in zip(input_id_link, input_bit_positions): - input_bit_ids.extend([f'{link}_{j}{suffix}' for j in positions]) - - return component.input_bit_size, input_bit_ids - - def _generate_output_ids(self, component, suffix=''): - output_id_link = component.id - output_bit_size = component.output_bit_size - output_bit_ids = [f'{output_id_link}_{j}{suffix}' for j in range(output_bit_size)] - - return output_bit_size, output_bit_ids + self._declarations = [f"(declare-const {variable} Bool)" for variable in self._variables_list] def _get_cipher_inputs_components_solutions(self, out_suffix, variable2value): components_solutions = {} @@ -133,10 +108,10 @@ def _get_cipher_inputs_components_solutions(self, out_suffix, variable2value): value = 0 for i in range(bit_size): value <<= 1 - if f'{cipher_input}_{i}{out_suffix}' in variable2value: - value ^= variable2value[f'{cipher_input}_{i}{out_suffix}'] + if f"{cipher_input}_{i}{out_suffix}" in variable2value: + value ^= variable2value[f"{cipher_input}_{i}{out_suffix}"] hex_digits = bit_size // 4 + (bit_size % 4 != 0) - hex_value = f'{value:0{hex_digits}x}' + hex_value = f"{value:#0{hex_digits + 2}x}" component_solution = set_component_solution(hex_value) components_solutions[cipher_input] = component_solution @@ -157,64 +132,80 @@ def _parallel_counter(self, hw_list, weight): variables = [] constraints = [] num_of_orders = math.ceil(math.log2(len(hw_list))) - dummy_list = [f'dummy_hw_{i}' for i in range(len(hw_list), 2 ** num_of_orders)] + dummy_list = [f"dummy_hw_{i}" for i in range(len(hw_list), 2**num_of_orders)] variables.extend(dummy_list) hw_list.extend(dummy_list) constraints.extend(utils.smt_assert(utils.smt_not(dummy)) for dummy in dummy_list) - for i in range(0, 2 ** num_of_orders, 2): - variables.append(f'r_{num_of_orders - 1}_{i // 2}_0') - variables.append(f'r_{num_of_orders - 1}_{i // 2}_1') - carry = utils.smt_and((f'{hw_list[i]}', f'{hw_list[i + 1]}')) - equation = utils.smt_equivalent((f'r_{num_of_orders - 1}_{i // 2}_0', carry)) + for i in range(0, 2**num_of_orders, 2): + variables.append(f"r_{num_of_orders - 1}_{i // 2}_0") + variables.append(f"r_{num_of_orders - 1}_{i // 2}_1") + carry = utils.smt_and((f"{hw_list[i]}", f"{hw_list[i + 1]}")) + equation = utils.smt_equivalent((f"r_{num_of_orders - 1}_{i // 2}_0", carry)) constraints.append(utils.smt_assert(equation)) - result = utils.smt_xor((f'{hw_list[i]}', f'{hw_list[i + 1]}')) - equation = utils.smt_equivalent((f'r_{num_of_orders - 1}_{i // 2}_1', result)) + result = utils.smt_xor((f"{hw_list[i]}", f"{hw_list[i + 1]}")) + equation = utils.smt_equivalent((f"r_{num_of_orders - 1}_{i // 2}_1", result)) constraints.append(utils.smt_assert(equation)) # recursively adding couple words series = num_of_orders - 2 for i in range(2, num_of_orders + 1): - for j in range(0, 2 ** num_of_orders, 2 ** i): + for j in range(0, 2**num_of_orders, 2**i): # carries computed as usual (remember the library convention: MSB indexed by 0) for k in range(0, i - 1): - variables.append(f'c_{series}_{j // (2 ** i)}_{k}') - carry = utils.smt_carry(f'r_{series + 1}_{j // (2 ** (i - 1))}_{k}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k}', - f'c_{series}_{j // (2 ** i)}_{k + 1}') - equation = utils.smt_equivalent((f'c_{series}_{j // (2 ** i)}_{k}', carry)) + variables.append(f"c_{series}_{j // (2**i)}_{k}") + carry = utils.smt_carry( + f"r_{series + 1}_{j // (2 ** (i - 1))}_{k}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k}", + f"c_{series}_{j // (2**i)}_{k + 1}", + ) + equation = utils.smt_equivalent((f"c_{series}_{j // (2**i)}_{k}", carry)) constraints.append(utils.smt_assert(equation)) # the carry for the tens is the first not null - variables.append(f'c_{series}_{j // (2 ** i)}_{i - 1}') - carry = utils.smt_and((f'r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}')) - equation = utils.smt_equivalent((f'c_{series}_{j // (2 ** i)}_{i - 1}', carry)) + variables.append(f"c_{series}_{j // (2**i)}_{i - 1}") + carry = utils.smt_and( + ( + f"r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}", + ) + ) + equation = utils.smt_equivalent((f"c_{series}_{j // (2**i)}_{i - 1}", carry)) constraints.append(utils.smt_assert(equation)) # first bit of the result (MSB) is simply the carry of the previous MSBs - variables.append(f'r_{series}_{j // (2 ** i)}_0') - equation = utils.smt_equivalent((f'r_{series}_{j // (2 ** i)}_0', - f'c_{series}_{j // (2 ** i)}_0')) + variables.append(f"r_{series}_{j // (2**i)}_0") + equation = utils.smt_equivalent((f"r_{series}_{j // (2**i)}_0", f"c_{series}_{j // (2**i)}_0")) constraints.append(utils.smt_assert(equation)) # remaining bits of the result except the last one are as usual for k in range(1, i): - variables.append(f'r_{series}_{j // (2 ** i)}_{k}') - result = utils.smt_xor((f'r_{series + 1}_{j // (2 ** (i - 1))}_{k - 1}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k - 1}', - f'c_{series}_{j // (2 ** i)}_{k}')) - equation = utils.smt_equivalent((f'r_{series}_{j // (2 ** i)}_{k}', result)) + variables.append(f"r_{series}_{j // (2**i)}_{k}") + result = utils.smt_xor( + ( + f"r_{series + 1}_{j // (2 ** (i - 1))}_{k - 1}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{k - 1}", + f"c_{series}_{j // (2**i)}_{k}", + ) + ) + equation = utils.smt_equivalent((f"r_{series}_{j // (2**i)}_{k}", result)) constraints.append(utils.smt_assert(equation)) # last bit of the result (LSB) - variables.append(f'r_{series}_{j // (2 ** i)}_{i}') - result = utils.smt_xor((f'r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}', - f'r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}')) - equation = utils.smt_equivalent((f'r_{series}_{j // (2 ** i)}_{i}', result)) + variables.append(f"r_{series}_{j // (2**i)}_{i}") + result = utils.smt_xor( + ( + f"r_{series + 1}_{j // (2 ** (i - 1))}_{i - 1}", + f"r_{series + 1}_{j // (2 ** (i - 1)) + 1}_{i - 1}", + ) + ) + equation = utils.smt_equivalent((f"r_{series}_{j // (2**i)}_{i}", result)) constraints.append(utils.smt_assert(equation)) series -= 1 # the bit length of hamming weight, needed to fix weight when building the model bit_length_of_hw = num_of_orders + 1 - constraints.extend(utils.smt_assert(f'r_0_0_{i}') if weight >> (bit_length_of_hw - 1 - i) & 1 - else utils.smt_assert(utils.smt_not(f'r_0_0_{i}')) - for i in range(bit_length_of_hw)) + constraints.extend( + utils.smt_assert(f"r_0_0_{i}") + if weight >> (bit_length_of_hw - 1 - i) & 1 + else utils.smt_assert(utils.smt_not(f"r_0_0_{i}")) + for i in range(bit_length_of_hw) + ) return variables, constraints, bit_length_of_hw @@ -223,21 +214,19 @@ def _sequential_counter_algorithm(self, hw_list, weight, dummy_id, greater_or_eq if greater_or_equal: weight = n - weight hw_list = [utils.smt_not(id_) for id_ in hw_list] - dummy_variables = [[f'{dummy_id}_{i}_{j}' for j in range(weight)] for i in range(n - 1)] + dummy_variables = [[f"{dummy_id}_{i}_{j}" for j in range(weight)] for i in range(n - 1)] constraints = [utils.smt_assert(utils.smt_implies(hw_list[0], dummy_variables[0][0]))] for j in range(1, weight): constraints.append(utils.smt_assert(utils.smt_not(dummy_variables[0][j]))) for i in range(1, n - 1): - constraints.append(utils.smt_assert(utils.smt_implies(hw_list[i], - dummy_variables[i][0]))) - constraints.append(utils.smt_assert(utils.smt_implies(dummy_variables[i - 1][0], - dummy_variables[i][0]))) + constraints.append(utils.smt_assert(utils.smt_implies(hw_list[i], dummy_variables[i][0]))) + constraints.append(utils.smt_assert(utils.smt_implies(dummy_variables[i - 1][0], dummy_variables[i][0]))) for j in range(1, weight): antecedent = utils.smt_and((hw_list[i], dummy_variables[i - 1][j - 1])) - constraints.append(utils.smt_assert(utils.smt_implies(antecedent, - dummy_variables[i][j]))) - constraints.append(utils.smt_assert(utils.smt_implies(dummy_variables[i - 1][j], - dummy_variables[i][j]))) + constraints.append(utils.smt_assert(utils.smt_implies(antecedent, dummy_variables[i][j]))) + constraints.append( + utils.smt_assert(utils.smt_implies(dummy_variables[i - 1][j], dummy_variables[i][j])) + ) opposite_dummy = utils.smt_not(dummy_variables[i - 1][weight - 1]) constraints.append(utils.smt_assert(utils.smt_implies(hw_list[i], opposite_dummy))) opposite_dummy = utils.smt_not(dummy_variables[n - 2][weight - 1]) @@ -247,15 +236,15 @@ def _sequential_counter_algorithm(self, hw_list, weight, dummy_id, greater_or_eq return dummy_variables, constraints def _sequential_counter(self, hw_list, weight): - return self._sequential_counter_algorithm(hw_list, weight, 'dummy_hw_0') + return self._sequential_counter_algorithm(hw_list, weight, "dummy_hw_0") def _sequential_counter_greater_or_equal(self, weight, dummy_id): - hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith('hw_')] - variables, constraints = self._sequential_counter_algorithm(hw_list, weight, dummy_id, - greater_or_equal=True) + hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith("hw_")] + variables, constraints = self._sequential_counter_algorithm(hw_list, weight, dummy_id, greater_or_equal=True) number_of_declarations = len(self._variables_list) formulae = self._model_constraints[ - len(constants.MODEL_PREFIX)+number_of_declarations:-len(constants.MODEL_SUFFIX)] + len(constants.MODEL_PREFIX) + number_of_declarations : -len(constants.MODEL_SUFFIX) + ] self._variables_list.extend(variables) self._declarations_builder() formulae.extend(constraints) @@ -263,10 +252,15 @@ def _sequential_counter_greater_or_equal(self, weight, dummy_id): def calculate_component_weight(self, component, out_suffix, output_values_dict): weight = 0 - if ('MODADD' in component.description or 'AND' in component.description - or 'OR' in component.description or SBOX in component.type): - weight = sum([output_values_dict[f'hw_{component.id}_{i}{out_suffix}'] - for i in range(component.output_bit_size)]) + if ( + "MODADD" in component.description + or "AND" in component.description + or "OR" in component.description + or SBOX in component.type + ): + weight = sum( + [output_values_dict[f"hw_{component.id}_{i}{out_suffix}"] for i in range(component.output_bit_size)] + ) return weight def cipher_input_variables(self): @@ -290,9 +284,11 @@ def cipher_input_variables(self): 'key_62', 'key_63'] """ - cipher_input_bit_ids = [f'{input_id}_{j}' - for input_id, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) - for j in range(size)] + cipher_input_bit_ids = [ + f"{input_id}_{j}" + for input_id, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + for j in range(size) + ] return cipher_input_bit_ids @@ -321,16 +317,16 @@ def fix_variables_value_constraints(self, fixed_variables=[]): """ constraints = [] for component in fixed_variables: - component_id = component['component_id'] - bit_positions = component['bit_positions'] - bit_values = component['bit_values'] + component_id = component["component_id"] + bit_positions = component["bit_positions"] + bit_values = component["bit_values"] - if component['constraint_type'] not in ['equal', 'not_equal']: - raise ValueError('constraint type not defined or misspelled.') + if component["constraint_type"] not in ["equal", "not_equal"]: + raise ValueError("constraint type not defined or misspelled.") - if component['constraint_type'] == 'equal': + if component["constraint_type"] == "equal": self.update_constraints_for_equal_type(bit_positions, bit_values, component_id, constraints) - elif component['constraint_type'] == 'not_equal': + elif component["constraint_type"] == "not_equal": self.update_constraints_for_not_equal_type(bit_positions, bit_values, component_id, constraints) return constraints @@ -351,19 +347,20 @@ def get_xor_probability_constraints(self, bit_ids, template): def update_constraints_for_equal_type(self, bit_positions, bit_values, component_id, constraints, out_suffix=""): for i, position in enumerate(bit_positions): if bit_values[i]: - constraint = f'{component_id}_{position}{out_suffix}' + constraint = f"{component_id}_{position}{out_suffix}" else: - constraint = utils.smt_not(f'{component_id}_{position}{out_suffix}') + constraint = utils.smt_not(f"{component_id}_{position}{out_suffix}") constraints.append(utils.smt_assert(constraint)) - def update_constraints_for_not_equal_type(self, bit_positions, bit_values, - component_id, constraints, out_suffix=""): + def update_constraints_for_not_equal_type( + self, bit_positions, bit_values, component_id, constraints, out_suffix="" + ): literals = [] for i, position in enumerate(bit_positions): if bit_values[i]: - literals.append(utils.smt_not(f'{component_id}_{position}{out_suffix}')) + literals.append(utils.smt_not(f"{component_id}_{position}{out_suffix}")) else: - literals.append(f'{component_id}_{position}{out_suffix}') + literals.append(f"{component_id}_{position}{out_suffix}") constraints.append(utils.smt_assert(utils.smt_or(literals))) def solve(self, model_type, solver_name=solvers.SOLVER_DEFAULT): @@ -398,38 +395,41 @@ def solve(self, model_type, solver_name=solvers.SOLVER_DEFAULT): 'components_values': {}, 'total_weight': None} """ + def _get_data(data_string, lines): data_line = [line for line in lines if data_string in line][0] - data = float(re.findall(r'\d+\.?\d*', data_line)[0]) + data = float(re.findall(r"\d+\.?\d*", data_line)[0]) return data - solver_specs = [specs for specs in solvers.SMT_SOLVERS_EXTERNAL - if specs['solver_name'] == solver_name.upper()][0] - solver_name = solver_specs['solver_name'] - command = [solver_specs['keywords']['command']['executable']] + solver_specs['keywords']['command']['options'] - smt_input = '\n'.join(self._model_constraints) + '\n' + solver_specs = [specs for specs in solvers.SMT_SOLVERS_EXTERNAL if specs["solver_name"] == solver_name.upper()][ + 0 + ] + solver_name = solver_specs["solver_name"] + command = [solver_specs["keywords"]["command"]["executable"]] + solver_specs["keywords"]["command"]["options"] + smt_input = "\n".join(self._model_constraints) + "\n" solver_process = subprocess.run(command, input=smt_input, capture_output=True, text=True) solver_output = solver_process.stdout.splitlines() - solve_time = _get_data(solver_specs['keywords']['time'], solver_output) - memory = _get_data(solver_specs['keywords']['memory'], solver_output) - if solver_output[0] == 'sat': - if solver_name == 'Z3_EXT': + solve_time = _get_data(solver_specs["keywords"]["time"], solver_output) + memory = _get_data(solver_specs["keywords"]["memory"], solver_output) + if solver_output[0] == "sat": + if solver_name == Z3_EXT: variable2value = z3_parser(solver_output) - elif solver_name == 'YICES_EXT': + elif solver_name == YICES_EXT: variable2value = yices_parser(solver_output) - elif solver_name == 'MATHSAT_EXT': + elif solver_name == MATHSAT_EXT: variable2value = mathsat_parser(solver_output) component2attributes, total_weight = self._parse_solver_output(variable2value) - status = 'SATISFIABLE' + status = SATISFIABLE else: component2attributes, total_weight = {}, None - status = 'UNSATISFIABLE' + status = UNSATISFIABLE if total_weight is not None: total_weight = float(total_weight) - solution = convert_solver_solution_to_dictionary(self._cipher, model_type, solver_name, solve_time, - memory, component2attributes, total_weight) - solution['status'] = status + solution = convert_solver_solution_to_dictionary( + self._cipher, model_type, solver_name, solve_time, memory, component2attributes, total_weight + ) + solution["status"] = status return solution @@ -441,7 +441,7 @@ def weight_constraints(self, weight): - ``weight`` -- **integer**; represents the total weight of the trail """ - hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith('hw_')] + hw_list = [variable_id for variable_id in self._variables_list if variable_id.startswith("hw_")] if weight == 0: return [], [utils.smt_assert(utils.smt_not(variable)) for variable in hw_list] @@ -474,7 +474,7 @@ def model_constraints(self): ValueError: No model generated """ if not self._model_constraints: - raise ValueError('No model generated') + raise ValueError("No model generated") return self._model_constraints @property diff --git a/claasp/cipher_modules/models/smt/smt_models/smt_cipher_model.py b/claasp/cipher_modules/models/smt/smt_models/smt_cipher_model.py index 9eda9217a..82941f32d 100644 --- a/claasp/cipher_modules/models/smt/smt_models/smt_cipher_model.py +++ b/claasp/cipher_modules/models/smt/smt_models/smt_cipher_model.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import time from claasp.cipher_modules.models.smt import solvers @@ -24,12 +22,20 @@ from claasp.cipher_modules.models.smt.utils import constants from claasp.cipher_modules.models.smt.utils.utils import get_component_hex_value from claasp.cipher_modules.models.utils import set_component_solution -from claasp.name_mappings import (SBOX, WORD_OPERATION, CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, - MIX_COLUMN, CIPHER) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CIPHER, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) class SmtCipherModel(SmtModel): - def __init__(self, cipher, counter='sequential'): + def __init__(self, cipher, counter="sequential"): super().__init__(cipher, counter) def build_cipher_model(self, fixed_variables=[]): @@ -58,15 +64,16 @@ def build_cipher_model(self, fixed_variables=[]): variables = [] self._variables_list = [] constraints = self.fix_variables_value_constraints(fixed_variables) - component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION] - operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR'] + component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION) + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR") self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.smt_constraints() @@ -75,8 +82,9 @@ def build_cipher_model(self, fixed_variables=[]): self._variables_list.extend(self.cipher_input_variables()) self._declarations_builder() - self._model_constraints = \ + self._model_constraints = ( constants.MODEL_PREFIX + self._declarations + self._model_constraints + constants.MODEL_SUFFIX + ) def find_missing_bits(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): """ @@ -116,12 +124,12 @@ def find_missing_bits(self, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT) self.build_cipher_model(fixed_variables=fixed_values) end_building_time = time.time() solution = self.solve(CIPHER, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution["building_time_seconds"] = end_building_time - start_building_time return solution def _parse_solver_output(self, variable2value): - out_suffix = '' + out_suffix = "" components_solutions = self._get_cipher_inputs_components_solutions(out_suffix, variable2value) for component in self._cipher.get_all_components(): hex_value = get_component_hex_value(component, out_suffix, variable2value) diff --git a/claasp/cipher_modules/models/smt/smt_models/smt_deterministic_truncated_xor_differential_model.py b/claasp/cipher_modules/models/smt/smt_models/smt_deterministic_truncated_xor_differential_model.py index 1a59e669a..12f6a4c44 100644 --- a/claasp/cipher_modules/models/smt/smt_models/smt_deterministic_truncated_xor_differential_model.py +++ b/claasp/cipher_modules/models/smt/smt_models/smt_deterministic_truncated_xor_differential_model.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,8 +20,10 @@ class SmtDeterministicTruncatedXorDifferentialModel(SmtModel): - def __init__(self, cipher, counter='sequential'): - raise NotImplementedError("The model is not implemented since, at the best of the authors knowledge, " - "deterministic truncated XOR differential model cannot take any advantage " - "of an SMT solver. Therefore, there is no SMT implementation for deterministic " - "truncated XOR differential model.") + def __init__(self, cipher, counter="sequential"): + raise NotImplementedError( + "The model is not implemented since, at the best of the authors knowledge, " + "deterministic truncated XOR differential model cannot take any advantage " + "of an SMT solver. Therefore, there is no SMT implementation for deterministic " + "truncated XOR differential model." + ) diff --git a/claasp/cipher_modules/models/smt/smt_models/smt_xor_differential_model.py b/claasp/cipher_modules/models/smt/smt_models/smt_xor_differential_model.py index df3b3a222..0a192ad55 100644 --- a/claasp/cipher_modules/models/smt/smt_models/smt_xor_differential_model.py +++ b/claasp/cipher_modules/models/smt/smt_models/smt_xor_differential_model.py @@ -1,34 +1,40 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import time from claasp.cipher_modules.models.smt import solvers from claasp.cipher_modules.models.smt.smt_model import SmtModel from claasp.cipher_modules.models.smt.utils import constants, utils from claasp.cipher_modules.models.utils import set_component_solution, get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, - MIX_COLUMN, SBOX, WORD_OPERATION, XOR_DIFFERENTIAL) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, + XOR_DIFFERENTIAL, +) class SmtXorDifferentialModel(SmtModel): - def __init__(self, cipher, counter='sequential'): + def __init__(self, cipher, counter="sequential"): super().__init__(cipher, counter) def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): @@ -64,14 +70,15 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) constraints = self.fix_variables_value_constraints(fixed_variables) component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION) - operation_types = ('AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'XOR') + operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "XOR") self._model_constraints = constraints for component in self._cipher.get_all_components(): operation = component.description[0] if component.type not in component_types or ( - WORD_OPERATION == component.type and operation not in operation_types): - print(f'{component.id} not yet implemented') + WORD_OPERATION == component.type and operation not in operation_types + ): + print(f"{component.id} not yet implemented") else: variables, constraints = component.smt_xor_differential_propagation_constraints(self) @@ -85,10 +92,13 @@ def build_xor_differential_trail_model(self, weight=-1, fixed_variables=[]): self._variables_list.extend(self.cipher_input_variables()) self._declarations_builder() - self._model_constraints = \ + self._model_constraints = ( constants.MODEL_PREFIX + self._declarations + self._model_constraints + constants.MODEL_SUFFIX + ) - def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_differential_trails_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ): """ Return a list of solutions containing all the XOR differential trails having the ``fixed_weight`` weight. By default, the search is set in the single-key setting. @@ -132,34 +142,47 @@ def find_all_xor_differential_trails_with_fixed_weight(self, fixed_weight, fixed start_building_time = time.time() self.build_xor_differential_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution["building_time_seconds"] = end_building_time - start_building_time solutions_list = [] - while solution['total_weight'] is not None: + while solution["total_weight"] is not None: solutions_list.append(solution) operands = self.get_operands(solution) for component in self._cipher.get_all_components(): bit_len = component.output_bit_size - is_word_operation = component.type == WORD_OPERATION and component.description[0] in \ - ('AND', 'MODADD', 'MODSUB', 'OR', 'SHIFT_BY_VARIABLE_AMOUNT') + is_word_operation = component.type == WORD_OPERATION and component.description[0] in ( + "AND", + "MODADD", + "MODSUB", + "OR", + "SHIFT_BY_VARIABLE_AMOUNT", + ) if component.type == SBOX or is_word_operation: - value_to_avoid = int(solution['components_values'][component.id]['value'], base=16) - operands.extend([utils.smt_not(f'{component.id}_{j}') - if value_to_avoid >> (bit_len - 1 - j) & 1 - else f'{component.id}_{j}' - for j in range(bit_len)]) + value_to_avoid = int(solution["components_values"][component.id]["value"], base=16) + operands.extend( + [ + utils.smt_not(f"{component.id}_{j}") + if value_to_avoid >> (bit_len - 1 - j) & 1 + else f"{component.id}_{j}" + for j in range(bit_len) + ] + ) clause = utils.smt_or(operands) - self._model_constraints = self._model_constraints[:-len(constants.MODEL_SUFFIX)] \ - + [utils.smt_assert(clause)] + constants.MODEL_SUFFIX + self._model_constraints = ( + self._model_constraints[: -len(constants.MODEL_SUFFIX)] + + [utils.smt_assert(clause)] + + constants.MODEL_SUFFIX + ) solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_all_xor_differential_trails_with_fixed_weight" + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_all_xor_differential_trails_with_fixed_weight" return solutions_list - def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_differential_trails_with_weight_at_most( + self, min_weight, max_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ): """ Return a list of solutions. By default, the search is set in the single-key setting. @@ -206,11 +229,11 @@ def find_all_xor_differential_trails_with_weight_at_most(self, min_weight, max_w """ solutions_list = [] for weight in range(min_weight, max_weight + 1): - solutions = self.find_all_xor_differential_trails_with_fixed_weight(weight, - fixed_values=fixed_values, - solver_name=solver_name) + solutions = self.find_all_xor_differential_trails_with_fixed_weight( + weight, fixed_values=fixed_values, solver_name=solver_name + ) for solution in solutions: - solution['test_name'] = "find_all_xor_differential_trails_with_weight_at_most" + solution["test_name"] = "find_all_xor_differential_trails_with_weight_at_most" solutions_list.extend(solutions) return solutions_list @@ -265,21 +288,21 @@ def find_lowest_weight_xor_differential_trail(self, fixed_values=[], solver_name self.build_xor_differential_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time = solution['solving_time_seconds'] - max_memory = solution['memory_megabytes'] - while solution['total_weight'] is None: + solution["building_time_seconds"] = end_building_time - start_building_time + total_time = solution["solving_time_seconds"] + max_memory = solution["memory_megabytes"] + while solution["total_weight"] is None: current_weight += 1 start_building_time = time.time() self.build_xor_differential_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time += solution['solving_time_seconds'] - max_memory = max((max_memory, solution['memory_megabytes'])) - solution['solving_time_seconds'] = total_time - solution['memory_megabytes'] = max_memory - solution['test_name'] = "find_lowest_weight_xor_differential_trail" + solution["building_time_seconds"] = end_building_time - start_building_time + total_time += solution["solving_time_seconds"] + max_memory = max((max_memory, solution["memory_megabytes"])) + solution["solving_time_seconds"] = total_time + solution["memory_megabytes"] = max_memory + solution["test_name"] = "find_lowest_weight_xor_differential_trail" return solution @@ -332,13 +355,14 @@ def find_one_xor_differential_trail(self, fixed_values=[], solver_name=solvers.S self.build_xor_differential_trail_model(fixed_variables=fixed_values) end_building_time = time.time() solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_differential_trail" + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_differential_trail" return solution - def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_one_xor_differential_trail_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ): """ Return the solution representing a XOR differential trail whose probability is ``2 ** fixed_weight``. By default, the search is set in the single-key setting. @@ -382,32 +406,34 @@ def find_one_xor_differential_trail_with_fixed_weight(self, fixed_weight, fixed_ start_building_time = time.time() self.build_xor_differential_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() solution = self.solve(XOR_DIFFERENTIAL, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_differential_trail_with_fixed_weight" + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_differential_trail_with_fixed_weight" return solution def get_operands(self, solution): operands = [] for input_, bit_len in zip(self._cipher.inputs, self._cipher.inputs_bit_size): - value_to_avoid = int(solution['components_values'][input_]['value'], base=16) - operands.extend([utils.smt_not(f'{input_}_{j}') - if value_to_avoid >> (bit_len - 1 - j) & 1 - else f'{input_}_{j}' - for j in range(bit_len)]) + value_to_avoid = int(solution["components_values"][input_]["value"], base=16) + operands.extend( + [ + utils.smt_not(f"{input_}_{j}") if value_to_avoid >> (bit_len - 1 - j) & 1 else f"{input_}_{j}" + for j in range(bit_len) + ] + ) return operands def _parse_solver_output(self, variable2value): - out_suffix = '' + out_suffix = "" components_solutions = self._get_cipher_inputs_components_solutions(out_suffix, variable2value) total_weight = 0 for component in self._cipher.get_all_components(): hex_value = utils.get_component_hex_value(component, out_suffix, variable2value) weight = self.calculate_component_weight(component, out_suffix, variable2value) component_solution = set_component_solution(hex_value, weight) - components_solutions[f'{component.id}{out_suffix}'] = component_solution + components_solutions[f"{component.id}{out_suffix}"] = component_solution total_weight += weight return components_solutions, total_weight diff --git a/claasp/cipher_modules/models/smt/smt_models/smt_xor_linear_model.py b/claasp/cipher_modules/models/smt/smt_models/smt_xor_linear_model.py index e252e0c6e..9d228d293 100644 --- a/claasp/cipher_modules/models/smt/smt_models/smt_xor_linear_model.py +++ b/claasp/cipher_modules/models/smt/smt_models/smt_xor_linear_model.py @@ -1,38 +1,48 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import time from claasp.cipher_modules.models.smt import solvers from claasp.cipher_modules.models.smt.utils import constants, utils from claasp.cipher_modules.models.smt.smt_model import SmtModel from claasp.cipher_modules.models.smt.utils.constants import INPUT_BIT_ID_SUFFIX, OUTPUT_BIT_ID_SUFFIX -from claasp.cipher_modules.models.utils import get_bit_bindings, set_component_solution, \ - get_single_key_scenario_format_for_fixed_values -from claasp.name_mappings import (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, - MIX_COLUMN, SBOX, WORD_OPERATION, XOR_LINEAR, INPUT_KEY) +from claasp.cipher_modules.models.utils import ( + get_bit_bindings, + get_single_key_scenario_format_for_fixed_values, + set_component_solution, +) +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONSTANT, + INPUT_KEY, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, + XOR_LINEAR, +) class SmtXorLinearModel(SmtModel): - def __init__(self, cipher, counter='sequential'): + def __init__(self, cipher, counter="sequential"): super().__init__(cipher, counter) - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join) + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, "_".join) def branch_xor_linear_constraints(self): """ @@ -55,8 +65,10 @@ def branch_xor_linear_constraints(self): '(assert (not (xor xor_2_10_14_o cipher_output_2_12_30_i)))', '(assert (not (xor xor_2_10_15_o cipher_output_2_12_31_i)))'] """ - return [utils.smt_assert(utils.smt_not(utils.smt_xor([output_bit] + input_bits))) - for output_bit, input_bits in self.bit_bindings.items()] + return [ + utils.smt_assert(utils.smt_not(utils.smt_xor([output_bit] + input_bits))) + for output_bit, input_bits in self.bit_bindings.items() + ] def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): """ @@ -89,21 +101,28 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): variables = [] if INPUT_KEY not in [variable["component_id"] for variable in fixed_variables]: self._cipher = self._cipher.remove_key_schedule() - self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self._cipher, '_'.join) + self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self._cipher, "_".join) if fixed_variables == []: fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher) constraints = self.fix_variables_value_xor_linear_constraints(fixed_variables) self._model_constraints = constraints for component in self._cipher.get_all_components(): - component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, - SBOX, MIX_COLUMN, WORD_OPERATION) + component_types = ( + CONSTANT, + INTERMEDIATE_OUTPUT, + CIPHER_OUTPUT, + LINEAR_LAYER, + SBOX, + MIX_COLUMN, + WORD_OPERATION, + ) operation = component.description[0] operation_types = ("AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB") if component.type in component_types and (component.type != WORD_OPERATION or operation in operation_types): variables, constraints = component.smt_xor_linear_mask_propagation_constraints(self) else: - print(f'{component.id} not yet implemented') + print(f"{component.id} not yet implemented") self._variables_list.extend(variables) self._model_constraints.extend(constraints) @@ -118,8 +137,9 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]): self._variables_list.extend(self.cipher_input_xor_linear_variables()) self._declarations_builder() - self._model_constraints = \ + self._model_constraints = ( constants.MODEL_PREFIX + self._declarations + self._model_constraints + constants.MODEL_SUFFIX + ) def cipher_input_xor_linear_variables(self): """ @@ -143,14 +163,17 @@ def cipher_input_xor_linear_variables(self): 'key_63_o'] """ out_suffix = constants.OUTPUT_BIT_ID_SUFFIX - cipher_input_bit_ids = [f'{input_id}_{j}{out_suffix}' - for input_id, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) - for j in range(size)] + cipher_input_bit_ids = [ + f"{input_id}_{j}{out_suffix}" + for input_id, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size) + for j in range(size) + ] return cipher_input_bit_ids - def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_linear_trails_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ): """ Return a list of solutions containing all the XOR linear trails having weight equal to ``fixed_weight``. By default, the search removes the key schedule, if any. @@ -190,18 +213,18 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value start_building_time = time.time() self.build_xor_linear_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time + solution["building_time_seconds"] = end_building_time - start_building_time solutions_list = [] - while solution['total_weight'] is not None: + while solution["total_weight"] is not None: solutions_list.append(solution) operands = [] - for component in solution['components_values']: - value_as_hex_string = solution['components_values'][component]['value'] + for component in solution["components_values"]: + value_as_hex_string = solution["components_values"][component]["value"] value_to_avoid = int(value_as_hex_string, base=16) - bit_len = len(value_as_hex_string) * 4 + bit_len = (len(value_as_hex_string) - 2) * 4 if CONSTANT in component and component.endswith(INPUT_BIT_ID_SUFFIX): continue elif component.endswith(INPUT_BIT_ID_SUFFIX) or component.endswith(OUTPUT_BIT_ID_SUFFIX): @@ -210,21 +233,29 @@ def find_all_xor_linear_trails_with_fixed_weight(self, fixed_weight, fixed_value else: component_id = component suffix = OUTPUT_BIT_ID_SUFFIX - operands.extend([utils.smt_not(f'{component_id}_{j}{suffix}') - if value_to_avoid >> (bit_len - 1 - j) & 1 - else f'{component_id}_{j}{suffix}' - for j in range(bit_len)]) + operands.extend( + [ + utils.smt_not(f"{component_id}_{j}{suffix}") + if value_to_avoid >> (bit_len - 1 - j) & 1 + else f"{component_id}_{j}{suffix}" + for j in range(bit_len) + ] + ) clause = utils.smt_or(operands) - self._model_constraints = self._model_constraints[:-len(constants.MODEL_SUFFIX)] \ - + [utils.smt_assert(clause)] + constants.MODEL_SUFFIX + self._model_constraints = ( + self._model_constraints[: -len(constants.MODEL_SUFFIX)] + + [utils.smt_assert(clause)] + + constants.MODEL_SUFFIX + ) solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_all_xor_linear_trails_with_fixed_weight" + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_all_xor_linear_trails_with_fixed_weight" return solutions_list - def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_all_xor_linear_trails_with_weight_at_most( + self, min_weight, max_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ): """ Return a list of solutions. By default, the search removes the key schedule, if any. @@ -266,11 +297,11 @@ def find_all_xor_linear_trails_with_weight_at_most(self, min_weight, max_weight, """ solutions_list = [] for weight in range(min_weight, max_weight + 1): - solutions = self.find_all_xor_linear_trails_with_fixed_weight(weight, - fixed_values=fixed_values, - solver_name=solver_name) + solutions = self.find_all_xor_linear_trails_with_fixed_weight( + weight, fixed_values=fixed_values, solver_name=solver_name + ) for solution in solutions: - solution['test_name'] = "find_all_xor_linear_trails_with_weight_at_most" + solution["test_name"] = "find_all_xor_linear_trails_with_weight_at_most" solutions_list.extend(solutions) return solutions_list @@ -321,21 +352,21 @@ def find_lowest_weight_xor_linear_trail(self, fixed_values=[], solver_name=solve self.build_xor_linear_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time = solution['solving_time_seconds'] - max_memory = solution['memory_megabytes'] - while solution['total_weight'] is None: + solution["building_time_seconds"] = end_building_time - start_building_time + total_time = solution["solving_time_seconds"] + max_memory = solution["memory_megabytes"] + while solution["total_weight"] is None: current_weight += 1 start_building_time = time.time() self.build_xor_linear_trail_model(weight=current_weight, fixed_variables=fixed_values) end_building_time = time.time() solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - total_time += solution['solving_time_seconds'] - max_memory = max((max_memory, solution['memory_megabytes'])) - solution['solving_time_seconds'] = total_time - solution['memory_megabytes'] = max_memory - solution['test_name'] = "find_lowest_weight_xor_linear_trail" + solution["building_time_seconds"] = end_building_time - start_building_time + total_time += solution["solving_time_seconds"] + max_memory = max((max_memory, solution["memory_megabytes"])) + solution["solving_time_seconds"] = total_time + solution["memory_megabytes"] = max_memory + solution["test_name"] = "find_lowest_weight_xor_linear_trail" return solution @@ -384,13 +415,14 @@ def find_one_xor_linear_trail(self, fixed_values=[], solver_name=solvers.SOLVER_ self.build_xor_linear_trail_model(fixed_variables=fixed_values) end_building_time = time.time() solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_linear_trail" + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_linear_trail" return solution - def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values=[], - solver_name=solvers.SOLVER_DEFAULT): + def find_one_xor_linear_trail_with_fixed_weight( + self, fixed_weight, fixed_values=[], solver_name=solvers.SOLVER_DEFAULT + ): """ Return the solution representing a XOR linear trail whose weight is ``fixed_weight``. By default, the search removes the key schedule, if any. @@ -430,11 +462,11 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values start_building_time = time.time() self.build_xor_linear_trail_model(weight=fixed_weight, fixed_variables=fixed_values) if self._counter == self._sequential_counter: - self._sequential_counter_greater_or_equal(fixed_weight, 'dummy_hw_1') + self._sequential_counter_greater_or_equal(fixed_weight, "dummy_hw_1") end_building_time = time.time() solution = self.solve(XOR_LINEAR, solver_name=solver_name) - solution['building_time_seconds'] = end_building_time - start_building_time - solution['test_name'] = "find_one_xor_linear_trail_with_fixed_weight" + solution["building_time_seconds"] = end_building_time - start_building_time + solution["test_name"] = "find_one_xor_linear_trail_with_fixed_weight" return solution @@ -464,18 +496,19 @@ def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]): constraints = [] out_suffix = constants.OUTPUT_BIT_ID_SUFFIX for component in fixed_variables: - component_id = component['component_id'] - bit_positions = component['bit_positions'] - bit_values = component['bit_values'] + component_id = component["component_id"] + bit_positions = component["bit_positions"] + bit_values = component["bit_values"] - if component['constraint_type'] not in ('equal', 'not_equal'): - raise ValueError('constraint type not defined or misspelled.') + if component["constraint_type"] not in ("equal", "not_equal"): + raise ValueError("constraint type not defined or misspelled.") - if component['constraint_type'] == 'equal': + if component["constraint_type"] == "equal": self.update_constraints_for_equal_type(bit_positions, bit_values, component_id, constraints, out_suffix) else: - self.update_constraints_for_not_equal_type(bit_positions, bit_values, - component_id, constraints, out_suffix) + self.update_constraints_for_not_equal_type( + bit_positions, bit_values, component_id, constraints, out_suffix + ) return constraints @@ -491,11 +524,11 @@ def _parse_solver_output(self, variable2value): hex_solution = utils.get_component_hex_value(component, out_suffix, variable2value) weight = self.calculate_component_weight(component, out_suffix, variable2value) component_solution = set_component_solution(hex_solution, weight, 1) - components_solutions[f'{component.id}{out_suffix}'] = component_solution + components_solutions[f"{component.id}{out_suffix}"] = component_solution total_weight += weight input_hex_value = utils.get_component_hex_value(component, in_suffix, variable2value) component_solution = set_component_solution(input_hex_value, 0, 1) - components_solutions[f'{component.id}{in_suffix}'] = component_solution + components_solutions[f"{component.id}{in_suffix}"] = component_solution return components_solutions, total_weight diff --git a/claasp/cipher_modules/models/smt/solvers.py b/claasp/cipher_modules/models/smt/solvers.py index cf47eaa3b..670ec6758 100644 --- a/claasp/cipher_modules/models/smt/solvers.py +++ b/claasp/cipher_modules/models/smt/solvers.py @@ -29,8 +29,12 @@ needed. """ +# external solvers definition +MATHSAT_EXT = "MATHSAT_EXT" +YICES_EXT = "YICES_EXT" +Z3_EXT = "Z3_EXT" -SOLVER_DEFAULT = "Z3_EXT" +SOLVER_DEFAULT = Z3_EXT SMT_SOLVERS_INTERNAL = [] @@ -39,11 +43,11 @@ SMT_SOLVERS_EXTERNAL = [ { "solver_brand_name": "MathSAT 5", - "solver_name": "MATHSAT_EXT", + "solver_name": MATHSAT_EXT, "keywords": { "command": { "executable": "mathsat", - "options": ['-model', '-stats'], + "options": ["-model", "-stats"], "input_file": "", "solve": "", "output_file": "", @@ -57,7 +61,7 @@ }, { "solver_brand_name": "Yices2", - "solver_name": "YICES_EXT", + "solver_name": YICES_EXT, "keywords": { "command": { "executable": "yices-smt2", @@ -75,7 +79,7 @@ }, { "solver_brand_name": "Z3 Theorem Prover", - "solver_name": "Z3_EXT", + "solver_name": Z3_EXT, "keywords": { "command": { "executable": "z3", diff --git a/claasp/cipher_modules/models/smt/utils/constants.py b/claasp/cipher_modules/models/smt/utils/constants.py index eaf470a88..d1a26199e 100644 --- a/claasp/cipher_modules/models/smt/utils/constants.py +++ b/claasp/cipher_modules/models/smt/utils/constants.py @@ -1,4 +1,4 @@ -INPUT_BIT_ID_SUFFIX = '_i' -OUTPUT_BIT_ID_SUFFIX = '_o' -MODEL_PREFIX = ['(set-option :print-success false)', '(set-logic QF_UF)'] -MODEL_SUFFIX = ['(check-sat)', '(get-model)', '(exit)'] +INPUT_BIT_ID_SUFFIX = "_i" +OUTPUT_BIT_ID_SUFFIX = "_o" +MODEL_PREFIX = ["(set-option :print-success false)", "(set-logic QF_UF)"] +MODEL_SUFFIX = ["(check-sat)", "(get-model)", "(exit)"] diff --git a/claasp/cipher_modules/models/smt/utils/utils.py b/claasp/cipher_modules/models/smt/utils/utils.py index 0e4262f23..1a8aff918 100644 --- a/claasp/cipher_modules/models/smt/utils/utils.py +++ b/claasp/cipher_modules/models/smt/utils/utils.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,6 +20,7 @@ # - Build formulae - # # ------------------------ # + def smt_and(formulae): """ Return a string representing the AND of formulae in SMT-LIB standard. @@ -35,7 +35,7 @@ def smt_and(formulae): sage: smt_and(['a', 'c', 'e']) '(and a c e)' """ - return f'(and {" ".join(formulae)})' + return f"(and {' '.join(formulae)})" def smt_assert(formula): @@ -52,7 +52,7 @@ def smt_assert(formula): sage: smt_assert('(= a b c)') '(assert (= a b c))' """ - return f'(assert {formula})' + return f"(assert {formula})" def smt_distinct(variable_0, variable_1): @@ -70,7 +70,7 @@ def smt_distinct(variable_0, variable_1): sage: smt_distinct('a', 'q') '(distinct a q)' """ - return f'(distinct {variable_0} {variable_1})' + return f"(distinct {variable_0} {variable_1})" def smt_equivalent(formulae): @@ -87,7 +87,7 @@ def smt_equivalent(formulae): sage: smt_equivalent(['a', 'b', 'c', 'd']) '(= a b c d)' """ - return f'(= {" ".join(formulae)})' + return f"(= {' '.join(formulae)})" def smt_implies(antecedent, consequent): @@ -105,7 +105,7 @@ def smt_implies(antecedent, consequent): sage: smt_implies('(and a c)', '(or l f)') '(=> (and a c) (or l f))' """ - return f'(=> {antecedent} {consequent})' + return f"(=> {antecedent} {consequent})" def smt_ite(condition, consequent, alternative): @@ -124,7 +124,7 @@ def smt_ite(condition, consequent, alternative): sage: smt_ite('t', '(and a b)', '(and a e)') '(ite t (and a b) (and a e))' """ - return f'(ite {condition} {consequent} {alternative})' + return f"(ite {condition} {consequent} {alternative})" def smt_not(formula): @@ -141,7 +141,7 @@ def smt_not(formula): sage: smt_not('(xor a e)') '(not (xor a e))' """ - return f'(not {formula})' + return f"(not {formula})" def smt_or(formulae): @@ -158,7 +158,7 @@ def smt_or(formulae): sage: smt_or(['b', 'd', 'f']) '(or b d f)' """ - return f'(or {" ".join(formulae)})' + return f"(or {' '.join(formulae)})" def smt_xor(formulae): @@ -175,7 +175,7 @@ def smt_xor(formulae): sage: smt_xor(['b', 'd', 'f']) '(xor b d f)' """ - return f'(xor {" ".join(formulae)})' + return f"(xor {' '.join(formulae)})" def smt_carry(x, y, previous_carry): @@ -236,9 +236,9 @@ def get_component_hex_value(component, out_suffix, variable2value): value = 0 for i in range(output_bit_size): value <<= 1 - if f'{component.id}_{i}{out_suffix}' in variable2value: - value ^= variable2value[f'{component.id}_{i}{out_suffix}'] + if f"{component.id}_{i}{out_suffix}" in variable2value: + value ^= variable2value[f"{component.id}_{i}{out_suffix}"] hex_digits = output_bit_size // 4 + (output_bit_size % 4 != 0) - hex_value = f'{value:0{hex_digits}x}' + hex_value = f"{value:#0{hex_digits + 2}x}" return hex_value diff --git a/claasp/cipher_modules/models/utils.py b/claasp/cipher_modules/models/utils.py index 733cbccb6..a58b9b178 100644 --- a/claasp/cipher_modules/models/utils.py +++ b/claasp/cipher_modules/models/utils.py @@ -36,7 +36,6 @@ INPUT_MESSAGE, INPUT_STATE, ) -from claasp.utils.utils import get_k_th_bit def add_arcs(arcs, component, curr_input_bit_ids, input_bit_size, intermediate_output_arcs, previous_output_bit_ids): @@ -812,7 +811,7 @@ def extract_bits(columns, positions): for j in range(num_columns): byte_index = (bit_size - positions[i] - 1) // 8 bit_index = positions[i] % 8 - result[i, j] = get_k_th_bit(columns[:, j][byte_index], bit_index) + result[i, j] = 1 & (columns[:, j][byte_index] >> bit_index) return result diff --git a/claasp/cipher_modules/neural_network_tests.py b/claasp/cipher_modules/neural_network_tests.py index de4f62fe0..fcc4acee6 100644 --- a/claasp/cipher_modules/neural_network_tests.py +++ b/claasp/cipher_modules/neural_network_tests.py @@ -391,10 +391,14 @@ class RoundNumberTooHigh(Exception): def get_neural_network(self, network_name, input_size, word_size=None, depth=1): from tensorflow.keras.optimizers import Adam + if input_size is None or input_size==0: + input_size = self.cipher.output_bit_size + if word_size is None or word_size == 0: + word_size = self.cipher.output_bit_size + input_size = int(input_size) + word_size = int(word_size) + depth = int(depth) if network_name == 'gohr_resnet': - if word_size is None or word_size == 0: - print("Word size not specified for ", network_name, ", defaulting to ciphertext size...") - word_size = self.cipher.output_bit_size neural_network = self._make_resnet(word_size=word_size, input_size=input_size, depth=depth) elif network_name == 'dbitnet': neural_network = self._make_dbitnet(input_size=input_size) @@ -498,11 +502,11 @@ def train_neural_distinguisher(self, data_generator, starting_round, neural_netw x_eval, y_eval = data_generator(samples=testing_samples, nr=nr) if save_prefix is None: h = neural_network.fit(x, y, epochs=int(epochs), batch_size=bs, - validation_data=(x_eval, y_eval)) + validation_data=(x_eval, y_eval), verbose=2) else: h = neural_network.fit(x, y, epochs=int(epochs), batch_size=bs, validation_data=(x_eval, y_eval), - callbacks=[self.make_checkpoint(save_prefix + str(nr)+'.h5')]) + callbacks=[self._make_checkpoint(save_prefix + str(nr)+'.h5')], verbose=2) acc[nr] = np.max(h.history["val_acc"]) print(f'Validation accuracy at {nr} rounds :{acc[nr]}') nr +=1 @@ -700,7 +704,7 @@ def data_generator(nr, samples): print(f'Training {neural_net} on input difference {[hex(x) for x in input_difference]} ({self.cipher.inputs}), from round {nr}...') neural_results = self.train_neural_distinguisher(data_generator, nr, neural_network, training_samples, - testing_samples, number_of_epochs) + testing_samples, number_of_epochs, save_prefix=save_prefix) neural_distinguisher_test_results['test_results']['plaintext']['cipher_output'][ 'neural_distinguisher_test'].append({'accuracies': list(neural_results.values())}) diff --git a/claasp/cipher_modules/report.py b/claasp/cipher_modules/report.py index 6bf416736..a6b736bd5 100644 --- a/claasp/cipher_modules/report.py +++ b/claasp/cipher_modules/report.py @@ -905,7 +905,7 @@ def save_as_image(self, show_as_hex=False, test_name=None, fixed_input=None, fix sage: speck = SpeckBlockCipher(number_of_rounds=5) sage: avalanche_test_results = AvalancheTests(speck).avalanche_tests() sage: report = Report(avalanche_test_results) - sage: report.save_as_image(test_name='avalanche_weight_vectors', fixed_input='plaintext', fixed_output='round_output', fixed_input_difference='average') # random + sage: report.save_as_image(test_name='avalanche_weight_vectors', fixed_input='plaintext', fixed_output='round_output', fixed_input_difference='average') # doctest: +SKIP """ time = '_date:' + 'time:'.join(str(datetime.now()).split(' ')) diff --git a/claasp/ciphers/block_ciphers/kalyna_block_cipher.py b/claasp/ciphers/block_ciphers/kalyna_block_cipher.py new file mode 100644 index 000000000..ba26f60a0 --- /dev/null +++ b/claasp/ciphers/block_ciphers/kalyna_block_cipher.py @@ -0,0 +1,705 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** +from claasp.cipher import Cipher +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT +import numpy as np + + +class KalynaBlockCipher(Cipher): + """ + Return a cipher object of Kalyna-128/128 Block Cipher. + The technical specifications along with the test vectors can be found here: https://eprint.iacr.org/2015/650.pdf + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.kalyna_block_cipher import KalynaBlockCipher + sage: kalyna = KalynaBlockCipher() + sage: key = 0x0F0E0D0C0B0A09080706050403020100 + sage: plaintext = 0x1F1E1D1C1B1A19181716151413121110 + sage: ciphertext = 0x06ADD2B439EAC9E120AC9B777D1CBF81 + sage: kalyna.evaluate([key, plaintext]) == ciphertext + True + """ + + def __init__(self, number_of_rounds=10): + # cipher dictionary initialize + self.CIPHER_BLOCK_SIZE = 128 + self.KEY_BLOCK_SIZE = 128 + self.NROUNDS = number_of_rounds + self.SBOX_BIT_SIZE = 8 + self.round_keys = {} + + super().__init__( + family_name="kalyna_block_cipher", + cipher_type="BLOCK_CIPHER", + cipher_inputs=[INPUT_KEY, INPUT_PLAINTEXT], + cipher_inputs_bit_size=[self.KEY_BLOCK_SIZE, self.CIPHER_BLOCK_SIZE], + cipher_output_bit_size=self.CIPHER_BLOCK_SIZE, + ) + # fmt: off + pi0_e = [ + 0xA8, 0x43, 0x5F, 0x06, 0x6B, 0x75, 0x6C, 0x59, 0x71, 0xDF, 0x87, 0x95, 0x17, 0xF0, 0xD8, 0x09, + 0x6D, 0xF3, 0x1D, 0xCB, 0xC9, 0x4D, 0x2C, 0xAF, 0x79, 0xE0, 0x97, 0xFD, 0x6F, 0x4B, 0x45, 0x39, + 0x3E, 0xDD, 0xA3, 0x4F, 0xB4, 0xB6, 0x9A, 0x0E, 0x1F, 0xBF, 0x15, 0xE1, 0x49, 0xD2, 0x93, 0xC6, + 0x92, 0x72, 0x9E, 0x61, 0xD1, 0x63, 0xFA, 0xEE, 0xF4, 0x19, 0xD5, 0xAD, 0x58, 0xA4, 0xBB, 0xA1, + 0xDC, 0xF2, 0x83, 0x37, 0x42, 0xE4, 0x7A, 0x32, 0x9C, 0xCC, 0xAB, 0x4A, 0x8F, 0x6E, 0x04, 0x27, + 0x2E, 0xE7, 0xE2, 0x5A, 0x96, 0x16, 0x23, 0x2B, 0xC2, 0x65, 0x66, 0x0F, 0xBC, 0xA9, 0x47, 0x41, + 0x34, 0x48, 0xFC, 0xB7, 0x6A, 0x88, 0xA5, 0x53, 0x86, 0xF9, 0x5B, 0xDB, 0x38, 0x7B, 0xC3, 0x1E, + 0x22, 0x33, 0x24, 0x28, 0x36, 0xC7, 0xB2, 0x3B, 0x8E, 0x77, 0xBA, 0xF5, 0x14, 0x9F, 0x08, 0x55, + 0x9B, 0x4C, 0xFE, 0x60, 0x5C, 0xDA, 0x18, 0x46, 0xCD, 0x7D, 0x21, 0xB0, 0x3F, 0x1B, 0x89, 0xFF, + 0xEB, 0x84, 0x69, 0x3A, 0x9D, 0xD7, 0xD3, 0x70, 0x67, 0x40, 0xB5, 0xDE, 0x5D, 0x30, 0x91, 0xB1, + 0x78, 0x11, 0x01, 0xE5, 0x00, 0x68, 0x98, 0xA0, 0xC5, 0x02, 0xA6, 0x74, 0x2D, 0x0B, 0xA2, 0x76, + 0xB3, 0xBE, 0xCE, 0xBD, 0xAE, 0xE9, 0x8A, 0x31, 0x1C, 0xEC, 0xF1, 0x99, 0x94, 0xAA, 0xF6, 0x26, + 0x2F, 0xEF, 0xE8, 0x8C, 0x35, 0x03, 0xD4, 0x7F, 0xFB, 0x05, 0xC1, 0x5E, 0x90, 0x20, 0x3D, 0x82, + 0xF7, 0xEA, 0x0A, 0x0D, 0x7E, 0xF8, 0x50, 0x1A, 0xC4, 0x07, 0x57, 0xB8, 0x3C, 0x62, 0xE3, 0xC8, + 0xAC, 0x52, 0x64, 0x10, 0xD0, 0xD9, 0x13, 0x0C, 0x12, 0x29, 0x51, 0xB9, 0xCF, 0xD6, 0x73, 0x8D, + 0x81, 0x54, 0xC0, 0xED, 0x4E, 0x44, 0xA7, 0x2A, 0x85, 0x25, 0xE6, 0xCA, 0x7C, 0x8B, 0x56, 0x80 + ] + pi1_e = [ + 0xCE, 0xBB, 0xEB, 0x92, 0xEA, 0xCB, 0x13, 0xC1, 0xE9, 0x3A, 0xD6, 0xB2, 0xD2, 0x90, 0x17, 0xF8, + 0x42, 0x15, 0x56, 0xB4, 0x65, 0x1C, 0x88, 0x43, 0xC5, 0x5C, 0x36, 0xBA, 0xF5, 0x57, 0x67, 0x8D, + 0x31, 0xF6, 0x64, 0x58, 0x9E, 0xF4, 0x22, 0xAA, 0x75, 0x0F, 0x02, 0xB1, 0xDF, 0x6D, 0x73, 0x4D, + 0x7C, 0x26, 0x2E, 0xF7, 0x08, 0x5D, 0x44, 0x3E, 0x9F, 0x14, 0xC8, 0xAE, 0x54, 0x10, 0xD8, 0xBC, + 0x1A, 0x6B, 0x69, 0xF3, 0xBD, 0x33, 0xAB, 0xFA, 0xD1, 0x9B, 0x68, 0x4E, 0x16, 0x95, 0x91, 0xEE, + 0x4C, 0x63, 0x8E, 0x5B, 0xCC, 0x3C, 0x19, 0xA1, 0x81, 0x49, 0x7B, 0xD9, 0x6F, 0x37, 0x60, 0xCA, + 0xE7, 0x2B, 0x48, 0xFD, 0x96, 0x45, 0xFC, 0x41, 0x12, 0x0D, 0x79, 0xE5, 0x89, 0x8C, 0xE3, 0x20, + 0x30, 0xDC, 0xB7, 0x6C, 0x4A, 0xB5, 0x3F, 0x97, 0xD4, 0x62, 0x2D, 0x06, 0xA4, 0xA5, 0x83, 0x5F, + 0x2A, 0xDA, 0xC9, 0x00, 0x7E, 0xA2, 0x55, 0xBF, 0x11, 0xD5, 0x9C, 0xCF, 0x0E, 0x0A, 0x3D, 0x51, + 0x7D, 0x93, 0x1B, 0xFE, 0xC4, 0x47, 0x09, 0x86, 0x0B, 0x8F, 0x9D, 0x6A, 0x07, 0xB9, 0xB0, 0x98, + 0x18, 0x32, 0x71, 0x4B, 0xEF, 0x3B, 0x70, 0xA0, 0xE4, 0x40, 0xFF, 0xC3, 0xA9, 0xE6, 0x78, 0xF9, + 0x8B, 0x46, 0x80, 0x1E, 0x38, 0xE1, 0xB8, 0xA8, 0xE0, 0x0C, 0x23, 0x76, 0x1D, 0x25, 0x24, 0x05, + 0xF1, 0x6E, 0x94, 0x28, 0x9A, 0x84, 0xE8, 0xA3, 0x4F, 0x77, 0xD3, 0x85, 0xE2, 0x52, 0xF2, 0x82, + 0x50, 0x7A, 0x2F, 0x74, 0x53, 0xB3, 0x61, 0xAF, 0x39, 0x35, 0xDE, 0xCD, 0x1F, 0x99, 0xAC, 0xAD, + 0x72, 0x2C, 0xDD, 0xD0, 0x87, 0xBE, 0x5E, 0xA6, 0xEC, 0x04, 0xC6, 0x03, 0x34, 0xFB, 0xDB, 0x59, + 0xB6, 0xC2, 0x01, 0xF0, 0x5A, 0xED, 0xA7, 0x66, 0x21, 0x7F, 0x8A, 0x27, 0xC7, 0xC0, 0x29, 0xD7 + ] + pi2_e = [ + 0x93, 0xD9, 0x9A, 0xB5, 0x98, 0x22, 0x45, 0xFC, 0xBA, 0x6A, 0xDF, 0x02, 0x9F, 0xDC, 0x51, 0x59, + 0x4A, 0x17, 0x2B, 0xC2, 0x94, 0xF4, 0xBB, 0xA3, 0x62, 0xE4, 0x71, 0xD4, 0xCD, 0x70, 0x16, 0xE1, + 0x49, 0x3C, 0xC0, 0xD8, 0x5C, 0x9B, 0xAD, 0x85, 0x53, 0xA1, 0x7A, 0xC8, 0x2D, 0xE0, 0xD1, 0x72, + 0xA6, 0x2C, 0xC4, 0xE3, 0x76, 0x78, 0xB7, 0xB4, 0x09, 0x3B, 0x0E, 0x41, 0x4C, 0xDE, 0xB2, 0x90, + 0x25, 0xA5, 0xD7, 0x03, 0x11, 0x00, 0xC3, 0x2E, 0x92, 0xEF, 0x4E, 0x12, 0x9D, 0x7D, 0xCB, 0x35, + 0x10, 0xD5, 0x4F, 0x9E, 0x4D, 0xA9, 0x55, 0xC6, 0xD0, 0x7B, 0x18, 0x97, 0xD3, 0x36, 0xE6, 0x48, + 0x56, 0x81, 0x8F, 0x77, 0xCC, 0x9C, 0xB9, 0xE2, 0xAC, 0xB8, 0x2F, 0x15, 0xA4, 0x7C, 0xDA, 0x38, + 0x1E, 0x0B, 0x05, 0xD6, 0x14, 0x6E, 0x6C, 0x7E, 0x66, 0xFD, 0xB1, 0xE5, 0x60, 0xAF, 0x5E, 0x33, + 0x87, 0xC9, 0xF0, 0x5D, 0x6D, 0x3F, 0x88, 0x8D, 0xC7, 0xF7, 0x1D, 0xE9, 0xEC, 0xED, 0x80, 0x29, + 0x27, 0xCF, 0x99, 0xA8, 0x50, 0x0F, 0x37, 0x24, 0x28, 0x30, 0x95, 0xD2, 0x3E, 0x5B, 0x40, 0x83, + 0xB3, 0x69, 0x57, 0x1F, 0x07, 0x1C, 0x8A, 0xBC, 0x20, 0xEB, 0xCE, 0x8E, 0xAB, 0xEE, 0x31, 0xA2, + 0x73, 0xF9, 0xCA, 0x3A, 0x1A, 0xFB, 0x0D, 0xC1, 0xFE, 0xFA, 0xF2, 0x6F, 0xBD, 0x96, 0xDD, 0x43, + 0x52, 0xB6, 0x08, 0xF3, 0xAE, 0xBE, 0x19, 0x89, 0x32, 0x26, 0xB0, 0xEA, 0x4B, 0x64, 0x84, 0x82, + 0x6B, 0xF5, 0x79, 0xBF, 0x01, 0x5F, 0x75, 0x63, 0x1B, 0x23, 0x3D, 0x68, 0x2A, 0x65, 0xE8, 0x91, + 0xF6, 0xFF, 0x13, 0x58, 0xF1, 0x47, 0x0A, 0x7F, 0xC5, 0xA7, 0xE7, 0x61, 0x5A, 0x06, 0x46, 0x44, + 0x42, 0x04, 0xA0, 0xDB, 0x39, 0x86, 0x54, 0xAA, 0x8C, 0x34, 0x21, 0x8B, 0xF8, 0x0C, 0x74, 0x67 + ] + pi3_e = [ + 0x68, 0x8D, 0xCA, 0x4D, 0x73, 0x4B, 0x4E, 0x2A, 0xD4, 0x52, 0x26, 0xB3, 0x54, 0x1E, 0x19, 0x1F, + 0x22, 0x03, 0x46, 0x3D, 0x2D, 0x4A, 0x53, 0x83, 0x13, 0x8A, 0xB7, 0xD5, 0x25, 0x79, 0xF5, 0xBD, + 0x58, 0x2F, 0x0D, 0x02, 0xED, 0x51, 0x9E, 0x11, 0xF2, 0x3E, 0x55, 0x5E, 0xD1, 0x16, 0x3C, 0x66, + 0x70, 0x5D, 0xF3, 0x45, 0x40, 0xCC, 0xE8, 0x94, 0x56, 0x08, 0xCE, 0x1A, 0x3A, 0xD2, 0xE1, 0xDF, + 0xB5, 0x38, 0x6E, 0x0E, 0xE5, 0xF4, 0xF9, 0x86, 0xE9, 0x4F, 0xD6, 0x85, 0x23, 0xCF, 0x32, 0x99, + 0x31, 0x14, 0xAE, 0xEE, 0xC8, 0x48, 0xD3, 0x30, 0xA1, 0x92, 0x41, 0xB1, 0x18, 0xC4, 0x2C, 0x71, + 0x72, 0x44, 0x15, 0xFD, 0x37, 0xBE, 0x5F, 0xAA, 0x9B, 0x88, 0xD8, 0xAB, 0x89, 0x9C, 0xFA, 0x60, + 0xEA, 0xBC, 0x62, 0x0C, 0x24, 0xA6, 0xA8, 0xEC, 0x67, 0x20, 0xDB, 0x7C, 0x28, 0xDD, 0xAC, 0x5B, + 0x34, 0x7E, 0x10, 0xF1, 0x7B, 0x8F, 0x63, 0xA0, 0x05, 0x9A, 0x43, 0x77, 0x21, 0xBF, 0x27, 0x09, + 0xC3, 0x9F, 0xB6, 0xD7, 0x29, 0xC2, 0xEB, 0xC0, 0xA4, 0x8B, 0x8C, 0x1D, 0xFB, 0xFF, 0xC1, 0xB2, + 0x97, 0x2E, 0xF8, 0x65, 0xF6, 0x75, 0x07, 0x04, 0x49, 0x33, 0xE4, 0xD9, 0xB9, 0xD0, 0x42, 0xC7, + 0x6C, 0x90, 0x00, 0x8E, 0x6F, 0x50, 0x01, 0xC5, 0xDA, 0x47, 0x3F, 0xCD, 0x69, 0xA2, 0xE2, 0x7A, + 0xA7, 0xC6, 0x93, 0x0F, 0x0A, 0x06, 0xE6, 0x2B, 0x96, 0xA3, 0x1C, 0xAF, 0x6A, 0x12, 0x84, 0x39, + 0xE7, 0xB0, 0x82, 0xF7, 0xFE, 0x9D, 0x87, 0x5C, 0x81, 0x35, 0xDE, 0xB4, 0xA5, 0xFC, 0x80, 0xEF, + 0xCB, 0xBB, 0x6B, 0x76, 0xBA, 0x5A, 0x7D, 0x78, 0x0B, 0x95, 0xE3, 0xAD, 0x74, 0x98, 0x3B, 0x36, + 0x64, 0x6D, 0xDC, 0xF0, 0x59, 0xA9, 0x4C, 0x17, 0x7F, 0x91, 0xB8, 0xC9, 0x57, 0x1B, 0xE0, 0x61 + ] + # fmt: on + + sboxes = {0: pi0_e, 1: pi1_e, 2: pi2_e, 3: pi3_e} + mapping = [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7] + + M = np.zeros((128, 128), dtype=int) + + for output_byte_index in range(16): + input_byte_index = mapping[output_byte_index] + for bit_index in range(8): + output_bit = output_byte_index * 8 + bit_index + input_bit = input_byte_index * 8 + bit_index + M[output_bit, input_bit] = 1 + + self.kalyna_matrix = [ + [0x01, 0x04, 0x07, 0x06, 0x08, 0x01, 0x05, 0x01], + [0x01, 0x01, 0x04, 0x07, 0x06, 0x08, 0x01, 0x05], + [0x05, 0x01, 0x01, 0x04, 0x07, 0x06, 0x08, 0x01], + [0x01, 0x05, 0x01, 0x01, 0x04, 0x07, 0x06, 0x08], + [0x08, 0x01, 0x05, 0x01, 0x01, 0x04, 0x07, 0x06], + [0x06, 0x08, 0x01, 0x05, 0x01, 0x01, 0x04, 0x07], + [0x07, 0x06, 0x08, 0x01, 0x05, 0x01, 0x01, 0x04], + [0x04, 0x07, 0x06, 0x08, 0x01, 0x05, 0x01, 0x01], + ] + + self.irreducible_polynomial = {8: 0x11D} + self.kalyna_matrix_description = [ + self.kalyna_matrix, + 0x11D, + 8, + ] + + self.add_round() + S0 = self.add_constant_component( + self.CIPHER_BLOCK_SIZE, 0x00000000000000000000000000000005 + ) + g1_first = self.add_MODADD_component( + [INPUT_KEY, S0.id], + [ + [i for i in range(self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE)], + [i for i in range(self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + g1_second = self.add_MODADD_component( + [INPUT_KEY, S0.id], + [ + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + sboxes_components = [] + for i in range(self.CIPHER_BLOCK_SIZE // 16): + sboxes_components.append( + self.add_SBOX_component( + [g1_first.id], + [list(range((8 * i), (8 * i + 8)))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + for i in range(self.CIPHER_BLOCK_SIZE // 16, self.CIPHER_BLOCK_SIZE // 8): + sboxes_components.append( + self.add_SBOX_component( + [g1_second.id], + [list(range((8 * i) - 64, (8 * i + 8) - 64))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + input_ids = [c.id for c in sboxes_components] + input_positions = [list(range(8)) for _ in range(16)] + shift1 = self.add_linear_layer_component( + input_ids, input_positions, self.CIPHER_BLOCK_SIZE, M.tolist() + ) + g1_first = self.add_mix_column_component( + [shift1.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + g1_second = self.add_mix_column_component( + [shift1.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + g2_first = self.add_XOR_component( + [g1_first.id, INPUT_KEY], + [ + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + [i for i in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + g2_second = self.add_XOR_component( + [g1_second.id, INPUT_KEY], + [ + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + [i for i in range(self.KEY_BLOCK_SIZE // 2, self.KEY_BLOCK_SIZE)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + sboxes_components_2 = [] + for i in range(self.CIPHER_BLOCK_SIZE // 16): + sboxes_components_2.append( + self.add_SBOX_component( + [g2_first.id], + [list(range((8 * i), (8 * i + 8)))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + for i in range(self.CIPHER_BLOCK_SIZE // 16, self.CIPHER_BLOCK_SIZE // 8): + sboxes_components_2.append( + self.add_SBOX_component( + [g2_second.id], + [list(range((8 * i) - 64, (8 * i + 8) - 64))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + + input_ids_2 = [c.id for c in sboxes_components_2] + input_positions_2 = [list(range(8)) for _ in range(16)] + shift_2 = self.add_linear_layer_component( + input_ids_2, input_positions_2, self.CIPHER_BLOCK_SIZE, M.tolist() + ) + g2_first = self.add_mix_column_component( + [shift_2.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + g2_second = self.add_mix_column_component( + [shift_2.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + g2_first_modadd = self.add_MODADD_component( + [INPUT_KEY, g2_first.id], + [ + [i for i in range(self.KEY_BLOCK_SIZE // 2, self.KEY_BLOCK_SIZE)], + [i for i in range(64)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + g2_second_modadd = self.add_MODADD_component( + [INPUT_KEY, g2_second.id], + [ + [i for i in range(self.KEY_BLOCK_SIZE // 2)], + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + sboxes_components_3 = [] + for i in range(self.CIPHER_BLOCK_SIZE // 16): + sboxes_components_3.append( + self.add_SBOX_component( + [g2_second_modadd.id], + [list(range((8 * i), (8 * i + 8)))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + for i in range(self.CIPHER_BLOCK_SIZE // 16, self.CIPHER_BLOCK_SIZE // 8): + sboxes_components_3.append( + self.add_SBOX_component( + [g2_first_modadd.id], + [list(range((8 * i) - 64, (8 * i + 8) - 64))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + + input_ids_3 = [c.id for c in sboxes_components_3] + input_positions_3 = [list(range(8)) for _ in range(16)] + shift_3 = self.add_linear_layer_component( + input_ids_3, input_positions_3, self.CIPHER_BLOCK_SIZE, M.tolist() + ) + g3_first = self.add_mix_column_component( + [shift_3.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + g3_second = self.add_mix_column_component( + [shift_3.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + + const = self.add_constant_component(self.CIPHER_BLOCK_SIZE, 0x0) + g3_second_128 = self.add_XOR_component( + [g3_second.id, const.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE, + ) + g3_first_128 = self.add_XOR_component( + [const.id, g3_first.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE, + ) + k_T = self.add_XOR_component( + [g3_second_128.id, g3_first_128.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE)], + [j for j in range(self.CIPHER_BLOCK_SIZE)], + ], + self.CIPHER_BLOCK_SIZE, + ) + tmv_0 = self.add_constant_component( + self.CIPHER_BLOCK_SIZE, 0x00010001000100010001000100010001 + ) + even_round_keys = {} + for k in range(0, self.NROUNDS + 1, 2): + tmv = self.add_rotate_component( + [tmv_0.id], + [list(range(self.KEY_BLOCK_SIZE))], + self.KEY_BLOCK_SIZE, + -k // 2, + ) + rotated_key = self.add_rotate_component( + [INPUT_KEY], + [list(range(self.KEY_BLOCK_SIZE))], + self.KEY_BLOCK_SIZE, + 32 * k, + ) + k_i_prime = self.add_MODADD_component( + [tmv.id, k_T.id], + [ + [j for j in range(self.KEY_BLOCK_SIZE)], + [j for j in range(self.KEY_BLOCK_SIZE)], + ], + self.CIPHER_BLOCK_SIZE, + ) + k_i_prime_int = self.add_MODADD_component( + [k_i_prime.id, rotated_key.id], + [ + [j for j in range(self.KEY_BLOCK_SIZE)], + [j for j in range(self.KEY_BLOCK_SIZE)], + ], + self.CIPHER_BLOCK_SIZE, + ) + sboxes_components_E = [] + for i in range(self.KEY_BLOCK_SIZE // 8): + sboxes_components_E.append( + self.add_SBOX_component( + [k_i_prime_int.id], + [list(range((8 * i), (8 * i + 8)))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + input_ids_E = [c.id for c in sboxes_components_E] + input_positions_E = [list(range(8)) for _ in range(16)] + shift_E = self.add_linear_layer_component( + input_ids_E, input_positions_E, self.CIPHER_BLOCK_SIZE, M.tolist() + ) + gE_first = self.add_mix_column_component( + [shift_E.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + gE_second = self.add_mix_column_component( + [shift_E.id], + [ + [ + i + for i in range( + self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE + ) + ] + ], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + k_i_prime_first_half = self.add_XOR_component( + [gE_second.id, k_i_prime.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.KEY_BLOCK_SIZE // 2, + ) + k_i_prime_second_half = self.add_XOR_component( + [gE_first.id, k_i_prime.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2, self.KEY_BLOCK_SIZE)], + ], + self.KEY_BLOCK_SIZE // 2, + ) + sboxes_components_E2 = [] + for i in range(self.KEY_BLOCK_SIZE // 16): + sboxes_components_E2.append( + self.add_SBOX_component( + [k_i_prime_first_half.id], + [list(range((8 * i), (8 * i + 8)))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + for i in range(self.KEY_BLOCK_SIZE // 16, self.KEY_BLOCK_SIZE // 8): + sboxes_components_E2.append( + self.add_SBOX_component( + [k_i_prime_second_half.id], + [list(range((8 * i) - 64, (8 * i + 8) - 64))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + + input_ids_E2 = [c.id for c in sboxes_components_E2] + input_positions_E2 = [list(range(8)) for _ in range(16)] + shift_E2 = self.add_linear_layer_component( + input_ids_E2, input_positions_E2, self.CIPHER_BLOCK_SIZE, M.tolist() + ) + gE2_first = self.add_mix_column_component( + [shift_E2.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + gE2_second = self.add_mix_column_component( + [shift_E2.id], + [ + [ + i + for i in range( + self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE + ) + ] + ], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + k_2i_prime_first_half = self.add_MODADD_component( + [gE2_second.id, k_i_prime.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.KEY_BLOCK_SIZE // 2, + ) + k_2i_prime_second_half = self.add_MODADD_component( + [gE2_first.id, k_i_prime.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2, self.KEY_BLOCK_SIZE)], + ], + self.KEY_BLOCK_SIZE // 2, + ) + self.add_round_key_output_component( + [k_2i_prime_first_half.id, k_2i_prime_second_half.id], + [ + list(range(self.CIPHER_BLOCK_SIZE // 2)), + list(range(self.CIPHER_BLOCK_SIZE // 2)), + ], + self.KEY_BLOCK_SIZE, + ) + even_round_keys[k] = { + "first_half": k_2i_prime_first_half, + "second_half": k_2i_prime_second_half, + } + + odd_round_keys = {} + for i in range(1, self.NROUNDS, 2): + even_first_id = even_round_keys[i - 1]["first_half"] + even_second_id = even_round_keys[i - 1]["second_half"] + left_shifted = self.add_rotate_component( + [even_second_id.id, even_first_id.id], + [ + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.KEY_BLOCK_SIZE, + -self.KEY_BLOCK_SIZE // 4 + 24, + ) + self.add_round_key_output_component( + [left_shifted.id], + [list(range(self.KEY_BLOCK_SIZE))], + self.KEY_BLOCK_SIZE, + ) + odd_round_keys[i] = left_shifted + + k0_first = even_round_keys[0]["first_half"] + k0_second = even_round_keys[0]["second_half"] + + # self.add_round() + first_half = self.add_MODADD_component( + [k0_first.id, INPUT_PLAINTEXT], + [ + [i for i in range(self.KEY_BLOCK_SIZE // 2)], + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + + second_half = self.add_MODADD_component( + [k0_second.id, INPUT_PLAINTEXT], + [ + [i for i in range(self.KEY_BLOCK_SIZE // 2)], + [i for i in range(self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + for k in range(1, self.NROUNDS): + self.add_round() + sboxes_components = [] + for i in range(self.CIPHER_BLOCK_SIZE // 16): + sboxes_components.append( + self.add_SBOX_component( + [first_half.id], + [list(range((8 * i), (8 * i + 8)))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + for i in range(self.CIPHER_BLOCK_SIZE // 16, self.CIPHER_BLOCK_SIZE // 8): + sboxes_components.append( + self.add_SBOX_component( + [second_half.id], + [list(range((8 * i) - 64, (8 * i + 8) - 64))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + + input_ids = [c.id for c in sboxes_components] + input_positions = [list(range(8)) for _ in range(16)] + shift = self.add_linear_layer_component( + input_ids, input_positions, self.CIPHER_BLOCK_SIZE, M.tolist() + ) + g_first = self.add_mix_column_component( + [shift.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + g_second = self.add_mix_column_component( + [shift.id], + [ + [ + i + for i in range( + self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE + ) + ] + ], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + if k % 2 == 1: + k_round = odd_round_keys[k] + first_half = self.add_XOR_component( + [g_second.id, k_round.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + second_half = self.add_XOR_component( + [g_first.id, k_round.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [ + j + for j in range( + self.KEY_BLOCK_SIZE // 2, self.KEY_BLOCK_SIZE + ) + ], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + else: + k_round_first_half = even_round_keys[k]["first_half"] + k_round_second_half = even_round_keys[k]["second_half"] + first_half = self.add_XOR_component( + [g_second.id, k_round_first_half.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + second_half = self.add_XOR_component( + [g_first.id, k_round_second_half.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + + self.add_round_output_component( + [first_half.id, second_half.id], + [ + list(range(self.CIPHER_BLOCK_SIZE // 2)), + list(range(self.CIPHER_BLOCK_SIZE // 2)), + ], + self.CIPHER_BLOCK_SIZE, + ) + + self.add_round() + sboxes_components = [] + for i in range(self.CIPHER_BLOCK_SIZE // 16): + sboxes_components.append( + self.add_SBOX_component( + [first_half.id], + [list(range((8 * i), (8 * i + 8)))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + for i in range(self.CIPHER_BLOCK_SIZE // 16, self.CIPHER_BLOCK_SIZE // 8): + sboxes_components.append( + self.add_SBOX_component( + [second_half.id], + [list(range((8 * i) - 64, (8 * i + 8) - 64))], + self.SBOX_BIT_SIZE, + sboxes[3 - (i % 4)], + ) + ) + + input_ids = [c.id for c in sboxes_components] + input_positions = [list(range(8)) for _ in range(16)] + shift = self.add_linear_layer_component( + input_ids, input_positions, self.CIPHER_BLOCK_SIZE, M.tolist() + ) + g_first = self.add_mix_column_component( + [shift.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + g_second = self.add_mix_column_component( + [shift.id], + [[i for i in range(self.CIPHER_BLOCK_SIZE // 2, self.CIPHER_BLOCK_SIZE)]], + self.CIPHER_BLOCK_SIZE // 2, + self.kalyna_matrix_description, + ) + + k_round_first_half = even_round_keys[self.NROUNDS]["first_half"] + k_round_second_half = even_round_keys[self.NROUNDS]["second_half"] + first_half = self.add_MODADD_component( + [g_second.id, k_round_first_half.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + second_half = self.add_MODADD_component( + [g_first.id, k_round_second_half.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 2)], + [j for j in range(self.KEY_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE // 2, + ) + self.add_cipher_output_component( + [first_half.id, second_half.id], + [ + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + [i for i in range(self.CIPHER_BLOCK_SIZE // 2)], + ], + self.CIPHER_BLOCK_SIZE, + ) diff --git a/claasp/ciphers/block_ciphers/led_block_cipher.py b/claasp/ciphers/block_ciphers/led_block_cipher.py new file mode 100644 index 000000000..0efe7958c --- /dev/null +++ b/claasp/ciphers/block_ciphers/led_block_cipher.py @@ -0,0 +1,173 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + + +from claasp.cipher import Cipher +from claasp.DTOs.component_state import ComponentState +from claasp.name_mappings import BLOCK_CIPHER, INPUT_PLAINTEXT, INPUT_KEY + +SBOX = [0xC, 0x5, 0x6, 0xB, 0x9, 0x0, 0xA, 0xD, 0x3, 0xE, 0xF, 0x8, 0x4, 0x7, 0x1, 0x2] + +M = [[0x4, 0x1, 0x2, 0x2], [0x8, 0x6, 0x5, 0x6], [0xB, 0xE, 0xA, 0x9], [0x2, 0x2, 0xF, 0xB]] + +IRREDUCIBLE_POLYNOMIAL = 0x13 + +PARAMETERS_CONFIGURATION_LIST = [ + {"key_bit_size": 64, "number_of_rounds": 32}, + {"key_bit_size": 128, "number_of_rounds": 48}, +] + + +def get_round_register(round_number): + round_register = [0, 0, 0, 0, 0, 0] + for _ in range(round_number + 1): + new_round_bit = round_register[0] ^ round_register[1] ^ 1 + round_register = round_register[1:] + [new_round_bit] + return round_register + + +class LedBlockCipher(Cipher): + """ + LED Block Cipher implementation + + Note that this implementation do not use the number of steps as a parameter, + instead it derives it from the number of rounds (number_of_steps = number_of_rounds // 4). + """ + + def __init__(self, key_bit_size=64, number_of_rounds=32): + assert number_of_rounds % 4 == 0, "Number of rounds must be a multiple of 4." + self.block_bit_size = 64 + self.key_bit_size = key_bit_size + self.number_of_steps = number_of_rounds // 4 + + super().__init__( + family_name="led", + cipher_type=BLOCK_CIPHER, + cipher_inputs=[INPUT_PLAINTEXT, INPUT_KEY], + cipher_inputs_bit_size=[self.block_bit_size, self.key_bit_size], + cipher_output_bit_size=self.block_bit_size, + ) + + state = ComponentState([INPUT_PLAINTEXT], [list(range(self.block_bit_size))]) + if key_bit_size == self.block_bit_size: + range_0 = list(range(self.block_bit_size)) + range_1 = list(range(self.block_bit_size)) + elif key_bit_size == 5 * self.block_bit_size // 4: + range_0 = list(range(self.block_bit_size)) + range_1 = list(range(self.block_bit_size, self.key_bit_size)) + list(range(3 * self.block_bit_size // 4)) + elif key_bit_size == 2 * self.block_bit_size: + range_0 = list(range(self.block_bit_size)) + range_1 = list(range(self.block_bit_size, self.key_bit_size)) + key = [ComponentState([INPUT_KEY], [range_0]), ComponentState([INPUT_KEY], [range_1])] + + round_number = 0 + key_index = 0 + + self.add_round() + state = self.add_round_key(state, key[key_index]) + for step_number in range(self.number_of_steps): + for _ in range(4): + state = self.add_constants(state, round_number) + state = self.sub_cells(state) + state = self.shift_rows(state) + state = self.mix_columns(state) + round_number += 1 + state = self.add_round_key(state, key[key_index]) + key_index = (key_index + 1) % 2 + if step_number != self.number_of_steps - 1: + self.add_round_output_component(state.id, state.input_bit_positions, self.block_bit_size) + self.add_round() + else: + self.add_cipher_output_component(state.id, state.input_bit_positions, self.block_bit_size) + + def get_round_constant(self, round_number): + register = get_round_register(round_number) + rc_high = "".join(map(str, register[0:3])) + rc_low = "".join(map(str, register[3:6])) + rc_high_number = int(rc_high, 2) + rc_low_number = int(rc_low, 2) + + ks_high = 4 if self.key_bit_size == 64 else 8 + + constant = ( + ks_high << 60 + | rc_high_number << 56 + | (ks_high ^ 1) << 44 + | rc_low_number << 40 + | 2 << 28 + | rc_high_number << 24 + | 3 << 12 + | rc_low_number << 8 + ) + + return constant + + def add_constants(self, state, round_number): + constant = self.get_round_constant(round_number) + const_id = self.add_constant_component(self.block_bit_size, constant).id + + xor_id = self.add_XOR_component( + [*state.id, const_id], + [*state.input_bit_positions, list(range(self.block_bit_size))], + self.block_bit_size, + ).id + return ComponentState([xor_id], [list(range(self.block_bit_size))]) + + def sub_cells(self, state): + sbox_out_ids = [] + for i in range(16): + id_sbox = self.add_SBOX_component(state.id, [state.input_bit_positions[0][i * 4 : (i + 1) * 4]], 4, SBOX).id + sbox_out_ids.append(id_sbox) + return ComponentState(sbox_out_ids, [list(range(4))] * 16) + + def shift_rows(self, state): + shifted = [] + for i in range(4): + row_data = state.id[i * 4 : (i + 1) * 4] + row_data = row_data[i:] + row_data[:i] + shifted.extend(row_data) + + return ComponentState(shifted, [list(range(4))] * 16) + + def mix_columns(self, state): + mix_columns_ids = [] + + for col in range(4): + col_ids = [state.id[row * 4 + col] for row in range(4)] + col_pos = [state.input_bit_positions[row * 4 + col] for row in range(4)] + + mix_columns_id = self.add_mix_column_component( + col_ids, col_pos, self.block_bit_size // 4, [M, IRREDUCIBLE_POLYNOMIAL, 4] + ).id + + mix_columns_ids.append(mix_columns_id) + + new_state = [] + new_positions = [] + + for i in range(4): + new_state.extend(mix_columns_ids) + new_positions.extend(list(range(4 * i, 4 * (i + 1))) for _ in range(4)) + + return ComponentState(new_state, new_positions) + + def add_round_key(self, state, key): + xor_id = self.add_XOR_component( + [*state.id, *key.id], [*state.input_bit_positions, *key.input_bit_positions], self.block_bit_size + ).id + + return ComponentState([xor_id], [list(range(self.block_bit_size))]) diff --git a/claasp/ciphers/block_ciphers/mantis_block_cipher.py b/claasp/ciphers/block_ciphers/mantis_block_cipher.py new file mode 100644 index 000000000..8348656dc --- /dev/null +++ b/claasp/ciphers/block_ciphers/mantis_block_cipher.py @@ -0,0 +1,345 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + +from claasp.cipher import Cipher +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, INPUT_TWEAK +from claasp.utils.utils import extract_inputs +import numpy as np + +MANTIS_ROUND_CONSTANTS = [ + 0x13198a2e03707344, + 0xa4093822299f31d0, + 0x082efa98ec4e6c89, + 0x452821e638d01377, + 0xbe5466cf34e90c6c, + 0xc0ac29b7c97c50dd, + 0x3f84d5b5b5470917, + 0x9216d5d98979fb1b +] + +MANTIS_ALPHA = 0x243f6a8885a308d3 + +MANTIS_SBOX = [0xC, 0xA, 0xD, 0x3, 0xE, 0xB, 0xF, 0x7, + 0x8, 0x9, 0x1, 0x5, 0x0, 0x2, 0x4, 0x6] + +TWEAK_PERMUTATION = [6, 5, 14, 15, 0, 1, 2, 3, 7, 12, 13, 4, 8, 9, 10, 11] +TWEAK_PERMUTATION_INV = [4, 5, 6, 7, 11, 1, 0, 8, 12, 13, 14, 15, 9, 10, 2, 3] + +CELL_PERMUTATION = [0, 11, 6, 13, 10, 1, 12, 7, 5, 14, 3, 8, 15, 4, 9, 2] +CELL_PERMUTATION_INV = [0, 5, 15, 10, 13, 8, 2, 7, 11, 14, 4, 1, 6, 3, 9, 12] + +MIDORI_M = [[0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]] + + +class MantisBlockCipher(Cipher): + + """ + Return a cipher object of the MANTIS Block Cipher. + + The MANTIS cipher is a lightweight tweakable block cipher designed for low-latency applications. + It operates on 64-bit blocks with a 128-bit key and a 64-bit tweak. + INPUT: + + - ``number_of_rounds`` -- **integer** (default: `6`); + number of rounds of the cipher. Must be one of [5, 6, 7, 8]. + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.mantis_block_cipher import MantisBlockCipher + sage: mantis = MantisBlockCipher(number_of_rounds=6) + sage: plaintext = 0xd6522035c1c0c6c1 + sage: key = 0x92f09952c625e3e9d7a060f714c0292b + sage: tweak = 0xba912e6f1055fed2 + sage: ciphertext = 0x60e43457311936fd + sage: mantis.evaluate([plaintext, key, tweak]) == ciphertext + True + """ + + def __init__(self, number_of_rounds=6): + self.block_bit_size = 64 + self.key_bit_size = 128 + self.tweak_bit_size = 64 + + super().__init__( + family_name="mantis", + cipher_type="block_cipher", + cipher_inputs=[INPUT_PLAINTEXT, INPUT_KEY, INPUT_TWEAK], + cipher_inputs_bit_size=[ + self.block_bit_size, + self.key_bit_size, + self.tweak_bit_size], + cipher_output_bit_size=self.block_bit_size + ) + self.add_round() + current_state = self.add_pre_whitening() + current_state, current_tweak = self.add_forward_rounds( + current_state, number_of_rounds) + current_state = self.add_middle_layer(current_state) + current_state, current_tweak = self.add_backward_rounds( + current_state, number_of_rounds, current_tweak) + ciphertext = self.add_post_whitening(current_state, current_tweak) + self.add_cipher_output_component( + [ciphertext], [list(range(64))], 64) + + def apply_sbox_layer(self, current_state): + sbox_id_list = [] + sbox_bit_positions = [] + for i in range(16): + data_id_list, data_bit_positions = extract_inputs( + [current_state], [list(range(64))], list(range(i * 4, (i + 1) * 4))) + sbox_output = self.add_SBOX_component( + data_id_list, data_bit_positions, 4, MANTIS_SBOX) + sbox_id_list.append(sbox_output.id) + sbox_bit_positions.append(list(range(4))) + + zero_constant = self.add_constant_component(64, 0) + concatenated_state = self.add_XOR_component( + sbox_id_list + [zero_constant.id], + sbox_bit_positions + [list(range(64))], + 64 + ) + return concatenated_state.id + + def add_round_constant(self, sbox_id_list, sbox_bit_positions, round_idx): + constant = self.add_constant_component( + 64, MANTIS_ROUND_CONSTANTS[round_idx]) + constant_xor = self.add_XOR_component( + sbox_id_list + [constant.id], + sbox_bit_positions + [list(range(64))], + 64 + ) + return constant_xor.id + + def add_tweakey(self, current_tweak): + permuted_tweak = self.add_word_permutation_component( + [current_tweak], [list(range(64))], 64, TWEAK_PERMUTATION, 4 + ).id + + tweakey = self.add_XOR_component( + [permuted_tweak, INPUT_KEY], + [list(range(64)), list(range(64, 128))], 64 + ) + return tweakey.id, permuted_tweak + + def permute_cells(self, current_state): + permuted_state = self.add_word_permutation_component( + [current_state], [list(range(64))], 64, CELL_PERMUTATION, 4 + ) + return permuted_state.id + + def apply_mixcolumns(self, current_state): + column_size = 16 + num_columns = 4 + + groups = [] + for col in range(num_columns): + indices = [] + for row in range(num_columns): + nibble_index = row * 4 + col + start_bit = nibble_index * 4 + indices.extend(range(start_bit, start_bit + 4)) + groups.append(indices) + + mix_column_ids = [] + for i in range(num_columns): + data_id_list, data_bit_positions = extract_inputs( + [current_state], [list(range(64))], groups[i]) + + mix_output = self.add_mix_column_component( + data_id_list, data_bit_positions, column_size, [ + MIDORI_M, 19, 4] + ) + mix_column_ids.append(mix_output.id) + + zero_constant = self.add_constant_component(64, 0) + concatenated_state = self.add_XOR_component( + mix_column_ids + [zero_constant.id], + [list(range(column_size)) + for _ in range(num_columns)] + [list(range(64))], + 64 + ) + + column_to_row_permutation = [ + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 + ] + mixed_state = self.add_word_permutation_component( + [concatenated_state.id], [ + list(range(64))], 64, column_to_row_permutation, 4 + ) + return mixed_state.id + + def permute_cells_inverse(self, current_state): + permuted_state = self.add_word_permutation_component( + [current_state], [list(range(64))], 64, CELL_PERMUTATION_INV, 4 + ) + return permuted_state.id + + def add_tweakey_backward( + self, current_state, current_tweak, round_idx, num_rounds): + if round_idx != num_rounds - 1: + current_tweak = self.add_word_permutation_component( + [current_tweak], + [list(range(64))], + 64, + TWEAK_PERMUTATION_INV, + 4 + ).id + alpha_component = self.add_constant_component(64, MANTIS_ALPHA) + k1_xor_alpha = self.add_XOR_component( + [INPUT_KEY, alpha_component.id], + [list(range(64, 128)), list(range(64))], + 64 + ) + tweakey = self.add_XOR_component( + [current_tweak, k1_xor_alpha.id], + [list(range(64)), list(range(64))], + 64 + ) + result = self.add_XOR_component( + [current_state, tweakey.id], + [list(range(64)), list(range(64))], + 64 + ) + return result.id, current_tweak + + def add_round_constant_direct(self, current_state, round_idx): + constant = self.add_constant_component( + 64, MANTIS_ROUND_CONSTANTS[round_idx]) + constant_xor = self.add_XOR_component( + [current_state, constant.id], + [list(range(64)), list(range(64))], + 64 + ) + return constant_xor.id + + def add_pre_whitening(self): + k1_xor_tweak = self.add_XOR_component( + [INPUT_KEY, INPUT_TWEAK], [list(range(64, 128)), list(range(64))], 64) + m_xor_k0 = self.add_XOR_component([INPUT_PLAINTEXT, INPUT_KEY], [ + list(range(64)), list(range(64))], 64) + pre_whitening = self.add_XOR_component([m_xor_k0.id, k1_xor_tweak.id], [ + list(range(64)), list(range(64))], 64) + return pre_whitening.id + + def add_forward_rounds(self, current_state, num_rounds): + current_tweak = INPUT_TWEAK + + for round_idx in range(num_rounds): + sbox_id_list = [] + sbox_bit_positions = [] + for i in range(16): + data_id_list, data_bit_positions = extract_inputs( + [current_state], [list(range(64))], list(range(i * 4, (i + 1) * 4))) + sbox_output = self.add_SBOX_component( + data_id_list, data_bit_positions, 4, MANTIS_SBOX) + sbox_id_list.append(sbox_output.id) + sbox_bit_positions.append(list(range(4))) + + current_state = self.add_round_constant( + sbox_id_list, sbox_bit_positions, round_idx) + + tweakey_id, current_tweak = self.add_tweakey(current_tweak) + current_state = self.add_XOR_component( + [current_state, tweakey_id], + [list(range(64)), list(range(64))], 64 + ).id + + current_state = self.permute_cells(current_state) + + current_state = self.apply_mixcolumns(current_state) + + self.add_round_output_component( + [current_state], [list(range(64))], 64) + self.add_round() + + return current_state, current_tweak + + def add_middle_layer(self, current_state): + current_state = self.apply_sbox_layer(current_state) + + current_state = self.apply_mixcolumns(current_state) + + current_state = self.apply_sbox_layer(current_state) + + return current_state + + def add_backward_rounds(self, current_state, num_rounds, current_tweak): + for round_idx in range(num_rounds - 1, -1, -1): + current_state = self.apply_mixcolumns(current_state) + + current_state = self.permute_cells_inverse(current_state) + + current_state, current_tweak = self.add_tweakey_backward( + current_state, current_tweak, round_idx, num_rounds + ) + + current_state = self.add_round_constant_direct( + current_state, round_idx) + + current_state = self.apply_sbox_layer(current_state) + + self.add_round_output_component( + [current_state], [list(range(64))], 64) + + if round_idx != 0: + self.add_round() + + return current_state, current_tweak + + def add_post_whitening(self, current_state, current_tweak): + current_tweak = self.add_word_permutation_component( + [current_tweak], + [list(range(64))], + 64, + TWEAK_PERMUTATION_INV, + 4 + ).id + k0_rot1 = self.add_rotate_component( + [INPUT_KEY], [list(range(64))], 64, 1 + ) + k0_sh63 = self.add_SHIFT_component( + [INPUT_KEY], [list(range(64))], 64, 63 + ) + k0_prime = self.add_XOR_component( + [k0_rot1.id, k0_sh63.id], + [list(range(64)), list(range(64))], + 64 + ) + alpha_component = self.add_constant_component(64, MANTIS_ALPHA) + k1_xor_alpha = self.add_XOR_component( + [INPUT_KEY, alpha_component.id], + [list(range(64, 128)), list(range(64))], + 64 + ) + + k1_alpha_xor_tweak = self.add_XOR_component( + [k1_xor_alpha.id, current_tweak], + [list(range(64)), list(range(64))], + 64 + ) + state_xor_tweakey = self.add_XOR_component( + [current_state, k1_alpha_xor_tweak.id], + [list(range(64)), list(range(64))], + 64 + ) + post_whitening = self.add_XOR_component( + [state_xor_tweakey.id, k0_prime.id], + [list(range(64)), list(range(64))], + 64 + ) + + return post_whitening.id diff --git a/claasp/ciphers/block_ciphers/skipjack_block_cipher.py b/claasp/ciphers/block_ciphers/skipjack_block_cipher.py new file mode 100644 index 000000000..fb55253e9 --- /dev/null +++ b/claasp/ciphers/block_ciphers/skipjack_block_cipher.py @@ -0,0 +1,265 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + + +""" +SKIPJACK block cipher [NIST1998] +https://csrc.nist.gov/csrc/media/projects/cryptographic-algorithm-validation-program/documents/skipjack/skipjack.pdf + +Cipher Specifications: +- Block size: 64 bits (4 words of 16 bits each) +- Key size: 80 bits (10 bytes) +- Rounds: 32 (alternating Rule A and Rule B) +- Structure: + * Rounds 1-8: Rule A + * Rounds 9-16: Rule B + * Rounds 17-24: Rule A + * Rounds 25-32: Rule B +""" + +from claasp.cipher import Cipher +from claasp.DTOs.component_state import ComponentState +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY + + +# SKIPJACK F-Table (S-box 8x8 bits) +SKIPJACK_FTABLE = [ + 0xa3, 0xd7, 0x09, 0x83, 0xf8, 0x48, 0xf6, 0xf4, 0xb3, 0x21, 0x15, 0x78, 0x99, 0xb1, 0xaf, 0xf9, + 0xe7, 0x2d, 0x4d, 0x8a, 0xce, 0x4c, 0xca, 0x2e, 0x52, 0x95, 0xd9, 0x1e, 0x4e, 0x38, 0x44, 0x28, + 0x0a, 0xdf, 0x02, 0xa0, 0x17, 0xf1, 0x60, 0x68, 0x12, 0xb7, 0x7a, 0xc3, 0xe9, 0xfa, 0x3d, 0x53, + 0x96, 0x84, 0x6b, 0xba, 0xf2, 0x63, 0x9a, 0x19, 0x7c, 0xae, 0xe5, 0xf5, 0xf7, 0x16, 0x6a, 0xa2, + 0x39, 0xb6, 0x7b, 0x0f, 0xc1, 0x93, 0x81, 0x1b, 0xee, 0xb4, 0x1a, 0xea, 0xd0, 0x91, 0x2f, 0xb8, + 0x55, 0xb9, 0xda, 0x85, 0x3f, 0x41, 0xbf, 0xe0, 0x5a, 0x58, 0x80, 0x5f, 0x66, 0x0b, 0xd8, 0x90, + 0x35, 0xd5, 0xc0, 0xa7, 0x33, 0x06, 0x65, 0x69, 0x45, 0x00, 0x94, 0x56, 0x6d, 0x98, 0x9b, 0x76, + 0x97, 0xfc, 0xb2, 0xc2, 0xb0, 0xfe, 0xdb, 0x20, 0xe1, 0xeb, 0xd6, 0xe4, 0xdd, 0x47, 0x4a, 0x1d, + 0x42, 0xed, 0x9e, 0x6e, 0x49, 0x3c, 0xcd, 0x43, 0x27, 0xd2, 0x07, 0xd4, 0xde, 0xc7, 0x67, 0x18, + 0x89, 0xcb, 0x30, 0x1f, 0x8d, 0xc6, 0x8f, 0xaa, 0xc8, 0x74, 0xdc, 0xc9, 0x5d, 0x5c, 0x31, 0xa4, + 0x70, 0x88, 0x61, 0x2c, 0x9f, 0x0d, 0x2b, 0x87, 0x50, 0x82, 0x54, 0x64, 0x26, 0x7d, 0x03, 0x40, + 0x34, 0x4b, 0x1c, 0x73, 0xd1, 0xc4, 0xfd, 0x3b, 0xcc, 0xfb, 0x7f, 0xab, 0xe6, 0x3e, 0x5b, 0xa5, + 0xad, 0x04, 0x23, 0x9c, 0x14, 0x51, 0x22, 0xf0, 0x29, 0x79, 0x71, 0x7e, 0xff, 0x8c, 0x0e, 0xe2, + 0x0c, 0xef, 0xbc, 0x72, 0x75, 0x6f, 0x37, 0xa1, 0xec, 0xd3, 0x8e, 0x62, 0x8b, 0x86, 0x10, 0xe8, + 0x08, 0x77, 0x11, 0xbe, 0x92, 0x4f, 0x24, 0xc5, 0x32, 0x36, 0x9d, 0xcf, 0xf3, 0xa6, 0xbb, 0xac, + 0x5e, 0x6c, 0xa9, 0x13, 0x57, 0x25, 0xb5, 0xe3, 0xbd, 0xa8, 0x3a, 0x01, 0x05, 0x59, 0x2a, 0x46 +] + + +class SkipjackBlockCipher(Cipher): + """ + SKIPJACK NSA - 64-bit block, 80-bit key, 32 rounds + + Based on official CLAASP documentation: + - add_SBOX_component(input_id_links, input_bit_positions, output_bit_size, description) + - add_XOR_component(input_id_links, input_bit_positions, output_bit_size) + - add_concatenate_component(input_id_links, input_bit_positions, output_bit_size) + + Where: + - input_id_links: list of strings ['comp1', 'comp2'] + - input_bit_positions: list of lists [[bits1], [bits2]] + """ + + def __init__(self, number_of_rounds=32): + self.WORD_SIZE = 16 + self.BLOCK_SIZE = 64 + self.KEY_SIZE = 80 + + super().__init__( + family_name="skipjack", + cipher_type="block_cipher", + cipher_inputs=[INPUT_PLAINTEXT, INPUT_KEY], + cipher_inputs_bit_size=[self.BLOCK_SIZE, self.KEY_SIZE], + cipher_output_bit_size=self.BLOCK_SIZE, + ) + + # Initialize state: w1, w2, w3, w4 from plaintext + # Plaintext format (big-endian): bits 0-15 (w1), 16-31 (w2), 32-47 (w3), 48-63 (w4) + w1, w2, w3, w4 = self._initialize_state() + + # 32 rounds + for round_number in range(number_of_rounds): + self.add_round() + counter = round_number + 1 + + if (1 <= counter <= 8) or (17 <= counter <= 24): + # Rule A + w1, w2, w3, w4 = self._rule_a(w1, w2, w3, w4, counter, round_number) + else: + # Rule B + w1, w2, w3, w4 = self._rule_b(w1, w2, w3, w4, counter, round_number) + + self._add_round_output(w1, w2, w3, w4, round_number, number_of_rounds) + + def _initialize_state(self): + """ + Extract w1, w2, w3, w4 from plaintext. + + Plaintext = 0x33221100ddccbbaa (64 bits) + Conventional big-endian: + w1 = 0x3322 (bits 0-15, most significant) + w2 = 0x1100 (bits 16-31) + w3 = 0xddcc (bits 32-47) + w4 = 0xbbaa (bits 48-63, least significant) + """ + w1 = ComponentState([INPUT_PLAINTEXT], [list(range(0, 16))]) + w2 = ComponentState([INPUT_PLAINTEXT], [list(range(16, 32))]) + w3 = ComponentState([INPUT_PLAINTEXT], [list(range(32, 48))]) + w4 = ComponentState([INPUT_PLAINTEXT], [list(range(48, 64))]) + return w1, w2, w3, w4 + + def _g_permutation(self, word, step): + """ + G function: 4-round Feistel network with F-table SBOX. + + Algorithm: + - Input: 16-bit word + - g[0] = high byte, g[1] = low byte + - 4 Feistel rounds: g[i+2] = F[g[i+1] XOR key[j]] XOR g[i] + - Output: (g[4] << 8) | g[5] + + CLAASP convention (big-endian): + - bits [0:7] = high byte = g[0] + - bits [8:15] = low byte = g[1] + """ + # Extract g[0] (high byte, bits 0-7) and g[1] (low byte, bits 8-15) + g0 = ComponentState(word.id, [word.input_bit_positions[0][0:8]]) + g1 = ComponentState(word.id, [word.input_bit_positions[0][8:16]]) + + g_prev = g0 # g[0] + g_out = g1 # g[1] + + for feistel_round in range(4): + # Key index: j = (4*step + feistel_round) % 10 + key_index = (4 * step + feistel_round) % 10 + + # Extract corresponding key byte (KEY = 80 bits, 10 bytes in big-endian) + # byte[0] = bits 0-7, ..., byte[9] = bits 72-79 + bit_start = key_index * 8 + bit_end = (key_index + 1) * 8 + key_byte = ComponentState([INPUT_KEY], [list(range(bit_start, bit_end))]) + + # XOR: g_out XOR key_byte + self.add_XOR_component( + [g_out.id[0], key_byte.id[0]], + [g_out.input_bit_positions[0], key_byte.input_bit_positions[0]], + 8 + ) + xor_result = ComponentState([self.get_current_component_id()], [list(range(8))]) + + # SBOX: F[xor_result] + self.add_SBOX_component( + [xor_result.id[0]], + [xor_result.input_bit_positions[0]], + 8, + SKIPJACK_FTABLE + ) + sbox_result = ComponentState([self.get_current_component_id()], [list(range(8))]) + + # XOR: sbox_result XOR g_prev + self.add_XOR_component( + [sbox_result.id[0], g_prev.id[0]], + [sbox_result.input_bit_positions[0], g_prev.input_bit_positions[0]], + 8 + ) + g_new = ComponentState([self.get_current_component_id()], [list(range(8))]) + + # Update for next iteration + g_prev = g_out + g_out = g_new + + # At the end: g_prev = g[4], g_out = g[5] + # Result: (g[4] << 8) | g[5] + # In CLAASP concatenate: first input goes to MSB (high byte position) + # So we need: [g[4], g[5]] where g[4] is high byte + self.add_concatenate_component( + [g_prev.id[0], g_out.id[0]], + [g_prev.input_bit_positions[0], g_out.input_bit_positions[0]], + 16 + ) + return ComponentState([self.get_current_component_id()], [list(range(16))]) + + def _rule_a(self, w1, w2, w3, w4, counter, round_number): + """Rule A: w1' = G(w1) XOR w4 XOR counter, w2' = G(w1), w3' = w2, w4' = w3""" + g_output = self._g_permutation(w1, round_number) + + # Counter + self.add_constant_component(16, counter) + counter_comp = ComponentState([self.get_current_component_id()], [list(range(16))]) + + # w1' = G(w1) XOR w4 XOR counter + self.add_XOR_component( + [g_output.id[0], w4.id[0]], + [g_output.input_bit_positions[0], w4.input_bit_positions[0]], + 16 + ) + temp = ComponentState([self.get_current_component_id()], [list(range(16))]) + + self.add_XOR_component( + [temp.id[0], counter_comp.id[0]], + [temp.input_bit_positions[0], counter_comp.input_bit_positions[0]], + 16 + ) + w1_new = ComponentState([self.get_current_component_id()], [list(range(16))]) + + return w1_new, g_output, w2, w3 + + def _rule_b(self, w1, w2, w3, w4, counter, round_number): + """Rule B: w1' = w4, w2' = G(w1), w3' = w1 XOR w2 XOR counter, w4' = w3""" + g_output = self._g_permutation(w1, round_number) + + # Counter + self.add_constant_component(16, counter) + counter_comp = ComponentState([self.get_current_component_id()], [list(range(16))]) + + # w3' = w1 XOR w2 XOR counter + self.add_XOR_component( + [w1.id[0], w2.id[0]], + [w1.input_bit_positions[0], w2.input_bit_positions[0]], + 16 + ) + temp = ComponentState([self.get_current_component_id()], [list(range(16))]) + + self.add_XOR_component( + [temp.id[0], counter_comp.id[0]], + [temp.input_bit_positions[0], counter_comp.input_bit_positions[0]], + 16 + ) + w3_new = ComponentState([self.get_current_component_id()], [list(range(16))]) + + return w4, g_output, w3_new, w3 + + def _add_round_output(self, w1, w2, w3, w4, round_number, total_rounds): + """Add round output: concatenate w1||w2||w3||w4""" + self.add_concatenate_component( + [w1.id[0], w2.id[0], w3.id[0], w4.id[0]], + [w1.input_bit_positions[0], w2.input_bit_positions[0], w3.input_bit_positions[0], w4.input_bit_positions[0]], + 64 + ) + + if round_number == total_rounds - 1: + # Final cipher output + self.add_cipher_output_component( + [self.get_current_component_id()], + [list(range(64))], + 64 + ) + else: + # Intermediate output + self.add_round_output_component( + [self.get_current_component_id()], + [list(range(64))], + 64 + ) + \ No newline at end of file diff --git a/claasp/ciphers/block_ciphers/sm4_block_cipher.py b/claasp/ciphers/block_ciphers/sm4_block_cipher.py new file mode 100644 index 000000000..aedcecbfe --- /dev/null +++ b/claasp/ciphers/block_ciphers/sm4_block_cipher.py @@ -0,0 +1,294 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** +from claasp.cipher import Cipher +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT + + +class SM4(Cipher): + """ + Return a cipher object of SM4 Block Cipher. + The technical specifications along with the test vectors can be found here: http://www.gmbz.org.cn/upload/2025-01-23/1737625646289030731.pdf + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.sm4_block_cipher import SM4 + sage: sm4 = SM4() + sage: key = 0x0123456789ABCDEFFEDCBA9876543210 + sage: plaintext = 0x0123456789ABCDEFFEDCBA9876543210 + sage: ciphertext = 0x681EDF34D206965E86B3E94F536E4246 + sage: sm4.evaluate([key, plaintext]) == ciphertext + True + """ + + def __init__(self, number_of_rounds=32, word_size=8, state_size=8): + # cipher dictionary initialize + self.CIPHER_BLOCK_SIZE = 128 + self.KEY_BLOCK_SIZE = 128 + self.NROUNDS = number_of_rounds + self.NUM_ROWS = state_size + self.SBOX_BIT_SIZE = word_size + self.ROW_SIZE = state_size * word_size + self.round_keys = {} + + super().__init__( + family_name="sm4_block_cipher", + cipher_type="BLOCK_CIPHER", + cipher_inputs=[INPUT_KEY, INPUT_PLAINTEXT], + cipher_inputs_bit_size=[self.KEY_BLOCK_SIZE, self.CIPHER_BLOCK_SIZE], + cipher_output_bit_size=self.CIPHER_BLOCK_SIZE, + ) + # fmt: off + self.sbox = [ + 0xd6, 0x90, 0xe9, 0xfe, 0xcc, 0xe1, 0x3d, 0xb7, 0x16, 0xb6, 0x14, 0xc2, 0x28, 0xfb, 0x2c, 0x05, + 0x2b, 0x67, 0x9a, 0x76, 0x2a, 0xbe, 0x04, 0xc3, 0xaa, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99, + 0x9c, 0x42, 0x50, 0xf4, 0x91, 0xef, 0x98, 0x7a, 0x33, 0x54, 0x0b, 0x43, 0xed, 0xcf, 0xac, 0x62, + 0xe4, 0xb3, 0x1c, 0xa9, 0xc9, 0x08, 0xe8, 0x95, 0x80, 0xdf, 0x94, 0xfa, 0x75, 0x8f, 0x3f, 0xa6, + 0x47, 0x07, 0xa7, 0xfc, 0xf3, 0x73, 0x17, 0xba, 0x83, 0x59, 0x3c, 0x19, 0xe6, 0x85, 0x4f, 0xa8, + 0x68, 0x6b, 0x81, 0xb2, 0x71, 0x64, 0xda, 0x8b, 0xf8, 0xeb, 0x0f, 0x4b, 0x70, 0x56, 0x9d, 0x35, + 0x1e, 0x24, 0x0e, 0x5e, 0x63, 0x58, 0xd1, 0xa2, 0x25, 0x22, 0x7c, 0x3b, 0x01, 0x21, 0x78, 0x87, + 0xd4, 0x00, 0x46, 0x57, 0x9f, 0xd3, 0x27, 0x52, 0x4c, 0x36, 0x02, 0xe7, 0xa0, 0xc4, 0xc8, 0x9e, + 0xea, 0xbf, 0x8a, 0xd2, 0x40, 0xc7, 0x38, 0xb5, 0xa3, 0xf7, 0xf2, 0xce, 0xf9, 0x61, 0x15, 0xa1, + 0xe0, 0xae, 0x5d, 0xa4, 0x9b, 0x34, 0x1a, 0x55, 0xad, 0x93, 0x32, 0x30, 0xf5, 0x8c, 0xb1, 0xe3, + 0x1d, 0xf6, 0xe2, 0x2e, 0x82, 0x66, 0xca, 0x60, 0xc0, 0x29, 0x23, 0xab, 0x0d, 0x53, 0x4e, 0x6f, + 0xd5, 0xdb, 0x37, 0x45, 0xde, 0xfd, 0x8e, 0x2f, 0x03, 0xff, 0x6a, 0x72, 0x6d, 0x6c, 0x5b, 0x51, + 0x8d, 0x1b, 0xaf, 0x92, 0xbb, 0xdd, 0xbc, 0x7f, 0x11, 0xd9, 0x5c, 0x41, 0x1f, 0x10, 0x5a, 0xd8, + 0x0a, 0xc1, 0x31, 0x88, 0xa5, 0xcd, 0x7b, 0xbd, 0x2d, 0x74, 0xd0, 0x12, 0xb8, 0xe5, 0xb4, 0xb0, + 0x89, 0x69, 0x97, 0x4a, 0x0c, 0x96, 0x77, 0x7e, 0x65, 0xb9, 0xf1, 0x09, 0xc5, 0x6e, 0xc6, 0x84, + 0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48, + ] + self.FK = [ + 0xa3b1bac6, 0x56aa3350, 0x677d9197, 0xb27022dc, + ] + + self.CK = [ + 0x00070e15, 0x1c232a31, 0x383f464d, 0x545b6269, + 0x70777e85, 0x8c939aa1, 0xa8afb6bd, 0xc4cbd2d9, + 0xe0e7eef5, 0xfc030a11, 0x181f262d, 0x343b4249, + 0x50575e65, 0x6c737a81, 0x888f969d, 0xa4abb2b9, + 0xc0c7ced5, 0xdce3eaf1, 0xf8ff060d, 0x141b2229, + 0x30373e45, 0x4c535a61, 0x686f767d, 0x848b9299, + 0xa0a7aeb5, 0xbcc3cad1, 0xd8dfe6ed, 0xf4fb0209, + 0x10171e25, 0x2c333a41, 0x484f565d, 0x646b7279, + ] + # fmt: on + + word_bits = word_size * 4 + + def encrypt_block(self, INPUT_PLAINTEXT, INPUT_KEY, word_bits): + K = [] + for idx in range(4): + FK_const = self.add_constant_component(word_bits, self.FK[idx]) + Ki = self.add_XOR_component( + [INPUT_KEY, FK_const.id], + [ + [i for i in range(idx * 32, (idx + 1) * 32)], + [i for i in range(32)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + K.append(Ki) + + X = [ + {"id": INPUT_PLAINTEXT, "bit_position": list(range(32))}, + {"id": INPUT_PLAINTEXT, "bit_position": list(range(32, 64))}, + {"id": INPUT_PLAINTEXT, "bit_position": list(range(64, 96))}, + {"id": INPUT_PLAINTEXT, "bit_position": list(range(96, 128))}, + ] + + for i in range(self.NROUNDS): + CK_const = self.add_constant_component(word_bits, self.CK[i]) + + t1 = self.add_XOR_component( + [K[i + 3].id, CK_const.id], + [ + [j for j in range(self.KEY_BLOCK_SIZE // 4)], + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + t2 = self.add_XOR_component( + [t1.id, K[i + 2].id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + [j for j in range(self.KEY_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + temp_k = self.add_XOR_component( + [t2.id, K[i + 1].id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + [j for j in range(self.KEY_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + + sboxes_k = [ + self.add_SBOX_component( + [temp_k.id], [[b * 8 + k for k in range(8)]], 8, self.sbox + ) + for b in range(4) + ] + ids_k = [c.id for c in sboxes_k] + pos_k = [list(range(8)) for _ in range(4)] + + rot13 = self.add_rotate_component(ids_k, pos_k, word_bits, -13) + rot23 = self.add_rotate_component(ids_k, pos_k, word_bits, -23) + xor_rot = self.add_XOR_component( + [rot13.id, rot23.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + + L_prime = self.add_XOR_component( + [sboxes_k[b].id for b in range(4)] + [xor_rot.id], + [list(range(8)) for _ in range(4)] + [list(range(32))], + 32, + ) + + Ki4 = self.add_XOR_component( + [L_prime.id, K[i].id], + [ + [j for j in range(self.KEY_BLOCK_SIZE // 4)], + [j for j in range(self.KEY_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + K.append(Ki4) + rk_i = Ki4 + + t1 = self.add_XOR_component( + [X[i + 3]["id"], rk_i.id], + [ + X[i + 3]["bit_position"], + [j for j in range(self.KEY_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + t2 = self.add_XOR_component( + [t1.id, X[i + 2]["id"]], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + X[i + 2]["bit_position"], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + temp_x = self.add_XOR_component( + [t2.id, X[i + 1]["id"]], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + X[i + 1]["bit_position"], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + + sboxes_x = [ + self.add_SBOX_component( + [temp_x.id], [[b * 8 + k for k in range(8)]], 8, self.sbox + ) + for b in range(4) + ] + ids_x = [c.id for c in sboxes_x] + pos_x = [list(range(8)) for _ in range(4)] + + rot2 = self.add_rotate_component(ids_x, pos_x, word_bits, -2) + rot10 = self.add_rotate_component(ids_x, pos_x, word_bits, -10) + rot18 = self.add_rotate_component(ids_x, pos_x, word_bits, -18) + rot24 = self.add_rotate_component(ids_x, pos_x, word_bits, -24) + + xor_rot2_10 = self.add_XOR_component( + [rot2.id, rot10.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + xor_rot18_24 = self.add_XOR_component( + [rot18.id, rot24.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + xor_all = self.add_XOR_component( + [xor_rot2_10.id, xor_rot18_24.id], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + + L_out = self.add_XOR_component( + [sboxes_x[b].id for b in range(4)] + [xor_all.id], + [list(range(8)) for _ in range(4)] + + [list(range(self.CIPHER_BLOCK_SIZE // 4))], + 32, + ) + + Xi4 = self.add_XOR_component( + [L_out.id, X[i]["id"]], + [ + [j for j in range(self.CIPHER_BLOCK_SIZE // 4)], + X[i]["bit_position"], + ], + self.CIPHER_BLOCK_SIZE // 4, + ) + + self.add_round_key_output_component( + [rk_i.id], + [list(range(self.CIPHER_BLOCK_SIZE // 4))], + self.CIPHER_BLOCK_SIZE // 4, + ) + self.add_round_output_component( + [Xi4.id], + [list(range(self.CIPHER_BLOCK_SIZE // 4))], + self.CIPHER_BLOCK_SIZE // 4, + ) + + X.append( + { + "id": Xi4.id, + "bit_position": list(range(self.CIPHER_BLOCK_SIZE // 4)), + } + ) + if i < self.NROUNDS - 1: + self.add_round() + + C0 = X[self.NROUNDS + 3] + C1 = X[self.NROUNDS + 2] + C2 = X[self.NROUNDS + 1] + C3 = X[self.NROUNDS] + return C0, C1, C2, C3 + + self.add_round() + C0, C1, C2, C3 = encrypt_block(self, INPUT_PLAINTEXT, INPUT_KEY, word_bits) + self.add_cipher_output_component( + [C0["id"], C1["id"], C2["id"], C3["id"]], + [ + C0["bit_position"], + C1["bit_position"], + C2["bit_position"], + C3["bit_position"], + ], + self.CIPHER_BLOCK_SIZE, + ) diff --git a/claasp/ciphers/block_ciphers/threefish_block_cipher.py b/claasp/ciphers/block_ciphers/threefish_block_cipher.py index 68cb08e16..35d2267d2 100644 --- a/claasp/ciphers/block_ciphers/threefish_block_cipher.py +++ b/claasp/ciphers/block_ciphers/threefish_block_cipher.py @@ -18,11 +18,10 @@ from math import log2 from claasp.cipher import Cipher +from claasp.name_mappings import BLOCK_CIPHER, INPUT_PLAINTEXT, INPUT_KEY, INPUT_TWEAK from claasp.utils.utils import extract_inputs -from claasp.name_mappings import BLOCK_CIPHER, INPUT_PLAINTEXT, INPUT_KEY -INPUT_TWEAK = "tweak" ROUND_CONSTANTS = [ [[0x0E, 0x10], [0x2E, 0x24, 0x13, 0x25], [0x18, 0x0D, 0x08, 0x2F, 0x08, 0x11, 0x16, 0x25]], [[0x34, 0x39], [0x21, 0x1B, 0x0E, 0x2A], [0x26, 0x13, 0x0A, 0x37, 0x31, 0x12, 0x17, 0x34]], diff --git a/claasp/ciphers/stream_ciphers/chacha_stream_cipher.py b/claasp/ciphers/stream_ciphers/chacha_stream_cipher.py index 75bfc2d38..4524dc304 100644 --- a/claasp/ciphers/stream_ciphers/chacha_stream_cipher.py +++ b/claasp/ciphers/stream_ciphers/chacha_stream_cipher.py @@ -124,3 +124,5 @@ def __init__(self, block_bit_size=512, key_bit_size=256, number_of_rounds=20, component = self.component_from(last_round, component_number) if component.type == "cipher_output": component.set_input_id_links(lst_ids) + + self.sort_cipher() diff --git a/claasp/ciphers/stream_ciphers/trivium_stream_cipher.py b/claasp/ciphers/stream_ciphers/trivium_stream_cipher.py index b1847e1d8..37b596cf5 100644 --- a/claasp/ciphers/stream_ciphers/trivium_stream_cipher.py +++ b/claasp/ciphers/stream_ciphers/trivium_stream_cipher.py @@ -118,7 +118,7 @@ def get_keystream_bit_len(self, keystream_bit_len): def trivium_state_initialization(self, key, iv): cst0 = self.add_constant_component(13, 0x0).id - cst1 = self.add_constant_component(111, 0xE000000000000000000000000000).id + cst1 = self.add_constant_component(111, 0x7000000000000000000000000000).id state0_id = [cst0] + key[0] + [cst0] + iv[0] + [cst1] state0_pos = [ @@ -128,13 +128,8 @@ def trivium_state_initialization(self, key, iv): list(range(self.iv_bit_size)), list(range(111)), ] - triv_state = self.add_FSR_component(state0_id, state0_pos, self.state_bit_size, NLFSR_DESCR).id - triv_state = self.add_FSR_component( - [triv_state], - [list(range(self.state_bit_size))], - self.state_bit_size, - NLFSR_DESCR + [self.number_of_initialization_clocks - 1], - ).id + triv_state = self.add_FSR_component(state0_id, state0_pos, self.state_bit_size, + NLFSR_DESCR + [self.number_of_initialization_clocks]).id return triv_state def trivium_key_stream(self, state, clock_number, key_stream): diff --git a/claasp/component.py b/claasp/component.py index aa2946448..c2321c8f4 100644 --- a/claasp/component.py +++ b/claasp/component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -27,6 +26,15 @@ from claasp.cipher_modules.models.sat.utils import constants from claasp.DTOs.power_of_2_word_based_dto import PowerOf2WordBasedDTO +from claasp.name_mappings import ( + CIPHER_OUTPUT, + CONCATENATE, + INTERMEDIATE_OUTPUT, + LINEAR_LAYER, + MIX_COLUMN, + SBOX, + WORD_OPERATION, +) def check_size(position_list, size): @@ -37,7 +45,7 @@ def check_size(position_list, size): if position_list[j] % size == 0 and (position_list[j + size - 1] + 1) % size == 0: # check consecutive positions i = position_list[j] - for position in position_list[j + 1:j + size]: + for position in position_list[j + 1 : j + size]: i += 1 if i != position: return False @@ -55,7 +63,7 @@ def linear_layer_to_binary_matrix(linear_layer_function, input_bit_size, output_ for i in range(p_matrix.nrows()): p_matrix[i] = vector_space.random_element() - c_matrix = matrix(GF(2), input_bit_size, output_bit_size)#, input_bit_size) + c_matrix = matrix(GF(2), input_bit_size, output_bit_size) # , input_bit_size) for i in range(c_matrix.nrows()): result = linear_layer_function(BitArray(list(p_matrix[i])), *list_specific_inputs) c_matrix[i] = vector(GF(2), result) @@ -64,11 +72,18 @@ def linear_layer_to_binary_matrix(linear_layer_function, input_bit_size, output_ def free_input(code): - code.append('\tdelete_bitstring(input);\n') + code.append("\tdelete_bitstring(input);\n") class Component: - def __init__(self, component_id, component_type, component_input, output_bit_size, description): + def __init__( + self, + component_id, + component_type, + component_input, + output_bit_size, + description, + ): if not isinstance(component_input.id_links, list): print("type of [input_id_link] should be a list") return @@ -97,13 +112,13 @@ def __init__(self, component_id, component_type, component_input, output_bit_siz self._input = deepcopy(component_input) self._output_bit_size = output_bit_size self._description = description - self._suffixes = ['_i', '_o'] + self._suffixes = ("_i", "_o") def _create_minizinc_1d_array_from_list(self, mzn_list): mzn_list_size = len(mzn_list) - lst_temp = f'[{",".join(mzn_list)}]' + lst_temp = f"[{','.join(mzn_list)}]" - return f'array1d(0..{mzn_list_size}-1, {lst_temp})' + return f"array1d(0..{mzn_list_size}-1, {lst_temp})" def _define_var(self, input_postfix, output_postfix, data_type): """ @@ -121,10 +136,10 @@ def _define_var(self, input_postfix, output_postfix, data_type): output_size = self.output_bit_size var_names_temp = [] if self.type != "constant": - var_names_temp += [component_id + "_" + input_postfix + str(i) for i in range(input_size)] - var_names_temp += [component_id + "_" + output_postfix + str(i) for i in range(output_size)] + var_names_temp += [f"{component_id}_{input_postfix}{i}" for i in range(input_size)] + var_names_temp += [f"{component_id}_{output_postfix}{i}" for i in range(output_size)] for i in range(len(var_names_temp)): - var_definition_names.append(f'var {data_type}: {var_names_temp[i]};') + var_definition_names.append(f"var {data_type}: {var_names_temp[i]};") return var_definition_names @@ -132,35 +147,35 @@ def _generate_component_input_ids(self): input_id_link = self.id in_suffix = constants.INPUT_BIT_ID_SUFFIX input_bit_size = self.input_bit_size - input_bit_ids = [f'{input_id_link}_{i}{in_suffix}' for i in range(input_bit_size)] + input_bit_ids = [f"{input_id_link}_{i}{in_suffix}" for i in range(input_bit_size)] return input_bit_size, input_bit_ids - def _generate_input_ids(self, suffix=''): + def _generate_input_ids(self, suffix=""): input_id_link = self.input_id_links input_bit_positions = self.input_bit_positions input_bit_ids = [] for link, positions in zip(input_id_link, input_bit_positions): - input_bit_ids.extend([f'{link}_{j}{suffix}' for j in positions]) + input_bit_ids.extend([f"{link}_{j}{suffix}" for j in positions]) - return self.input_bit_size, input_bit_ids + return input_bit_ids def _generate_input_double_ids(self): - _, in_ids_0 = self._generate_input_ids(suffix='_0') - _, in_ids_1 = self._generate_input_ids(suffix='_1') + in_ids_0 = self._generate_input_ids(suffix="_0") + in_ids_1 = self._generate_input_ids(suffix="_1") return in_ids_0, in_ids_1 - def _generate_output_ids(self, suffix=''): + def _generate_output_ids(self, suffix=""): output_id_link = self.id output_bit_size = self.output_bit_size - output_bit_ids = [f'{output_id_link}_{j}{suffix}' for j in range(output_bit_size)] + output_bit_ids = [f"{output_id_link}_{j}{suffix}" for j in range(output_bit_size)] return output_bit_size, output_bit_ids def _generate_output_double_ids(self): - out_len, out_ids_0 = self._generate_output_ids(suffix='_0') - _, out_ids_1 = self._generate_output_ids(suffix='_1') + out_len, out_ids_0 = self._generate_output_ids(suffix="_0") + _, out_ids_1 = self._generate_output_ids(suffix="_1") return out_len, out_ids_0, out_ids_1 @@ -283,11 +298,17 @@ def _get_input_output_variables_tuples(self): """ tuple_size = 2 - output_ids_tuple = [tuple(f"{self.id}_{i}_class_bit_{j}" for j in range(tuple_size)) for i in range(self.output_bit_size)] + output_ids_tuple = [ + tuple(f"{self.id}_{i}_class_bit_{j}" for j in range(tuple_size)) for i in range(self.output_bit_size) + ] input_ids_tuple = [] for index, link in enumerate(self.input_id_links): - input_ids_tuple.extend([tuple(f"{link}_{pos}_class_bit_{j}" for j in range(tuple_size)) for pos in self.input_bit_positions[index]]) - + input_ids_tuple.extend( + [ + tuple(f"{link}_{pos}_class_bit_{j}" for j in range(tuple_size)) + for pos in self.input_bit_positions[index] + ] + ) return input_ids_tuple, output_ids_tuple @@ -329,13 +350,12 @@ def _get_wordwise_input_output_linked_class(self, model): 'rot_0_18_word_3_class'] """ - output_class_ids = [self.id + '_word_' + str(i) + '_class' for i in - range(self.output_bit_size // model.word_size)] + output_class_ids = [f"{self.id}_word_{i}_class" for i in range(self.output_bit_size // model.word_size)] input_class_ids = [] for index, link in enumerate(self.input_id_links): - for pos in self.input_bit_positions[index][::model.word_size]: - input_class_ids.append(link + '_word_' + str(pos // model.word_size) + '_class') + for pos in self.input_bit_positions[index][:: model.word_size]: + input_class_ids.append(f"{link}_word_{pos // model.word_size}_class") return input_class_ids, output_class_ids @@ -372,14 +392,11 @@ def _get_wordwise_input_output_linked_class_tuples(self, model): tuple_size = 2 input_class, output_class = self._get_wordwise_input_output_linked_class(model) - output_class_tuples = [tuple(f"{id}_bit_{i}" for i in range(tuple_size)) for id in - output_class] - input_class_tuples = [tuple(f"{id}_bit_{i}" for i in range(tuple_size)) for id in - input_class] + output_class_tuples = [tuple(f"{id}_bit_{i}" for i in range(tuple_size)) for id in output_class] + input_class_tuples = [tuple(f"{id}_bit_{i}" for i in range(tuple_size)) for id in input_class] return input_class_tuples, output_class_tuples - def _get_wordwise_input_output_full_tuples(self, model): """ @@ -420,23 +437,26 @@ def _get_wordwise_input_output_full_tuples(self, model): input_ids, output_ids = self._get_input_output_variables() input_class_id_tuples, output_class_id_tuples = self._get_wordwise_input_output_linked_class_tuples(model) - input_full_tuple = [tuple(list(input_class_id_tuples[i]) + input_ids[i * word_size: (i + 1) * word_size]) for i in - range(len(input_ids) // word_size)] - output_full_tuple = [tuple(list(output_class_id_tuples[i]) + output_ids[i * word_size: (i + 1) * word_size]) for i in - range(len(output_ids) // word_size)] + input_full_tuple = [ + tuple(list(input_class_id_tuples[i]) + input_ids[i * word_size : (i + 1) * word_size]) + for i in range(len(input_ids) // word_size) + ] + output_full_tuple = [ + tuple(list(output_class_id_tuples[i]) + output_ids[i * word_size : (i + 1) * word_size]) + for i in range(len(output_ids) // word_size) + ] return input_full_tuple, output_full_tuple - def as_python_dictionary(self): return { - 'id': self._id, - 'type': self._type, - 'input_bit_size': self.input_bit_size, - 'input_id_link': self.input_id_links, - 'input_bit_positions': self.input_bit_positions, - 'output_bit_size': self._output_bit_size, - 'description': self._description + "id": self._id, + "type": self._type, + "input_bit_size": self.input_bit_size, + "input_id_link": self.input_id_links, + "input_bit_positions": self.input_bit_positions, + "output_bit_size": self._output_bit_size, + "description": self._description, } def get_graph_representation(self): @@ -447,18 +467,18 @@ def get_graph_representation(self): "input_id_link": deepcopy(self._input.id_links), "input_bit_positions": deepcopy(self._input.bit_positions), "output_bit_size": self._output_bit_size, - "description": self._description + "description": self._description, } def is_id_equal_to(self, component_id): return self._id == component_id def is_power_of_2_word_based(self, dto): - available_word_sizes = [64, 32, 16, 8] + available_word_sizes = (64, 32, 16, 8) fixed = dto.fixed word_size = dto.word_size - if self._type in ('sbox', 'mix_column', 'linear_layer'): + if self._type in (SBOX, MIX_COLUMN, LINEAR_LAYER): return PowerOf2WordBasedDTO(False, fixed) # Check output size @@ -467,7 +487,7 @@ def is_power_of_2_word_based(self, dto): return PowerOf2WordBasedDTO(False, fixed) # Check input positions and size - if self._type != 'constant': + if self._type != "constant": valid_sizes = [positions for positions in self.input_bit_positions if not check_size(positions, word_size)] if valid_sizes or self.input_bit_size % word_size != 0: return PowerOf2WordBasedDTO(False, fixed) @@ -475,7 +495,7 @@ def is_power_of_2_word_based(self, dto): return PowerOf2WordBasedDTO(word_size, fixed) def check_output_size(self, available_word_sizes, fixed, word_size): - if self._type in ('concatenate', 'intermediate_output', 'cipher_output'): + if self._type in (CONCATENATE, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT): word_size = self.output_size_for_concatenate(available_word_sizes, fixed, word_size) if word_size is None: return None, fixed @@ -494,8 +514,11 @@ def output_size_for_concatenate(self, available_word_sizes, fixed, word_size): if word_sizes: word_size = word_sizes[0] else: - word_sizes = [size for size in available_word_sizes[available_word_sizes.index(word_size):] - if self._output_bit_size % size != 0] + word_sizes = [ + size + for size in available_word_sizes[available_word_sizes.index(word_size) :] + if self._output_bit_size % size != 0 + ] if (fixed and self._output_bit_size % word_size != 0) or (not fixed and not word_sizes): word_size = None elif not fixed: @@ -506,7 +529,7 @@ def output_size_for_concatenate(self, available_word_sizes, fixed, word_size): def is_forbidden(self, forbidden_types, forbidden_descriptions): if self._type in forbidden_types: return True - if self._type == "word_operation" and self._description[0] in forbidden_descriptions: + if self._type == WORD_OPERATION and self._description[0] in forbidden_descriptions: return True return False @@ -521,8 +544,8 @@ def print(self): print(f" description =", self._description) def print_as_python_dictionary(self): - print(" 'id': '" + self._id + "',") - print(" 'type': '" + self._type + "',") + print(f" 'id': '{self._id}',") + print(f" 'type': '{self._type}',") print(f" 'input_bit_size': {self.input_bit_size},") print(f" 'input_id_link': {self.input_id_links},") print(f" 'input_bit_positions': {self.input_bit_positions},") @@ -540,29 +563,29 @@ def set_input_bit_positions(self, bit_positions): def print_values(self, code): code.append(f'\tprintf("{self.id}_input = ");') - code.append('\tprint_bitstring(input, 16);') + code.append("\tprint_bitstring(input, 16);") code.append(f'\tprintf("{self.id}_output = ");') - code.append(f'\tprint_bitstring({self.id}, 16);\n') + code.append(f"\tprint_bitstring({self.id}, 16);\n") def print_word_values(self, code): code.append(f'\tprintf("{self.id}_input = ");') - code.append('\tprint_wordstring(input, 16);') + code.append("\tprint_wordstring(input, 16);") code.append(f'\tprintf("{self.id}_output = ");') - code.append(f'\tprint_wordstring({self.id}, 16);\n') + code.append(f"\tprint_wordstring({self.id}, 16);\n") def select_bits(self, code): n = len(self.input_id_links) - code.append((f'\tinput_id = (BitString*[]) {{{", ".join(self.input_id_links)}}};\n' - f'\tinput_positions = (uint16_t*[]) {{')) + code.append( + (f"\tinput_id = (BitString*[]) {{{', '.join(self.input_id_links)}}};\n\tinput_positions = (uint16_t*[]) {{") + ) for position_list in self.input_bit_positions: - code.append( - (f'\t\t(uint16_t[]) {{{len(position_list)}, {", ".join([str(p) for p in position_list])}}},')) + code.append((f"\t\t(uint16_t[]) {{{len(position_list)}, {', '.join(map(str, position_list))}}},")) - code.append('\t};') + code.append("\t};") - code.append(f'\tinput = select_bits({n}, input_id, input_positions, {self.output_bit_size});') + code.append(f"\tinput = select_bits({n}, input_id, input_positions, {self.output_bit_size});") def select_words(self, code, word_size, input=True): word_list = [] @@ -570,18 +593,18 @@ def select_words(self, code, word_size, input=True): for position_list in self.input_bit_positions: for j in range(0, len(position_list), word_size): - word_list.append(f'{self.input_id_links[i]} -> list[{position_list[j] // word_size}]') + word_list.append(f"{self.input_id_links[i]} -> list[{position_list[j] // word_size}]") i += 1 if input: - code.append(f'\tinput -> list = (Word[]) {{{", ".join(word_list)}}};') - code.append(f'\tinput -> string_size = {len(word_list)};') + code.append(f"\tinput -> list = (Word[]) {{{', '.join(word_list)}}};") + code.append(f"\tinput -> string_size = {len(word_list)};") else: - code.append(f'\tWordString* {self.id} = create_wordstring({len(word_list)}, false);') + code.append(f"\tWordString* {self.id} = create_wordstring({len(word_list)}, false);") code.append( - f'\tmemcpy({self.id} -> ' - f'list, (Word[]) {{{", ".join(word_list)}}}, {len(word_list)} * sizeof(Word));') + f"\tmemcpy({self.id} -> list, (Word[]) {{{', '.join(word_list)}}}, {len(word_list)} * sizeof(Word));" + ) def set_id(self, id_string): self._id = id_string diff --git a/claasp/components/and_component.py b/claasp/components/and_component.py index 19c0130a8..743bb2b44 100644 --- a/claasp/components/and_component.py +++ b/claasp/components/and_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -24,8 +23,9 @@ def cp_twoterms(model, inp1, inp2, out, cp_constraints): - cp_constraints.append(f'constraint Ham_weight(Andz({inp1}, {inp2}, {out})) == 0 /\\ p[{model.c}] = ' - f'Ham_weight(OR({inp1}, {inp2}));') + cp_constraints.append( + f"constraint Ham_weight(Andz({inp1}, {inp2}, {out})) == 0 /\\ p[{model.c}] = Ham_weight(OR({inp1}, {inp2}));" + ) return cp_constraints @@ -50,9 +50,9 @@ def cp_xor_differential_probability_ddt(numadd): count = 0 for j in range(n): k = i ^ j - binary_j = format(j, f'0{numadd}b') + binary_j = f"{j:0{numadd}b}" result_j = 1 - binary_k = format(k, f'0{numadd}b') + binary_k = f"{k:0{numadd}b}" result_k = 1 for addenda in range(numadd): result_j *= int(binary_j[addenda]) @@ -82,23 +82,35 @@ def cp_xor_linear_probability_lat(numadd): lat = [] for full_mask in range(2 ** (numadd + 1)): num_of_matches = 0 - for values in range(2 ** numadd): + for values in range(2**numadd): full_values = values << 1 bit_of_values = (values >> i & 1 for i in range(numadd)) full_values ^= 0 not in bit_of_values equation = full_values & full_mask addenda = (equation >> i & 1 for i in range(numadd + 1)) - num_of_matches += (sum(addenda) % 2 == 0) + num_of_matches += sum(addenda) % 2 == 0 lat.append(num_of_matches - (2 ** (numadd - 1))) return lat class AND(MultiInputNonlinearLogicalOperator): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, 'and') + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + "and", + ) def algebraic_polynomials(self, model): """ @@ -133,9 +145,9 @@ def algebraic_polynomials(self, model): noutputs = self.output_bit_size word_size = noutputs ring_R = model.ring() - input_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] - words_vars = [list(map(ring_R, input_vars))[i:i + word_size] for i in range(0, ninputs, word_size)] + input_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] + words_vars = [list(map(ring_R, input_vars))[i : i + word_size] for i in range(0, ninputs, word_size)] x = [ring_R.one() for _ in range(noutputs)] for word_vars in words_vars: @@ -164,19 +176,15 @@ def cp_constraints(self): ... 'constraint and_0_8[11] = xor_0_7[11] * key[23];']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_constraints = [] - for i in range(output_size): - operation = ' * '.join(all_inputs[i::output_size]) - new_constraint = f'constraint {output_id_link}[{i}] = {operation};' - cp_constraints.append(new_constraint) + for i in range(self.output_bit_size): + operation = " * ".join(all_inputs[i :: self.output_bit_size]) + cp_constraint = f"constraint {self.id}[{i}] = {operation};" + cp_constraints.append(cp_constraint) return cp_declarations, cp_constraints @@ -202,39 +210,39 @@ def cp_xor_linear_mask_propagation_constraints(self, model): ... 'constraint table([and_0_8_i[11]]++[and_0_8_i[23]]++[and_0_8_o[11]]++[p[11]],and2inputs_LAT);']) """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) output_id_link = self.id cp_declarations = [] cp_constraints = [] num_add = self.description[1] - input_len = input_size // num_add - cp_declarations.append(f'array[0..{input_size - 1}] of var 0..1:{output_id_link}_i;') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1:{output_id_link}_o;') + input_len = self.input_bit_size // num_add + cp_declarations.append(f"array[0..{self.input_bit_size - 1}] of var 0..1:{output_id_link}_i;") + cp_declarations.append(f"array[0..{self.output_bit_size - 1}] of var 0..1:{output_id_link}_o;") model.component_and_probability[output_id_link] = 0 probability = [] - for i in range(output_size): - new_constraint = f'constraint table(' + for i in range(self.output_bit_size): + new_constraint = "constraint table(" for j in range(num_add): - new_constraint = new_constraint + f'[{output_id_link}_i[{i + input_len * j}]]++' + new_constraint = new_constraint + f"[{output_id_link}_i[{i + input_len * j}]]++" if model.float_and_lat_values: - cp_declarations.append(f'var :p_{output_id_link}_{i};') - new_constraint = \ - new_constraint + f'[{output_id_link}_o[{i}]]++[p_{output_id_link}_{i}],and{num_add}inputs_LAT);' + cp_declarations.append(f"var :p_{output_id_link}_{i};") + new_constraint = ( + new_constraint + f"[{output_id_link}_o[{i}]]++[p_{output_id_link}_{i}],and{num_add}inputs_LAT);" + ) cp_constraints.append(new_constraint) for k in range(len(model.float_and_lat_values)): rounded_float = round(float(model.float_and_lat_values[k]), 2) cp_constraints.append( - f'constraint if p_{output_id_link}_{i} == {1000 + k} then p[{model.c}]={rounded_float} else ' - f'p[{model.c}]=p_{output_id_link}_{i} endif;') + f"constraint if p_{output_id_link}_{i} == {1000 + k} then p[{model.c}]={rounded_float} else " + f"p[{model.c}]=p_{output_id_link}_{i} endif;" + ) else: - new_constraint = new_constraint + f'[{output_id_link}_o[{i}]]++[p[{model.c}]],and{num_add}inputs_LAT);' + new_constraint = new_constraint + f"[{output_id_link}_o[{i}]]++[p[{model.c}]],and{num_add}inputs_LAT);" cp_constraints.append(new_constraint) probability.append(model.c) model.c += 1 model.component_and_probability[output_id_link] = probability - result = cp_declarations, cp_constraints - return result + + return cp_declarations, cp_constraints def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, model): """ @@ -281,14 +289,16 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode variables = [(f"x_class[{var}]", x_class[var]) for var in input_vars + output_vars] constraints = [] - a = [[x_class[input_vars[i + chunk * input_bit_size]] for chunk in range(number_of_inputs)] for i in - range(input_bit_size)] + a = [ + [x_class[input_vars[i + chunk * input_bit_size]] for chunk in range(number_of_inputs)] + for i in range(input_bit_size) + ] b = [x_class[output_vars[i]] for i in range(output_bit_size)] upper_bound = model._model.get_max(x_class) for i in range(output_bit_size): - input_sum = sum([a[i][chunk] for chunk in range(number_of_inputs)]) + input_sum = sum(a[i][chunk] for chunk in range(number_of_inputs)) # if d_leq == 1 if sum(a_i) <= 0 d_leq, c_leq = milp_utils.milp_leq(model, input_sum, 0, number_of_inputs * upper_bound) constraints += c_leq @@ -327,10 +337,10 @@ def generic_sign_linear_constraints(self, inputs, outputs): return sign def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_AND([{",".join(params)} ], {self.description[1]}, {self.output_bit_size})'] + return [f" {self.id} = bit_vector_AND([{','.join(params)} ], {self.description[1]}, {self.output_bit_size})"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} =byte_vector_AND({params})'] + return [f" {self.id} = byte_vector_AND({params})"] def sat_constraints(self): """ @@ -365,7 +375,7 @@ def sat_constraints(self): '-and_0_8_11 key_23', 'and_0_8_11 -xor_0_7_11 -key_23']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -402,7 +412,7 @@ def smt_constraints(self): '(assert (= and_0_8_10 (and xor_0_7_10 key_22)))', '(assert (= and_0_8_11 (and xor_0_7_11 key_23)))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): diff --git a/claasp/components/cipher_output_component.py b/claasp/components/cipher_output_component.py index 1e82f6bd8..d34a8fb0e 100644 --- a/claasp/components/cipher_output_component.py +++ b/claasp/components/cipher_output_component.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -20,21 +20,30 @@ from claasp.component import Component from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.cipher_modules.models.sat.utils import constants, utils as sat_utils +from claasp.name_mappings import CIPHER_OUTPUT, INTERMEDIATE_OUTPUT class CipherOutput(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, is_intermediate=False, output_tag=""): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + is_intermediate=False, + output_tag="", + ): if is_intermediate: - component_type = 'intermediate_output' + component_type = INTERMEDIATE_OUTPUT description = [output_tag] else: - component_type = 'cipher_output' - description = ['cipher_output'] - component_id = f'{component_type}_{current_round_number}_{current_round_number_of_components}' + component_type = CIPHER_OUTPUT + description = [CIPHER_OUTPUT] + component_id = f"{component_type}_{current_round_number}_{current_round_number_of_components}" component_input = Input(output_bit_size, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) - self._suffixes = ['_o'] + self._suffixes = ["_o"] def cms_constraints(self): """ @@ -91,15 +100,11 @@ def cp_constraints(self): ... 'constraint cipher_output_2_12[31] = xor_2_10[15];']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[i]};' for i in range(output_size)] + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) + cp_constraints = [f"constraint {self.id}[{i}] = {all_inputs[i]};" for i in range(self.output_bit_size)] return cp_declarations, cp_constraints @@ -131,22 +136,27 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model ... 'constraint intermediate_output_0_35_active[15] = xor_0_34_active[3];']) """ - input_id_link = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions + cp_declarations = [] all_inputs_active = [] all_inputs_value = [] - cp_declarations = [] - for id_link, bit_positions in zip(input_id_link, input_bit_positions): - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * model.word_size] // model.word_size}]' - for j in range(len(bit_positions) // model.word_size)]) - for id_link, bit_positions in zip(input_id_link, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * model.word_size] // model.word_size}]' - for j in range(len(bit_positions) // model.word_size)]) - cp_constraints = [f'constraint {output_id_link}_value[{i}] = {input_};' - for i, input_ in enumerate(all_inputs_value)] - cp_constraints.extend([f'constraint {output_id_link}_active[{i}] = {input_};' - for i, input_ in enumerate(all_inputs_active)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * model.word_size] // model.word_size}]" + for j in range(len(bit_positions) // model.word_size) + ] + ) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * model.word_size] // model.word_size}]" + for j in range(len(bit_positions) // model.word_size) + ] + ) + cp_constraints = [f"constraint {self.id}_value[{i}] = {input_};" for i, input_ in enumerate(all_inputs_value)] + cp_constraints.extend( + [f"constraint {self.id}_active[{i}] = {input_};" for i, input_ in enumerate(all_inputs_active)] + ) return cp_declarations, cp_constraints @@ -174,20 +184,19 @@ def cp_xor_differential_propagation_first_step_constraints(self, model): ... 'constraint intermediate_output_0_35[15] = xor_0_34[3];']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions + cp_declarations = [f"array[0..{(self.output_bit_size - 1) // model.word_size}] of var 0..1: {self.id};"] all_inputs = [] cp_constraints = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{bit_positions[j * model.word_size] // model.word_size}]' - for j in range(len(bit_positions) // model.word_size)]) - cp_declarations = [f'array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};'] - cp_constraints.extend([f'constraint {output_id_link}[{i}] = {input_};' - for i, input_ in enumerate(all_inputs)]) - result = cp_declarations, cp_constraints - return result + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend( + [ + f"{id_link}[{bit_positions[j * model.word_size] // model.word_size}]" + for j in range(len(bit_positions) // model.word_size) + ] + ) + cp_constraints.extend([f"constraint {self.id}[{i}] = {input_};" for i, input_ in enumerate(all_inputs)]) + + return cp_declarations, cp_constraints def cp_xor_differential_propagation_constraints(self, model): return self.cp_constraints() @@ -217,41 +226,42 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): 'constraint cipher_output_21_12_o[30] = cipher_output_21_12_i[30];', 'constraint cipher_output_21_12_o[31] = cipher_output_21_12_i[31];']) """ - id_ = self.id - output_bit_size = self.output_bit_size - cp_declarations = [f'array[0..{output_bit_size - 1}] of var 0..1: {id_}_i;', - f'array[0..{output_bit_size - 1}] of var 0..1: {id_}_o;'] - cp_constraints = [f'constraint {id_}_o[{i}] = {id_}_i[{i}];' - for i in range(output_bit_size)] + cp_declarations = [ + f"array[0..{self.output_bit_size - 1}] of var 0..1: {self.id}_i;", + f"array[0..{self.output_bit_size - 1}] of var 0..1: {self.id}_o;", + ] + cp_constraints = [f"constraint {self.id}_o[{i}] = {self.id}_i[{i}];" for i in range(self.output_bit_size)] return cp_declarations, cp_constraints def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): code = [] - cipher_output_params = [f'bit_vector_select_word({self.input_id_links[i]}, {self.input_bit_positions[i]})' - for i in range(len(self.input_id_links))] - code.append(f' {self.id} = bit_vector_CONCAT([{",".join(cipher_output_params)} ])') + cipher_output_params = [ + f"bit_vector_select_word({link}, {positions})" + for link, positions in zip(self.input_id_links, self.input_bit_positions) + ] + code.append(f" {self.id} = bit_vector_CONCAT([{','.join(cipher_output_params)} ])") code.append(f' if "{self.description[0]}" not in intermediateOutputs.keys():') code.append(f' intermediateOutputs["{self.description[0]}"] = []') if convert_output_to_bytes: code.append( - f' intermediateOutputs["{self.description[0]}"]' - f'.append(np.packbits({self.id}, axis=0).transpose())') + f' intermediateOutputs["{self.description[0]}"].append(np.packbits({self.id}, axis=0).transpose())' + ) else: - code.append( - f' intermediateOutputs["{self.description[0]}"]' - f'.append({self.id}.transpose())') + code.append(f' intermediateOutputs["{self.description[0]}"].append({self.id}.transpose())') return code def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = {params}[0]', - f' if "{self.description[0]}" not in intermediateOutputs.keys():', - f' intermediateOutputs["{self.description[0]}"] = []', - f' if integers_inputs_and_outputs:', -# f' intermediateOutputs["{self.description[0]}"].append(evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size}))', - f' intermediateOutputs["{self.description[0]}"] = evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size})', - f' else:', - f' intermediateOutputs["{self.description[0]}"].append({self.id}.transpose())'] + return [ + f" {self.id} = {params}[0]", + f' if "{self.description[0]}" not in intermediateOutputs.keys():', + f' intermediateOutputs["{self.description[0]}"] = []', + " if integers_inputs_and_outputs:", + # f' intermediateOutputs["{self.description[0]}"].append(evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size}))', + f' intermediateOutputs["{self.description[0]}"] = evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size})', + " else:", + f' intermediateOutputs["{self.description[0]}"].append({self.id}.transpose())', + ] def milp_constraints(self, model): """ @@ -284,9 +294,8 @@ def milp_constraints(self, model): input_vars, output_vars = self._get_input_output_variables() variables = [(f"x[{var}]", x[var]) for var in input_vars + output_vars] constraints = [] - output_bit_size = self.output_bit_size - model.intermediate_output_names.append([self.id, output_bit_size]) - for i in range(output_bit_size): + model.intermediate_output_names.append([self.id, self.output_bit_size]) + for i in range(self.output_bit_size): constraints.append(x[output_vars[i]] == x[input_vars[i]]) return variables, constraints @@ -323,9 +332,8 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode input_vars, output_vars = self._get_input_output_variables() variables = [(f"x_class[{var}]", x_class[var]) for var in input_vars + output_vars] constraints = [] - output_bit_size = self.output_bit_size - model.intermediate_output_names.append([self.id, output_bit_size]) - for i in range(output_bit_size): + model.intermediate_output_names.append([self.id, self.output_bit_size]) + for i in range(self.output_bit_size): constraints.append(x_class[output_vars[i]] == x_class[input_vars[i]]) return variables, constraints @@ -428,19 +436,26 @@ def minizinc_constraints(self, model): intermediate_component_string = [] component_id = self.id ninputs = self.input_bit_size - input_vars = [f'{component_id}_{model.input_postfix}{i}' for i in range(ninputs)] - output_vars = [f'{component_id}_{model.output_postfix}{i}' for i in range(ninputs)] + input_vars = [f"{component_id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{component_id}_{model.output_postfix}{i}" for i in range(ninputs)] - for i in range(len(input_vars)): - intermediate_component_string.append(f'constraint {input_vars[i]} = {output_vars[i]};') + for input_var, output_var in zip(input_vars, output_vars): + intermediate_component_string.append(f"constraint {input_var} = {output_var};") mzn_input_array = self._create_minizinc_1d_array_from_list(input_vars) if self.description[0] in ["round_output", "cipher_output", "round_key_output"]: - model.mzn_output_directives.append("\noutput [\"component description: " + self.description[0] + - ", id: " + component_id + "_input:\" ++ show(" + mzn_input_array + - ")++\"\\n\"];" + "\n") - - model.intermediate_constraints_array.append({f'{component_id}_input': input_vars}) + model.mzn_output_directives.append( + '\noutput ["component description: ' + + self.description[0] + + ", id: " + + component_id + + '_input:" ++ show(' + + mzn_input_array + + ')++"\\n"];' + + "\n" + ) + + model.intermediate_constraints_array.append({f"{component_id}_input": input_vars}) return var_names, intermediate_component_string @@ -479,7 +494,7 @@ def sat_constraints(self): 'cipher_output_2_12_31 -xor_2_10_15', 'xor_2_10_15 -cipher_output_2_12_31']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -623,7 +638,7 @@ def smt_constraints(self): '(assert (= cipher_output_2_12_30 xor_2_10_14))', '(assert (= cipher_output_2_12_31 xor_2_10_15))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): diff --git a/claasp/components/concatenate_component.py b/claasp/components/concatenate_component.py index 09a0c28c2..7f5fa03fa 100644 --- a/claasp/components/concatenate_component.py +++ b/claasp/components/concatenate_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -19,14 +18,21 @@ from claasp.input import Input from claasp.component import Component +from claasp.name_mappings import CONCATENATE class Concatenate(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): - component_id = f'concatenate_{current_round_number}_{current_round_number_of_components}' - component_type = 'concatenate' - description = ['', 0] + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): + component_id = f"{CONCATENATE}_{current_round_number}_{current_round_number_of_components}" + component_type = CONCATENATE + description = ["", 0] component_input = Input(output_bit_size, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -34,7 +40,7 @@ def get_bit_based_c_code(self, verbosity): concatenate_code = [] self.select_bits(concatenate_code) - concatenate_code.append(f'\tBitString *{self.id} = input;') + concatenate_code.append(f"\tBitString *{self.id} = input;") if verbosity: self.print_values(concatenate_code) @@ -42,10 +48,10 @@ def get_bit_based_c_code(self, verbosity): return concatenate_code def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_CONCAT([{",".join(params)} ])'] + return [f" {self.id} = bit_vector_CONCAT([{','.join(params)} ])"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = np.vstack({params})'] + return [f" {self.id} = np.vstack({params})"] def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): concatenate_code = [] @@ -53,9 +59,9 @@ def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): wordstring_variables.append(self.id) if verbosity: - concatenate_code.append(f'\tstr = wordstring_to_hex_string({self.id});') + concatenate_code.append(f"\tstr = wordstring_to_hex_string({self.id});") concatenate_code.append(f'\tprintf("{self.id} input: %s\\n", str);') concatenate_code.append(f'\tprintf("{self.id} output: %s\\n", str);') - concatenate_code.append('\tfree(str);') + concatenate_code.append("\tfree(str);") return concatenate_code diff --git a/claasp/components/constant_component.py b/claasp/components/constant_component.py index 878083c68..e220ef171 100644 --- a/claasp/components/constant_component.py +++ b/claasp/components/constant_component.py @@ -1,4 +1,3 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute # @@ -21,8 +20,8 @@ from claasp.component import Component from claasp.cipher_modules.models.sat.utils import constants from claasp.cipher_modules.models.smt.utils import utils as smt_utils -from claasp.cipher_modules.code_generator import constant_to_bitstring -from claasp.cipher_modules.generic_functions_vectorized_byte import integer_array_to_evaluate_vectorized_input +from claasp.name_mappings import CONSTANT + def constant_to_repr(val, output_size): _val = int(val, 0) @@ -30,24 +29,20 @@ def constant_to_repr(val, output_size): s = output_size + (8 - (output_size % 8)) else: s = output_size - ret = [(_val >> s - (8 * (i + 1))) & 0xff for i in range(s // 8)] + ret = [(_val >> s - (8 * (i + 1))) & 0xFF for i in range(s // 8)] return ret - - class Constant(Component): - - def __init__(self, current_round_number, current_round_number_of_components, - output_bit_size, value): - component_id = f'constant_{current_round_number}_{current_round_number_of_components}' - component_type = 'constant' + def __init__(self, current_round_number, current_round_number_of_components, output_bit_size, value): + component_id = f"{CONSTANT}_{current_round_number}_{current_round_number_of_components}" + component_type = CONSTANT if output_bit_size % 4 == 0: description = [f"{value:#0{(output_bit_size // 4) + 2}x}"] else: description = [f"{value:#0{output_bit_size + 2}b}"] - component_input = Input(0, [''], [[]]) + component_input = Input(0, [""], [[]]) super().__init__(component_id, component_type, component_input, output_bit_size, description) def algebraic_polynomials(self, model): @@ -95,7 +90,7 @@ def algebraic_polynomials(self, model): constant = int(self.description[0], 16) ring_R = model.ring() - y = list(map(ring_R, [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)])) + y = list(map(ring_R, [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)])) b = list(map(int, reversed(bin(constant)[2:]))) b += [0] * (noutputs - len(b)) @@ -168,17 +163,12 @@ def cp_constraints(self): 'constraint constant_2_0[12] = 0;', 'constraint constant_2_0[13] = 0;', 'constraint constant_2_0[14] = 0;', - 'constraint constant_2_0[15] = 0;']) + 'constraint constant_2_0[15] = 1;']) """ - output_size = self.output_bit_size - output_id_link = self.id - description = self.description - value = f'{int(description[0], 16):0{output_size}b}' - new_declaration = f'array[0..{int(output_size) - 1}] of var 0..1: {output_id_link};' - cp_declarations = [new_declaration] - cp_constraints = [] - for i in range(output_size): - cp_constraints.append(f'constraint {output_id_link}[{i}] = 0;') + cp_declarations = [f"array[0..{self.output_bit_size - 1}] of var 0..1: {self.id};"] + value = int(self.description[0], 16) + bits = map(int, f"{value:0{self.output_bit_size}b}") + cp_constraints = [f"constraint {self.id}[{i}] = {bit};" for i, bit in enumerate(bits)] return cp_declarations, cp_constraints @@ -206,16 +196,21 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model 'array[0..1] of var 0..1: constant_0_18_value = array1d(0..1, [0,0]);'], []) """ - output_size = int(self.output_bit_size) - output_id_link = self.id + output_bit_size = self.output_bit_size word_size = model.word_size - new_declaration = f'array[0..{(output_size - 1) // word_size}] of var 0..1: ' \ - f'{output_id_link}_active = array1d(0..{(output_size - 1) // word_size}, [' \ - + ','.join('0' * (output_size // word_size)) + ']);' + new_declaration = ( + f"array[0..{(output_bit_size - 1) // word_size}] of var 0..1: " + f"{self.id}_active = array1d(0..{(output_bit_size - 1) // word_size}, [" + + ",".join("0" * (output_bit_size // word_size)) + + "]);" + ) cp_declarations = [new_declaration] - cp_declarations.append(f'array[0..{(output_size - 1) // word_size}] of var 0..1: ' - f'{output_id_link}_value = array1d(0..{(output_size - 1) // word_size}, [' - + ','.join('0' * (output_size // word_size)) + ']);') + cp_declarations.append( + f"array[0..{(output_bit_size - 1) // word_size}] of var 0..1: " + f"{self.id}_value = array1d(0..{(output_bit_size - 1) // word_size}, [" + + ",".join("0" * (output_bit_size // word_size)) + + "]);" + ) cp_constraints = [] return cp_declarations, cp_constraints @@ -238,15 +233,15 @@ def cp_xor_differential_propagation_first_step_constraints(self, model): sage: constant_component.cp_xor_differential_propagation_first_step_constraints(cp) (['array[0..3] of var 0..1: constant_0_30 = array1d(0..3, [0,0,0,0]);'], []) """ - output_size = int(self.output_bit_size) - output_id_link = self.id - new_declaration = f'array[0..{(output_size - 1) // model.word_size}] of var 0..1: ' \ - f'{output_id_link} = array1d(0..{(output_size - 1) // model.word_size}, [' \ - + ','.join('0' * (output_size // model.word_size)) + ']);' - cp_declarations = [new_declaration] + cp_declarations = [ + f"array[0..{(self.output_bit_size - 1) // model.word_size}] of var 0..1: " + f"{self.id} = array1d(0..{(self.output_bit_size - 1) // model.word_size}, [" + + ",".join("0" * (self.output_bit_size // model.word_size)) + + "]);" + ] cp_constraints = [] - result = cp_declarations, cp_constraints - return result + + return cp_declarations, cp_constraints def cp_xor_differential_propagation_constraints(self, model=None): """ @@ -280,13 +275,8 @@ def cp_xor_differential_propagation_constraints(self, model=None): 'constraint constant_2_0[14] = 0;', 'constraint constant_2_0[15] = 0;']) """ - output_size = int(self.output_bit_size) - output_id_link = self.id - new_declaration = f'array[0..{int(output_size) - 1}] of var 0..2: {output_id_link};' - cp_declarations = [new_declaration] - cp_constraints = [] - for i in range(output_size): - cp_constraints.append(f'constraint {output_id_link}[{i}] = 0;') + cp_declarations = [f"array[0..{self.output_bit_size - 1}] of var 0..2: {self.id};"] + cp_constraints = [f"constraint {self.id}[{i}] = 0;" for i in range(self.output_bit_size)] return cp_declarations, cp_constraints @@ -307,44 +297,45 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): (['array[0..15] of var 0..1: constant_2_0_o;'], []) """ - output_size = int(self.output_bit_size) - output_id_link = self.id - cp_declarations = [] + cp_declarations = [f"array[0..{self.output_bit_size - 1}] of var 0..1: {self.id}_o;"] cp_constraints = [] - new_declaration = f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;' - cp_declarations.append(new_declaration) - result = cp_declarations, cp_constraints - return result + + return cp_declarations, cp_constraints def get_bit_based_c_code(self, verbosity): - constant_code = [f'\tBitString *{self.id} = bitstring_from_hex_string("' - f'{int(self.description[0], 16):#0{(self.output_bit_size // 4) + 2}x}", ' - f'{self.output_bit_size});'] + constant_code = [ + f'\tBitString *{self.id} = bitstring_from_hex_string("' + f'{int(self.description[0], 16):#0{(self.output_bit_size // 4) + 2}x}", ' + f"{self.output_bit_size});" + ] if verbosity: constant_code.append(f'\tprintf("{self.id} input: 0x0");') constant_code.append(f'\tprintf("{self.id} output: ");') - constant_code.append(f'\tprint_bitstring({self.id}, 16);\n') + constant_code.append(f"\tprint_bitstring({self.id}, 16);\n") return constant_code def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = np.array({constant_to_bitstring(self.description[0], self.output_bit_size)}, ' - f'dtype=np.uint8).reshape({self.output_bit_size, 1})'] + value = int(self.description[0], 0) + bits = list(map(int, f"{value:0{self.output_bit_size}b}")) + return [f" {self.id} = np.array({bits}, dtype=np.uint8).reshape({self.output_bit_size}, 1)"] def get_byte_based_vectorized_python_code(self, params): val = constant_to_repr(self.description[0], self.output_bit_size) - return [f' {self.id} = np.array({val}, dtype=np.uint8).reshape({len(val)}, 1)'] + return [f" {self.id} = np.array({val}, dtype=np.uint8).reshape({len(val)}, 1)"] def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): - constant_code = [f'\tWordString *{self.id} = wordstring_from_hex_string("' - f'{int(self.description[0], 16):#0{(self.output_bit_size // 4) + 2}x}", ' - f'{self.output_bit_size // word_size});'] + constant_code = [ + f'\tWordString *{self.id} = wordstring_from_hex_string("' + f'{int(self.description[0], 16):#0{(self.output_bit_size // 4) + 2}x}", ' + f"{self.output_bit_size // word_size});" + ] wordstring_variables.append(self.id) if verbosity: constant_code.append(f'\tprintf("{self.id} input: 0x0\\n");') constant_code.append(f'\tprintf("{self.id} output: ");') - constant_code.append(f'\tprint_wordstring({self.id}, 16);\n') + constant_code.append(f"\tprint_wordstring({self.id}, 16);\n") return constant_code @@ -375,9 +366,8 @@ def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, mod input_vars, output_vars = self._get_wordwise_input_output_linked_class(model) variables = [(f"x_class[{var}]", x_class[var]) for var in input_vars + output_vars] - constraints = [] - for i in range(len(output_vars)): - constraints.append(x_class[output_vars[i]] == 0) + constraints = [x_class[output_var] == 0 for output_var in output_vars] + return variables, constraints def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, model): @@ -411,9 +401,8 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode input_vars, output_vars = self._get_input_output_variables() variables = [(f"x_class[{var}]", x_class[var]) for var in input_vars + output_vars] - constraints = [] - for i in range(self.output_bit_size): - constraints.append(x_class[output_vars[i]] == 0) + constraints = [x_class[output_var] == 0 for output_var in output_vars] + return variables, constraints def milp_xor_differential_propagation_constraints(self, model): @@ -449,9 +438,9 @@ def milp_xor_differential_propagation_constraints(self, model): x = model.binary_variable input_vars, output_vars = self._get_input_output_variables() variables = [(f"x[{var}]", x[var]) for var in input_vars + output_vars] - constraints = [x[output_vars[i]] == 0 for i in range(self.output_bit_size)] - result = variables, constraints - return result + constraints = [x[output_var] == 0 for output_var in output_vars] + + return variables, constraints def milp_xor_linear_mask_propagation_constraints(self, model): """ @@ -483,8 +472,8 @@ def milp_xor_linear_mask_propagation_constraints(self, model): input_vars, output_vars = self._get_independent_input_output_variables() variables = [(f"x[{var}]", x[var]) for var in input_vars + output_vars] constraints = [] - result = variables, constraints - return result + + return variables, constraints def minizinc_deterministic_truncated_xor_differential_trail_constraints(self, model): return self.minizinc_xor_differential_propagation_constraints(model) @@ -506,16 +495,15 @@ def minizinc_xor_differential_propagation_constraints(self, model): sage: constant_component = fancy.get_component_from_id("constant_0_10") sage: _, constant_xor_differential_constraints = constant_component.minizinc_xor_differential_propagation_constraints(minizinc) sage: constant_xor_differential_constraints[6] - 'constraint constant_0_10_y6=0;' + 'constraint constant_0_10_y6 = 0;' """ var_names = self._define_var(model.input_postfix, model.output_postfix, model.data_type) constant_component_string = [] - noutputs = self.output_bit_size - constant_str_values = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] + constant_str_values = [f"{self.id}_{model.output_postfix}{i}" for i in range(self.output_bit_size)] for constant_str in constant_str_values: - constant_component_string.append(f'constraint {constant_str}=0;') - result = var_names, constant_component_string - return result + constant_component_string.append(f"constraint {constant_str} = 0;") + + return var_names, constant_component_string def sat_constraints(self): """ @@ -548,11 +536,11 @@ def sat_constraints(self): '-constant_2_0_14', 'constant_2_0_15']) """ - output_bit_len, output_bit_ids = self._generate_output_ids() + _, output_bit_ids = self._generate_output_ids() value = int(self.description[0], 16) - value_bits = [value >> i & 1 for i in reversed(range(output_bit_len))] - minus = ['-' * (not i) for i in value_bits] - constraints = [f'{minus[i]}{output_bit_ids[i]}' for i in range(output_bit_len)] + bits = map(int, f"{value:0{self.output_bit_size}b}") + signs = ["-" * (bit ^ 1) for bit in bits] + constraints = [f"{sign}{output_bit_id}" for sign, output_bit_id in zip(signs, output_bit_ids)] return output_bit_ids, constraints @@ -590,7 +578,8 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): '-constant_2_0_15_1']) """ _, out_ids_0, out_ids_1 = self._generate_output_double_ids() - constraints = [f'-{out_id}' for out_id in out_ids_0] + [f'-{out_id}' for out_id in out_ids_1] + constraints = [f"-{out_id}" for out_id in out_ids_0] + [f"-{out_id}" for out_id in out_ids_1] + return out_ids_0 + out_ids_1, constraints def sat_semi_deterministic_truncated_xor_differential_constraints(self): @@ -628,9 +617,9 @@ def sat_xor_differential_propagation_constraints(self, model=None): '-constant_2_0_15']) """ _, output_bit_ids = self._generate_output_ids() - constraints = [f'-{output_bit_id}' for output_bit_id in output_bit_ids] - result = output_bit_ids, constraints - return result + constraints = [f"-{output_bit_id}" for output_bit_id in output_bit_ids] + + return output_bit_ids, constraints def sat_xor_linear_mask_propagation_constraints(self, model=None): """ @@ -662,8 +651,8 @@ def sat_xor_linear_mask_propagation_constraints(self, model=None): """ out_suffix = constants.OUTPUT_BIT_ID_SUFFIX _, output_bit_ids = self._generate_output_ids(suffix=out_suffix) - result = output_bit_ids, [] - return result + + return output_bit_ids, [] def smt_constraints(self): """ @@ -692,11 +681,13 @@ def smt_constraints(self): '(assert (not constant_0_2_30))', '(assert constant_0_2_31)']) """ - output_bit_len, output_bit_ids = self._generate_output_ids() + _, output_bit_ids = self._generate_output_ids() value = int(self.description[0], 16) - constraints = [smt_utils.smt_assert(output_bit_ids[i]) if value >> (output_bit_len - 1 - i) & 1 - else smt_utils.smt_assert(smt_utils.smt_not(output_bit_ids[i])) - for i in range(output_bit_len)] + bits = map(int, f"{value:0{self.output_bit_size}b}") + constraints = [ + smt_utils.smt_assert(output_bit_id) if bit else smt_utils.smt_assert(smt_utils.smt_not(output_bit_id)) + for bit, output_bit_id in zip(bits, output_bit_ids) + ] return output_bit_ids, constraints @@ -727,11 +718,10 @@ def smt_xor_differential_propagation_constraints(self, model=None): '(assert (not constant_0_2_30))', '(assert (not constant_0_2_31))']) """ - output_bit_len, output_bit_ids = self._generate_output_ids() - constraints = [smt_utils.smt_assert(smt_utils.smt_not(output_bit_ids[i])) - for i in range(output_bit_len)] - result = output_bit_ids, constraints - return result + _, output_bit_ids = self._generate_output_ids() + constraints = [smt_utils.smt_assert(smt_utils.smt_not(output_bit_id)) for output_bit_id in output_bit_ids] + + return output_bit_ids, constraints def smt_xor_linear_mask_propagation_constraints(self, model=None): """ @@ -759,5 +749,5 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): """ out_suffix = constants.OUTPUT_BIT_ID_SUFFIX _, output_bit_ids = self._generate_output_ids(out_suffix) - result = output_bit_ids, [] - return result + + return output_bit_ids, [] diff --git a/claasp/components/fsr_component.py b/claasp/components/fsr_component.py index 1a7edd270..60fd6f7c7 100644 --- a/claasp/components/fsr_component.py +++ b/claasp/components/fsr_component.py @@ -1,4 +1,3 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute # @@ -16,14 +15,13 @@ # along with this program. If not, see . # **************************************************************************** - from sage.modules.free_module_element import vector from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing from sage.rings.finite_rings.finite_field_constructor import FiniteField as GF + from claasp.input import Input from claasp.component import Component -from claasp.cipher_modules.generic_functions import _bits_to_words_array - +from claasp.cipher_modules.generic_functions import _bits_to_words_array def _get_polynomial_from_binary_polynomial_index_list(polynomial_index_list, R): @@ -38,6 +36,7 @@ def _get_polynomial_from_binary_polynomial_index_list(polynomial_index_list, R): p += m return p + def _get_polynomial_from_word_polynomial_index_list(polynomial_index_list, R): if polynomial_index_list == []: return R(1) @@ -47,32 +46,42 @@ def _get_polynomial_from_word_polynomial_index_list(polynomial_index_list, R): for _ in polynomial_index_list: m = 0 # presently it is for field of characteristic 2 only - cc = "{0:b}".format(_[0]) + cc = f"{_[0]:b}" for i in range(len(cc)): - if cc[i] == '1': m = m + pow(y, len(cc) - 1 - i) + if cc[i] == "1": + m = m + pow(y, len(cc) - 1 - i) for i in _[1]: m = m * x[i] p += m return p + def _words_array_to_bits(word_array, word_gf): bits_inside_word = word_gf.degree() - output = [0] * (len(word_array)*bits_inside_word) + output = [0] * (len(word_array) * bits_inside_word) for i in range(len(word_array)): coeffcients = word_array[i].coefficients() monomials = word_array[i].monomials() for j in range(len(coeffcients)): bits = coeffcients[j].polynomial().monomials() for b in bits: - output[i*bits_inside_word+(bits_inside_word-b.degree()-1)] += monomials[j] + output[i * bits_inside_word + (bits_inside_word - b.degree() - 1)] += monomials[j] return output + class FSR(Component): - def __init__(self, current_round_number, current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, description): - component_id = f'fsr_{current_round_number}_{current_round_number_of_components}' - component_type = 'fsr' + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ): + component_id = f"fsr_{current_round_number}_{current_round_number_of_components}" + component_type = "fsr" input_len = 0 for bits in input_bit_positions: input_len = input_len + len(bits) @@ -135,21 +144,19 @@ def algebraic_polynomials(self, model): sage: S[480] fsr_0_714_y480 + fsr_0_714_x352 + fsr_0_714_x64 + fsr_0_714_x0 """ - bits_inside_word = self.description[1] if bits_inside_word == 1: return self._algebraic_polynomials_binary(model) - else: - return self._algebraic_polynomials_word(model) + return self._algebraic_polynomials_word(model) def _algebraic_polynomials_binary(self, model): noutputs = self.output_bit_size ninputs = self.input_bit_size ring_R = model.ring() - x_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] + x_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] x_polynomial_ring = PolynomialRing(ring_R.base(), x_vars) - x = vector(ring_R, (map(ring_R, [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)]))) - y = vector(ring_R, (map(ring_R, [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)]))) + x = vector(ring_R, (map(ring_R, [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)]))) + y = vector(ring_R, (map(ring_R, [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)]))) number_of_registers = len(self.description[0]) registers_polynomial = [0 for _ in range(number_of_registers)] registers_start = [0 for _ in range(number_of_registers)] @@ -163,12 +170,16 @@ def _algebraic_polynomials_binary(self, model): end = 0 for i in range(number_of_registers): - registers_polynomial[i] = _get_polynomial_from_binary_polynomial_index_list(self.description[0][i][1], x_polynomial_ring) + registers_polynomial[i] = _get_polynomial_from_binary_polynomial_index_list( + self.description[0][i][1], x_polynomial_ring + ) registers_start[i] = end end += self.description[0][i][0] - registers_update_bit[i] = end-1 + registers_update_bit[i] = end - 1 if len(self.description[0][i]) > 2: - clock_polynomials[i] = _get_polynomial_from_binary_polynomial_index_list(self.description[0][i][2], x_polynomial_ring) + clock_polynomials[i] = _get_polynomial_from_binary_polynomial_index_list( + self.description[0][i][2], x_polynomial_ring + ) for _ in range(clocks): for i in range(number_of_registers): @@ -176,34 +187,33 @@ def _algebraic_polynomials_binary(self, model): if clock_polynomials[i] is not None: clock_bit = clock_polynomials[i](*x) for k in range(registers_start[i], registers_update_bit[i]): - x[k] = clock_bit*x[k+1] + (clock_bit+1)*x[k] - x[registers_update_bit[i]] = clock_bit*feedback_bit + (clock_bit+1)*x[registers_update_bit[i]] + x[k] = clock_bit * x[k + 1] + (clock_bit + 1) * x[k] + x[registers_update_bit[i]] = clock_bit * feedback_bit + (clock_bit + 1) * x[registers_update_bit[i]] else: for k in range(registers_start[i], registers_update_bit[i]): - x[k] = x[k+1] + x[k] = x[k + 1] x[registers_update_bit[i]] = feedback_bit - output_polynomials = y+vector(x) + output_polynomials = y + vector(x) return output_polynomials def _algebraic_polynomials_word(self, model): - bits_inside_word = self.description[1] noutputs = self.output_bit_size ninputs = self.input_bit_size - word_gf = GF(2 ** bits_inside_word) # Finite field 2^bits_inside_word - x_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] - y_vars = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] + word_gf = GF(2**bits_inside_word) # Finite field 2^bits_inside_word + x_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + y_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] ring_R = PolynomialRing(word_gf, x_vars + y_vars) # Now the base ring is GF(2^n) number_of_words = int(ninputs / bits_inside_word) - x = vector(ring_R, (map(ring_R, [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)]))) - y = vector(ring_R, (map(ring_R, [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)]))) + x = vector(ring_R, (map(ring_R, [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)]))) + y = vector(ring_R, (map(ring_R, [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)]))) word_array = _bits_to_words_array(x, bits_inside_word, word_gf) - word_polynomial_ring = PolynomialRing(word_gf, number_of_words, 'w') + word_polynomial_ring = PolynomialRing(word_gf, number_of_words, "w") number_of_registers = len(self.description[0]) registers_polynomial = [0 for _ in range(number_of_registers)] @@ -216,8 +226,9 @@ def _algebraic_polynomials_word(self, model): end = 0 for i in range(number_of_registers): - registers_polynomial[i] = _get_polynomial_from_word_polynomial_index_list(self.description[0][i][1], - word_polynomial_ring) + registers_polynomial[i] = _get_polynomial_from_word_polynomial_index_list( + self.description[0][i][1], word_polynomial_ring + ) registers_start[i] = end end += self.description[0][i][0] registers_update_word[i] = end - 1 @@ -234,4 +245,4 @@ def _algebraic_polynomials_word(self, model): ring_R = model.ring() output_polynomials_gf2 = [ring_R(str(p)) for p in output_polynomials] - return output_polynomials_gf2 \ No newline at end of file + return output_polynomials_gf2 diff --git a/claasp/components/intermediate_output_component.py b/claasp/components/intermediate_output_component.py index 5637670a6..43d3a6a06 100644 --- a/claasp/components/intermediate_output_component.py +++ b/claasp/components/intermediate_output_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -20,9 +19,10 @@ from claasp.components.cipher_output_component import CipherOutput from claasp.cipher_modules.models.sat.utils import utils as sat_utils from claasp.cipher_modules.models.smt.utils import utils as smt_utils -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import \ - update_dictionary_that_contains_xor_inequalities_between_n_input_bits, \ - output_dictionary_that_contains_xor_inequalities +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import ( + update_dictionary_that_contains_xor_inequalities_between_n_input_bits, + output_dictionary_that_contains_xor_inequalities, +) def update_xor_linear_constraints_for_more_than_one_bit(constraints, intermediate_var, linked_components, x): @@ -50,11 +50,25 @@ def update_xor_linear_constraints_for_more_than_one_bit(constraints, intermediat class IntermediateOutput(CipherOutput): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, output_tag): - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, True, output_tag) - self._suffixes = ['_i', '_o'] + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + output_tag, + ): + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + True, + output_tag, + ) + self._suffixes = ["_i", "_o"] def cp_xor_linear_mask_propagation_constraints(self, model): """ @@ -89,41 +103,42 @@ def cp_xor_linear_mask_propagation_constraints(self, model): for intermediate_var, linked_components in bit_bindings.items(): # no fork if len(linked_components) == 1: - constraints.append(f'constraint {intermediate_var} = {linked_components[0]};') + constraints.append(f"constraint {intermediate_var} = {linked_components[0]};") # fork else: operation = " + ".join(linked_components) - constraints.append(f'constraint {intermediate_var} = ({operation}) mod 2;') + constraints.append(f"constraint {intermediate_var} = ({operation}) mod 2;") return variables, constraints def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): code = [] - intermediate_output_params = [f'bit_vector_select_word({self.input_id_links[i]}, {self.input_bit_positions[i]})' - for i in range(len(self.input_id_links))] - code.append(f' {self.id} = bit_vector_CONCAT([{",".join(intermediate_output_params)} ])') + intermediate_output_params = [ + f"bit_vector_select_word({self.input_id_links[i]}, {self.input_bit_positions[i]})" + for i in range(len(self.input_id_links)) + ] + code.append(f" {self.id} = bit_vector_CONCAT([{','.join(intermediate_output_params)} ])") code.append(f' if "{self.description[0]}" not in intermediateOutputs.keys():') code.append(f' intermediateOutputs["{self.description[0]}"] = []') if convert_output_to_bytes: code.append( - f' intermediateOutputs["{self.description[0]}"]' - f'.append(np.packbits({self.id}, axis=0).transpose())') + f' intermediateOutputs["{self.description[0]}"].append(np.packbits({self.id}, axis=0).transpose())' + ) else: - code.append( - f' intermediateOutputs["{self.description[0]}"]' - f'.append({self.id}.transpose())') + code.append(f' intermediateOutputs["{self.description[0]}"].append({self.id}.transpose())') return code def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = {params}[0]', - f' if "{self.description[0]}" not in intermediateOutputs.keys():', - f' intermediateOutputs["{self.description[0]}"] = []', - f' if integers_inputs_and_outputs:', - #f' intermediateOutputs["{self.description[0]}"].append(evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size}))', - f' intermediateOutputs["{self.description[0]}"] = evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size})', - f' else:', - f' intermediateOutputs["{self.description[0]}"].append({self.id}.transpose())'] - + return [ + f" {self.id} = {params}[0]", + f' if "{self.description[0]}" not in intermediateOutputs.keys():', + f' intermediateOutputs["{self.description[0]}"] = []', + " if integers_inputs_and_outputs:", + # f' intermediateOutputs["{self.description[0]}"].append(evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size}))', + f' intermediateOutputs["{self.description[0]}"] = evaluate_vectorized_outputs_to_integers([{self.id}.transpose()], {self.input_bit_size})', + " else:", + f' intermediateOutputs["{self.description[0]}"].append({self.id}.transpose())', + ] def milp_xor_linear_mask_propagation_constraints(self, model): """ @@ -164,8 +179,9 @@ def milp_xor_linear_mask_propagation_constraints(self, model): constraints.append(binary_variable[intermediate_var] == binary_variable[linked_components[0]]) # fork else: - update_xor_linear_constraints_for_more_than_one_bit(constraints, intermediate_var, - linked_components, binary_variable) + update_xor_linear_constraints_for_more_than_one_bit( + constraints, intermediate_var, linked_components, binary_variable + ) return variables, constraints @@ -205,8 +221,9 @@ def sat_xor_linear_mask_propagation_constraints(self, model=None): constraints.extend(sat_utils.cnf_equivalent([intermediate_var] + linked_components)) # fork else: - result_bit_ids = [f'inter_{i}_{intermediate_var}' - for i in range(len(linked_components) - 2)] + [intermediate_var] + result_bit_ids = [f"inter_{i}_{intermediate_var}" for i in range(len(linked_components) - 2)] + [ + intermediate_var + ] constraints.extend(sat_utils.cnf_xor_seq(result_bit_ids, linked_components)) return variables, constraints diff --git a/claasp/components/linear_layer_component.py b/claasp/components/linear_layer_component.py index 1ca73b1d5..b5141db92 100644 --- a/claasp/components/linear_layer_component.py +++ b/claasp/components/linear_layer_component.py @@ -1,17 +1,16 @@ -import numpy as np # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,9 +20,10 @@ from sage.modules.free_module_element import vector from sage.rings.finite_rings.finite_field_constructor import FiniteField -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits import \ - update_dictionary_that_contains_xor_inequalities_for_specific_wordwise_matrix, \ - output_dictionary_that_contains_wordwise_truncated_xor_inequalities +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits import ( + update_dictionary_that_contains_xor_inequalities_for_specific_wordwise_matrix, + output_dictionary_that_contains_wordwise_truncated_xor_inequalities, +) from claasp.cipher_modules.models.milp.utils.utils import espresso_pos_to_constraints from claasp.input import Input from claasp.component import Component, free_input @@ -32,11 +32,13 @@ from claasp.cipher_modules.models.milp.utils import utils as milp_utils from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import ( update_dictionary_that_contains_xor_inequalities_for_specific_matrix, - output_dictionary_that_contains_xor_inequalities) + output_dictionary_that_contains_xor_inequalities, +) -def update_constraints_for_more_than_one_bit(constraints, dict_inequalities, i, indexes_of_values_in_col, input_vars, - number_of_1s, output_vars, x): +def update_constraints_for_more_than_one_bit( + constraints, dict_inequalities, i, indexes_of_values_in_col, input_vars, number_of_1s, output_vars, x +): inequalities = dict_inequalities[number_of_1s] for ineq in inequalities: index_ineq = 0 @@ -59,10 +61,17 @@ def update_constraints_for_more_than_one_bit(constraints, dict_inequalities, i, class LinearLayer(Component): - def __init__(self, current_round_number, current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, description): - component_id = f'linear_layer_{current_round_number}_{current_round_number_of_components}' - component_type = 'linear_layer' + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ): + component_id = f"linear_layer_{current_round_number}_{current_round_number_of_components}" + component_type = "linear_layer" input_len = 0 for bits in input_bit_positions: input_len = input_len + len(bits) @@ -92,9 +101,8 @@ def algebraic_polynomials(self, model): ninputs = self.input_bit_size ring_R = model.ring() M = Matrix(ring_R, self.description, nrows=noutputs, ncols=ninputs) - x = vector(ring_R, (map(ring_R, [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)]))) - y = vector(ring_R, - list(map(ring_R, [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)]))) + x = vector(ring_R, (map(ring_R, [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)]))) + y = vector(ring_R, list(map(ring_R, [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)]))) return (y - M * x).list() @@ -124,14 +132,14 @@ def cms_constraints(self): 'x -linear_layer_0_6_22 sbox_0_0_2 sbox_0_2_2 sbox_0_3_2 sbox_0_4_3 sbox_0_5_0 sbox_0_5_1 sbox_0_5_3', 'x -linear_layer_0_6_23 sbox_0_0_0 sbox_0_0_1 sbox_0_0_2 sbox_0_0_3 sbox_0_1_3 sbox_0_2_1 sbox_0_3_1 sbox_0_3_2 sbox_0_3_3 sbox_0_4_1 sbox_0_4_2 sbox_0_4_3 sbox_0_5_1 sbox_0_5_2 sbox_0_5_3']) """ - input_bit_len, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() matrix = self.description constraints = [] for i in range(output_bit_len): - operands = [f'x -{output_bit_ids[i]}'] - operands.extend(input_bit_ids[j] for j in range(input_bit_len) if matrix[j][i]) - constraints.append(' '.join(operands)) + operands = [f"x -{output_bit_ids[i]}"] + operands.extend(input_bit_id for j, input_bit_id in enumerate(input_bit_ids) if matrix[j][i]) + constraints.append(" ".join(operands)) return output_bit_ids, constraints @@ -174,13 +182,13 @@ def cms_xor_linear_mask_propagation_constraints(self, model=None): operands = [input_bit_ids[i]] for j in range(output_bit_len): if inverse_matrix[j][i]: - variable = f'dummy_{i}_{output_bit_ids[j]}' + variable = f"dummy_{i}_{output_bit_ids[j]}" operands.append(variable) dummy_variables[j].append(variable) constraints.extend(sat_utils.cnf_equivalent(operands)) for i in range(output_bit_len): - operands = [f'x -{output_bit_ids[i]}'] + dummy_variables[i] - constraints.append(' '.join(operands)) + operands = [f"x -{output_bit_ids[i]}"] + dummy_variables[i] + constraints.append(" ".join(operands)) dummy_bit_ids = [d for i in range(output_bit_len) for d in dummy_variables[i]] return input_bit_ids + dummy_bit_ids + output_bit_ids, constraints @@ -204,21 +212,16 @@ def cp_constraints(self): ... 'constraint linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] + sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[3] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_3[2] + sbox_0_3[3] + sbox_0_4[1] + sbox_0_4[2] + sbox_0_4[3] + sbox_0_5[1] + sbox_0_5[2] + sbox_0_5[3]) mod 2;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions matrix = self.description cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_constraints = [] - for i in range(output_size): + for i in range(self.output_bit_size): addenda = [all_inputs[j] for j in range(len(matrix)) if matrix[j][i]] - sum_of_addenda = ' + '.join(addenda) - new_constraint = f'constraint {output_id_link}[{i}] = ({sum_of_addenda}) mod 2;' - cp_constraints.append(new_constraint) + sum_of_addenda = " + ".join(addenda) + cp_constraints.append(f"constraint {self.id}[{i}] = ({sum_of_addenda}) mod 2;") return cp_declarations, cp_constraints @@ -237,29 +240,23 @@ def cp_deterministic_truncated_xor_differential_constraints(self): sage: linear_layer_component = fancy.component_from(0, 6) sage: linear_layer_component.cp_deterministic_truncated_xor_differential_constraints() ([], - ['constraint if ((sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[0] < 2) /\\ (sbox_0_1[1] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[0] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[3]< 2)) then linear_layer_0_6[0] = (sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[0] + sbox_0_1[1] + sbox_0_1[3] + sbox_0_2[0] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_4[2] + sbox_0_5[1] + sbox_0_5[3]) mod 2 else linear_layer_0_6[0] = 2 endif;', + ['constraint if ((sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[0] < 2) /\\ (sbox_0_1[1] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[0] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[3] < 2)) then linear_layer_0_6[0] = (sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[0] + sbox_0_1[1] + sbox_0_1[3] + sbox_0_2[0] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_4[2] + sbox_0_5[1] + sbox_0_5[3]) mod 2 else linear_layer_0_6[0] = 2 endif;', ... - 'constraint if ((sbox_0_0[0] < 2) /\\ (sbox_0_0[1] < 2) /\\ (sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_3[2] < 2) /\\ (sbox_0_3[3] < 2) /\\ (sbox_0_4[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_4[3] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[2] < 2) /\\ (sbox_0_5[3]< 2)) then linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] + sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[3] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_3[2] + sbox_0_3[3] + sbox_0_4[1] + sbox_0_4[2] + sbox_0_4[3] + sbox_0_5[1] + sbox_0_5[2] + sbox_0_5[3]) mod 2 else linear_layer_0_6[23] = 2 endif;']) + 'constraint if ((sbox_0_0[0] < 2) /\\ (sbox_0_0[1] < 2) /\\ (sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_3[2] < 2) /\\ (sbox_0_3[3] < 2) /\\ (sbox_0_4[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_4[3] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[2] < 2) /\\ (sbox_0_5[3] < 2)) then linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] + sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[3] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_3[2] + sbox_0_3[3] + sbox_0_4[1] + sbox_0_4[2] + sbox_0_4[3] + sbox_0_5[1] + sbox_0_5[2] + sbox_0_5[3]) mod 2 else linear_layer_0_6[23] = 2 endif;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions - matrix = self.description cp_declarations = [] + matrix = self.description all_inputs = [] - - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_constraints = [] - for i in range(output_size): + for i in range(self.output_bit_size): addenda = [all_inputs[j] for j in range(len(matrix)) if matrix[j][i]] - operation = f' < 2) /\\ ('.join(addenda) - new_constraint = f'constraint if ((' - new_constraint += operation + f'< 2)) then ' - operation2 = ' + '.join(addenda) - new_constraint += f'{output_id_link}[{i}] = ({operation2}) mod 2 else {output_id_link}[{i}] = 2 endif;' - cp_constraints.append(new_constraint) + operation = " < 2) /\\ (".join(addenda) + cp_constraint = f"constraint if (({operation} < 2)) then " + operation2 = " + ".join(addenda) + cp_constraint += f"{self.id}[{i}] = ({operation2}) mod 2 else {self.id}[{i}] = 2 endif;" + cp_constraints.append(cp_constraint) return cp_declarations, cp_constraints @@ -281,34 +278,40 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model sage: linear_layer_component = fancy.component_from(0, 6) sage: linear_layer_component.cp_deterministic_truncated_xor_differential_constraints() ([], - ['constraint if ((sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[0] < 2) /\\ (sbox_0_1[1] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[0] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[3]< 2)) then linear_layer_0_6[0] = (sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[0] + sbox_0_1[1] + sbox_0_1[3] + sbox_0_2[0] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_4[2] + sbox_0_5[1] + sbox_0_5[3]) mod 2 else linear_layer_0_6[0] = 2 endif;', + ['constraint if ((sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[0] < 2) /\\ (sbox_0_1[1] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[0] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[3] < 2)) then linear_layer_0_6[0] = (sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[0] + sbox_0_1[1] + sbox_0_1[3] + sbox_0_2[0] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_4[2] + sbox_0_5[1] + sbox_0_5[3]) mod 2 else linear_layer_0_6[0] = 2 endif;', ... - 'constraint if ((sbox_0_0[0] < 2) /\\ (sbox_0_0[1] < 2) /\\ (sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_3[2] < 2) /\\ (sbox_0_3[3] < 2) /\\ (sbox_0_4[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_4[3] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[2] < 2) /\\ (sbox_0_5[3]< 2)) then linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] + sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[3] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_3[2] + sbox_0_3[3] + sbox_0_4[1] + sbox_0_4[2] + sbox_0_4[3] + sbox_0_5[1] + sbox_0_5[2] + sbox_0_5[3]) mod 2 else linear_layer_0_6[23] = 2 endif;']) + 'constraint if ((sbox_0_0[0] < 2) /\\ (sbox_0_0[1] < 2) /\\ (sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_3[2] < 2) /\\ (sbox_0_3[3] < 2) /\\ (sbox_0_4[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_4[3] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[2] < 2) /\\ (sbox_0_5[3] < 2)) then linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] + sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[3] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_3[2] + sbox_0_3[3] + sbox_0_4[1] + sbox_0_4[2] + sbox_0_4[3] + sbox_0_5[1] + sbox_0_5[2] + sbox_0_5[3]) mod 2 else linear_layer_0_6[23] = 2 endif;']) """ - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs_value = [] all_inputs_active = [] matrix = self.description word_size = model.word_size output_size = len(matrix) // word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - input_len = len(all_inputs_value) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) cp_constraints = [] for i in range(output_size): - operation = ' == 0) /\\ ('.join(all_inputs_active[i::output_size]) - new_constraint = 'constraint if ((' - new_constraint += operation + '== 0)) then ' - new_constraint += f'{output_id_link}_active[{i}] = 0 /\\ {output_id_link}_value[{i}] = 0 else'\ - f'{output_id_link}_active[{i}] = 3 /\\ {output_id_link}_value[{i}] = -2 endif;' - cp_constraints.append(new_constraint) - + operation = " == 0) /\\ (".join(all_inputs_active[i::output_size]) + cp_constraint = "constraint if ((" + cp_constraint += operation + "== 0)) then " + cp_constraint += ( + f"{self.id}_active[{i}] = 0 /\\ {self.id}_value[{i}] = 0 else" + f"{self.id}_active[{i}] = 3 /\\ {self.id}_value[{i}] = -2 endif;" + ) + cp_constraints.append(cp_constraint) + return cp_declarations, cp_constraints def cp_xor_differential_propagation_constraints(self, model): @@ -334,22 +337,16 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): ... 'constraint linear_layer_0_6_i[23]=(linear_layer_0_6_o[0]+linear_layer_0_6_o[1]+linear_layer_0_6_o[2]+linear_layer_0_6_o[3]+linear_layer_0_6_o[4]+linear_layer_0_6_o[7]+linear_layer_0_6_o[8]+linear_layer_0_6_o[11]+linear_layer_0_6_o[13]+linear_layer_0_6_o[14]+linear_layer_0_6_o[15]+linear_layer_0_6_o[18]+linear_layer_0_6_o[19]+linear_layer_0_6_o[20]+linear_layer_0_6_o[21]+linear_layer_0_6_o[22]+linear_layer_0_6_o[23]) mod 2;']) """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) - output_id_link = self.id - description = self.description - cp_declarations = [] + cp_declarations = [ + f"array[0..{self.input_bit_size - 1}] of var 0..1:{self.id}_i;", + f"array[0..{self.output_bit_size - 1}] of var 0..1:{self.id}_o;", + ] cp_constraints = [] - matrix = Matrix(FiniteField(2), description) - cp_declarations.append(f'array[0..{input_size - 1}] of var 0..1:{output_id_link}_i;') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1:{output_id_link}_o;') - for i in range(input_size): - new_constraint = f'constraint {output_id_link}_i[{i}]=(' - for j in range(input_size): - if matrix[i][j] == 1: - new_constraint = new_constraint + f'{output_id_link}_o[{j}]+' - new_constraint = new_constraint[:-1] + f') mod 2;' - cp_constraints.append(new_constraint) + matrix = Matrix(FiniteField(2), self.description) + for i in range(self.input_bit_size): + addenda = [f"{self.id}_o[{j}]" for j in range(self.input_bit_size) if matrix[i][j] == 1] + cp_constraint = f"constraint {self.id}_i[{i}]=(" + "+".join(addenda) + ") mod 2;" + cp_constraints.append(cp_constraint) return cp_declarations, cp_constraints @@ -357,12 +354,12 @@ def get_bit_based_c_code(self, verbosity): linear_layer_code = [] self.select_bits(linear_layer_code) - linear_layer_code.append('\tlinear_transformation = (uint8_t*[]) {') + linear_layer_code.append("\tlinear_transformation = (uint8_t*[]) {") for row in self.description: - linear_layer_code.append(f'\t\t(uint8_t[]) {{{", ".join([str(x) for x in row])}}},') - linear_layer_code.append('\t};') + linear_layer_code.append(f"\t\t(uint8_t[]) {{{', '.join(map(str, row))}}},") + linear_layer_code.append("\t};") - linear_layer_code.append(f'\tBitString* {self.id} = LINEAR_LAYER(input, linear_transformation);\n') + linear_layer_code.append(f"\tBitString* {self.id} = LINEAR_LAYER(input, linear_transformation);\n") if verbosity: self.print_values(linear_layer_code) @@ -372,10 +369,10 @@ def get_bit_based_c_code(self, verbosity): return linear_layer_code def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_linear_layer(bit_vector_CONCAT([{",".join(params)} ]), {self.description})'] + return [f" {self.id} = bit_vector_linear_layer(bit_vector_CONCAT([{','.join(params)} ]), {self.description})"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_linear_layer({params}, {self.description})'] + return [f" {self.id} = byte_vector_linear_layer({params}, {self.description})"] def milp_constraints(self, model): """ @@ -414,7 +411,7 @@ def milp_constraints(self, model): matrix = self.description dict_inequalities = {} - matrix_without_unit_vectors = [row for row in matrix if sum([int(i) for i in row]) > 1] + matrix_without_unit_vectors = [row for row in matrix if sum(map(int, row)) > 1] if matrix_without_unit_vectors: update_dictionary_that_contains_xor_inequalities_for_specific_matrix(matrix_without_unit_vectors) dict_inequalities = output_dictionary_that_contains_xor_inequalities() @@ -424,8 +421,16 @@ def milp_constraints(self, model): number_of_1s = len([bit for bit in col if bit]) indexes_of_values_in_col = [value_index for value_index, value in enumerate(col) if value] if number_of_1s >= 2 and number_of_1s in dict_inequalities.keys(): - update_constraints_for_more_than_one_bit(constraints, dict_inequalities, i, indexes_of_values_in_col, - input_vars, number_of_1s, output_vars, x) + update_constraints_for_more_than_one_bit( + constraints, + dict_inequalities, + i, + indexes_of_values_in_col, + input_vars, + number_of_1s, + output_vars, + x, + ) if number_of_1s == 1: constraints.append(x[output_vars[i]] == x[input_vars[indexes_of_values_in_col[0]]]) @@ -480,8 +485,16 @@ def milp_xor_linear_mask_propagation_constraints(self, model): number_of_1s = len([bit for bit in col if bit]) indexes_of_values_in_col = [value_index for value_index, value in enumerate(col) if value] if number_of_1s >= 2: - update_constraints_for_more_than_one_bit(constraints, dict_inequalities, i, indexes_of_values_in_col, - input_vars, number_of_1s, output_vars, x) + update_constraints_for_more_than_one_bit( + constraints, + dict_inequalities, + i, + indexes_of_values_in_col, + input_vars, + number_of_1s, + output_vars, + x, + ) if number_of_1s == 1: constraints.append(x[output_vars[i]] == x[input_vars[indexes_of_values_in_col[0]]]) @@ -532,7 +545,6 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode col = [row[i] for row in matrix] number_of_1s = len([bit for bit in col if bit]) if number_of_1s >= 2: - # performing generalized_xor_deterministic_truncated_xor_differential a = [x_class[input_vars[j]] for j in range(len(col)) if col[j]] list_aj_less_2 = [] @@ -546,8 +558,9 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode constraints.extend(constr) xor_constr = milp_utils.milp_generalized_xor(a, x_class[output_vars[i]]) - constr = milp_utils.milp_if_then_else(all_aj_less_2, xor_constr, [x_class[output_vars[i]] == 2], - model._model.get_max(x_class) * len(a)) + constr = milp_utils.milp_if_then_else( + all_aj_less_2, xor_constr, [x_class[output_vars[i]] == 2], model._model.get_max(x_class) * len(a) + ) constraints.extend(constr) if number_of_1s == 1: @@ -558,7 +571,6 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode return variables, constraints - def milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(self, model): """ Returns a list of variables and a list of constraints for linear layer @@ -588,8 +600,9 @@ def milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(sel matrix = self.description input_id_tuples, output_id_tuples = self._get_input_output_variables_tuples() - linking_constraints = model.link_binary_tuples_to_integer_variables(input_id_tuples + output_id_tuples, - input_ids + output_ids) + linking_constraints = model.link_binary_tuples_to_integer_variables( + input_id_tuples + output_id_tuples, input_ids + output_ids + ) constraints = [] + linking_constraints for i in range(len(matrix)): @@ -597,19 +610,31 @@ def milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(sel number_of_inputs = len([bit for bit in col if bit]) if number_of_inputs >= 2: xor_inputs = [input_id_tuples[j] for j in range(len(col)) if col[j]] - result_ids = [(f'temp_xor_{j}_{self.id}_{i}_0', f'temp_xor_{j}_{self.id}_{i}_1') for j in - range(number_of_inputs - 2)] + [output_id_tuples[i]] - contains_2, greater_constraints = milp_utils.milp_greater(model, sum( - x[input_msb] for input_msb in [id[0] for id in xor_inputs]), 0, len(xor_inputs) + 1) + result_ids = [ + (f"temp_xor_{j}_{self.id}_{i}_0", f"temp_xor_{j}_{self.id}_{i}_1") + for j in range(number_of_inputs - 2) + ] + [output_id_tuples[i]] + contains_2, greater_constraints = milp_utils.milp_greater( + model, sum(x[input_msb] for input_msb in [id[0] for id in xor_inputs]), 0, len(xor_inputs) + 1 + ) constraints.extend(greater_constraints) - sequential_truncated_xor_constraints = milp_utils.milp_xor_truncated(model, xor_inputs[0], xor_inputs[1], result_ids[0]) + sequential_truncated_xor_constraints = milp_utils.milp_xor_truncated( + model, xor_inputs[0], xor_inputs[1], result_ids[0] + ) for chunk in range(1, number_of_inputs - 1): - sequential_truncated_xor_constraints.extend(milp_utils.milp_xor_truncated(model, xor_inputs[chunk + 1], - result_ids[chunk - 1], result_ids[chunk])) + sequential_truncated_xor_constraints.extend( + milp_utils.milp_xor_truncated( + model, xor_inputs[chunk + 1], result_ids[chunk - 1], result_ids[chunk] + ) + ) # if one of the inputs is varied, then the output is varied, # else, perform sequential_xor_deterministic_truncated_xor_differential - constraints.extend(milp_utils.milp_if_then_else(contains_2, [x_class[output_ids[i]] == 2], sequential_truncated_xor_constraints, 6)) + constraints.extend( + milp_utils.milp_if_then_else( + contains_2, [x_class[output_ids[i]] == 2], sequential_truncated_xor_constraints, 6 + ) + ) if number_of_inputs == 1: for index, value in enumerate(col): @@ -619,7 +644,6 @@ def milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(sel return variables, constraints - def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, model): """ Returns a list of variables and a list of constraints for linear layer @@ -660,11 +684,18 @@ def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, mod constraints = [] M = Matrix(self.description) - if M.ncols() > model.word_size and [len(input) for input in self.input_bit_positions] != [model.word_size] * len(self.input_bit_positions): + if M.ncols() > model.word_size and [len(input) for input in self.input_bit_positions] != [ + model.word_size + ] * len(self.input_bit_positions): self.print() # truncated matrix - matrix = [[not M[i:i + model.word_size, j:j + model.word_size].is_zero() for j in - range(0, M.ncols(), model.word_size)] for i in range(0, M.nrows(), model.word_size)] + matrix = [ + [ + not M[i : i + model.word_size, j : j + model.word_size].is_zero() + for j in range(0, M.ncols(), model.word_size) + ] + for i in range(0, M.nrows(), model.word_size) + ] else: matrix = self.description @@ -687,13 +718,22 @@ def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, mod else: # performing sequential wordwise_deterministic_truncated_xor xor_inputs = [input_vars[j] for j in range(len(col)) if col[j]] - result_ids = [tuple([f'temp_xor_{j}_{self.id}_word_{i}_0', f'temp_xor_{j}_{self.id}_word_{i}_1'] + [f'temp_xor_{j}_{self.id}_word_{i}_bit_{k}' for k in range(model.word_size)]) for j in - range(number_of_1s - 2)] + [output_vars[i]] + result_ids = [ + tuple( + [f"temp_xor_{j}_{self.id}_word_{i}_0", f"temp_xor_{j}_{self.id}_word_{i}_1"] + + [f"temp_xor_{j}_{self.id}_word_{i}_bit_{k}" for k in range(model.word_size)] + ) + for j in range(number_of_1s - 2) + ] + [output_vars[i]] constraints.extend( - milp_utils.milp_xor_truncated_wordwise(model, xor_inputs[0], xor_inputs[1], result_ids[0])) + milp_utils.milp_xor_truncated_wordwise(model, xor_inputs[0], xor_inputs[1], result_ids[0]) + ) for chunk in range(1, number_of_1s - 1): - constraints.extend(milp_utils.milp_xor_truncated_wordwise(model, xor_inputs[chunk + 1], - result_ids[chunk - 1], result_ids[chunk])) + constraints.extend( + milp_utils.milp_xor_truncated_wordwise( + model, xor_inputs[chunk + 1], result_ids[chunk - 1], result_ids[chunk] + ) + ) if number_of_1s == 1: index = col.index(1) @@ -725,12 +765,12 @@ def sat_constraints(self): sage: constraints[-1] 'linear_layer_0_6_23 -sbox_0_0_0 -sbox_0_0_1 -sbox_0_0_2 -sbox_0_0_3 -sbox_0_1_3 -sbox_0_2_1 -sbox_0_3_1 -sbox_0_3_2 -sbox_0_3_3 -sbox_0_4_1 -sbox_0_4_2 -sbox_0_4_3 -sbox_0_5_1 -sbox_0_5_2 -sbox_0_5_3' """ - input_bit_len, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() matrix = self.description constraints = [] for i in range(output_bit_len): - operands = [input_bit_ids[j] for j in range(input_bit_len) if matrix[j][i]] + operands = [input_bit_id for j, input_bit_id in enumerate(input_bit_ids) if matrix[j][i]] constraints.extend(sat_utils.cnf_xor(output_bit_ids[i], operands)) return output_bit_ids, constraints @@ -763,10 +803,10 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): matrix = self.description constraints = [] for i, out_ids_pair in enumerate(zip(out_ids_0, out_ids_1)): - operands = [in_ids_pair for j, in_ids_pair in enumerate(zip(in_ids_0, in_ids_1)) - if matrix[j][i]] - result_ids = [(f'inter_{j}_{self.id}_{i}_0', f'inter_{j}_{self.id}_{i}_1') - for j in range(len(operands) - 2)] + operands = [in_ids_pair for j, in_ids_pair in enumerate(zip(in_ids_0, in_ids_1)) if matrix[j][i]] + result_ids = [ + (f"inter_{j}_{self.id}_{i}_0", f"inter_{j}_{self.id}_{i}_1") for j in range(len(operands) - 2) + ] result_ids.append(out_ids_pair) if len(operands) == 1: constraints.extend(sat_utils.cnf_equivalent([result_ids[0][0], operands[0][0]])) @@ -834,7 +874,7 @@ def sat_xor_linear_mask_propagation_constraints(self, model=None): operands = [input_bit_ids[i]] for j in range(output_bit_len): if inverse_matrix[j][i]: - variable = f'dummy_{i}_{output_bit_ids[j]}' + variable = f"dummy_{i}_{output_bit_ids[j]}" operands.append(variable) dummy_variables[j].append(variable) constraints.extend(sat_utils.cnf_equivalent(operands)) @@ -863,7 +903,7 @@ def smt_constraints(self): sage: constraints[-1] '(assert (= linear_layer_0_6_23 (xor sbox_0_0_0 sbox_0_0_1 sbox_0_0_2 sbox_0_0_3 sbox_0_1_3 sbox_0_2_1 sbox_0_3_1 sbox_0_3_2 sbox_0_3_3 sbox_0_4_1 sbox_0_4_2 sbox_0_4_3 sbox_0_5_1 sbox_0_5_2 sbox_0_5_3)))' """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() matrix = self.description constraints = [] @@ -928,7 +968,7 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): operands = [input_bit_ids[i]] for j in range(output_bit_len): if inverse_matrix[j][i]: - variable = f'dummy_{i}_{output_bit_ids[j]}' + variable = f"dummy_{i}_{output_bit_ids[j]}" operands.append(variable) dummy_variables[j].append(variable) equivalence = smt_utils.smt_equivalent(operands) diff --git a/claasp/components/mix_column_component.py b/claasp/components/mix_column_component.py index b71744ff3..92f75c747 100644 --- a/claasp/components/mix_column_component.py +++ b/claasp/components/mix_column_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -26,22 +25,25 @@ from sage.rings.finite_rings.finite_field_constructor import FiniteField from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_mds_matrices import \ - update_dictionary_that_contains_wordwise_truncated_mds_inequalities, \ - output_dictionary_that_contains_wordwise_truncated_mds_inequalities, \ - delete_dictionary_that_contains_wordwise_truncated_mds_inequalities +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_mds_matrices import ( + update_dictionary_that_contains_wordwise_truncated_mds_inequalities, + output_dictionary_that_contains_wordwise_truncated_mds_inequalities, +) from claasp.cipher_modules.models.milp.utils.utils import espresso_pos_to_constraints from claasp.input import Input from claasp.component import Component, free_input from claasp.utils.utils import int_to_poly from claasp.components.linear_layer_component import LinearLayer -from claasp.cipher_modules.component_analysis_tests import binary_matrix_of_linear_component, branch_number, has_maximal_branch_number +from claasp.cipher_modules.component_analysis_tests import ( + binary_matrix_of_linear_component, + branch_number, + has_maximal_branch_number, +) def add_xor_components(word_size, output_id_link_1, output_id_link_2, output_size, list_of_xor_components): for i in range(output_size // word_size): - input_id_link = [output_id_link_1, output_id_link_2, - f'output_xor_{output_id_link_1}_{output_id_link_2}'] + input_id_link = [output_id_link_1, output_id_link_2, f"output_xor_{output_id_link_1}_{output_id_link_2}"] input_bit_positions = [[] for _ in range(3)] for index in range(word_size): for m in range(3): @@ -51,24 +53,21 @@ def add_xor_components(word_size, output_id_link_1, output_id_link_2, output_siz for input_bit in input_bit_positions: input_len += len(input_bit) component_input = Input(input_len, input_id_link, input_bit_positions) - xor_component = Component("", "word_operation", component_input, input_len, ['XOR', 3]) + xor_component = Component("", "word_operation", component_input, input_len, ["XOR", 3]) list_of_xor_components.append(xor_component) -def calculate_input_bit_positions(word_size, word_index, input_name_1, input_name_2, - new_input_bit_positions_1, new_input_bit_positions_2): +def calculate_input_bit_positions( + word_size, word_index, input_name_1, input_name_2, new_input_bit_positions_1, new_input_bit_positions_2 +): input_bit_positions = [[] for _ in range(3)] if input_name_1 != input_name_2: - input_bit_positions[0] = [int(new_input_bit_positions_1) * word_size + index - for index in range(word_size)] + input_bit_positions[0] = [int(new_input_bit_positions_1) * word_size + index for index in range(word_size)] input_bit_positions[1] = [word_index * word_size + index for index in range(word_size)] - input_bit_positions[2] = [int(new_input_bit_positions_2) * word_size + index - for index in range(word_size)] + input_bit_positions[2] = [int(new_input_bit_positions_2) * word_size + index for index in range(word_size)] else: - input_bit_positions[0] = [int(new_input_bit_positions_1) * word_size + index - for index in range(word_size)] - input_bit_positions[0] += [int(new_input_bit_positions_2) * word_size + index - for index in range(word_size)] + input_bit_positions[0] = [int(new_input_bit_positions_1) * word_size + index for index in range(word_size)] + input_bit_positions[0] += [int(new_input_bit_positions_2) * word_size + index for index in range(word_size)] input_bit_positions[1] = [word_index * word_size + index for index in range(word_size)] return input_bit_positions @@ -78,25 +77,39 @@ def cp_get_all_inputs(word_size, input_bit_positions, input_id_link, numb_of_inp all_inputs = [] for i in range(numb_of_inp): for j in range(len(input_bit_positions[i]) // word_size): - all_inputs.append(f'{input_id_link[i]}' - f'[{input_bit_positions[i][j * word_size] // word_size}]') + all_inputs.append(f"{input_id_link[i]}[{input_bit_positions[i][j * word_size] // word_size}]") return all_inputs class MixColumn(LinearLayer): - def __init__(self, current_round_number, current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, description): - super().__init__(current_round_number, current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, description) - self._id = f'mix_column_{current_round_number}_{current_round_number_of_components}' - self._type = 'mix_column' - - def _cp_add_declarations_and_constraints(self, word_size, mix_column_mant, list_of_xor_components, - cp_constraints, cp_declarations, mix_column_name): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ): + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) + self._id = f"mix_column_{current_round_number}_{current_round_number_of_components}" + self._type = "mix_column" + + def _cp_add_declarations_and_constraints( + self, word_size, mix_column_mant, list_of_xor_components, cp_constraints, cp_declarations, mix_column_name + ): for component_mix in mix_column_mant: - variables, constraints = self._cp_create_component(word_size, component_mix, mix_column_name, - list_of_xor_components) + variables, constraints = self._cp_create_component( + word_size, component_mix, mix_column_name, list_of_xor_components + ) cp_declarations.extend(variables) cp_constraints.extend(constraints) @@ -121,20 +134,22 @@ def _cp_build_truncated_table(self, word_size): input_size = int(self.input_bit_size) output_size = int(self.output_bit_size) output_id_link = self.id - branch = branch_number(self, 'differential', 'word') + branch = branch_number(self, "differential", "word") total_size = (input_size + output_size) // word_size - table_items = '' + table_items = "" solutions = 0 - for i in range(2 ** total_size): - binary_i = f'{i:0{total_size}b}' - bit_sum = sum(int(x) for x in binary_i) + for i in range(2**total_size): + binary_i = f"{i:0{total_size}b}" + bit_sum = sum(map(int, binary_i)) if bit_sum == 0 or bit_sum >= branch: table_items += binary_i solutions += 1 - table = ','.join(table_items) - mix_column_table = f'array[0..{solutions - 1}, 1..{total_size}] of int: ' \ - f'mix_column_truncated_table_{output_id_link} = ' \ - f'array2d(0..{solutions - 1}, 1..{total_size}, [{table}]);' + table = ",".join(table_items) + mix_column_table = ( + f"array[0..{solutions - 1}, 1..{total_size}] of int: " + f"mix_column_truncated_table_{output_id_link} = " + f"array2d(0..{solutions - 1}, 1..{total_size}, [{table}]);" + ) return mix_column_table @@ -168,52 +183,59 @@ def _cp_create_component(self, word_size, component, mix_column_name, list_of_xo return cp_declarations, cp_constraints input_id_link_1 = component.input_id_links - all_inputs_1 = cp_get_all_inputs(word_size, component.input_bit_positions, input_id_link_1, - len(input_id_link_1)) + all_inputs_1 = cp_get_all_inputs( + word_size, component.input_bit_positions, input_id_link_1, len(input_id_link_1) + ) input_id_link_2 = self.input_id_links - all_inputs_2 = cp_get_all_inputs(word_size, self.input_bit_positions, input_id_link_2, - len(input_id_link_2)) + all_inputs_2 = cp_get_all_inputs(word_size, self.input_bit_positions, input_id_link_2, len(input_id_link_2)) input_size = int(component.input_bit_size) output_id_link_1 = component.id output_id_link_2 = self.id cp_declarations.append( - f'array[0..{(input_size - 1) // word_size}] of var 0..1: input_xor_{output_id_link_1}_{output_id_link_2};') + f"array[0..{(input_size - 1) // word_size}] of var 0..1: input_xor_{output_id_link_1}_{output_id_link_2};" + ) cp_declarations.append( - f'array[0..{(input_size - 1) // word_size}] of var 0..1: output_xor_{output_id_link_1}_{output_id_link_2};') + f"array[0..{(input_size - 1) // word_size}] of var 0..1: output_xor_{output_id_link_1}_{output_id_link_2};" + ) for word_index in range(input_size // word_size): input_id_link = [] - divide_1 = all_inputs_1[word_index].partition('[') + divide_1 = all_inputs_1[word_index].partition("[") input_name_1 = divide_1[0] new_input_bit_positions_1 = divide_1[2][:-1] - divide_2 = all_inputs_2[word_index].partition('[') + divide_2 = all_inputs_2[word_index].partition("[") input_name_2 = divide_2[0] new_input_bit_positions_2 = divide_2[2][:-1] if all_inputs_1[word_index] == all_inputs_2[word_index]: input_bit_positions = [[] for _ in range(3)] - cp_constraints.append( - f'constraint input_xor_{output_id_link_1}_{output_id_link_2}[{word_index}] = 0') + cp_constraints.append(f"constraint input_xor_{output_id_link_1}_{output_id_link_2}[{word_index}] = 0") else: input_id_link.append(input_name_1) - input_id_link.append(f'input_xor_{output_id_link_1}_{output_id_link_2}') + input_id_link.append(f"input_xor_{output_id_link_1}_{output_id_link_2}") if input_name_1 != input_name_2: input_id_link.append(input_name_2) - input_bit_positions = calculate_input_bit_positions(word_size, word_index, - input_name_1, input_name_2, - new_input_bit_positions_1, - new_input_bit_positions_2) + input_bit_positions = calculate_input_bit_positions( + word_size, + word_index, + input_name_1, + input_name_2, + new_input_bit_positions_1, + new_input_bit_positions_2, + ) input_bit_positions = [x for x in input_bit_positions if x != []] input_len = 0 for input_bit in input_bit_positions: input_len += len(input_bit) component_input = Input(input_len, input_id_link, input_bit_positions) - xor_component = Component("", "word_operation", component_input, input_len, ['XOR', 3]) + xor_component = Component("", "word_operation", component_input, input_len, ["XOR", 3]) list_of_xor_components.append(xor_component) - new_constraint = 'constraint table(' - new_constraint += f'[input_xor_{output_id_link_1}_{output_id_link_2}[s]|s in ' \ - f'0..{(input_size - 1) // word_size}]++' - new_constraint += f'[output_xor_{output_id_link_1}_{output_id_link_2}[s]|s in ' \ - f'0..{(input_size - 1) // word_size}]' - new_constraint += f',mix_column_truncated_table_{mix_column_name});' + new_constraint = "constraint table(" + new_constraint += ( + f"[input_xor_{output_id_link_1}_{output_id_link_2}[s]|s in 0..{(input_size - 1) // word_size}]++" + ) + new_constraint += ( + f"[output_xor_{output_id_link_1}_{output_id_link_2}[s]|s in 0..{(input_size - 1) // word_size}]" + ) + new_constraint += f",mix_column_truncated_table_{mix_column_name});" cp_constraints.append(new_constraint) output_size = int(component.output_bit_size) add_xor_components(word_size, output_id_link_1, output_id_link_2, output_size, list_of_xor_components) @@ -249,9 +271,9 @@ def algebraic_polynomials(self, model): deg_of_extension = self.description[2] if self.description[1] != 0: coefficient_vector = ZZ(self.description[1]).digits(base=2) - E = FiniteField(2 ** deg_of_extension, name='Z', modulus=coefficient_vector) + E = FiniteField(2**deg_of_extension, name="Z", modulus=coefficient_vector) else: - E = FiniteField(2 ** deg_of_extension) + E = FiniteField(2**deg_of_extension) init_matrix = self.description[0] M = Matrix(E, [[E.fetch_int(value) for value in row] for row in init_matrix]) @@ -267,8 +289,8 @@ def algebraic_polynomials(self, model): F = Sequence((M * X).list()[i] + Y[i] for i in range(noutput_words)).weil_restriction() - input_vars = [self.id + '_' + model.input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + '_' + model.output_postfix + str(i) for i in range(noutputs)] + input_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] ring_R = F.ring().change_ring(names=input_vars + output_vars) polynomials = [ring_R(f) for f in F] @@ -329,8 +351,7 @@ def cp_constraints(self): 'constraint mix_column_0_21[31] = (rot_0_17[0] + rot_0_17[7] + rot_0_18[7] + rot_0_19[7] + rot_0_20[0]) mod 2;']) """ matrix = binary_matrix_of_linear_component(self) - matrix_transposed = [[matrix[i][j] for i in range(matrix.nrows())] - for j in range(matrix.ncols())] + matrix_transposed = [[matrix[i][j] for i in range(matrix.nrows())] for j in range(matrix.ncols())] original_description = deepcopy(self.description) self.set_description(matrix_transposed) cp_declarations, cp_constraints = super().cp_constraints() @@ -353,13 +374,12 @@ def cp_deterministic_truncated_xor_differential_constraints(self, inverse=False) sage: mix_column_component = aes.component_from(0, 21) sage: mix_column_component.cp_deterministic_truncated_xor_differential_constraints() ([], - ['constraint if ((rot_0_17[1] < 2) /\\ (rot_0_18[0] < 2) /\\ (rot_0_18[1] < 2) /\\ (rot_0_19[0] < 2) /\\ (rot_0_20[0]< 2)) then mix_column_0_21[0] = (rot_0_17[1] + rot_0_18[0] + rot_0_18[1] + rot_0_19[0] + rot_0_20[0]) mod 2 else mix_column_0_21[0] = 2 endif;', + ['constraint if ((rot_0_17[1] < 2) /\\ (rot_0_18[0] < 2) /\\ (rot_0_18[1] < 2) /\\ (rot_0_19[0] < 2) /\\ (rot_0_20[0] < 2)) then mix_column_0_21[0] = (rot_0_17[1] + rot_0_18[0] + rot_0_18[1] + rot_0_19[0] + rot_0_20[0]) mod 2 else mix_column_0_21[0] = 2 endif;', ... - 'constraint if ((rot_0_17[0] < 2) /\\ (rot_0_17[7] < 2) /\\ (rot_0_18[7] < 2) /\\ (rot_0_19[7] < 2) /\\ (rot_0_20[0]< 2)) then mix_column_0_21[31] = (rot_0_17[0] + rot_0_17[7] + rot_0_18[7] + rot_0_19[7] + rot_0_20[0]) mod 2 else mix_column_0_21[31] = 2 endif;']) + 'constraint if ((rot_0_17[0] < 2) /\\ (rot_0_17[7] < 2) /\\ (rot_0_18[7] < 2) /\\ (rot_0_19[7] < 2) /\\ (rot_0_20[0] < 2)) then mix_column_0_21[31] = (rot_0_17[0] + rot_0_17[7] + rot_0_18[7] + rot_0_19[7] + rot_0_20[0]) mod 2 else mix_column_0_21[31] = 2 endif;']) """ matrix = binary_matrix_of_linear_component(self) - matrix_transposed = [[matrix[i][j] for i in range(matrix.nrows())] - for j in range(matrix.ncols())] + matrix_transposed = [[matrix[i][j] for i in range(matrix.nrows())] for j in range(matrix.ncols())] original_description = deepcopy(self.description) self.set_description(matrix_transposed) cp_declarations, cp_constraints = super().cp_deterministic_truncated_xor_differential_constraints() @@ -400,15 +420,16 @@ def cp_xor_differential_propagation_first_step_constraints(self, model): mix_column_name = output_id_link number_of_mix = 0 is_mix = False - additional_constraint = 'no' + additional_constraint = "no" for i in range(numb_of_inp): for j in range(len(input_bit_positions[i]) // model.word_size): all_inputs.append( - f'{input_id_link[i]}[{input_bit_positions[i][j * model.word_size] // model.word_size}]') + f"{input_id_link[i]}[{input_bit_positions[i][j * model.word_size] // model.word_size}]" + ) rem = len(input_bit_positions[i]) % model.word_size if rem != 0: rem = model.word_size - (len(input_bit_positions[i]) % model.word_size) - all_inputs.append(f'{output_id_link}_i[{number_of_mix}]') + all_inputs.append(f"{output_id_link}_i[{number_of_mix}]") number_of_mix += 1 is_mix = True l = 1 @@ -419,8 +440,8 @@ def cp_xor_differential_propagation_first_step_constraints(self, model): l += 1 cp_declarations = [] if is_mix: - cp_declarations.append(f'array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;') - cp_declarations.append(f'array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};') + cp_declarations.append(f"array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;") + cp_declarations.append(f"array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};") already_in = False for mant in model.mix_column_mant: if description == mant.description: @@ -429,15 +450,21 @@ def cp_xor_differential_propagation_first_step_constraints(self, model): break if not already_in: cp_declarations.append(self._cp_build_truncated_table(model.word_size)) - table_inputs = '++'.join([f'[{input_}]' for input_ in all_inputs]) - table_outputs = '++'.join([f'[{output_id_link}[{i}]]' for i in range(output_size // model.word_size)]) - new_constraint = f'constraint table({table_inputs}++{table_outputs}, ' \ - f'mix_column_truncated_table_{mix_column_name});' + table_inputs = "++".join([f"[{input_}]" for input_ in all_inputs]) + table_outputs = "++".join([f"[{output_id_link}[{i}]]" for i in range(output_size // model.word_size)]) + new_constraint = ( + f"constraint table({table_inputs}++{table_outputs}, mix_column_truncated_table_{mix_column_name});" + ) cp_constraints = [new_constraint] - if additional_constraint == 'yes': - self._cp_add_declarations_and_constraints(model.word_size, model.mix_column_mant, - model.list_of_xor_components, cp_constraints, - cp_declarations, mix_column_name) + if additional_constraint == "yes": + self._cp_add_declarations_and_constraints( + model.word_size, + model.mix_column_mant, + model.list_of_xor_components, + cp_constraints, + cp_declarations, + mix_column_name, + ) model.mix_column_mant.append(self) result = cp_declarations, cp_constraints @@ -466,38 +493,33 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): ... 'constraint mix_column_0_21_i[31]=(mix_column_0_21_o[0]+mix_column_0_21_o[2]+mix_column_0_21_o[7]+mix_column_0_21_o[9]+mix_column_0_21_o[10]+mix_column_0_21_o[15]+mix_column_0_21_o[18]+mix_column_0_21_o[23]+mix_column_0_21_o[24]+mix_column_0_21_o[25]+mix_column_0_21_o[26]) mod 2;']) """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) - output_id_link = self.id - matrix_component = binary_matrix_of_linear_component(self) - cp_declarations = [] + cp_declarations = [ + f"array[0..{self.input_bit_size - 1}] of var 0..1:{self.id}_i;", + f"array[0..{self.output_bit_size - 1}] of var 0..1:{self.id}_o;", + ] cp_constraints = [] + matrix_component = binary_matrix_of_linear_component(self) matrix = Matrix(FiniteField(2), matrix_component) inverse_matrix = matrix.inverse() - cp_declarations.append(f'array[0..{input_size - 1}] of var 0..1:{output_id_link}_i;') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1:{output_id_link}_o;') - for i in range(input_size): - new_constraint = f'constraint {output_id_link}_i[{i}]=(' - for j in range(input_size): - if inverse_matrix[i][j] == 1: - new_constraint = new_constraint + f'{output_id_link}_o[{j}]+' - new_constraint = new_constraint[:-1] + ') mod 2;' + for i in range(self.input_bit_size): + addenda = [f"{self.id}_o[{j}]" for j in range(self.input_bit_size) if inverse_matrix[i][j] == 1] + new_constraint = f"constraint {self.id}_i[{i}]=(" + "+".join(addenda) + ") mod 2;" cp_constraints.append(new_constraint) - result = cp_declarations, cp_constraints - return result + + return cp_declarations, cp_constraints def get_bit_based_c_code(self, verbosity): mix_column_code = [] self.select_bits(mix_column_code) - mix_column_code.append('\tmatrix = (uint64_t*[]) {') + mix_column_code.append("\tmatrix = (uint64_t*[]) {") for row in self.description[0]: - mix_column_code.append(f'\t\t(uint64_t[]) {{{", ".join([str(x) for x in row])}}},') - mix_column_code.append('\t};') + mix_column_code.append(f"\t\t(uint64_t[]) {{{', '.join(map(str, row))}}},") + mix_column_code.append("\t};") mix_column_code.append( - f'\tBitString* {self.id} = ' - f'MIX_COLUMNS(input, matrix, {self.description[1]}, {self.description[2]});\n') + f"\tBitString* {self.id} = MIX_COLUMNS(input, matrix, {self.description[1]}, {self.description[2]});\n" + ) if verbosity: self.print_values(mix_column_code) @@ -510,42 +532,49 @@ def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): matrix = self.description[0] polynomial = self.description[1] input_size = self.description[2] - params_mix_column = '' - mul_tables = '' + params_mix_column = "" + mul_tables = "" if polynomial > 0 and polynomial != 257: - mul_tables = dict() - F2 = FiniteField(2)['x'] + mul_tables = {} + F2 = FiniteField(2)["x"] _modulus = int_to_poly(polynomial, input_size + 1, F2.gen()) - F = FiniteField(pow(2, input_size), name='a', modulus=_modulus) + F = FiniteField(pow(2, input_size), name="a", modulus=_modulus) for row in matrix: for element in row: if element not in mul_tables: - mul_tables[element] = [(F.fetch_int(i) * F.fetch_int(element)).integer_representation() - for i in range(2 ** input_size)] + mul_tables[element] = [ + (F.fetch_int(i) * F.fetch_int(element)).integer_representation() + for i in range(2**input_size) + ] params_mix_column = [ - f'bit_vector_select_word({self.input_id_links[i]}, {self.input_bit_positions[i]})' - for i in range(len(self.input_id_links))] + f"bit_vector_select_word({self.input_id_links[i]}, {self.input_bit_positions[i]})" + for i in range(len(self.input_id_links)) + ] - return [f' {self.id} = bit_vector_mix_column(bit_vector_CONCAT([{",".join(params_mix_column)} ]), ' - f'{matrix}, {mul_tables}, {input_size})'] + return [ + f" {self.id} = bit_vector_mix_column(bit_vector_CONCAT([{','.join(params_mix_column)} ]), " + f"{matrix}, {mul_tables}, {input_size})" + ] def get_byte_based_vectorized_python_code(self, params): matrix = self.description[0] polynomial = self.description[1] input_size = self.description[2] if polynomial > 0 and polynomial != 257: # check if in 0..2**n-1 - mul_tables = dict() - F2 = FiniteField(2)['x'] + mul_tables = {} + F2 = FiniteField(2)["x"] _modulus = int_to_poly(polynomial, input_size + 1, F2.gen()) - F = FiniteField(pow(2, input_size), name='a', modulus=_modulus) + F = FiniteField(pow(2, input_size), name="a", modulus=_modulus) for row in matrix: for element in row: if element not in mul_tables: - mul_tables[element] = [(F.fetch_int(i) * F.fetch_int(element)).integer_representation() - for i in range(2 ** input_size)] - return [f' {self.id}=byte_vector_mix_column({params} , {matrix}, {mul_tables}, {input_size})'] - return [f' {self.id}=byte_vector_mix_column_poly0({params} , {matrix}, {input_size})'] + mul_tables[element] = [ + (F.fetch_int(i) * F.fetch_int(element)).integer_representation() + for i in range(2**input_size) + ] + return [f" {self.id}=byte_vector_mix_column({params} , {matrix}, {mul_tables}, {input_size})"] + return [f" {self.id}=byte_vector_mix_column_poly0({params} , {matrix}, {input_size})"] def milp_constraints(self, model): """ @@ -577,8 +606,7 @@ def milp_constraints(self, model): 1 <= 1 + x_1 + x_8 - x_9 + x_16 + x_24 + x_32] """ bin_matrix = binary_matrix_of_linear_component(self) - matrix_transposed = [[bin_matrix[i][j] for i in range(bin_matrix.nrows())] - for j in range(bin_matrix.ncols())] + matrix_transposed = [[bin_matrix[i][j] for i in range(bin_matrix.nrows())] for j in range(bin_matrix.ncols())] original_description = deepcopy(self.description) self.set_description(matrix_transposed) variables, constraints = super().milp_constraints(model) @@ -622,8 +650,7 @@ def milp_xor_linear_mask_propagation_constraints(self, model): x_127 == x_35] """ bin_matrix = binary_matrix_of_linear_component(self) - matrix_transposed = [[bin_matrix[i][j] for i in range(bin_matrix.nrows())] - for j in range(bin_matrix.ncols())] + matrix_transposed = [[bin_matrix[i][j] for i in range(bin_matrix.nrows())] for j in range(bin_matrix.ncols())] original_description = deepcopy(self.description) self.set_description(matrix_transposed) variables, constraints = super().milp_xor_linear_mask_propagation_constraints(model) @@ -679,8 +706,11 @@ def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, mod if has_maximal_branch_number(self): x = model.binary_variable input_class_tuple, output_class_tuple = self._get_wordwise_input_output_linked_class_tuples(model) - variables = [(f"x[{var_elt}]", x[var_elt]) for var_tuple in input_class_tuple + output_class_tuple for - var_elt in var_tuple] + variables = [ + (f"x[{var_elt}]", x[var_elt]) + for var_tuple in input_class_tuple + output_class_tuple + for var_elt in var_tuple + ] matrix = Matrix(self.description[0]) all_vars = [x[i] for _ in input_class_tuple + output_class_tuple for i in _] @@ -693,8 +723,7 @@ def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, mod constraints.extend(minimized_constraints) else: M = self.description[0] - bin_matrix = Matrix([[1 if M[i][j] else 0 for i in range(len(M))] - for j in range(len(M[0]))]) + bin_matrix = Matrix([[1 if M[i][j] else 0 for i in range(len(M))] for j in range(len(M[0]))]) bin_matrix_transposed = [list(_) for _ in list(zip(*bin_matrix))] self.set_description(bin_matrix_transposed) variables, constraints = super().milp_wordwise_deterministic_truncated_xor_differential_constraints(model) diff --git a/claasp/components/modadd_component.py b/claasp/components/modadd_component.py index e7ad86c9a..c8124a4ae 100644 --- a/claasp/components/modadd_component.py +++ b/claasp/components/modadd_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -25,13 +24,14 @@ def cms_modadd(output_ids, input0_ids, input1_ids, carry_ids): # The CMS modular addition between 2 addenda constraints = [] - for carry_id, input0_id, input1_id, previous_carry_id in zip(carry_ids, input0_ids[1:], - input1_ids[1:], carry_ids[1:]): + for carry_id, input0_id, input1_id, previous_carry_id in zip( + carry_ids, input0_ids[1:], input1_ids[1:], carry_ids[1:] + ): constraints.extend(sat_utils.cnf_carry(carry_id, input0_id, input1_id, previous_carry_id)) constraints.extend(sat_utils.cnf_and(carry_ids[-1], (input0_ids[-1], input1_ids[-1]))) for output_id, input0_id, input1_id, carry_id in zip(output_ids, input0_ids, input1_ids, carry_ids): - constraints.append(f'x -{output_id} {input0_id} {input1_id} {carry_id}') - constraints.append(f'x -{output_ids[-1]} {input0_ids[-1]} {input1_ids[-1]}') + constraints.append(f"x -{output_id} {input0_id} {input1_id} {carry_id}") + constraints.append(f"x -{output_ids[-1]} {input0_ids[-1]} {input1_ids[-1]}") return constraints @@ -44,18 +44,23 @@ def cms_modadd_seq(outputs_ids, inputs_ids, carries_ids): def cp_twoterms(input_1, input_2, out, input_length, cp_constraints, cp_declarations): - cp_declarations.append(f'array[1..{input_length - 1}] of var 0..1: carry_{out};') + cp_declarations.append(f"array[1..{input_length - 1}] of var 0..1: carry_{out};") for i in range(1, input_length - 1): cp_constraints.append( - f'constraint carry_{out}[{i}] = ({input_1}[{i}]*{input_2}[{i}] + ' - f'{input_1}[{i}]*carry_{out}[{i + 1}] + carry_{out}[{i + 1}]*{input_2}[{i}]) mod 2;') - cp_constraints.append(f'constraint carry_{out}[{input_length - 1}] = ' - f'({input_1}[{input_length - 1}] * {input_2}[{input_length - 1}]) mod 2;') + f"constraint carry_{out}[{i}] = ({input_1}[{i}]*{input_2}[{i}] + " + f"{input_1}[{i}]*carry_{out}[{i + 1}] + carry_{out}[{i + 1}]*{input_2}[{i}]) mod 2;" + ) + cp_constraints.append( + f"constraint carry_{out}[{input_length - 1}] = " + f"({input_1}[{input_length - 1}] * {input_2}[{input_length - 1}]) mod 2;" + ) for i in range(input_length - 1): - cp_constraints.append(f'constraint {out}[{i}] = ' - f'({input_1}[{i}] + {input_2}[{i}] + carry_{out}[{i + 1}]) mod 2;') - cp_constraints.append(f'constraint {out}[{input_length - 1}] = ' - f'({input_1}[{input_length - 1}] + {input_2}[{input_length - 1}]) mod 2;') + cp_constraints.append( + f"constraint {out}[{i}] = ({input_1}[{i}] + {input_2}[{i}] + carry_{out}[{i + 1}]) mod 2;" + ) + cp_constraints.append( + f"constraint {out}[{input_length - 1}] = ({input_1}[{input_length - 1}] + {input_2}[{input_length - 1}]) mod 2;" + ) return cp_declarations, cp_constraints @@ -63,7 +68,9 @@ def cp_twoterms(input_1, input_2, out, input_length, cp_constraints, cp_declarat def sat_modadd(output_ids, input0_ids, input1_ids, carry_ids): # The SAT modular addition between 2 addenda constraints = [] - for carry_id, input0_id, input1_id, previous_carry_id in zip(carry_ids, input0_ids[1:], input1_ids[1:], carry_ids[1:]): + for carry_id, input0_id, input1_id, previous_carry_id in zip( + carry_ids, input0_ids[1:], input1_ids[1:], carry_ids[1:] + ): constraints.extend(sat_utils.cnf_carry(carry_id, input0_id, input1_id, previous_carry_id)) constraints.extend(sat_utils.cnf_and(carry_ids[-1], (input0_ids[-1], input1_ids[-1]))) for output_id, input0_id, input1_id, carry_id in zip(output_ids, input0_ids, input1_ids, carry_ids): @@ -83,7 +90,9 @@ def sat_modadd_seq(outputs_ids, inputs_ids, carries_ids): def smt_modadd(output_ids, input0_ids, input1_ids, carry_ids): # The SMT modular addition between 2 addenda constraints = [] - for carry_id, input0_id, input1_id, previous_carry_id in zip(carry_ids, input0_ids[1:], input1_ids[1:], carry_ids[1:]): + for carry_id, input0_id, input1_id, previous_carry_id in zip( + carry_ids, input0_ids[1:], input1_ids[1:], carry_ids[1:] + ): operation = smt_utils.smt_carry(input0_id, input1_id, previous_carry_id) equation = smt_utils.smt_equivalent((carry_id, operation)) constraints.append(smt_utils.smt_assert(equation)) @@ -109,10 +118,24 @@ def smt_modadd_seq(outputs_ids, inputs_ids, carries_ids): class MODADD(Modular): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, modulus): - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, 'modadd', modulus) + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + modulus, + ): + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + "modadd", + modulus, + ) def algebraic_polynomials(self, model): """ @@ -142,12 +165,10 @@ def algebraic_polynomials(self, model): ninput_bits = self.input_bit_size noutput_bits = word_size = self.output_bit_size - input_vars = [component_id + "_" + model.input_postfix + str(i) for i in range(ninput_bits)] - output_vars = [component_id + "_" + model.output_postfix + str(i) for i in range(noutput_bits)] - carries_vars = \ - [[component_id + "_" + "c" + str(n) + "_" + str(i) for i in range(word_size)] for n in range(nadditions)] - aux_outputs_vars = [[component_id + "_" + "o" + str(n) + "_" + str(i) for i in range(word_size)] for n in - range(nadditions - 1)] + input_vars = [f"{component_id}_{model.input_postfix}{i}" for i in range(ninput_bits)] + output_vars = [f"{component_id}_{model.output_postfix}{i}" for i in range(noutput_bits)] + carries_vars = [[f"{component_id}_c{n}_{i}" for i in range(word_size)] for n in range(nadditions)] + aux_outputs_vars = [[f"{component_id}_o{n}_{i}" for i in range(word_size)] for n in range(nadditions - 1)] ring_R = model.ring() input_vars = list(map(ring_R, input_vars)) @@ -155,7 +176,9 @@ def algebraic_polynomials(self, model): carries_vars = [list(map(ring_R, carry_vars)) for carry_vars in carries_vars] aux_outputs_vars = [list(map(ring_R, aux_output_vars)) for aux_output_vars in aux_outputs_vars] - def maj(xi, yi, zi): return xi * yi + xi * zi + yi * zi + def maj(xi, yi, zi): + return xi * yi + xi * zi + yi * zi + polynomials = [] for n in range(nadditions): # z = x + y if n == 0: @@ -168,7 +191,7 @@ def maj(xi, yi, zi): return xi * yi + xi * zi + yi * zi else: z = aux_outputs_vars[n] - y = input_vars[(n + 1) * word_size: (n + 1) * word_size + word_size] + y = input_vars[(n + 1) * word_size : (n + 1) * word_size + word_size] c = carries_vars[n] polynomials += [c[0] + 0] @@ -209,28 +232,33 @@ def cms_constraints(self): 'x -modadd_0_1_14 rot_0_0_14 plaintext_30 carry_modadd_0_1_14', 'x -modadd_0_1_15 rot_0_0_15 plaintext_31']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - carry_bit_ids = [f'carry_{output_bit_ids[i]}' for i in range(output_bit_len - 1)] + carry_bit_ids = [f"carry_{output_bit_ids[i]}" for i in range(output_bit_len - 1)] constraints = [] # carries for i in range(output_bit_len - 2): - constraints.extend(sat_utils.cnf_carry(carry_bit_ids[i], - input_bit_ids[i + 1], - input_bit_ids[output_bit_len + i + 1], - carry_bit_ids[i + 1])) - constraints.extend(sat_utils.cnf_and(carry_bit_ids[output_bit_len - 2], - (input_bit_ids[output_bit_len - 1], - input_bit_ids[2 * output_bit_len - 1]))) + constraints.extend( + sat_utils.cnf_carry( + carry_bit_ids[i], input_bit_ids[i + 1], input_bit_ids[output_bit_len + i + 1], carry_bit_ids[i + 1] + ) + ) + constraints.extend( + sat_utils.cnf_and( + carry_bit_ids[output_bit_len - 2], + (input_bit_ids[output_bit_len - 1], input_bit_ids[2 * output_bit_len - 1]), + ) + ) # results for CryptoMiniSat can be implemented using the leading x for i in range(output_bit_len - 1): - constraints.append(f'x -{output_bit_ids[i]} ' - f'{input_bit_ids[i]} ' - f'{input_bit_ids[output_bit_len + i]} ' - f'{carry_bit_ids[i]}') - constraints.append(f'x -{output_bit_ids[output_bit_len - 1]} ' - f'{input_bit_ids[output_bit_len - 1]} ' - f'{input_bit_ids[2 * output_bit_len - 1]}') + constraints.append( + f"x -{output_bit_ids[i]} {input_bit_ids[i]} {input_bit_ids[output_bit_len + i]} {carry_bit_ids[i]}" + ) + constraints.append( + f"x -{output_bit_ids[output_bit_len - 1]} " + f"{input_bit_ids[output_bit_len - 1]} " + f"{input_bit_ids[2 * output_bit_len - 1]}" + ) return carry_bit_ids + output_bit_ids, constraints @@ -256,55 +284,71 @@ def cp_constraints(self): 'constraint modadd_0_1[14] = (pre_modadd_0_1_1[14] + pre_modadd_0_1_0[14] + carry_modadd_0_1[15]) mod 2;', 'constraint modadd_0_1[15] = (pre_modadd_0_1_1[15] + pre_modadd_0_1_0[15]) mod 2;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links output_id_link = self.id - input_bit_positions = self.input_bit_positions num_add = self.description[1] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) input_len = len(all_inputs) // num_add cp_declarations = [] cp_constraints = [] for i in range(num_add): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};') - cp_constraints.extend([f'constraint pre_{output_id_link}_{i}[{j}] = {all_inputs[i * input_len + j]};' - for j in range(input_len)]) + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};") + cp_constraints.extend( + [ + f"constraint pre_{output_id_link}_{i}[{j}] = {all_inputs[i * input_len + j]};" + for j in range(input_len) + ] + ) for i in range(num_add, 2 * num_add - 2): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};") for i in range(num_add - 2): - cp_twoterms(f'pre_{output_id_link}_{num_add - 1}', f'pre_{output_id_link}_{i + 1}', - f'pre_{output_id_link}_{num_add + i}', output_size, - cp_constraints, cp_declarations) - cp_twoterms(f'pre_{output_id_link}_{2 * num_add - 3}', f'pre_{output_id_link}_0', f'{output_id_link}', - output_size, cp_constraints, cp_declarations) + cp_twoterms( + f"pre_{output_id_link}_{num_add - 1}", + f"pre_{output_id_link}_{i + 1}", + f"pre_{output_id_link}_{num_add + i}", + self.output_bit_size, + cp_constraints, + cp_declarations, + ) + cp_twoterms( + f"pre_{output_id_link}_{2 * num_add - 3}", + f"pre_{output_id_link}_0", + f"{output_id_link}", + self.output_bit_size, + cp_constraints, + cp_declarations, + ) return cp_declarations, cp_constraints - def cp_twoterms_xor_differential_probability(self, inp1, inp2, out, inplen, cp_constraints, cp_declarations, c, model): + def cp_twoterms_xor_differential_probability( + self, inp1, inp2, out, inplen, cp_constraints, cp_declarations, c, model + ): if inp1 not in model.modadd_twoterms_mant: - cp_declarations.append(f'array[0..{inplen - 1}] of var 0..1: Shi_{inp1} = LShift({inp1},1);') + cp_declarations.append(f"array[0..{inplen - 1}] of var 0..1: Shi_{inp1} = LShift({inp1},1);") model.modadd_twoterms_mant.append(inp1) if inp2 not in model.modadd_twoterms_mant: - cp_declarations.append(f'array[0..{inplen - 1}] of var 0..1: Shi_{inp2} = LShift({inp2},1);') + cp_declarations.append(f"array[0..{inplen - 1}] of var 0..1: Shi_{inp2} = LShift({inp2},1);") model.modadd_twoterms_mant.append(inp2) if out not in model.modadd_twoterms_mant: - cp_declarations.append(f'array[0..{inplen - 1}] of var 0..1: Shi_{out} = LShift({out},1);') + cp_declarations.append(f"array[0..{inplen - 1}] of var 0..1: Shi_{out} = LShift({out},1);") model.modadd_twoterms_mant.append(out) - cp_declarations.append(f'array[0..{inplen - 1}] of var 0..1: eq_{out} = Eq(Shi_{inp1}, Shi_{inp2}, Shi_{out});') + cp_declarations.append(f"array[0..{inplen - 1}] of var 0..1: eq_{out} = Eq(Shi_{inp1}, Shi_{inp2}, Shi_{out});") cp_constraints.append( - f'constraint forall(j in 0..{inplen - 1})(if eq_{out}[j] = 1 then (sum([{inp1}[j], {inp2}[j], ' - f'{out}[j]]) mod 2) = Shi_{inp2}[j] else true endif) /\\ p[{c}] = {100 * inplen}-100 * sum(eq_{out});') + f"constraint forall(j in 0..{inplen - 1})(if eq_{out}[j] = 1 then (sum([{inp1}[j], {inp2}[j], " + f"{out}[j]]) mod 2) = Shi_{inp2}[j] else true endif) /\\ p[{c}] = {100 * inplen}-100 * sum(eq_{out});" + ) return cp_declarations, cp_constraints def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = ' - f'bit_vector_MODADD([{",".join(params)} ], {self.description[1]}, {self.output_bit_size})'] + return [ + f" {self.id} = bit_vector_MODADD([{','.join(params)} ], {self.description[1]}, {self.output_bit_size})" + ] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_MODADD({params})'] + return [f" {self.id} = byte_vector_MODADD({params})"] def sat_constraints(self): """ @@ -339,20 +383,22 @@ def sat_constraints(self): 'modadd_0_1_15 rot_0_0_15 -plaintext_31', '-modadd_0_1_15 -rot_0_0_15 -plaintext_31']) """ - _, input_ids = self._generate_input_ids() + input_ids = self._generate_input_ids() output_len, output_ids = self._generate_output_ids() num_of_addenda = self.description[1] # reformat of the in_ids - inputs_ids = [input_ids[i * output_len: (i + 1) * output_len] for i in range(num_of_addenda)] + inputs_ids = [input_ids[i * output_len : (i + 1) * output_len] for i in range(num_of_addenda)] # carries - carries_ids = [[f'carry_{i}_{output_id}' for output_id in output_ids[:-1]] for i in range(num_of_addenda - 1)] + carries_ids = [[f"carry_{i}_{output_id}" for output_id in output_ids[:-1]] for i in range(num_of_addenda - 1)] # reformat of the outputs_ids - outputs_ids = [[f'modadd_output_{i}_{output_id}' for output_id in output_ids] - for i in range(num_of_addenda - 2)] + [output_ids] + outputs_ids = [ + [f"modadd_output_{i}_{output_id}" for output_id in output_ids] for i in range(num_of_addenda - 2) + ] + [output_ids] constraints = sat_modadd_seq(outputs_ids, inputs_ids, carries_ids) # flattening lists ids = [carry_id for carry_ids in carries_ids for carry_id in carry_ids] ids.extend([output_id for output_ids in outputs_ids for output_id in output_ids]) + return ids, constraints def smt_constraints(self): @@ -384,18 +430,20 @@ def smt_constraints(self): '(assert (= modadd_0_1_30 (xor shift_0_0_30 key_30 carry_0_modadd_0_1_30)))', '(assert (= modadd_0_1_31 (xor shift_0_0_31 key_31)))']) """ - _, input_ids = self._generate_input_ids() + input_ids = self._generate_input_ids() output_len, output_ids = self._generate_output_ids() num_of_addenda = self.description[1] # reformat of the in_ids - inputs_ids = [input_ids[i * output_len: (i + 1) * output_len] for i in range(num_of_addenda)] + inputs_ids = [input_ids[i * output_len : (i + 1) * output_len] for i in range(num_of_addenda)] # carries - carries_ids = [[f'carry_{i}_{output_id}' for output_id in output_ids[:-1]] for i in range(num_of_addenda - 1)] + carries_ids = [[f"carry_{i}_{output_id}" for output_id in output_ids[:-1]] for i in range(num_of_addenda - 1)] # reformat of the outputs_ids - outputs_ids = [[f'modadd_output_{i}_{output_id}' for output_id in output_ids] - for i in range(num_of_addenda - 2)] + [output_ids] + outputs_ids = [ + [f"modadd_output_{i}_{output_id}" for output_id in output_ids] for i in range(num_of_addenda - 2) + ] + [output_ids] constraints = smt_modadd_seq(outputs_ids, inputs_ids, carries_ids) # flattening lists ids = [carry_id for carry_ids in carries_ids for carry_id in carry_ids] ids.extend([output_id for output_ids in outputs_ids for output_id in output_ids]) + return ids, constraints diff --git a/claasp/components/modsub_component.py b/claasp/components/modsub_component.py index 1c2f795cb..826540a7f 100644 --- a/claasp/components/modsub_component.py +++ b/claasp/components/modsub_component.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -22,21 +22,35 @@ def cp_twoterms(input_1, input_2, out, component_name, input_length, cp_constraints, cp_declarations): - cp_declarations.append(f'array[0..{input_length - 1}] of var 0..1:pre_minus_{input_2};') - cp_declarations.append(f'array[0..{input_length - 1}] of var 0..1:minus_{input_2};') + cp_declarations.append(f"array[0..{input_length - 1}] of var 0..1:pre_minus_{input_2};") + cp_declarations.append(f"array[0..{input_length - 1}] of var 0..1:minus_{input_2};") for i in range(input_length): - cp_constraints.append(f'constraint pre_minus_{input_2}[{i}]=({input_2}[{i}] + 1) mod 2;') - cp_constraints.append(f'constraint modadd(pre_minus_{input_2}, constant_{component_name}, minus_{input_2});') - cp_constraints.append(f'constraint modadd({input_1},minus_{input_2},{out});') + cp_constraints.append(f"constraint pre_minus_{input_2}[{i}]=({input_2}[{i}] + 1) mod 2;") + cp_constraints.append(f"constraint modadd(pre_minus_{input_2}, constant_{component_name}, minus_{input_2});") + cp_constraints.append(f"constraint modadd({input_1},minus_{input_2},{out});") return cp_declarations, cp_constraints class MODSUB(Modular): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, modulus): - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, 'modsub', modulus) + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + modulus, + ): + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + "modsub", + modulus, + ) def algebraic_polynomials(self, model): """ @@ -73,12 +87,10 @@ def algebraic_polynomials(self, model): ninput_bits = self.input_bit_size noutput_bits = word_size = self.output_bit_size - input_vars = [component_id + "_" + model.input_postfix + str(i) for i in range(ninput_bits)] - output_vars = [component_id + "_" + model.output_postfix + str(i) for i in range(noutput_bits)] - borrows_vars = [[component_id + "_" + "b" + str(n) + "_" + str(i) for i in range(word_size)] for n in - range(nsubtractions)] - aux_outputs_vars = [[component_id + "_" + "o" + str(n) + "_" + str(i) for i in range(word_size)] for n in - range(nsubtractions - 1)] + input_vars = [f"{component_id}_{model.input_postfix}{i}" for i in range(ninput_bits)] + output_vars = [f"{component_id}_{model.output_postfix}{i}" for i in range(noutput_bits)] + borrows_vars = [[f"{component_id}_b{n}_{i}" for i in range(word_size)] for n in range(nsubtractions)] + aux_outputs_vars = [[f"{component_id}_o{n}_{i}" for i in range(word_size)] for n in range(nsubtractions - 1)] ring_R = model.ring() input_vars = list(map(ring_R, input_vars)) @@ -101,7 +113,7 @@ def borrow_polynomial(xi, yi, bi): else: z = aux_outputs_vars[n] - y = input_vars[(n + 1) * word_size: (n + 2) * word_size] + y = input_vars[(n + 1) * word_size : (n + 2) * word_size] b = borrows_vars[n] polynomials += [b[0] + 0] @@ -170,55 +182,80 @@ def cp_constraints(self): 'constraint modadd(pre_minus_pre_modsub_0_7_1, constant_modsub_0_7, minus_pre_modsub_0_7_1);', 'constraint modadd(pre_modsub_0_7_0,minus_pre_modsub_0_7_1,modsub_0_7);']) """ - output_size = int(self.output_bit_size) - input_id_link = self.input_id_links - numb_of_inp = len(input_id_link) + output_size = self.output_bit_size output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] cp_constraints = [] num_add = self.description[1] all_inputs = [] - for i in range(numb_of_inp): - for j in range(len(input_bit_positions[i])): - all_inputs.append(f'{input_id_link[i]}[{input_bit_positions[i][j]}]') + for i, input_id_link in enumerate(self.input_id_links): + for position in self.input_bit_positions[i]: + all_inputs.append(f"{input_id_link}[{position}]") total_input_len = len(all_inputs) input_len = total_input_len // num_add - new_declaration = f'array[0..{output_size - 1}] of var 0..1: constant_{output_id_link}= ' \ - f'array1d(0..{output_size - 1},[' + new_declaration = ( + f"array[0..{output_size - 1}] of var 0..1: constant_{output_id_link}= array1d(0..{output_size - 1},[" + ) for i in range(output_size - 1): - new_declaration = new_declaration + '0, ' - new_declaration = new_declaration + '1]);' + new_declaration = new_declaration + "0, " + new_declaration = new_declaration + "1]);" cp_declarations.append(new_declaration) - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: {output_id_link};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..1: {output_id_link};") for i in range(num_add): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1:pre_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1:pre_{output_id_link}_{i};") for j in range(input_len): - cp_constraints.append(f'constraint pre_{output_id_link}_{i}[{j}]={all_inputs[i * input_len + j]};') + cp_constraints.append(f"constraint pre_{output_id_link}_{i}[{j}]={all_inputs[i * input_len + j]};") for i in range(num_add - 2): - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1:temp_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..1:temp_{output_id_link}_{i};") if num_add == 2: - cp_twoterms(f'pre_{output_id_link}_0', f'pre_{output_id_link}_1', str(output_id_link), - str(output_id_link), output_size, cp_constraints, cp_declarations) + cp_twoterms( + f"pre_{output_id_link}_0", + f"pre_{output_id_link}_1", + str(output_id_link), + str(output_id_link), + output_size, + cp_constraints, + cp_declarations, + ) elif num_add > 2: - cp_twoterms(f'pre_{output_id_link}_0', f'pre_{output_id_link}_1', f'temp_{output_id_link}_0', - str(output_id_link), output_size, cp_constraints, cp_declarations) + cp_twoterms( + f"pre_{output_id_link}_0", + f"pre_{output_id_link}_1", + f"temp_{output_id_link}_0", + str(output_id_link), + output_size, + cp_constraints, + cp_declarations, + ) for i in range(1, num_add - 2): - cp_twoterms(f'pre_{output_id_link}_{i + 1}', f'temp_{output_id_link}_{i - 1}', - f'temp_{output_id_link}_{i}', str(output_id_link), output_size, cp_constraints, - cp_declarations) - cp_twoterms(f'pre_{output_id_link}_{num_add - 1}', f'temp_{output_id_link}_{num_add - 3}', - str(output_id_link), str(output_id_link), output_size, cp_constraints, - cp_declarations) + cp_twoterms( + f"pre_{output_id_link}_{i + 1}", + f"temp_{output_id_link}_{i - 1}", + f"temp_{output_id_link}_{i}", + str(output_id_link), + output_size, + cp_constraints, + cp_declarations, + ) + cp_twoterms( + f"pre_{output_id_link}_{num_add - 1}", + f"temp_{output_id_link}_{num_add - 3}", + str(output_id_link), + str(output_id_link), + output_size, + cp_constraints, + cp_declarations, + ) return cp_declarations, cp_constraints def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_MODSUB([{",".join(params)} ], ' - f'{self.description[1]}, {self.output_bit_size})'] + return [ + f" {self.id} = bit_vector_MODSUB([{','.join(params)} ], {self.description[1]}, {self.output_bit_size})" + ] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_MODSUB({params})'] + return [f" {self.id} = byte_vector_MODSUB({params})"] def sat_constraints(self): """ @@ -255,44 +292,56 @@ def sat_constraints(self): 'modsub_0_7_31 modadd_0_4_31 -temp_input_plaintext_63', '-modsub_0_7_31 -modadd_0_4_31 -temp_input_plaintext_63']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - temp_carry_bit_ids = [f'temp_carry_{input_bit_ids[output_bit_len + i]}' for i in range(output_bit_len - 1)] - temp_input_bit_ids = [f'temp_input_{input_bit_ids[output_bit_len + i]}' for i in range(output_bit_len)] - carry_bit_ids = [f'carry_{output_bit_ids[i]}' for i in range(output_bit_len - 1)] + temp_carry_bit_ids = [f"temp_carry_{input_bit_ids[output_bit_len + i]}" for i in range(output_bit_len - 1)] + temp_input_bit_ids = [f"temp_input_{input_bit_ids[output_bit_len + i]}" for i in range(output_bit_len)] + carry_bit_ids = [f"carry_{output_bit_ids[i]}" for i in range(output_bit_len - 1)] constraints = [] # carries complement 2 for i in range(output_bit_len - 2): - constraints.extend(sat_utils.cnf_carry_comp2(temp_carry_bit_ids[i], - input_bit_ids[output_bit_len + i + 1], - temp_carry_bit_ids[i + 1])) - constraints.extend(sat_utils.cnf_inequality(temp_carry_bit_ids[output_bit_len - 2], - input_bit_ids[2 * output_bit_len - 1])) + constraints.extend( + sat_utils.cnf_carry_comp2( + temp_carry_bit_ids[i], input_bit_ids[output_bit_len + i + 1], temp_carry_bit_ids[i + 1] + ) + ) + constraints.extend( + sat_utils.cnf_inequality(temp_carry_bit_ids[output_bit_len - 2], input_bit_ids[2 * output_bit_len - 1]) + ) # results complement 2 for i in range(output_bit_len - 1): - constraints.extend(sat_utils.cnf_result_comp2(temp_input_bit_ids[i], - input_bit_ids[output_bit_len + i], - temp_carry_bit_ids[i])) - constraints.extend(sat_utils.cnf_equivalent([temp_input_bit_ids[output_bit_len - 1], - input_bit_ids[2 * output_bit_len - 1]])) + constraints.extend( + sat_utils.cnf_result_comp2( + temp_input_bit_ids[i], input_bit_ids[output_bit_len + i], temp_carry_bit_ids[i] + ) + ) + constraints.extend( + sat_utils.cnf_equivalent([temp_input_bit_ids[output_bit_len - 1], input_bit_ids[2 * output_bit_len - 1]]) + ) # carries for i in range(output_bit_len - 2): - constraints.extend(sat_utils.cnf_carry(carry_bit_ids[i], - input_bit_ids[i + 1], - temp_input_bit_ids[i + 1], - carry_bit_ids[i + 1])) - constraints.extend(sat_utils.cnf_and(carry_bit_ids[output_bit_len - 2], - (input_bit_ids[output_bit_len - 1], - temp_input_bit_ids[output_bit_len - 1]))) + constraints.extend( + sat_utils.cnf_carry( + carry_bit_ids[i], input_bit_ids[i + 1], temp_input_bit_ids[i + 1], carry_bit_ids[i + 1] + ) + ) + constraints.extend( + sat_utils.cnf_and( + carry_bit_ids[output_bit_len - 2], + (input_bit_ids[output_bit_len - 1], temp_input_bit_ids[output_bit_len - 1]), + ) + ) # results for i in range(output_bit_len - 1): - constraints.extend(sat_utils.cnf_xor(output_bit_ids[i], - [input_bit_ids[i], - temp_input_bit_ids[i], - carry_bit_ids[i]])) - constraints.extend(sat_utils.cnf_xor(output_bit_ids[output_bit_len - 1], - [input_bit_ids[output_bit_len - 1], - temp_input_bit_ids[output_bit_len - 1]])) + constraints.extend( + sat_utils.cnf_xor(output_bit_ids[i], [input_bit_ids[i], temp_input_bit_ids[i], carry_bit_ids[i]]) + ) + constraints.extend( + sat_utils.cnf_xor( + output_bit_ids[output_bit_len - 1], + [input_bit_ids[output_bit_len - 1], temp_input_bit_ids[output_bit_len - 1]], + ) + ) return temp_carry_bit_ids + temp_input_bit_ids + carry_bit_ids + output_bit_ids, constraints @@ -327,31 +376,33 @@ def smt_constraints(self): '(assert (= modsub_0_7_30 (xor modadd_0_4_30 temp_input_plaintext_62 carry_modsub_0_7_30)))', '(assert (= modsub_0_7_31 (xor modadd_0_4_31 temp_input_plaintext_63)))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - temp_carry_bit_ids = [f'temp_carry_{input_bit_ids[output_bit_len + i]}' for i in range(output_bit_len - 1)] - temp_input_bit_ids = [f'temp_input_{input_bit_ids[output_bit_len + i]}' for i in range(output_bit_len)] - carry_bit_ids = [f'carry_{output_bit_ids[i]}' for i in range(output_bit_len - 1)] + temp_carry_bit_ids = [f"temp_carry_{input_bit_ids[output_bit_len + i]}" for i in range(output_bit_len - 1)] + temp_input_bit_ids = [f"temp_input_{input_bit_ids[output_bit_len + i]}" for i in range(output_bit_len)] + carry_bit_ids = [f"carry_{output_bit_ids[i]}" for i in range(output_bit_len - 1)] constraints = [] # carries complement 2 for i in range(output_bit_len - 2): - operation = smt_utils.smt_and((smt_utils.smt_not(input_bit_ids[output_bit_len + i + 1]), - temp_carry_bit_ids[i + 1])) + operation = smt_utils.smt_and( + (smt_utils.smt_not(input_bit_ids[output_bit_len + i + 1]), temp_carry_bit_ids[i + 1]) + ) equation = smt_utils.smt_equivalent((temp_carry_bit_ids[i], operation)) constraints.append(smt_utils.smt_assert(equation)) - distinction = smt_utils.smt_distinct(temp_carry_bit_ids[output_bit_len - 2], - input_bit_ids[2 * output_bit_len - 1]) + distinction = smt_utils.smt_distinct( + temp_carry_bit_ids[output_bit_len - 2], input_bit_ids[2 * output_bit_len - 1] + ) constraints.append(smt_utils.smt_assert(distinction)) # results complement 2 for i in range(output_bit_len - 1): - operation = smt_utils.smt_xor((smt_utils.smt_not(input_bit_ids[output_bit_len + i]), - temp_carry_bit_ids[i])) + operation = smt_utils.smt_xor((smt_utils.smt_not(input_bit_ids[output_bit_len + i]), temp_carry_bit_ids[i])) equation = smt_utils.smt_equivalent((temp_input_bit_ids[i], operation)) constraints.append(smt_utils.smt_assert(equation)) - equation = smt_utils.smt_equivalent((temp_input_bit_ids[output_bit_len - 1], - input_bit_ids[2 * output_bit_len - 1])) + equation = smt_utils.smt_equivalent( + (temp_input_bit_ids[output_bit_len - 1], input_bit_ids[2 * output_bit_len - 1]) + ) constraints.append(smt_utils.smt_assert(equation)) # carries diff --git a/claasp/components/modular_component.py b/claasp/components/modular_component.py index 7c2a3bfdf..cf4258674 100644 --- a/claasp/components/modular_component.py +++ b/claasp/components/modular_component.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,6 +21,7 @@ from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.component import Component from claasp.input import Input +from claasp.name_mappings import WORD_OPERATION def milp_n_window_heuristic(input_vars, output_vars, component_id, window_size, mip, x): @@ -30,15 +31,14 @@ def create_window_size_array(j, input_1_vars, input_2_vars, output_vars): temp_vars = x[input_1_vars[j - i]] + x[input_2_vars[j - i]] + x[output_vars[j - i]] mod_add_var = mip.new_variable(name="mod") mip.set_max(mod_add_var, 1) - u = mip.new_variable(name='u') - mip.add_constraint(temp_vars == 2 * u['u' + component_id + str(j) + - str(i)] + mod_add_var["mod" + component_id + str(j) + str(i)]) + u = mip.new_variable(name="u") + mip.add_constraint(temp_vars == 2 * u[f"u{component_id}{j}{i}"] + mod_add_var[f"mod{component_id}{j}{i}"]) temp_array.append(mod_add_var["mod" + component_id + str(j) + str(i)]) return temp_array input_size = int(len(input_vars) / 2) input_1_vars = input_vars[:input_size] - input_2_vars = input_vars[input_size:2 * input_size] + input_2_vars = input_vars[input_size : 2 * input_size] for j in range(window_size, input_size - 1): window_size_array = create_window_size_array(j, input_1_vars, input_2_vars, output_vars) mip.add_constraint(mip.sum(window_size_array) <= int(window_size)) @@ -71,11 +71,18 @@ def generic_sign_linear_constraints(inputs, outputs): class Modular(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, operation, modulus): - - component_id = f'{operation}_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + operation, + modulus, + ): + component_id = f"{operation}_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION input_len = 0 for bits in input_bit_positions: input_len = input_len + len(bits) @@ -120,21 +127,20 @@ def cms_xor_linear_mask_propagation_constraints(self, model=None): _, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX output_bit_len, output_bit_ids = self._generate_output_ids(suffix=out_suffix) - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] - constraints = [f'-{hw_bit_ids[0]}'] - constraints.append(f'x -{hw_bit_ids[1]} {output_bit_ids[0]} ' - f'{input_bit_ids[0]} {input_bit_ids[output_bit_len]}') + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] + constraints = [f"-{hw_bit_ids[0]}"] + constraints.append(f"x -{hw_bit_ids[1]} {output_bit_ids[0]} {input_bit_ids[0]} {input_bit_ids[output_bit_len]}") for i in range(2, output_bit_len): - constraints.append(f'x -{hw_bit_ids[i]} {hw_bit_ids[i - 1]} {output_bit_ids[i - 1]} ' - f'{input_bit_ids[i - 1]} {input_bit_ids[output_bit_len + i - 1]}') + constraints.append( + f"x -{hw_bit_ids[i]} {hw_bit_ids[i - 1]} {output_bit_ids[i - 1]} " + f"{input_bit_ids[i - 1]} {input_bit_ids[output_bit_len + i - 1]}" + ) for i in range(output_bit_len): - constraints.extend(sat_utils.cnf_modadd_inequality(hw_bit_ids[i], - output_bit_ids[i], - input_bit_ids[i])) + constraints.extend(sat_utils.cnf_modadd_inequality(hw_bit_ids[i], output_bit_ids[i], input_bit_ids[i])) for i in range(output_bit_len): - constraints.extend(sat_utils.cnf_modadd_inequality(hw_bit_ids[i], - output_bit_ids[i], - input_bit_ids[output_bit_len + i])) + constraints.extend( + sat_utils.cnf_modadd_inequality(hw_bit_ids[i], output_bit_ids[i], input_bit_ids[output_bit_len + i]) + ) result = input_bit_ids + output_bit_ids + hw_bit_ids, constraints return result @@ -159,49 +165,59 @@ def cp_deterministic_truncated_xor_differential_constraints(self): 'constraint pre_modadd_0_1_1[15] = plaintext[31];', 'constraint modular_addition_word(pre_modadd_0_1_1, pre_modadd_0_1_0, modadd_0_1);']) """ - input_id_links = self.input_id_links output_id_link = self.id - input_bit_positions = self.input_bit_positions num_add = self.description[1] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) input_len = len(all_inputs) // num_add cp_declarations = [] cp_constraints = [] for i in range(num_add): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..2: pre_{output_id_link}_{i};') - cp_constraints.extend([f'constraint pre_{output_id_link}_{i}[{j}] = {all_inputs[i * input_len + j]};' - for j in range(input_len)]) + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..2: pre_{output_id_link}_{i};") + cp_constraints.extend( + [ + f"constraint pre_{output_id_link}_{i}[{j}] = {all_inputs[i * input_len + j]};" + for j in range(input_len) + ] + ) for i in range(num_add, 2 * num_add - 2): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};") for i in range(num_add - 2): - cp_constraints.append(f'constraint modular_addition_word(pre_{output_id_link}_{num_add - 1}, ' - f'pre_{output_id_link}_{i + 1}, pre_{output_id_link}_{num_add + i});') - cp_constraints.append(f'constraint modular_addition_word(pre_{output_id_link}_{2 * num_add - 3}, ' - f'pre_{output_id_link}_0, {output_id_link});') + cp_constraints.append( + f"constraint modular_addition_word(pre_{output_id_link}_{num_add - 1}, " + f"pre_{output_id_link}_{i + 1}, pre_{output_id_link}_{num_add + i});" + ) + cp_constraints.append( + f"constraint modular_addition_word(pre_{output_id_link}_{2 * num_add - 3}, " + f"pre_{output_id_link}_0, {output_id_link});" + ) return cp_declarations, cp_constraints def cp_deterministic_truncated_xor_differential_trail_constraints(self): return self.cp_deterministic_truncated_xor_differential_constraints() - def cp_twoterms_xor_differential_probability(self, input_1, input_2, out, input_length, - cp_constraints, cp_declarations, c, model): + def cp_twoterms_xor_differential_probability( + self, input_1, input_2, out, input_length, cp_constraints, cp_declarations, c, model + ): if input_1 not in model.modadd_twoterms_mant: - cp_declarations.append(f'array[0..{input_length - 1}] of var 0..1: Shi_{input_1} = LShift({input_1},1);') + cp_declarations.append(f"array[0..{input_length - 1}] of var 0..1: Shi_{input_1} = LShift({input_1},1);") model.modadd_twoterms_mant.append(input_1) if input_2 not in model.modadd_twoterms_mant: - cp_declarations.append(f'array[0..{input_length - 1}] of var 0..1: Shi_{input_2} = LShift({input_2},1);') + cp_declarations.append(f"array[0..{input_length - 1}] of var 0..1: Shi_{input_2} = LShift({input_2},1);") model.modadd_twoterms_mant.append(input_2) if out not in model.modadd_twoterms_mant: - cp_declarations.append(f'array[0..{input_length - 1}] of var 0..1: Shi_{out} = LShift({out},1);') + cp_declarations.append(f"array[0..{input_length - 1}] of var 0..1: Shi_{out} = LShift({out},1);") model.modadd_twoterms_mant.append(out) - cp_declarations.append(f'array[0..{input_length - 1}] of var 0..1: eq_{out} = ' - f'Eq(Shi_{input_1}, Shi_{input_2}, Shi_{out});') - cp_constraints.append(f'constraint forall(j in 0..{input_length - 1})(if eq_{out}[j] = ' - f'1 then (sum([{input_1}[j], {input_2}[j], {out}[j]]) mod 2) = Shi_{input_2}[j] else ' - f'true endif) /\\ p[{c}] = {input_length}-sum(eq_{out});') + cp_declarations.append( + f"array[0..{input_length - 1}] of var 0..1: eq_{out} = Eq(Shi_{input_1}, Shi_{input_2}, Shi_{out});" + ) + cp_constraints.append( + f"constraint forall(j in 0..{input_length - 1})(if eq_{out}[j] = " + f"1 then (sum([{input_1}[j], {input_2}[j], {out}[j]]) mod 2) = Shi_{input_2}[j] else " + f"true endif) /\\ p[{c}] = {input_length}-sum(eq_{out});" + ) return cp_declarations, cp_constraints @@ -228,35 +244,43 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model ... 'constraint if xor_0_0_temp_0_15_active + xor_0_0_temp_1_15_active > 2 then xor_0_0_active[15] == 3 /\\ xor_0_0_value[15] = -2 elseif xor_0_0_temp_0_15_active + xor_0_0_temp_1_15_active == 1 then xor_0_0_active[15] = 1 /\\ xor_0_0_value[15] = xor_0_0_temp_0_15_value + xor_0_0_temp_1_15_value elseif xor_0_0_temp_0_15_active + xor_0_0_temp_1_15_active == 0 then xor_0_0_active[15] = 0 /\\ xor_0_0_value[15] = 0 elseif xor_0_0_temp_0_15_value + xor_0_0_temp_1_15_value < 0 then xor_0_0_active[15] = 2 /\\ xor_0_0_value[15] = -1 elseif xor_0_0_temp_0_15_value == xor_0_0_temp_1_15_value then xor_0_0_active[15] = 0 /\\ xor_0_0_value[15] = 0 else xor_0_0_active[15] = 1 /\\ xor_0_0_value[15] = sum([(((floor(xor_0_0_temp_0_15_value/pow(2,j)) + floor(xor_0_0_temp_1_15_value/pow(2,j))) mod 2) * pow(2,j)) | j in 0..xor_0_0_bound_value_0_15]) endif;']) """ - input_id_links = self.input_id_links output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs_value = [] all_inputs_active = [] numadd = self.description[1] word_size = model.word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) input_len = len(all_inputs_value) // numadd cp_constraints = [] - cp_declarations.append(f'array[0..{input_len}] of var 0..2: carry_{output_id_link};') - cp_constraints.append(f'constraint carry_{output_id_link}[0] = 0;') + cp_declarations.append(f"array[0..{input_len}] of var 0..2: carry_{output_id_link};") + cp_constraints.append(f"constraint carry_{output_id_link}[0] = 0;") for i in range(input_len): - new_constraint = f'constraint if ' - operation = f' == 0 /\\ '.join(all_inputs_active[i::num_add]) + new_constraint = "constraint if " + operation = " == 0 /\\ ".join(all_inputs_active[i :: self.description[1]]) new_constraint += operation - new_constraint += f' == 0 /\\ carry_{output_id_link}[{i}] == 0 then {output_id_link}_active[{i}] = 0 /\\ {output_id_link}_value[{i}] = 0 /\\ carry_{output_id_link}[{i + 1}] = 0 else ' \ - f'{output_id_link}_active[{i}] = 3 /\\ {output_id_link}_value[{i}] = -2 /\\ carry_{output_id_link}[{i + 1}] = 3 endif;' + new_constraint += ( + f" == 0 /\\ carry_{output_id_link}[{i}] == 0 then {output_id_link}_active[{i}] = 0 /\\ {output_id_link}_value[{i}] = 0 /\\ carry_{output_id_link}[{i + 1}] = 0 else " + f"{output_id_link}_active[{i}] = 3 /\\ {output_id_link}_value[{i}] = -2 /\\ carry_{output_id_link}[{i + 1}] = 3 endif;" + ) cp_constraints.append(new_constraint) return cp_declarations, cp_constraints def cp_xor_differential_propagation_constraints(self, model): - r""" + """ Return lists of declarations and constraints for the probability of Modular Addition/Substraction component for CP xor differential probability. INPUT: @@ -279,39 +303,53 @@ def cp_xor_differential_propagation_constraints(self, model): 'constraint pre_modadd_0_1_1[15] = plaintext[31];', 'constraint forall(j in 0..15)(if eq_modadd_0_1[j] = 1 then (sum([pre_modadd_0_1_1[j], pre_modadd_0_1_0[j], modadd_0_1[j]]) mod 2) = Shi_pre_modadd_0_1_0[j] else true endif) /\\ p[0] = 1600-100 * sum(eq_modadd_0_1);']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links output_id_link = self.id - input_bit_positions = self.input_bit_positions num_add = self.description[1] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) input_len = len(all_inputs) // num_add cp_declarations = [] cp_constraints = [] for i in range(num_add): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};') - cp_constraints.extend([f'constraint pre_{output_id_link}_{i}[{j}] = {all_inputs[i * input_len + j]};' - for j in range(input_len)]) + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};") + cp_constraints.extend( + [ + f"constraint pre_{output_id_link}_{i}[{j}] = {all_inputs[i * input_len + j]};" + for j in range(input_len) + ] + ) for i in range(num_add, 2 * num_add - 2): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};") probability = [] for i in range(num_add - 2): - self.cp_twoterms_xor_differential_probability(f'pre_{output_id_link}_{num_add - 1}', - f'pre_{output_id_link}_{i + 1}', - f'pre_{output_id_link}_{num_add + i}', output_size, - cp_constraints, cp_declarations, model.c, model) + self.cp_twoterms_xor_differential_probability( + f"pre_{output_id_link}_{num_add - 1}", + f"pre_{output_id_link}_{i + 1}", + f"pre_{output_id_link}_{num_add + i}", + self.output_bit_size, + cp_constraints, + cp_declarations, + model.c, + model, + ) probability.append(model.c) model.c += 1 - self.cp_twoterms_xor_differential_probability(f'pre_{output_id_link}_{2 * num_add - 3}', - f'pre_{output_id_link}_0', f'{output_id_link}', - output_size, cp_constraints, cp_declarations, model.c, model) + self.cp_twoterms_xor_differential_probability( + f"pre_{output_id_link}_{2 * num_add - 3}", + f"pre_{output_id_link}_0", + f"{output_id_link}", + self.output_bit_size, + cp_constraints, + cp_declarations, + model.c, + model, + ) probability.append(model.c) model.c += 1 model.component_and_probability[output_id_link] = probability - result = cp_declarations, cp_constraints - return result + + return cp_declarations, cp_constraints def cp_xor_differential_propagation_constraints_arx_optimized(self, model): """ @@ -332,7 +370,6 @@ def cp_xor_differential_propagation_constraints_arx_optimized(self, model): sage: constraints[6] 'constraint pre_modadd_1_9_1[0] = sbox_1_0[0];' """ - output_size = int(self.output_bit_size) input_id_links = self.input_id_links output_id_link = self.id input_bit_positions = self.input_bit_positions @@ -343,65 +380,86 @@ def cp_xor_differential_propagation_constraints_arx_optimized(self, model): all_inputs = [] for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) input_len = len(all_inputs) // num_add cp_declarations = [] cp_constraints = [] - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: dummy_{output_id_link};') - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: x1_{output_id_link};') - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: x2_{output_id_link};') + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: dummy_{output_id_link};") + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: x1_{output_id_link};") + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: x2_{output_id_link};") for i in range(input_len): cp_constraints.append( - f'constraint x1_{output_id_link}[{i}] = {input_id_links[0]}[{input_bit_positions[0][i]}];') + f"constraint x1_{output_id_link}[{i}] = {input_id_links[0]}[{input_bit_positions[0][i]}];" + ) cp_constraints.append( - f'constraint x2_{output_id_link}[{i}] = {input_id_links[1]}[{input_bit_positions[1][i]}];') + f"constraint x2_{output_id_link}[{i}] = {input_id_links[1]}[{input_bit_positions[1][i]}];" + ) cp_constraints.append( - f'constraint x2_{output_id_link}[{input_len - 1}] + x1_{output_id_link}[{input_len - 1}] + {output_id_link}[{input_len - 1}] <= 2;') + f"constraint x2_{output_id_link}[{input_len - 1}] + x1_{output_id_link}[{input_len - 1}] + {output_id_link}[{input_len - 1}] <= 2;" + ) cp_constraints.append( - f'constraint x2_{output_id_link}[{input_len - 1}] + x1_{output_id_link}[{input_len - 1}] + {output_id_link}[{input_len - 1}] - 2*dummy_{output_id_link}[{input_len - 1}] >= 0;') + f"constraint x2_{output_id_link}[{input_len - 1}] + x1_{output_id_link}[{input_len - 1}] + {output_id_link}[{input_len - 1}] - 2*dummy_{output_id_link}[{input_len - 1}] >= 0;" + ) cp_constraints.append( - f'constraint dummy_{output_id_link}[{input_len - 1}] - x2_{output_id_link}[{input_len - 1}] >= 0;') + f"constraint dummy_{output_id_link}[{input_len - 1}] - x2_{output_id_link}[{input_len - 1}] >= 0;" + ) cp_constraints.append( - f'constraint dummy_{output_id_link}[{input_len - 1}] - x1_{output_id_link}[{input_len - 1}] >= 0;') + f"constraint dummy_{output_id_link}[{input_len - 1}] - x1_{output_id_link}[{input_len - 1}] >= 0;" + ) cp_constraints.append( - f'constraint dummy_{output_id_link}[{input_len - 1}] - {output_id_link}[{input_len - 1}] >= 0;') + f"constraint dummy_{output_id_link}[{input_len - 1}] - {output_id_link}[{input_len - 1}] >= 0;" + ) for i in range(input_len - 1): cp_constraints.append( - f'constraint x1_{output_id_link}[{i + 1}] - {output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] >= 0;') + f"constraint x1_{output_id_link}[{i + 1}] - {output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint x2_{output_id_link}[{i + 1}] - x1_{output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] >= 0;') + f"constraint x2_{output_id_link}[{i + 1}] - x1_{output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint {output_id_link}[{i + 1}] - x2_{output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] >= 0;') + f"constraint {output_id_link}[{i + 1}] - x2_{output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint x2_{output_id_link}[{i + 1}] + x1_{output_id_link}[{i + 1}] + {output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] <= 3;') + f"constraint x2_{output_id_link}[{i + 1}] + x1_{output_id_link}[{i + 1}] + {output_id_link}[{i + 1}] + dummy_{output_id_link}[{i}] <= 3;" + ) cp_constraints.append( - f'constraint x2_{output_id_link}[{i + 1}] + x1_{output_id_link}[{i + 1}] + {output_id_link}[{i + 1}] - dummy_{output_id_link}[{i}] >= 0;') + f"constraint x2_{output_id_link}[{i + 1}] + x1_{output_id_link}[{i + 1}] + {output_id_link}[{i + 1}] - dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint - x1_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;') + f"constraint - x1_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint x1_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;') + f"constraint x1_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint x1_{output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;') + f"constraint x1_{output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint x2_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;') + f"constraint x2_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= 0;" + ) cp_constraints.append( - f'constraint {output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;') + f"constraint {output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;" + ) cp_constraints.append( - f'constraint - x1_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;') + f"constraint - x1_{output_id_link}[{i + 1}] + x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;" + ) cp_constraints.append( - f'constraint - x1_{output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;') + f"constraint - x1_{output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] + x1_{output_id_link}[{i}] - {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;" + ) cp_constraints.append( - f'constraint - x1_{output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;') + f"constraint - x1_{output_id_link}[{i + 1}] - x2_{output_id_link}[{i}] - x1_{output_id_link}[{i}] + {output_id_link}[{i}] + dummy_{output_id_link}[{i}] >= -2;" + ) cp_constraints.append( - f'constraint p[{model.c}] = sum([if (x1_{output_id_link}[i+1] = x2_{output_id_link}[i+1]) /\\ (x1_{output_id_link}[i+1] = {output_id_link}[i+1]) then 0 else 100 endif | i in 0..{input_len - 2}]);') + f"constraint p[{model.c}] = sum([if (x1_{output_id_link}[i+1] = x2_{output_id_link}[i+1]) /\\ (x1_{output_id_link}[i+1] = {output_id_link}[i+1]) then 0 else 100 endif | i in 0..{input_len - 2}]);" + ) model.c += 1 - result = cp_declarations, cp_constraints - return result + + return cp_declarations, cp_constraints def cp_xor_linear_mask_propagation_constraints(self, model): """ @@ -425,52 +483,51 @@ def cp_xor_linear_mask_propagation_constraints(self, model): 'constraint pre_modadd_0_1_1[15]=modadd_0_1_i[31];', 'constraint modadd_linear(pre_modadd_0_1_1, pre_modadd_0_1_0, modadd_0_1_o, p[0]);']) """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) output_id_link = self.id cp_declarations = [] cp_constraints = [] num_add = self.description[1] - input_len = input_size // num_add - cp_declarations.append(f'array[0..{input_size - 1}] of var 0..1: {output_id_link}_i;') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;') + input_len = self.input_bit_size // num_add + cp_declarations.append(f"array[0..{self.input_bit_size - 1}] of var 0..1: {output_id_link}_i;") + cp_declarations.append(f"array[0..{self.output_bit_size - 1}] of var 0..1: {output_id_link}_o;") probability = [] for i in range(num_add): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1: pre_{output_id_link}_{i};") for j in range(input_len): cp_constraints.append( - f'constraint pre_{output_id_link}_{i}[{j}]={output_id_link}_i[{i * input_len + j}];') + f"constraint pre_{output_id_link}_{i}[{j}]={output_id_link}_i[{i * input_len + j}];" + ) for i in range(num_add, 2 * num_add - 2): - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: pre_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{self.output_bit_size - 1}] of var 0..1: pre_{output_id_link}_{i};") for i in range(num_add - 2): cp_constraints.append( - f'constraint modadd_linear(pre_{output_id_link}_{num_add - 1}, pre_{output_id_link}_{i + 1}, ' - f'pre_{output_id_link}_{num_add + i}, p[{model.c}]);') + f"constraint modadd_linear(pre_{output_id_link}_{num_add - 1}, pre_{output_id_link}_{i + 1}, " + f"pre_{output_id_link}_{num_add + i}, p[{model.c}]);" + ) probability.append(model.c) model.c = model.c + 1 cp_constraints.append( - f'constraint modadd_linear(pre_{output_id_link}_{2 * num_add - 3}, pre_{output_id_link}_0, ' - f'{output_id_link}_o, p[{model.c}]);') + f"constraint modadd_linear(pre_{output_id_link}_{2 * num_add - 3}, pre_{output_id_link}_0, " + f"{output_id_link}_o, p[{model.c}]);" + ) probability.append(model.c) model.c = model.c + 1 model.component_and_probability[output_id_link] = probability - result = cp_declarations, cp_constraints - return result + + return cp_declarations, cp_constraints def get_word_operation_sign(self, sign, solution): output_id_link = self.id - input_size = self.input_bit_size - output_size = self.output_bit_size - input_int = int(solution['components_values'][f'{output_id_link}_i']['value'], 16) - output_int = int(solution['components_values'][f'{output_id_link}_o']['value'], 16) - inputs = [int(digit) for digit in format(input_int, f'0{input_size}b')] - outputs = [int(digit) for digit in format(output_int, f'0{output_size}b')] + input_int = int(solution["components_values"][f"{output_id_link}_i"]["value"], 16) + output_int = int(solution["components_values"][f"{output_id_link}_o"]["value"], 16) + inputs = [int(digit) for digit in format(input_int, f"0{self.input_bit_size}b")] + outputs = [int(digit) for digit in format(output_int, f"0{self.output_bit_size}b")] component_sign = generic_sign_linear_constraints(inputs, outputs) sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign @@ -518,50 +575,115 @@ def milp_xor_differential_propagation_constraints(self, model): constraints.append(x[component_id + "_dummy"] >= x[output_vars[output_bit_size - 1]]) constraints.append(x[component_id + "_dummy"] >= x[input_vars[output_bit_size - 1]]) constraints.append(x[component_id + "_dummy"] >= x[input_vars[2 * output_bit_size - 1]]) - constraints.append(x[output_vars[output_bit_size - 1]] + x[input_vars[output_bit_size - 1]] + - x[input_vars[2 * output_bit_size - 1]] - 2 * x[component_id + "_dummy"] >= 0) - constraints.append(x[output_vars[output_bit_size - 1]] + - x[input_vars[output_bit_size - 1]] + x[input_vars[2 * output_bit_size - 1]] <= 2) + constraints.append( + x[output_vars[output_bit_size - 1]] + + x[input_vars[output_bit_size - 1]] + + x[input_vars[2 * output_bit_size - 1]] + - 2 * x[component_id + "_dummy"] + >= 0 + ) + constraints.append( + x[output_vars[output_bit_size - 1]] + + x[input_vars[output_bit_size - 1]] + + x[input_vars[2 * output_bit_size - 1]] + <= 2 + ) # 2nd condition: # indice 0 for the MSB for i in range(output_bit_size - 1, 0, -1): - constraints.append(x[input_vars[output_bit_size + i]] - x[output_vars[i]] + - x[component_id + "_eq_" + str(i)] >= 0) - constraints.append(x[input_vars[i]] - x[input_vars[output_bit_size + i]] + - x[component_id + "_eq_" + str(i)] >= 0) + constraints.append( + x[input_vars[output_bit_size + i]] - x[output_vars[i]] + x[component_id + "_eq_" + str(i)] >= 0 + ) + constraints.append( + x[input_vars[i]] - x[input_vars[output_bit_size + i]] + x[component_id + "_eq_" + str(i)] >= 0 + ) constraints.append(-x[input_vars[i]] + x[output_vars[i]] + x[component_id + "_eq_" + str(i)] >= 0) - constraints.append(-x[input_vars[i]] - x[input_vars[output_bit_size + i]] - x[output_vars[i]] - x[ - component_id + "_eq_" + str(i)] >= -3) - constraints.append(x[input_vars[i]] + x[input_vars[output_bit_size + i]] + x[output_vars[i]] - x[ - component_id + "_eq_" + str(i)] >= 0) constraints.append( - -x[input_vars[output_bit_size + i]] + x[input_vars[i - 1]] + x[input_vars[output_bit_size + i - 1]] + x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= 0) + -x[input_vars[i]] + - x[input_vars[output_bit_size + i]] + - x[output_vars[i]] + - x[component_id + "_eq_" + str(i)] + >= -3 + ) + constraints.append( + x[input_vars[i]] + + x[input_vars[output_bit_size + i]] + + x[output_vars[i]] + - x[component_id + "_eq_" + str(i)] + >= 0 + ) constraints.append( - x[input_vars[output_bit_size + i]] + x[input_vars[i - 1]] - x[input_vars[output_bit_size + i - 1]] + x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= 0) + -x[input_vars[output_bit_size + i]] + + x[input_vars[i - 1]] + + x[input_vars[output_bit_size + i - 1]] + + x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= 0 + ) constraints.append( - x[input_vars[output_bit_size + i]] - x[input_vars[i - 1]] + x[input_vars[output_bit_size + i - 1]] + x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= 0) - constraints.append(x[input_vars[i]] + x[input_vars[i - 1]] + x[input_vars[output_bit_size + i - 1]] - x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= 0) - constraints.append(x[output_vars[i]] - x[input_vars[i - 1]] - x[input_vars[output_bit_size + i - 1]] - x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= -2) + x[input_vars[output_bit_size + i]] + + x[input_vars[i - 1]] + - x[input_vars[output_bit_size + i - 1]] + + x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= 0 + ) constraints.append( - -x[input_vars[output_bit_size + i]] - x[input_vars[output_bit_size + i - 1]] + x[input_vars[i - 1]] - x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= -2) + x[input_vars[output_bit_size + i]] + - x[input_vars[i - 1]] + + x[input_vars[output_bit_size + i - 1]] + + x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= 0 + ) constraints.append( - -x[input_vars[output_bit_size + i]] + x[input_vars[output_bit_size + i - 1]] - x[input_vars[i - 1]] - x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= -2) + x[input_vars[i]] + + x[input_vars[i - 1]] + + x[input_vars[output_bit_size + i - 1]] + - x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= 0 + ) constraints.append( - -x[input_vars[output_bit_size + i]] - x[input_vars[output_bit_size + i - 1]] - x[input_vars[i - 1]] + x[ - output_vars[i - 1]] + x[component_id + "_eq_" + str(i)] >= -2) - constraints.append(p[component_id + "_probability"] == (10 ** model.weight_precision) * sum( - x[component_id + "_eq_" + str(i)] for i in range(output_bit_size - 1, 0, -1))) + x[output_vars[i]] + - x[input_vars[i - 1]] + - x[input_vars[output_bit_size + i - 1]] + - x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= -2 + ) + constraints.append( + -x[input_vars[output_bit_size + i]] + - x[input_vars[output_bit_size + i - 1]] + + x[input_vars[i - 1]] + - x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= -2 + ) + constraints.append( + -x[input_vars[output_bit_size + i]] + + x[input_vars[output_bit_size + i - 1]] + - x[input_vars[i - 1]] + - x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= -2 + ) + constraints.append( + -x[input_vars[output_bit_size + i]] + - x[input_vars[output_bit_size + i - 1]] + - x[input_vars[i - 1]] + + x[output_vars[i - 1]] + + x[component_id + "_eq_" + str(i)] + >= -2 + ) + constraints.append( + p[component_id + "_probability"] + == (10**model.weight_precision) + * sum(x[component_id + "_eq_" + str(i)] for i in range(output_bit_size - 1, 0, -1)) + ) # the most significant bit is not taken in consideration if model.n_window_heuristic is not None: - milp_n_window_heuristic(input_vars, output_vars, component_id, - model.n_window_heuristic, model.model, x) + milp_n_window_heuristic(input_vars, output_vars, component_id, model.n_window_heuristic, model.model, x) result = variables, constraints return result @@ -619,7 +741,6 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode c = [x_class[output_vars[i]] for i in range(output_bit_size)] for i in range(output_bit_size): - M = output_bit_size + 1 # i_less = 1 iff i < pivot @@ -658,8 +779,9 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode # # # # if a_b_less_2 == 1 then c = a XOR b # # # # else c = 2 normal_xor_constr = milp_utils.milp_generalized_xor([a[i], b[i]], c[i]) - truncated_xor_constr = milp_utils.milp_if_then_else(a_b_less_2, normal_xor_constr, [c[i] == 2], - model._model.get_max(x_class) * num_of_inputs) + truncated_xor_constr = milp_utils.milp_if_then_else( + a_b_less_2, normal_xor_constr, [c[i] == 2], model._model.get_max(x_class) * num_of_inputs + ) constr = milp_utils.milp_if_then(p_eq, truncated_xor_constr, model._model.get_max(x_class) * num_of_inputs) constraints.extend(constr) @@ -711,11 +833,13 @@ def milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(sel input_id_tuples, output_id_tuples = self._get_input_output_variables_tuples() input_ids, output_ids = self._get_input_output_variables() - linking_constraints = model.link_binary_tuples_to_integer_variables(input_id_tuples + output_id_tuples, - input_ids + output_ids) + linking_constraints = model.link_binary_tuples_to_integer_variables( + input_id_tuples + output_id_tuples, input_ids + output_ids + ) - variables = [(f"x[{var_elt}]", x[var_elt]) for var_tuple in input_id_tuples + output_id_tuples for var_elt in - var_tuple] + variables = [ + (f"x[{var_elt}]", x[var_elt]) for var_tuple in input_id_tuples + output_id_tuples for var_elt in var_tuple + ] constraints = [] + linking_constraints input_vars = [tuple(x[i] for i in _) for _ in input_id_tuples] @@ -734,9 +858,13 @@ def milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(sel for i in range(pivot + 1, output_bit_size): constraints_pivot.extend([sum(input_vars[i] + input_vars[i + output_bit_size] + output_vars[i]) == 0]) constraints_pivot.extend( - milp_utils.milp_xor_truncated(model, input_id_tuples[pivot::output_bit_size][0], - input_id_tuples[pivot::output_bit_size][1], - output_id_tuples[pivot])) + milp_utils.milp_xor_truncated( + model, + input_id_tuples[pivot::output_bit_size][0], + input_id_tuples[pivot::output_bit_size][1], + output_id_tuples[pivot], + ) + ) constraints.extend(milp_utils.milp_if_then(pivot_vars[pivot], constraints_pivot, output_bit_size + 1)) return variables, constraints @@ -761,55 +889,59 @@ def minizinc_xor_differential_propagation_constraints(self, model): 'constraint modular_addition_word(array1d(0..6-1, [modadd_1_9_x0,modadd_1_9_x1,modadd_1_9_x2,modadd_1_9_x3,modadd_1_9_x4,modadd_1_9_x5]),array1d(0..6-1, [modadd_1_9_x6,modadd_1_9_x7,modadd_1_9_x8,modadd_1_9_x9,modadd_1_9_x10,modadd_1_9_x11]),array1d(0..6-1, [modadd_1_9_y0_0,modadd_1_9_y1_0,modadd_1_9_y2_0,modadd_1_9_y3_0,modadd_1_9_y4_0,modadd_1_9_y5_0]), p_modadd_1_9_0, dummy_modadd_1_9_0, -1)=1;\nconstraint carry_modadd_1_9_0 = XOR3(array1d(0..6-1, [modadd_1_9_x0,modadd_1_9_x1,modadd_1_9_x2,modadd_1_9_x3,modadd_1_9_x4,modadd_1_9_x5]),array1d(0..6-1, [modadd_1_9_x6,modadd_1_9_x7,modadd_1_9_x8,modadd_1_9_x9,modadd_1_9_x10,modadd_1_9_x11]),array1d(0..6-1, [modadd_1_9_y0_0,modadd_1_9_y1_0,modadd_1_9_y2_0,modadd_1_9_y3_0,modadd_1_9_y4_0,modadd_1_9_y5_0]));\n' """ - def create_block_of_modadd_constraints(input_vars_1_temp, input_vars_2_temp, - output_varstrs_temp, i, round_number): + def create_block_of_modadd_constraints( + input_vars_1_temp, input_vars_2_temp, output_varstrs_temp, i, round_number + ): mzn_input_array_1 = self._create_minizinc_1d_array_from_list(input_vars_1_temp) mzn_input_array_2 = self._create_minizinc_1d_array_from_list(input_vars_2_temp) mzn_output_array = self._create_minizinc_1d_array_from_list(output_varstrs_temp) - dummy_declaration = f'var {model.data_type}: dummy_{component_id}_{i};\n' - mzn_probability_var = f'p_{component_id}_{i}' + dummy_declaration = f"var {model.data_type}: dummy_{component_id}_{i};\n" + mzn_probability_var = f"p_{component_id}_{i}" model.probability_vars.append(mzn_probability_var) - pr_declaration = (f'array [0..{noutput_bits}-2] of var {model.data_type}:' - f'{mzn_probability_var};\n') + pr_declaration = f"array [0..{noutput_bits}-2] of var {model.data_type}:{mzn_probability_var};\n" model.probability_modadd_vars_per_round[round_number - 1].append(mzn_probability_var) mzn_block_variables = "" dummy_id = "" if model.sat_or_milp == "milp": mzn_block_variables += dummy_declaration - dummy_id += f'dummy_{component_id}_{i},' + dummy_id += f"dummy_{component_id}_{i}," mzn_block_variables += pr_declaration if model.window_size_list: round_window_size = model.window_size_list[round_number - 1] - mzn_block_constraints = (f'constraint modular_addition_word(' - f'{mzn_input_array_1},{mzn_input_array_2},{mzn_output_array},' - f' p_{component_id}_{i},' - f' {dummy_id}' - f' {round_window_size}' - f')={model.true_value};\n') + mzn_block_constraints = ( + f"constraint modular_addition_word(" + f"{mzn_input_array_1},{mzn_input_array_2},{mzn_output_array}," + f" p_{component_id}_{i}," + f" {dummy_id}" + f" {round_window_size}" + f")={model.true_value};\n" + ) else: - mzn_block_constraints = (f'constraint modular_addition_word(' - f'{mzn_input_array_1},{mzn_input_array_2},{mzn_output_array},' - f' p_{component_id}_{i},' - f' {dummy_id}' - f' -1' - f')={model.true_value};\n') - - mzn_carry_var = f'carry_{component_id}_{i}' - modadd_carries_definition = (f'array [0..{noutput_bits}-1] of var {model.data_type}:' - f'{mzn_carry_var};\n') + mzn_block_constraints = ( + f"constraint modular_addition_word(" + f"{mzn_input_array_1},{mzn_input_array_2},{mzn_output_array}," + f" p_{component_id}_{i}," + f" {dummy_id}" + f" -1" + f")={model.true_value};\n" + ) + + mzn_carry_var = f"carry_{component_id}_{i}" + modadd_carries_definition = f"array [0..{noutput_bits}-1] of var {model.data_type}:{mzn_carry_var};\n" mzn_block_variables += modadd_carries_definition - model.carries_vars.append({'mzn_carry_array_name': mzn_carry_var, 'mzn_carry_array_size': noutput_bits}) - mzn_block_constraints_carries = (f'constraint {mzn_carry_var} = ' - f'XOR3(' - f'{mzn_input_array_1},{mzn_input_array_2},' - f'{mzn_output_array});\n') + model.carries_vars.append({"mzn_carry_array_name": mzn_carry_var, "mzn_carry_array_size": noutput_bits}) + mzn_block_constraints_carries = ( + f"constraint {mzn_carry_var} = XOR3({mzn_input_array_1},{mzn_input_array_2},{mzn_output_array});\n" + ) mzn_block_constraints += mzn_block_constraints_carries - model.mzn_carries_output_directives.append(f'output ["carries {component_id}:"++show(XOR3(' - f'{mzn_input_array_1},{mzn_input_array_2},' - f'{mzn_output_array}))++"\\n"];') + model.mzn_carries_output_directives.append( + f'output ["carries {component_id}:"++show(XOR3(' + f"{mzn_input_array_1},{mzn_input_array_2}," + f'{mzn_output_array}))++"\\n"];' + ) return mzn_block_variables, mzn_block_constraints @@ -831,19 +963,21 @@ def create_block_of_modadd_constraints(input_vars_1_temp, input_vars_2_temp, for i in range(ninput_words - 2): new_output_vars_temp = [] for output_var in output_varstrs: - mzn_constraints += [f'var {model.data_type}: {output_var}_{i};'] + mzn_constraints += [f"var {model.data_type}: {output_var}_{i};"] new_output_vars_temp.append(output_var + "_" + str(i)) new_output_vars.append(new_output_vars_temp) for i in range(ninput_words - 1): - input_vars_1 = input_varstrs[i * word_chunk:i * word_chunk + word_chunk] - input_vars_2 = input_varstrs[i * word_chunk + word_chunk:i * word_chunk + word_chunk + word_chunk] + input_vars_1 = input_varstrs[i * word_chunk : i * word_chunk + word_chunk] + input_vars_2 = input_varstrs[i * word_chunk + word_chunk : i * word_chunk + word_chunk + word_chunk] if i == ninput_words - 2: - mzn_variables_and_constraints = create_block_of_modadd_constraints(input_vars_1, input_vars_2, - output_varstrs, i, round_number) + mzn_variables_and_constraints = create_block_of_modadd_constraints( + input_vars_1, input_vars_2, output_varstrs, i, round_number + ) else: - mzn_variables_and_constraints = create_block_of_modadd_constraints(input_vars_1, input_vars_2, - new_output_vars[i], i, round_number) + mzn_variables_and_constraints = create_block_of_modadd_constraints( + input_vars_1, input_vars_2, new_output_vars[i], i, round_number + ) var_names += [mzn_variables_and_constraints[0]] mzn_constraints += [mzn_variables_and_constraints[1]] @@ -892,42 +1026,46 @@ def milp_xor_linear_mask_propagation_constraints(self, model): variables = [] constraints = [] if number_of_inputs == 2: - variables, constraints = self.twoterms_milp_probability_xor_linear_constraints(binary_variable, - integer_variable, - input_vars, - output_vars, 0) - constraints.append(correlation[component_id + "_probability"] == (10 ** model.weight_precision) * - correlation[component_id + "_modadd_probability" + str(0)]) + variables, constraints = self.twoterms_milp_probability_xor_linear_constraints( + binary_variable, integer_variable, input_vars, output_vars, 0 + ) + constraints.append( + correlation[component_id + "_probability"] + == (10**model.weight_precision) * correlation[component_id + "_modadd_probability" + str(0)] + ) elif number_of_inputs > 2: - temp_output_vars = [[f"{var}_temp_modadd_{i}" for var in output_vars] - for i in range(number_of_inputs - 2)] - variables, constraints = \ - self.twoterms_milp_probability_xor_linear_constraints(binary_variable, integer_variable, - input_vars[:2 * output_bit_size], - temp_output_vars[0], 0) + temp_output_vars = [[f"{var}_temp_modadd_{i}" for var in output_vars] for i in range(number_of_inputs - 2)] + variables, constraints = self.twoterms_milp_probability_xor_linear_constraints( + binary_variable, integer_variable, input_vars[: 2 * output_bit_size], temp_output_vars[0], 0 + ) for i in range(1, number_of_inputs - 2): temp_output_vars.extend([[f"{var}_temp_modadd_{i}" for var in output_vars]]) temp_variables, temp_constraints = self.twoterms_milp_probability_xor_linear_constraints( binary_variable, integer_variable, - input_vars[(i + 1) * output_bit_size:(i + 2) * output_bit_size] + temp_output_vars[i - 1], - temp_output_vars[i], i) + input_vars[(i + 1) * output_bit_size : (i + 2) * output_bit_size] + temp_output_vars[i - 1], + temp_output_vars[i], + i, + ) variables.extend(temp_variables) constraints.extend(temp_constraints) - temp_variables, temp_constraints = \ - self.twoterms_milp_probability_xor_linear_constraints( - binary_variable, integer_variable, - input_vars[(number_of_inputs - 1) * output_bit_size: number_of_inputs * output_bit_size] + - temp_output_vars[number_of_inputs - 3], - output_vars, number_of_inputs - 2) + temp_variables, temp_constraints = self.twoterms_milp_probability_xor_linear_constraints( + binary_variable, + integer_variable, + input_vars[(number_of_inputs - 1) * output_bit_size : number_of_inputs * output_bit_size] + + temp_output_vars[number_of_inputs - 3], + output_vars, + number_of_inputs - 2, + ) variables.extend(temp_variables) constraints.extend(temp_constraints) - constraints.append(correlation[component_id + "_probability"] == - (10 ** model.weight_precision) * sum( - correlation[component_id + "_modadd_probability" + str(i)] - for i in range(number_of_inputs - 1))) + constraints.append( + correlation[component_id + "_probability"] + == (10**model.weight_precision) + * sum(correlation[component_id + "_modadd_probability" + str(i)] for i in range(number_of_inputs - 1)) + ) result = variables, constraints return result @@ -968,52 +1106,66 @@ def sat_xor_differential_propagation_constraints(self, model=None): """ def extend_constraints_for_window_size( - model_, output_bit_len_, window_size_, input_bit_ids_, output_bit_ids_, constraints_ + model_, output_bit_len_, window_size_, input_bit_ids_, output_bit_ids_, constraints_ ): window_size_ += 1 for i in range(output_bit_len_ - window_size_): - aux_var = f'full_window_track_{self.id}_{i}' + aux_var = f"full_window_track_{self.id}_{i}" if model_.window_size_number_of_full_window is not None: model_._window_size_full_window_vars.append(aux_var) - first_addend = input_bit_ids_[i:i + window_size_] - second_addend = input_bit_ids_[output_bit_len_ + i:output_bit_len_ + i + window_size_] - result = output_bit_ids_[i:i + window_size_] - from claasp.cipher_modules.models.sat.utils.n_window_heuristic_helper import \ - generate_window_size_clauses + first_addend = input_bit_ids_[i : i + window_size_] + second_addend = input_bit_ids_[output_bit_len_ + i : output_bit_len_ + i + window_size_] + result = output_bit_ids_[i : i + window_size_] + from claasp.cipher_modules.models.sat.utils.n_window_heuristic_helper import ( + generate_window_size_clauses, + ) + new_constraints = generate_window_size_clauses(first_addend, second_addend, result, aux_var) constraints_.extend(new_constraints) - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - dummy_bit_ids = [f'dummy_{output_bit_ids[i]}' for i in range(output_bit_len - 1)] - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + dummy_bit_ids = [f"dummy_{output_bit_ids[i]}" for i in range(output_bit_len - 1)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] constraints = [] # Hamming weight for i in range(output_bit_len - 1): - constraints.extend(sat_utils.cnf_hw_lipmaa(hw_bit_ids[i], - input_bit_ids[i + 1], - input_bit_ids[output_bit_len + i + 1], - output_bit_ids[i + 1])) - constraints.append(f'-{hw_bit_ids[output_bit_len - 1]}') + constraints.extend( + sat_utils.cnf_hw_lipmaa( + hw_bit_ids[i], input_bit_ids[i + 1], input_bit_ids[output_bit_len + i + 1], output_bit_ids[i + 1] + ) + ) + constraints.append(f"-{hw_bit_ids[output_bit_len - 1]}") # Trail validity # for i in range(output_bit_len - 1): - constraints.extend(sat_utils.cnf_lipmaa(hw_bit_ids[i], - dummy_bit_ids[i], - input_bit_ids[output_bit_len + i + 1], - input_bit_ids[i], - input_bit_ids[output_bit_len + i], - output_bit_ids[i])) - constraints.extend(sat_utils.cnf_xor(output_bit_ids[output_bit_len - 1], - [input_bit_ids[output_bit_len - 1], - input_bit_ids[2 * output_bit_len - 1]])) + constraints.extend( + sat_utils.cnf_lipmaa( + hw_bit_ids[i], + dummy_bit_ids[i], + input_bit_ids[output_bit_len + i + 1], + input_bit_ids[i], + input_bit_ids[output_bit_len + i], + output_bit_ids[i], + ) + ) + constraints.extend( + sat_utils.cnf_xor( + output_bit_ids[output_bit_len - 1], + [input_bit_ids[output_bit_len - 1], input_bit_ids[2 * output_bit_len - 1]], + ) + ) from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel + if type(model) is SatXorDifferentialModel and model.window_size_by_round_values is not None: if model.window_size_weight_pr_vars != -1: for i in range(output_bit_len - model.window_size_weight_pr_vars): - constraints.extend(sat_utils.cnf_n_window_heuristic_on_w_vars( - hw_bit_ids[i: i + (model.window_size_weight_pr_vars + 1)])) + constraints.extend( + sat_utils.cnf_n_window_heuristic_on_w_vars( + hw_bit_ids[i : i + (model.window_size_weight_pr_vars + 1)] + ) + ) component_round_number = model._cipher.get_round_from_component_id(self.id) if type(model) is SatXorDifferentialModel and model.window_size_by_round_values is not None: @@ -1072,23 +1224,35 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): """ in_ids_0, in_ids_1 = self._generate_input_double_ids() out_len, out_ids_0, out_ids_1 = self._generate_output_double_ids() - carry_ids_0 = [f'carry_{out_id}_0' for out_id in out_ids_0] - carry_ids_1 = [f'carry_{out_id}_1' for out_id in out_ids_1] - constraints = [f'-{carry_ids_0[-1]} -{carry_ids_1[-1]}'] - constraints.extend(sat_utils.modadd_truncated_msb((out_ids_0[0], out_ids_1[0]), - (in_ids_0[0], in_ids_1[0]), - (in_ids_0[out_len], in_ids_1[out_len]), - (carry_ids_0[0], carry_ids_1[0]))) + carry_ids_0 = [f"carry_{out_id}_0" for out_id in out_ids_0] + carry_ids_1 = [f"carry_{out_id}_1" for out_id in out_ids_1] + constraints = [f"-{carry_ids_0[-1]} -{carry_ids_1[-1]}"] + constraints.extend( + sat_utils.modadd_truncated_msb( + (out_ids_0[0], out_ids_1[0]), + (in_ids_0[0], in_ids_1[0]), + (in_ids_0[out_len], in_ids_1[out_len]), + (carry_ids_0[0], carry_ids_1[0]), + ) + ) for i in range(1, out_len - 1): - constraints.extend(sat_utils.modadd_truncated((out_ids_0[i], out_ids_1[i]), - (in_ids_0[i], in_ids_1[i]), - (in_ids_0[i + out_len], in_ids_1[i + out_len]), - (carry_ids_0[i], carry_ids_1[i]), - (carry_ids_0[i - 1], carry_ids_1[i - 1]))) - constraints.extend(sat_utils.modadd_truncated_lsb((out_ids_0[-1], out_ids_1[-1]), - (in_ids_0[out_len - 1], in_ids_1[out_len - 1]), - (in_ids_0[2 * out_len - 1], in_ids_1[2 * out_len - 1]), - (carry_ids_0[-2], carry_ids_1[-2]))) + constraints.extend( + sat_utils.modadd_truncated( + (out_ids_0[i], out_ids_1[i]), + (in_ids_0[i], in_ids_1[i]), + (in_ids_0[i + out_len], in_ids_1[i + out_len]), + (carry_ids_0[i], carry_ids_1[i]), + (carry_ids_0[i - 1], carry_ids_1[i - 1]), + ) + ) + constraints.extend( + sat_utils.modadd_truncated_lsb( + (out_ids_0[-1], out_ids_1[-1]), + (in_ids_0[out_len - 1], in_ids_1[out_len - 1]), + (in_ids_0[2 * out_len - 1], in_ids_1[2 * out_len - 1]), + (carry_ids_0[-2], carry_ids_1[-2]), + ) + ) return out_ids_0 + out_ids_1 + carry_ids_0 + carry_ids_1, constraints @@ -1113,10 +1277,10 @@ def position_0_constraints(a_t15, b_t15, c_t15, a_v15, b_v15, c_v15): clauses.append(f"-{b_t15}") clauses.append(f"-{c_t15}") - clauses.append(f' {a_v15} {b_v15} -{c_v15}') - clauses.append(f' {a_v15} -{b_v15} {c_v15}') - clauses.append(f'-{a_v15} {b_v15} {c_v15}') - clauses.append(f'-{a_v15} -{b_v15} -{c_v15}') + clauses.append(f" {a_v15} {b_v15} -{c_v15}") + clauses.append(f" {a_v15} -{b_v15} {c_v15}") + clauses.append(f"-{a_v15} {b_v15} {c_v15}") + clauses.append(f"-{a_v15} -{b_v15} -{c_v15}") return clauses @@ -1124,10 +1288,10 @@ def position_0_constraints(a_t15, b_t15, c_t15, a_v15, b_v15, c_v15): out_len, out_ids_0, out_ids_1 = self._generate_output_double_ids() A_t = in_ids_0[0:out_len] - B_t = in_ids_0[out_len:2 * out_len] + B_t = in_ids_0[out_len : 2 * out_len] A_v = in_ids_1[0:out_len] - B_v = in_ids_1[out_len:2 * out_len] + B_v = in_ids_1[out_len : 2 * out_len] C_t = out_ids_0[0:out_len] C_v = out_ids_1[0:out_len] @@ -1137,9 +1301,9 @@ def position_0_constraints(a_t15, b_t15, c_t15, a_v15, b_v15, c_v15): ) word_size = out_len - p = [f'hw_p_{self.id}_{i}' for i in range(word_size)] - q = [f'hw_q_{self.id}_{i}' for i in range(word_size)] - r = [f'hw_r_{self.id}_{i}' for i in range(word_size)] + p = [f"hw_p_{self.id}_{i}" for i in range(word_size)] + q = [f"hw_q_{self.id}_{i}" for i in range(word_size)] + r = [f"hw_r_{self.id}_{i}" for i in range(word_size)] window_size_ = 3 for bit_position in range(word_size - 1): @@ -1149,45 +1313,106 @@ def position_0_constraints(a_t15, b_t15, c_t15, a_v15, b_v15, c_v15): A_t_bit_positions.append(A_t[bit_position + i]) if len(A_t_bit_positions) == 3: cnf_clauses = sat_utils.get_cnf_semi_deterministic_window_1( - A_t[bit_position], A_t[bit_position + 1], A_t[bit_position + 2], - A_v[bit_position], A_v[bit_position + 1], A_v[bit_position + 2], - B_t[bit_position], B_t[bit_position + 1], B_t[bit_position + 2], - B_v[bit_position], B_v[bit_position + 1], B_v[bit_position + 2], - C_t[bit_position], C_t[bit_position + 1], C_t[bit_position + 2], - C_v[bit_position], C_v[bit_position + 1], - p[bit_position], q[bit_position], r[bit_position]) + A_t[bit_position], + A_t[bit_position + 1], + A_t[bit_position + 2], + A_v[bit_position], + A_v[bit_position + 1], + A_v[bit_position + 2], + B_t[bit_position], + B_t[bit_position + 1], + B_t[bit_position + 2], + B_v[bit_position], + B_v[bit_position + 1], + B_v[bit_position + 2], + C_t[bit_position], + C_t[bit_position + 1], + C_t[bit_position + 2], + C_v[bit_position], + C_v[bit_position + 1], + p[bit_position], + q[bit_position], + r[bit_position], + ) elif len(A_t_bit_positions) == 2: cnf_clauses = sat_utils.get_semi_deterministic_cnf_window_0( - A_t[bit_position], A_t[bit_position + 1], - A_v[bit_position], A_v[bit_position + 1], - B_t[bit_position], B_t[bit_position + 1], - B_v[bit_position], B_v[bit_position + 1], - C_t[bit_position], C_t[bit_position + 1], - C_v[bit_position], C_v[bit_position + 1], - p[bit_position], q[bit_position], r[bit_position]) + A_t[bit_position], + A_t[bit_position + 1], + A_v[bit_position], + A_v[bit_position + 1], + B_t[bit_position], + B_t[bit_position + 1], + B_v[bit_position], + B_v[bit_position + 1], + C_t[bit_position], + C_t[bit_position + 1], + C_v[bit_position], + C_v[bit_position + 1], + p[bit_position], + q[bit_position], + r[bit_position], + ) elif len(A_t_bit_positions) == 4: cnf_clauses = sat_utils.get_cnf_semi_deterministic_window_2( - A_t[bit_position], A_t[bit_position + 1], A_t[bit_position + 2], A_t[bit_position + 3], - A_v[bit_position], A_v[bit_position + 1], A_v[bit_position + 2], A_v[bit_position + 3], - B_t[bit_position], B_t[bit_position + 1], B_t[bit_position + 2], B_t[bit_position + 3], - B_v[bit_position], B_v[bit_position + 1], B_v[bit_position + 2], B_v[bit_position + 3], - C_t[bit_position], C_t[bit_position + 1], C_t[bit_position + 2], C_t[bit_position + 3], - C_v[bit_position], C_v[bit_position + 1], - p[bit_position], q[bit_position], r[bit_position]) + A_t[bit_position], + A_t[bit_position + 1], + A_t[bit_position + 2], + A_t[bit_position + 3], + A_v[bit_position], + A_v[bit_position + 1], + A_v[bit_position + 2], + A_v[bit_position + 3], + B_t[bit_position], + B_t[bit_position + 1], + B_t[bit_position + 2], + B_t[bit_position + 3], + B_v[bit_position], + B_v[bit_position + 1], + B_v[bit_position + 2], + B_v[bit_position + 3], + C_t[bit_position], + C_t[bit_position + 1], + C_t[bit_position + 2], + C_t[bit_position + 3], + C_v[bit_position], + C_v[bit_position + 1], + p[bit_position], + q[bit_position], + r[bit_position], + ) elif len(A_t_bit_positions) == 5: cnf_clauses = sat_utils.get_cnf_semi_deterministic_window_3( - A_t[bit_position], A_t[bit_position + 1], A_t[bit_position + 2], A_t[bit_position + 3], + A_t[bit_position], + A_t[bit_position + 1], + A_t[bit_position + 2], + A_t[bit_position + 3], A_t[bit_position + 4], - A_v[bit_position], A_v[bit_position + 1], A_v[bit_position + 2], A_v[bit_position + 3], + A_v[bit_position], + A_v[bit_position + 1], + A_v[bit_position + 2], + A_v[bit_position + 3], A_v[bit_position + 4], - B_t[bit_position], B_t[bit_position + 1], B_t[bit_position + 2], B_t[bit_position + 3], + B_t[bit_position], + B_t[bit_position + 1], + B_t[bit_position + 2], + B_t[bit_position + 3], B_t[bit_position + 4], - B_v[bit_position], B_v[bit_position + 1], B_v[bit_position + 2], B_v[bit_position + 3], + B_v[bit_position], + B_v[bit_position + 1], + B_v[bit_position + 2], + B_v[bit_position + 3], B_v[bit_position + 4], - C_t[bit_position], C_t[bit_position + 1], C_t[bit_position + 2], C_t[bit_position + 3], + C_t[bit_position], + C_t[bit_position + 1], + C_t[bit_position + 2], + C_t[bit_position + 3], C_t[bit_position + 4], - C_v[bit_position], C_v[bit_position + 1], - p[bit_position], q[bit_position], r[bit_position]) + C_v[bit_position], + C_v[bit_position + 1], + p[bit_position], + q[bit_position], + r[bit_position], + ) else: raise Exception("Window size not supported") @@ -1230,26 +1455,29 @@ def sat_xor_linear_mask_propagation_constraints(self, model=None): _, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX output_bit_len, output_bit_ids = self._generate_output_ids(suffix=out_suffix) - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] - constraints = [f'-{hw_bit_ids[0]}'] - constraints.extend(sat_utils.cnf_xor(hw_bit_ids[1], - [output_bit_ids[0], - input_bit_ids[0], - input_bit_ids[output_bit_len]])) + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] + constraints = [f"-{hw_bit_ids[0]}"] + constraints.extend( + sat_utils.cnf_xor(hw_bit_ids[1], [output_bit_ids[0], input_bit_ids[0], input_bit_ids[output_bit_len]]) + ) for i in range(2, output_bit_len): - constraints.extend(sat_utils.cnf_xor(hw_bit_ids[i], - [hw_bit_ids[i - 1], - output_bit_ids[i - 1], - input_bit_ids[i - 1], - input_bit_ids[output_bit_len + i - 1]])) + constraints.extend( + sat_utils.cnf_xor( + hw_bit_ids[i], + [ + hw_bit_ids[i - 1], + output_bit_ids[i - 1], + input_bit_ids[i - 1], + input_bit_ids[output_bit_len + i - 1], + ], + ) + ) for i in range(output_bit_len): - constraints.extend(sat_utils.cnf_modadd_inequality(hw_bit_ids[i], - output_bit_ids[i], - input_bit_ids[i])) + constraints.extend(sat_utils.cnf_modadd_inequality(hw_bit_ids[i], output_bit_ids[i], input_bit_ids[i])) for i in range(output_bit_len): - constraints.extend(sat_utils.cnf_modadd_inequality(hw_bit_ids[i], - output_bit_ids[i], - input_bit_ids[output_bit_len + i])) + constraints.extend( + sat_utils.cnf_modadd_inequality(hw_bit_ids[i], output_bit_ids[i], input_bit_ids[output_bit_len + i]) + ) result = input_bit_ids + output_bit_ids + hw_bit_ids, constraints return result @@ -1286,30 +1514,38 @@ def smt_xor_differential_propagation_constraints(self, model=None): '(assert (or hw_modadd_0_1_30 (not (xor shift_0_0_30 key_30 modadd_0_1_30 key_31))))', '(assert (not (xor modadd_0_1_31 shift_0_0_31 key_31)))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] constraints = [] # Hamming weight for i in range(output_bit_len - 1): - operation = smt_utils.smt_equivalent((input_bit_ids[i + 1], - input_bit_ids[output_bit_len + i + 1], - output_bit_ids[i + 1])) + operation = smt_utils.smt_equivalent( + (input_bit_ids[i + 1], input_bit_ids[output_bit_len + i + 1], output_bit_ids[i + 1]) + ) equation = smt_utils.smt_equivalent([smt_utils.smt_not(hw_bit_ids[i]), operation]) constraints.append(smt_utils.smt_assert(equation)) constraints.append(smt_utils.smt_assert(smt_utils.smt_not(hw_bit_ids[output_bit_len - 1]))) # Trail validity # for i in range(output_bit_len - 1): - lipmaa = smt_utils.smt_lipmaa(hw_bit_ids[i], - input_bit_ids[i], - input_bit_ids[output_bit_len + i], - output_bit_ids[i], - input_bit_ids[output_bit_len + i + 1]) + lipmaa = smt_utils.smt_lipmaa( + hw_bit_ids[i], + input_bit_ids[i], + input_bit_ids[output_bit_len + i], + output_bit_ids[i], + input_bit_ids[output_bit_len + i + 1], + ) constraints.append(smt_utils.smt_assert(lipmaa)) - lipmaa_lsb = smt_utils.smt_not(smt_utils.smt_xor([output_bit_ids[output_bit_len - 1], - input_bit_ids[output_bit_len - 1], - input_bit_ids[2 * output_bit_len - 1]])) + lipmaa_lsb = smt_utils.smt_not( + smt_utils.smt_xor( + [ + output_bit_ids[output_bit_len - 1], + input_bit_ids[output_bit_len - 1], + input_bit_ids[2 * output_bit_len - 1], + ] + ) + ) constraints.append(smt_utils.smt_assert(lipmaa_lsb)) result = output_bit_ids + hw_bit_ids, constraints return result @@ -1350,16 +1586,15 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): _, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX output_bit_len, output_bit_ids = self._generate_output_ids(suffix=out_suffix) - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] constraints = [smt_utils.smt_assert(smt_utils.smt_not(hw_bit_ids[0]))] operation = smt_utils.smt_xor((output_bit_ids[0], input_bit_ids[0], input_bit_ids[output_bit_len])) equation = smt_utils.smt_equivalent((hw_bit_ids[1], operation)) constraints.append(smt_utils.smt_assert(equation)) for i in range(2, output_bit_len): - operation = smt_utils.smt_xor((hw_bit_ids[i - 1], - output_bit_ids[i - 1], - input_bit_ids[i - 1], - input_bit_ids[output_bit_len + i - 1])) + operation = smt_utils.smt_xor( + (hw_bit_ids[i - 1], output_bit_ids[i - 1], input_bit_ids[i - 1], input_bit_ids[output_bit_len + i - 1]) + ) equation = smt_utils.smt_equivalent((hw_bit_ids[i], operation)) constraints.append(smt_utils.smt_assert(equation)) for i in range(output_bit_len): @@ -1373,8 +1608,9 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): result = input_bit_ids + output_bit_ids + hw_bit_ids, constraints return result - def twoterms_milp_probability_xor_linear_constraints(self, binary_variable, integer_variable, input_vars, - output_vars, chunk_number): + def twoterms_milp_probability_xor_linear_constraints( + self, binary_variable, integer_variable, input_vars, output_vars, chunk_number + ): """ Return lists of variables and constraints for the probability of Modular Addition/Substraction for two inputs MILP xor linear model. @@ -1401,75 +1637,103 @@ def twoterms_milp_probability_xor_linear_constraints(self, binary_variable, inte output_bit_size = len(output_vars) for i in range(output_bit_size): - constraints.append(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] - - x[input_vars[output_bit_size + i]] - - x[input_vars[i]] + - x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] >= 0) - constraints.append(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[output_bit_size + i]] + - x[input_vars[i]] - - x[output_vars[i]] - - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] >= 0) - constraints.append(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[output_bit_size + i]] - - x[input_vars[i]] - - x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] >= 0) - constraints.append(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] - - x[input_vars[output_bit_size + i]] + - x[input_vars[i]] - - x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] >= 0) - constraints.append(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[output_bit_size + i]] - - x[input_vars[i]] + - x[output_vars[i]] - - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] >= 0) - constraints.append(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] - - x[input_vars[output_bit_size + i]] + - x[input_vars[i]] + - x[output_vars[i]] - - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] >= 0) - constraints.append(x[input_vars[output_bit_size + i]] - - x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[i]] + x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] >= 0) - constraints.append(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[output_bit_size + i]] + x[input_vars[i]] + - x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] <= 4) - - constraints.append(correlation[f"{self.id}_modadd_probability{chunk_number}"] == sum( - x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] for i in range(output_bit_size))) + constraints.append( + x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[output_bit_size + i]] + - x[input_vars[i]] + + x[output_vars[i]] + + x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + >= 0 + ) + constraints.append( + x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + + x[input_vars[output_bit_size + i]] + + x[input_vars[i]] + - x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + >= 0 + ) + constraints.append( + x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + + x[input_vars[output_bit_size + i]] + - x[input_vars[i]] + - x[output_vars[i]] + + x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + >= 0 + ) + constraints.append( + x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[output_bit_size + i]] + + x[input_vars[i]] + - x[output_vars[i]] + + x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + >= 0 + ) + constraints.append( + x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + + x[input_vars[output_bit_size + i]] + - x[input_vars[i]] + + x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + >= 0 + ) + constraints.append( + x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + - x[input_vars[output_bit_size + i]] + + x[input_vars[i]] + + x[output_vars[i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + >= 0 + ) + constraints.append( + x[input_vars[output_bit_size + i]] + - x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + + x[input_vars[i]] + + x[output_vars[i]] + + x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + >= 0 + ) + constraints.append( + x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] + + x[input_vars[output_bit_size + i]] + + x[input_vars[i]] + + x[output_vars[i]] + + x[f"{self.id}_chunk_{chunk_number}_dummy_{i + 1}"] + <= 4 + ) + + constraints.append( + correlation[f"{self.id}_modadd_probability{chunk_number}"] + == sum(x[f"{self.id}_chunk_{chunk_number}_dummy_{i}"] for i in range(output_bit_size)) + ) return variables, constraints def create_bct_mzn_constraint_from_component_ids(self): component_dict = self.as_python_dictionary() - delta_left_component_id = component_dict['input_id_link'][0] - delta_right_component_id = component_dict['input_id_link'][1] + delta_left_component_id = component_dict["input_id_link"][0] + delta_right_component_id = component_dict["input_id_link"][1] nabla_left_component_id = self.id - nabla_right_component_id = f'new_{delta_right_component_id}' + nabla_right_component_id = f"new_{delta_right_component_id}" branch_size = self.output_bit_size delta_left_vars = [] delta_right_vars = [] nabla_left_vars = [] nabla_right_vars = [] for i in range(branch_size): - delta_left_vars.append(f'{delta_left_component_id}_y{i}') - delta_right_vars.append(f'{delta_right_component_id}_y{i}') - nabla_left_vars.append(f'{nabla_left_component_id}_y{i}') - nabla_right_vars.append(f'{nabla_right_component_id}_y{i}') + delta_left_vars.append(f"{delta_left_component_id}_y{i}") + delta_right_vars.append(f"{delta_right_component_id}_y{i}") + nabla_left_vars.append(f"{nabla_left_component_id}_y{i}") + nabla_right_vars.append(f"{nabla_right_component_id}_y{i}") delta_left_str = ",".join(delta_left_vars) delta_right_str = ",".join(delta_right_vars) nabla_left_str = ",".join(nabla_left_vars) nabla_right_str = ",".join(nabla_right_vars) - delta_left = f'array1d(0..{branch_size}-1, [{delta_left_str}])' - delta_right = f'array1d(0..{branch_size}-1, [{delta_right_str}])' - nabla_left = f'array1d(0..{branch_size}-1, [{nabla_left_str}])' - nabla_right = f'array1d(0..{branch_size}-1, [{nabla_right_str}])' + delta_left = f"array1d(0..{branch_size}-1, [{delta_left_str}])" + delta_right = f"array1d(0..{branch_size}-1, [{delta_right_str}])" + nabla_left = f"array1d(0..{branch_size}-1, [{nabla_left_str}])" + nabla_right = f"array1d(0..{branch_size}-1, [{nabla_right_str}])" constraint = ( f"constraint onlyLargeSwitch_BCT_enum({delta_left}, {delta_right}, " diff --git a/claasp/components/multi_input_non_linear_logical_operator_component.py b/claasp/components/multi_input_non_linear_logical_operator_component.py index 45837adbc..485664b53 100644 --- a/claasp/components/multi_input_non_linear_logical_operator_component.py +++ b/claasp/components/multi_input_non_linear_logical_operator_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,17 +20,26 @@ from claasp.component import Component from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.cipher_modules.models.sat.utils import constants, utils as sat_utils -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_and_operation_2_input_bits import (and_LAT, - and_inequalities) +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_and_operation_2_input_bits import ( + and_LAT, + and_inequalities, +) +from claasp.name_mappings import WORD_OPERATION class MultiInputNonlinearLogicalOperator(Component): - - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, operation): - component_id = f'{operation}_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' - input_len = sum(len(bits) for bits in input_bit_positions) + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + operation, + ): + component_id = f"{operation}_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION + input_len = sum(map(len, input_bit_positions)) description = [operation.upper(), int(input_len / output_bit_size)] component_input = Input(input_len, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -91,20 +99,15 @@ def cp_deterministic_truncated_xor_differential_constraints(self): ... 'constraint if xor_0_7[11] == 0 /\\ key[23] == 0 then and_0_8[11] = 0 else and_0_8[11] = 2 endif;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_constraints = [] - for i in range(output_size): - operation = f' == 0 /\\ '.join(all_inputs[i::output_size]) - new_constraint = f'constraint if {operation} == 0 then {output_id_link}[{i}] = 0 ' \ - f'else {output_id_link}[{i}] = 2 endif;' - cp_constraints.append(new_constraint) + for i in range(self.output_bit_size): + operation = " == 0 /\\ ".join(all_inputs[i :: self.output_bit_size]) + cp_constraint = f"constraint if {operation} == 0 then {self.id}[{i}] = 0 else {self.id}[{i}] = 2 endif;" + cp_constraints.append(cp_constraint) return cp_declarations, cp_constraints @@ -135,27 +138,33 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model ... 'constraint if sbox_0_14_active[0] == 0 then and_0_18_active[3] = 0 /\\ and_0_18_value[3] = 0 else and_0_18_active[3] = 3 /\\ and_0_18_value[3] = -2 endif;']) """ - - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs_value = [] all_inputs_active = [] numadd = self.description[1] word_size = model.word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) input_len = len(all_inputs_value) // numadd cp_constraints = [] for i in range(input_len): - operation = f' == 0 /\\ '.join(all_inputs_active[i::input_len]) - new_constraint = f'constraint if {operation} == 0 then {output_id_link}_active[{i}] = 0 ' \ - f'/\\ {output_id_link}_value[{i}] = 0 else {output_id_link}_active[{i}] = 3 ' \ - f'/\\ {output_id_link}_value[{i}] = -2 endif;' + operation = " == 0 /\\ ".join(all_inputs_active[i::input_len]) + new_constraint = ( + f"constraint if {operation} == 0 then {self.id}_active[{i}] = 0 " + f"/\\ {self.id}_value[{i}] = 0 else {self.id}_active[{i}] = 3 " + f"/\\ {self.id}_value[{i}] = -2 endif;" + ) cp_constraints.append(new_constraint) return cp_declarations, cp_constraints @@ -181,30 +190,23 @@ def cp_xor_differential_propagation_constraints(self, model): ... 'constraint table([xor_0_7[11]]++[key[23]]++[and_0_8[11]]++[p[11]],and2inputs_DDT);']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions num_add = self.description[1] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) input_len = len(all_inputs) // num_add cp_declarations = [] cp_constraints = [] probability = [] - for i in range(output_size): - new_constraint = f'constraint table(' - for j in range(num_add): - new_constraint = new_constraint + f'[{all_inputs[i + input_len * j]}]++' - new_constraint = new_constraint + f'[{output_id_link}[{i}]]++[p[{model.c}]],and{num_add}inputs_DDT);' - cp_constraints.append(new_constraint) + for i in range(self.output_bit_size): + inputs = "++".join(f"[{all_inputs[i + input_len * j]}]" for j in range(num_add)) + cp_constraint = f"constraint table({inputs}++[{self.id}[{i}]]++[p[{model.c}]],and{num_add}inputs_DDT);" + cp_constraints.append(cp_constraint) model.c += 1 probability.append(model.c) - model.component_and_probability[output_id_link] = probability - result = cp_declarations, cp_constraints + model.component_and_probability[self.id] = probability - return result + return cp_declarations, cp_constraints def generic_sign_linear_constraints(self, inputs, outputs): """AND component and OR component override this method.""" @@ -212,23 +214,22 @@ def generic_sign_linear_constraints(self, inputs, outputs): def get_word_operation_sign(self, sign, solution): output_id_link = self.id - input_size = self.input_bit_size - output_size = self.output_bit_size - input_int = int(solution['components_values'][f'{output_id_link}_i']['value'], 16) - output_int = int(solution['components_values'][f'{output_id_link}_o']['value'], 16) - inputs = [int(digit) for digit in format(input_int, f'0{input_size}b')] - outputs = [int(digit) for digit in format(output_int, f'0{output_size}b')] + input_int = int(solution["components_values"][f"{output_id_link}_i"]["value"], 16) + output_int = int(solution["components_values"][f"{output_id_link}_o"]["value"], 16) + inputs = [int(digit) for digit in format(input_int, f"0{self.input_bit_size}b")] + outputs = [int(digit) for digit in format(output_int, f"0{self.output_bit_size}b")] component_sign = self.generic_sign_linear_constraints(inputs, outputs) sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign - def milp_twoterms_xor_linear_probability_constraints(self, binary_variable, integer_variable, - input_vars, output_vars, chunk_number): + def milp_twoterms_xor_linear_probability_constraints( + self, binary_variable, integer_variable, input_vars, output_vars, chunk_number + ): """ Return a variables list and a constraints list to compute the probability for AND component, for two inputs for MILP xor linear probability. @@ -252,15 +253,16 @@ def milp_twoterms_xor_linear_probability_constraints(self, binary_variable, inte inequalities = and_LAT() for ineq in inequalities: - for index in range(len(output_vars)): - tmp = x[input_vars[index]] * ineq[1] - tmp += x[input_vars[index + len(output_vars)]] * ineq[2] - tmp += x[output_vars[index]] * ineq[3] + for i, output_var in enumerate(output_vars): + tmp = x[input_vars[i]] * ineq[1] + tmp += x[input_vars[i + len(output_vars)]] * ineq[2] + tmp += x[output_var] * ineq[3] tmp += ineq[0] constraints.append(tmp >= 0) - constraints.append(p[self.id + "_and_probability" + str(chunk_number)] == - sum(x[output_vars[i]] for i in range(len(output_vars)))) + constraints.append( + p[f"{self.id}_and_probability{chunk_number}"] == sum(x[output_var] for output_var in output_vars) + ) return variables, constraints @@ -308,16 +310,18 @@ def milp_xor_differential_propagation_constraints(self, model): model.non_linear_component_id.append(component_id) inequalities = and_inequalities() for ineq in inequalities: - for index in range(len(output_vars)): + for i, output_var in enumerate(output_vars): tmp = 0 for number_of_chunk in range(self.description[1]): - tmp += x[input_vars[index + number_of_chunk * len(output_vars)]] * ineq[number_of_chunk + 1] - tmp += x[output_vars[index]] * ineq[self.description[1] + 1] - tmp += x[component_id + "_and_" + str(index)] * ineq[self.description[1] + 2] + tmp += x[input_vars[i + number_of_chunk * len(output_vars)]] * ineq[number_of_chunk + 1] + tmp += x[output_var] * ineq[self.description[1] + 1] + tmp += x[f"{component_id}_and_{i}"] * ineq[self.description[1] + 2] tmp += ineq[0] constraints.append(tmp >= 0) - constraints.append(p[component_id + "_probability"] == (10 ** model.weight_precision) * sum(x[component_id + "_and_" + str(i)] - for i in range(len(output_vars)))) + constraints.append( + p[component_id + "_probability"] + == (10**model.weight_precision) * sum(x[component_id + "_and_" + str(i)] for i in range(len(output_vars))) + ) result = variables, constraints return result @@ -372,34 +376,45 @@ def milp_xor_linear_mask_propagation_constraints(self, model): constraints = [] if number_of_inputs == 2: variables, constraints = self.milp_twoterms_xor_linear_probability_constraints( - binary_variable, integer_variable, input_vars, output_vars, 0) - constraints.append(p[component_id + "_probability"] == (10 ** model.weight_precision) * p[component_id + "_and_probability" + str(0)]) + binary_variable, integer_variable, input_vars, output_vars, 0 + ) + constraints.append( + p[component_id + "_probability"] + == (10**model.weight_precision) * p[component_id + "_and_probability" + str(0)] + ) elif number_of_inputs > 2: - temp_output_vars = [[f"{var}_temp_and_{i}" for var in output_vars] - for i in range(number_of_inputs - 2)] + temp_output_vars = [[f"{var}_temp_and_{i}" for var in output_vars] for i in range(number_of_inputs - 2)] variables, constraints = self.milp_twoterms_xor_linear_probability_constraints( - binary_variable, integer_variable, input_vars[:2 * output_bit_size], temp_output_vars[0], 0) + binary_variable, integer_variable, input_vars[: 2 * output_bit_size], temp_output_vars[0], 0 + ) for i in range(1, number_of_inputs - 2): temp_output_vars.extend([[f"{var}_temp_and_{i}" for var in output_vars]]) - temp_variables, temp_constraints = \ - self.milp_twoterms_xor_linear_probability_constraints( - binary_variable, integer_variable, - input_vars[(i + 1) * output_bit_size:(i + 2) * output_bit_size] + temp_output_vars[i - 1], - temp_output_vars[i], i) + temp_variables, temp_constraints = self.milp_twoterms_xor_linear_probability_constraints( + binary_variable, + integer_variable, + input_vars[(i + 1) * output_bit_size : (i + 2) * output_bit_size] + temp_output_vars[i - 1], + temp_output_vars[i], + i, + ) variables.extend(temp_variables) constraints.extend(temp_constraints) - temp_variables, temp_constraints = \ - self.milp_twoterms_xor_linear_probability_constraints( - binary_variable, integer_variable, - input_vars[(number_of_inputs - 1) * output_bit_size: number_of_inputs * output_bit_size] + - temp_output_vars[number_of_inputs - 3], output_vars, number_of_inputs - 2) + temp_variables, temp_constraints = self.milp_twoterms_xor_linear_probability_constraints( + binary_variable, + integer_variable, + input_vars[(number_of_inputs - 1) * output_bit_size : number_of_inputs * output_bit_size] + + temp_output_vars[number_of_inputs - 3], + output_vars, + number_of_inputs - 2, + ) variables.extend(temp_variables) constraints.extend(temp_constraints) constraints.append( - p[component_id + "_probability"] == (10 ** model.weight_precision) * sum(p[component_id + "_and_probability" + str(i)] - for i in range(number_of_inputs - 1))) + p[component_id + "_probability"] + == (10**model.weight_precision) + * sum(p[component_id + "_and_probability" + str(i)] for i in range(number_of_inputs - 1)) + ) result = variables, constraints return result @@ -437,10 +452,10 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): out_len, out_ids_0, out_ids_1 = self._generate_output_double_ids() constraints = [] for i in range(out_len): - constraints.extend([f'{out_ids_0[i]} -{in_id}' for in_id in in_ids_0[i::out_len]]) - constraints.extend([f'{out_ids_0[i]} -{in_id}' for in_id in in_ids_1[i::out_len]]) - constraints.append(f'{out_ids_0[i]} -{out_ids_1[i]}') - clause = f'{" ".join(in_ids_0[i::out_len])} {" ".join(in_ids_1[i::out_len])} -{out_ids_0[i]}' + constraints.extend([f"{out_ids_0[i]} -{in_id}" for in_id in in_ids_0[i::out_len]]) + constraints.extend([f"{out_ids_0[i]} -{in_id}" for in_id in in_ids_1[i::out_len]]) + constraints.append(f"{out_ids_0[i]} -{out_ids_1[i]}") + clause = f"{' '.join(in_ids_0[i::out_len])} {' '.join(in_ids_1[i::out_len])} -{out_ids_0[i]}" constraints.append(clause) return out_ids_0 + out_ids_1, constraints @@ -474,13 +489,16 @@ def sat_xor_differential_propagation_constraints(self, model=None): '-xor_0_7_11 hw_and_0_8_11', '-key_23 hw_and_0_8_11']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] constraints = [] for i in range(output_bit_len): - constraints.extend(sat_utils.cnf_and_differential(input_bit_ids[i], input_bit_ids[output_bit_len + i], - output_bit_ids[i], hw_bit_ids[i])) + constraints.extend( + sat_utils.cnf_and_differential( + input_bit_ids[i], input_bit_ids[output_bit_len + i], output_bit_ids[i], hw_bit_ids[i] + ) + ) result = output_bit_ids + hw_bit_ids, constraints return result @@ -513,11 +531,14 @@ def sat_xor_linear_mask_propagation_constraints(self, model=None): _, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX output_bit_len, output_bit_ids = self._generate_output_ids(out_suffix) - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] constraints = [] for i in range(output_bit_len): - constraints.extend(sat_utils.cnf_and_linear(input_bit_ids[i], input_bit_ids[output_bit_len + i], - output_bit_ids[i], hw_bit_ids[i])) + constraints.extend( + sat_utils.cnf_and_linear( + input_bit_ids[i], input_bit_ids[output_bit_len + i], output_bit_ids[i], hw_bit_ids[i] + ) + ) result = input_bit_ids + output_bit_ids + hw_bit_ids, constraints return result @@ -555,15 +576,19 @@ def smt_xor_differential_propagation_constraints(self, model=None): '(assert (or (and (not xor_0_7_10) (not key_22) (not and_0_8_10) (not hw_and_0_8_10)) (and xor_0_7_10 hw_and_0_8_10) (and key_22 hw_and_0_8_10)))', '(assert (or (and (not xor_0_7_11) (not key_23) (not and_0_8_11) (not hw_and_0_8_11)) (and xor_0_7_11 hw_and_0_8_11) (and key_23 hw_and_0_8_11)))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] constraints = [] for i in range(output_bit_len): - minterm_0 = smt_utils.smt_and((smt_utils.smt_not(input_bit_ids[i]), - smt_utils.smt_not(input_bit_ids[output_bit_len + i]), - smt_utils.smt_not(output_bit_ids[i]), - smt_utils.smt_not(hw_bit_ids[i]))) + minterm_0 = smt_utils.smt_and( + ( + smt_utils.smt_not(input_bit_ids[i]), + smt_utils.smt_not(input_bit_ids[output_bit_len + i]), + smt_utils.smt_not(output_bit_ids[i]), + smt_utils.smt_not(hw_bit_ids[i]), + ) + ) minterm_1 = smt_utils.smt_and((input_bit_ids[i], hw_bit_ids[i])) minterm_2 = smt_utils.smt_and((input_bit_ids[output_bit_len + i], hw_bit_ids[i])) sop = smt_utils.smt_or((minterm_0, minterm_1, minterm_2)) @@ -600,13 +625,17 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): _, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX output_bit_len, output_bit_ids = self._generate_output_ids(out_suffix) - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] constraints = [] for i in range(output_bit_len): - minterm_0 = smt_utils.smt_and((smt_utils.smt_not(input_bit_ids[i]), - smt_utils.smt_not(input_bit_ids[output_bit_len + i]), - smt_utils.smt_not(output_bit_ids[i]), - smt_utils.smt_not(hw_bit_ids[i]))) + minterm_0 = smt_utils.smt_and( + ( + smt_utils.smt_not(input_bit_ids[i]), + smt_utils.smt_not(input_bit_ids[output_bit_len + i]), + smt_utils.smt_not(output_bit_ids[i]), + smt_utils.smt_not(hw_bit_ids[i]), + ) + ) minterm_1 = smt_utils.smt_and((output_bit_ids[i], hw_bit_ids[i])) sop = smt_utils.smt_or((minterm_0, minterm_1)) constraints.append(smt_utils.smt_assert(sop)) diff --git a/claasp/components/not_component.py b/claasp/components/not_component.py index c30b7c906..76430a4e7 100644 --- a/claasp/components/not_component.py +++ b/claasp/components/not_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,14 +20,21 @@ from claasp.component import Component from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.cipher_modules.models.sat.utils import constants, utils as sat_utils +from claasp.name_mappings import WORD_OPERATION class NOT(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): - component_id = f'not_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' - description = ['NOT', 0] + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): + component_id = f"not_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION + description = ["NOT", 0] component_input = Input(output_bit_size, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -58,8 +64,8 @@ def algebraic_polynomials(self, model): """ ninputs = self.input_bit_size noutputs = self.output_bit_size - input_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] + input_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] ring_R = model.ring() x = list(map(ring_R, input_vars)) y = list(map(ring_R, output_vars)) @@ -121,15 +127,11 @@ def cp_constraints(self): ... 'constraint not_0_8[31] = (xor_0_6[31] + 1) mod 2;']) """ - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) - cp_constraints = [f'constraint {output_id_link}[{i}] = ({input_} + 1) mod 2;' - for i, input_ in enumerate(all_inputs)] + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) + cp_constraints = [f"constraint {self.id}[{i}] = ({input_} + 1) mod 2;" for i, input_ in enumerate(all_inputs)] return cp_declarations, cp_constraints @@ -152,15 +154,11 @@ def cp_deterministic_truncated_xor_differential_constraints(self): ... 'constraint not_0_8[31] = xor_0_6[31];']) """ - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) - cp_constraints = [f'constraint {output_id_link}[{i}] = {input_};' - for i, input_ in enumerate(all_inputs)] + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) + cp_constraints = [f"constraint {self.id}[{i}] = {input_};" for i, input_ in enumerate(all_inputs)] return cp_declarations, cp_constraints @@ -168,28 +166,34 @@ def cp_deterministic_truncated_xor_differential_trail_constraints(self): return self.cp_deterministic_truncated_xor_differential_constraints() def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model): - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs_value = [] all_inputs_active = [] word_size = model.word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) input_len = len(all_inputs_value) cp_constraints = [] for i in range(input_len): - cp_constraints.append(f'constraint {output_id_link}_active[{i}] = {all_inputs_active[i]};') - cp_constraints.append(f'if {all_inputs_value[i]} < 0 then {output_id_link}_value[{i}] = {all_inputs_value[i]} '\ - f'else {output_id_link}_value[{i}] = {2**word_size - 1} - {all_inputs_value[i]}') + cp_constraints.append(f"constraint {self.id}_active[{i}] = {all_inputs_active[i]};") + cp_constraints.append( + f"if {all_inputs_value[i]} < 0 then {self.id}_value[{i}] = {all_inputs_value[i]} " + f"else {self.id}_value[{i}] = {2**word_size - 1} - {all_inputs_value[i]}" + ) return cp_declarations, cp_constraints - def cp_xor_differential_first_step_constraints(self, model): """ Return lists of declarations and constraints for NOT component for the CP xor differential first step model. @@ -213,18 +217,17 @@ def cp_xor_differential_first_step_constraints(self, model): 'constraint not_0_18[2] = sbox_0_10[0];', 'constraint not_0_18[3] = sbox_0_14[0];']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions word_size = model.word_size - cp_declarations = [f'array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};'] + cp_declarations = [f"array[0..{(self.output_bit_size - 1) // model.word_size}] of var 0..1: {self.id};"] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - cp_constraints = [f'constraint {output_id_link}[{i}] = {input_};' - for i, input_ in enumerate(all_inputs)] + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend( + [ + f"{id_link}[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + cp_constraints = [f"constraint {self.id}[{i}] = {input_};" for i, input_ in enumerate(all_inputs)] return cp_declarations, cp_constraints @@ -247,17 +250,13 @@ def cp_xor_differential_propagation_constraints(self, model=None): ... 'constraint not_0_8[31] = xor_0_6[31];']) """ - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) - cp_constraints = [f'constraint {output_id_link}[{i}] = {input_};' - for i, input_ in enumerate(all_inputs)] - result = cp_declarations, cp_constraints - return result + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) + cp_constraints = [f"constraint {self.id}[{i}] = {input_};" for i, input_ in enumerate(all_inputs)] + + return cp_declarations, cp_constraints def cp_xor_differential_propagation_first_step_constraints(self, model): return self.cp_xor_differential_first_step_constraints(model) @@ -282,35 +281,32 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): ... 'constraint not_0_5_o[63]=not_0_5_i[63];']) """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) - output_id_link = self.id - cp_declarations = [] + cp_declarations = [ + f"array[0..{self.input_bit_size - 1}] of var 0..1:{self.id}_i;", + f"array[0..{self.output_bit_size - 1}] of var 0..1:{self.id}_o;", + ] cp_constraints = [] - cp_declarations.append(f'array[0..{input_size - 1}] of var 0..1:{output_id_link}_i;') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1:{output_id_link}_o;') - for i in range(input_size): - cp_constraints.append(f'constraint {output_id_link}_o[{i}]={output_id_link}_i[{i}];') - result = cp_declarations, cp_constraints - return result + for i in range(self.input_bit_size): + cp_constraints.append(f"constraint {self.id}_o[{i}]={self.id}_i[{i}];") + + return cp_declarations, cp_constraints def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_NOT([{",".join(params)} ])'] + return [f" {self.id} = bit_vector_NOT([{','.join(params)} ])"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_NOT({params})'] + return [f" {self.id} = byte_vector_NOT({params})"] def get_word_operation_sign(self, sign, solution): output_id_link = self.id - input_size = self.input_bit_size - input_int = int(solution['components_values'][f'{output_id_link}_i']['value'], 16) - inputs = [int(digit) for digit in format(input_int, f'0{input_size}b')] + input_int = int(solution["components_values"][f"{output_id_link}_i"]["value"], 16) + inputs = [int(digit) for digit in format(input_int, f"0{self.input_bit_size}b")] component_sign = self.generic_sign_linear_constraints(inputs) sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign @@ -332,11 +328,7 @@ def generic_sign_linear_constraints(self, inputs): sage: not_component.generic_sign_linear_constraints(inputs) 1 """ - ones = 0 - for entry in inputs: - if entry == 1: - ones += 1 - parity = ones % 2 + parity = inputs.count(1) % 2 if parity == 1: sign = -1 else: @@ -541,7 +533,7 @@ def sat_constraints(self): 'not_0_8_31 xor_0_6_31', '-not_0_8_31 -xor_0_6_31']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -589,8 +581,8 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): for out_id, in_id in zip(out_ids_0, in_ids_0): constraints.extend(sat_utils.cnf_equivalent([out_id, in_id])) for out_id, in_id_0, in_id_1 in zip(out_ids_1, in_ids_0, in_ids_1): - constraints.append(f'{in_id_0} {in_id_1} {out_id}') - constraints.append(f'{in_id_0} -{in_id_1} -{out_id}') + constraints.append(f"{in_id_0} {in_id_1} {out_id}") + constraints.append(f"{in_id_0} -{in_id_1} -{out_id}") return out_ids_0 + out_ids_1, constraints @@ -626,7 +618,7 @@ def sat_xor_differential_propagation_constraints(self, model=None): 'not_0_8_31 -xor_0_6_31', 'xor_0_6_31 -not_0_8_31']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -707,7 +699,7 @@ def smt_constraints(self): '(assert (distinct not_0_5_62 xor_0_2_62))', '(assert (distinct not_0_5_63 xor_0_2_63))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -744,7 +736,7 @@ def smt_xor_differential_propagation_constraints(self, model=None): '(assert (= not_0_5_62 xor_0_2_62))', '(assert (= not_0_5_63 xor_0_2_63))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -784,7 +776,9 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): _, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX _, output_bit_ids = self._generate_output_ids(suffix=out_suffix) - constraints = [smt_utils.smt_assert(smt_utils.smt_equivalent((input_bit_id, output_bit_id))) - for input_bit_id, output_bit_id in zip(input_bit_ids, output_bit_ids)] + constraints = [ + smt_utils.smt_assert(smt_utils.smt_equivalent((input_bit_id, output_bit_id))) + for input_bit_id, output_bit_id in zip(input_bit_ids, output_bit_ids) + ] result = input_bit_ids + output_bit_ids, constraints return result diff --git a/claasp/components/or_component.py b/claasp/components/or_component.py index 35744bc87..b7f40fb73 100644 --- a/claasp/components/or_component.py +++ b/claasp/components/or_component.py @@ -1,16 +1,16 @@ # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -22,10 +22,22 @@ class OR(MultiInputNonlinearLogicalOperator): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, 'or') + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + "or", + ) def algebraic_polynomials(self, model): """ @@ -82,9 +94,9 @@ def algebraic_polynomials(self, model): ors_number = self.description[1] - 1 word_size = noutputs ring_R = model.ring() - input_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] - words_vars = [list(map(ring_R, input_vars))[i:i + word_size] for i in range(0, ninputs, word_size)] + input_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] + words_vars = [list(map(ring_R, input_vars))[i : i + word_size] for i in range(0, ninputs, word_size)] def or_polynomial(x0, x1): return x0 * x1 + x0 + x1 @@ -121,7 +133,6 @@ def cp_constraints(self): 'constraint pre_or_0_9_1[11]=key[23];', 'constraint or(pre_or_0_9_0, pre_or_0_9_1, or_0_9);']) """ - output_size = int(self.output_bit_size) input_id_link = self.input_id_links numb_of_inp = len(input_id_link) output_id_link = self.id @@ -132,29 +143,31 @@ def cp_constraints(self): all_inputs = [] for i in range(numb_of_inp): for j in range(len(input_bit_positions[i])): - all_inputs.append(f'{input_id_link[i]}[{input_bit_positions[i][j]}]') + all_inputs.append(f"{input_id_link[i]}[{input_bit_positions[i][j]}]") total_input_len = len(all_inputs) input_len = total_input_len // num_add - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: {output_id_link};') + cp_declarations.append(f"array[0..{self.output_bit_size - 1}] of var 0..1: {output_id_link};") for i in range(num_add): - cp_declarations.append(f'array[0..{input_len - 1}] of var 0..1:pre_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{input_len - 1}] of var 0..1:pre_{output_id_link}_{i};") for j in range(input_len): - cp_constraints.append(f'constraint pre_{output_id_link}_{i}[{j}]={all_inputs[i * input_len + j]};') + cp_constraints.append(f"constraint pre_{output_id_link}_{i}[{j}]={all_inputs[i * input_len + j]};") for i in range(num_add - 2): - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1:temp_{output_id_link}_{i};') + cp_declarations.append(f"array[0..{self.output_bit_size - 1}] of var 0..1:temp_{output_id_link}_{i};") if num_add == 2: - cp_constraints.append( - f'constraint or(pre_{output_id_link}_0, pre_{output_id_link}_1, {output_id_link});') + cp_constraints.append(f"constraint or(pre_{output_id_link}_0, pre_{output_id_link}_1, {output_id_link});") elif num_add > 2: cp_constraints.append( - f'constraint or(pre_{output_id_link}_0, pre_{output_id_link}_1, temp_{output_id_link}_0);') + f"constraint or(pre_{output_id_link}_0, pre_{output_id_link}_1, temp_{output_id_link}_0);" + ) for i in range(1, num_add - 2): cp_constraints.append( - f'constraint or(pre_{output_id_link}_{i + 1}, temp_{output_id_link}_{i - 1}, ' - f'temp_{output_id_link}_{i});') + f"constraint or(pre_{output_id_link}_{i + 1}, temp_{output_id_link}_{i - 1}, " + f"temp_{output_id_link}_{i});" + ) cp_constraints.append( - f'constraint or(pre_{output_id_link}_{num_add - 1}, temp_{output_id_link}_{num_add - 3},' - f'{output_id_link});') + f"constraint or(pre_{output_id_link}_{num_add - 1}, temp_{output_id_link}_{num_add - 3}," + f"{output_id_link});" + ) return cp_declarations, cp_constraints @@ -186,31 +199,28 @@ def cp_xor_linear_mask_propagation_constraints(self, model): 'constraint p[0] = sum(p_or_39_6);'] """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) - output_id_link = self.id - cp_declarations = [] + cp_declarations = [ + f"array[0..{self.output_bit_size - 1}] of var 0..{100 * self.output_bit_size}: p_{self.id};", + f"array[0..{self.input_bit_size - 1}] of var 0..1:{self.id}_i;", + f"array[0..{self.output_bit_size - 1}] of var 0..1:{self.id}_o;", + ] cp_constraints = [] num_add = self.description[1] - input_len = input_size // num_add - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..{100 * output_size}: p_{output_id_link};') - cp_declarations.append(f'array[0..{input_size - 1}] of var 0..1:{output_id_link}_i;') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1:{output_id_link}_o;') - model.component_and_probability[output_id_link] = 0 + input_len = self.input_bit_size // num_add + model.component_and_probability[self.id] = 0 p_count = 0 - for i in range(output_size): - new_constraint = f'constraint table(' - for j in range(num_add): - new_constraint = new_constraint + f'[{output_id_link}_i[{i + input_len * j}]]++' - new_constraint = new_constraint + f'[{output_id_link}_o[{i}]]++[p_{output_id_link}[{p_count}]],and{num_add}inputs_LAT);' - cp_constraints.append(new_constraint) + for i in range(self.output_bit_size): + inputs = "++".join(f"[{self.id}_i[{i + input_len * j}]]" for j in range(num_add)) + cp_constraint = ( + f"constraint table({inputs}++[{self.id}_o[{i}]]++[p_{self.id}[{p_count}]],and{num_add}inputs_LAT);" + ) + cp_constraints.append(cp_constraint) p_count = p_count + 1 - cp_constraints.append(f'constraint p[{model.c}] = sum(p_{output_id_link});') - model.component_and_probability[output_id_link] = model.c + cp_constraints.append(f"constraint p[{model.c}] = sum(p_{self.id});") + model.component_and_probability[self.id] = model.c model.c = model.c + 1 - result = cp_declarations, cp_constraints - return result + return cp_declarations, cp_constraints def generic_sign_linear_constraints(self, inputs, outputs): """ @@ -240,10 +250,10 @@ def generic_sign_linear_constraints(self, inputs, outputs): return sign def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_OR([{",".join(params)} ], {self.description[1]}, {self.output_bit_size})'] + return [f" {self.id} = bit_vector_OR([{','.join(params)} ], {self.description[1]}, {self.output_bit_size})"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_OR({params})'] + return [f" {self.id} = byte_vector_OR({params})"] def sat_constraints(self): """ @@ -278,7 +288,7 @@ def sat_constraints(self): 'or_0_4_31 -xor_0_1_31', '-or_0_4_31 xor_0_3_31 xor_0_1_31']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -315,7 +325,7 @@ def smt_constraints(self): '(assert (= or_0_4_30 (or xor_0_3_30 xor_0_1_30)))', '(assert (= or_0_4_31 (or xor_0_3_31 xor_0_1_31)))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): diff --git a/claasp/components/permutation_component.py b/claasp/components/permutation_component.py index 4e26d0889..49bae53d7 100644 --- a/claasp/components/permutation_component.py +++ b/claasp/components/permutation_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,12 +20,25 @@ class Permutation(LinearLayer): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, permutation_description): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + ): matrix = [] for i in range(output_bit_size): row = [0] * output_bit_size row[permutation_description[i]] = 1 matrix.append(row) - super().__init__(current_round_number, current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, matrix) + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + matrix, + ) diff --git a/claasp/components/reverse_component.py b/claasp/components/reverse_component.py index 1f63763ef..a9d2b9237 100644 --- a/claasp/components/reverse_component.py +++ b/claasp/components/reverse_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,12 +20,24 @@ class Reverse(LinearLayer): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): matrix = [] for i in range(output_bit_size): row = [0] * output_bit_size row[output_bit_size - i - 1] = 1 matrix.append(row) - super().__init__(current_round_number, current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, matrix) + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + matrix, + ) diff --git a/claasp/components/rotate_component.py b/claasp/components/rotate_component.py index ba1ef3763..4dee6bb6b 100644 --- a/claasp/components/rotate_component.py +++ b/claasp/components/rotate_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,14 +20,22 @@ from claasp.component import Component from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.cipher_modules.models.sat.utils import constants, utils as sat_utils +from claasp.name_mappings import WORD_OPERATION class Rotate(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter): - component_id = f'rot_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' - description = ['ROTATE', parameter] + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ): + component_id = f"rot_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION + description = ["ROTATE", parameter] component_input = Input(output_bit_size, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -60,8 +67,8 @@ def algebraic_polynomials(self, model): rotation_const = self.description[1] ninputs = noutputs = self.output_bit_size - input_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] + input_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] ring_R = model.ring() x = list(map(ring_R, input_vars)) y = list(map(ring_R, output_vars)) @@ -122,22 +129,22 @@ def cp_constraints(self): ... 'constraint rot_0_0[15] = plaintext[8];']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions rot_amount = abs(self.description[1]) all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_declarations = [] input_len = len(all_inputs) if rot_amount == self.description[1]: - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[(i - rot_amount) % input_len]};' - for i in range(output_size)] + cp_constraints = [ + f"constraint {self.id}[{i}] = {all_inputs[(i - rot_amount) % input_len]};" + for i in range(self.output_bit_size) + ] else: - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[(i + rot_amount) % input_len]};' - for i in range(output_size)] + cp_constraints = [ + f"constraint {self.id}[{i}] = {all_inputs[(i + rot_amount) % input_len]};" + for i in range(self.output_bit_size) + ] return cp_declarations, cp_constraints @@ -163,45 +170,54 @@ def cp_inverse_constraints(self): ... 'constraint rot_0_0_inverse[15] = plaintext[8];']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions rot_amount = abs(self.description[1]) all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_declarations = [] input_len = len(all_inputs) if rot_amount == self.description[1]: - cp_constraints = [f'constraint {output_id_link}_inverse[{i}] = {all_inputs[(i - rot_amount) % input_len]};' - for i in range(output_size)] + cp_constraints = [ + f"constraint {self.id}_inverse[{i}] = {all_inputs[(i - rot_amount) % input_len]};" + for i in range(self.output_bit_size) + ] else: - cp_constraints = [f'constraint {output_id_link}_inverse[{i}] = {all_inputs[(i + rot_amount) % input_len]};' - for i in range(output_size)] + cp_constraints = [ + f"constraint {self.id}_inverse[{i}] = {all_inputs[(i + rot_amount) % input_len]};" + for i in range(self.output_bit_size) + ] return cp_declarations, cp_constraints def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model): - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs_value = [] all_inputs_active = [] word_size = model.word_size rot_amount = self.description[1] // word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) input_len = len(all_inputs_value) cp_constraints = [] for i in range(input_len): - cp_constraints.append(f'constraint {output_id_link}_active[{i}] = {all_inputs_active[(i - rot_amount) % input_len]};') - cp_constraints.append(f'constraint {output_id_link}_value[{i}] = {all_inputs_value[(i - rot_amount) % input_len]};') - + cp_constraints.append( + f"constraint {self.id}_active[{i}] = {all_inputs_active[(i - rot_amount) % input_len]};" + ) + cp_constraints.append( + f"constraint {self.id}_value[{i}] = {all_inputs_value[(i - rot_amount) % input_len]};" + ) + return cp_declarations, cp_constraints def cp_xor_differential_first_step_constraints(self, model): @@ -226,7 +242,6 @@ def cp_xor_differential_first_step_constraints(self, model): 'constraint rot_0_18[2] = sbox_0_14[0];', 'constraint rot_0_18[3] = sbox_0_2[0];']) """ - output_size = int(self.output_bit_size) input_id_link = self.input_id_links output_id_link = self.id input_bit_positions = self.input_bit_positions @@ -238,11 +253,11 @@ def cp_xor_differential_first_step_constraints(self, model): is_mix = False for i in range(numb_of_inp): for j in range(len(input_bit_positions[i]) // word_size): - all_inputs.append(f'{input_id_link[i]}[{input_bit_positions[i][j * word_size] // word_size}]') + all_inputs.append(f"{input_id_link[i]}[{input_bit_positions[i][j * word_size] // word_size}]") rem = len(input_bit_positions[i]) % word_size if rem != 0: rem = word_size - (len(input_bit_positions[i]) % word_size) - all_inputs.append(f'{output_id_link}_i[{number_of_mix}]') + all_inputs.append(f"{output_id_link}_i[{number_of_mix}]") number_of_mix += 1 is_mix = True l = 1 @@ -251,16 +266,20 @@ def cp_xor_differential_first_step_constraints(self, model): del input_bit_positions[i + l][0:rem] rem -= length l += 1 - cp_declarations = [f'array[0..{(output_size - 1) // word_size}] of var 0..1: {output_id_link};'] + cp_declarations = [f"array[0..{(self.output_bit_size - 1) // word_size}] of var 0..1: {output_id_link};"] if is_mix: - cp_declarations.append(f'array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;') + cp_declarations.append(f"array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;") input_len = len(all_inputs) if rot_amount == self.description[1]: - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[(i - rot_amount) % input_len]};' - for i in range(output_size // word_size)] + cp_constraints = [ + f"constraint {output_id_link}[{i}] = {all_inputs[(i - rot_amount) % input_len]};" + for i in range(self.output_bit_size // word_size) + ] else: - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[(i + rot_amount) % input_len]};' - for i in range(output_size // word_size)] + cp_constraints = [ + f"constraint {output_id_link}[{i}] = {all_inputs[(i + rot_amount) % input_len]};" + for i in range(self.output_bit_size // word_size) + ] return cp_declarations, cp_constraints @@ -290,29 +309,27 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): ... 'constraint rot_0_0_o[15]=rot_0_0_i[8];']) """ + cp_declarations = [ + f"array[0..{self.output_bit_size - 1}] of var 0..1: {self.id}_i;", + f"array[0..{self.output_bit_size - 1}] of var 0..1: {self.id}_o;", + ] output_size = int(self.output_bit_size) - output_id_link = self.id rot_amount = abs(self.description[1]) cp_constraints = [] - cp_declarations = [f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_i;', - f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;'] if rot_amount == self.description[1]: for i in range(output_size): - cp_constraints.append( - f'constraint {output_id_link}_o[{i}]={output_id_link}_i[{(i - rot_amount) % output_size}];') + cp_constraints.append(f"constraint {self.id}_o[{i}]={self.id}_i[{(i - rot_amount) % output_size}];") else: for i in range(output_size): - cp_constraints.append( - f'constraint {output_id_link}_o[{i}]={output_id_link}_i[{(i + rot_amount) % output_size}];') - result = cp_declarations, cp_constraints + cp_constraints.append(f"constraint {self.id}_o[{i}]={self.id}_i[{(i + rot_amount) % output_size}];") - return result + return cp_declarations, cp_constraints def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_ROTATE([{",".join(params)} ], {self.description[1]})'] + return [f" {self.id} = bit_vector_ROTATE([{','.join(params)} ], {self.description[1]})"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_ROTATE({params}, {self.description[1]}, {self.input_bit_size})'] + return [f" {self.id} = byte_vector_ROTATE({params}, {self.description[1]}, {self.input_bit_size})"] def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): rotate_code = [] @@ -321,8 +338,8 @@ def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): wordstring_variables.append(self.id) direction = "RIGHT" if self.description[1] >= 0 else "LEFT" rotate_code.append( - f'\tWordString *{self.id} = ' - f'{direction}_{self.description[0]}(input, {abs(self.description[1])});') + f"\tWordString *{self.id} = {direction}_{self.description[0]}(input, {abs(self.description[1])});" + ) if verbosity: self.print_word_values(rotate_code) @@ -333,10 +350,10 @@ def get_word_operation_sign(self, sign, solution): output_id_link = self.id component_sign = 1 sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign @@ -371,7 +388,6 @@ def milp_constraints(self, model): x_31 == x_8] """ x = model.binary_variable - output_bit_size = self.output_bit_size rotation_step = self.description[1] abs_rotation_step = abs(rotation_step) input_vars, output_vars = self._get_input_output_variables() @@ -384,8 +400,8 @@ def milp_constraints(self, model): elif rotation_step > 0: tmp = input_vars[-abs_rotation_step:] input_vars = tmp + input_vars[:-abs_rotation_step] - for i in range(output_bit_size): - constraints.append(x[output_vars[i]] == x[input_vars[i]]) + for output_var, input_var in zip(output_vars, input_vars): + constraints.append(x[output_var] == x[input_var]) return variables, constraints @@ -481,10 +497,8 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode x_31 == x_8] """ - x_class = model.trunc_binvar - output_size = self.output_bit_size rotation_step = self.description[1] abs_rotation_step = abs(rotation_step) input_class_vars, output_class_vars = self._get_input_output_variables() @@ -497,8 +511,8 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode elif rotation_step > 0: tmp = input_class_vars[-abs_rotation_step:] input_class_vars = tmp + input_class_vars[:-abs_rotation_step] - for i in range(output_size): - constraints.append(x_class[output_class_vars[i]] == x_class[input_class_vars[i]]) + for output_class_var, input_class_var in zip(output_class_vars, input_class_vars): + constraints.append(x_class[output_class_var] == x_class[input_class_var]) return class_variables, constraints @@ -536,7 +550,6 @@ def milp_xor_linear_mask_propagation_constraints(self, model): x_31 == x_8] """ x = model.binary_variable - output_bit_size = self.output_bit_size rotation_step = self.description[1] abs_rotation_step = abs(rotation_step) input_vars, output_vars = self._get_independent_input_output_variables() @@ -548,11 +561,10 @@ def milp_xor_linear_mask_propagation_constraints(self, model): elif rotation_step > 0: tmp = input_vars[-abs_rotation_step:] input_vars = tmp + input_vars[:-abs_rotation_step] - for i in range(output_bit_size): - constraints.append(x[output_vars[i]] == x[input_vars[i]]) - result = variables, constraints + for output_var, input_var in zip(output_vars, input_vars): + constraints.append(x[output_var] == x[input_var]) - return result + return variables, constraints def minizinc_constraints(self, model): r""" @@ -581,8 +593,8 @@ def minizinc_constraints(self, model): var_names = self._define_var(input_postfix, output_postfix, model.data_type) rotation_const = self.description[1] ninputs = noutputs = self.output_bit_size - input_vars = [self.id + "_" + input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + "_" + output_postfix + str(i) for i in range(noutputs)] + input_vars = [f"{self.id}_{input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{output_postfix}{i}" for i in range(noutputs)] input_vars_1 = input_vars mzn_input_array_1 = self._create_minizinc_1d_array_from_list(input_vars_1) output_vars_1 = output_vars @@ -590,10 +602,12 @@ def minizinc_constraints(self, model): if rotation_const < 0: rotate_mzn_constraints = [ - f'constraint LRot({mzn_input_array_1}, {int(-1*rotation_const)})={mzn_output_array_1};\n'] + f"constraint LRot({mzn_input_array_1}, {int(-1 * rotation_const)})={mzn_output_array_1};\n" + ] else: rotate_mzn_constraints = [ - f'constraint RRot({mzn_input_array_1}, {int(rotation_const)})={mzn_output_array_1};\n'] + f"constraint RRot({mzn_input_array_1}, {int(rotation_const)})={mzn_output_array_1};\n" + ] return var_names, rotate_mzn_constraints @@ -636,7 +650,7 @@ def sat_constraints(self): 'rot_1_1_15 -key_40', 'key_40 -rot_1_1_15']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() rotation = self.description[1] input_bit_ids_rotated = input_bit_ids[-rotation:] + input_bit_ids[:-rotation] @@ -801,7 +815,7 @@ def smt_constraints(self): '(assert (= rot_0_0_14 plaintext_7))', '(assert (= rot_0_0_15 plaintext_8))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() rotation = self.description[1] input_bit_ids_rotated = input_bit_ids[-rotation:] + input_bit_ids[:-rotation] diff --git a/claasp/components/sbox_component.py b/claasp/components/sbox_component.py index 20a634fce..b1ce9d563 100644 --- a/claasp/components/sbox_component.py +++ b/claasp/components/sbox_component.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import math import subprocess from itertools import product, combinations @@ -26,10 +24,10 @@ from sage.arith.misc import is_power_of_two from sage.crypto.sbox import SBox -from claasp.cipher_modules.models.milp.utils.generate_undisturbed_bits_inequalities_for_sboxes import \ - update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits, \ - get_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits, \ - delete_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits +from claasp.cipher_modules.models.milp.utils.generate_undisturbed_bits_inequalities_for_sboxes import ( + update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits, + get_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits, +) from claasp.cipher_modules.models.milp.utils.milp_name_mappings import MILP_DEFAULT_WEIGHT_PRECISION from claasp.cipher_modules.models.milp.utils.utils import espresso_pos_to_constraints from claasp.input import Input @@ -39,22 +37,27 @@ from claasp.cipher_modules.models.milp.utils import utils as milp_utils from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_large_sboxes import ( update_dictionary_that_contains_inequalities_for_large_sboxes, - get_dictionary_that_contains_inequalities_for_large_sboxes) + get_dictionary_that_contains_inequalities_for_large_sboxes, +) from claasp.cipher_modules.models.milp.utils.generate_sbox_inequalities_for_trail_search import ( update_dictionary_that_contains_inequalities_for_small_sboxes, - get_dictionary_that_contains_inequalities_for_small_sboxes) + get_dictionary_that_contains_inequalities_for_small_sboxes, +) def check_table_feasibility(table, table_type, solver): occurrences = set(abs(value) for row in table.rows() for value in set(row)) - {0} for occurrence in occurrences: if not is_power_of_two(occurrence): - raise ValueError(f'The S-box {table_type} of the cipher contains {occurrence} ' - f'which is not a power of two. Currently, {solver} cannot handle it.') + raise ValueError( + f"The S-box {table_type} of the cipher contains {occurrence} " + f"which is not a power of two. Currently, {solver} cannot handle it." + ) -def cp_update_ddt_valid_probabilities(cipher, component, word_size, cp_declarations, - table_items, valid_probabilities, sbox_mant): +def cp_update_ddt_valid_probabilities( + cipher, component, word_size, cp_declarations, table_items, valid_probabilities, sbox_mant +): input_size = int(component.input_bit_size) output_id_link = component.id description = component.description @@ -68,21 +71,20 @@ def cp_update_ddt_valid_probabilities(cipher, component, word_size, cp_declarati for i in range(sbox_ddt.nrows()): set_of_occurrences = set(sbox_ddt.rows()[i]) set_of_occurrences -= {0} - valid_probabilities.update({round(100 * math.log2(2 ** input_size / occurrence)) - for occurrence in set_of_occurrences}) + valid_probabilities.update( + {round(100 * math.log2(2**input_size / occurrence)) for occurrence in set_of_occurrences} + ) sbox_mant.append((description, output_id_link)) if cipher.is_spn(): input_id_link = component.input_id_links[0] input_bit_positions = component.input_bit_positions[0] - all_inputs = [f'{input_id_link}[{position}]' for position in input_bit_positions] + all_inputs = [f"{input_id_link}[{position}]" for position in input_bit_positions] for i in range(input_size // word_size): - ineq_left_side = '+'.join([f'{all_inputs[i * word_size + j]}' - for j in range(word_size)]) - new_declaration = f'constraint ({ineq_left_side} > 0) = word_{output_id_link}[{i}];' + ineq_left_side = "+".join([f"{all_inputs[i * word_size + j]}" for j in range(word_size)]) + new_declaration = f"constraint ({ineq_left_side} > 0) = word_{output_id_link}[{i}];" cp_declarations.append(new_declaration) - cp_declarations.append( - f'array[0..{input_size // word_size - 1}] of var 0..1: word_{output_id_link};') - table_items.append(f'[word_{output_id_link}[s] | s in 0..{input_size // word_size - 1}]') + cp_declarations.append(f"array[0..{input_size // word_size - 1}] of var 0..1: word_{output_id_link};") + table_items.append(f"[word_{output_id_link}[s] | s in 0..{input_size // word_size - 1}]") def cp_update_lat_valid_probabilities(component, valid_probabilities, sbox_mant): @@ -99,20 +101,33 @@ def cp_update_lat_valid_probabilities(component, valid_probabilities, sbox_mant) for i in range(sbox_lat.nrows()): set_of_occurrences = set(sbox_lat.rows()[i]) set_of_occurrences -= {0} - valid_probabilities.update({round(100 * math.log2(abs(pow(2, input_size - 1) / occurence))) for occurence in set_of_occurrences}) + valid_probabilities.update( + {round(100 * math.log2(abs(pow(2, input_size - 1) / occurence))) for occurence in set_of_occurrences} + ) sbox_mant.append((description, output_id_link)) -def milp_set_constraints_from_dictionnary_for_large_sbox(component_id, input_vars, - output_vars, sbox_input_size, sbox_output_size, x, p, - probability_dictionary, analysis, weight_precision): +def milp_set_constraints_from_dictionnary_for_large_sbox( + component_id, + input_vars, + output_vars, + sbox_input_size, + sbox_output_size, + x, + p, + probability_dictionary, + analysis, + weight_precision, +): constraints = [] # condition to know if sbox is active or not constraints.append( - sbox_input_size * x[f"{component_id}_active"] >= sum(x[input_vars[i]] for i in range(sbox_input_size))) + sbox_input_size * x[f"{component_id}_active"] >= sum(x[input_vars[i]] for i in range(sbox_input_size)) + ) constraints.append( - sbox_input_size * (1 - x[f"{component_id}_active"]) >= - -sum(x[input_vars[i]] for i in range(sbox_input_size)) + 1) + sbox_input_size * (1 - x[f"{component_id}_active"]) + >= -sum(x[input_vars[i]] for i in range(sbox_input_size)) + 1 + ) constraints += [x[f"{component_id}_active"] >= x[output_vars[i]] for i in range(sbox_output_size)] # mip.add_constraint(sum(x[output_vars[i]] for i in range(sbox.input_size())) >= x[id + "_active"]) @@ -121,28 +136,31 @@ def milp_set_constraints_from_dictionnary_for_large_sbox(component_id, input_var else: exponent = sbox_input_size - 1 - M = (10 ** weight_precision) * sbox_input_size + M = (10**weight_precision) * sbox_input_size constraint_choice_proba = 0 constraint_compute_proba = 0 for proba in probability_dictionary.keys(): for ineq in probability_dictionary[proba]: - constraint = milp_large_xor_probability_constraint_for_inequality(M, component_id, ineq, input_vars, - output_vars, proba, sbox_input_size, - sbox_output_size, x) + constraint = milp_large_xor_probability_constraint_for_inequality( + M, component_id, ineq, input_vars, output_vars, proba, sbox_input_size, sbox_output_size, x + ) constraints.append(constraint >= 0) constraint_choice_proba += x[f"{component_id}_sboxproba_{proba}"] - constraint_compute_proba += (x[f"{component_id}_sboxproba_{proba}"] * - (10 ** weight_precision) * round(-log(abs(proba) / (2 ** exponent), 2), - weight_precision)) + constraint_compute_proba += ( + x[f"{component_id}_sboxproba_{proba}"] + * (10**weight_precision) + * round(-log(abs(proba) / (2**exponent), 2), weight_precision) + ) constraints.append(constraint_choice_proba == x[f"{component_id}_active"]) constraints.append(p[f"{component_id}_probability"] == constraint_compute_proba) return constraints -def milp_large_xor_probability_constraint_for_inequality(M, component_id, ineq, input_vars, - output_vars, proba, sbox_input_size, sbox_output_size, x): +def milp_large_xor_probability_constraint_for_inequality( + M, component_id, ineq, input_vars, output_vars, proba, sbox_input_size, sbox_output_size, x +): constraint = 0 for i in range(sbox_input_size - 1, -1, -1): char = ineq[i] @@ -165,28 +183,27 @@ def milp_large_xor_probability_constraint_for_inequality(M, component_id, ineq, def sat_build_table_template(table, get_hamming_weight_function, input_bit_len, output_bit_len): # create espresso input input_length = input_bit_len + 2 * output_bit_len - espresso_input = [f'.i {input_length}', '.o 1'] + espresso_input = [f".i {input_length}", ".o 1"] for i in range(table.nrows()): for j in range(table.ncols()): if table[i, j] != 0: - input_diff = f'{i:0{input_bit_len}b}' - output_diff = f'{j:0{output_bit_len}b}' + input_diff = f"{i:0{input_bit_len}b}" + output_diff = f"{j:0{output_bit_len}b}" hamming_weight = get_hamming_weight_function(input_bit_len, table[i, j]) - weight_vec = '0' * (output_bit_len - hamming_weight) - weight_vec += '1' * hamming_weight - espresso_input.append(f'{input_diff}{output_diff}{weight_vec} 1') - espresso_input.append('.e') - espresso_input = '\n'.join(espresso_input) + '\n' + weight_vec = "0" * (output_bit_len - hamming_weight) + weight_vec += "1" * hamming_weight + espresso_input.append(f"{input_diff}{output_diff}{weight_vec} 1") + espresso_input.append(".e") + espresso_input = "\n".join(espresso_input) + "\n" # execute espresso process - espresso_process = subprocess.run(['espresso', '-epos'], input=espresso_input, - capture_output=True, text=True) + espresso_process = subprocess.run(["espresso", "-epos"], input=espresso_input, capture_output=True, text=True) espresso_output = espresso_process.stdout.splitlines() # formatting template template = [] for line in espresso_output[4:-1]: - clause = tuple((int(line[i]), i) for i in range(input_length) if line[i] != '-') + clause = tuple((int(line[i]), i) for i in range(input_length) if line[i] != "-") template.append(clause) return template @@ -210,10 +227,6 @@ def smt_get_sbox_probability_constraints(bit_ids, template): return constraints -def _to_int(bits): - return int("".join(map(str, bits)), 2) - - def _combine_truncated(input_1, input_2): return [bit_1 if bit_1 == bit_2 else 2 for bit_1, bit_2 in zip(input_1, input_2)] @@ -229,17 +242,20 @@ def _get_truncated_output_difference(ddt_row, n): output_bits[n - 1 - bit] = 1 if delta[0] else 0 return has_undisturbed_bits, output_bits -def _mzn_update_sbox_mant_for_deterministic_truncated_xor_differential(inv_output_id_link, undisturbed_bits, sbox_mant, inverse): + +def _mzn_update_sbox_mant_for_deterministic_truncated_xor_differential( + inv_output_id_link, undisturbed_bits, sbox_mant, inverse +): undisturbed_bits_ddt = [] for pair in undisturbed_bits: undisturbed_bits_ddt += list(pair[0]) + list(pair[1]) for i in range(len(undisturbed_bits_ddt)): undisturbed_bits_ddt[i] = str(undisturbed_bits_ddt[i]) - undisturbed_table_bits = ','.join(undisturbed_bits_ddt) + undisturbed_table_bits = ",".join(undisturbed_bits_ddt) already_in = False output_id_link_sost = inv_output_id_link for mant in sbox_mant: - if undisturbed_table_bits == mant[0] and ((not inverse) or (inverse and 'inverse' in mant[1])): + if undisturbed_table_bits == mant[0] and ((not inverse) or (inverse and "inverse" in mant[1])): already_in = True output_id_link_sost = mant[1] if not already_in: @@ -247,13 +263,22 @@ def _mzn_update_sbox_mant_for_deterministic_truncated_xor_differential(inv_outpu return already_in, output_id_link_sost, undisturbed_table_bits + class SBOX(Component): sboxes_ddt_templates = {} sboxes_lat_templates = {} - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, s_box_description): - component_id = f'sbox_{current_round_number}_{current_round_number_of_components}' - component_type = 'sbox' + + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + s_box_description, + ): + component_id = f"sbox_{current_round_number}_{current_round_number_of_components}" + component_type = "sbox" input_len = sum(map(len, input_bit_positions)) description = s_box_description component_input = Input(input_len, input_id_links, input_bit_positions) @@ -330,7 +355,7 @@ def get_ddt_with_undisturbed_transitions(self): all_combinations_of_inputs_with_undisturbed_bits = {} for input_bits in all_fixed_inputs: - delta_in = _to_int(input_bits) + delta_in = int("".join(map(str, input_bits)), 2) has_undisturbed_bits, output_bits = _get_truncated_output_difference(ddt[delta_in], n) if has_undisturbed_bits: fixed_inputs_with_undisturbed_bits.append(input_bits) @@ -340,7 +365,7 @@ def get_ddt_with_undisturbed_transitions(self): tested_inputs = all_fixed_inputs[:] inputs_to_combine = fixed_inputs_with_undisturbed_bits[:] - while (len(inputs_to_combine) != 0): + while len(inputs_to_combine) != 0: newly_combined_inputs = [] for input_1, input_2 in combinations(inputs_to_combine, 2): truncated_positions = list(map(xor, input_1, input_2)) @@ -359,7 +384,7 @@ def get_ddt_with_undisturbed_transitions(self): inputs_to_combine = newly_combined_inputs for input_bits in set(list(product([0, 1, 2], repeat=n))).difference(set(tested_inputs)): - valid_points.append((input_bits, (2,)*n)) + valid_points.append((input_bits, (2,) * n)) return valid_points @@ -414,39 +439,37 @@ def cp_constraints(self, sbox_mant, second=False): (['array [1..16, 1..8] of int: table_sbox_0_5 = array2d(1..16, 1..8, [0,0,0,0,1,1,0,0,0,0,0,1,1,0,1,0,0,0,1,0,1,1,0,1,0,0,1,1,0,0,1,1,0,1,0,0,1,1,1,0,0,1,0,1,1,0,1,1,0,1,1,0,1,1,1,1,0,1,1,1,0,1,1,1,1,0,0,0,1,0,0,0,1,0,0,1,1,0,0,1,1,0,1,0,0,0,0,1,1,0,1,1,0,1,0,1,1,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,1,1,1,0,0,1,0,0,1,1,1,1,0,1,1,0]);'], ['constraint table([xor_0_1[4]]++[xor_0_1[5]]++[xor_0_1[6]]++[xor_0_1[7]]++[sbox_0_5[0]]++[sbox_0_5[1]]++[sbox_0_5[2]]++[sbox_0_5[3]], table_sbox_0_5);']) """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions sbox = self.description if second: - sec_output_id_link = 'second_' + self.id + sec_output_id_link = f"second_{self.id}" else: sec_output_id_link = self.id already_in = False output_id_link_sost = sec_output_id_link for mant in sbox_mant: - if sbox == mant[0] and ((not second) or (second and 'second' in mant[1])): + if sbox == mant[0] and ((not second) or (second and "second" in mant[1])): already_in = True output_id_link_sost = mant[1] cp_declarations = [] + input_size = self.input_bit_size + output_size = self.output_bit_size if not already_in: - bin_i = (','.join(f'{i:0{input_size}b}') for i in range(2 ** input_size)) - bin_sbox = (','.join(f'{sbox[i]:0{output_size}b}') for i in range(2 ** input_size)) - table_values = ','.join([f'{i},{s}' for i, s in zip(bin_i, bin_sbox)]) - sbox_declaration = f'array [1..{len(sbox)}, 1..{input_size + output_size}] of int: ' \ - f'table_{output_id_link_sost} = array2d(1..{len(sbox)}, 1..{input_size + output_size}, ' \ - f'[{table_values}]);' + bin_i = (",".join(f"{i:0{input_size}b}") for i in range(2**input_size)) + bin_sbox = (",".join(f"{sbox[i]:0{output_size}b}") for i in range(2**input_size)) + table_values = ",".join([f"{i},{s}" for i, s in zip(bin_i, bin_sbox)]) + sbox_declaration = ( + f"array [1..{len(sbox)}, 1..{input_size + output_size}] of int: " + f"table_{output_id_link_sost} = array2d(1..{len(sbox)}, 1..{input_size + output_size}, " + f"[{table_values}]);" + ) cp_declarations.append(sbox_declaration) - sbox_mant.append((sbox, output_id_link)) + sbox_mant.append((sbox, self.id)) all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'[{id_link}[{position}]]' for position in bit_positions]) - table_input = '++'.join(all_inputs) - table_output = '++'.join([f'[{output_id_link}[{i}]]' for i in range(output_size)]) - new_constraint = f'constraint table({table_input}++{table_output}, table_{output_id_link_sost});' - cp_constraints = [new_constraint] + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"[{id_link}[{position}]]" for position in bit_positions]) + table_input = "++".join(all_inputs) + table_output = "++".join([f"[{self.id}[{i}]]" for i in range(output_size)]) + cp_constraints = [f"constraint table({table_input}++{table_output}, table_{output_id_link_sost});"] return cp_declarations, cp_constraints @@ -468,35 +491,37 @@ def cp_deterministic_truncated_xor_differential_constraints(self, sbox_mant, inv ['constraint table([xor_0_0[0]]++[xor_0_0[1]]++[xor_0_0[2]]++[xor_0_0[3]]++[xor_0_0[4]]++[xor_0_0[5]]++[xor_0_0[6]]++[xor_0_0[7]]++[sbox_0_1[0]]++[sbox_0_1[1]]++[sbox_0_1[2]]++[sbox_0_1[3]]++[sbox_0_1[4]]++[sbox_0_1[5]]++[sbox_0_1[6]]++[sbox_0_1[7]], table_sbox_0_1);'] """ - input_id_links = self.input_id_links output_id_link = self.id if inverse: - inv_output_id_link = 'inverse_' + self.id + inv_output_id_link = f"inverse_{self.id}" else: inv_output_id_link = self.id - output_size = self.output_bit_size - input_bit_positions = self.input_bit_positions - cp_declarations = [] - cp_constraints = [] all_inputs = [] eventual_undisturbed_bits = self.get_ddt_with_undisturbed_transitions() num_pairs = len(eventual_undisturbed_bits) len_input = len(eventual_undisturbed_bits[0][0]) len_output = len(eventual_undisturbed_bits[0][1]) - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'[{id_link}[{position}]]' for position in bit_positions]) - table_input = '++'.join(all_inputs) - table_output = '++'.join([f'[{output_id_link}[{i}]]' for i in range(output_size)]) - - already_in, output_id_link_sost, undisturbed_table_bits = _mzn_update_sbox_mant_for_deterministic_truncated_xor_differential( - inv_output_id_link, eventual_undisturbed_bits, sbox_mant, inverse) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"[{id_link}[{position}]]" for position in bit_positions]) + table_input = "++".join(all_inputs) + table_output = "++".join([f"[{output_id_link}[{i}]]" for i in range(self.output_bit_size)]) + + already_in, output_id_link_sost, undisturbed_table_bits = ( + _mzn_update_sbox_mant_for_deterministic_truncated_xor_differential( + inv_output_id_link, eventual_undisturbed_bits, sbox_mant, inverse + ) + ) + cp_declarations = [] + cp_constraints = [] if not already_in: - undisturbed_declaration = f'array [1..{num_pairs}, 1..{len_input + len_output}] of int: ' \ - f'table_{output_id_link_sost} = array2d(1..{num_pairs}, 1..{len_input + len_output}, ' \ - f'[{undisturbed_table_bits}]);' + undisturbed_declaration = ( + f"array [1..{num_pairs}, 1..{len_input + len_output}] of int: " + f"table_{output_id_link_sost} = array2d(1..{num_pairs}, 1..{len_input + len_output}, " + f"[{undisturbed_table_bits}]);" + ) cp_declarations.append(undisturbed_declaration) - new_constraint = f'constraint table({table_input}++{table_output}, table_{output_id_link_sost});' + new_constraint = f"constraint table({table_input}++{table_output}, table_{output_id_link_sost});" cp_constraints.append(new_constraint) return cp_declarations, cp_constraints, sbox_mant @@ -504,7 +529,9 @@ def cp_deterministic_truncated_xor_differential_constraints(self, sbox_mant, inv def cp_deterministic_truncated_xor_differential_trail_constraints(self, sbox_mant, inverse=False): return self.cp_deterministic_truncated_xor_differential_constraints(sbox_mant, inverse) - def cp_hybrid_deterministic_truncated_xor_differential_constraints(self, sbox_mant, inverse=False, list_of_component_number=[]): + def cp_hybrid_deterministic_truncated_xor_differential_constraints( + self, sbox_mant, inverse=False, list_of_component_number=[] + ): """ Return lists of declarations and constraints for SBOX component for CP hybrid deterministic truncated xor differential. @@ -523,25 +550,26 @@ def cp_hybrid_deterministic_truncated_xor_differential_constraints(self, sbox_ma sage: constraints ['constraint abstract_sbox_0_2(array1d(0..3, [xor_0_1[4]]++[xor_0_1[5]]++[xor_0_1[6]]++[xor_0_1[7]]), array1d(0..3, [sbox_0_2[0]]++[sbox_0_2[1]]++[sbox_0_2[2]]++[sbox_0_2[3]]), 0, 0);'] """ + def _get_abstracted_predicate(): if list_of_component_number != []: max_number_of_sboxes = max(len(lst) for lst in list_of_component_number.values()) - sbox_id = f"{max_number_of_sboxes*10}*r + 10*(index+1)" + sbox_id = f"{max_number_of_sboxes * 10}*r + 10*(index+1)" else: - sbox_id = f"100*r + 10*(index+1)" + sbox_id = "100*r + 10*(index+1)" sbox_declaration = f"predicate abstract_{output_id_link_sost}(array[int] of var int: x, array[int] of var int: y, int: r, int: index) =\n" - condition1 = r" /\ ".join([f"x{[j]} == 0" for j in range(len(all_inputs))]) - outcome1 = r" /\ ".join([f"y{[j]} = 0" for j in range(output_size)]) - condition2 = r" \/ ".join([f"x{[j]} == 1" for j in range(len(all_inputs))]) - outcome2_1 = r" /\ ".join([f"y{[j]} = {sbox_id}" for j in range(output_size)]) + condition1 = " /\\ ".join(f"x[{j}] == 0" for j in range(len(all_inputs))) + outcome1 = " /\\ ".join([f"y[{j}] = 0" for j in range(output_size)]) + condition2 = " \\/ ".join([f"x[{j}] == 1" for j in range(len(all_inputs))]) + outcome2_1 = " /\\ ".join([f"y[{j}] = {sbox_id}" for j in range(output_size)]) outcome2_2 = f"apply_{output_id_link_sost}(x, y)" outcome2 = f"({outcome2_1}) \\/ ({outcome2_2})" - condition3_1 = r" /\ ".join([f"x{[j]} > 2" for j in range(len(all_inputs))]) - condition3_2 = r" /\ ".join([f"x{[j]} = x{[0]}" for j in range(1, len(all_inputs))]) + condition3_1 = " /\\ ".join([f"x[{j}] > 2" for j in range(len(all_inputs))]) + condition3_2 = " /\\ ".join([f"x[{j}] = x[0]" for j in range(1, len(all_inputs))]) condition3 = f"({condition3_1}) /\\ ({condition3_2})" - default = r" /\ ".join([f"y{[j]} = 2" for j in range(len(all_inputs))]) + default = " /\\ ".join([f"y[{j}] = 2" for j in range(len(all_inputs))]) sbox_declaration += ( f"\tif ({condition1}) then ({outcome1})\n" f"\telseif ({condition2}) then ({outcome2})\n" @@ -550,42 +578,42 @@ def _get_abstracted_predicate(): ) return sbox_declaration + def _get_sbox_predicate(): - sbox_declaration = f"predicate apply_{output_id_link_sost}(array[int] of var int: x, array[int] of var int: y) =\n" + sbox_declaration = ( + f"predicate apply_{output_id_link_sost}(array[int] of var int: x, array[int] of var int: y) =\n" + ) for i, (inputs, outputs) in enumerate(undisturbed_bits): - condition = r" /\ ".join([f"x{[j]} == {inputs[j]}" for j in range(len(all_inputs))]) - outcome = r" /\ ".join([f"y{[j]} = {outputs[j]}" for j in range(output_size)]) + condition = " /\\ ".join([f"x[{j}] == {inputs[j]}" for j in range(len(all_inputs))]) + outcome = " /\\ ".join([f"y[{j}] = {outputs[j]}" for j in range(output_size)]) if i == 0: sbox_declaration += f"\tif ({condition}) then ({outcome})\n" else: sbox_declaration += f"\telseif ({condition}) then ({outcome})\n" - else_condition = r" /\ ".join([f"y{[j]} = 2" for j in range(output_size)]) + else_condition = " /\\ ".join([f"y{[j]} = 2" for j in range(output_size)]) sbox_declaration += f"\telse ({else_condition}) endif;" return sbox_declaration - input_id_links = self.input_id_links - output_id_link = self.id if inverse: - inv_output_id_link = 'inverse_' + self.id + inv_output_id_link = "inverse_" + self.id else: inv_output_id_link = self.id output_size = self.output_bit_size - input_bit_positions = self.input_bit_positions - cp_declarations = [] - cp_constraints = [] all_inputs = [] eventual_undisturbed_bits = self.get_ddt_with_undisturbed_transitions() - undisturbed_bits = [tr for tr in eventual_undisturbed_bits if tr[1] != (2,)*self.output_bit_size] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'[{id_link}[{position}]]' for position in bit_positions]) - table_input = '++'.join(all_inputs) - table_output = '++'.join([f'[{output_id_link}[{i}]]' for i in range(output_size)]) + undisturbed_bits = [tr for tr in eventual_undisturbed_bits if tr[1] != (2,) * self.output_bit_size] + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"[{id_link}[{position}]]" for position in bit_positions]) + table_input = "++".join(all_inputs) + table_output = "++".join([f"[{self.id}[{i}]]" for i in range(output_size)]) - already_in, output_id_link_sost, undisturbed_table_bits = _mzn_update_sbox_mant_for_deterministic_truncated_xor_differential( - inv_output_id_link, undisturbed_bits, sbox_mant, inverse) + already_in, output_id_link_sost, _ = _mzn_update_sbox_mant_for_deterministic_truncated_xor_differential( + inv_output_id_link, undisturbed_bits, sbox_mant, inverse + ) + cp_declarations = [] if not already_in: sbox_declaration = _get_sbox_predicate() abstract_sbox_declaration = _get_abstracted_predicate() @@ -594,8 +622,8 @@ def _get_sbox_predicate(): index_of_id = 0 if list_of_component_number != []: index_of_id = list_of_component_number[round].index(id_in_round) - new_constraint = f'constraint abstract_{output_id_link_sost}(array1d(0..{len(all_inputs)-1}, {table_input}), array1d(0..{output_size-1}, {table_output}), {round}, {index_of_id});' - cp_constraints.append(new_constraint) + cp_constraint = f"constraint abstract_{output_id_link_sost}(array1d(0..{len(all_inputs) - 1}, {table_input}), array1d(0..{output_size - 1}, {table_output}), {round}, {index_of_id});" + cp_constraints = [cp_constraint] return cp_declarations, cp_constraints, sbox_mant @@ -618,20 +646,21 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model ([], ['constraint if xor_0_0_value[0]==0 then sbox_0_1_active[0] = 0 else sbox_0_1_active[0] = 2 endif;']) """ - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions - cp_constraints = [] cp_declarations = [] all_inputs = [] word_size = model.word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + cp_constraints = [] for i, input_ in enumerate(all_inputs): cp_constraints.append( - f'constraint if {input_}==0 then {output_id_link}_active[{i}] = 0' - f' else {output_id_link}_active[{i}] = 2 endif;') + f"constraint if {input_}==0 then {self.id}_active[{i}] = 0 else {self.id}_active[{i}] = 2 endif;" + ) return cp_declarations, cp_constraints @@ -654,22 +683,17 @@ def cp_xor_differential_first_step_constraints(self, model): (['array[0..0] of var 0..1: sbox_0_1;'], ['constraint sbox_0_1[0] = xor_0_0[0];']) """ - input_size = int(self.input_bit_size) - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions all_inputs = [] - cp_constraints = [] word_size = model.word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): for j in range(len(bit_positions) // word_size): - all_inputs.append(f'{id_link}[{bit_positions[j * word_size] // word_size}]') - model.input_sbox.append((f'{id_link}[{bit_positions[j * word_size] // word_size}]', - input_size // word_size - 1)) - model.table_of_solutions_length += input_size // word_size - cp_declarations = [f'array[0..{(output_size - 1) // word_size}] of var 0..1: {output_id_link};'] - cp_constraints.extend([f'constraint {output_id_link}[{i}] = {input_};' for i, input_ in enumerate(all_inputs)]) + all_inputs.append(f"{id_link}[{bit_positions[j * word_size] // word_size}]") + model.input_sbox.append( + (f"{id_link}[{bit_positions[j * word_size] // word_size}]", self.input_bit_size // word_size - 1) + ) + model.table_of_solutions_length += self.input_bit_size // word_size + cp_declarations = [f"array[0..{(self.output_bit_size - 1) // word_size}] of var 0..1: {self.id};"] + cp_constraints = [f"constraint {self.id}[{i}] = {input_};" for i, input_ in enumerate(all_inputs)] return cp_declarations, cp_constraints @@ -697,21 +721,18 @@ def cp_xor_differential_propagation_constraints(self, model, inverse=False): """ input_size = int(self.input_bit_size) output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions description = self.description sbox = SBox(description) - cp_declarations = [] already_in = False if inverse: - output_id_link_sost = 'inverse_' + output_id_link + output_id_link_sost = f"inverse_{self.id}" else: - output_id_link_sost = output_id_link + output_id_link_sost = self.id for mant in model.sbox_mant: - if description == mant[0] and ((not inverse) or (inverse and 'inverse' in mant[1])): + if description == mant[0] and ((not inverse) or (inverse and "inverse" in mant[1])): already_in = True output_id_link_sost = mant[1] + cp_declarations = [] if not already_in: sbox_ddt = sbox.difference_distribution_table() dim_ddt = len([i for i in sbox_ddt.list() if i]) @@ -719,24 +740,26 @@ def cp_xor_differential_propagation_constraints(self, model, inverse=False): for i in range(sbox_ddt.nrows()): for j in range(sbox_ddt.ncols()): if sbox_ddt[i][j]: - sep_bin_i = ','.join(f'{i:0{input_size}b}') - sep_bin_j = ','.join(f'{j:0{output_size}b}') - log_of_prob = round(100 * math.log2((2 ** input_size) / sbox_ddt[i][j])) - ddt_entries.append(f'{sep_bin_i},{sep_bin_j},{log_of_prob}') - ddt_values = ','.join(ddt_entries) - sbox_declaration = f'array [1..{dim_ddt}, 1..{input_size + output_size + 1}] of int: ' \ - f'DDT_{output_id_link_sost} = array2d(1..{dim_ddt}, 1..{input_size + output_size + 1}, ' \ - f'[{ddt_values}]);' + sep_bin_i = ",".join(f"{i:0{input_size}b}") + sep_bin_j = ",".join(f"{j:0{output_size}b}") + log_of_prob = round(100 * math.log2((2**input_size) / sbox_ddt[i][j])) + ddt_entries.append(f"{sep_bin_i},{sep_bin_j},{log_of_prob}") + ddt_values = ",".join(ddt_entries) + sbox_declaration = ( + f"array [1..{dim_ddt}, 1..{input_size + output_size + 1}] of int: " + f"DDT_{output_id_link_sost} = array2d(1..{dim_ddt}, 1..{input_size + output_size + 1}, " + f"[{ddt_values}]);" + ) cp_declarations.append(sbox_declaration) - model.sbox_mant.append((description, output_id_link)) + model.sbox_mant.append((description, self.id)) all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'[{id_link}[{position}]]' for position in bit_positions]) - table_input = '++'.join(all_inputs) - table_output = '++'.join([f'[{output_id_link}[{i}]]' for i in range(output_size)]) - constraint = f'constraint table({table_input}++{table_output}++[p[{model.c}]], DDT_{output_id_link_sost});' + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"[{id_link}[{position}]]" for position in bit_positions]) + table_input = "++".join(all_inputs) + table_output = "++".join([f"[{self.id}[{i}]]" for i in range(output_size)]) + constraint = f"constraint table({table_input}++{table_output}++[p[{model.c}]], DDT_{output_id_link_sost});" cp_constraints = [constraint] - model.component_and_probability[output_id_link] = model.c + model.component_and_probability[self.id] = model.c model.c += 1 return cp_declarations, cp_constraints @@ -776,31 +799,34 @@ def cp_xor_linear_mask_propagation_constraints(self, model): if already_in == 0: size = 0 sbox_lat = sbox.linear_approximation_table() - sbox_declaration = '[' + sbox_declaration = "[" for i in range(sbox_lat.nrows()): for j in range(sbox_lat.ncols()): - sep_bin_i = ','.join(f'{i:0{input_size}b}') - sep_bin_j = ','.join(f'{j:0{output_size}b}') + sep_bin_i = ",".join(f"{i:0{input_size}b}") + sep_bin_j = ",".join(f"{j:0{output_size}b}") if sbox_lat[i, j] != 0: size += 1 bias = round(100 * math.log2(abs(pow(2, input_size - 1) / sbox_lat[i, j]))) - sbox_declaration = sbox_declaration + f'{sep_bin_i},{sep_bin_j},{bias},' - pre_declaration = f'array [1..{size},1..{input_size + output_size + 1}] of int: ' \ - f'LAT_{output_id_link}=array2d(1..{size},1..{input_size + output_size + 1},' - sbox_declaration = pre_declaration + sbox_declaration[:-1] + ']);' + sbox_declaration = sbox_declaration + f"{sep_bin_i},{sep_bin_j},{bias}," + pre_declaration = ( + f"array [1..{size},1..{input_size + output_size + 1}] of int: " + f"LAT_{output_id_link}=array2d(1..{size},1..{input_size + output_size + 1}," + ) + sbox_declaration = pre_declaration + sbox_declaration[:-1] + "]);" cp_declarations.append(sbox_declaration) sbox_mant.append((description, output_id_link)) - cp_declarations.append(f'array[0..{input_size - 1}] of var 0..1: {output_id_link}_i;') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;') - new_constraint = 'constraint table(' + cp_declarations.append(f"array[0..{input_size - 1}] of var 0..1: {output_id_link}_i;") + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;") + new_constraint = "constraint table(" for i in range(input_size): - new_constraint = new_constraint + f'[{output_id_link}_i[{i}]]++' + new_constraint = new_constraint + f"[{output_id_link}_i[{i}]]++" for i in range(output_size): - new_constraint = new_constraint + f'[{output_id_link}_o[{i}]]++' - new_constraint = new_constraint + f'[p[{model.c}]],LAT_{output_id_link_sost});' + new_constraint = new_constraint + f"[{output_id_link}_o[{i}]]++" + new_constraint = new_constraint + f"[p[{model.c}]],LAT_{output_id_link_sost});" cp_constraints.append(new_constraint) model.component_and_probability[output_id_link] = model.c model.c = model.c + 1 + return cp_declarations, cp_constraints def generate_sbox_sign_lat(self): @@ -821,12 +847,8 @@ def get_bit_based_c_code(self, verbosity): sbox_code = [] self.select_bits(sbox_code) - sbox_code.append( - f'\tsubstitution_list = ' - f'(uint64_t[]) {{{", ".join([str(x) for x in self.description])}}};') - sbox_code.append( - f'\tBitString* {self.id} = ' - f'SBOX(input, {self.output_bit_size}, substitution_list);\n') + sbox_code.append(f"\tsubstitution_list = (uint64_t[]) {{{', '.join([str(x) for x in self.description])}}};") + sbox_code.append(f"\tBitString* {self.id} = SBOX(input, {self.output_bit_size}, substitution_list);\n") if verbosity: self.print_values(sbox_code) @@ -836,20 +858,25 @@ def get_bit_based_c_code(self, verbosity): return sbox_code def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - sbox_params = [f'bit_vector_select_word({self.input_id_links[i]}, {self.input_bit_positions[i]})' - for i in range(len(self.input_id_links))] - return [f' {self.id} = bit_vector_SBOX(bit_vector_CONCAT([{",".join(sbox_params)} ]), ' - f'np.array({self.description}, dtype=np.uint8), output_bit_size = {self.output_bit_size})'] + sbox_params = [ + f"bit_vector_select_word({self.input_id_links[i]}, {self.input_bit_positions[i]})" + for i in range(len(self.input_id_links)) + ] + return [ + f" {self.id} = bit_vector_SBOX(bit_vector_CONCAT([{','.join(sbox_params)} ]), " + f"np.array({self.description}, dtype=np.uint8), output_bit_size = {self.output_bit_size})" + ] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_SBOX({params}, {self.description}, {self.input_bit_size})'] + return [f" {self.id} = byte_vector_SBOX({params}, {self.description}, {self.input_bit_size})"] def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): # TODO: consider the option for sbox - return ['\t//// TODO'] + return ["\t//// TODO"] - def milp_large_xor_differential_probability_constraints(self, binary_variable, integer_variable, - non_linear_component_id, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION): + def milp_large_xor_differential_probability_constraints( + self, binary_variable, integer_variable, non_linear_component_id, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION + ): """ Return lists of variables and constrains modeling SBOX component, with input bit size less or equal to 6. @@ -901,13 +928,24 @@ def milp_large_xor_differential_probability_constraints(self, binary_variable, i update_dictionary_that_contains_inequalities_for_large_sboxes(sbox, analysis="differential") dict_product_of_sum = get_dictionary_that_contains_inequalities_for_large_sboxes(analysis="differential") - constraints = milp_set_constraints_from_dictionnary_for_large_sbox(component_id, input_vars, - output_vars, sbox_input_size, sbox_output_size, x, p, - dict_product_of_sum[str(sbox)], analysis="differential", weight_precision=weight_precision) + constraints = milp_set_constraints_from_dictionnary_for_large_sbox( + component_id, + input_vars, + output_vars, + sbox_input_size, + sbox_output_size, + x, + p, + dict_product_of_sum[str(sbox)], + analysis="differential", + weight_precision=weight_precision, + ) return variables, constraints - def milp_large_xor_linear_probability_constraints(self, binary_variable, integer_variable, non_linear_component_id, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION): + def milp_large_xor_linear_probability_constraints( + self, binary_variable, integer_variable, non_linear_component_id, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION + ): """ Return lists of variables and constrains modeling SBOX component, with input bit size less or equal to 6. @@ -957,16 +995,24 @@ def milp_large_xor_linear_probability_constraints(self, binary_variable, integer update_dictionary_that_contains_inequalities_for_large_sboxes(sbox, analysis="linear") dict_product_of_sum = get_dictionary_that_contains_inequalities_for_large_sboxes(analysis="linear") - constraints = milp_set_constraints_from_dictionnary_for_large_sbox(component_id, input_vars, - output_vars, sbox_input_size, - sbox_output_size, x, p, - dict_product_of_sum[str(sbox)], - analysis="linear", weight_precision=weight_precision) + constraints = milp_set_constraints_from_dictionnary_for_large_sbox( + component_id, + input_vars, + output_vars, + sbox_input_size, + sbox_output_size, + x, + p, + dict_product_of_sum[str(sbox)], + analysis="linear", + weight_precision=weight_precision, + ) return variables, constraints - def milp_small_xor_differential_probability_constraints(self, binary_variable, integer_variable, - non_linear_component_id, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION): + def milp_small_xor_differential_probability_constraints( + self, binary_variable, integer_variable, non_linear_component_id, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION + ): """ Return a list of variables and a list of constrains modeling a component of type SBOX. @@ -1026,29 +1072,36 @@ def milp_small_xor_differential_probability_constraints(self, binary_variable, i constraints.append(x[f"{self.id}_active"] >= x[output_vars[i]]) # mip.add_constraint(sum(x[output_vars[i]] for i in range(sbox.input_size())) >= x[id + "_active"]) - M = (10 ** weight_precision) * max(input_size, output_size) + M = (10**weight_precision) * max(input_size, output_size) dict_constraints = {} for proba in dict_inequalities: dict_constraints[proba] = [] for ineq in dict_inequalities[proba]: - dict_constraints[proba].append(sum(x[input_vars[i]] * ineq[i + 1] for i in range(len(input_vars))) + - sum(x[output_vars[i]] * ineq[i + 1 + len(input_vars)] for i in - range(len(output_vars))) + - ineq[0] + M * (1 - x[f"{self.id}_proba_{proba}"]) >= 0) + dict_constraints[proba].append( + sum(x[input_vars[i]] * ineq[i + 1] for i in range(len(input_vars))) + + sum(x[output_vars[i]] * ineq[i + 1 + len(input_vars)] for i in range(len(output_vars))) + + ineq[0] + + M * (1 - x[f"{self.id}_proba_{proba}"]) + >= 0 + ) for proba in dict_constraints: constraints.extend(dict_constraints[proba]) + constraints.append(sum(x[f"{self.id}_proba_{proba}"] for proba in dict_constraints) == x[f"{self.id}_active"]) constraints.append( - sum(x[f"{self.id}_proba_{proba}"] for proba in dict_constraints) == x[f"{self.id}_active"]) - constraints.append(p[f"{self.id}_probability"] == (10 ** weight_precision) * sum( - x[f"{self.id}_proba_{proba}"] * (-log(proba / 2 ** sbox.input_size(), 2)) for proba in - dict_constraints)) + p[f"{self.id}_probability"] + == (10**weight_precision) + * sum( + x[f"{self.id}_proba_{proba}"] * (-log(proba / 2 ** sbox.input_size(), 2)) for proba in dict_constraints + ) + ) return variables, constraints - def milp_small_xor_linear_probability_constraints(self, binary_variable, integer_variable, non_linear_component_id, - weight_precision=MILP_DEFAULT_WEIGHT_PRECISION): + def milp_small_xor_linear_probability_constraints( + self, binary_variable, integer_variable, non_linear_component_id, weight_precision=MILP_DEFAULT_WEIGHT_PRECISION + ): """ Return a list of variables and a list of constrains modeling a component of type Sbox. @@ -1112,26 +1165,35 @@ def milp_small_xor_linear_probability_constraints(self, binary_variable, integer # Big-M Reformulation method as used in 4.1 of # https://tosc.iacr.org/index.php/ToSC/article/view/805/759 - M = (10 ** weight_precision) * max(input_size, output_size) + M = (10**weight_precision) * max(input_size, output_size) dict_constraints = {} for proba in dict_inequalities: dict_constraints[proba] = [] for ineq in dict_inequalities[proba]: - dict_constraints[proba].append(sum(x[input_vars[i]] * ineq[i + 1] for i in range(len(input_vars))) + - sum(x[output_vars[i]] * ineq[i + 1 + len(input_vars)] - for i in range(len(output_vars))) + - ineq[0] + M * (1 - x[f"{component_id}_proba_{proba}"]) >= 0) + dict_constraints[proba].append( + sum(x[input_vars[i]] * ineq[i + 1] for i in range(len(input_vars))) + + sum(x[output_vars[i]] * ineq[i + 1 + len(input_vars)] for i in range(len(output_vars))) + + ineq[0] + + M * (1 - x[f"{component_id}_proba_{proba}"]) + >= 0 + ) for proba in dict_constraints: constraints.extend(dict_constraints[proba]) constraints.append( - sum(x[f"{component_id}_proba_{proba}"] for proba in dict_constraints) == x[f"{component_id}_active"]) + sum(x[f"{component_id}_proba_{proba}"] for proba in dict_constraints) == x[f"{component_id}_active"] + ) # correlation[i,j] = 2p[i,j] - 1, where p[i,j] = LAT[i,j] / 2^n + 1/2 - constraints.append(p[f"{component_id}_probability"] == (10 ** weight_precision) * sum(x[f"{component_id}_proba_{proba}"] * - (log((2 ** (sbox.input_size() - 1)) / abs( - proba), 2)) for proba in dict_constraints)) + constraints.append( + p[f"{component_id}_probability"] + == (10**weight_precision) + * sum( + x[f"{component_id}_proba_{proba}"] * (log((2 ** (sbox.input_size() - 1)) / abs(proba), 2)) + for proba in dict_constraints + ) + ) return variables, constraints @@ -1169,10 +1231,9 @@ def milp_xor_differential_propagation_constraints(self, model): integer_variable = model.integer_variable non_linear_component_id = model.non_linear_component_id weight_precision = model.weight_precision - variables, constraints = self.milp_large_xor_differential_probability_constraints(binary_variable, - integer_variable, - non_linear_component_id, - weight_precision) + variables, constraints = self.milp_large_xor_differential_probability_constraints( + binary_variable, integer_variable, non_linear_component_id, weight_precision + ) return variables, constraints @@ -1211,9 +1272,9 @@ def milp_xor_linear_mask_propagation_constraints(self, model): integer_variable = model.integer_variable non_linear_component_id = model.non_linear_component_id weight_precision = model.weight_precision - variables, constraints = self.milp_large_xor_linear_probability_constraints(binary_variable, - integer_variable, - non_linear_component_id, weight_precision) + variables, constraints = self.milp_large_xor_linear_probability_constraints( + binary_variable, integer_variable, non_linear_component_id, weight_precision + ) return variables, constraints def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, model): @@ -1269,24 +1330,30 @@ def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, mod input_class_tuple, output_class_tuple = self._get_wordwise_input_output_linked_class_tuples(model) - variables = [(f"x[{var_elt}]", x[var_elt]) for var_tuple in input_class_tuple + output_class_tuple for var_elt in var_tuple] + variables = [ + (f"x[{var_elt}]", x[var_elt]) + for var_tuple in input_class_tuple + output_class_tuple + for var_elt in var_tuple + ] input_vars = [x[i] for _ in input_class_tuple for i in _] output_vars = [x[i] for _ in output_class_tuple for i in _] - constraints = [1 + output_vars[1] >= input_vars[0] + input_vars[1], - input_vars[0] + input_vars[1] >= output_vars[0], - input_vars[0] >= output_vars[1], - input_vars[1] >= output_vars[1], - output_vars[0] >= input_vars[1], - output_vars[0] >= input_vars[0]] + constraints = [ + 1 + output_vars[1] >= input_vars[0] + input_vars[1], + input_vars[0] + input_vars[1] >= output_vars[0], + input_vars[0] >= output_vars[1], + input_vars[1] >= output_vars[1], + output_vars[0] >= input_vars[1], + output_vars[0] >= input_vars[0], + ] return variables, constraints def milp_wordwise_deterministic_truncated_xor_differential_simple_constraints(self, model): """ Models the wordwise Sbox component according to a simplified version of Model 4 from [SGWW2020]_ - + The valid set for the input output pair (x, y) is {(0, 0), (1, 2), (2, 2), (3, 3)} if dX = 1 @@ -1331,8 +1398,9 @@ def milp_wordwise_deterministic_truncated_xor_differential_simple_constraints(se var_if, if_constraints = milp_utils.milp_eq(model, x_class[input], 1, big_m) then_constraints = [x_class[output] == 2] else_constraints = [x_class[input] == x_class[output]] - constraints.extend(if_constraints + milp_utils.milp_if_then_else(var_if, then_constraints, else_constraints, - big_m)) + constraints.extend( + if_constraints + milp_utils.milp_if_then_else(var_if, then_constraints, else_constraints, big_m) + ) return variables, constraints @@ -1374,17 +1442,17 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode variables = [(f"x_class[{var}]", x_class[var]) for var in input_class_vars + output_class_vars] constraints = [] - input_sum = sum([x_class[input] for input in input_class_vars]) + input_sum = sum(x_class[input] for input in input_class_vars) # if sum(x_class[input]) <= 0 (i.e. all x_class[input] == 0) d_leq, c_leq = milp_utils.milp_leq(model, input_sum, 0, 2 * len(input_class_vars)) constraints += c_leq # then all outputs are 0's, else they are all 2's - constraints += milp_utils.milp_if_then_else(d_leq, [x_class[_] == 0 for _ in output_class_vars], - [x_class[_] == 2 for _ in output_class_vars], 2) + constraints += milp_utils.milp_if_then_else( + d_leq, [x_class[_] == 0 for _ in output_class_vars], [x_class[_] == 2 for _ in output_class_vars], 2 + ) return variables, constraints - def milp_undisturbed_bits_bitwise_deterministic_truncated_xor_differential_constraints(self, model): """ Models the wordwise Sbox component, with added undisturbed bits information, as mentioned in [CZZ2023]_ @@ -1433,11 +1501,13 @@ def milp_undisturbed_bits_bitwise_deterministic_truncated_xor_differential_const input_id_tuples, output_id_tuples = self._get_input_output_variables_tuples() input_ids, output_ids = self._get_input_output_variables() - linking_constraints = model.link_binary_tuples_to_integer_variables(input_id_tuples + output_id_tuples, - input_ids + output_ids) + linking_constraints = model.link_binary_tuples_to_integer_variables( + input_id_tuples + output_id_tuples, input_ids + output_ids + ) - variables = [(f"x[{var_elt}]", x[var_elt]) for var_tuple in input_id_tuples + output_id_tuples for var_elt in - var_tuple] + variables = [ + (f"x[{var_elt}]", x[var_elt]) for var_tuple in input_id_tuples + output_id_tuples for var_elt in var_tuple + ] constraints = [] + linking_constraints input_vars = [tuple(x[i] for i in _) for _ in input_id_tuples] @@ -1487,17 +1557,17 @@ def sat_constraints(self): '-xor_0_0_4 -xor_0_0_5 -xor_0_0_6 -xor_0_0_7 sbox_0_2_2', '-xor_0_0_4 -xor_0_0_5 -xor_0_0_6 -xor_0_0_7 -sbox_0_2_3']) """ - input_bit_len, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() sbox_outputs = self.description constraints = [] for sbox_input, sbox_output in enumerate(sbox_outputs): - input_signs = ('-' * (sbox_input >> j & 1) for j in reversed(range(input_bit_len))) - current_input_bit_ids = (f'{sign}{bit_id}' for sign, bit_id in zip(input_signs, input_bit_ids)) - output_signs = ('-' * ((sbox_output >> j & 1) ^ 1) for j in reversed(range(output_bit_len))) - current_output_bit_ids = (f'{sign}{bit_id}' for sign, bit_id in zip(output_signs, output_bit_ids)) - input_constraint = ' '.join(current_input_bit_ids) - current_constraints = (f'{input_constraint} {bit_id}' for bit_id in current_output_bit_ids) + input_signs = ("-" * (sbox_input >> j & 1) for j in reversed(range(self.input_bit_size))) + current_input_bit_ids = (f"{sign}{bit_id}" for sign, bit_id in zip(input_signs, input_bit_ids)) + output_signs = ("-" * ((sbox_output >> j & 1) ^ 1) for j in reversed(range(output_bit_len))) + current_output_bit_ids = (f"{sign}{bit_id}" for sign, bit_id in zip(output_signs, output_bit_ids)) + input_constraint = " ".join(current_input_bit_ids) + current_constraints = (f"{input_constraint} {bit_id}" for bit_id in current_output_bit_ids) constraints.extend(current_constraints) return output_bit_ids, constraints @@ -1534,14 +1604,18 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): espresso_input_length = 2 * (len(valid_transitions[0][0]) + len(valid_transitions[0][1])) espresso_input = [f".i {espresso_input_length}", ".o 1"] for transition in valid_transitions: - espresso_condition = ['0'*(value == 0 or value == 1) + '1'*(value == 2) for value in transition[0]] - espresso_condition += ['0'*(value == 0) + '1'*(value == 1) + '-'*(value == 2) for value in transition[0]] - espresso_condition += ['0'*(value == 0 or value == 1) + '1'*(value == 2) for value in transition[1]] - espresso_condition += ['0'*(value == 0) + '1'*(value == 1) + '-'*(value == 2) for value in transition[1]] + espresso_condition = ["0" * (value == 0 or value == 1) + "1" * (value == 2) for value in transition[0]] + espresso_condition += [ + "0" * (value == 0) + "1" * (value == 1) + "-" * (value == 2) for value in transition[0] + ] + espresso_condition += ["0" * (value == 0 or value == 1) + "1" * (value == 2) for value in transition[1]] + espresso_condition += [ + "0" * (value == 0) + "1" * (value == 1) + "-" * (value == 2) for value in transition[1] + ] espresso_input += ["".join(espresso_condition) + " 1"] espresso_input += [".e"] espresso_input = "\n".join(espresso_input) - espresso_process = subprocess.run(['espresso', '-epos'], input=espresso_input, capture_output=True, text=True) + espresso_process = subprocess.run(["espresso", "-epos"], input=espresso_input, capture_output=True, text=True) espresso_output = espresso_process.stdout.splitlines() # building constraints input_ids_0, input_ids_1 = self._generate_input_double_ids() @@ -1551,8 +1625,8 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): ids = input_ids + output_ids constraints = [] for line in espresso_output[4:-1]: - literals = ['-' * int(line[i]) + ids[i] for i in range(espresso_input_length) if line[i] != '-'] - constraints.append(' '.join(literals)) + literals = ["-" * int(line[i]) + ids[i] for i in range(espresso_input_length) if line[i] != "-"] + constraints.append(" ".join(literals)) return output_ids, constraints @@ -1591,27 +1665,27 @@ def sat_xor_differential_propagation_constraints(self, model=None): 'xor_0_0_5 xor_0_0_6 sbox_0_2_0 sbox_0_2_2 -hw_sbox_0_2_1', '-hw_sbox_0_2_0']) """ - input_bit_len, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] sbox_values = self.description # if optimized SAT DDT template is not initialized in instance fields, compute it - if f'{sbox_values}' not in self.sboxes_ddt_templates: + if f"{sbox_values}" not in self.sboxes_ddt_templates: ddt = SBox(sbox_values).difference_distribution_table() - check_table_feasibility(ddt, 'DDT', 'SAT') + check_table_feasibility(ddt, "DDT", "SAT") - get_hamming_weight_function = (lambda input_bit_len, entry: input_bit_len - int(math.log2(entry))) - template = sat_build_table_template(ddt, get_hamming_weight_function, input_bit_len, output_bit_len) - self.sboxes_ddt_templates[f'{sbox_values}'] = template + get_hamming_weight_function = lambda input_bit_len, entry: input_bit_len - int(math.log2(entry)) + template = sat_build_table_template(ddt, get_hamming_weight_function, self.input_bit_size, output_bit_len) + self.sboxes_ddt_templates[f"{sbox_values}"] = template bit_ids = input_bit_ids + output_bit_ids + hw_bit_ids - template = self.sboxes_ddt_templates[f'{sbox_values}'] + template = self.sboxes_ddt_templates[f"{sbox_values}"] constraints = [] for clause in template: - literals = ['-' * value[0] + bit_ids[value[1]] for value in clause] - constraints.append(' '.join(literals)) + literals = ["-" * value[0] + bit_ids[value[1]] for value in clause] + constraints.append(" ".join(literals)) return output_bit_ids + hw_bit_ids, constraints @@ -1655,25 +1729,25 @@ def sat_xor_linear_mask_propagation_constraints(self, model=None): input_bit_len, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX output_bit_len, output_bit_ids = self._generate_output_ids(suffix=out_suffix) - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(input_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(input_bit_len)] sbox_values = self.description # if optimized SAT LAT template is not initialized in instance fields, compute it - if f'{sbox_values}' not in self.sboxes_lat_templates: + if f"{sbox_values}" not in self.sboxes_lat_templates: lat = SBox(sbox_values).linear_approximation_table() - check_table_feasibility(lat, 'LAT', 'SAT') + check_table_feasibility(lat, "LAT", "SAT") - get_hamming_weight_function = (lambda input_bit_len, entry: input_bit_len - int(math.log2(abs(entry))) - 1) + get_hamming_weight_function = lambda input_bit_len, entry: input_bit_len - int(math.log2(abs(entry))) - 1 template = sat_build_table_template(lat, get_hamming_weight_function, input_bit_len, output_bit_len) - self.sboxes_lat_templates[f'{sbox_values}'] = template + self.sboxes_lat_templates[f"{sbox_values}"] = template bit_ids = input_bit_ids + output_bit_ids + hw_bit_ids - template = self.sboxes_lat_templates[f'{sbox_values}'] + template = self.sboxes_lat_templates[f"{sbox_values}"] constraints = [] for clause in template: - literals = ['-' * value[0] + bit_ids[value[1]] for value in clause] - constraints.append(' '.join(literals)) + literals = ["-" * value[0] + bit_ids[value[1]] for value in clause] + constraints.append(" ".join(literals)) return bit_ids, constraints @@ -1702,21 +1776,23 @@ def smt_constraints(self): '(assert (=> (and xor_0_0_0 xor_0_0_1 xor_0_0_2 (not xor_0_0_3)) (and (not sbox_0_1_0) (not sbox_0_1_1) (not sbox_0_1_2) sbox_0_1_3)))', '(assert (=> (and xor_0_0_0 xor_0_0_1 xor_0_0_2 xor_0_0_3) (and (not sbox_0_1_0) (not sbox_0_1_1) sbox_0_1_2 (not sbox_0_1_3))))']) """ - input_bit_len, input_bit_ids = self._generate_input_ids() - output_bit_len, output_bit_ids = self._generate_output_ids() - sbox_values = self.description + input_bit_ids = self._generate_input_ids() + _, output_bit_ids = self._generate_output_ids() + sbox = self.description constraints = [] - for i in range(len(sbox_values)): - input_difference_lits = [input_bit_ids[j] - if i >> (input_bit_len - 1 - j) & 1 - else smt_utils.smt_not(input_bit_ids[j]) - for j in range(input_bit_len)] - input_difference = smt_utils.smt_and(input_difference_lits) - output_difference_lits = [output_bit_ids[j] - if sbox_values[i] >> (output_bit_len - 1 - j) & 1 - else smt_utils.smt_not(output_bit_ids[j]) - for j in range(output_bit_len)] - output_difference = smt_utils.smt_and(output_difference_lits) + for in_value, out_value in enumerate(sbox): + bits = map(int, f"{in_value:0{self.input_bit_size}b}") + input_literals = [ + input_bit_id if bit else smt_utils.smt_not(input_bit_id) + for input_bit_id, bit in zip(input_bit_ids, bits) + ] + input_difference = smt_utils.smt_and(input_literals) + bits = map(int, f"{out_value:0{self.input_bit_size}b}") + output_literals = [ + output_bit_id if bit else smt_utils.smt_not(output_bit_id) + for output_bit_id, bit in zip(output_bit_ids, bits) + ] + output_difference = smt_utils.smt_and(output_literals) implication = smt_utils.smt_implies(input_difference, output_difference) constraints.append(smt_utils.smt_assert(implication)) @@ -1751,24 +1827,24 @@ def smt_xor_differential_propagation_constraints(self, model): '(assert (or (not hw_sbox_0_5_1)))', '(assert (or (not hw_sbox_0_5_0)))']) """ - input_bit_len, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(output_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(output_bit_len)] sbox_values = self.description sboxes_ddt_templates = model.sboxes_ddt_templates # if optimized DDT template is not initialized in instance fields, compute it - if f'{sbox_values}' not in sboxes_ddt_templates: + if f"{sbox_values}" not in sboxes_ddt_templates: ddt = SBox(sbox_values).difference_distribution_table() - check_table_feasibility(ddt, 'DDT', 'SMT') + check_table_feasibility(ddt, "DDT", "SMT") - get_hamming_weight_function = (lambda input_bit_len, entry: input_bit_len - int(math.log2(entry))) - template = smt_build_table_template(ddt, get_hamming_weight_function, input_bit_len, output_bit_len) - sboxes_ddt_templates[f'{sbox_values}'] = template + get_hamming_weight_function = lambda input_bit_len, entry: input_bit_len - int(math.log2(entry)) + template = smt_build_table_template(ddt, get_hamming_weight_function, self.input_bit_size, output_bit_len) + sboxes_ddt_templates[f"{sbox_values}"] = template bit_ids = input_bit_ids + output_bit_ids + hw_bit_ids - template = sboxes_ddt_templates[f'{sbox_values}'] + template = sboxes_ddt_templates[f"{sbox_values}"] constraints = smt_get_sbox_probability_constraints(bit_ids, template) return output_bit_ids + hw_bit_ids, constraints @@ -1805,22 +1881,22 @@ def smt_xor_linear_mask_propagation_constraints(self, model): input_bit_len, input_bit_ids = self._generate_component_input_ids() out_suffix = constants.OUTPUT_BIT_ID_SUFFIX output_bit_len, output_bit_ids = self._generate_output_ids(suffix=out_suffix) - hw_bit_ids = [f'hw_{output_bit_ids[i]}' for i in range(input_bit_len)] + hw_bit_ids = [f"hw_{output_bit_ids[i]}" for i in range(input_bit_len)] sbox_values = self.description sboxes_lat_templates = model.sboxes_lat_templates # if optimized LAT template is not initialized in instance fields, compute it - if f'{sbox_values}' not in sboxes_lat_templates: + if f"{sbox_values}" not in sboxes_lat_templates: lat = SBox(sbox_values).linear_approximation_table() - check_table_feasibility(lat, 'LAT', 'SMT') + check_table_feasibility(lat, "LAT", "SMT") - get_hamming_weight_function = (lambda input_bit_len, entry: input_bit_len - int(math.log2(abs(entry))) - 1) + get_hamming_weight_function = lambda input_bit_len, entry: input_bit_len - int(math.log2(abs(entry))) - 1 template = smt_build_table_template(lat, get_hamming_weight_function, input_bit_len, output_bit_len) - sboxes_lat_templates[f'{sbox_values}'] = template + sboxes_lat_templates[f"{sbox_values}"] = template bit_ids = input_bit_ids + output_bit_ids + hw_bit_ids - template = sboxes_lat_templates[f'{sbox_values}'] + template = sboxes_lat_templates[f"{sbox_values}"] constraints = smt_get_sbox_probability_constraints(bit_ids, template) return bit_ids, constraints diff --git a/claasp/components/shift_component.py b/claasp/components/shift_component.py index 260159a92..2307a6b9c 100644 --- a/claasp/components/shift_component.py +++ b/claasp/components/shift_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,14 +20,22 @@ from claasp.component import Component from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.cipher_modules.models.sat.utils import constants, utils as sat_utils +from claasp.name_mappings import WORD_OPERATION class SHIFT(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter): - component_id = f'shift_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' - description = ['SHIFT', parameter] + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ): + component_id = f"shift_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION + description = ["SHIFT", parameter] component_input = Input(output_bit_size, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -60,14 +67,15 @@ def algebraic_polynomials(self, model): ninputs = noutputs = self.output_bit_size shift_constant = self.description[1] % noutputs - input_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] + input_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] ring_R = model.ring() x = list(map(ring_R, input_vars)) y = list(map(ring_R, output_vars)) - polynomials = [y[i] for i in range(shift_constant)] + \ - [y[shift_constant:][i] + x[i] for i in range(noutputs - shift_constant)] + polynomials = [y[i] for i in range(shift_constant)] + [ + y[shift_constant:][i] + x[i] for i in range(noutputs - shift_constant) + ] return polynomials @@ -128,24 +136,30 @@ def cp_constraints(self): 'constraint shift_0_0[30] = 0;', 'constraint shift_0_0[31] = 0;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions shift_amount = abs(self.description[1]) cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) if shift_amount == self.description[1]: - cp_constraints = [f'constraint {output_id_link}[{i}] = 0;' for i in range(shift_amount)] - cp_constraints.extend([f'constraint {output_id_link}[{i}] = {all_inputs[i - shift_amount]};' - for i in range(shift_amount, output_size)]) + cp_constraints = [f"constraint {self.id}[{i}] = 0;" for i in range(shift_amount)] + cp_constraints.extend( + [ + f"constraint {self.id}[{i}] = {all_inputs[i - shift_amount]};" + for i in range(shift_amount, self.output_bit_size) + ] + ) else: - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[i + shift_amount]};' - for i in range(output_size - shift_amount)] - cp_constraints.extend([f'constraint {output_id_link}[{i}] = 0;' - for i in range(output_size - shift_amount, output_size)]) + cp_constraints = [ + f"constraint {self.id}[{i}] = {all_inputs[i + shift_amount]};" + for i in range(self.output_bit_size - shift_amount) + ] + cp_constraints.extend( + [ + f"constraint {self.id}[{i}] = 0;" + for i in range(self.output_bit_size - shift_amount, self.output_bit_size) + ] + ) return cp_declarations, cp_constraints @@ -174,24 +188,30 @@ def cp_inverse_constraints(self): ... 'constraint shift_0_0_inverse[31] = 0;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions shift_amount = abs(self.description[1]) cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) if shift_amount == self.description[1]: - cp_constraints = [f'constraint {output_id_link}_inverse[{i}] = 0;' for i in range(shift_amount)] - cp_constraints.extend([f'constraint {output_id_link}_inverse[{i}] = {all_inputs[i - shift_amount]};' - for i in range(shift_amount, output_size)]) + cp_constraints = [f"constraint {self.id}_inverse[{i}] = 0;" for i in range(shift_amount)] + cp_constraints.extend( + [ + f"constraint {self.id}_inverse[{i}] = {all_inputs[i - shift_amount]};" + for i in range(shift_amount, self.output_bit_size) + ] + ) else: - cp_constraints = [f'constraint {output_id_link}_inverse[{i}] = {all_inputs[i + shift_amount]};' - for i in range(output_size - shift_amount)] - cp_constraints.extend([f'constraint {output_id_link}_inverse[{i}] = 0;' - for i in range(output_size - shift_amount, output_size)]) + cp_constraints = [ + f"constraint {self.id}_inverse[{i}] = {all_inputs[i + shift_amount]};" + for i in range(self.output_bit_size - shift_amount) + ] + cp_constraints.extend( + [ + f"constraint {self.id}_inverse[{i}] = 0;" + for i in range(self.output_bit_size - shift_amount, self.output_bit_size) + ] + ) return cp_declarations, cp_constraints @@ -220,38 +240,64 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model 'constraint shift_0_18_value[3] = 0;']) """ output_size = int(self.output_bit_size) - input_id_link = self.input_id_links output_id_link = self.id - input_bit_positions = self.input_bit_positions word_size = model.word_size shift_amount = abs(self.description[1]) // word_size all_inputs_active = [] all_inputs_value = [] cp_declarations = [] - for id_link, bit_positions in zip(input_id_link, input_bit_positions): - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - for id_link, bit_positions in zip(input_id_link, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) if shift_amount == self.description[1]: - cp_constraints = [f'constraint {output_id_link}_active[{i}] = 0;' for i in range(shift_amount)] - cp_constraints.extend([f'constraint {output_id_link}_active[{i}] = {all_inputs_active[i - shift_amount]};' - for i in range(shift_amount, output_size // word_size)]) - cp_constraints.extend([f'constraint {output_id_link}_value[{i}] = 0;' for i in range(shift_amount)]) - cp_constraints.extend([f'constraint {output_id_link}_value[{i}] = {all_inputs_active[i - shift_amount]};' - for i in range(shift_amount, output_size // word_size)]) + cp_constraints = [f"constraint {output_id_link}_active[{i}] = 0;" for i in range(shift_amount)] + cp_constraints.extend( + [ + f"constraint {output_id_link}_active[{i}] = {all_inputs_active[i - shift_amount]};" + for i in range(shift_amount, output_size // word_size) + ] + ) + cp_constraints.extend([f"constraint {output_id_link}_value[{i}] = 0;" for i in range(shift_amount)]) + cp_constraints.extend( + [ + f"constraint {output_id_link}_value[{i}] = {all_inputs_active[i - shift_amount]};" + for i in range(shift_amount, output_size // word_size) + ] + ) else: - cp_constraints = [f'constraint {output_id_link}_active[{i}] = {all_inputs_active[i + shift_amount]};' - for i in range(output_size // word_size - shift_amount)] - cp_constraints.extend([f'constraint {output_id_link}_active[{i}] = 0;' - for i in - range(output_size // word_size - shift_amount, output_size // word_size)]) - cp_constraints.extend([f'constraint {output_id_link}_value[{i}] = {all_inputs_active[i + shift_amount]};' - for i in range(output_size // word_size - shift_amount)]) - cp_constraints.extend([f'constraint {output_id_link}_value[{i}] = 0;' - for i in - range(output_size // word_size - shift_amount, output_size // word_size)]) + cp_constraints = [ + f"constraint {output_id_link}_active[{i}] = {all_inputs_active[i + shift_amount]};" + for i in range(output_size // word_size - shift_amount) + ] + cp_constraints.extend( + [ + f"constraint {output_id_link}_active[{i}] = 0;" + for i in range(output_size // word_size - shift_amount, output_size // word_size) + ] + ) + cp_constraints.extend( + [ + f"constraint {output_id_link}_value[{i}] = {all_inputs_active[i + shift_amount]};" + for i in range(output_size // word_size - shift_amount) + ] + ) + cp_constraints.extend( + [ + f"constraint {output_id_link}_value[{i}] = 0;" + for i in range(output_size // word_size - shift_amount, output_size // word_size) + ] + ) return cp_declarations, cp_constraints @@ -290,11 +336,12 @@ def cp_xor_differential_first_step_constraints(self, model): for i in range(numb_of_inp): for j in range(len(input_bit_positions[i]) // model.word_size): all_inputs.append( - f'{input_id_link[i]}[{input_bit_positions[i][j * model.word_size] // model.word_size}]') + f"{input_id_link[i]}[{input_bit_positions[i][j * model.word_size] // model.word_size}]" + ) rem = len(input_bit_positions[i]) % model.word_size if rem != 0: rem = model.word_size - (len(input_bit_positions[i]) % model.word_size) - all_inputs.append(f'{output_id_link}_i[{number_of_mix}]') + all_inputs.append(f"{output_id_link}_i[{number_of_mix}]") number_of_mix += 1 is_mix = True l = 1 @@ -303,20 +350,29 @@ def cp_xor_differential_first_step_constraints(self, model): del input_bit_positions[i + l][0:rem] rem -= length l += 1 - cp_declarations = [f'array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};'] + cp_declarations = [f"array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};"] if is_mix: - cp_declarations.append(f'array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;') + cp_declarations.append(f"array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;") if shift_amount == self.description[1]: - cp_constraints = [f'constraint {output_id_link}[{i}] = 0;' for i in range(shift_amount)] - cp_constraints.extend([f'constraint {output_id_link}[{i}] = {all_inputs[i - shift_amount]};' - for i in range(shift_amount, output_size // model.word_size)]) + cp_constraints = [f"constraint {output_id_link}[{i}] = 0;" for i in range(shift_amount)] + cp_constraints.extend( + [ + f"constraint {output_id_link}[{i}] = {all_inputs[i - shift_amount]};" + for i in range(shift_amount, output_size // model.word_size) + ] + ) else: - cp_constraints = [f'constraint {output_id_link}[{i}] = {all_inputs[i + shift_amount]};' - for i in range(output_size // model.word_size - shift_amount)] - cp_constraints.extend([ - f'constraint {output_id_link}[{i}] = 0;' - for i in range(output_size // model.word_size - shift_amount, output_size // model.word_size)]) + cp_constraints = [ + f"constraint {output_id_link}[{i}] = {all_inputs[i + shift_amount]};" + for i in range(output_size // model.word_size - shift_amount) + ] + cp_constraints.extend( + [ + f"constraint {output_id_link}[{i}] = 0;" + for i in range(output_size // model.word_size - shift_amount, output_size // model.word_size) + ] + ) return cp_declarations, cp_constraints @@ -350,31 +406,33 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): ... 'constraint shift_0_0_i[3]=0;']) """ - output_size = int(self.output_bit_size) + output_size = self.output_bit_size output_id_link = self.id shift_amount = abs(self.description[1]) + cp_declarations = [ + f"array[0..{output_size - 1}] of var 0..1: {output_id_link}_i;", + f"array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;", + ] cp_constraints = [] - cp_declarations = [f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_i;', - f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;'] if shift_amount == self.description[1]: for i in range(output_size - shift_amount, output_size): - cp_constraints.append(f'constraint {output_id_link}_i[{i}]=0;') + cp_constraints.append(f"constraint {output_id_link}_i[{i}]=0;") for i in range(shift_amount, output_size): - cp_constraints.append(f'constraint {output_id_link}_o[{i}]={output_id_link}_i[{i - shift_amount}];') + cp_constraints.append(f"constraint {output_id_link}_o[{i}]={output_id_link}_i[{i - shift_amount}];") else: for i in range(output_size - shift_amount): - cp_constraints.append(f'constraint {output_id_link}_o[{i}]={output_id_link}_i[{i + shift_amount}];') + cp_constraints.append(f"constraint {output_id_link}_o[{i}]={output_id_link}_i[{i + shift_amount}];") for i in range(shift_amount): - cp_constraints.append(f'constraint {output_id_link}_i[{i}]=0;') + cp_constraints.append(f"constraint {output_id_link}_i[{i}]=0;") result = cp_declarations, cp_constraints return result def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_SHIFT([{",".join(params)} ], {self.description[1]})'] + return [f" {self.id} = bit_vector_SHIFT([{','.join(params)} ], {self.description[1]})"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_SHIFT({params}, {self.description[1]})'] + return [f" {self.id} = byte_vector_SHIFT({params}, {self.description[1]})"] def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): shift_code = [] @@ -383,8 +441,8 @@ def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): wordstring_variables.append(self.id) direction = "RIGHT" if self.description[1] >= 0 else "LEFT" shift_code.append( - f'\tWordString *{self.id} = ' - f'{direction}_{self.description[0]}(input, {abs(self.description[1])});') + f"\tWordString *{self.id} = {direction}_{self.description[0]}(input, {abs(self.description[1])});" + ) if verbosity: self.print_word_values(shift_code) @@ -395,10 +453,10 @@ def get_word_operation_sign(self, sign, solution): output_id_link = self.id component_sign = 1 sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign @@ -648,10 +706,12 @@ def minizinc_constraints(self, model): if shift_const < 0: shift_mzn_constraints = [ - f'constraint LSHIFT({mzn_input_array_1}, {int(-1*shift_const)})={mzn_output_array_1};\n'] + f"constraint LSHIFT({mzn_input_array_1}, {int(-1 * shift_const)})={mzn_output_array_1};\n" + ] else: shift_mzn_constraints = [ - f'constraint RSHIFT({mzn_input_array_1}, {int(shift_const)})={mzn_output_array_1};\n'] + f"constraint RSHIFT({mzn_input_array_1}, {int(shift_const)})={mzn_output_array_1};\n" + ] return var_names, shift_mzn_constraints @@ -695,7 +755,7 @@ def sat_constraints(self): '-shift_0_0_30', '-shift_0_0_31']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() shift_amount = self.description[1] constraints = [] @@ -704,10 +764,10 @@ def sat_constraints(self): for i in range(output_bit_len - shift_amount): constraints.extend(sat_utils.cnf_equivalent([output_bit_ids[i], input_bit_ids[i + shift_amount]])) for i in range(output_bit_len - shift_amount, output_bit_len): - constraints.append(f'-{output_bit_ids[i]}') + constraints.append(f"-{output_bit_ids[i]}") else: for i in range(shift_amount): - constraints.append(f'-{output_bit_ids[i]}') + constraints.append(f"-{output_bit_ids[i]}") for i in range(shift_amount, output_bit_len): constraints.extend(sat_utils.cnf_equivalent([output_bit_ids[i], input_bit_ids[i - shift_amount]])) @@ -756,12 +816,12 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): constraints.extend(sat_utils.cnf_equivalent([out_ids_0[i], in_ids_0[i + shift_amount]])) constraints.extend(sat_utils.cnf_equivalent([out_ids_1[i], in_ids_1[i + shift_amount]])) for i in range(out_len - shift_amount, out_len): - constraints.append(f'-{out_ids_0[i]}') - constraints.append(f'-{out_ids_1[i]}') + constraints.append(f"-{out_ids_0[i]}") + constraints.append(f"-{out_ids_1[i]}") else: for i in range(shift_amount): - constraints.append(f'-{out_ids_0[i]}') - constraints.append(f'-{out_ids_1[i]}') + constraints.append(f"-{out_ids_0[i]}") + constraints.append(f"-{out_ids_1[i]}") for i in range(shift_amount, out_len): constraints.extend(sat_utils.cnf_equivalent([out_ids_0[i], in_ids_0[i - shift_amount]])) constraints.extend(sat_utils.cnf_equivalent([out_ids_1[i], in_ids_1[i - shift_amount]])) @@ -840,14 +900,13 @@ def sat_xor_linear_mask_propagation_constraints(self, model=None): constraints = [] if shift_amount < 0: shift_amount = -shift_amount - constraints.extend([f'-{input_bit_ids[i]}' for i in range(shift_amount)]) + constraints.extend([f"-{input_bit_ids[i]}" for i in range(shift_amount)]) for i in range(output_bit_len - shift_amount): constraints.extend(sat_utils.cnf_equivalent([output_bit_ids[i], input_bit_ids[i + shift_amount]])) else: for i in range(shift_amount, output_bit_len): constraints.extend(sat_utils.cnf_equivalent([output_bit_ids[i], input_bit_ids[i - shift_amount]])) - constraints.extend([f'-{input_bit_ids[i]}' - for i in range(output_bit_len - shift_amount, output_bit_len)]) + constraints.extend([f"-{input_bit_ids[i]}" for i in range(output_bit_len - shift_amount, output_bit_len)]) result = input_bit_ids + output_bit_ids, constraints return result @@ -883,7 +942,7 @@ def smt_constraints(self): '(assert (not shift_0_0_30))', '(assert (not shift_0_0_31))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() shift_amount = self.description[1] constraints = [] @@ -973,8 +1032,7 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): constraints = [] if shift_amount < 0: shift_amount = -shift_amount - constraints.extend([smt_utils.smt_assert(smt_utils.smt_not(input_bit_ids[i])) - for i in range(shift_amount)]) + constraints.extend([smt_utils.smt_assert(smt_utils.smt_not(input_bit_ids[i])) for i in range(shift_amount)]) for i in range(output_bit_len - shift_amount): equation = smt_utils.smt_equivalent((output_bit_ids[i], input_bit_ids[i + shift_amount])) constraints.append(smt_utils.smt_assert(equation)) @@ -982,8 +1040,12 @@ def smt_xor_linear_mask_propagation_constraints(self, model=None): for i in range(shift_amount, output_bit_len): equation = smt_utils.smt_equivalent((output_bit_ids[i], input_bit_ids[i - shift_amount])) constraints.append(smt_utils.smt_assert(equation)) - constraints.extend([smt_utils.smt_assert(smt_utils.smt_not(input_bit_ids[i])) - for i in range(output_bit_len - shift_amount, output_bit_len)]) + constraints.extend( + [ + smt_utils.smt_assert(smt_utils.smt_not(input_bit_ids[i])) + for i in range(output_bit_len - shift_amount, output_bit_len) + ] + ) result = input_bit_ids + output_bit_ids, constraints return result diff --git a/claasp/components/shift_rows_component.py b/claasp/components/shift_rows_component.py index 8b2043bd5..f4a8df064 100644 --- a/claasp/components/shift_rows_component.py +++ b/claasp/components/shift_rows_component.py @@ -1,29 +1,42 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** from claasp.components.rotate_component import Rotate +from claasp.name_mappings import WORD_OPERATION class ShiftRows(Rotate): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter): - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter) - self._id = f'shift_rows_{current_round_number}_{current_round_number_of_components}' - self._type = 'word_operation' + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ): + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ) + self._id = f"shift_rows_{current_round_number}_{current_round_number_of_components}" + self._type = WORD_OPERATION diff --git a/claasp/components/sigma_component.py b/claasp/components/sigma_component.py index e69997fe4..16a80448c 100644 --- a/claasp/components/sigma_component.py +++ b/claasp/components/sigma_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -24,11 +23,26 @@ class Sigma(LinearLayer): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, rotation_amounts_parameter): - binary_matrix = linear_layer_to_binary_matrix(SIGMA, output_bit_size, output_bit_size, [rotation_amounts_parameter]) + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ): + binary_matrix = linear_layer_to_binary_matrix( + SIGMA, output_bit_size, output_bit_size, [rotation_amounts_parameter] + ) description = list(binary_matrix.transpose()) - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, description) - self._id = f'sigma_{current_round_number}_{current_round_number_of_components}' + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) + self._id = f"sigma_{current_round_number}_{current_round_number_of_components}" self._input = Input(output_bit_size, input_id_links, input_bit_positions) diff --git a/claasp/components/theta_gaston_component.py b/claasp/components/theta_gaston_component.py index 6c91c5db3..060178eb9 100644 --- a/claasp/components/theta_gaston_component.py +++ b/claasp/components/theta_gaston_component.py @@ -1,24 +1,23 @@ -import os -import pickle -from typing import Any - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** +import os +import pickle +from typing import Any from claasp.input import Input from claasp.cipher_modules.generic_functions import THETA_GASTON @@ -34,13 +33,21 @@ CACHE_DIR = os.path.join(ROOT_DIR, "ciphers", "permutations") os.makedirs(CACHE_DIR, exist_ok=True) + def _matrix_cache_path(cipher_id): return os.path.join(CACHE_DIR, f"gaston_theta_{cipher_id}.pkl") -class ThetaGaston(LinearLayer): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, rotation_amounts_parameter): +class ThetaGaston(LinearLayer): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ): matrix_id = "_".join(str(p) for p in rotation_amounts_parameter) if matrix_id in _cached_matrices: binary_matrix = _cached_matrices[matrix_id] @@ -51,13 +58,20 @@ def __init__(self, current_round_number, current_round_number_of_components, binary_matrix = pickle.load(f) else: binary_matrix = linear_layer_to_binary_matrix( - THETA_GASTON, output_bit_size, output_bit_size, [rotation_amounts_parameter]) + THETA_GASTON, output_bit_size, output_bit_size, [rotation_amounts_parameter] + ) with open(path, "wb") as f: pickle.dump(binary_matrix, f) _cached_matrices[matrix_id] = binary_matrix description = list(binary_matrix.transpose()) - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, description) - self._id = f'theta_gaston_{current_round_number}_{current_round_number_of_components}' + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) + self._id = f"theta_gaston_{current_round_number}_{current_round_number_of_components}" self._input = Input(output_bit_size, input_id_links, input_bit_positions) diff --git a/claasp/components/theta_keccak_component.py b/claasp/components/theta_keccak_component.py index ccb1c81d4..f9790dda0 100644 --- a/claasp/components/theta_keccak_component.py +++ b/claasp/components/theta_keccak_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -24,11 +23,23 @@ class ThetaKeccak(LinearLayer): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): binary_matrix = linear_layer_to_binary_matrix(THETA_KECCAK, output_bit_size, output_bit_size, []) description = list(binary_matrix.transpose()) - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, description) - self._id = f'theta_keccak_{current_round_number}_{current_round_number_of_components}' + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) + self._id = f"theta_keccak_{current_round_number}_{current_round_number_of_components}" self._input = Input(output_bit_size, input_id_links, input_bit_positions) diff --git a/claasp/components/theta_xoodoo_component.py b/claasp/components/theta_xoodoo_component.py index dccd72003..2fb34b4a1 100644 --- a/claasp/components/theta_xoodoo_component.py +++ b/claasp/components/theta_xoodoo_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -24,11 +23,23 @@ class ThetaXoodoo(LinearLayer): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): binary_matrix = linear_layer_to_binary_matrix(THETA_XOODOO, output_bit_size, output_bit_size, []) description = list(binary_matrix.transpose()) - super().__init__(current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, description) - self._id = f'theta_xoodoo_{current_round_number}_{current_round_number_of_components}' + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) + self._id = f"theta_xoodoo_{current_round_number}_{current_round_number_of_components}" self._input = Input(output_bit_size, input_id_links, input_bit_positions) diff --git a/claasp/components/variable_rotate_component.py b/claasp/components/variable_rotate_component.py index 58244261a..3f2feae6b 100644 --- a/claasp/components/variable_rotate_component.py +++ b/claasp/components/variable_rotate_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -19,17 +18,23 @@ from claasp.input import Input from claasp.component import Component +from claasp.name_mappings import WORD_OPERATION class VariableRotate(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter): - component_id = f'var_rot_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' - input_len = 0 - for bits in input_bit_positions: - input_len = input_len + len(bits) - description = ['ROTATE_BY_VARIABLE_AMOUNT', parameter] + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ): + component_id = f"var_rot_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION + input_len = sum(map(len, input_bit_positions)) + description = ["ROTATE_BY_VARIABLE_AMOUNT", parameter] component_input = Input(input_len, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -39,7 +44,7 @@ def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): self.select_words(variable_rotate_code, word_size) wordstring_variables.append(self.id) direction = "RIGHT" if self.description[1] >= 0 else "LEFT" - variable_rotate_code.append(f'\tWordString *{self.id} = {direction}_{self.description[0]}(input);') + variable_rotate_code.append(f"\tWordString *{self.id} = {direction}_{self.description[0]}(input);") if verbosity: self.print_word_values(variable_rotate_code) @@ -50,9 +55,9 @@ def get_word_operation_sign(self, sign, solution): output_id_link = self.id component_sign = 1 sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign diff --git a/claasp/components/variable_shift_component.py b/claasp/components/variable_shift_component.py index 46646b3f2..60496e55f 100644 --- a/claasp/components/variable_shift_component.py +++ b/claasp/components/variable_shift_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -23,17 +22,23 @@ from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.component import Component from claasp.input import Input +from claasp.name_mappings import WORD_OPERATION class VariableShift(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter): - component_id = f'var_shift_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' - input_len = 0 - for bits in input_bit_positions: - input_len = input_len + len(bits) - description = ['SHIFT_BY_VARIABLE_AMOUNT', parameter] + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ): + component_id = f"var_shift_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION + input_len = sum(map(len, input_bit_positions)) + description = ["SHIFT_BY_VARIABLE_AMOUNT", parameter] component_input = Input(input_len, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -99,33 +104,40 @@ def cp_constraints(self): all_inputs = [] for i in range(numb_of_inp - 1): for j in range(len(input_bit_positions[i])): - all_inputs.append(f'{input_id_link[i]}[{input_bit_positions[i][j]}]') - cp_declarations.append(f'array[0..{output_size - 1}] of var 0..1: pre_{output_id_link};') + all_inputs.append(f"{input_id_link[i]}[{input_bit_positions[i][j]}]") + cp_declarations.append(f"array[0..{output_size - 1}] of var 0..1: pre_{output_id_link};") for i in range(output_size): - cp_constraints.append(f'constraint pre_{output_id_link}[{i}]={all_inputs[i]};') - cp_declarations.append(f'var int: shift_amount_{output_id_link};') + cp_constraints.append(f"constraint pre_{output_id_link}[{i}]={all_inputs[i]};") + cp_declarations.append(f"var int: shift_amount_{output_id_link};") cp_constraints.append( - f'constraint bitArrayToInt([{input_id_link[numb_of_inp - 1]}[i]|i in ' - f'{input_bit_positions[numb_of_inp - 1][len(input_bit_positions[numb_of_inp - 1]) - bit_for_shift_amount]}' - f'..{input_bit_positions[numb_of_inp - 1][len(input_bit_positions[numb_of_inp - 1]) - 1]}],' - f'shift_amount_{output_id_link});') + f"constraint bitArrayToInt([{input_id_link[numb_of_inp - 1]}[i]|i in " + f"{input_bit_positions[numb_of_inp - 1][len(input_bit_positions[numb_of_inp - 1]) - bit_for_shift_amount]}" + f"..{input_bit_positions[numb_of_inp - 1][len(input_bit_positions[numb_of_inp - 1]) - 1]}]," + f"shift_amount_{output_id_link});" + ) if shift_direction == 1: cp_constraints.append( - f'constraint {output_id_link}=RShift(pre_{output_id_link},shift_amount_{output_id_link});') + f"constraint {output_id_link}=RShift(pre_{output_id_link},shift_amount_{output_id_link});" + ) else: cp_constraints.append( - f'constraint {output_id_link}=LShift(pre_{output_id_link},shift_amount_{output_id_link});') + f"constraint {output_id_link}=LShift(pre_{output_id_link},shift_amount_{output_id_link});" + ) return cp_declarations, cp_constraints def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_SHIFT_BY_VARIABLE_AMOUNT([{",".join(params)} ], ' - f'{self.output_bit_size}, {self.description[1]})'] + return [ + f" {self.id} = bit_vector_SHIFT_BY_VARIABLE_AMOUNT([{','.join(params)} ], " + f"{self.output_bit_size}, {self.description[1]})" + ] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_SHIFT_BY_VARIABLE_AMOUNT({params}, ' - f'{self.output_bit_size}, {self.description[1]})'] + return [ + f" {self.id} = byte_vector_SHIFT_BY_VARIABLE_AMOUNT({params}, " + f"{self.output_bit_size}, {self.description[1]})" + ] def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): variable_shift_code = [] @@ -133,7 +145,7 @@ def get_word_based_c_code(self, verbosity, word_size, wordstring_variables): self.select_words(variable_shift_code, word_size) wordstring_variables.append(self.id) direction = "RIGHT" if self.description[1] >= 0 else "LEFT" - variable_shift_code.append(f'\tWordString *{self.id} = {direction}_{self.description[0]}(input);') + variable_shift_code.append(f"\tWordString *{self.id} = {direction}_{self.description[0]}(input);") if verbosity: self.print_word_values(variable_shift_code) @@ -144,10 +156,10 @@ def get_word_operation_sign(self, sign, solution): output_id_link = self.id component_sign = 1 sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign @@ -184,7 +196,7 @@ def minizinc_xor_differential_propagation_constraints(self, model): for i in range(len(second_subvector_input_vars)): index_subvector = len(second_subvector_input_vars) - i - 1 - bin_terms.append(f'{2**index_subvector}*{second_subvector_input_vars[index_subvector]}') + bin_terms.append(f"{2**index_subvector}*{second_subvector_input_vars[index_subvector]}") str_shift_amount = "+".join(bin_terms) shift_direction = self.description[1] @@ -192,11 +204,15 @@ def minizinc_xor_differential_propagation_constraints(self, model): mzn_input_array_output = self._create_minizinc_1d_array_from_list(output_vars) if shift_direction < 0: - mzn_shift_by_variable_amount_constraints = [f'constraint LSHIFT_BY_VARIABLE_AMOUNT({mzn_input_array_input},' - f' {str_shift_amount})={mzn_input_array_output};\n'] + mzn_shift_by_variable_amount_constraints = [ + f"constraint LSHIFT_BY_VARIABLE_AMOUNT({mzn_input_array_input}," + f" {str_shift_amount})={mzn_input_array_output};\n" + ] else: - mzn_shift_by_variable_amount_constraints = [f'constraint RSHIFT_BY_VARIABLE_AMOUNT({mzn_input_array_input},' - f' {str_shift_amount})={mzn_input_array_output};\n'] + mzn_shift_by_variable_amount_constraints = [ + f"constraint RSHIFT_BY_VARIABLE_AMOUNT({mzn_input_array_input}," + f" {str_shift_amount})={mzn_input_array_output};\n" + ] return var_names, mzn_shift_by_variable_amount_constraints @@ -229,34 +245,48 @@ def sat_constraints(self): '-var_shift_0_2_31 -key_91', 'var_shift_0_2_31 -state_3_var_shift_0_2_31 key_91']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() input_ids = input_bit_ids[:output_bit_len] shift_ids = input_bit_ids[output_bit_len:] number_of_states = int(math.log2(output_bit_len)) - 1 - states = [[f'state_{i}_{output_bit_ids[j]}' for j in range(output_bit_len)] - for i in range(number_of_states)] + states = [[f"state_{i}_{output_bit_ids[j]}" for j in range(output_bit_len)] for i in range(number_of_states)] constraints = [] for j in range(output_bit_len - 1): - constraints.extend(sat_utils.cnf_vshift_id(states[0][j], input_ids[j], - input_ids[j + 1], shift_ids[output_bit_len - 1])) - constraints.extend(sat_utils.cnf_vshift_false(states[0][output_bit_len - 1], input_ids[output_bit_len - 1], - shift_ids[output_bit_len - 1])) + constraints.extend( + sat_utils.cnf_vshift_id(states[0][j], input_ids[j], input_ids[j + 1], shift_ids[output_bit_len - 1]) + ) + constraints.extend( + sat_utils.cnf_vshift_false( + states[0][output_bit_len - 1], input_ids[output_bit_len - 1], shift_ids[output_bit_len - 1] + ) + ) for i in range(1, number_of_states): - for j in range(output_bit_len - 2 ** i): - constraints.extend(sat_utils.cnf_vshift_id(states[i][j], states[i - 1][j], - states[i - 1][j + 2 ** i], - shift_ids[output_bit_len - 1 - i])) - for j in range(output_bit_len - 2 ** i, output_bit_len): - constraints.extend(sat_utils.cnf_vshift_false(states[i][j], states[i - 1][j], - shift_ids[output_bit_len - 1 - i])) - for j in range(output_bit_len - 2 ** number_of_states): - constraints.extend(sat_utils.cnf_vshift_id(output_bit_ids[j], states[number_of_states - 1][j], - states[number_of_states - 1][j + 2 ** number_of_states], - shift_ids[output_bit_len - 1 - number_of_states])) - for j in range(output_bit_len - 2 ** number_of_states, output_bit_len): - constraints.extend(sat_utils.cnf_vshift_false(output_bit_ids[j], states[number_of_states - 1][j], - shift_ids[output_bit_len - 1 - number_of_states])) + for j in range(output_bit_len - 2**i): + constraints.extend( + sat_utils.cnf_vshift_id( + states[i][j], states[i - 1][j], states[i - 1][j + 2**i], shift_ids[output_bit_len - 1 - i] + ) + ) + for j in range(output_bit_len - 2**i, output_bit_len): + constraints.extend( + sat_utils.cnf_vshift_false(states[i][j], states[i - 1][j], shift_ids[output_bit_len - 1 - i]) + ) + for j in range(output_bit_len - 2**number_of_states): + constraints.extend( + sat_utils.cnf_vshift_id( + output_bit_ids[j], + states[number_of_states - 1][j], + states[number_of_states - 1][j + 2**number_of_states], + shift_ids[output_bit_len - 1 - number_of_states], + ) + ) + for j in range(output_bit_len - 2**number_of_states, output_bit_len): + constraints.extend( + sat_utils.cnf_vshift_false( + output_bit_ids[j], states[number_of_states - 1][j], shift_ids[output_bit_len - 1 - number_of_states] + ) + ) return output_bit_ids, constraints @@ -285,17 +315,17 @@ def smt_constraints(self): '(assert (ite key_91 (not var_shift_0_2_30) (= var_shift_0_2_30 state_3_var_shift_0_2_30)))', '(assert (ite key_91 (not var_shift_0_2_31) (= var_shift_0_2_31 state_3_var_shift_0_2_31)))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() input_ids = input_bit_ids[:output_bit_len] shift_ids = input_bit_ids[output_bit_len:] states = [] number_of_states = int(math.log2(output_bit_len)) - 1 for i in range(number_of_states): - states.append([f'state_{i}_{output_bit_ids[j]}' for j in range(output_bit_len)]) + states.append([f"state_{i}_{output_bit_ids[j]}" for j in range(output_bit_len)]) constraints = [] if len(states) <= 0: - raise ValueError('states must not be empty') + raise ValueError("states must not be empty") # first shift for j in range(output_bit_len - 1): @@ -310,25 +340,26 @@ def smt_constraints(self): # intermediate shifts for i in range(1, number_of_states): - for j in range(output_bit_len - 2 ** i): - consequent = smt_utils.smt_equivalent((states[i][j], states[i - 1][j + 2 ** i])) + for j in range(output_bit_len - 2**i): + consequent = smt_utils.smt_equivalent((states[i][j], states[i - 1][j + 2**i])) alternative = smt_utils.smt_equivalent((states[i][j], states[i - 1][j])) shift = smt_utils.smt_ite(shift_ids[output_bit_len - 1 - i], consequent, alternative) constraints.append(smt_utils.smt_assert(shift)) - for j in range(output_bit_len - 2 ** i, output_bit_len): + for j in range(output_bit_len - 2**i, output_bit_len): consequent = smt_utils.smt_not(states[i][j]) alternative = smt_utils.smt_equivalent((states[i][j], states[i - 1][j])) shift = smt_utils.smt_ite(shift_ids[output_bit_len - 1 - i], consequent, alternative) constraints.append(smt_utils.smt_assert(shift)) # last shift - for j in range(output_bit_len - 2 ** number_of_states): + for j in range(output_bit_len - 2**number_of_states): consequent = smt_utils.smt_equivalent( - (output_bit_ids[j], states[number_of_states - 1][j + 2 ** number_of_states])) + (output_bit_ids[j], states[number_of_states - 1][j + 2**number_of_states]) + ) alternative = smt_utils.smt_equivalent((output_bit_ids[j], states[number_of_states - 1][j])) shift = smt_utils.smt_ite(shift_ids[output_bit_len - 1 - number_of_states], consequent, alternative) constraints.append(smt_utils.smt_assert(shift)) - for j in range(output_bit_len - 2 ** number_of_states, output_bit_len): + for j in range(output_bit_len - 2**number_of_states, output_bit_len): consequent = smt_utils.smt_not(output_bit_ids[j]) alternative = smt_utils.smt_equivalent((output_bit_ids[j], states[number_of_states - 1][j])) shift = smt_utils.smt_ite(shift_ids[output_bit_len - 1 - number_of_states], consequent, alternative) diff --git a/claasp/components/word_permutation_component.py b/claasp/components/word_permutation_component.py index 85f72fd42..631325d46 100644 --- a/claasp/components/word_permutation_component.py +++ b/claasp/components/word_permutation_component.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -21,14 +20,27 @@ class WordPermutation(MixColumn): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, - output_bit_size, permutation_description, word_size): + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + word_size, + ): matrix = [] for i in range(len(permutation_description)): row = [0] * len(permutation_description) row[permutation_description[i]] = 1 matrix.append(row) description = [matrix, 0, word_size] - super().__init__(current_round_number, current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, description) + super().__init__( + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) diff --git a/claasp/components/xor_component.py b/claasp/components/xor_component.py index 9e4a207c2..1a77dc451 100644 --- a/claasp/components/xor_component.py +++ b/claasp/components/xor_component.py @@ -1,17 +1,16 @@ -from claasp.cipher_modules.models.milp.utils.utils import espresso_pos_to_constraints # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -19,17 +18,19 @@ from claasp.input import Input from claasp.component import Component -from claasp.name_mappings import CONSTANT from claasp.cipher_modules.models.smt.utils import utils as smt_utils from claasp.cipher_modules.models.sat.utils import constants, utils as sat_utils from claasp.cipher_modules.models.milp.utils import utils as milp_utils from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_xor_with_n_input_bits import ( output_dictionary_that_contains_xor_inequalities, - update_dictionary_that_contains_xor_inequalities_between_n_input_bits) + update_dictionary_that_contains_xor_inequalities_between_n_input_bits, +) from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits import ( update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs, - output_dictionary_that_contains_wordwise_truncated_xor_inequalities + output_dictionary_that_contains_wordwise_truncated_xor_inequalities, ) +from claasp.cipher_modules.models.milp.utils.utils import espresso_pos_to_constraints +from claasp.name_mappings import WORD_OPERATION def cp_build_truncated_table(numadd): @@ -46,13 +47,15 @@ def cp_build_truncated_table(numadd): sage: cp_build_truncated_table(3) 'array[0..4, 1..3] of int: xor_truncated_table_3 = array2d(0..4, 1..3, [0,0,0,0,1,1,1,0,1,1,1,0,1,1,1]);' """ - size = 2 ** numadd - binary_list = (f'{i:0{numadd}b}' for i in range(size)) - table_items = [','.join(i) for i in binary_list if i.count('1') != 1] - table = ','.join(table_items) - xor_table = f'array[0..{size - numadd - 1}, 1..{numadd}] of int: ' \ - f'xor_truncated_table_{numadd} = array2d(0..{size - numadd - 1}, 1..{numadd}, ' \ - f'[{table}]);' + size = 2**numadd + binary_list = (f"{i:0{numadd}b}" for i in range(size)) + table_items = [",".join(i) for i in binary_list if i.count("1") != 1] + table = ",".join(table_items) + xor_table = ( + f"array[0..{size - numadd - 1}, 1..{numadd}] of int: " + f"xor_truncated_table_{numadd} = array2d(0..{size - numadd - 1}, 1..{numadd}, " + f"[{table}]);" + ) return xor_table @@ -94,16 +97,16 @@ def get_transformed_xor_input_links_and_positions(word_size, all_inputs, i, inpu for j in range(numadd + 1): if all_inputs[i + input_len * j][0] not in input_id_link: input_id_link.append(all_inputs[i + input_len * j][0]) - input_bit_positions[new_numb_of_inp] += [all_inputs[i + input_len * j][1] * word_size + k - for k in range(word_size)] + input_bit_positions[new_numb_of_inp] += [ + all_inputs[i + input_len * j][1] * word_size + k for k in range(word_size) + ] new_numb_of_inp += 1 else: index = 0 for c in range(len(input_id_link)): if input_id_link[c] == all_inputs[i + input_len * j][0]: index += c - input_bit_positions[index] += [all_inputs[i + input_len * j][1] * word_size + k - for k in range(word_size)] + input_bit_positions[index] += [all_inputs[i + input_len * j][1] * word_size + k for k in range(word_size)] input_bit_positions = [x for x in input_bit_positions if x != []] return input_bit_positions, input_id_link @@ -134,12 +137,18 @@ def get_milp_constraints_from_inequalities(inequalities, input_vars, number_of_i class XOR(Component): - def __init__(self, current_round_number, current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size): - component_id = f'xor_{current_round_number}_{current_round_number_of_components}' - component_type = 'word_operation' + def __init__( + self, + current_round_number, + current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ): + component_id = f"xor_{current_round_number}_{current_round_number_of_components}" + component_type = WORD_OPERATION input_len = sum(map(len, input_bit_positions)) - description = ['XOR', int(input_len / output_bit_size)] + description = ["XOR", int(input_len / output_bit_size)] component_input = Input(input_len, input_id_links, input_bit_positions) super().__init__(component_id, component_type, component_input, output_bit_size, description) @@ -175,10 +184,10 @@ def algebraic_polynomials(self, model): ninputs = self.input_bit_size noutputs = self.output_bit_size word_size = noutputs - input_vars = [self.id + "_" + model.input_postfix + str(i) for i in range(ninputs)] - output_vars = [self.id + "_" + model.output_postfix + str(i) for i in range(noutputs)] + input_vars = [f"{self.id}_{model.input_postfix}{i}" for i in range(ninputs)] + output_vars = [f"{self.id}_{model.output_postfix}{i}" for i in range(noutputs)] ring_R = model.ring() - words_vars = [list(map(ring_R, input_vars))[i:i + word_size] for i in range(0, ninputs, word_size)] + words_vars = [list(map(ring_R, input_vars))[i : i + word_size] for i in range(0, ninputs, word_size)] x = [ring_R.zero() for _ in range(noutputs)] for word_vars in words_vars: @@ -216,13 +225,13 @@ def cms_constraints(self): 'x -xor_0_2_14 modadd_0_1_14 key_62', 'x -xor_0_2_15 modadd_0_1_15 key_63']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): - operands = [f'x -{output_bit_ids[i]}'] + operands = [f"x -{output_bit_ids[i]}"] operands.extend(input_bit_ids[i::output_bit_len]) - constraints.append(' '.join(operands)) + constraints.append(" ".join(operands)) return output_bit_ids, constraints @@ -251,19 +260,14 @@ def cp_constraints(self): ... 'constraint xor_0_2[15] = (modadd_0_1[15] + key[63]) mod 2;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_constraints = [] - for i in range(output_size): - operation = ' + '.join(all_inputs[i::output_size]) - new_constraint = f'constraint {output_id_link}[{i}] = ({operation}) mod 2;' - cp_constraints.append(new_constraint) + for i in range(self.output_bit_size): + operation = " + ".join(all_inputs[i::self.output_bit_size]) + cp_constraints.append(f"constraint {self.id}[{i}] = ({operation}) mod 2;") return cp_declarations, cp_constraints @@ -286,21 +290,17 @@ def cp_deterministic_truncated_xor_differential_constraints(self): ... 'constraint if ((modadd_0_1[15] < 2) /\\ (key[63]< 2)) then xor_0_2[15] = (modadd_0_1[15] + key[63]) mod 2 else xor_0_2[15] = 2 endif;']) """ - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{position}]' for position in bit_positions]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend([f"{id_link}[{position}]" for position in bit_positions]) cp_constraints = [] - for i in range(output_size): - operation = ' < 2) /\\ ('.join(all_inputs[i::output_size]) - new_constraint = 'constraint if ((' - new_constraint += operation + '< 2)) then ' - operation2 = ' + '.join(all_inputs[i::output_size]) - new_constraint += f'{output_id_link}[{i}] = ({operation2}) mod 2 else {output_id_link}[{i}] = 2 endif;' + for i in range(self.output_bit_size): + operation = " < 2) /\\ (".join(all_inputs[i::self.output_bit_size]) + new_constraint = "constraint if ((" + new_constraint += operation + "< 2)) then " + operation2 = " + ".join(all_inputs[i::self.output_bit_size]) + new_constraint += f"{self.id}[{i}] = ({operation2}) mod 2 else {self.id}[{i}] = 2 endif;" cp_constraints.append(new_constraint) return cp_declarations, cp_constraints @@ -325,31 +325,22 @@ def cp_hybrid_deterministic_truncated_xor_differential_constraints(self): ... 'constraint if (modadd_0_1[15] < 2) /\\ (key[63] < 2) then xor_0_2[15] = (modadd_0_1[15] + key[63]) mod 2 elseif (modadd_0_1[15] + key[63] = modadd_0_1[15]) then xor_0_2[15] = modadd_0_1[15] elseif (modadd_0_1[15] + key[63] = key[63]) then xor_0_2[15] = key[63] else xor_0_2[15] = 2 endif;']) """ - - output_size = int(self.output_bit_size) - input_id_links = self.input_id_links - output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs = [ - f'{id_link}[{position}]' - for id_link, bit_positions in zip(input_id_links, input_bit_positions) + f"{id_link}[{position}]" + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions) for position in bit_positions ] cp_constraints = [] - for i in range(output_size): - inputs = all_inputs[i::output_size] - condition = ' < 2) /\\ ('.join(inputs) + ' < 2' - operation_sum = ' + '.join(inputs) - + for i in range(self.output_bit_size): + inputs = all_inputs[i::self.output_bit_size] + condition = " < 2) /\\ (".join(inputs) + " < 2" + operation_sum = " + ".join(inputs) new_constraint = ( - f'constraint if ({condition}) then ' - f'{output_id_link}[{i}] = ({operation_sum}) mod 2 ' - f'elseif ({operation_sum} = {inputs[0]}) then ' - f'{output_id_link}[{i}] = {inputs[0]} ' - f'elseif ({operation_sum} = {inputs[1]}) then ' - f'{output_id_link}[{i}] = {inputs[1]} ' - f'else {output_id_link}[{i}] = 2 endif;' + f"constraint if ({condition}) then {self.id}[{i}] = ({operation_sum}) mod 2 " + f"elseif ({operation_sum} = {inputs[0]}) then {self.id}[{i}] = {inputs[0]} " + f"elseif ({operation_sum} = {inputs[1]}) then {self.id}[{i}] = {inputs[1]} " + f"else {self.id}[{i}] = 2 endif;" ) cp_constraints.append(new_constraint) @@ -381,75 +372,111 @@ def cp_wordwise_deterministic_truncated_xor_differential_constraints(self, model ... 'constraint if xor_0_0_temp_0_15_active + xor_0_0_temp_1_15_active > 2 then xor_0_0_active[15] == 3 /\\ xor_0_0_value[15] = -2 elseif xor_0_0_temp_0_15_active + xor_0_0_temp_1_15_active == 1 then xor_0_0_active[15] = 1 /\\ xor_0_0_value[15] = xor_0_0_temp_0_15_value + xor_0_0_temp_1_15_value elseif xor_0_0_temp_0_15_active + xor_0_0_temp_1_15_active == 0 then xor_0_0_active[15] = 0 /\\ xor_0_0_value[15] = 0 elseif xor_0_0_temp_0_15_value + xor_0_0_temp_1_15_value < 0 then xor_0_0_active[15] = 2 /\\ xor_0_0_value[15] = -1 elseif xor_0_0_temp_0_15_value == xor_0_0_temp_1_15_value then xor_0_0_active[15] = 0 /\\ xor_0_0_value[15] = 0 else xor_0_0_active[15] = 1 /\\ xor_0_0_value[15] = sum([(((floor(xor_0_0_temp_0_15_value/pow(2,j)) + floor(xor_0_0_temp_1_15_value/pow(2,j))) mod 2) * pow(2,j)) | j in 0..xor_0_0_bound_value_0_15]) endif;']) """ - input_id_links = self.input_id_links output_id_link = self.id - input_bit_positions = self.input_bit_positions cp_declarations = [] all_inputs_value = [] all_inputs_active = [] numadd = self.description[1] word_size = model.word_size - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs_value.extend([f'{id_link}_value[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) - all_inputs_active.extend([f'{id_link}_active[{bit_positions[j * word_size] // word_size}]' - for j in range(len(bit_positions) // word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs_value.extend( + [ + f"{id_link}_value[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) + all_inputs_active.extend( + [ + f"{id_link}_active[{bit_positions[j * word_size] // word_size}]" + for j in range(len(bit_positions) // word_size) + ] + ) input_len = len(all_inputs_value) // numadd cp_constraints = [] initial_constraints = [] for i in range(2 * numadd - 2): for j in range(input_len): - cp_declarations.append( - f'var -2..{2**word_size - 1}: {output_id_link}_temp_{i}_{j}_value;') - cp_declarations.append( - f'var 0..3: {output_id_link}_temp_{i}_{j}_active;') + cp_declarations.append(f"var -2..{2**word_size - 1}: {output_id_link}_temp_{i}_{j}_value;") + cp_declarations.append(f"var 0..3: {output_id_link}_temp_{i}_{j}_active;") for i in range(input_len): for summand in range(numadd - 2): - cp_declarations.append(f'var 0..{word_size + 1}: {output_id_link}_bound_value_{numadd + summand - 1}_{i} = if {output_id_link}_temp_{numadd + summand - 1}_{i}_value ' \ - f'+ {output_id_link}_temp_{summand}_{i}_value > 0 then ceil(log2({output_id_link}_temp_{numadd + summand - 1}_{i}_value ' \ - f'+ {output_id_link}_temp_{summand}_{i}_value)) else 0 endif;') - cp_declarations.append(f'var 0..{word_size + 1}: {output_id_link}_bound_value_{numadd - 2}_{i} = if {output_id_link}_temp_{numadd - 2}_{i}_value ' \ - f'+ {output_id_link}_temp_{2 * numadd - 3}_{i}_value > 0 then ceil(log2({output_id_link}_temp_{numadd - 2}_{i}_value ' \ - f'+ {output_id_link}_temp_{2 * numadd - 3}_{i}_value)) else 0 endif;') + cp_declarations.append( + f"var 0..{word_size + 1}: {output_id_link}_bound_value_{numadd + summand - 1}_{i} = if {output_id_link}_temp_{numadd + summand - 1}_{i}_value " + f"+ {output_id_link}_temp_{summand}_{i}_value > 0 then ceil(log2({output_id_link}_temp_{numadd + summand - 1}_{i}_value " + f"+ {output_id_link}_temp_{summand}_{i}_value)) else 0 endif;" + ) + cp_declarations.append( + f"var 0..{word_size + 1}: {output_id_link}_bound_value_{numadd - 2}_{i} = if {output_id_link}_temp_{numadd - 2}_{i}_value " + f"+ {output_id_link}_temp_{2 * numadd - 3}_{i}_value > 0 then ceil(log2({output_id_link}_temp_{numadd - 2}_{i}_value " + f"+ {output_id_link}_temp_{2 * numadd - 3}_{i}_value)) else 0 endif;" + ) for i in range(numadd): for j in range(input_len): initial_constraints.append( - f'constraint {output_id_link}_temp_{i}_{j}_value = {all_inputs_value[i * input_len + j]} /\\ ' - f'{output_id_link}_temp_{i}_{j}_active = {all_inputs_active[i * input_len + j]};') + f"constraint {output_id_link}_temp_{i}_{j}_value = {all_inputs_value[i * input_len + j]} /\\ " + f"{output_id_link}_temp_{i}_{j}_active = {all_inputs_active[i * input_len + j]};" + ) cp_constraints += initial_constraints for i in range(input_len): - new_constraint = '' + new_constraint = "" for summand in range(numadd - 2): - new_constraint += f'constraint if {output_id_link}_temp_{numadd + summand - 1}_{i}_active + {output_id_link}_temp_{summand}_{i}_active > 2 then ' \ - f'{output_id_link}_temp_{numadd + summand}_{i}_active = 3 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = -2 ' - new_constraint += f'elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_active + {output_id_link}_temp_{summand}_{i}_active == 1 then' \ - f' {output_id_link}_temp_{numadd + summand}_{i}_active = 1 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value =' \ - f' {output_id_link}_temp_{numadd + summand - 1}_{i}_value + {output_id_link}_temp_{summand}_{i}_value ' - new_constraint += f'elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_active + {output_id_link}_temp_{summand}_{i}_active == 0 then' \ - f' {output_id_link}_temp_{numadd + summand}_{i}_active = 0 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = 0 ' - new_constraint += f'elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_value + {output_id_link}_temp_{summand}_{i}_value < 0 then ' \ - f'{output_id_link}_temp_{numadd + summand}_{i}_active = 2 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = -1 ' - new_constraint += f'elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_value == {output_id_link}_temp_{summand}_{i}_value then ' \ - f'{output_id_link}_temp_{numadd + summand}_{i}_active = 0 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = 0 ' - xor_to_int = f'sum([(((floor({output_id_link}_temp_{numadd + summand - 1}_{i}_value/pow(2,j)) + floor({output_id_link}_temp_{summand}_{i}' \ - f'_value/pow(2,j))) mod 2) * pow(2,j)) | j in 0..{output_id_link}_bound_value_{numadd + summand - 1}_{i}])' - new_constraint += f'else {output_id_link}_temp_{numadd + summand}_{i}_active = 1 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value' \ - f' = {xor_to_int} endif;\n' - new_constraint += f'constraint if {output_id_link}_temp_{numadd - 2}_{i}_active + {output_id_link}_temp_{2 * numadd - 3}_{i}_active > 2 then ' \ - f'{output_id_link}_active[{i}] == 3 /\\ {output_id_link}_value[{i}] = -2 ' - new_constraint += f'elseif {output_id_link}_temp_{numadd - 2}_{i}_active + {output_id_link}_temp_{2 * numadd - 3}_{i}_active == 1 then ' \ - f'{output_id_link}_active[{i}] = 1 /\\ {output_id_link}_value[{i}] = {output_id_link}_temp_{numadd - 2}_' \ - f'{i}_value + {output_id_link}_temp_{2 * numadd - 3}_{i}_value ' - new_constraint += f'elseif {output_id_link}_temp_{numadd - 2}_{i}_active + {output_id_link}_temp_{2 * numadd - 3}_{i}_active == 0 then ' \ - f'{output_id_link}_active[{i}] = 0 /\\ {output_id_link}_value[{i}] = 0 ' - new_constraint += f'elseif {output_id_link}_temp_{numadd - 2}_{i}_value + {output_id_link}_temp_{2 * numadd - 3}_{i}_value < 0 then ' \ - f'{output_id_link}_active[{i}] = 2 /\\ {output_id_link}_value[{i}] = -1 ' - new_constraint += f'elseif {output_id_link}_temp_{numadd - 2}_{i}_value == {output_id_link}_temp_{2 * numadd - 3}_{i}_value then ' \ - f'{output_id_link}_active[{i}] = 0 /\\ {output_id_link}_value[{i}] = 0 ' - xor_to_int = f'sum([(((floor({output_id_link}_temp_{numadd - 2}_{i}_value/pow(2,j)) + floor({output_id_link}_temp_{2 * numadd - 3}_{i}' \ - f'_value/pow(2,j))) mod 2) * pow(2,j)) | j in 0..{output_id_link}_bound_value_{numadd - 2}_{i}])' - new_constraint += f'else {output_id_link}_active[{i}] = 1 /\\ {output_id_link}_value[{i}] =' \ - f' {xor_to_int} endif;' + new_constraint += ( + f"constraint if {output_id_link}_temp_{numadd + summand - 1}_{i}_active + {output_id_link}_temp_{summand}_{i}_active > 2 then " + f"{output_id_link}_temp_{numadd + summand}_{i}_active = 3 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = -2 " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_active + {output_id_link}_temp_{summand}_{i}_active == 1 then" + f" {output_id_link}_temp_{numadd + summand}_{i}_active = 1 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value =" + f" {output_id_link}_temp_{numadd + summand - 1}_{i}_value + {output_id_link}_temp_{summand}_{i}_value " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_active + {output_id_link}_temp_{summand}_{i}_active == 0 then" + f" {output_id_link}_temp_{numadd + summand}_{i}_active = 0 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = 0 " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_value + {output_id_link}_temp_{summand}_{i}_value < 0 then " + f"{output_id_link}_temp_{numadd + summand}_{i}_active = 2 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = -1 " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd + summand - 1}_{i}_value == {output_id_link}_temp_{summand}_{i}_value then " + f"{output_id_link}_temp_{numadd + summand}_{i}_active = 0 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value = 0 " + ) + xor_to_int = ( + f"sum([(((floor({output_id_link}_temp_{numadd + summand - 1}_{i}_value/pow(2,j)) + floor({output_id_link}_temp_{summand}_{i}" + f"_value/pow(2,j))) mod 2) * pow(2,j)) | j in 0..{output_id_link}_bound_value_{numadd + summand - 1}_{i}])" + ) + new_constraint += ( + f"else {output_id_link}_temp_{numadd + summand}_{i}_active = 1 /\\ {output_id_link}_temp_{numadd + summand}_{i}_value" + f" = {xor_to_int} endif;\n" + ) + new_constraint += ( + f"constraint if {output_id_link}_temp_{numadd - 2}_{i}_active + {output_id_link}_temp_{2 * numadd - 3}_{i}_active > 2 then " + f"{output_id_link}_active[{i}] == 3 /\\ {output_id_link}_value[{i}] = -2 " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd - 2}_{i}_active + {output_id_link}_temp_{2 * numadd - 3}_{i}_active == 1 then " + f"{output_id_link}_active[{i}] = 1 /\\ {output_id_link}_value[{i}] = {output_id_link}_temp_{numadd - 2}_" + f"{i}_value + {output_id_link}_temp_{2 * numadd - 3}_{i}_value " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd - 2}_{i}_active + {output_id_link}_temp_{2 * numadd - 3}_{i}_active == 0 then " + f"{output_id_link}_active[{i}] = 0 /\\ {output_id_link}_value[{i}] = 0 " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd - 2}_{i}_value + {output_id_link}_temp_{2 * numadd - 3}_{i}_value < 0 then " + f"{output_id_link}_active[{i}] = 2 /\\ {output_id_link}_value[{i}] = -1 " + ) + new_constraint += ( + f"elseif {output_id_link}_temp_{numadd - 2}_{i}_value == {output_id_link}_temp_{2 * numadd - 3}_{i}_value then " + f"{output_id_link}_active[{i}] = 0 /\\ {output_id_link}_value[{i}] = 0 " + ) + xor_to_int = ( + f"sum([(((floor({output_id_link}_temp_{numadd - 2}_{i}_value/pow(2,j)) + floor({output_id_link}_temp_{2 * numadd - 3}_{i}" + f"_value/pow(2,j))) mod 2) * pow(2,j)) | j in 0..{output_id_link}_bound_value_{numadd - 2}_{i}])" + ) + new_constraint += ( + f"else {output_id_link}_active[{i}] = 1 /\\ {output_id_link}_value[{i}] = {xor_to_int} endif;" + ) cp_constraints.append(new_constraint) return cp_declarations, cp_constraints @@ -477,18 +504,21 @@ def cp_xor_differential_propagation_first_step_constraints(self, model, variable (['array[0..1, 1..2] of int: xor_truncated_table_2 = array2d(0..1, 1..2, [0,0,1,1]);'], 'constraint table([rot_2_16[0]]++[xor_2_26[0]], xor_truncated_table_2);') """ - input_id_links = self.input_id_links - input_bit_positions = self.input_bit_positions - description = self.description - numadd = description[1] + numadd = self.description[1] all_inputs = [] - for id_link, bit_positions in zip(input_id_links, input_bit_positions): - all_inputs.extend([f'{id_link}[{bit_positions[j * model.word_size] // model.word_size}]' - for j in range(len(bit_positions) // model.word_size)]) + for id_link, bit_positions in zip(self.input_id_links, self.input_bit_positions): + all_inputs.extend( + [ + f"{id_link}[{bit_positions[j * model.word_size] // model.word_size}]" + for j in range(len(bit_positions) // model.word_size) + ] + ) input_len = len(all_inputs) // numadd - cp_constraints = 'constraint table(' \ - + '++'.join([f'[{all_inputs[input_len * j]}]' for j in range(numadd)]) \ - + f', xor_truncated_table_{numadd});' + cp_constraints = ( + "constraint table(" + + "++".join(f"[{all_inputs[input_len * j]}]" for j in range(numadd)) + + f", xor_truncated_table_{numadd});" + ) xor_table = cp_build_truncated_table(numadd) cp_declarations = [] if xor_table not in variables_list: @@ -516,45 +546,47 @@ def cp_xor_linear_mask_propagation_constraints(self, model=None): ... 'constraint xor_0_2_o[15] = xor_0_2_i[31];']) """ - input_size = self.input_bit_size - output_size = self.output_bit_size - output_id_link = self.id + cp_declarations = [ + f"array[0..{self.input_bit_size - 1}] of var 0..1: {self.id}_i;", + f"array[0..{self.output_bit_size - 1}] of var 0..1: {self.id}_o;", + ] num_of_addenda = self.description[1] - input_len = input_size // num_of_addenda - cp_declarations = [f'array[0..{input_size - 1}] of var 0..1: {output_id_link}_i;', - f'array[0..{output_size - 1}] of var 0..1: {output_id_link}_o;'] + input_len = self.input_bit_size // num_of_addenda cp_constraints = [] - for i in range(output_size): - cp_constraints.extend([f'constraint {output_id_link}_o[{i}] = {output_id_link}_i[{i + input_len * j}];' - for j in range(num_of_addenda)]) + for i in range(self.output_bit_size): + cp_constraints.extend( + [ + f"constraint {self.id}_o[{i}] = {self.id}_i[{i + input_len * j}];" + for j in range(num_of_addenda) + ] + ) result = cp_declarations, cp_constraints return result def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes): - return [f' {self.id} = bit_vector_XOR([{",".join(params)} ], {self.description[1]}, {self.output_bit_size})'] + return [f" {self.id} = bit_vector_XOR([{','.join(params)} ], {self.description[1]}, {self.output_bit_size})"] def get_byte_based_vectorized_python_code(self, params): - return [f' {self.id} = byte_vector_XOR({params})'] + return [f" {self.id} = byte_vector_XOR({params})"] def get_word_operation_sign(self, constants, sign, solution): output_id_link = self.id - input_id_links = self.input_id_links - input_size = self.input_bit_size - for i, input_id_link in enumerate(input_id_links): - if 'constant' in input_id_link: - int_const_mask = int(solution['components_values'][f'{input_id_link}_o']['value']) - bit_const_mask = [int(digit) for digit in format(int_const_mask, f'0{input_size}b')] + for i, input_id_link in enumerate(self.input_id_links): + if "constant" in input_id_link: + int_const_mask = int(solution["components_values"][f"{input_id_link}_o"]["value"]) + bit_const_mask = [int(digit) for digit in format(int_const_mask, f"0{self.input_bit_size}b")] input_bit_positions = self.input_bit_positions[i] constant = int(constants[input_id_link]) - bit_constant = [int(digit) for digit in format(constant, f'0{input_size}b')] - component_sign = generic_with_constant_sign_linear_constraints(bit_constant, bit_const_mask, - input_bit_positions) + bit_constant = [int(digit) for digit in format(constant, f"0{self.input_bit_size}b")] + component_sign = generic_with_constant_sign_linear_constraints( + bit_constant, bit_const_mask, input_bit_positions + ) sign = sign * component_sign - solution['components_values'][f'{output_id_link}_o']['sign'] = component_sign - solution['components_values'][output_id_link] = solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_o'] - del solution['components_values'][f'{output_id_link}_i'] + solution["components_values"][f"{output_id_link}_o"]["sign"] = component_sign + solution["components_values"][output_id_link] = solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_o"] + del solution["components_values"][f"{output_id_link}_i"] return sign @@ -607,8 +639,9 @@ def milp_constraints(self, model): update_dictionary_that_contains_xor_inequalities_between_n_input_bits(number_of_input_bits) dict_inequalities = output_dictionary_that_contains_xor_inequalities() inequalities = dict_inequalities[number_of_input_bits] - constraints.extend(get_milp_constraints_from_inequalities(inequalities, input_vars, - number_of_input_bits, output_vars, x)) + constraints.extend( + get_milp_constraints_from_inequalities(inequalities, input_vars, number_of_input_bits, output_vars, x) + ) return variables, constraints @@ -724,22 +757,31 @@ def milp_bitwise_deterministic_truncated_xor_differential_binary_constraints(sel input_id_tuples, output_id_tuples = self._get_input_output_variables_tuples() input_ids, output_ids = self._get_input_output_variables() - variables = [(f"x[{var_elt}]", x[var_elt]) for var_tuple in input_id_tuples + output_id_tuples for var_elt in - var_tuple] + variables = [ + (f"x[{var_elt}]", x[var_elt]) for var_tuple in input_id_tuples + output_id_tuples for var_elt in var_tuple + ] - linking_constraints = model.link_binary_tuples_to_integer_variables(input_id_tuples + output_id_tuples, - input_ids + output_ids) + linking_constraints = model.link_binary_tuples_to_integer_variables( + input_id_tuples + output_id_tuples, input_ids + output_ids + ) number_of_inputs = self.description[1] constraints = [] + linking_constraints for i, output_id in enumerate(output_id_tuples): - result_ids = [(f'temp_xor_{j}_{self.id}_{i}_0', f'temp_xor_{j}_{self.id}_{i}_1') - for j in range(number_of_inputs - 2)] + [output_id] - constraints.extend(milp_utils.milp_xor_truncated(model, input_id_tuples[i::output_bit_size][0], - input_id_tuples[i::output_bit_size][1], result_ids[0])) + result_ids = [ + (f"temp_xor_{j}_{self.id}_{i}_0", f"temp_xor_{j}_{self.id}_{i}_1") for j in range(number_of_inputs - 2) + ] + [output_id] + constraints.extend( + milp_utils.milp_xor_truncated( + model, input_id_tuples[i::output_bit_size][0], input_id_tuples[i::output_bit_size][1], result_ids[0] + ) + ) for chunk in range(1, number_of_inputs - 1): - constraints.extend(milp_utils.milp_xor_truncated(model, input_id_tuples[i::output_bit_size][chunk + 1], - result_ids[chunk - 1], result_ids[chunk])) + constraints.extend( + milp_utils.milp_xor_truncated( + model, input_id_tuples[i::output_bit_size][chunk + 1], result_ids[chunk - 1], result_ids[chunk] + ) + ) return variables, constraints @@ -789,8 +831,10 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode # if a_i < 2 for all i then b = XOR(a_i, i in range(num_inputs)) # else b = 2 - a = [[x_class[input_vars[i + chunk * input_bit_size]] for chunk in range(num_of_inputs)] for i in - range(input_bit_size)] + a = [ + [x_class[input_vars[i + chunk * input_bit_size]] for chunk in range(num_of_inputs)] + for i in range(input_bit_size) + ] b = [x_class[output_vars[i]] for i in range(output_bit_size)] for i in range(output_bit_size): @@ -808,8 +852,9 @@ def milp_bitwise_deterministic_truncated_xor_differential_constraints(self, mode # if all_ai_less_2 == 1 then b = XOR(a_i, i in range(num_inputs)) # else b = 2 xor_constr = milp_utils.milp_generalized_xor(a[i], b[i]) - constr = milp_utils.milp_if_then_else(all_ai_less_2, xor_constr, [b[i] == 2], - model._model.get_max(x_class) * num_of_inputs) + constr = milp_utils.milp_if_then_else( + all_ai_less_2, xor_constr, [b[i] == 2], model._model.get_max(x_class) * num_of_inputs + ) constraints.extend(constr) return variables, constraints @@ -872,8 +917,9 @@ def milp_wordwise_deterministic_truncated_xor_differential_constraints(self, mod variables = [(f"x[{var}]", x[var]) for sublist in input_vars + output_vars for var in sublist] constraints = [] - update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs(model.word_size, - num_of_inputs) + update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs( + model.word_size, num_of_inputs + ) dict_inequalities = output_dictionary_that_contains_wordwise_truncated_xor_inequalities() inequalities = dict_inequalities[model.word_size][num_of_inputs] @@ -931,18 +977,24 @@ def milp_wordwise_deterministic_truncated_xor_differential_sequential_constraint constraints = [] for i, output_var in enumerate(output_vars): - result_ids = [tuple([f'temp_xor_{j}_{self.id}_word_{i}_0', f'temp_xor_{j}_{self.id}_word_{i}_1'] + [ - f'temp_xor_{j}_{self.id}_word_{i}_bit_{k}' for k in range(model.word_size)]) for j in - range(number_of_inputs - 2)] + [output_var] - constraints.extend(milp_utils.milp_xor_truncated_wordwise(model, - input_vars[i::output_word_size][0], - input_vars[i::output_word_size][1], - result_ids[0])) + result_ids = [ + tuple( + [f"temp_xor_{j}_{self.id}_word_{i}_0", f"temp_xor_{j}_{self.id}_word_{i}_1"] + + [f"temp_xor_{j}_{self.id}_word_{i}_bit_{k}" for k in range(model.word_size)] + ) + for j in range(number_of_inputs - 2) + ] + [output_var] + constraints.extend( + milp_utils.milp_xor_truncated_wordwise( + model, input_vars[i::output_word_size][0], input_vars[i::output_word_size][1], result_ids[0] + ) + ) for chunk in range(1, number_of_inputs - 1): - constraints.extend(milp_utils.milp_xor_truncated_wordwise(model, - input_vars[i::output_word_size][chunk + 1], - result_ids[chunk - 1], - result_ids[chunk])) + constraints.extend( + milp_utils.milp_xor_truncated_wordwise( + model, input_vars[i::output_word_size][chunk + 1], result_ids[chunk - 1], result_ids[chunk] + ) + ) return variables, constraints @@ -988,7 +1040,7 @@ def milp_wordwise_deterministic_truncated_xor_differential_simple_constraints(se output_vars = [x_class[var] for var in output_class_vars] for i in range(len(output_class_vars)): - input_words = input_vars[i::len(output_class_vars)] + input_words = input_vars[i :: len(output_class_vars)] input_a = input_words[0] input_b = input_words[1] output_c = output_vars[i] @@ -996,8 +1048,9 @@ def milp_wordwise_deterministic_truncated_xor_differential_simple_constraints(se then_constraints_list = [] # if dX0 + dX1 > 2 then dY = 3 - a_b_greater_2, geq_2_constraints = milp_utils.milp_geq(model, input_a + input_b, 2, - 2 * model._model.get_max(x_class) + 1) + a_b_greater_2, geq_2_constraints = milp_utils.milp_geq( + model, input_a + input_b, 2, 2 * model._model.get_max(x_class) + 1 + ) var_if_list.append(a_b_greater_2) constraints.extend(geq_2_constraints) then_constraints_list.append([output_c == 3]) @@ -1017,8 +1070,15 @@ def milp_wordwise_deterministic_truncated_xor_differential_simple_constraints(se # else dY = 2 else_constraints = [output_c == 2] - constraints.extend(milp_utils.milp_if_elif_else(model, var_if_list, then_constraints_list, else_constraints, - num_of_inputs * model._model.get_max(x_class))) + constraints.extend( + milp_utils.milp_if_elif_else( + model, + var_if_list, + then_constraints_list, + else_constraints, + num_of_inputs * model._model.get_max(x_class), + ) + ) return variables, constraints @@ -1047,13 +1107,17 @@ def create_block_of_xor_constraints(input_vars_1_temp, input_vars_2_temp, output mzn_input_array_2 = self._create_minizinc_1d_array_from_list(input_vars_2_temp) mzn_output_array = self._create_minizinc_1d_array_from_list(output_varstrs_temp) if model.sat_or_milp == "sat": - mzn_block_variables = '' - mzn_block_constraints = f'constraint xor_word(\n{mzn_input_array_1},' \ - f'\n{mzn_input_array_2},\n{mzn_output_array})={model.true_value};\n' + mzn_block_variables = "" + mzn_block_constraints = ( + f"constraint xor_word(\n{mzn_input_array_1}," + f"\n{mzn_input_array_2},\n{mzn_output_array})={model.true_value};\n" + ) else: - mzn_block_variables = f'array [0..{noutputs}-1] of var 0..1: dummy_{component_id}_{i};\n' - mzn_block_constraints = f'constraint xor_word(\n{mzn_input_array_1},\n{mzn_input_array_2},' \ - f'\n{mzn_output_array},\ndummy_{component_id}_{i})={model.true_value};\n' + mzn_block_variables = f"array [0..{noutputs}-1] of var 0..1: dummy_{component_id}_{i};\n" + mzn_block_constraints = ( + f"constraint xor_word(\n{mzn_input_array_1},\n{mzn_input_array_2}," + f"\n{mzn_output_array},\ndummy_{component_id}_{i})={model.true_value};\n" + ) return mzn_block_variables, mzn_block_constraints if self.description[0].lower() != "xor": @@ -1069,23 +1133,25 @@ def create_block_of_xor_constraints(input_vars_1_temp, input_vars_2_temp, output output_vars = [component_id + "_" + model.output_postfix + str(i) for i in range(noutputs)] ninput_words = int(self.description[1]) word_chunk = noutputs - new_output_vars = [input_vars[0 * word_chunk:0 * word_chunk + word_chunk]] + new_output_vars = [input_vars[0 * word_chunk : 0 * word_chunk + word_chunk]] for i in range(ninput_words - 2): new_output_vars_temp = [] for output_var in output_vars: - mzn_constraints += [f'var {model.data_type}: {output_var}_{str(i)};\n'] + mzn_constraints += [f"var {model.data_type}: {output_var}_{str(i)};\n"] new_output_vars_temp.append(output_var + "_" + str(i)) new_output_vars.append(new_output_vars_temp) for i in range(1, ninput_words): - input_vars_1 = input_vars[i * word_chunk:i * word_chunk + word_chunk] + input_vars_1 = input_vars[i * word_chunk : i * word_chunk + word_chunk] input_vars_2 = new_output_vars[i - 1] if i == ninput_words - 1: - mzn_variables_and_constraints = create_block_of_xor_constraints(input_vars_1, input_vars_2, - output_vars, i) + mzn_variables_and_constraints = create_block_of_xor_constraints( + input_vars_1, input_vars_2, output_vars, i + ) else: - mzn_variables_and_constraints = create_block_of_xor_constraints(input_vars_1, input_vars_2, - new_output_vars[i], i) + mzn_variables_and_constraints = create_block_of_xor_constraints( + input_vars_1, input_vars_2, new_output_vars[i], i + ) var_names += [mzn_variables_and_constraints[0]] mzn_constraints += [mzn_variables_and_constraints[1]] @@ -1128,12 +1194,13 @@ def sat_constraints(self): 'xor_0_2_15 modadd_0_1_15 -key_63', '-xor_0_2_15 -modadd_0_1_15 -key_63']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): - result_bit_ids = [f'inter_{j}_{output_bit_ids[i]}' - for j in range(self.description[1] - 2)] + [output_bit_ids[i]] + result_bit_ids = [f"inter_{j}_{output_bit_ids[i]}" for j in range(self.description[1] - 2)] + [ + output_bit_ids[i] + ] constraints.extend(sat_utils.cnf_xor_seq(result_bit_ids, input_bit_ids[i::output_bit_len])) return output_bit_ids, constraints @@ -1174,8 +1241,9 @@ def sat_bitwise_deterministic_truncated_xor_differential_constraints(self): out_ids = [(id_0, id_1) for id_0, id_1 in zip(out_ids_0, out_ids_1)] constraints = [] for i, out_id in enumerate(out_ids): - result_ids = [(f'inter_{j}_{self.id}_{i}_0', f'inter_{j}_{self.id}_{i}_1') - for j in range(self.description[1] - 2)] + [out_id] + result_ids = [ + (f"inter_{j}_{self.id}_{i}_0", f"inter_{j}_{self.id}_{i}_1") for j in range(self.description[1] - 2) + ] + [out_id] constraints.extend(sat_utils.cnf_xor_truncated_seq(result_ids, in_ids[i::out_len])) return out_ids_0 + out_ids_1, constraints @@ -1284,7 +1352,7 @@ def smt_constraints(self): '(assert (= xor_0_2_14 (xor modadd_0_1_14 key_62)))', '(assert (= xor_0_2_15 (xor modadd_0_1_15 key_63)))']) """ - _, input_bit_ids = self._generate_input_ids() + input_bit_ids = self._generate_input_ids() output_bit_len, output_bit_ids = self._generate_output_ids() constraints = [] for i in range(output_bit_len): @@ -1388,11 +1456,10 @@ def cp_transform_xor_components_for_first_step(self, model): input_id_link = self.input_id_links output_id_link = self.id input_bit_positions = self.input_bit_positions - description = self.description - numadd = description[1] + numadd = self.description[1] numb_of_inp = len(input_id_link) all_inputs = [] - cp_declarations = [f'array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};'] + cp_declarations = [f"array[0..{(output_size - 1) // model.word_size}] of var 0..1: {output_id_link};"] number_of_mix = 0 is_mix = False for i in range(numb_of_inp): @@ -1401,7 +1468,7 @@ def cp_transform_xor_components_for_first_step(self, model): rem = len(input_bit_positions[i]) % model.word_size if rem != 0: rem = model.word_size - (len(input_bit_positions[i]) % model.word_size) - all_inputs.append([f'{output_id_link}_i', number_of_mix]) + all_inputs.append([f"{output_id_link}_i", number_of_mix]) number_of_mix += 1 is_mix = True l = 1 @@ -1411,18 +1478,18 @@ def cp_transform_xor_components_for_first_step(self, model): rem -= length l += 1 if is_mix: - cp_declarations.append(f'array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;') + cp_declarations.append(f"array[0..{number_of_mix - 1}] of var 0..1: {output_id_link}_i;") all_inputs += [[output_id_link, i] for i in range(output_size // model.word_size)] input_len = output_size // model.word_size for i in range(input_len): - input_bit_positions, input_id_link = \ - get_transformed_xor_input_links_and_positions(model.word_size, all_inputs, i, - input_len, numadd, numb_of_inp) + input_bit_positions, input_id_link = get_transformed_xor_input_links_and_positions( + model.word_size, all_inputs, i, input_len, numadd, numb_of_inp + ) input_bits = 0 for input_bit in input_bit_positions: input_bits += len(input_bit) xor_component = XOR("", "", input_id_link, input_bit_positions, input_bits) - xor_component.set_description(['XOR', numadd + 1]) + xor_component.set_description(["XOR", numadd + 1]) model.list_of_xor_components.append(xor_component) cp_constraints = [] diff --git a/claasp/compound_xor_differential_cipher.py b/claasp/compound_xor_differential_cipher.py index 97b60d9d0..7989e4bd3 100644 --- a/claasp/compound_xor_differential_cipher.py +++ b/claasp/compound_xor_differential_cipher.py @@ -1,12 +1,13 @@ from copy import deepcopy from claasp.components.xor_component import XOR +from claasp.name_mappings import CIPHER_OUTPUT, INTERMEDIATE_OUTPUT def get_component_pair(round_component_): original_component = deepcopy(round_component_) - new_id_pair1 = f'{original_component.id}_pair1' - new_id_pair2 = f'{original_component.id}_pair2' + new_id_pair1 = f"{original_component.id}_pair1" + new_id_pair2 = f"{original_component.id}_pair2" original_component.set_id(new_id_pair1) component_copy = deepcopy(original_component) component_copy.set_id(new_id_pair2) @@ -18,10 +19,10 @@ def update_input_id_links(component1_, component2_): input_id_links2 = component2_.input_id_links new_input_id_link1 = [] for input_id_link1 in input_id_links1: - new_input_id_link1.append(f'{input_id_link1}_pair1') + new_input_id_link1.append(f"{input_id_link1}_pair1") new_input_id_link2 = [] for input_id_link2 in input_id_links2: - new_input_id_link2.append(f'{input_id_link2}_pair2') + new_input_id_link2.append(f"{input_id_link2}_pair2") component1_.set_input_id_links(new_input_id_link1) component2_.set_input_id_links(new_input_id_link2) @@ -31,8 +32,8 @@ def update_cipher_inputs(cipher): new_inputs_pair2 = [] old_cipher_inputs_ = deepcopy(cipher.inputs) for cipher_input in cipher.inputs: - new_inputs_pair1.append(f'{cipher_input}_pair1') - new_inputs_pair2.append(f'{cipher_input}_pair2') + new_inputs_pair1.append(f"{cipher_input}_pair1") + new_inputs_pair2.append(f"{cipher_input}_pair2") cipher._inputs = new_inputs_pair1 + new_inputs_pair2 return old_cipher_inputs_ @@ -42,14 +43,20 @@ def create_xor_component_inputs(old_cipher_inputs_, cipher, round_object): half_number_of_cipher_inputs = int(len(cipher.inputs_bit_size) / 2) i = 0 for cipher_input in old_cipher_inputs_: - input_link_positions = [list(range(cipher.inputs_bit_size[i]))] + \ - [list(range(cipher.inputs_bit_size[i + half_number_of_cipher_inputs]))] - input_links = [f'{cipher_input}_pair1', f'{cipher_input}_pair2'] + input_link_positions = [list(range(cipher.inputs_bit_size[i]))] + [ + list(range(cipher.inputs_bit_size[i + half_number_of_cipher_inputs])) + ] + input_links = [f"{cipher_input}_pair1", f"{cipher_input}_pair2"] current_components_number = round_object.get_number_of_components() output_bit_size = cipher.inputs_bit_size[i] - new_xor_component = XOR(0, current_components_number, input_links, input_link_positions, - output_bit_size) - new_xor_component.set_id(f'{cipher_input}_pair1_pair2') + new_xor_component = XOR( + 0, + current_components_number, + input_links, + input_link_positions, + output_bit_size, + ) + new_xor_component.set_id(f"{cipher_input}_pair1_pair2") round_object.add_component(new_xor_component) i += 1 @@ -59,10 +66,16 @@ def create_xor_component(component1_, component2_, round_object, round_number): input_links = [component1_.id, component2_.id] current_components_number = round_object.get_number_of_components() output_bit_size = component1_.output_bit_size - new_xor_component = XOR(round_number, current_components_number, input_links, input_link_positions, output_bit_size) + new_xor_component = XOR( + round_number, + current_components_number, + input_links, + input_link_positions, + output_bit_size, + ) - if component1_.type == 'intermediate_output' or component1_.type == 'cipher_output': - component_id = "_".join(component1_.id.split('_')[:-1]) + if component1_.type == INTERMEDIATE_OUTPUT or component1_.type == CIPHER_OUTPUT: + component_id = "_".join(component1_.id.split("_")[:-1]) new_xor_component.set_id(component_id) round_object.add_component(new_xor_component) diff --git a/claasp/editor.py b/claasp/editor.py index a1ebaa688..46c12de04 100644 --- a/claasp/editor.py +++ b/claasp/editor.py @@ -1,54 +1,61 @@ -import sys # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** +import sys from copy import deepcopy -from claasp.components.or_component import OR from claasp.components.and_component import AND -from claasp.components.theta_gaston_component import ThetaGaston -from claasp.components.xor_component import XOR -from claasp.components.not_component import NOT +from claasp.components.cipher_output_component import CipherOutput +from claasp.components.concatenate_component import Concatenate +from claasp.components.constant_component import Constant from claasp.components.fsr_component import FSR -from claasp.components.sbox_component import SBOX -from claasp.components.shift_component import SHIFT -from claasp.components.sigma_component import Sigma -from claasp.components.rotate_component import Rotate +from claasp.components.intermediate_output_component import IntermediateOutput +from claasp.components.linear_layer_component import LinearLayer +from claasp.components.mix_column_component import MixColumn from claasp.components.modadd_component import MODADD from claasp.components.modsub_component import MODSUB +from claasp.components.not_component import NOT +from claasp.components.or_component import OR +from claasp.components.permutation_component import Permutation from claasp.components.reverse_component import Reverse -from claasp.components.constant_component import Constant +from claasp.components.rotate_component import Rotate +from claasp.components.sbox_component import SBOX +from claasp.components.shift_component import SHIFT from claasp.components.shift_rows_component import ShiftRows -from claasp.components.mix_column_component import MixColumn -from claasp.components.permutation_component import Permutation -from claasp.components.concatenate_component import Concatenate -from claasp.components.linear_layer_component import LinearLayer -from claasp.components.theta_xoodoo_component import ThetaXoodoo +from claasp.components.sigma_component import Sigma +from claasp.components.theta_gaston_component import ThetaGaston from claasp.components.theta_keccak_component import ThetaKeccak -from claasp.components.cipher_output_component import CipherOutput -from claasp.components.variable_shift_component import VariableShift +from claasp.components.theta_xoodoo_component import ThetaXoodoo from claasp.components.variable_rotate_component import VariableRotate +from claasp.components.variable_shift_component import VariableShift from claasp.components.word_permutation_component import WordPermutation -from claasp.components.intermediate_output_component import IntermediateOutput -from claasp.name_mappings import INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, CONSTANT, INPUT_KEY, LINEAR_LAYER +from claasp.components.xor_component import XOR +from claasp.name_mappings import ( + INTERMEDIATE_OUTPUT, + CIPHER_OUTPUT, + CONSTANT, + INPUT_KEY, + LINEAR_LAYER, +) -cipher_round_not_found_error = "Error! The cipher has no round: please run self.add_round() before adding any " \ - "component. " +CIPHER_ROUND_NOT_FOUND_ERROR = ( + "Error! The cipher has no round: please run self.add_round() before adding any component. " +) def add_AND_component(cipher, input_id_links, input_bit_positions, output_bit_size): @@ -88,11 +95,16 @@ def add_AND_component(cipher, input_id_links, input_bit_positions, output_bit_si cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = AND(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = AND( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component @@ -134,11 +146,16 @@ def add_cipher_output_component(cipher, input_id_links, input_bit_positions, out cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = CipherOutput(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = CipherOutput( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component @@ -184,11 +201,16 @@ def add_concatenate_component(cipher, input_id_links, input_bit_positions, outpu cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = Concatenate(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = Concatenate( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component @@ -239,11 +261,15 @@ def add_constant_component(cipher, output_bit_size, value): cipher_reference_code = None """ if cipher.current_round_number is None: - print("Error! The cipher has no rounds: please run self.add_round() before adding any component.") + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = Constant(cipher.current_round_number, cipher.current_round_number_of_components, - output_bit_size, value) + new_component = Constant( + cipher.current_round_number, + cipher.current_round_number_of_components, + output_bit_size, + value, + ) add_component(cipher, new_component) return new_component @@ -303,11 +329,17 @@ def add_FSR_component(cipher, input_id_links, input_bit_positions, output_bit_si """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = FSR(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, description) + new_component = FSR( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) add_component(cipher, new_component) return new_component @@ -350,11 +382,17 @@ def add_intermediate_output_component(cipher, input_id_links, input_bit_position cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = IntermediateOutput(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, output_tag) + new_component = IntermediateOutput( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + output_tag, + ) add_component(cipher, new_component) return new_component @@ -398,11 +436,17 @@ def add_linear_layer_component(cipher, input_id_links, input_bit_positions, outp cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = LinearLayer(cipher.current_round_number, cipher.current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, description) + new_component = LinearLayer( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) add_component(cipher, new_component) return new_component @@ -445,11 +489,17 @@ def add_mix_column_component(cipher, input_id_links, input_bit_positions, output cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = MixColumn(cipher.current_round_number, cipher.current_round_number_of_components, input_id_links, - input_bit_positions, output_bit_size, mix_column_description) + new_component = MixColumn( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + mix_column_description, + ) add_component(cipher, new_component) return new_component @@ -491,11 +541,17 @@ def add_MODADD_component(cipher, input_id_links, input_bit_positions, output_bit cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = MODADD(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, modulus) + new_component = MODADD( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + modulus, + ) add_component(cipher, new_component) return new_component @@ -537,11 +593,17 @@ def add_MODSUB_component(cipher, input_id_links, input_bit_positions, output_bit cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = MODSUB(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, modulus) + new_component = MODSUB( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + modulus, + ) add_component(cipher, new_component) return new_component @@ -583,11 +645,16 @@ def add_NOT_component(cipher, input_id_links, input_bit_positions, output_bit_si cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = NOT(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = NOT( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component @@ -629,17 +696,27 @@ def add_OR_component(cipher, input_id_links, input_bit_positions, output_bit_siz cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = OR(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = OR( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component -def add_permutation_component(cipher, input_id_links, input_bit_positions, output_bit_size, - permutation_description): +def add_permutation_component( + cipher, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, +): """ Create a permutation component to permute the bit position in the editor. @@ -677,11 +754,17 @@ def add_permutation_component(cipher, input_id_links, input_bit_positions, outpu cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = Permutation(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, permutation_description) + new_component = Permutation( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + ) add_component(cipher, new_component) return new_component @@ -724,11 +807,16 @@ def add_reverse_component(cipher, input_id_links, input_bit_positions, output_bi """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = Reverse(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = Reverse( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component @@ -772,11 +860,17 @@ def add_rotate_component(cipher, input_id_links, input_bit_positions, output_bit cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = Rotate(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter) + new_component = Rotate( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ) add_component(cipher, new_component) return new_component @@ -812,8 +906,15 @@ def add_round(cipher): } """ cipher.rounds.add_round() - cipher.set_id(make_cipher_id(cipher.family_name, cipher.inputs, cipher.inputs_bit_size, - cipher.output_bit_size, cipher.number_of_rounds)) + cipher.set_id( + make_cipher_id( + cipher.family_name, + cipher.inputs, + cipher.inputs_bit_size, + cipher.output_bit_size, + cipher.number_of_rounds, + ) + ) cipher.set_file_name(make_file_name(cipher.id)) @@ -854,11 +955,17 @@ def add_round_key_output_component(cipher, input_id_links, input_bit_positions, cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = IntermediateOutput(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, 'round_key_output') + new_component = IntermediateOutput( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + "round_key_output", + ) add_component(cipher, new_component) return new_component @@ -900,11 +1007,17 @@ def add_round_output_component(cipher, input_id_links, input_bit_positions, outp cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = IntermediateOutput(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, 'round_output') + new_component = IntermediateOutput( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + "round_output", + ) add_component(cipher, new_component) return new_component @@ -948,11 +1061,17 @@ def add_SBOX_component(cipher, input_id_links, input_bit_positions, output_bit_s cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = SBOX(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, description) + new_component = SBOX( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + description, + ) add_component(cipher, new_component) return new_component @@ -996,11 +1115,17 @@ def add_SHIFT_component(cipher, input_id_links, input_bit_positions, output_bit_ cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = SHIFT(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter) + new_component = SHIFT( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ) add_component(cipher, new_component) return new_component @@ -1044,16 +1169,28 @@ def add_shift_rows_component(cipher, input_id_links, input_bit_positions, output cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = ShiftRows(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter) + new_component = ShiftRows( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ) add_component(cipher, new_component) return new_component -def add_sigma_component(cipher, input_id_links, input_bit_positions, output_bit_size, rotation_amounts_parameter): +def add_sigma_component( + cipher, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, +): """ Use this function to create a sigma component in cipher. @@ -1096,16 +1233,28 @@ def add_sigma_component(cipher, input_id_links, input_bit_positions, output_bit_ cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - linear_layer_component = Sigma(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, - output_bit_size, rotation_amounts_parameter) + linear_layer_component = Sigma( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ) add_component(cipher, linear_layer_component) return linear_layer_component -def add_theta_gaston_component(cipher, input_id_links, input_bit_positions, output_bit_size, rotation_amounts_parameter): + +def add_theta_gaston_component( + cipher, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, +): """ Use this function to create the theta component of Gaston in cipher. @@ -1136,15 +1285,21 @@ def add_theta_gaston_component(cipher, input_id_links, input_bit_positions, outp 3520 """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - linear_layer_component = ThetaGaston(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, - output_bit_size, rotation_amounts_parameter) + linear_layer_component = ThetaGaston( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + rotation_amounts_parameter, + ) add_component(cipher, linear_layer_component) return linear_layer_component + def add_theta_keccak_component(cipher, input_id_links, input_bit_positions, output_bit_size): """ Use this function to create the theta component of Keccak in cipher. @@ -1172,11 +1327,16 @@ def add_theta_keccak_component(cipher, input_id_links, input_bit_positions, outp 'linear_layer' """ if cipher.number_of_rounds == 0: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = ThetaKeccak(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = ThetaKeccak( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component @@ -1208,11 +1368,16 @@ def add_theta_xoodoo_component(cipher, input_id_links, input_bit_positions, outp 'linear_layer' """ if cipher.number_of_rounds == 0: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - theta_xoodoo_component = ThetaXoodoo(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + theta_xoodoo_component = ThetaXoodoo( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, theta_xoodoo_component) return deepcopy(theta_xoodoo_component) @@ -1257,11 +1422,17 @@ def add_variable_rotate_component(cipher, input_id_links, input_bit_positions, o cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = VariableRotate(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter) + new_component = VariableRotate( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ) add_component(cipher, new_component) return new_component @@ -1304,17 +1475,29 @@ def add_variable_shift_component(cipher, input_id_links, input_bit_positions, ou cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = VariableShift(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size, parameter) + new_component = VariableShift( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + parameter, + ) add_component(cipher, new_component) return new_component -def add_word_permutation_component(cipher, input_id_links, input_bit_positions, output_bit_size, - permutation_description, word_size): +def add_word_permutation_component( + cipher, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + word_size, +): """ Create a permutation component to permute the word position in the editor. @@ -1353,12 +1536,18 @@ def add_word_permutation_component(cipher, input_id_links, input_bit_positions, cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = WordPermutation(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, - output_bit_size, permutation_description, word_size) + new_component = WordPermutation( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + permutation_description, + word_size, + ) add_component(cipher, new_component) return new_component @@ -1400,11 +1589,16 @@ def add_XOR_component(cipher, input_id_links, input_bit_positions, output_bit_si cipher_reference_code = None """ if cipher.current_round_number is None: - print(cipher_round_not_found_error) + print(CIPHER_ROUND_NOT_FOUND_ERROR) return None - new_component = XOR(cipher.current_round_number, cipher.current_round_number_of_components, - input_id_links, input_bit_positions, output_bit_size) + new_component = XOR( + cipher.current_round_number, + cipher.current_round_number_of_components, + input_id_links, + input_bit_positions, + output_bit_size, + ) add_component(cipher, new_component) return new_component @@ -1448,26 +1642,29 @@ def get_unique_links_information(new_links): return unique_lengths, unique_links -def is_linear_layer_permutation(M, M_T): - ones = [1] * len(M) - M_has_only_one_1_in_rows = ([sum(row) for row in M] == ones) - M_has_only_one_1_in_cols = ([sum(row) for row in M_T] == ones) +def is_linear_layer_permutation(matrix, matrix_transposed): + ones = (1,) * len(matrix) + has_only_one_1_in_rows = tuple(sum(row) for row in matrix) == ones + has_only_one_1_in_cols = tuple(sum(row) for row in matrix_transposed) == ones - return M_has_only_one_1_in_rows and M_has_only_one_1_in_cols + return has_only_one_1_in_rows and has_only_one_1_in_cols -def make_cipher_id(family_name, inputs, inputs_bit_size, - output_bit_size, number_of_rounds): - cipher_id = f'{family_name}' - for i in range(len(inputs)): - cipher_id += f'_{inputs[i][0]}{inputs_bit_size[i]}' +def make_cipher_id(family_name, inputs, inputs_bit_size, output_bit_size, number_of_rounds): + tokens = [f"{family_name}",] + for input_, size in zip(inputs, inputs_bit_size): + if not input_.startswith("input_"): + tokens += [f"{input_[0]}{size}"] + else: + tokens +=[f"{input_[6]}{size}"] + tokens += [f"o{output_bit_size}", f"r{number_of_rounds}"] + cipher_id = "_".join(tokens) - cipher_id += f'_o{output_bit_size}_r{number_of_rounds}' return cipher_id def make_file_name(cipher_id): - return f'{cipher_id}.py' + return f"{cipher_id}.py" def next_component_index_from(index): @@ -1485,13 +1682,12 @@ def propagate_equivalences(cipher, round_id, component_id, new_expanded_links, n new_input_positions = [new_positions[i] for i in old_positions] unique_lengths, unique_links = get_unique_links_information(new_links) final_input_positions = get_final_input_positions(new_input_positions, unique_lengths) - input_id_links = input_id_link[:id_index] + unique_links \ - + input_id_link[id_index + 1:] + input_id_links = input_id_link[:id_index] + unique_links + input_id_link[id_index + 1 :] component.set_input_id_links(input_id_links) input_bit_positions = component.input_bit_positions - component.set_input_bit_positions(input_bit_positions[:id_index] \ - + final_input_positions \ - + input_bit_positions[id_index + 1:]) + component.set_input_bit_positions( + input_bit_positions[:id_index] + final_input_positions + input_bit_positions[id_index + 1 :] + ) while [] in component.input_bit_positions: component.input_bit_positions.remove([]) @@ -1502,22 +1698,28 @@ def propagate_permutations(cipher): for round_ in cipher_without_permutations.rounds_as_list: for component in round_.components: if component.type == LINEAR_LAYER: - M = component.description - number_of_rows = len(M) - number_of_columns = len(M[0]) - M_is_square = (number_of_rows == number_of_columns) - if M_is_square: - M_T = [[M[i][j] for i in range(number_of_rows)] for j in range(number_of_columns)] - if is_linear_layer_permutation(M, M_T): + matrix = component.description + nrows = len(matrix) + ncols = len(matrix[0]) + matrix_is_square = nrows == ncols + if matrix_is_square: + matrix_transposed = [[matrix[i][j] for i in range(nrows)] for j in range(ncols)] + if is_linear_layer_permutation(matrix, matrix_transposed): ids_of_permutations.append(component.id) input_bit_positions = component.input_bit_positions expanded_links = generate_expanded_links(component, input_bit_positions) - flat_input_bit_positions = [position for positions in input_bit_positions - for position in positions] - new_expanded_links = [expanded_links[row.index(1)] for row in M_T] - new_positions = [flat_input_bit_positions[row.index(1)] for row in M_T] - propagate_equivalences(cipher_without_permutations, round_.id, component.id, - new_expanded_links, new_positions) + flat_input_bit_positions = [ + position for positions in input_bit_positions for position in positions + ] + new_expanded_links = [expanded_links[row.index(1)] for row in matrix_transposed] + new_positions = [flat_input_bit_positions[row.index(1)] for row in matrix_transposed] + propagate_equivalences( + cipher_without_permutations, + round_.id, + component.id, + new_expanded_links, + new_positions, + ) return (ids_of_permutations, cipher_without_permutations) @@ -1525,18 +1727,22 @@ def propagate_rotations(cipher): cipher_without_rotations = deepcopy(cipher) for round_ in cipher_without_rotations.rounds_as_list: for component in round_.components: - if component.description[0] == 'ROTATE': + if component.description[0] == "ROTATE": input_bit_positions = component.input_bit_positions expanded_links = [] for link, positions in zip(component.input_id_links, input_bit_positions): expanded_links.extend([link] * len(positions)) - flat_input_bit_positions = [position for positions in input_bit_positions - for position in positions] + flat_input_bit_positions = [position for positions in input_bit_positions for position in positions] amount = component.description[1] new_expanded_links = expanded_links[-amount:] + expanded_links[:-amount] new_positions = flat_input_bit_positions[-amount:] + flat_input_bit_positions[:-amount] - propagate_equivalences(cipher_without_rotations, round_.id, component.id, - new_expanded_links, new_positions) + propagate_equivalences( + cipher_without_rotations, + round_.id, + component.id, + new_expanded_links, + new_positions, + ) return cipher_without_rotations @@ -1618,7 +1824,10 @@ def remove_key_schedule(cipher, keep_round_key_injection=True): for round_ in cipher_without_key_schedule.rounds_as_list: for component in round_.components: if any("key" in id for id in component.input_id_links): - key_index = next((i for i, link in enumerate(component.input_id_links) if "key" in link), None) + key_index = next( + (i for i, link in enumerate(component.input_id_links) if "key" in link), + None, + ) component.input_id_links.pop(key_index) component.input_bit_positions.pop(key_index) if len(component.input_bit_positions) == 1: @@ -1806,7 +2015,7 @@ def remove_rotations(cipher): cipher_without_rotations = propagate_rotations(cipher) for round_ in cipher.rounds_as_list: for component in round_.components: - if component.description[0] == 'ROTATE': + if component.description[0] == "ROTATE": cipher_without_rotations.remove_round_component_from_id(round_.id, component.id) return cipher_without_rotations @@ -1913,8 +2122,10 @@ def sort_cipher(cipher): for i in range(cipher.number_of_rounds): current_round = cipher.rounds.round_at(i) for fixed_index in range(current_round.get_number_of_components()): - for moving_index in range(next_component_index_from(fixed_index), - current_round.number_of_components): + for moving_index in range( + next_component_index_from(fixed_index), + current_round.number_of_components, + ): if current_round.is_component_input(fixed_index, moving_index): current_round.swap_components(fixed_index, moving_index) @@ -1931,7 +2142,7 @@ def update_component_inputs(component, component_id, parent_links): parent_links.add(component.id) input_id_links = component.input_id_links for i in range(len(input_id_links)): - if input_id_links[i] not in parent_links and input_id_links[i] != '': + if input_id_links[i] not in parent_links and input_id_links[i] != "": input_id_links[i] = component_id bit_len = len(component.input_bit_positions[i]) component.input_bit_positions[i] = list(range(offset, bit_len + offset)) @@ -1944,11 +2155,12 @@ def update_inputs(cipher_without_key_schedule, keep_round_key_addition): parent_links = set(cipher_without_key_schedule.inputs) for cipher_round in cipher_without_key_schedule.rounds_as_list: for index, component in enumerate(cipher_round.components): - component_id = f'key_{cipher_round.id}_{index}' + component_id = f"key_{cipher_round.id}_{index}" modified, offset = update_component_inputs(component, component_id, parent_links) if keep_round_key_addition: update_cipher_inputs(cipher_without_key_schedule, component_id, modified, offset) + def get_output_bit_size_from_id(cipher_list, component_id): try: for cipher in cipher_list: @@ -1956,8 +2168,6 @@ def get_output_bit_size_from_id(cipher_list, component_id): return cipher.inputs_bit_size[cipher.inputs.index(component_id)] elif component_id in cipher.get_all_components_ids(): return cipher.get_component_from_id(component_id).output_bit_size - raise ValueError(f'{component_id} not found.') + raise ValueError(f"{component_id} not found.") except ValueError as e: sys.exit(str(e)) - - diff --git a/claasp/input.py b/claasp/input.py index 7acd5ceb5..aaf6551c5 100644 --- a/claasp/input.py +++ b/claasp/input.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** diff --git a/claasp/name_mappings.py b/claasp/name_mappings.py index c49b3d674..587984b92 100644 --- a/claasp/name_mappings.py +++ b/claasp/name_mappings.py @@ -1,11 +1,11 @@ -# cipher_type +# cipher types BLOCK_CIPHER = "block_cipher" STREAM_CIPHER = "stream_cipher" TWEAKABLE_BLOCK_CIPHER = "tweakable_block_cipher" PERMUTATION = "permutation" HASH_FUNCTION = "hash_function" -# CIPHER INPUTS +# cipher inputs INPUT_KEY = "key" INPUT_PLAINTEXT = "plaintext" INPUT_INITIALIZATION_VECTOR = "initialization_vector" @@ -30,12 +30,16 @@ FSR = "fsr" # model types -CIPHER = 'cipher' -XOR_DIFFERENTIAL = 'xor_differential' -XOR_LINEAR = 'xor_linear' -DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL = 'deterministic_truncated_xor_differential' -IMPOSSIBLE_XOR_DIFFERENTIAL = 'impossible_xor_differential' -BOOMERANG_XOR_DIFFERENTIAL = 'boomerang_xor_differential' +CIPHER = "cipher" +XOR_DIFFERENTIAL = "xor_differential" +XOR_LINEAR = "xor_linear" +DETERMINISTIC_TRUNCATED_XOR_DIFFERENTIAL = "deterministic_truncated_xor_differential" +IMPOSSIBLE_XOR_DIFFERENTIAL = "impossible_xor_differential" +BOOMERANG_XOR_DIFFERENTIAL = "boomerang_xor_differential" # cipher inverse -CIPHER_INVERSE_SUFFIX = "_inverse" \ No newline at end of file +CIPHER_INVERSE_SUFFIX = "_inverse" + +# models +SATISFIABLE = "SATISFIABLE" +UNSATISFIABLE = "UNSATISFIABLE" diff --git a/claasp/round.py b/claasp/round.py index 3242a510b..721eefb99 100644 --- a/claasp/round.py +++ b/claasp/round.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -28,8 +27,7 @@ def add_component(self, component): def are_there_forbidden_components(self, forbidden_types, forbidden_descriptions): is_there_forbidden_component = False for component in self._components: - is_there_forbidden_component = component.is_forbidden(forbidden_types, - forbidden_descriptions) + is_there_forbidden_component = component.is_forbidden(forbidden_types, forbidden_descriptions) if is_there_forbidden_component: return is_there_forbidden_component @@ -56,8 +54,7 @@ def get_round_from_component_id(self, component_id): return self._id def is_component_input(self, fixed_index, moving_index): - return self._components[moving_index].id in \ - self._components[fixed_index].input_id_links + return self._components[moving_index].id in self._components[fixed_index].input_id_links def is_power_of_2_word_based(self, dto): for component in self._components: @@ -69,8 +66,7 @@ def is_power_of_2_word_based(self, dto): def print_round(self): for component_number in range(self.number_of_components): - print("\n # round = {} - round component = {}" - .format(self._id, component_number)) + print(f"\n # round = {self._id} - round component = {component_number}") requested_component = self.component_from(component_number) requested_component.print() @@ -123,4 +119,3 @@ def update_input_id_links_from_component_id(self, component_id, new_input_id_lin break i += 1 self._components[i].set_input_id_links(new_input_id_links) - diff --git a/claasp/rounds.py b/claasp/rounds.py index e346611f9..5aa1b5d68 100644 --- a/claasp/rounds.py +++ b/claasp/rounds.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -37,8 +36,9 @@ def add_round(self): def are_there_not_forbidden_components(self, forbidden_types, forbidden_descriptions): for cipher_round in self._rounds: - are_there_forbidden_components = cipher_round.are_there_forbidden_components(forbidden_types, - forbidden_descriptions) + are_there_forbidden_components = cipher_round.are_there_forbidden_components( + forbidden_types, forbidden_descriptions + ) if are_there_forbidden_components: return not are_there_forbidden_components @@ -115,7 +115,7 @@ def get_component_from_id(self, component_id): component = cipher_round.get_component_from_id(component_id) if component is not None: return component - raise ValueError(f'Component with id {component_id} not found.') + raise ValueError(f"Component with id {component_id} not found.") def get_round_from_component_id(self, component_id): for cipher_round in self._rounds: diff --git a/claasp/utils/integer.py b/claasp/utils/integer.py index c46eeb188..546e4260b 100644 --- a/claasp/utils/integer.py +++ b/claasp/utils/integer.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** diff --git a/claasp/utils/integer_functions.py b/claasp/utils/integer_functions.py index 68e4ed45e..6c15dfad2 100644 --- a/claasp/utils/integer_functions.py +++ b/claasp/utils/integer_functions.py @@ -1,28 +1,26 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - -def bytearray_to_int(data, endianess='big'): +def bytearray_to_int(data, endianess="big"): return int.from_bytes(data, endianess, signed=False) -def int_to_bytearray(data, size, endianess='big'): +def int_to_bytearray(data, size, endianess="big"): return bytearray(data.to_bytes(size // 8, endianess)) @@ -42,26 +40,26 @@ def wordlist_to_bytearray(data, word_size, size=None): return int_to_bytearray(data_int, size) -def int_to_wordlist(value, word_size, size, endianess='big'): +def int_to_wordlist(value, word_size, size, endianess="big"): wordlist = [] for _ in range(size // word_size): wordlist.append(value % 2**word_size) value = value >> word_size - if endianess == 'big': + if endianess == "big": wordlist.reverse() return wordlist -def wordlist_to_int(wordlist, word_size, endianess='big'): +def wordlist_to_int(wordlist, word_size, endianess="big"): value = 0 - if endianess == 'little': - for i in range(len(wordlist)): - value += wordlist[i] * 2**(word_size * i) - else: - for i in range(len(wordlist)): - value += wordlist[i] * 2**(word_size * (len(wordlist) - i - 1)) + if endianess == "little": + ordered_list = wordlist + elif endianess == "big": + ordered_list = reversed(wordlist) + for i, word in enumerate(ordered_list): + value += word * 2 ** (word_size * i) return value diff --git a/claasp/utils/sage_scripts.py b/claasp/utils/sage_scripts.py index 7022a386f..f9ee076a2 100644 --- a/claasp/utils/sage_scripts.py +++ b/claasp/utils/sage_scripts.py @@ -1,46 +1,46 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import yaml from os import listdir +from claasp.name_mappings import BLOCK_CIPHER, HASH_FUNCTION, PERMUTATION, STREAM_CIPHER + def get_cipher(cipher_module): for name in cipher_module.__dict__: - if 'BlockCipher' in name or 'HashFunction' in name or 'Permutation' in name: + if "BlockCipher" in name or "HashFunction" in name or "Permutation" in name: return cipher_module.__dict__[name] return 0 def get_ciphers(): - ciphers_files = listdir('claasp/ciphers/block_ciphers') - ciphers_files.extend(listdir('claasp/ciphers/permutations')) - ciphers_files.extend(listdir('claasp/ciphers/hash_functions')) - ciphers_files.extend(listdir('claasp/ciphers/stream_ciphers')) + ciphers_files = listdir("claasp/ciphers/block_ciphers") + ciphers_files.extend(listdir("claasp/ciphers/permutations")) + ciphers_files.extend(listdir("claasp/ciphers/hash_functions")) + ciphers_files.extend(listdir("claasp/ciphers/stream_ciphers")) ciphers_files = list(set(ciphers_files)) ciphers = [cipher for cipher in ciphers_files if get_cipher_type(cipher)] return ciphers def make_cipher_id(cipher_family_name, inputs, inputs_bit_size, output_bit_size): - cipher_id = f'{cipher_family_name}' + cipher_id = f"{cipher_family_name}" for i in range(len(inputs)): cipher_id = cipher_id + "_" + inputs[i][0] + str(inputs_bit_size[i]) cipher_id = cipher_id + "_o" + str(output_bit_size) @@ -52,7 +52,7 @@ def create_scenario_string(scenario_dict): final = [] conversions = {">": "greater", "=": "equal", ">=": "greater_or_equal"} for key, value in scenario_dict.items(): - final.append(f'{key}_{conversions[value]}') + final.append(f"{key}_{conversions[value]}") final.sort() return "_".join(final) @@ -68,13 +68,13 @@ def load_parameters(file_path): def get_cipher_type(cipher_filename): cipher_type = "" - if "block_cipher" in cipher_filename: + if BLOCK_CIPHER in cipher_filename: cipher_type = "block_ciphers" - elif "permutation" in cipher_filename: + elif PERMUTATION in cipher_filename: cipher_type = "permutations" - elif "hash_function" in cipher_filename: + elif HASH_FUNCTION in cipher_filename: cipher_type = "hash_functions" - elif "stream_cipher" in cipher_filename: + elif STREAM_CIPHER in cipher_filename: cipher_type = "stream_ciphers" return cipher_type diff --git a/claasp/utils/sequence_operations.py b/claasp/utils/sequence_operations.py index 253d9df1f..c31d8071d 100644 --- a/claasp/utils/sequence_operations.py +++ b/claasp/utils/sequence_operations.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** diff --git a/claasp/utils/templates.py b/claasp/utils/templates.py index 71221d1cf..0f58aad9d 100644 --- a/claasp/utils/templates.py +++ b/claasp/utils/templates.py @@ -1,17 +1,16 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** @@ -52,10 +51,10 @@ def get_template(self): class Template: - def __init__(self): - self._j2_env = Environment(loader=FileSystemLoader('claasp/utils/tii_reports'), - trim_blocks=True, autoescape=True) + self._j2_env = Environment( + loader=FileSystemLoader("claasp/utils/tii_reports"), trim_blocks=True, autoescape=True + ) self.__header = None self.__footer = None self.__body = None @@ -70,10 +69,9 @@ def set_footer(self, footer): self.__footer = footer def render_template(self, rule_data_): - return self._j2_env.get_template(rule_data_['template_path']).render(header=self.__header.content, - body=self.__body.content, - footer=self.__footer.content, - rule_data=rule_data_) + return self._j2_env.get_template(rule_data_["template_path"]).render( + header=self.__header.content, body=self.__body.content, footer=self.__footer.content, rule_data=rule_data_ + ) class Builder(object): @@ -84,9 +82,14 @@ class Builder(object): """ - def get_header(self): pass - def get_footer(self): pass - def get_body(self): pass + def get_header(self): + pass + + def get_footer(self): + pass + + def get_body(self): + pass class LatexBuilder(Builder): @@ -102,13 +105,13 @@ def __init__(self, data): def get_header(self): header = Header() - header.content = 'TII - Latex - Report' + header.content = "TII - Latex - Report" return header def get_footer(self): footer = Footer() - footer.content = '' + footer.content = "" return footer @@ -132,13 +135,13 @@ def __init__(self, data): def get_header(self): header = Header() - header.content = 'TII - CSV - Report' + header.content = "TII - CSV - Report" return header def get_footer(self): footer = Footer() - footer.content = '' + footer.content = "" return footer diff --git a/claasp/utils/utils.py b/claasp/utils/utils.py index cf8c0f2b1..de2827569 100644 --- a/claasp/utils/utils.py +++ b/claasp/utils/utils.py @@ -1,22 +1,20 @@ - # **************************************************************************** # Copyright 2023 Technology Innovation Institute -# +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . # **************************************************************************** - import json import pprint import random @@ -84,7 +82,7 @@ def aggregate_list_of_dictionary(dataset, group_by_key, sum_value_keys): def bytes_positions_to_little_endian_for_32_bits(lst): - r""" + """ Read the bytes positions in little-endian order. INPUT: @@ -102,7 +100,7 @@ def bytes_positions_to_little_endian_for_32_bits(lst): temp_lst = [] for j in range(4): - temp_lst += lst[(3 - j) * 8:(3 - j) * 8 + 8] + temp_lst += lst[(3 - j) * 8 : (3 - j) * 8 + 8] return temp_lst @@ -110,7 +108,7 @@ def bytes_positions_to_little_endian_for_32_bits(lst): def bytes_positions_to_little_endian_for_multiple_of_32(lst, number_of_blocks): output_lst = [] for block_number in range(number_of_blocks): - temp_lst = lst[block_number * 32:block_number * 32 + 32] + temp_lst = lst[block_number * 32 : block_number * 32 + 32] temp2_lst = bytes_positions_to_little_endian_for_32_bits(temp_lst) output_lst.append(temp2_lst) @@ -189,7 +187,7 @@ def get_2d_array_element_from_1d_array_index(i, lst, array_dim): def get_ci(i, qi, si, t): q = qi[(i % 7)] s = si[(i % 6)] - ci = t ** s * (q + t ** 3) + ci = t**s * (q + t**3) _ci = ci.change_ring(IntegerRing()) return _ci(2) @@ -221,8 +219,8 @@ def get_number_of_rounds_from(block_bit_size, key_bit_size, number_of_rounds, pa if number_of_rounds == 0: n = None for parameters in parameters_configurations: - if parameters['block_bit_size'] == block_bit_size and parameters['key_bit_size'] == key_bit_size: - n = parameters['number_of_rounds'] + if parameters["block_bit_size"] == block_bit_size and parameters["key_bit_size"] == key_bit_size: + n = parameters["number_of_rounds"] break if n is None: raise ValueError("No available number of rounds for the given parameters.") @@ -232,24 +230,6 @@ def get_number_of_rounds_from(block_bit_size, key_bit_size, number_of_rounds, pa return n -def get_k_th_bit(n, k): - """ - Return the k-th bit of the number n. - - INPUT: - - - ``n`` -- **integer**; integer number - - ``k`` -- **integer**; integer number representing the index of the bit we need - - EXAMPLES:: - - sage: from claasp.utils.utils import get_k_th_bit - sage: get_k_th_bit(3, 0) - 1 - """ - return 1 & (n >> k) - - def group_list_by_key(lst): """ Group list of dictionaries by key. @@ -266,6 +246,7 @@ def group_list_by_key(lst): defaultdict(, {'cipher_output': [[{'1': 0}], [{'2': 0}], [{'4': 0}]], 'round_key_output': [[{'1': 0}], [{'3': 0}], [{'2': 0}]]}) """ from collections import defaultdict + joint_results_objects_group_by_tag_output = defaultdict(list) for value in lst: for key, item in value.items(): @@ -287,9 +268,10 @@ def layer_and_lane_initialization(plane_num=3, lane_num=4, lane_size=32): planes = [] plane_size = lane_num * lane_size for i in range(plane_num): - p = ComponentState([INPUT_PLAINTEXT for _ in range(lane_num)], - [[k + j * lane_size + i * plane_size for k in range(lane_size)] - for j in range(lane_num)]) + p = ComponentState( + [INPUT_PLAINTEXT for _ in range(lane_num)], + [[k + j * lane_size + i * plane_size for k in range(lane_size)] for j in range(lane_num)], + ) planes.append(p) return planes @@ -315,7 +297,7 @@ def merging_list_of_lists(lst): def pprint_dictionary(dictionary): - r""" + """ Pretty-print of a dictionary. INPUT: @@ -338,7 +320,7 @@ def pprint_dictionary(dictionary): def pprint_dictionary_to_file(dictionary, name_file): - r""" + """ Pretty-print of a dictionary. INPUT: @@ -367,10 +349,10 @@ def pprint_dictionary_to_file(dictionary, name_file): sage: os.remove(f"{tii_dir_path}/test_json") """ - if 'cipher' in dictionary.keys(): - dictionary['cipher'] = dictionary['cipher'].id + if "cipher" in dictionary.keys(): + dictionary["cipher"] = dictionary["cipher"].id dictionary_json = json.loads(str(dictionary).replace("'", '"')) - source_file = open(name_file, 'w') + source_file = open(name_file, "w") print(json.dumps(dictionary_json, indent=4), file=source_file) source_file.close() @@ -416,9 +398,7 @@ def signed_distance(lst_x, lst_y): sage: signed_distance(lst_x, lst_y) 0 """ - n = len(lst_x) - - return sum([abs(sgn_function(lst_x[i]) - sgn_function(lst_y[i])) for i in range(n)]) + return sum(abs(sgn_function(i) - sgn_function(j)) for i, j in zip(lst_x, lst_y)) def simplify_inputs(inputs_id, inputs_pos): @@ -471,7 +451,7 @@ def poly_to_int(polynom, word_size, a): str_poly = str_poly.split(" + ") binary_lst = [] for i in range(word_size): - tmp = a ** i + tmp = a**i if str(tmp) in str_poly: binary_lst.append("1") else: diff --git a/docker/Dockerfile b/docker/Dockerfile index 668ec4c1d..6736240a3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -9,7 +9,8 @@ RUN apt-get -q update RUN apt-get install --no-install-recommends -y \ libboost-program-options-dev \ libsqlite3-dev \ - libstdc++-9-dev + libstdc++-9-dev \ + ca-certificates RUN apt-get install --no-install-recommends -y \ gfortran \ @@ -27,6 +28,10 @@ RUN apt-get install -y \ git \ wget +RUN apt-get update && apt-get install -y \ + chromium-browser \ + chromium-chromedriver + RUN apt-get install -y \ dieharder=3.31.1.2-1build1 \ latexmk=1:4.76-1 \ @@ -48,8 +53,7 @@ ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${GUROBI_HOME}/lib" WORKDIR /opt -## Installing Soplex (for SCIP) - +# Installing Soplex (for SCIP) RUN wget https://github.com/scipopt/soplex/archive/refs/tags/release-603.tar.gz \ && tar -xf release-603.tar.gz \ && rm release-603.tar.gz \ @@ -61,20 +65,19 @@ RUN wget https://github.com/scipopt/soplex/archive/refs/tags/release-603.tar.gz ENV SOPLEX_HOME="/opt/soplex-release-603/build" -## Installing HiGHS - +# Installing HiGHS RUN wget https://github.com/ERGO-Code/HiGHS/releases/download/v1.10.0/source-archive.tar.gz \ && tar -xf source-archive.tar.gz \ && rm source-archive.tar.gz \ && cd HiGHS \ && cmake -S. -B build \ - && cmake --build build --parallel + && cmake --build build ENV PATH="${PATH}:/opt/HiGHS/build/bin" WORKDIR /opt -## Installing SCIP +# Installing SCIP RUN wget https://github.com/scipopt/scip/archive/refs/tags/v803.tar.gz \ && tar -xf v803.tar.gz \ && rm v803.tar.gz \ @@ -88,6 +91,7 @@ ENV PATH="${PATH}:/opt/scip-803/build/bin" # Installing SageMath tools RUN sage -pip install bitstring==4.0.1 \ + kaleido==1.0.0 \ keras==2.13.1 \ minizinc==0.5.0 \ pandas==1.5.2 \ @@ -100,6 +104,7 @@ RUN sage -pip install bitstring==4.0.1 \ sphinx==5.0.0 \ sphinxcontrib-bibtex==2.5.0 \ tensorflow==2.13.0 \ + plotly==6.2.0 \ pytest==7.2.1 \ pytest-cov==4.0.0 \ pytest-xdist==3.2.0 \ @@ -108,7 +113,8 @@ RUN sage -pip install bitstring==4.0.1 \ numpy==1.24.3 \ joblib==1.4.2 \ gurobipy==11.0 \ - pytest-isolate==0.0.11 + pytest-isolate==0.0.11 \ + psutil==7.1.3 # Installing nist sts RUN curl -O -s https://csrc.nist.gov/CSRC/media/Projects/Random-Bit-Generation/documents/sts-2_1_2.zip \ @@ -122,27 +128,93 @@ RUN cd /opt/sts-2.1.2/sts-2.1.2 \ && ln -s /usr/local/bin/sts-2.1.2/assess /usr/local/bin/niststs \ && rm /opt/sts-2_1_2.zip +# Installing Minizinc +RUN wget https://github.com/MiniZinc/MiniZincIDE/releases/download/2.9.4/MiniZincIDE-2.9.4-bundle-linux-x86_64.tgz \ + && pwd && tar -xf MiniZincIDE-2.9.4-bundle-linux-x86_64.tgz +RUN mv MiniZincIDE-2.9.4-bundle-linux-x86_64 /MiniZinc \ + && rm MiniZincIDE-2.9.4-bundle-linux-x86_64.tgz + + +# Copy all MiniZinc components to system locations +RUN cp /MiniZinc/bin/* /usr/local/bin/ +RUN mkdir -p /usr/local/share/minizinc && cp -r /MiniZinc/share/minizinc/* /usr/local/share/minizinc/ + +RUN printf '%s\n' \ + '#!/bin/sh' \ + 'export LD_LIBRARY_PATH=/MiniZinc/lib:"$LD_LIBRARY_PATH"' \ + 'export SCIPOPTDIR=/opt/scip-803/build' \ + 'export PATH="$SCIPOPTDIR/bin:$PATH"' \ + 'export LD_LIBRARY_PATH="$SCIPOPTDIR/lib:$LD_LIBRARY_PATH"' \ + 'exec /MiniZinc/bin/minizinc "$@"' \ + > /usr/local/bin/minizinc-wrapper \ + && chmod +x /usr/local/bin/minizinc-wrapper \ + && ln -sf /usr/local/bin/minizinc-wrapper /usr/local/bin/minizinc + + +# Installing Chuffed +RUN wget https://github.com/chuffed/chuffed/archive/refs/tags/0.13.2.tar.gz \ + && tar -xf 0.13.2.tar.gz \ + && rm 0.13.2.tar.gz + +RUN cd chuffed-0.13.2 \ + && cmake -B build -S . \ + && cmake --build build \ + && cmake --build build --target install + + +# Installing Choco WORKDIR /opt -# Installing Minizinc -RUN wget https://github.com/MiniZinc/MiniZincIDE/releases/download/2.6.4/MiniZincIDE-2.6.4-bundle-linux-x86_64.tgz \ - && tar -xf MiniZincIDE-2.6.4-bundle-linux-x86_64.tgz \ - && rm MiniZincIDE-2.6.4-bundle-linux-x86_64.tgz +RUN wget https://github.com/chocoteam/choco-solver/archive/refs/tags/v4.10.18.tar.gz \ + && tar -xf v4.10.18.tar.gz \ + && rm v4.10.18.tar.gz + +RUN mv choco-solver-4.10.18/parsers/src/main/minizinc/fzn-choco* /usr/local/bin \ + && mkdir /usr/local/share/minizinc/choco \ + && mv choco-solver-4.10.18/parsers/src/main/minizinc/mzn_lib/* /usr/local/share/minizinc/choco -ENV PATH="/opt/MiniZincIDE-2.6.4-bundle-linux-x86_64/bin:${PATH}" +RUN wget https://github.com/chocoteam/choco-solver/releases/download/v4.10.18/choco-solver-4.10.18-light.jar + +# Update files +RUN sed -i 's&JAR_FILE=.*&JAR_FILE="/opt/choco-solver-4.10.18-light.jar"&g' /usr/local/bin/fzn-choco.py + +RUN echo '\ +{\n\ + "name": "Choco-solver",\n\ + "version": "4.10.18",\n\ + "id": "org.choco.choco",\n\ + "executable": "/usr/local/bin/fzn-choco.sh",\n\ + "mznlib": "/usr/local/share/minizinc/choco",\n\ + "tags": ["cp","int"],\n\ + "stdFlags": ["-a","-n","-s","-p","-r","-f","-t","--cp-profiler"]\n\ +}\ +' > /usr/local/share/minizinc/solvers/choco.msc # Installing CryptoMiniSat -RUN wget https://github.com/msoos/cryptominisat/archive/refs/tags/5.11.4.tar.gz \ - && tar -xf 5.11.4.tar.gz \ - && rm 5.11.4.tar.gz +# explicit commits are used instead of releases because +# no semantic versioning is used for cadical and cadiback +WORKDIR /opt -RUN cd cryptominisat-5.11.4 \ - && mkdir build \ - && cd build \ - && cmake .. \ - && make \ - && make install \ - && ldconfig +RUN git clone https://github.com/meelgroup/cadical.git \ + && cd cadical \ + && git checkout f5ac6ffa6eaaf72b407f6fb591091c5ff02271d8 \ + && CXXFLAGS=-fPIC ./configure --competition \ + && make -j4 \ + && cp build/libcadical.so /usr/lib/ + +RUN git clone https://github.com/meelgroup/cadiback.git \ + && cd cadiback \ + && git checkout 6050733173995f3cff6ecb04ac99d3691bd0b5e4 \ + && CXX=c++ ./configure \ + && make -j4 \ + && cp libcadiback.so /usr/lib/ + +RUN git clone https://github.com/msoos/cryptominisat.git \ + && cd cryptominisat \ + && git checkout 34ec41b3bb0332bd313e51d131298d30c52aa71f \ + && cmake -DENABLE_TESTING=OFF -DIPASIR=ON -B build -S . \ + && cmake --build build \ + && cmake --build build --target install WORKDIR /opt @@ -174,8 +246,10 @@ WORKDIR /opt ENV PATH="/opt/ParKissat-RS-master/:${PATH}" # Installing Glucose -RUN wget https://www.labri.fr/perso/lsimon/downloads/softwares/glucose-syrup-4.1.tgz \ - && tar -xf glucose-syrup-4.1.tgz \ +RUN wget https://www.labri.fr/perso/lsimon/downloads/softwares/glucose-syrup-4.1.tgz || \ + wget https://fosszone.csd.auth.gr/sagemath/spkg/upstream/glucose/glucose-syrup-4.1.tgz + +RUN tar -xf glucose-syrup-4.1.tgz \ && rm glucose-syrup-4.1.tgz RUN cd glucose-syrup-4.1/simp \ @@ -231,11 +305,11 @@ RUN cd cadical-rel-1.5.3 \ && ./configure \ && make -WORKDIR /opt - ENV PATH="/opt/cadical-rel-1.5.3/build:${PATH}" -# Installing Yices-Sat +# Installing Yices +WORKDIR /opt + RUN wget https://yices.csl.sri.com/releases/2.6.4/yices-2.6.4-x86_64-pc-linux-gnu.tar.gz \ && tar -xf yices-2.6.4-x86_64-pc-linux-gnu.tar.gz \ && rm yices-2.6.4-x86_64-pc-linux-gnu.tar.gz @@ -243,90 +317,17 @@ RUN wget https://yices.csl.sri.com/releases/2.6.4/yices-2.6.4-x86_64-pc-linux-gn RUN cd yices-2.6.4 \ && ./install-yices -WORKDIR /opt - -# Installing or-tools -RUN wget https://github.com/google/or-tools/releases/download/v9.2/or-tools_amd64_flatzinc_ubuntu-21.10_v9.2.9972.tar.gz \ - && tar -xf or-tools_amd64_flatzinc_ubuntu-21.10_v9.2.9972.tar.gz \ - && rm or-tools_amd64_flatzinc_ubuntu-21.10_v9.2.9972.tar.gz - -RUN mkdir -p /opt/minizinc/solvers/s - -RUN echo '\ -{ \n\ -"executable": "/opt/or-tools_flatzinc_Ubuntu-21.10-64bit_v9.2.9972/bin/fzn-or-tools", \n\ -"id": "Xor", \n\ -"isGUIApplication": false, \n\ -"mznlib": "/opt/or-tools_flatzinc_Ubuntu-21.10-64bit_v9.2.9972/lib", \n\ -"mznlibVersion": 1, \n\ -"name": "Xor", \n\ -"needsMznExecutable": false, \n\ -"needsPathsFile": false, \n\ -"needsSolns2Out": true, \n\ -"needsStdlibDir": false, \n\ -"stdFlags": [ \n\ - "-a", \n\ - "-p", \n\ - "-r", \n\ - "-f" \n\ -], \n\ -"supportsFzn": true, \n\ -"supportsMzn": false, \n\ -"supportsNL": false, \n\ -"version": "8.2" } \ -' > /opt/minizinc/solvers/Xor.msc - -# Installing Choco - -# Copy Choco's executable from the previous stage across -RUN wget https://github.com/chocoteam/choco-solver/archive/refs/tags/v4.10.12.tar.gz \ - && tar -xf v4.10.12.tar.gz \ - && rm v4.10.12.tar.gz - -RUN wget https://github.com/chocoteam/choco-solver/releases/download/v4.10.12/choco-solver-4.10.12.jar - -# Update files -RUN sed -i 's&CHOCO_JAR=.*&CHOCO_JAR=/opt/choco-solver-4.10.12.jar&g' /opt/choco-solver-4.10.12/parsers/src/main/minizinc/fzn-choco && \ - sed -i 's&"mznlib".*&"mznlib":"/opt/choco-solver-4.10.12/parsers/src/main/minizinc/mzn-lib/",&g' /opt/choco-solver-4.10.12/parsers/src/main/minizinc/choco.msc && \ - sed -i 's&"executable".*&"executable":"/opt/choco-solver-4.10.12/parsers/src/main/minizinc/fzn-choco",&g' /opt/choco-solver-4.10.12/parsers/src/main/minizinc/choco.msc - -ENV PATH="/opt/choco-solver-4.10.12:${PATH}" - -RUN echo '\ -{ \n\ - "id": "org.choco.choco", \n\ - "name": "Choco-solver", \n\ - "description": "Choco FlatZinc executable", \n\ - "version": "4.10.12", \n\ - "mznlib": "/opt/choco-solver-4.10.12/parsers/src/main/minizinc/mzn_lib", \n\ - "executable": "/opt/choco-solver-4.10.12/parsers/src/main/minizinc/fzn-choco", \n\ - "tags": ["cp","int"], \n\ - "stdFlags": ["-a","-f","-n","-p","-r","-s","-t"], \n\ - "supportsMzn": false, \n\ - "supportsFzn": true, \n\ - "needsSolns2Out": true, \n\ - "needsMznExecutable": false, \n\ - "needsStdlibDir": false, \n\ - "isGUIApplication": false \n\ -} \ -' > /opt/minizinc/solvers/choco.msc - -ENV MZN_SOLVER_PATH="/opt/minizinc/solvers" - -ENV LD_LIBRARY_PATH="/opt/MiniZincIDE-2.6.4-bundle-linux-x86_64/lib:${LD_LIBRARY_PATH}" - -RUN rm -rf /opt/MiniZincIDE-2.6.4-bundle-linux-x86_64/lib/liblzma.so.5 -RUN rm -rf /opt/MiniZincIDE-2.6.4-bundle-linux-x86_64/lib/libselinux.so.1 -RUN rm -rf /opt/MiniZincIDE-2.6.4-bundle-linux-x86_64/lib/libsystemd.so.0 -RUN rm -rf /opt/MiniZincIDE-2.6.4-bundle-linux-x86_64/lib/libcrypt.so.1 - -RUN sage -pip install plotly -U kaleido COPY required_dependencies/sage_numerical_backends_gurobi-9.3.1.tar.gz /opt/ RUN cd /opt/ && sage -pip install sage_numerical_backends_gurobi-9.3.1.tar.gz RUN apt-get install -y coinor-cbc coinor-libcbc-dev RUN sage -python -m pip install sage-numerical-backends-coin==9.0b12 +# clean image +RUN rm -r /opt/choco-solver-4.10.18 +RUN rm -r /opt/chuffed-0.13.2 + + FROM claasp-base AS claasp-lib # Create a non-root user "sage" with home directory @@ -343,4 +344,4 @@ WORKDIR /home/${NAME}/tii-claasp COPY . . RUN make install -ENV TERM=xterm-color \ No newline at end of file +ENV TERM=xterm-color diff --git a/docs/create_rst_structure.py b/docs/create_rst_structure.py index c916b76cd..317ba0cf2 100644 --- a/docs/create_rst_structure.py +++ b/docs/create_rst_structure.py @@ -7,55 +7,45 @@ EXCLUDED_FOLDERS = ["__pycache__", "DTOs", "tii_reports"] EXCLUDED_FILES = ["__init__.py", "constants.py", ".DS_Store", "name_mappings.py", "finalAnalysisReportExample.txt"] EXCLUDED_EXTENSIONS = [".md"] -ROOT_FOLDER = '../claasp/' +ROOT_FOLDER = "../claasp/" SOURCE_ROOT_FOLDER = "./source/" Path(SOURCE_ROOT_FOLDER).mkdir(exist_ok=True) -IS_HTML = sys.argv[1] == 'html' +IS_HTML = sys.argv[1] == "html" REFERENCES_EXTENSION = "rst" if IS_HTML else "bib" copyfile("conf.py", Path("source", "conf.py")) -copyfile('references.rst', Path("source", f'references.{REFERENCES_EXTENSION}')) +copyfile("references.rst", Path("source", f"references.{REFERENCES_EXTENSION}")) def header_style(section, level): if not section: return "" - sections = { - 0: "=", - 1: "-", - 2: "=", - 3: "-", - 4: "`", - 5: "'", - 6: ".", - 7: "~", - 8: "*", - 9: "+", - 10: "^" - } + sections = {0: "=", 1: "-", 2: "=", 3: "-", 4: "`", 5: "'", 6: ".", 7: "~", 8: "*", 9: "+", 10: "^"} style = sections[level] * len(section) - if level in [0, 1]: + if level in (0, 1): return f"{style}\n{section}\n{style}\n" - if level in [2, 3, 4, 5, 6, 7, 8, 9, 10]: + if level in (2, 3, 4, 5, 6, 7, 8, 9, 10): return f"{section}\n{style}\n" return section with Path(SOURCE_ROOT_FOLDER, "index.rst").open(mode="w") as index_rst_file: - index_rst_file.write("=========================\n" - "CLAASP: Cryptographic Library for Automated Analysis of Symmetric Primitives\n" - "=========================\n" - "\n" - "This is a sample reference manual for CLAASP.\n" - "\n" - "To use this module, you need to import it: \n\n" - " from claasp import *\n\n" - "This reference shows a minimal example of documentation of \n" - "CLAASP following SageMath guidelines.\n") + index_rst_file.write( + "=========================\n" + "CLAASP: Cryptographic Library for Automated Analysis of Symmetric Primitives\n" + "=========================\n" + "\n" + "This is a sample reference manual for CLAASP.\n" + "\n" + "To use this module, you need to import it: \n\n" + " from claasp import *\n\n" + "This reference shows a minimal example of documentation of \n" + "CLAASP following SageMath guidelines.\n" + ) for root, directories, files in os.walk(ROOT_FOLDER): path = root.split(os.sep) @@ -76,26 +66,30 @@ def header_style(section, level): file_header = file_name.replace("_", " ").capitalize() adornment = "=" * len(file_header) link = file_path.replace("../claasp/", "").replace("/", ".") - rst_file.write(f"{header_style(file_header, 1)}\n" - f".. automodule:: {link}\n" - " :members:\n" - " :undoc-members:\n" - " :inherited-members:\n" - " :show-inheritance:\n\n") + rst_file.write( + f"{header_style(file_header, 1)}\n" + f".. automodule:: {link}\n" + " :members:\n" + " :undoc-members:\n" + " :inherited-members:\n" + " :show-inheritance:\n\n" + ) index_rst_file.write("\n") if IS_HTML: - index_rst_file.write("\n\n" - "General Information\n" - "===================\n" - "\n" - "* :ref:`Bibliographic References `\n" - "\n" - "Indices and Tables\n" - "==================\n" - "\n" - "* :ref:`genindex`\n" - "* :ref:`modindex`\n" - "* :ref:`search`\n") + index_rst_file.write( + "\n\n" + "General Information\n" + "===================\n" + "\n" + "* :ref:`Bibliographic References `\n" + "\n" + "Indices and Tables\n" + "==================\n" + "\n" + "* :ref:`genindex`\n" + "* :ref:`modindex`\n" + "* :ref:`search`\n" + ) else: index_rst_file.write(".. include:: references.bib") diff --git a/docs/references.rst b/docs/references.rst index 9372d2fe0..cad18f1df 100644 --- a/docs/references.rst +++ b/docs/references.rst @@ -181,6 +181,13 @@ solving the rank decoding and minrank problems* : In Advances in Cryptology–ASIACRYPT2020 +.. [BJKL+2016] + Beierle C., Jean J., Kölbl S., Leander G., Moradi A., Peyrin T., Sasaki Y., + Sasdrich P., Sim S.M. : + *The SKINNY Family of Block Ciphers and Its Low-Latency Variant MANTIS* : + Cryptology ePrint Archive, Report 2016/660, 2016. + https://eprint.iacr.org/2016/660.pdf + .. _claasp-ref-C: .. only:: html @@ -409,6 +416,11 @@ **N** +.. [NIST1998] + National Institute of Standards and Technology (NIST) : *SKIPJACK and + KEA Algorithm Specifications* : Version 2.0. 29 de mayo de 1998 : + https://csrc.nist.gov/csrc/media/projects/cryptographic-algorithm-validation-program/documents/skipjack/skipjack.pdf + .. _claasp-ref-O: .. only:: html diff --git a/sonar-project.properties b/sonar-project.properties index 6c597cdbd..aa1801e3c 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -7,6 +7,7 @@ sonar.organization=crypto-tii sonar.sources=claasp sonar.cpd.exclusions=claasp/ciphers/** sonar.tests=tests/unit +sonar.coverage.exclusions=claasp/cipher_modules/models/milp/milp_models/Gurobi/monomial_prediction.py ## C/C++ config sonar.c.file.suffixes=- diff --git a/tests/benchmark/sat_xor_differential_model_test.py b/tests/benchmark/sat_xor_differential_model_test.py index fbe2e36c6..5ccda9fa2 100644 --- a/tests/benchmark/sat_xor_differential_model_test.py +++ b/tests/benchmark/sat_xor_differential_model_test.py @@ -1,89 +1,96 @@ +from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel +from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list -from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT -speck = SpeckBlockCipher() -aes = AESBlockCipher() +SPECK = SpeckBlockCipher() +AES = AESBlockCipher() def test_build_xor_differential_trail_model_with_speck_cipher(benchmark): - sat = SatXorDifferentialModel(speck) + sat = SatXorDifferentialModel(SPECK) plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', + component_id=INPUT_PLAINTEXT, + constraint_type="not_equal", bit_positions=range(32), - bit_values=integer_to_bit_list(0, 32, 'big')) + bit_values=(0,) * 32, + ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64) + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) benchmark(sat.build_xor_differential_trail_model, 3, fixed_variables=[plaintext, key]) def test_build_xor_differential_trail_model_with_aes_cipher(benchmark): - sat = SatXorDifferentialModel(aes) + sat = SatXorDifferentialModel(AES) plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', + component_id=INPUT_PLAINTEXT, + constraint_type="not_equal", bit_positions=range(32), - bit_values=integer_to_bit_list(0, 32, 'big')) + bit_values=(0,) * 32, + ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64) + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) benchmark(sat.build_xor_differential_trail_model, 3, fixed_variables=[plaintext, key]) def test_find_all_xor_differential_trails_with_fixed_weight_with_speck_cipher(benchmark): - sat = SatXorDifferentialModel(speck) - sat.set_window_size_weight_pr_vars(1) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=(0,) * 32) - key = set_fixed_variables(component_id='key', constraint_type='equal', - bit_positions=range(64), bit_values=(0,) * 64) + sat = SatXorDifferentialModel(SPECK) + sat.window_size_weight_pr_vars = 1 + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) benchmark(sat.find_all_xor_differential_trails_with_fixed_weight, 9, fixed_values=[plaintext, key]) def test_find_all_xor_differential_trails_with_fixed_weight_with_aes_cipher(benchmark): - sat = SatXorDifferentialModel(aes) - sat.set_window_size_weight_pr_vars(1) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=(0,) * 32) - key = set_fixed_variables(component_id='key', constraint_type='equal', - bit_positions=range(64), bit_values=(0,) * 64) + sat = SatXorDifferentialModel(AES) + sat.window_size_weight_pr_vars = 1 + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) benchmark(sat.find_all_xor_differential_trails_with_fixed_weight, 9, fixed_values=[plaintext, key]) def test_find_lowest_weight_xor_differential_trail_with_speck_cipher(benchmark): speck = SpeckBlockCipher(number_of_rounds=7) sat = SatXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=(0,) * 32) - key = set_fixed_variables(component_id='key', constraint_type='equal', - bit_positions=range(64), bit_values=(0,) * 64) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) benchmark(sat.find_lowest_weight_xor_differential_trail, fixed_values=[plaintext, key]) def test_find_one_xor_differential_trail_with_fixed_weight(benchmark): - window_size_by_round_list = [0 for _ in range(speck.number_of_rounds)] - sat = SatXorDifferentialModel(speck) - sat.set_window_size_heuristic_by_round( - window_size_by_round_list + window_size_by_round_list = [0] * SPECK.number_of_rounds + sat = SatXorDifferentialModel(SPECK) + sat.set_window_size_heuristic_by_round(window_size_by_round_list) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=(0,) * 32) - key = set_fixed_variables(component_id='key', constraint_type='equal', - bit_positions=range(64), bit_values=(0,) * 64) benchmark(sat.find_one_xor_differential_trail_with_fixed_weight, 3, fixed_values=[plaintext, key]) def test_find_one_xor_differential_trail_with_fixed_weight_with_aes_cipher(benchmark): - sat = SatXorDifferentialModel(aes) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=(0,) * 32) - key = set_fixed_variables(component_id='key', constraint_type='equal', - bit_positions=range(64), bit_values=(0,) * 64) + sat = SatXorDifferentialModel(AES) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) benchmark(sat.find_one_xor_differential_trail_with_fixed_weight, 3, fixed_values=[plaintext, key]) diff --git a/tests/unit/cipher_modules/continuous_diffusion_analysis_test.py b/tests/unit/cipher_modules/continuous_diffusion_analysis_test.py index be5dbc8d7..8126843ca 100644 --- a/tests/unit/cipher_modules/continuous_diffusion_analysis_test.py +++ b/tests/unit/cipher_modules/continuous_diffusion_analysis_test.py @@ -1,3 +1,7 @@ +import pickle + +from plotly.basedatatypes import BaseFigure + from claasp.cipher_modules.continuous_diffusion_analysis import ContinuousDiffusionAnalysis from claasp.cipher_modules.report import Report from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher @@ -12,12 +16,21 @@ def test_continuous_tests(): assert test_results['plaintext']['cipher_output']['continuous_neutrality_measure'][0]['values'][0] > 0.009 -def test_continuous_tests_report(): - import pickle +def test_continuous_tests_report(monkeypatch, tmp_path): + captured_writes = [] + + def fake_write_image(self, file, *args, **kwargs): + captured_writes.append(file) + return None + + monkeypatch.setattr(BaseFigure, 'write_image', fake_write_image) with open('tests/unit/cipher_modules/pre_computed_cda_obj.pkl', 'rb') as f: cda_for_repo = pickle.load(f) cda_repo = Report(cda_for_repo) - cda_repo.save_as_image() + output_dir = str(tmp_path / 'cda-report') + cda_repo.save_as_image(output_directory=output_dir) + cda_repo.clean_reports(output_dir=output_dir) + assert captured_writes, 'Expected Plotly write_image to be invoked' def test_continuous_avalanche_factor(): diff --git a/tests/unit/cipher_modules/data/report_test_cache.pkl b/tests/unit/cipher_modules/data/report_test_cache.pkl new file mode 100644 index 000000000..0c039a8b3 Binary files /dev/null and b/tests/unit/cipher_modules/data/report_test_cache.pkl differ diff --git a/tests/unit/cipher_modules/division_trail_search_test.py b/tests/unit/cipher_modules/division_trail_search_test.py deleted file mode 100644 index 9ec64434e..000000000 --- a/tests/unit/cipher_modules/division_trail_search_test.py +++ /dev/null @@ -1,65 +0,0 @@ -from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher -from claasp.ciphers.permutations.gaston_sbox_permutation import GastonSboxPermutation -from claasp.ciphers.block_ciphers.aradi_block_cipher import AradiBlockCipher -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.ciphers.block_ciphers.midori_block_cipher import MidoriBlockCipher -from claasp.cipher_modules.division_trail_search import * - -""" - -Given a number of rounds of a chosen cipher and a chosen output bit, this module produces a model that can either: -- obtain the ANF of this chosen output bit, -- find the degree of this ANF, -- or check the presence or absence of a specified monomial. - -This module can only be used if the user possesses a Gurobi license. - -""" - -def test_find_anf_of_specific_output_bit(): - # Return the monomials of the anf of the chosen output bit - cipher = SimonBlockCipher(number_of_rounds=2) - milp = MilpDivisionTrailModel(cipher) - monomials = milp.find_anf_of_specific_output_bit(0) - assert monomials == ['p18','k32','p0','p3p24','p0p3p9','p2p9p24','p0p2p9','p10p17','p2p9p10','p10k49','p3k56','p17p24','p2p9k56','p0p9p17','k50','p24k49','p0p9k49','p4','k49k56','p17k56'] - - # Return the monomials of degree 2 of the anf of the chosen output bit - cipher = SimonBlockCipher(number_of_rounds=2) - milp = MilpDivisionTrailModel(cipher) - monomials = milp.find_anf_of_specific_output_bit(0, fixed_degree=2) - assert monomials ==['p17p24', 'p0p9k49', 'p3p24', 'p2p9k56', 'p10p17'] - -def test_find_degree_of_specific_output_bit(): - # Return the degree of the anf of the chosen output bit of the ciphertext - cipher = AradiBlockCipher(number_of_rounds=1) - milp = MilpDivisionTrailModel(cipher) - degree = milp.find_degree_of_specific_output_bit(0) - assert degree == 3 - - # Return the degree of the anf of the chosen output bit of the component xor_0_12 - cipher = AradiBlockCipher(number_of_rounds=1) - milp = MilpDivisionTrailModel(cipher) - degree = milp.find_degree_of_specific_output_bit(0, chosen_cipher_output='xor_0_12') - assert degree == 3 - - cipher = SpeckBlockCipher(number_of_rounds=1) - milp = MilpDivisionTrailModel(cipher) - degree = milp.find_degree_of_specific_output_bit(15) - assert degree == 1 - - cipher = GastonSboxPermutation(number_of_rounds=1) - milp = MilpDivisionTrailModel(cipher) - degree = milp.find_degree_of_specific_output_bit(0) - assert degree == 2 - - cipher = MidoriBlockCipher(number_of_rounds=2) - milp = MilpDivisionTrailModel(cipher) - degree = milp.find_degree_of_specific_output_bit(0) - assert degree == 8 - -def test_check_presence_of_particular_monomial_in_specific_anf(): - # Return the all monomials that contains p230 of the anf of the chosen output bit - cipher = GastonSboxPermutation(number_of_rounds=1) - milp = MilpDivisionTrailModel(cipher) - monomials = milp.check_presence_of_particular_monomial_in_specific_anf([("plaintext", 230)], 0) - assert monomials == ['p181p230','p15p230','p33p230','p54p230','p55p230','p82p230','p100p230','p114p230','p115p230','p128p230','p140p230','p141p230','p146p230','p223p230','p205p230','p209p230','p210p230','p230p267','p230p313','p230p314','p230p315'] diff --git a/tests/unit/cipher_modules/models/cp/mzn_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_model_test.py index 18d15a895..a8fc7add1 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_model_test.py @@ -5,11 +5,15 @@ from claasp.ciphers.block_ciphers.midori_block_cipher import MidoriBlockCipher from claasp.cipher_modules.models.cp.mzn_model import MznModel from claasp.ciphers.block_ciphers.raiden_block_cipher import RaidenBlockCipher -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import \ - MznXorDifferentialModelARXOptimized +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import MznXorDifferentialModel +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import ( + MznXorDifferentialModelARXOptimized, +) from claasp.cipher_modules.models.cp.mzn_models.mzn_cipher_model_arx_optimized import MznCipherModelARXOptimized -from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model_arx_optimized \ - import MznDeterministicTruncatedXorDifferentialModelARXOptimized +from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model_arx_optimized import ( + MznDeterministicTruncatedXorDifferentialModelARXOptimized, +) +from claasp.cipher_modules.models.utils import set_fixed_variables @pytest.mark.filterwarnings("ignore::DeprecationWarning:") @@ -17,40 +21,26 @@ def test_build_mix_column_truncated_table(): aes = AESBlockCipher(number_of_rounds=3) mzn = MznModel(aes) mix_column = aes.component_from(0, 21) - assert mzn.build_mix_column_truncated_table(mix_column) == 'array[0..93, 1..8] of int: ' \ - 'mix_column_truncated_table_mix_column_0_21 = ' \ - 'array2d(0..93, 1..8, [0,0,0,0,0,0,0,0,0,0,0,1,1,' \ - '1,1,1,0,0,1,0,1,1,1,1,0,0,1,1,0,1,1,1,0,0,1,1,1,' \ - '0,1,1,0,0,1,1,1,1,0,1,0,0,1,1,1,1,1,0,0,0,1,1,1,' \ - '1,1,1,0,1,0,0,1,1,1,1,0,1,0,1,0,1,1,1,0,1,0,1,1,' \ - '0,1,1,0,1,0,1,1,1,0,1,0,1,0,1,1,1,1,0,0,1,0,1,1,' \ - '1,1,1,0,1,1,0,0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,0,1,' \ - '1,0,1,0,1,1,0,1,1,1,0,0,1,1,0,1,1,1,1,0,1,1,1,0,' \ - '0,1,1,0,1,1,1,0,1,0,1,0,1,1,1,0,1,1,0,0,1,1,1,0,' \ - '1,1,1,0,1,1,1,1,0,0,1,0,1,1,1,1,0,1,0,0,1,1,1,1,' \ - '0,1,1,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,1,0,1,1,1,1,' \ - '1,1,0,0,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,1,0,' \ - '1,1,1,1,0,0,1,1,0,1,1,1,0,0,1,1,1,0,1,1,0,0,1,1,' \ - '1,1,0,1,0,0,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,1,0,1,' \ - '0,1,1,1,0,1,0,1,1,0,1,1,0,1,0,1,1,1,0,1,0,1,0,1,' \ - '1,1,1,1,0,1,1,0,0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,0,' \ - '1,1,0,1,0,1,1,0,1,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1,' \ - '0,1,0,1,0,1,1,1,0,1,1,1,0,1,1,1,1,0,0,1,0,1,1,1,' \ - '1,0,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1,1,1,0,0,0,' \ - '1,1,1,1,1,0,0,1,0,1,1,1,1,0,0,1,1,0,1,1,1,0,0,1,' \ - '1,1,0,1,1,0,0,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,1,0,' \ - '1,0,1,1,1,0,1,0,1,1,0,1,1,0,1,0,1,1,1,1,1,0,1,1,' \ - '0,0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,0,1,1,1,1,0,1,1,' \ - '1,0,0,1,1,0,1,1,1,0,1,1,1,0,1,1,1,1,0,1,1,0,1,1,' \ - '1,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,1,0,1,1,1,1,0,0,' \ - '1,1,0,1,1,1,0,0,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,1,' \ - '0,1,0,1,1,1,0,1,0,1,1,1,1,1,0,1,1,0,0,1,1,1,0,1,' \ - '1,0,1,1,1,1,0,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,0,' \ - '0,0,1,1,1,1,1,0,0,1,0,1,1,1,1,0,0,1,1,1,1,1,1,0,' \ - '1,0,0,1,1,1,1,0,1,0,1,1,1,1,1,0,1,1,0,1,1,1,1,0,' \ - '1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,1,1,1,1,1,1,' \ - '0,1,0,1,1,1,1,1,0,1,1,1,1,1,1,1,1,0,0,1,1,1,1,1,' \ - '1,0,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1]);' + assert ( + mzn.build_mix_column_truncated_table(mix_column) == "array[0..93, 1..8] of int: " + "mix_column_truncated_table_mix_column_0_21 = array2d(0..93, 1..8, [" + "0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,1,0,1,1,1,1,0,0,1,1,0,1,1,1,0,0,1,1,1,0,1,1,0,0,1,1,1,1,0," + "1,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,1,0,1,0,1,1,1,0,1,0,1,1,0,1,1,0,1,0,1,1,1," + "0,1,0,1,0,1,1,1,1,0,0,1,0,1,1,1,1,1,0,1,1,0,0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,0,1,1,0,1,0,1,1,0,1," + "1,1,0,0,1,1,0,1,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1,0,1,0,1,0,1,1,1,0,1,1,0,0,1,1,1,0,1,1,1,0,1,1,1," + "1,0,0,1,0,1,1,1,1,0,1,0,0,1,1,1,1,0,1,1,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1,0,0,1,1," + "1,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,1,0,1,1,1,1,0,0,1,1,0,1,1,1,0,0,1,1,1,0,1,1,0,0,1,1,1,1,0,1,0," + "0,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,1,0,1,0,1,1,1,0,1,0,1,1,0,1,1,0,1,0,1,1,1,0,1,0,1,0,1,1,1,1,1," + "0,1,1,0,0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,0,1,1,0,1,0,1,1,0,1,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1,0,1,0," + "1,0,1,1,1,0,1,1,1,0,1,1,1,1,0,0,1,0,1,1,1,1,0,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1,1,1,0,0,0,1,1," + "1,1,1,0,0,1,0,1,1,1,1,0,0,1,1,0,1,1,1,0,0,1,1,1,0,1,1,0,0,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,1,0,1," + "0,1,1,1,0,1,0,1,1,0,1,1,0,1,0,1,1,1,1,1,0,1,1,0,0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,0,1,1,1,1,0,1,1," + "1,0,0,1,1,0,1,1,1,0,1,1,1,0,1,1,1,1,0,1,1,0,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,1,0,1,1,1,1,0," + "0,1,1,0,1,1,1,0,0,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,1,0,1,0,1,1,1,0,1,0,1,1,1,1,1,0,1,1,0,0,1,1,1," + "0,1,1,0,1,1,1,1,0,1,1,1,0,1,1,1,0,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,1,0,1,1,1,1,0,0,1,1,1,1," + "1,1,0,1,0,0,1,1,1,1,0,1,0,1,1,1,1,1,0,1,1,0,1,1,1,1,0,1,1,1,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,1,1," + "1,1,1,1,0,1,0,1,1,1,1,1,0,1,1,1,1,1,1,1,1,0,0,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1]);" + ) def test_find_possible_number_of_active_sboxes(): @@ -61,36 +51,38 @@ def test_find_possible_number_of_active_sboxes(): def test_fix_variables_value_constraints(): - raiden = RaidenBlockCipher(number_of_rounds=1) mzn = MznXorDifferentialModelARXOptimized(raiden) mzn.build_xor_differential_trail_model() - fixed_variables = [{ - 'component_id': 'key', - 'constraint_type': 'equal', - 'bit_positions': [0, 1, 2, 3], - 'bit_values': [0, 1, 0, 1]}] + fixed_variables = [ + {"component_id": "key", "constraint_type": "equal", "bit_positions": [0, 1, 2, 3], "bit_values": [0, 1, 0, 1]} + ] - constraint_key_y_0 = 'constraint key_y0 = 0;' + constraint_key_y_0 = "constraint key_y0 = 0;" assert mzn.fix_variables_value_constraints_for_ARX(fixed_variables)[0] == constraint_key_y_0 - fixed_variables = [{'component_id': 'plaintext', - 'constraint_type': 'sum', - 'bit_positions': [0, 1, 2, 3], - 'operator': '>', - 'value': '0'}] - - assert mzn.fix_variables_value_constraints_for_ARX(fixed_variables)[0] == f'constraint plaintext_y0+plaintext_y1+' \ - f'plaintext_y2+plaintext_y3>0;' + fixed_variables = [ + { + "component_id": "plaintext", + "constraint_type": "sum", + "bit_positions": [0, 1, 2, 3], + "operator": ">", + "value": "0", + } + ] + + assert ( + mzn.fix_variables_value_constraints_for_ARX(fixed_variables)[0] + == "constraint plaintext_y0+plaintext_y1+plaintext_y2+plaintext_y3>0;" + ) raiden = RaidenBlockCipher(number_of_rounds=1) mzn = MznDeterministicTruncatedXorDifferentialModelARXOptimized(raiden) mzn.build_deterministic_truncated_xor_differential_trail_model() - fixed_variables = [{'component_id': 'key', - 'constraint_type': 'equal', - 'bit_positions': [0, 1, 2, 3], - 'bit_values': [0, 1, 0, 1]}] + fixed_variables = [ + {"component_id": "key", "constraint_type": "equal", "bit_positions": [0, 1, 2, 3], "bit_values": [0, 1, 0, 1]} + ] assert mzn.fix_variables_value_constraints_for_ARX(fixed_variables)[0] == constraint_key_y_0 @@ -98,13 +90,30 @@ def test_fix_variables_value_constraints(): mzn = MznCipherModelARXOptimized(raiden) mzn.build_cipher_model() - fixed_variables = [{'component_id': 'key', - 'constraint_type': 'equal', - 'bit_positions': [0, 1, 2, 3], - 'bit_values': [0, 1, 0, 1]}] + fixed_variables = [ + {"component_id": "key", "constraint_type": "equal", "bit_positions": [0, 1, 2, 3], "bit_values": [0, 1, 0, 1]} + ] assert mzn.fix_variables_value_constraints_for_ARX(fixed_variables)[0] == constraint_key_y_0 + speck = SpeckBlockCipher(number_of_rounds=3) + mzn = MznXorDifferentialModel(speck) + fixed_values = [set_fixed_variables('plaintext','equal',range(32),[(speck.get_all_components_ids()[-1],list(range(32)))])] + trail = mzn.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['components_values']['plaintext']['value'] == trail['components_values'][speck.get_all_components_ids()[-1]]['value'] + + mzn.initialise_model() + fixed_values = [set_fixed_variables('plaintext','not_equal',range(32),[(speck.get_all_components_ids()[-1],list(range(32)))])] + trail = mzn.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['components_values']['plaintext']['value'] != trail['components_values'][speck.get_all_components_ids()[-1]]['value'] + + mzn.initialise_model() + fixed_values = [set_fixed_variables('plaintext','equal',range(32),[0]*31+[1])] + fixed_values.append(set_fixed_variables(speck.get_all_components_ids()[-1],'equal',range(32),[0]*31+[1])) + fixed_values.append(set_fixed_variables('plaintext','not_equal',range(32),[(speck.get_all_components_ids()[-1],list(range(32)))])) + trail = mzn.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['status'] == 'UNSATISFIABLE' + def test_model_constraints(): with pytest.raises(Exception): diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized_test.py index 8db402f0b..f693cfd7c 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_boomerang_model_arx_optimized_test.py @@ -1,44 +1,37 @@ +import os +import numpy as np + from claasp.cipher_modules.models.cp.mzn_models.mzn_boomerang_model_arx_optimized import MznBoomerangModelARXOptimized +from claasp.cipher_modules.models.cp.solvers import CPSAT from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation from claasp.name_mappings import BOOMERANG_XOR_DIFFERENTIAL -import numpy as np -import os -from os import urandom - - -def speck32_64_word_size(): - return 16 - -def speck32_64_alpha(): - return 7 - -def speck32_64_beta(): - return 2 - - -MASK_VAL = 2 ** speck32_64_word_size() - 1 +SPECK32_64_WORD_SIZE = 16 +SPECK32_64_ALPHA = 7 +SPECK32_64_BETA = 2 +MASK_VAL = 2 ** SPECK32_64_WORD_SIZE - 1 def speck32_64_rol(x, k): - return ((x << k) & MASK_VAL) | (x >> (speck32_64_word_size() - k)) + return ((x << k) & MASK_VAL) | (x >> (SPECK32_64_WORD_SIZE - k)) def speck32_64_ror(x, k): - return (x >> k) | ((x << (speck32_64_word_size() - k)) & MASK_VAL) + return (x >> k) | ((x << (SPECK32_64_WORD_SIZE - k)) & MASK_VAL) def speck32_64_decrypt(ciphertext, ks): def dec_one_round(c, subkey_): c0, c1 = c c1 = c1 ^ c0 - c1 = speck32_64_ror(c1, speck32_64_beta()) + c1 = speck32_64_ror(c1, SPECK32_64_BETA) c0 = c0 ^ subkey_ c0 = (c0 - c1) & MASK_VAL - c0 = speck32_64_rol(c0, speck32_64_alpha()) + c0 = speck32_64_rol(c0, SPECK32_64_ALPHA) return c0, c1 + x, y = ciphertext for subkey in reversed(ks): x, y = dec_one_round((x, y), subkey) @@ -47,10 +40,10 @@ def dec_one_round(c, subkey_): def speck32_64_enc_one_round(plaintext, subkey): c0, c1 = plaintext - c0 = speck32_64_ror(c0, speck32_64_alpha()) + c0 = speck32_64_ror(c0, SPECK32_64_ALPHA) c0 = (c0 + c1) & MASK_VAL c0 = c0 ^ subkey - c1 = speck32_64_rol(c1, speck32_64_beta()) + c1 = speck32_64_rol(c1, SPECK32_64_BETA) c1 = c1 ^ c0 return c0, c1 @@ -65,17 +58,16 @@ def speck32_64_expand_key(k, t): def speck32_64_encrypt(p, ks): - x, y = p for k in ks: x, y = speck32_64_enc_one_round((x, y), k) return x, y -def speck32_64_bct_distinguisher_verifier(delta_, nabla_, nr, n=2 ** 10): - keys = np.frombuffer(urandom(8*n), dtype=np.uint16).reshape(4, -1) - plaintext_data_0_left = np.frombuffer(urandom(2*n), dtype=np.uint16) - plaintext_data_0_right = np.frombuffer(urandom(2*n), dtype=np.uint16) +def speck32_64_bct_distinguisher_verifier(delta_, nabla_, nr, n=2**10): + keys = np.frombuffer(os.urandom(8 * n), dtype=np.uint16).reshape(4, -1) + plaintext_data_0_left = np.frombuffer(os.urandom(2 * n), dtype=np.uint16) + plaintext_data_0_right = np.frombuffer(os.urandom(2 * n), dtype=np.uint16) plaintext_data_1_left = plaintext_data_0_left ^ delta_[0] plaintext_data_1_right = plaintext_data_0_right ^ delta_[1] subkey_list = speck32_64_expand_key(keys, nr) @@ -93,7 +85,7 @@ def speck32_64_bct_distinguisher_verifier(delta_, nabla_, nr, n=2 ** 10): plaintext_data_3_left, plaintext_data_3_right = speck32_64_decrypt(output_xor_nabla_1, subkey_list) nabla_temp = (np.uint32(delta_[0]) << 16) ^ delta_[1] - nabla_prime_temp_left = (np.uint32(plaintext_data_2_left ^ plaintext_data_3_left) << 16) + nabla_prime_temp_left = np.uint32(plaintext_data_2_left ^ plaintext_data_3_left) << 16 nabla_prime_temp = nabla_prime_temp_left ^ (plaintext_data_2_right ^ plaintext_data_3_right) total = np.sum(nabla_temp == nabla_prime_temp) @@ -110,70 +102,105 @@ def test_build_boomerang_model_speck_single_key(): speck = SpeckBlockCipher(number_of_rounds=8) speck = speck.remove_key_schedule() - top_cipher_end = [ - "xor_3_10", - "rot_4_6" - ] + top_cipher_end = ["xor_3_10", "rot_4_6"] - bottom_cipher_start = [ - "xor_4_8", - "rot_4_9", - 'key_4_2', - 'key_5_2', - 'key_6_2', - 'key_7_2' - ] + bottom_cipher_start = ["xor_4_8", "rot_4_9", "key_4_2", "key_5_2", "key_6_2", "key_7_2"] - sboxes = [ - "modadd_4_7", - ] + sboxes = ["modadd_4_7"] mzn_bct_model = MznBoomerangModelARXOptimized(speck, top_cipher_end, bottom_cipher_start, sboxes) fixed_variables_for_top_cipher = [ - {'component_id': 'plaintext', 'constraint_type': 'sum', 'bit_positions': [i for i in range(32)], - 'operator': '>', 'value': '0'}, - {'component_id': 'key_0_2', 'constraint_type': 'equal', 'bit_positions': [i for i in range(16)], - 'bit_values': [0 for _ in range(16)]}, - {'component_id': 'key_1_2', 'constraint_type': 'equal', 'bit_positions': [i for i in range(16)], - 'bit_values': [0 for _ in range(16)]}, - {'component_id': 'key_2_2', 'constraint_type': 'equal', 'bit_positions': [i for i in range(16)], - 'bit_values': [0 for _ in range(16)]}, - {'component_id': 'key_3_2', 'constraint_type': 'equal', 'bit_positions': [i for i in range(16)], - 'bit_values': [0 for _ in range(16)]}, - {'component_id': 'xor_3_10', 'constraint_type': 'sum', 'bit_positions': [i for i in range(16)], - 'operator': '>', 'value': '0'}] + { + "component_id": "plaintext", + "constraint_type": "sum", + "bit_positions": list(range(32)), + "operator": ">", + "value": "0", + }, + { + "component_id": "key_0_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, + { + "component_id": "key_1_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, + { + "component_id": "key_2_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, + { + "component_id": "key_3_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, + { + "component_id": "xor_3_10", + "constraint_type": "sum", + "bit_positions": list(range(16)), + "operator": ">", + "value": "0", + }, + ] fixed_variables_for_bottom_cipher = [ - {'component_id': 'new_xor_3_10', 'constraint_type': 'sum', 'bit_positions': [i for i in range(16)], - 'operator': '>', 'value': '0'}, - {'component_id': 'key_4_2', 'constraint_type': 'equal', - 'bit_positions': [i for i in range(16)], - 'bit_values': [0 for _ in range(16)]}, - {'component_id': 'key_5_2', 'constraint_type': 'equal', - 'bit_positions': [i for i in range(16)], - 'bit_values': [0 for _ in range(16)]}, - {'component_id': 'key_6_2', 'constraint_type': 'equal', - 'bit_positions': [i for i in range(16)], - 'bit_values': [0 for _ in range(16)]}, - {'component_id': 'key_7_2', 'constraint_type': 'equal', - 'bit_positions': [i for i in range(16)], 'bit_values': [0 for _ in range(16)]}, + { + "component_id": "new_xor_3_10", + "constraint_type": "sum", + "bit_positions": list(range(16)), + "operator": ">", + "value": "0", + }, + { + "component_id": "key_4_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, + { + "component_id": "key_5_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, + { + "component_id": "key_6_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, + { + "component_id": "key_7_2", + "constraint_type": "equal", + "bit_positions": list(range(16)), + "bit_values": (0,) * 16, + }, ] mzn_bct_model.create_boomerang_model(fixed_variables_for_top_cipher, fixed_variables_for_bottom_cipher) - result = mzn_bct_model.solve_for_ARX(solver_name='Xor') + result = mzn_bct_model.solve_for_ARX(solver_name=CPSAT) total_weight = MznBoomerangModelARXOptimized._get_total_weight(result) - parsed_result = mzn_bct_model.bct_parse_result(result, 'Xor', total_weight, BOOMERANG_XOR_DIFFERENTIAL) - filename = '.' + parsed_result = mzn_bct_model.bct_parse_result(result, CPSAT, total_weight, BOOMERANG_XOR_DIFFERENTIAL) + filename = "." mzn_bct_model.write_minizinc_model_to_file(filename) assert os.path.exists(mzn_bct_model.filename), "File was not created" os.remove(mzn_bct_model.filename) - assert total_weight == parsed_result['total_weight'] - input_difference = split_32bit_to_16bit(int(parsed_result['component_values']['plaintext']['value'], 16)) - output_difference = split_32bit_to_16bit(int(parsed_result['component_values']['cipher_output_7_12']['value'], 16)) - assert speck32_64_bct_distinguisher_verifier(input_difference, output_difference, speck.number_of_rounds, n=2**20) \ - > 0.0001 + assert total_weight == parsed_result["total_weight"] + input_difference = split_32bit_to_16bit(int(parsed_result["component_values"]["plaintext"]["value"], 16)) + output_difference = split_32bit_to_16bit(int(parsed_result["component_values"]["cipher_output_7_12"]["value"], 16)) + assert ( + speck32_64_bct_distinguisher_verifier(input_difference, output_difference, speck.number_of_rounds, n=2**20) + > 0.0001 + ) def test_build_boomerang_model_chacha(): @@ -183,72 +210,95 @@ def test_build_boomerang_model_chacha(): "rot_3_5", "modadd_3_3", "rot_3_2", - "modadd_3_6", "rot_3_11", "modadd_3_9", "rot_3_8", - "modadd_3_12", "rot_3_17", "modadd_3_15", "rot_3_14", - "modadd_3_18", "rot_3_23", "modadd_3_21", - "rot_3_20" + "rot_3_20", ] bottom_cipher_start = [ "xor_4_4", "modadd_4_3", "xor_4_1", - "xor_4_10", "modadd_4_9", "xor_4_7", - "xor_4_16", "modadd_4_15", "xor_4_13", - "xor_4_22", "modadd_4_21", - "xor_4_19" + "xor_4_19", ] - sboxes = [ - "modadd_4_0", - "modadd_4_6", - "modadd_4_12", - "modadd_4_18" - ] + sboxes = ["modadd_4_0", "modadd_4_6", "modadd_4_12", "modadd_4_18"] mzn_bct_model = MznBoomerangModelARXOptimized(chacha, top_cipher_end, bottom_cipher_start, sboxes) fixed_variables_for_top_cipher = [ - {'component_id': 'plaintext', 'constraint_type': 'sum', 'bit_positions': [i for i in range(512)], - 'operator': '>', 'value': '0'}, - {'component_id': 'plaintext', 'constraint_type': 'sum', 'bit_positions': [i for i in range(384)], - 'operator': '=', 'value': '0'} + { + "component_id": "plaintext", + "constraint_type": "sum", + "bit_positions": list(range(512)), + "operator": ">", + "value": "0", + }, + { + "component_id": "plaintext", + "constraint_type": "sum", + "bit_positions": list(range(384)), + "operator": "=", + "value": "0", + }, ] fixed_variables_for_bottom_cipher = [ - {'component_id': 'new_rot_3_23', 'constraint_type': 'sum', 'bit_positions': [i for i in range(32)], - 'operator': '>', 'value': '0'}, - {'component_id': 'new_rot_3_5', 'constraint_type': 'sum', 'bit_positions': [i for i in range(32)], - 'operator': '>', 'value': '0'}, - {'component_id': 'new_rot_3_11', 'constraint_type': 'sum', 'bit_positions': [i for i in range(32)], - 'operator': '>', 'value': '0'}, - {'component_id': 'new_rot_3_17', 'constraint_type': 'sum', 'bit_positions': [i for i in range(32)], - 'operator': '>', 'value': '0'}] + { + "component_id": "new_rot_3_23", + "constraint_type": "sum", + "bit_positions": list(range(32)), + "operator": ">", + "value": "0", + }, + { + "component_id": "new_rot_3_5", + "constraint_type": "sum", + "bit_positions": list(range(32)), + "operator": ">", + "value": "0", + }, + { + "component_id": "new_rot_3_11", + "constraint_type": "sum", + "bit_positions": list(range(32)), + "operator": ">", + "value": "0", + }, + { + "component_id": "new_rot_3_17", + "constraint_type": "sum", + "bit_positions": list(range(32)), + "operator": ">", + "value": "0", + }, + ] mzn_bct_model.create_boomerang_model(fixed_variables_for_top_cipher, fixed_variables_for_bottom_cipher) - result = mzn_bct_model.solve_for_ARX(solver_name='Xor') + result = mzn_bct_model.solve_for_ARX(solver_name=CPSAT) total_weight = MznBoomerangModelARXOptimized._get_total_weight(result) - parsed_result = mzn_bct_model.bct_parse_result(result, 'Xor', total_weight, BOOMERANG_XOR_DIFFERENTIAL) - filename = '.' + parsed_result = mzn_bct_model.bct_parse_result(result, CPSAT, total_weight, BOOMERANG_XOR_DIFFERENTIAL) + filename = "." mzn_bct_model.write_minizinc_model_to_file(filename) + assert os.path.exists(mzn_bct_model.filename), "File was not created" + os.remove(mzn_bct_model.filename) - assert total_weight == parsed_result['total_weight'] + + assert total_weight == parsed_result["total_weight"] diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_cipher_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_cipher_model_test.py index e16bc1dbf..89e2de40f 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_cipher_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_cipher_model_test.py @@ -1,15 +1,22 @@ -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.cipher_modules.models.cp.mzn_models.mzn_cipher_model import MznCipherModel from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY -def test_build_cipher_model(): - speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=4) +def test_find_missing_bits(): + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=22) mzn = MznCipherModel(speck) - fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little')), - set_fixed_variables('plaintext', 'equal', range(32), integer_to_bit_list(0, 32, 'little'))] - mzn.build_cipher_model(fixed_variables) - assert len(mzn.model_constraints) == 1204 - assert mzn.model_constraints[2] == 'array[0..31] of var 0..1: plaintext;' - assert mzn.model_constraints[3] == 'array[0..63] of var 0..1: key;' - assert mzn.model_constraints[4] == 'array[0..15] of var 0..1: rot_0_0;' + cipher_output_id = speck.get_all_components_ids()[-1] + plaintext_bits = integer_to_bit_list(0x6574694C, 32, "big") + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=plaintext_bits + ) + key_bits = integer_to_bit_list(0x1918111009080100, 64, "big") + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=key_bits + ) + + missing_bits = mzn.find_missing_bits(fixed_values=[plaintext, key]) + + assert missing_bits["components_values"][cipher_output_id]["value"] == "0xa86842f2" diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized_test.py index 17e286731..cf05e619a 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_arx_optimized_test.py @@ -1,6 +1,7 @@ +from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model_arx_optimized import ( + MznDeterministicTruncatedXorDifferentialModelARXOptimized, +) from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model_arx_optimized \ - import MznDeterministicTruncatedXorDifferentialModelARXOptimized def test_build_deterministic_truncated_xor_differential_trail_model(): diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_test.py index 07c281c74..3fda84f37 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_deterministic_truncated_xor_differential_model_test.py @@ -1,65 +1,81 @@ -from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher +from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model import ( + MznDeterministicTruncatedXorDifferentialModel, +) +from claasp.cipher_modules.models.cp.solvers import CHUFFED +from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list -from claasp.cipher_modules.models.cp.mzn_models.mzn_deterministic_truncated_xor_differential_model import \ - MznDeterministicTruncatedXorDifferentialModel +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT def test_build_deterministic_truncated_xor_differential_trail_model(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) mzn = MznDeterministicTruncatedXorDifferentialModel(speck) - fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] + fixed_variables = [set_fixed_variables(INPUT_KEY, "equal", range(64), (0,) * 64)] mzn.build_deterministic_truncated_xor_differential_trail_model(fixed_variables) assert len(mzn.model_constraints) == 438 - assert mzn.model_constraints[2] == 'array[0..31] of var 0..2: plaintext;' - assert mzn.model_constraints[3] == 'array[0..63] of var 0..2: key;' - assert mzn.model_constraints[4] == 'array[0..15] of var 0..2: rot_0_0;' + assert mzn.model_constraints[2] == "array[0..31] of var 0..2: plaintext;" + assert mzn.model_constraints[3] == "array[0..63] of var 0..2: key;" + assert mzn.model_constraints[4] == "array[0..15] of var 0..2: rot_0_0;" def test_find_all_deterministic_truncated_xor_differential_trails(): speck = SpeckBlockCipher(number_of_rounds=3) mzn = MznDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) - trail = mzn.find_all_deterministic_truncated_xor_differential_trails(3, [plaintext, key], 'Chuffed', solve_external = True) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + trails = mzn.find_all_deterministic_truncated_xor_differential_trails( + 3, [plaintext, key], CHUFFED, solve_external=True + ) - assert len(trail) == 4 - for i in range(len(trail)): - assert str(trail[i]['cipher']) == 'speck_p32_k64_o32_r3' - assert trail[i]['model_type'] == 'deterministic_truncated_xor_differential' - assert trail[i]['model_type'] == 'deterministic_truncated_xor_differential' - assert trail[i]['solver_name'] == 'Chuffed' + assert len(trails) == 4 + for trail in trails: + assert str(trail["cipher"]) == "speck_p32_k64_o32_r3" + assert trail["model_type"] == "deterministic_truncated_xor_differential" + assert trail["solver_name"] == CHUFFED - trail = mzn.find_all_deterministic_truncated_xor_differential_trails(3, [plaintext, key], 'chuffed', solve_external = False) + trails = mzn.find_all_deterministic_truncated_xor_differential_trails( + 3, [plaintext, key], CHUFFED, solve_external=False + ) - assert len(trail) == 4 - for i in range(len(trail)): - assert str(trail[i]['cipher']) == 'speck_p32_k64_o32_r3' - assert trail[i]['model_type'] == 'deterministic_truncated_xor_differential' - assert trail[i]['model_type'] == 'deterministic_truncated_xor_differential' - assert trail[i]['solver_name'] == 'chuffed' + assert len(trails) == 4 + for trail in trails: + assert str(trail["cipher"]) == "speck_p32_k64_o32_r3" + assert trail["model_type"] == "deterministic_truncated_xor_differential" + assert trail["solver_name"] == CHUFFED def test_find_one_deterministic_truncated_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=1) mzn = MznDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) - trail = mzn.find_one_deterministic_truncated_xor_differential_trail(1, [plaintext, key], 'Chuffed', solve_external = True) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + trail = mzn.find_one_deterministic_truncated_xor_differential_trail( + 1, [plaintext, key], CHUFFED, solve_external=True + ) - assert str(trail[0]['cipher']) == 'speck_p32_k64_o32_r1' + assert str(trail[0]["cipher"]) == "speck_p32_k64_o32_r1" - assert trail[0]['components_values']['key']['value'] == '000000000000000000000000000000000000000000000000000000' \ - '0000000000' - assert trail[0]['model_type'] == 'deterministic_truncated_xor_differential_one_solution' - assert trail[0]['solver_name'] == 'Chuffed' + assert ( + trail[0]["components_values"][INPUT_KEY]["value"] + == "0000000000000000000000000000000000000000000000000000000000000000" + ) + assert trail[0]["model_type"] == "deterministic_truncated_xor_differential_one_solution" + assert trail[0]["solver_name"] == CHUFFED - trail = mzn.find_one_deterministic_truncated_xor_differential_trail(1, [plaintext, key], 'chuffed', solve_external = False) + trail = mzn.find_one_deterministic_truncated_xor_differential_trail( + 1, [plaintext, key], CHUFFED, solve_external=False + ) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r1' + assert str(trail["cipher"]) == "speck_p32_k64_o32_r1" - assert trail['model_type'] == 'deterministic_truncated_xor_differential_one_solution' - assert trail['solver_name'] == 'chuffed' + assert trail["model_type"] == "deterministic_truncated_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model_test.py index e2f066370..29561fe02 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_hybrid_impossible_xor_differential_model_test.py @@ -1,97 +1,153 @@ -from claasp.cipher_modules.models.cp.mzn_models.mzn_hybrid_impossible_xor_differential_model import \ - MznHybridImpossibleXorDifferentialModel +from claasp.cipher_modules.models.cp.mzn_models.mzn_hybrid_impossible_xor_differential_model import ( + MznHybridImpossibleXorDifferentialModel, +) +from claasp.cipher_modules.models.cp.solvers import CHUFFED +from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.lblock_block_cipher import LBlockBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, SATISFIABLE def test_build_impossible_xor_differential_trail_model(): lblock = LBlockBlockCipher(number_of_rounds=4) mzn = MznHybridImpossibleXorDifferentialModel(lblock) - fixed_variables = [set_fixed_variables('key', 'equal', range(80), [0]*80)] - mzn.build_hybrid_impossible_xor_differential_trail_model(number_of_rounds=4, fixed_variables=fixed_variables, middle_round=3) + fixed_variables = [set_fixed_variables(INPUT_KEY, "equal", range(80), (0,) * 80)] + mzn.build_hybrid_impossible_xor_differential_trail_model( + number_of_rounds=4, fixed_variables=fixed_variables, middle_round=3 + ) assert len(mzn.model_constraints) == 2442 - assert mzn.model_constraints[2] == 'set of int: ext_domain = 0..2 union { i | i in 10..800 where (i mod 10 = 0)};' - assert mzn.model_constraints[3] == 'array[0..63] of var ext_domain: plaintext;' - assert mzn.model_constraints[4] == 'array[0..79] of var ext_domain: key;' - assert mzn.model_constraints[5] == 'array[0..63] of var ext_domain: inverse_cipher_output_3_19;' + assert mzn.model_constraints[2] == "set of int: ext_domain = 0..2 union { i | i in 10..800 where (i mod 10 = 0)};" + assert mzn.model_constraints[3] == "array[0..63] of var ext_domain: plaintext;" + assert mzn.model_constraints[4] == "array[0..79] of var ext_domain: key;" + assert mzn.model_constraints[5] == "array[0..63] of var ext_domain: inverse_cipher_output_3_19;" def test_find_all_impossible_xor_differential_trails(): lblock = LBlockBlockCipher(number_of_rounds=4) mzn = MznHybridImpossibleXorDifferentialModel(lblock) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - ciphertext = set_fixed_variables(component_id='inverse_' + lblock.get_all_components_ids()[-1], constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - key = set_fixed_variables('key', constraint_type='equal', - bit_positions=range(78), bit_values=[0] * 78) - trails = mzn.find_all_impossible_xor_differential_trails(4, [plaintext, ciphertext, key], 'Chuffed', 1, 3, 4, False, solve_external=True) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext = set_fixed_variables( + component_id="inverse_" + lblock.get_all_components_ids()[-1], + constraint_type="equal", + bit_positions=range(64), + bit_values=(0,) * 64, + ) + key = set_fixed_variables(INPUT_KEY, constraint_type="equal", bit_positions=range(78), bit_values=[0] * 78) + trails = mzn.find_all_impossible_xor_differential_trails( + 4, [plaintext, ciphertext, key], CHUFFED, 1, 3, 4, False, solve_external=True + ) assert len(trails) == 6 - assert trails[0]['status'] == 'SATISFIABLE' - assert trails[0]['components_values']['plaintext'][ - 'value'] == '................................................................' - assert trails[0]['components_values']['inverse_cipher_output_3_19'][ - 'value'] == '................................................................' + assert trails[0]["status"] == SATISFIABLE + assert ( + trails[0]["components_values"][INPUT_PLAINTEXT]["value"] + == "................................................................" + ) + assert ( + trails[0]["components_values"]["inverse_cipher_output_3_19"]["value"] + == "................................................................" + ) + def test_find_all_improbable_xor_differential_trails(): lblock = LBlockBlockCipher(number_of_rounds=4) mzn = MznHybridImpossibleXorDifferentialModel(lblock) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - ciphertext = set_fixed_variables(component_id='inverse_' + lblock.get_all_components_ids()[-1], constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(80), - bit_values=[0]*10+[1]+[0]*69) - trails = mzn.find_all_impossible_xor_differential_trails(4, [plaintext, ciphertext, key], 'Chuffed', 1, 3, 4, False, - probabilistic=True, solve_external=True) - assert trails['total_weight'] == 0.0 - assert len(trails['solutions']) == 6 - assert trails['solutions'][0]['status'] == 'SATISFIABLE' - assert trails['solutions'][0]['components_values']['plaintext'][ - 'value'] == '................................................................' - assert trails['solutions'][0]['components_values']['inverse_cipher_output_3_19'][ - 'value'] == '................................................................' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext = set_fixed_variables( + component_id="inverse_" + lblock.get_all_components_ids()[-1], + constraint_type="equal", + bit_positions=range(64), + bit_values=(0,) * 64, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(80), bit_values=[0] * 10 + [1] + [0] * 69 + ) + trails = mzn.find_all_impossible_xor_differential_trails( + 4, [plaintext, ciphertext, key], CHUFFED, 1, 3, 4, False, probabilistic=True, solve_external=True + ) + assert trails["total_weight"] == 0.0 + assert len(trails["solutions"]) == 6 + assert trails["solutions"][0]["status"] == SATISFIABLE + assert ( + trails["solutions"][0]["components_values"][INPUT_PLAINTEXT]["value"] + == "................................................................" + ) + assert ( + trails["solutions"][0]["components_values"]["inverse_cipher_output_3_19"]["value"] + == "................................................................" + ) + def test_find_one_impossible_xor_differential_trail(): lblock = LBlockBlockCipher(number_of_rounds=4) mzn = MznHybridImpossibleXorDifferentialModel(lblock) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - ciphertext = set_fixed_variables(component_id='inverse_' + lblock.get_all_components_ids()[-1], - constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(80), - bit_values=[0] * 10 + [1] + [0] * 69) - trail = mzn.find_one_impossible_xor_differential_trail(fixed_values=[plaintext, ciphertext, key], solver_name='Chuffed', middle_round=3, intermediate_components=False, solve_external = True) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext = set_fixed_variables( + component_id="inverse_" + lblock.get_all_components_ids()[-1], + constraint_type="equal", + bit_positions=range(64), + bit_values=(0,) * 64, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(80), bit_values=[0] * 10 + [1] + [0] * 69 + ) + trail = mzn.find_one_impossible_xor_differential_trail( + fixed_values=[plaintext, ciphertext, key], + solver_name=CHUFFED, + middle_round=3, + intermediate_components=False, + solve_external=True, + ) + + assert str(trail["cipher"]) == "lblock_p64_k80_o64_r4" + assert trail["model_type"] == "impossible_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED - assert str(trail['cipher']) == 'lblock_p64_k80_o64_r4' - assert trail['model_type'] == 'impossible_xor_differential_one_solution' - assert trail['solver_name'] == 'Chuffed' + assert ( + trail["components_values"][INPUT_PLAINTEXT]["value"] + == "................................................................" + ) + assert ( + trail["components_values"]["inverse_cipher_output_3_19"]["value"] + == "................................................................" + ) + assert trail["status"] == SATISFIABLE - assert trail['components_values']['plaintext']['value'] == '................................................................' - assert trail['components_values']['inverse_cipher_output_3_19']['value'] == '................................................................' - assert trail['status'] == 'SATISFIABLE' def test_find_one_improbable_xor_differential_trail(): lblock = LBlockBlockCipher(number_of_rounds=4) mzn = MznHybridImpossibleXorDifferentialModel(lblock) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - ciphertext = set_fixed_variables(component_id='inverse_' + lblock.get_all_components_ids()[-1], - constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(80), - bit_values=[0] * 10 + [1] + [0] * 69) - trail = mzn.find_one_impossible_xor_differential_trail(4, [plaintext, ciphertext, key], 'Chuffed', 1, 3, 4, False, - probabilistic=True, solve_external=True) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext = set_fixed_variables( + component_id="inverse_" + lblock.get_all_components_ids()[-1], + constraint_type="equal", + bit_positions=range(64), + bit_values=(0,) * 64, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(80), bit_values=[0] * 10 + [1] + [0] * 69 + ) + trail = mzn.find_one_impossible_xor_differential_trail( + 4, [plaintext, ciphertext, key], CHUFFED, 1, 3, 4, False, probabilistic=True, solve_external=True + ) - assert str(trail['cipher']) == 'lblock_p64_k80_o64_r4' - assert trail['model_type'] == 'impossible_xor_differential_one_solution' - assert trail['solver_name'] == 'Chuffed' + assert str(trail["cipher"]) == "lblock_p64_k80_o64_r4" + assert trail["model_type"] == "impossible_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED - assert trail['components_values']['plaintext'][ - 'value'] == '................................................................' - assert trail['components_values']['inverse_cipher_output_3_19'][ - 'value'] == '................................................................' - assert float(trail['total_weight']) in [2.0, 3.0] + assert ( + trail["components_values"][INPUT_PLAINTEXT]["value"] + == "................................................................" + ) + assert ( + trail["components_values"]["inverse_cipher_output_3_19"]["value"] + == "................................................................" + ) + assert float(trail["total_weight"]) in [2.0, 3.0] diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model_test.py index 77343b3db..9c1d3e402 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_impossible_xor_differential_model_test.py @@ -1,147 +1,217 @@ +from claasp.cipher_modules.models.cp.mzn_models.mzn_impossible_xor_differential_model import ( + MznImpossibleXorDifferentialModel, +) +from claasp.cipher_modules.models.cp.solvers import CHUFFED +from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list -from claasp.cipher_modules.models.cp.mzn_models.mzn_impossible_xor_differential_model import \ - MznImpossibleXorDifferentialModel +from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY, UNSATISFIABLE def test_build_impossible_xor_differential_trail_with_extensions_model(): speck = SpeckBlockCipher(number_of_rounds=6) mzn = MznImpossibleXorDifferentialModel(speck) - fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] - mzn.build_impossible_xor_differential_trail_with_extensions_model(number_of_rounds=6, fixed_variables=fixed_variables, initial_round=2, middle_round=3, final_round=5, intermediate_components=False) + fixed_variables = [set_fixed_variables(INPUT_KEY, "equal", range(64), (0,) * 64)] + mzn.build_impossible_xor_differential_trail_with_extensions_model( + number_of_rounds=6, + fixed_variables=fixed_variables, + initial_round=2, + middle_round=3, + final_round=5, + intermediate_components=False, + ) assert len(mzn.model_constraints) == 1764 - assert mzn.model_constraints[99] == 'array[0..31] of var 0..2: inverse_plaintext;' - assert mzn.model_constraints[3] == 'array[0..63] of var 0..2: key;' - assert mzn.model_constraints[39] == 'array[0..31] of var 0..2: cipher_output_5_12;' + assert mzn.model_constraints[99] == "array[0..31] of var 0..2: inverse_plaintext;" + assert mzn.model_constraints[3] == "array[0..63] of var 0..2: key;" + assert mzn.model_constraints[39] == "array[0..31] of var 0..2: cipher_output_5_12;" def test_build_impossible_xor_differential_trail_model(): speck = SpeckBlockCipher(number_of_rounds=5) mzn = MznImpossibleXorDifferentialModel(speck) - fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little'))] - mzn.build_impossible_xor_differential_trail_model(number_of_rounds=5, fixed_variables=fixed_variables, middle_round=3) + fixed_variables = [set_fixed_variables(INPUT_KEY, "equal", range(64), (0,) * 64)] + mzn.build_impossible_xor_differential_trail_model( + number_of_rounds=5, fixed_variables=fixed_variables, middle_round=3 + ) assert len(mzn.model_constraints) == 1661 - assert mzn.model_constraints[2] == 'array[0..31] of var 0..2: plaintext;' - assert mzn.model_constraints[3] == 'array[0..63] of var 0..2: key;' - assert mzn.model_constraints[4] == 'array[0..31] of var 0..2: inverse_cipher_output_4_12;' + assert mzn.model_constraints[2] == "array[0..31] of var 0..2: plaintext;" + assert mzn.model_constraints[3] == "array[0..63] of var 0..2: key;" + assert mzn.model_constraints[4] == "array[0..31] of var 0..2: inverse_cipher_output_4_12;" def test_find_all_impossible_xor_differential_trails(): speck = SpeckBlockCipher(number_of_rounds=7) mzn = MznImpossibleXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - ciphertext = set_fixed_variables(component_id='inverse_' + speck.get_all_components_ids()[-1], constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - key = set_fixed_variables('key', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - trail = mzn.find_all_impossible_xor_differential_trails(7, [plaintext, ciphertext, key], 'Chuffed', 1, 3, 7, False, solve_external = True) - - assert trail[0]['status'] == 'UNSATISFIABLE' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext_id = "inverse_" + speck.get_all_components_ids()[-1] + ciphertext = set_fixed_variables( + component_id=ciphertext_id, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + trail = mzn.find_all_impossible_xor_differential_trails( + 7, [plaintext, ciphertext, key], CHUFFED, 1, 3, 7, False, solve_external=True + ) + + assert trail[0]["status"] == UNSATISFIABLE def test_find_lowest_complexity_impossible_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=6) mzn = MznImpossibleXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - ciphertext = set_fixed_variables(component_id='inverse_' + speck.get_all_components_ids()[-1], constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - key = set_fixed_variables('key', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - trail = mzn.find_lowest_complexity_impossible_xor_differential_trail(6, [plaintext, ciphertext, key], 'Chuffed', 1, 3, 6, True, solve_external = True) - - assert str(trail['cipher']) == 'speck_p32_k64_o32_r6' - assert trail['model_type'] == 'impossible_xor_differential_one_solution' - assert trail['solver_name'] == 'Chuffed' - - assert trail['components_values']['plaintext']['value'] == '00000000010000000000000000000000' - assert trail['components_values']['inverse_cipher_output_5_12']['value'] == '10000000000000001000000000000010' - - assert trail['components_values']['xor_1_10']['value'] == '2222222100000010' - assert trail['components_values']['inverse_rot_2_9']['value'] == '2222222210022222' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext_id = "inverse_" + speck.get_all_components_ids()[-1] + ciphertext = set_fixed_variables( + component_id=ciphertext_id, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + trail = mzn.find_lowest_complexity_impossible_xor_differential_trail( + 6, [plaintext, ciphertext, key], CHUFFED, 1, 3, 6, True, solve_external=True + ) + + assert str(trail["cipher"]) == "speck_p32_k64_o32_r6" + assert trail["model_type"] == "impossible_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED + + assert trail["components_values"][INPUT_PLAINTEXT]["value"] != "0" * 32 + assert trail["components_values"][INPUT_KEY]["value"] == "0" * 64 + assert trail["components_values"]["inverse_cipher_output_5_12"]["value"] != "0" * 32 def test_find_one_impossible_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=6) mzn = MznImpossibleXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - ciphertext = set_fixed_variables(component_id='inverse_' + speck.get_all_components_ids()[-1], constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - key = set_fixed_variables('key', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - trail = mzn.find_one_impossible_xor_differential_trail(fixed_values=[plaintext, ciphertext, key], solver_name='Chuffed', middle_round=3, intermediate_components=True, solve_external = True) - - assert str(trail['cipher']) == 'speck_p32_k64_o32_r6' - assert trail['model_type'] == 'impossible_xor_differential_one_solution' - assert trail['solver_name'] == 'Chuffed' - - assert trail['components_values']['plaintext']['value'] == '00000000021000000010000000000000' - assert trail['components_values']['inverse_cipher_output_5_12']['value'] == '10000000000000001000000000000010' - - assert trail['components_values']['xor_1_10']['value'] == '2222222221000022' - assert trail['components_values']['inverse_rot_2_9']['value'] == '2222222210022222' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext_id = "inverse_" + speck.get_all_components_ids()[-1] + ciphertext = set_fixed_variables( + component_id=ciphertext_id, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + trail = mzn.find_one_impossible_xor_differential_trail( + fixed_values=[plaintext, ciphertext, key], + solver_name=CHUFFED, + middle_round=3, + intermediate_components=True, + solve_external=True, + ) + + assert str(trail["cipher"]) == "speck_p32_k64_o32_r6" + assert trail["model_type"] == "impossible_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED + + assert trail["components_values"][INPUT_PLAINTEXT]["value"] != "0" * 32 + assert trail["components_values"][INPUT_KEY]["value"] == "0" * 64 + assert trail["components_values"]["inverse_cipher_output_5_12"]["value"] != "0" * 32 + + +def test_find_one_impossible_xor_differential_trail_with_fully_automatic_model(): + simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) + mzn = MznImpossibleXorDifferentialModel(simon) + plaintext = set_fixed_variables( + component_id="plaintext", constraint_type="equal", bit_positions=range(32), bit_values=[0] * 31 + [1] + ) + key = set_fixed_variables(component_id="key", constraint_type="equal", bit_positions=range(64), bit_values=[0] * 64) + ciphertext = set_fixed_variables( + component_id="inverse_cipher_output_10_13", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0] * 6 + [2, 0, 2] + [0] * 23, + ) + trail = mzn.find_one_impossible_xor_differential_trail_with_fully_automatic_model( + fixed_values=[plaintext, key, ciphertext], solver_name=CHUFFED, intermediate_components=False + ) + + assert trail["status"] == "SATISFIABLE" + + assert trail["components_values"]["plaintext"]["value"] == "00000000000000000000000000000001" + assert trail["components_values"]["inverse_cipher_output_10_13"]["value"] == "00000020200000000000000000000000" + + assert trail["components_values"]["intermediate_output_5_12"]["value"] == "22222222222222220222222122222202" + assert trail["components_values"]["inverse_intermediate_output_5_12"]["value"] == "22222222002222202222222022222222" def test_find_one_impossible_xor_differential_trail_with_initial_and_final_round(): speck = SpeckBlockCipher(number_of_rounds=6) mzn = MznImpossibleXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - ciphertext = set_fixed_variables(component_id='inverse_' + speck.get_all_components_ids()[-1], - constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - key = set_fixed_variables('key', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - trail = mzn.find_one_impossible_xor_differential_trail(fixed_values=[plaintext, ciphertext, key], - solver_name='Chuffed', initial_round=1, final_round=6, - intermediate_components=True, solve_external=True) - - assert str(trail['cipher']) == 'speck_p32_k64_o32_r6' - assert trail['model_type'] == 'impossible_xor_differential_one_solution' - assert trail['solver_name'] == 'Chuffed' - - assert trail['components_values']['plaintext']['value'] == '00000000022200000021000000000000' - assert trail['components_values']['inverse_cipher_output_5_12']['value'] == '10000000000000001000000000000010' - - assert trail['components_values']['xor_1_10']['value'] == '2222222222100022' - assert trail['components_values']['inverse_rot_2_9']['value'] == '2222222210022222' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext_id = "inverse_" + speck.get_all_components_ids()[-1] + ciphertext = set_fixed_variables( + component_id=ciphertext_id, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + trail = mzn.find_one_impossible_xor_differential_trail( + fixed_values=[plaintext, ciphertext, key], + solver_name=CHUFFED, + initial_round=1, + final_round=6, + intermediate_components=True, + solve_external=True, + ) + + assert str(trail["cipher"]) == "speck_p32_k64_o32_r6" + assert trail["model_type"] == "impossible_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED + + assert trail["components_values"][INPUT_PLAINTEXT]["value"] != "0" * 32 + assert trail["components_values"][INPUT_KEY]["value"] == "0" * 64 + assert trail["components_values"]["inverse_cipher_output_5_12"]["value"] != "0" * 32 def test_find_one_impossible_xor_differential_trail_with_extensions(): speck = SpeckBlockCipher(number_of_rounds=6) mzn = MznImpossibleXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='inverse_plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - ciphertext = set_fixed_variables(component_id=speck.get_all_components_ids()[-1], constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - key = set_fixed_variables('key', constraint_type='equal', - bit_positions=range(64), bit_values=[0] * 64) - trail = mzn.find_one_impossible_xor_differential_trail_with_extensions(6, [plaintext, ciphertext, key], 'Chuffed', 2, 3, 5, True, solve_external = True) - - assert str(trail['cipher']) == 'speck_p32_k64_o32_r6' - assert trail['model_type'] == 'impossible_xor_differential_one_solution' - assert trail['solver_name'] == 'Chuffed' - - assert trail['components_values']['inverse_plaintext']['value'] == '22222220022222220000100000022200' - assert trail['components_values']['inverse_cipher_output_5_12']['value'] == '22222210000000002222221000000011' - - assert trail['components_values']['intermediate_output_2_12']['value'] == '22222222220000002222222222000022' - assert trail['components_values']['inverse_intermediate_output_2_12']['value'] == '22222222222222222222222222122222' + plaintext = set_fixed_variables( + component_id="inverse_plaintext", constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext_id = speck.get_all_components_ids()[-1] + ciphertext = set_fixed_variables( + component_id=ciphertext_id, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + trail = mzn.find_one_impossible_xor_differential_trail_with_extensions( + 6, [plaintext, ciphertext, key], CHUFFED, 2, 3, 5, True, solve_external=True + ) + + assert str(trail["cipher"]) == "speck_p32_k64_o32_r6" + assert trail["model_type"] == "impossible_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED + + assert trail["components_values"]["inverse_plaintext"]["value"] != "0" * 32 + assert trail["components_values"]["inverse_cipher_output_5_12"]["value"] != "0" * 32 def test_find_one_impossible_xor_differential_cluster(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=4) mzn = MznImpossibleXorDifferentialModel(speck) - fixed_variables = [set_fixed_variables('key', 'equal', range(64), integer_to_bit_list(0, 64, 'little')), - set_fixed_variables('plaintext', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little')), - set_fixed_variables('inverse_cipher_output_3_12', 'not_equal', range(32), integer_to_bit_list(0, 32, 'little'))] - trail = mzn.find_one_impossible_xor_differential_cluster(4, fixed_variables, 'Chuffed', 1, 3, 4, intermediate_components=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r4' - assert trail['model_type'] == 'impossible_xor_differential_one_solution' - assert trail['solver_name'] == 'Chuffed' - assert trail['components_values']['key']['value'] == '0000000000000000000000000000000000000000000000000000000000000000' - assert trail['status'] == 'SATISFIABLE' + fixed_variables = [ + set_fixed_variables(INPUT_KEY, "equal", range(64), (0,) * 64), + set_fixed_variables(INPUT_PLAINTEXT, "not_equal", range(32), (0,) * 32), + set_fixed_variables("inverse_cipher_output_3_12", "not_equal", range(32), (0,) * 32), + ] + trail = mzn.find_one_impossible_xor_differential_cluster( + 4, fixed_variables, CHUFFED, 1, 3, 4, intermediate_components=False + ) + assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" + assert trail["model_type"] == "impossible_xor_differential_one_solution" + assert trail["solver_name"] == CHUFFED + assert trail["components_values"][INPUT_KEY]["value"] == "0" * 64 + assert trail["status"] == "SATISFIABLE" diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model_test.py index c5cbb0398..984e09c3d 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_wordwise_deterministic_truncated_xor_differential_model_test.py @@ -1,27 +1,29 @@ -from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.cipher_modules.models.cp.mzn_models.mzn_wordwise_deterministic_truncated_xor_differential_model import ( + MznWordwiseDeterministicTruncatedXorDifferentialModel, +) +from claasp.cipher_modules.models.cp.solvers import CHUFFED from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list -from claasp.cipher_modules.models.cp.mzn_models.mzn_wordwise_deterministic_truncated_xor_differential_model import \ - MznWordwiseDeterministicTruncatedXorDifferentialModel +from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher def test_find_one_wordwise_deterministic_truncated_xor_differential_trail(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznWordwiseDeterministicTruncatedXorDifferentialModel(aes) - fixed_variables = [set_fixed_variables('key_value', 'equal', range(16), integer_to_bit_list(0, 16, 'little'))] + fixed_variables = [set_fixed_variables("key_value", "equal", range(16), (0,) * 16)] mzn.build_deterministic_truncated_xor_differential_trail_model(fixed_variables, wordwise=True) assert len(mzn.model_constraints) == 1361 - assert mzn.model_constraints[2] == 'array[0..15] of var 0..3: key_active;' - assert mzn.model_constraints[3] == 'array[0..15] of var -2..255: key_value;' - assert mzn.model_constraints[4] == 'array[0..15] of var 0..3: plaintext_active;' - assert mzn.model_constraints[5] == 'array[0..15] of var -2..255: plaintext_value;' + assert mzn.model_constraints[2] == "array[0..15] of var 0..3: key_active;" + assert mzn.model_constraints[3] == "array[0..15] of var -2..255: key_value;" + assert mzn.model_constraints[4] == "array[0..15] of var 0..3: plaintext_active;" + assert mzn.model_constraints[5] == "array[0..15] of var -2..255: plaintext_value;" -''' + +""" def test_build_wordwise_deterministic_truncated_xor_differential_trail_model(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznWordwiseDeterministicTruncatedXorDifferentialModel(aes) - fixed_variables = [set_fixed_variables('key_value', 'equal', range(16), integer_to_bit_list(0, 16, 'little'))] + fixed_variables = [set_fixed_variables('key_value', 'equal', range(16), (0,) * 16)] mzn.build_wordwise_deterministic_truncated_xor_differential_trail_model(fixed_variables) assert len(mzn.model_constraints) == 1361 @@ -35,16 +37,16 @@ def test_find_one_wordwise_deterministic_truncated_xor_differential_trail(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznWordwiseDeterministicTruncatedXorDifferentialModel(aes) plaintext = set_fixed_variables(component_id='plaintext_value', constraint_type='not_equal', - bit_positions=range(16), bit_values=[0] * 16) - key = set_fixed_variables(component_id='key_value', constraint_type='equal', bit_positions=range(16), bit_values=[0] * 16) - trail = mzn.find_one_wordwise_deterministic_truncated_xor_differential_trail(1, [plaintext, key], 'Chuffed', solve_external = True) + bit_positions=range(16), bit_values=(0,) * 16) + key = set_fixed_variables(component_id='key_value', constraint_type='equal', bit_positions=range(16), bit_values=(0,) * 16) + trail = mzn.find_one_wordwise_deterministic_truncated_xor_differential_trail(1, [plaintext, key], 'chuffed', solve_external = True) assert str(trail[0]['cipher']) == 'speck_p32_k64_o32_r1' assert trail[0]['components_values']['key']['value'] == '000000000000000000000000000000000000000000000000000000' \ '0000000000' assert trail[0]['model_type'] == 'deterministic_truncated_xor_differential_one_solution' - assert trail[0]['solver_name'] == 'Chuffed' + assert trail[0]['solver_name'] == 'chuffed' trail = mzn.find_one_wordwise_deterministic_truncated_xor_differential_trail(1, [plaintext, key], 'chuffed', solve_external = False) @@ -54,4 +56,4 @@ def test_find_one_wordwise_deterministic_truncated_xor_differential_trail(): '0000000000' assert trail[0]['model_type'] == 'deterministic_truncated_xor_differential_one_solution' assert trail[0]['solver_name'] == 'chuffed' -''' +""" diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized_test.py index 9a6f009cb..14ad79a0e 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_arx_optimized_test.py @@ -1,29 +1,38 @@ -from claasp.ciphers.block_ciphers.tea_block_cipher import TeaBlockCipher -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import ( + MznXorDifferentialModelARXOptimized, +) from claasp.ciphers.block_ciphers.raiden_block_cipher import RaidenBlockCipher -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import \ - MznXorDifferentialModelARXOptimized +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.ciphers.block_ciphers.tea_block_cipher import TeaBlockCipher -speck4 = SpeckBlockCipher(number_of_rounds=4, block_bit_size=32, key_bit_size=64) -mzn4 = MznXorDifferentialModelARXOptimized(speck4) +from claasp.cipher_modules.models.cp.solvers import CPSAT -speck5 = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) -mzn5 = MznXorDifferentialModelARXOptimized(speck5) +SPECK4 = SpeckBlockCipher(number_of_rounds=4, block_bit_size=32, key_bit_size=64) +MZN4 = MznXorDifferentialModelARXOptimized(SPECK4) + +SPECK5 = SpeckBlockCipher(number_of_rounds=5, block_bit_size=32, key_bit_size=64) +MZN5 = MznXorDifferentialModelARXOptimized(SPECK5) def generate_fixed_variables(block_size, key_size): bit_positions = list(range(block_size)) bit_positions_key = list(range(key_size)) - fixed_variables = [{'component_id': 'plaintext', - 'constraint_type': 'sum', - 'bit_positions': bit_positions, - 'operator': '>', - 'value': '0'}, - {'component_id': 'key', - 'constraint_type': 'sum', - 'bit_positions': bit_positions_key, - 'operator': '=', - 'value': '0'}] + fixed_variables = [ + { + "component_id": "plaintext", + "constraint_type": "sum", + "bit_positions": bit_positions, + "operator": ">", + "value": "0", + }, + { + "component_id": "key", + "constraint_type": "sum", + "bit_positions": bit_positions_key, + "operator": "=", + "value": "0", + }, + ] return fixed_variables @@ -32,125 +41,138 @@ def generate_fixed_variables(block_size, key_size): def test_build_lowest_weight_xor_differential_trail_model(): - mzn5.build_lowest_weight_xor_differential_trail_model(fixed_variables_32_64) - result = mzn5.solve_for_ARX('Xor') - assert result.statistics['nSolutions'] > 1 + MZN5.build_lowest_weight_xor_differential_trail_model(fixed_variables_32_64) + result = MZN5.solve_for_ARX(CPSAT) + assert result.statistics["nSolutions"] > 1 def test_build_lowest_xor_differential_trails_with_at_most_weight(): - mzn5.build_lowest_xor_differential_trails_with_at_most_weight(100, fixed_variables_32_64) - result = mzn5.solve_for_ARX('Xor') + MZN5.build_lowest_xor_differential_trails_with_at_most_weight(100, fixed_variables_32_64) + result = MZN5.solve_for_ARX(CPSAT) - assert result.statistics['nSolutions'] > 1 + assert result.statistics["nSolutions"] > 1 def test_find_all_xor_differential_trails_with_fixed_weight(): - result = mzn5.find_all_xor_differential_trails_with_fixed_weight( - 5, solver_name='Xor', fixed_values=fixed_variables_32_64) + result = MZN5.find_all_xor_differential_trails_with_fixed_weight( + 5, solver_name=CPSAT, fixed_values=fixed_variables_32_64 + ) - assert result['total_weight'] is None + assert result["total_weight"] is None def test_find_all_xor_differential_trails_with_weight_at_most(): - result = mzn4.find_all_xor_differential_trails_with_weight_at_most( - 1, solver_name='Xor', fixed_values=fixed_variables_32_64 + result = MZN4.find_all_xor_differential_trails_with_weight_at_most( + 1, solver_name=CPSAT, fixed_values=fixed_variables_32_64 ) - assert result[0]['total_weight'] > 1 + assert result[0]["total_weight"] > 1 def test_find_lowest_weight_xor_differential_trail(): - result = mzn5.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_32_64) + result = MZN5.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_32_64) assert result["total_weight"] == 9 - mzn = MznXorDifferentialModelARXOptimized(speck5, [0, 0, 0, 0, 0]) - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_32_64) + mzn = MznXorDifferentialModelARXOptimized(SPECK5, [0, 0, 0, 0, 0]) + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_32_64) assert result["total_weight"] == 9 speck = SpeckBlockCipher(number_of_rounds=4) mzn = MznXorDifferentialModelARXOptimized(speck) bit_positions_key = list(range(64)) - fixed_variables = [{'component_id': 'key', - 'constraint_type': 'sum', - 'bit_positions': bit_positions_key, - 'operator': '>', - 'value': '0'}] - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables) + fixed_variables = [ + { + "component_id": "key", + "constraint_type": "sum", + "bit_positions": bit_positions_key, + "operator": ">", + "value": "0", + } + ] + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables) assert result["total_weight"] == 0 tea = TeaBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialModelARXOptimized(tea) - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_64_128) + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_64_128) assert result["total_weight"] > 1 raiden = RaidenBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialModelARXOptimized(raiden, sat_or_milp="milp") - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_64_128) - assert result['total_weight'] == 6 - - pr_weights_per_round = [{"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}] - mzn = MznXorDifferentialModelARXOptimized(speck5, probability_weight_per_round=pr_weights_per_round) - solution = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_32_64) - round1_weight = solution['component_values']['modadd_0_1']['weight'] + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_64_128) + assert result["total_weight"] == 6 + + pr_weights_per_round = [ + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + ] + mzn = MznXorDifferentialModelARXOptimized(SPECK5, probability_weight_per_round=pr_weights_per_round) + solution = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_32_64) + round1_weight = solution["component_values"]["modadd_0_1"]["weight"] assert 2 <= round1_weight <= 4 - component_values = solution['component_values'] - round2_weight = component_values['modadd_1_2']['weight'] + component_values['modadd_1_7']['weight'] - round3_weight = component_values['modadd_2_2']['weight'] + component_values['modadd_2_7']['weight'] - round4_weight = component_values['modadd_3_2']['weight'] + component_values['modadd_3_7']['weight'] - round5_weight = component_values['modadd_4_2']['weight'] + component_values['modadd_4_7']['weight'] + component_values = solution["component_values"] + round2_weight = component_values["modadd_1_2"]["weight"] + component_values["modadd_1_7"]["weight"] + round3_weight = component_values["modadd_2_2"]["weight"] + component_values["modadd_2_7"]["weight"] + round4_weight = component_values["modadd_3_2"]["weight"] + component_values["modadd_3_7"]["weight"] + round5_weight = component_values["modadd_4_2"]["weight"] + component_values["modadd_4_7"]["weight"] assert 2 <= round2_weight <= 4 assert 2 <= round3_weight <= 4 assert 2 <= round4_weight <= 4 assert 2 <= round5_weight <= 4 - mzn = MznXorDifferentialModelARXOptimized(speck5, sat_or_milp="milp") - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_32_64) + mzn = MznXorDifferentialModelARXOptimized(SPECK5, sat_or_milp="milp") + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_32_64) assert result["total_weight"] == 9 - mzn = MznXorDifferentialModelARXOptimized(speck5, [0, 0, 0, 0, 0]) - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_32_64) + mzn = MznXorDifferentialModelARXOptimized(SPECK5, [0, 0, 0, 0, 0]) + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_32_64) assert result["total_weight"] == 9 speck = SpeckBlockCipher(number_of_rounds=4) mzn = MznXorDifferentialModelARXOptimized(speck, sat_or_milp="milp") bit_positions_key = list(range(64)) - fixed_variables = [{'component_id': 'key', - 'constraint_type': 'sum', - 'bit_positions': bit_positions_key, - 'operator': '>', - 'value': '0'}] - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables) + fixed_variables = [ + { + "component_id": "key", + "constraint_type": "sum", + "bit_positions": bit_positions_key, + "operator": ">", + "value": "0", + } + ] + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables) assert result["total_weight"] == 0 tea = TeaBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialModelARXOptimized(tea, sat_or_milp="milp") - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_64_128) + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_64_128) assert result["total_weight"] > 1 raiden = RaidenBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialModelARXOptimized(raiden, sat_or_milp="milp") - result = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_64_128) - assert result['total_weight'] == 6 - - pr_weights_per_round = [{"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}, - {"min_bound": 2, "max_bound": 4}] - mzn = MznXorDifferentialModelARXOptimized(speck5, - probability_weight_per_round=pr_weights_per_round, - sat_or_milp="milp") - solution = mzn.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_32_64) - round1_weight = solution['component_values']['modadd_0_1']['weight'] + result = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_64_128) + assert result["total_weight"] == 6 + + pr_weights_per_round = [ + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + {"min_bound": 2, "max_bound": 4}, + ] + mzn = MznXorDifferentialModelARXOptimized( + SPECK5, probability_weight_per_round=pr_weights_per_round, sat_or_milp="milp" + ) + solution = mzn.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_32_64) + round1_weight = solution["component_values"]["modadd_0_1"]["weight"] assert 2 <= round1_weight <= 4 - component_values = solution['component_values'] - round2_weight = component_values['modadd_1_2']['weight'] + component_values['modadd_1_7']['weight'] - round3_weight = component_values['modadd_2_2']['weight'] + component_values['modadd_2_7']['weight'] - round4_weight = component_values['modadd_3_2']['weight'] + component_values['modadd_3_7']['weight'] - round5_weight = component_values['modadd_4_2']['weight'] + component_values['modadd_4_7']['weight'] + component_values = solution["component_values"] + round2_weight = component_values["modadd_1_2"]["weight"] + component_values["modadd_1_7"]["weight"] + round3_weight = component_values["modadd_2_2"]["weight"] + component_values["modadd_2_7"]["weight"] + round4_weight = component_values["modadd_3_2"]["weight"] + component_values["modadd_3_7"]["weight"] + round5_weight = component_values["modadd_4_2"]["weight"] + component_values["modadd_4_7"]["weight"] assert 2 <= round2_weight <= 4 assert 2 <= round3_weight <= 4 assert 2 <= round4_weight <= 4 @@ -158,29 +180,29 @@ def test_find_lowest_weight_xor_differential_trail(): def test_find_lowest_weight_for_short_xor_differential_trail(): - mzn4.set_max_number_of_carries_on_arx_cipher(0) - mzn4.set_max_number_of_nonlinear_carries(0) - result = mzn4.find_lowest_weight_xor_differential_trail(solver_name='Xor', fixed_values=fixed_variables_32_64) + MZN4.set_max_number_of_carries_on_arx_cipher(0) + MZN4.set_max_number_of_nonlinear_carries(0) + result = MZN4.find_lowest_weight_xor_differential_trail(solver_name=CPSAT, fixed_values=fixed_variables_32_64) assert result["total_weight"] == 5 def test_get_probability_vars_from_key_schedule(): - mzn = MznXorDifferentialModelARXOptimized(speck4) + mzn = MznXorDifferentialModelARXOptimized(SPECK4) mzn.build_xor_differential_trail_model(fixed_variables=[]) - expected_result = ['p_modadd_1_2_0', 'p_modadd_2_2_0', 'p_modadd_3_2_0'] + expected_result = ["p_modadd_1_2_0", "p_modadd_2_2_0", "p_modadd_3_2_0"] assert set(mzn.get_probability_vars_from_key_schedule()) == set(expected_result) def test_get_probability_vars_from_permutation(): - mzn = MznXorDifferentialModelARXOptimized(speck4) + mzn = MznXorDifferentialModelARXOptimized(SPECK4) mzn.build_xor_differential_trail_model(fixed_variables=[]) - expected_result = ['p_modadd_0_1_0', 'p_modadd_1_7_0', 'p_modadd_2_7_0', 'p_modadd_3_7_0'] + expected_result = ["p_modadd_0_1_0", "p_modadd_1_7_0", "p_modadd_2_7_0", "p_modadd_3_7_0"] assert set(mzn.get_probability_vars_from_permutation()) == set(expected_result) def test_find_min_of_max_xor_differential_between_permutation_and_key_schedule(): - mzn = MznXorDifferentialModelARXOptimized(speck4) + mzn = MznXorDifferentialModelARXOptimized(SPECK4) result = mzn.find_min_of_max_xor_differential_between_permutation_and_key_schedule( - fixed_values=fixed_variables_32_64, solver_name='Xor' + fixed_values=fixed_variables_32_64, solver_name=CPSAT ) - assert result['total_weight'] == 5 + assert result["total_weight"] == 5 diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_test.py index 82b835918..e7f4695a5 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_model_test.py @@ -1,7 +1,11 @@ -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import ( - MznXorDifferentialModel, and_xor_differential_probability_ddt) + and_xor_differential_probability_ddt, + MznXorDifferentialModel, +) +from claasp.cipher_modules.models.cp.solvers import CHUFFED +from claasp.cipher_modules.models.utils import set_fixed_variables +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.name_mappings import INPUT_PLAINTEXT, UNSATISFIABLE def test_and_xor_differential_probability_ddt(): @@ -11,34 +15,35 @@ def test_and_xor_differential_probability_ddt(): def test_find_all_xor_differential_trails_with_fixed_weight(): speck = SpeckBlockCipher(block_bit_size=8, key_bit_size=16, number_of_rounds=2) mzn = MznXorDifferentialModel(speck) - trails = mzn.find_all_xor_differential_trails_with_fixed_weight(1, solver_name='Chuffed', solve_external=True) + trails = mzn.find_all_xor_differential_trails_with_fixed_weight(1, solver_name=CHUFFED, solve_external=True) assert len(trails) == 6 - trails = mzn.find_all_xor_differential_trails_with_fixed_weight(1, solver_name='chuffed', solve_external=False) + trails = mzn.find_all_xor_differential_trails_with_fixed_weight(1, solver_name=CHUFFED, solve_external=False) assert len(trails) == 6 - + + def test_solving_unsatisfiability(): speck = SpeckBlockCipher(block_bit_size=8, key_bit_size=16, number_of_rounds=4) mzn = MznXorDifferentialModel(speck) - trails = mzn.find_one_xor_differential_trail_with_fixed_weight(1, solver_name='Chuffed', solve_external=True) + trails = mzn.find_one_xor_differential_trail_with_fixed_weight(1, solver_name=CHUFFED, solve_external=True) - assert trails['status'] == 'UNSATISFIABLE' + assert trails["status"] == UNSATISFIABLE - trails = mzn.find_one_xor_differential_trail_with_fixed_weight(1, solver_name='chuffed', solve_external=False) + trails = mzn.find_one_xor_differential_trail_with_fixed_weight(1, solver_name=CHUFFED, solve_external=False) - assert trails['status'] == 'UNSATISFIABLE' + assert trails["status"] == UNSATISFIABLE def test_find_all_xor_differential_trails_with_weight_at_most(): speck = SpeckBlockCipher(block_bit_size=8, key_bit_size=16, number_of_rounds=2) mzn = MznXorDifferentialModel(speck) - trails = mzn.find_all_xor_differential_trails_with_weight_at_most(0, 1, solver_name='Chuffed', solve_external=True) + trails = mzn.find_all_xor_differential_trails_with_weight_at_most(0, 1, solver_name=CHUFFED, solve_external=True) assert len(trails) == 7 - trails = mzn.find_all_xor_differential_trails_with_weight_at_most(0, 1, solver_name='chuffed', solve_external=False) + trails = mzn.find_all_xor_differential_trails_with_weight_at_most(0, 1, solver_name=CHUFFED, solve_external=False) assert len(trails) == 7 @@ -46,68 +51,77 @@ def test_find_all_xor_differential_trails_with_weight_at_most(): def test_find_lowest_weight_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=5) mzn = MznXorDifferentialModel(speck) - trail = mzn.find_lowest_weight_xor_differential_trail(solver_name='Chuffed', solve_external=True) + trail = mzn.find_lowest_weight_xor_differential_trail(solver_name=CHUFFED, solve_external=True) + + assert str(trail["cipher"]) == "speck_p32_k64_o32_r5" + assert trail["total_weight"] == "9.0" + assert int(trail["components_values"]["cipher_output_4_12"]["value"], base=16) >= 0 + assert trail["components_values"]["cipher_output_4_12"]["weight"] == 0 + + trail = mzn.find_lowest_weight_xor_differential_trail(solver_name=CHUFFED, solve_external=True) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r5' - assert trail['total_weight'] == '9.0' - assert eval(trail['components_values']['cipher_output_4_12']['value']) >= 0 - assert trail['components_values']['cipher_output_4_12']['weight'] == 0 + assert str(trail["cipher"]) == "speck_p32_k64_o32_r5" + assert trail["total_weight"] == "9.0" + assert int(trail["components_values"]["cipher_output_4_12"]["value"], base=16) >= 0 + assert trail["components_values"]["cipher_output_4_12"]["weight"] == 0 - trail = mzn.find_lowest_weight_xor_differential_trail(solver_name='chuffed', solve_external=False) + trail = mzn.find_lowest_weight_xor_differential_trail(solver_name=CHUFFED, solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r5' - assert trail['total_weight'] == '9.0' - assert eval(trail['components_values']['cipher_output_4_12']['value']) >= 0 - assert trail['components_values']['cipher_output_4_12']['weight'] == 0 + assert str(trail["cipher"]) == "speck_p32_k64_o32_r5" + assert trail["total_weight"] == "9.0" + assert int(trail["components_values"]["cipher_output_4_12"]["value"], base=16) >= 0 + assert trail["components_values"]["cipher_output_4_12"]["weight"] == 0 def test_find_one_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - trail = mzn.find_one_xor_differential_trail([plaintext], 'Chuffed', solve_external=True) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + trail = mzn.find_one_xor_differential_trail([plaintext], CHUFFED, solve_external=True) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r2' - assert trail['model_type'] == 'xor_differential_one_solution' - assert eval(trail['components_values']['cipher_output_1_12']['value']) >= 0 - assert trail['components_values']['cipher_output_1_12']['weight'] == 0 - assert eval(trail['total_weight']) >= 0 + assert str(trail["cipher"]) == "speck_p32_k64_o32_r2" + assert trail["model_type"] == "xor_differential_one_solution" + assert int(trail["components_values"]["cipher_output_1_12"]["value"], base=16) >= 0 + assert trail["components_values"]["cipher_output_1_12"]["weight"] == 0 + assert float(trail["total_weight"]) >= 0 - trail = mzn.find_one_xor_differential_trail([plaintext], 'chuffed', solve_external=False) + trail = mzn.find_one_xor_differential_trail([plaintext], CHUFFED, solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r2' - assert trail['model_type'] == 'xor_differential_one_solution' - assert eval(trail['components_values']['cipher_output_1_12']['value']) >= 0 - assert trail['components_values']['cipher_output_1_12']['weight'] == 0 - assert eval(trail['total_weight']) >= 0 + assert str(trail["cipher"]) == "speck_p32_k64_o32_r2" + assert trail["model_type"] == "xor_differential_one_solution" + assert int(trail["components_values"]["cipher_output_1_12"]["value"], base=16) >= 0 + assert trail["components_values"]["cipher_output_1_12"]["weight"] == 0 + assert float(trail["total_weight"]) >= 0 def test_find_one_xor_differential_trail_with_fixed_weight(): speck = SpeckBlockCipher(number_of_rounds=5) mzn = MznXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='not_equal', - bit_positions=range(32), bit_values=[0] * 32) - trail = mzn.find_one_xor_differential_trail_with_fixed_weight(9, [plaintext], 'Chuffed', solve_external=True) - - assert str(trail['cipher']) == 'speck_p32_k64_o32_r5' - assert trail['model_type'] == 'xor_differential_one_solution' - assert eval(trail['components_values']['intermediate_output_0_5']['value']) >= 0 - assert trail['components_values']['intermediate_output_0_5']['weight'] == 0 - assert eval(trail['components_values']['intermediate_output_1_11']['value']) >= 0 - assert trail['components_values']['intermediate_output_1_11']['weight'] == 0 - assert eval(trail['components_values']['xor_3_8']['value']) >= 0 - assert trail['components_values']['xor_3_8']['weight'] == 0 - assert trail['total_weight'] == '9.0' - - trail = mzn.find_one_xor_differential_trail_with_fixed_weight(9, [plaintext], 'chuffed', solve_external=False) - - assert str(trail['cipher']) == 'speck_p32_k64_o32_r5' - assert trail['model_type'] == 'xor_differential_one_solution' - assert eval(trail['components_values']['intermediate_output_0_5']['value']) >= 0 - assert trail['components_values']['intermediate_output_0_5']['weight'] == 0 - assert eval(trail['components_values']['intermediate_output_1_11']['value']) >= 0 - assert trail['components_values']['intermediate_output_1_11']['weight'] == 0 - assert eval(trail['components_values']['xor_3_8']['value']) >= 0 - assert trail['components_values']['xor_3_8']['weight'] == 0 - assert trail['total_weight'] == '9.0' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 + ) + trail = mzn.find_one_xor_differential_trail_with_fixed_weight(9, [plaintext], CHUFFED, solve_external=True) + + assert str(trail["cipher"]) == "speck_p32_k64_o32_r5" + assert trail["model_type"] == "xor_differential_one_solution" + assert int(trail["components_values"]["intermediate_output_0_5"]["value"], base=16) >= 0 + assert trail["components_values"]["intermediate_output_0_5"]["weight"] == 0 + assert int(trail["components_values"]["intermediate_output_1_11"]["value"], base=16) >= 0 + assert trail["components_values"]["intermediate_output_1_11"]["weight"] == 0 + assert int(trail["components_values"]["xor_3_8"]["value"], base=16) >= 0 + assert trail["components_values"]["xor_3_8"]["weight"] == 0 + assert trail["total_weight"] == "9.0" + + trail = mzn.find_one_xor_differential_trail_with_fixed_weight(9, [plaintext], CHUFFED, solve_external=False) + + assert str(trail["cipher"]) == "speck_p32_k64_o32_r5" + assert trail["model_type"] == "xor_differential_one_solution" + assert int(trail["components_values"]["intermediate_output_0_5"]["value"], base=16) >= 0 + assert trail["components_values"]["intermediate_output_0_5"]["weight"] == 0 + assert int(trail["components_values"]["intermediate_output_1_11"]["value"], base=16) >= 0 + assert trail["components_values"]["intermediate_output_1_11"]["weight"] == 0 + assert int(trail["components_values"]["xor_3_8"]["value"], base=16) >= 0 + assert trail["components_values"]["xor_3_8"]["weight"] == 0 + assert trail["total_weight"] == "9.0" diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model_test.py index 060beaeb8..f79e47fa5 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_number_of_active_sboxes_model_test.py @@ -1,16 +1,18 @@ import pytest +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_number_of_active_sboxes_model import ( + MznXorDifferentialNumberOfActiveSboxesModel, +) +from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_number_of_active_sboxes_model import \ - MznXorDifferentialNumberOfActiveSboxesModel, build_xor_truncated_table -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.name_mappings import INPUT_KEY @pytest.mark.filterwarnings("ignore::DeprecationWarning:") def test_add_additional_xor_constraints(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialNumberOfActiveSboxesModel(aes) - fixed_variables = [set_fixed_variables('key', 'not_equal', range(128), integer_to_bit_list(0, 128, 'little'))] + fixed_variables = [set_fixed_variables(INPUT_KEY, "not_equal", range(128), (0,) * 128)] mzn.build_xor_differential_trail_first_step_model(-1, fixed_variables) mzn.add_additional_xor_constraints(5, 1) diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model_test.py index 2ac6623bb..227abdb1c 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model_test.py @@ -1,18 +1,20 @@ -import os -import pytest - -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model \ - import (MznXorDifferentialFixingNumberOfActiveSboxesModel) +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_trail_search_fixing_number_of_active_sboxes_model import ( + MznXorDifferentialFixingNumberOfActiveSboxesModel, +) +from claasp.cipher_modules.models.cp.solvers import CHUFFED +from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, XOR_DIFFERENTIAL def test_find_all_xor_differential_trails_with_fixed_weight(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) - fixed_variables = [set_fixed_variables('key', 'equal', range(128), integer_to_bit_list(0, 128, 'little')), - set_fixed_variables('plaintext', 'not_equal', range(128), integer_to_bit_list(0, 128, 'little'))] - trails = mzn.find_all_xor_differential_trails_with_fixed_weight(30, fixed_variables, 'Chuffed', 'Chuffed') + fixed_variables = [ + set_fixed_variables(INPUT_KEY, "equal", range(128), (0,) * 128), + set_fixed_variables(INPUT_PLAINTEXT, "not_equal", range(128), (0,) * 128), + ] + trails = mzn.find_all_xor_differential_trails_with_fixed_weight(30, fixed_variables, CHUFFED, CHUFFED) assert len(trails) == 255 @@ -20,65 +22,81 @@ def test_find_all_xor_differential_trails_with_fixed_weight(): def test_find_lowest_weight_xor_differential_trail(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) - fixed_variables = [set_fixed_variables('key', 'equal', range(128), integer_to_bit_list(0, 128, 'little')), - set_fixed_variables('plaintext', 'not_equal', range(128), integer_to_bit_list(0, 128, 'little'))] - solution = mzn.find_lowest_weight_xor_differential_trail(fixed_variables, 'Chuffed', 'Chuffed') - - assert str(solution['cipher']) == 'aes_block_cipher_k128_p128_o128_r2' - assert solution['model_type'] == 'xor_differential' - assert solution['solver_name'] == 'Chuffed' - assert solution['total_weight'] == '30.0' - assert solution['components_values']['key'] == {'value': '0x00000000000000000000000000000000', 'weight': 0} - assert eval(solution['components_values']['plaintext']['value']) > 0 - assert solution['components_values']['plaintext']['weight'] == 0 - assert eval(solution['components_values']['cipher_output_1_32']['value']) >= 0 - assert solution['components_values']['cipher_output_1_32']['weight'] == 0 + fixed_variables = [ + set_fixed_variables(INPUT_KEY, "equal", range(128), (0,) * 128), + set_fixed_variables(INPUT_PLAINTEXT, "not_equal", range(128), (0,) * 128), + ] + solution = mzn.find_lowest_weight_xor_differential_trail(fixed_variables, CHUFFED, CHUFFED) + + assert str(solution["cipher"]) == "aes_block_cipher_k128_p128_o128_r2" + assert solution["model_type"] == XOR_DIFFERENTIAL + assert solution["solver_name"] == CHUFFED + assert solution["total_weight"] == "30.0" + assert solution["components_values"][INPUT_KEY] == {"value": "0x00000000000000000000000000000000", "weight": 0} + assert eval(solution["components_values"][INPUT_PLAINTEXT]["value"]) > 0 + assert solution["components_values"][INPUT_PLAINTEXT]["weight"] == 0 + assert eval(solution["components_values"]["cipher_output_1_32"]["value"]) >= 0 + assert solution["components_values"]["cipher_output_1_32"]["weight"] == 0 def test_find_one_xor_differential_trail(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) - fixed_variables = [set_fixed_variables('key', 'equal', range(128), integer_to_bit_list(0, 128, 'little')), - set_fixed_variables('plaintext', 'not_equal', range(128), integer_to_bit_list(0, 128, 'little'))] - solution = mzn.find_one_xor_differential_trail(fixed_variables, 'Chuffed', 'Chuffed') + fixed_variables = [ + set_fixed_variables(INPUT_KEY, "equal", range(128), (0,) * 128), + set_fixed_variables(INPUT_PLAINTEXT, "not_equal", range(128), (0,) * 128), + ] + solution = mzn.find_one_xor_differential_trail(fixed_variables, CHUFFED, CHUFFED) - assert str(solution['cipher']) == 'aes_block_cipher_k128_p128_o128_r2' - assert solution['model_type'] == 'xor_differential' - assert solution['solver_name'] == 'Chuffed' - assert eval(solution['total_weight']) >= 0.0 - assert solution['components_values']['key'] == {'value': '0x00000000000000000000000000000000', 'weight': 0} - assert solution['components_values']['plaintext']['weight'] == 0 + assert str(solution["cipher"]) == "aes_block_cipher_k128_p128_o128_r2" + assert solution["model_type"] == XOR_DIFFERENTIAL + assert solution["solver_name"] == CHUFFED + assert eval(solution["total_weight"]) >= 0.0 + assert solution["components_values"][INPUT_KEY] == {"value": "0x00000000000000000000000000000000", "weight": 0} + assert solution["components_values"][INPUT_PLAINTEXT]["weight"] == 0 + + solution = mzn.find_one_xor_differential_trail(fixed_variables, CHUFFED, CHUFFED) + + assert str(solution["cipher"]) == "aes_block_cipher_k128_p128_o128_r2" + assert solution["model_type"] == XOR_DIFFERENTIAL + assert solution["solver_name"] == CHUFFED + assert eval(solution["total_weight"]) >= 0.0 + assert solution["components_values"][INPUT_KEY] == {"value": "0x00000000000000000000000000000000", "weight": 0} + assert solution["components_values"][INPUT_PLAINTEXT]["weight"] == 0 def test_find_one_xor_differential_trail_with_fixed_weight(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) - fixed_variables = [set_fixed_variables('key', 'equal', range(128), integer_to_bit_list(0, 128, 'little')), - set_fixed_variables('plaintext', 'not_equal', range(128), integer_to_bit_list(0, 128, 'little'))] - solution = mzn.find_one_xor_differential_trail_with_fixed_weight(224, fixed_variables, 'Chuffed', 'Chuffed') + fixed_variables = [ + set_fixed_variables(INPUT_KEY, "equal", range(128), (0,) * 128), + set_fixed_variables(INPUT_PLAINTEXT, "not_equal", range(128), (0,) * 128), + ] + solution = mzn.find_one_xor_differential_trail_with_fixed_weight(224, fixed_variables, CHUFFED, CHUFFED) - assert str(solution['cipher']) == 'aes_block_cipher_k128_p128_o128_r2' - assert solution['model_type'] == 'xor_differential' - assert solution['solver_name'] == 'Chuffed' - assert eval(solution['total_weight']) == 224.0 - assert solution['components_values']['key'] == {'value': '0x00000000000000000000000000000000', 'weight': 0} - assert eval(solution['components_values']['plaintext']['value']) > 0 - assert solution['components_values']['plaintext']['weight'] == 0 - assert solution['components_values']['cipher_output_1_32']['weight'] == 0 + assert str(solution["cipher"]) == "aes_block_cipher_k128_p128_o128_r2" + assert solution["model_type"] == XOR_DIFFERENTIAL + assert solution["solver_name"] == CHUFFED + assert eval(solution["total_weight"]) == 224.0 + assert solution["components_values"][INPUT_KEY] == {"value": "0x00000000000000000000000000000000", "weight": 0} + assert eval(solution["components_values"][INPUT_PLAINTEXT]["value"]) > 0 + assert solution["components_values"][INPUT_PLAINTEXT]["weight"] == 0 + assert solution["components_values"]["cipher_output_1_32"]["weight"] == 0 def test_solve_full_two_steps_xor_differential_model(): aes = AESBlockCipher(number_of_rounds=2) mzn = MznXorDifferentialFixingNumberOfActiveSboxesModel(aes) - fixed_variables = [ - set_fixed_variables('key', 'not_equal', list(range(128)), integer_to_bit_list(0, 128, 'little'))] - constraints = mzn.solve_full_two_steps_xor_differential_model('xor_differential_one_solution', -1, fixed_variables, 'Chuffed', 'Chuffed') - - assert str(constraints['cipher']) == 'aes_block_cipher_k128_p128_o128_r2' - assert eval(constraints['components_values']['intermediate_output_0_35']['value']) >= 0 - assert constraints['components_values']['intermediate_output_0_35']['weight'] == 0 - assert eval(constraints['components_values']['xor_0_36']['value']) >= 0 - assert constraints['components_values']['xor_0_36']['weight'] == 0 - assert eval(constraints['components_values']['intermediate_output_0_37']['value']) >= 0 - assert constraints['components_values']['intermediate_output_0_37']['weight'] == 0 - assert eval(constraints['total_weight']) >= 0 + fixed_variables = [set_fixed_variables(INPUT_KEY, "not_equal", range(128), (0,) * 128)] + constraints = mzn.solve_full_two_steps_xor_differential_model( + "xor_differential_one_solution", -1, fixed_variables, CHUFFED, CHUFFED + ) + + assert str(constraints["cipher"]) == "aes_block_cipher_k128_p128_o128_r2" + assert eval(constraints["components_values"]["intermediate_output_0_35"]["value"]) >= 0 + assert constraints["components_values"]["intermediate_output_0_35"]["weight"] == 0 + assert eval(constraints["components_values"]["xor_0_36"]["value"]) >= 0 + assert constraints["components_values"]["xor_0_36"]["weight"] == 0 + assert eval(constraints["components_values"]["intermediate_output_0_37"]["value"]) >= 0 + assert constraints["components_values"]["intermediate_output_0_37"]["weight"] == 0 + assert eval(constraints["total_weight"]) >= 0 diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model_test.py index 35ae14330..8c7c1acb1 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_xor_linear_model_test.py @@ -1,8 +1,9 @@ +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel +from claasp.cipher_modules.models.cp.solvers import CHUFFED +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.toys.fancy_block_cipher import FancyBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel def test_and_xor_linear_probability_lat(): @@ -17,8 +18,9 @@ def test_final_xor_linear_constraints(): mzn = MznXorLinearModel(speck) mzn.build_xor_linear_trail_model(-1) - assert mzn.final_xor_linear_constraints(-1)[:-1] == \ - ['solve:: int_search(p, smallest, indomain_min, complete) minimize sum(p);'] + assert mzn.final_xor_linear_constraints(-1)[:-1] == [ + "solve:: int_search(p, smallest, indomain_min, complete) minimize sum(p);" + ] def test_find_all_xor_linear_trails_with_fixed_weight(): @@ -28,7 +30,7 @@ def test_find_all_xor_linear_trails_with_fixed_weight(): assert len(trails) == 12 - trails = mzn.find_all_xor_linear_trails_with_fixed_weight(1, solver_name = 'chuffed', solve_external=False) + trails = mzn.find_all_xor_linear_trails_with_fixed_weight(1, solver_name=CHUFFED, solve_external=False) assert len(trails) == 12 @@ -40,7 +42,7 @@ def test_find_all_xor_linear_trails_with_weight_at_most(): assert len(trails) == 13 - trails = mzn.find_all_xor_linear_trails_with_weight_at_most(0, 1, solver_name = 'chuffed', solve_external=False) + trails = mzn.find_all_xor_linear_trails_with_weight_at_most(0, 1, solver_name=CHUFFED, solve_external=False) assert len(trails) == 13 @@ -50,15 +52,15 @@ def test_find_lowest_weight_xor_linear_trail(): mzn = MznXorLinearModel(speck) trail = mzn.find_lowest_weight_xor_linear_trail(solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r4' - assert eval(trail['components_values']['cipher_output_3_12_o']['value']) >= 0 - assert trail['components_values']['cipher_output_3_12_o']['weight'] == 0 - assert trail['total_weight'] == '3.0' + assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" + assert eval(trail["components_values"]["cipher_output_3_12_o"]["value"]) >= 0 + assert trail["components_values"]["cipher_output_3_12_o"]["weight"] == 0 + assert trail["total_weight"] == "3.0" - trail = mzn.find_lowest_weight_xor_linear_trail(solver_name = 'chuffed', solve_external=False) + trail = mzn.find_lowest_weight_xor_linear_trail(solver_name=CHUFFED, solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r4' - assert trail['total_weight'] == '3.0' + assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" + assert trail["total_weight"] == "3.0" def test_find_one_xor_linear_trail(): @@ -66,19 +68,19 @@ def test_find_one_xor_linear_trail(): mzn = MznXorLinearModel(speck) trail = mzn.find_one_xor_linear_trail(solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r4' - assert trail['components_values']['plaintext']['weight'] == 0 - assert eval(trail['components_values']['plaintext']['value']) > 0 - assert trail['components_values']['cipher_output_3_12_o']['weight'] == 0 - assert eval(trail['components_values']['cipher_output_3_12_o']['value']) >= 0 - assert eval(trail['total_weight']) >= 0 + assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" + assert trail["components_values"]["plaintext"]["weight"] == 0 + assert eval(trail["components_values"]["plaintext"]["value"]) > 0 + assert trail["components_values"]["cipher_output_3_12_o"]["weight"] == 0 + assert eval(trail["components_values"]["cipher_output_3_12_o"]["value"]) >= 0 + assert eval(trail["total_weight"]) >= 0 - trail = mzn.find_one_xor_linear_trail(solver_name = 'chuffed', solve_external=False) + trail = mzn.find_one_xor_linear_trail(solver_name=CHUFFED, solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r4' - assert trail['components_values']['plaintext']['weight'] == 0 - assert eval(trail['components_values']['plaintext']['value']) > 0 - assert eval(trail['total_weight']) >= 0 + assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" + assert trail["components_values"]["plaintext"]["weight"] == 0 + assert eval(trail["components_values"]["plaintext"]["value"]) > 0 + assert eval(trail["total_weight"]) >= 0 def test_find_one_xor_linear_trail_with_fixed_weight(): @@ -86,26 +88,26 @@ def test_find_one_xor_linear_trail_with_fixed_weight(): mzn = MznXorLinearModel(speck) trail = mzn.find_one_xor_linear_trail_with_fixed_weight(3, solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r4' - assert trail['model_type'] == 'xor_linear_one_solution' - assert trail['total_weight'] == '3.0' + assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" + assert trail["model_type"] == "xor_linear_one_solution" + assert trail["total_weight"] == "3.0" - trail = mzn.find_one_xor_linear_trail_with_fixed_weight(3, solver_name = 'chuffed', solve_external=False) + trail = mzn.find_one_xor_linear_trail_with_fixed_weight(3, solver_name=CHUFFED, solve_external=False) - assert str(trail['cipher']) == 'speck_p32_k64_o32_r4' - assert trail['model_type'] == 'xor_linear_one_solution' - assert trail['total_weight'] == '3.0' + assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" + assert trail["model_type"] == "xor_linear_one_solution" + assert trail["total_weight"] == "3.0" def test_fix_variables_value_xor_linear_constraints(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=4) mzn = MznXorLinearModel(speck) assert mzn.fix_variables_value_xor_linear_constraints( - [set_fixed_variables('plaintext', 'equal', list(range(4)), integer_to_bit_list(5, 4, 'big'))]) == \ - ['constraint plaintext_o[0] = 0 /\\ plaintext_o[1] = 1 /\\ plaintext_o[2] = 0 /\\ plaintext_o[3] = 1;'] + [set_fixed_variables("plaintext", "equal", list(range(4)), integer_to_bit_list(5, 4, "big"))] + ) == ["constraint plaintext_o[0] = 0 /\\ plaintext_o[1] = 1 /\\ plaintext_o[2] = 0 /\\ plaintext_o[3] = 1;"] assert mzn.fix_variables_value_xor_linear_constraints( - [set_fixed_variables('plaintext', 'not_equal', list(range(4)), integer_to_bit_list(5, 4, 'big'))]) == \ - ['constraint plaintext_o[0] != 0 \\/ plaintext_o[1] != 1 \\/ plaintext_o[2] != 0 \\/ plaintext_o[3] != 1;'] + [set_fixed_variables("plaintext", "not_equal", list(range(4)), integer_to_bit_list(5, 4, "big"))] + ) == ["constraint plaintext_o[0] != 0 \\/ plaintext_o[1] != 1 \\/ plaintext_o[2] != 0 \\/ plaintext_o[3] != 1;"] def test_input_xor_linear_constraints(): @@ -113,20 +115,24 @@ def test_input_xor_linear_constraints(): mzn = MznXorLinearModel(speck) declarations, constraints = mzn.input_xor_linear_constraints() - assert declarations[0] == 'array[0..31] of var 0..1: plaintext_o;' - assert declarations[1] == 'array[0..63] of var 0..1: key_o;' - assert declarations[2] == 'array[0..6] of var {0, 1600, 900, 200, 1100, 400, 1300, 600, 1500, 800, 100, 1000, ' \ - '300, 1200, 500, 1400, 700}: p;' - assert declarations[3] == 'var int: weight = sum(p);' + assert declarations[0] == "array[0..31] of var 0..1: plaintext_o;" + assert declarations[1] == "array[0..63] of var 0..1: key_o;" + assert ( + declarations[2] == "array[0..6] of var " + "{0, 1600, 900, 200, 1100, 400, 1300, 600, 1500, 800, 100, 1000, 300, 1200, 500, 1400, 700}: p;" + ) + assert declarations[3] == "var int: weight = sum(p);" assert constraints == [] fancy = FancyBlockCipher(number_of_rounds=4) mzn = MznXorLinearModel(fancy) declarations, constraints = mzn.input_xor_linear_constraints() - assert declarations[0] == 'array[0..23] of var 0..1: plaintext_o;' - assert declarations[1] == 'array[0..23] of var 0..1: key_o;' - assert declarations[2] == 'array [1..5, 1..4] of int: and2inputs_LAT = array2d(1..5, 1..4, ' \ - '[0,0,0,0,0,0,1,100,0,1,1,100,1,0,1,100,1,1,1,100]);' - assert declarations[3] == 'array[0..127] of var {0, 100, 200, 300, 400, 500, 600}: p;' + assert declarations[0] == "array[0..23] of var 0..1: plaintext_o;" + assert declarations[1] == "array[0..23] of var 0..1: key_o;" + assert ( + declarations[2] + == "array [1..5, 1..4] of int: and2inputs_LAT = array2d(1..5, 1..4, [0,0,0,0,0,0,1,100,0,1,1,100,1,0,1,100,1,1,1,100]);" + ) + assert declarations[3] == "array[0..127] of var {0, 100, 200, 300, 400, 500, 600}: p;" assert constraints == [] diff --git a/tests/unit/cipher_modules/models/milp/milp_model_test.py b/tests/unit/cipher_modules/models/milp/milp_model_test.py index e4548e218..37f157f92 100644 --- a/tests/unit/cipher_modules/models/milp/milp_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_model_test.py @@ -1,12 +1,16 @@ import pytest +from claasp.cipher_modules.models.milp.milp_model import ( + get_independent_input_output_variables, + get_input_output_variables, +) from claasp.cipher_modules.models.milp.milp_model import MilpModel from claasp.cipher_modules.models.milp.milp_models.milp_xor_differential_model import MilpXorDifferentialModel from claasp.cipher_modules.models.milp.milp_models.milp_xor_linear_model import MilpXorLinearModel from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.milp.milp_model import get_independent_input_output_variables, \ - get_input_output_variables +from claasp.name_mappings import INPUT_PLAINTEXT, XOR_DIFFERENTIAL, XOR_LINEAR +from claasp.cipher_modules.models.utils import set_fixed_variables def test_get_independent_input_output_variables(): @@ -15,16 +19,16 @@ def test_get_independent_input_output_variables(): input_output_variables = get_independent_input_output_variables(component) assert len(input_output_variables[0]) == 32 - assert input_output_variables[0][0] == 'xor_1_10_0_i' - assert input_output_variables[0][1] == 'xor_1_10_1_i' - assert input_output_variables[0][30] == 'xor_1_10_30_i' - assert input_output_variables[0][31] == 'xor_1_10_31_i' + assert input_output_variables[0][0] == "xor_1_10_0_i" + assert input_output_variables[0][1] == "xor_1_10_1_i" + assert input_output_variables[0][30] == "xor_1_10_30_i" + assert input_output_variables[0][31] == "xor_1_10_31_i" assert len(input_output_variables[1]) == 16 - assert input_output_variables[1][0] == 'xor_1_10_0_o' - assert input_output_variables[1][1] == 'xor_1_10_1_o' - assert input_output_variables[1][14] == 'xor_1_10_14_o' - assert input_output_variables[1][15] == 'xor_1_10_15_o' + assert input_output_variables[1][0] == "xor_1_10_0_o" + assert input_output_variables[1][1] == "xor_1_10_1_o" + assert input_output_variables[1][14] == "xor_1_10_14_o" + assert input_output_variables[1][15] == "xor_1_10_15_o" def test_get_input_output_variables(): @@ -33,44 +37,64 @@ def test_get_input_output_variables(): input_output_variables = get_input_output_variables(component) assert len(input_output_variables[0]) == 16 - assert input_output_variables[0][0] == 'plaintext_0' - assert input_output_variables[0][1] == 'plaintext_1' - assert input_output_variables[0][14] == 'plaintext_14' - assert input_output_variables[0][15] == 'plaintext_15' + assert input_output_variables[0][0] == "plaintext_0" + assert input_output_variables[0][1] == "plaintext_1" + assert input_output_variables[0][14] == "plaintext_14" + assert input_output_variables[0][15] == "plaintext_15" assert len(input_output_variables[1]) == 16 - assert input_output_variables[1][0] == 'rot_0_0_0' - assert input_output_variables[1][1] == 'rot_0_0_1' - assert input_output_variables[1][14] == 'rot_0_0_14' - assert input_output_variables[1][15] == 'rot_0_0_15' + assert input_output_variables[1][0] == "rot_0_0_0" + assert input_output_variables[1][1] == "rot_0_0_1" + assert input_output_variables[1][14] == "rot_0_0_14" + assert input_output_variables[1][15] == "rot_0_0_15" def test_fix_variables_value_constraints(): simon = SimonBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) milp = MilpModel(simon) milp.init_model_in_sage_milp_class() - fixed_variables = [{'component_id': 'plaintext', - 'constraint_type': 'equal', - 'bit_positions': [0, 1, 2, 3], - 'bit_values': [1, 0, 1, 1] - }, - {'component_id': 'cipher_output_1_8', - 'constraint_type': 'not_equal', - 'bit_positions': [0, 1, 2, 3], - 'bit_values': [1, 1, 1, 0] - }] + fixed_variables = [ + { + "component_id": INPUT_PLAINTEXT, + "constraint_type": "equal", + "bit_positions": [0, 1, 2, 3], + "bit_values": [1, 0, 1, 1], + }, + { + "component_id": "cipher_output_1_8", + "constraint_type": "not_equal", + "bit_positions": [0, 1, 2, 3], + "bit_values": [1, 1, 1, 0], + }, + ] constraints = milp.fix_variables_value_constraints(fixed_variables) assert len(constraints) == 9 - assert str(constraints[0]) == 'x_0 == 1' - assert str(constraints[1]) == 'x_1 == 0' - assert str(constraints[2]) == 'x_2 == 1' - assert str(constraints[3]) == 'x_3 == 1' - assert str(constraints[4]) == 'x_4 == 1 - x_5' - assert str(constraints[5]) == 'x_6 == 1 - x_7' - assert str(constraints[6]) == 'x_8 == 1 - x_9' - assert str(constraints[7]) == 'x_10 == x_11' - assert str(constraints[8]) == '1 <= x_4 + x_6 + x_8 + x_10' + assert str(constraints[0]) == "x_0 == 1" + assert str(constraints[1]) == "x_1 == 0" + assert str(constraints[2]) == "x_2 == 1" + assert str(constraints[3]) == "x_3 == 1" + assert str(constraints[4]) == "x_4 == 1 - x_5" + assert str(constraints[5]) == "x_6 == 1 - x_7" + assert str(constraints[6]) == "x_8 == 1 - x_9" + assert str(constraints[7]) == "x_10 == x_11" + assert str(constraints[8]) == "1 <= x_4 + x_6 + x_8 + x_10" + + speck = SpeckBlockCipher(number_of_rounds=3) + milp = MilpXorDifferentialModel(speck) + fixed_values = [set_fixed_variables('plaintext','equal',range(32),[(speck.get_all_components_ids()[-1],list(range(32)))])] + trail = milp.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['components_values']['plaintext']['value'] == trail['components_values'][speck.get_all_components_ids()[-1]]['value'] + + fixed_values = [set_fixed_variables('plaintext','not_equal',range(32),[(speck.get_all_components_ids()[-1],list(range(32)))])] + trail = milp.find_one_xor_differential_trail(fixed_values=fixed_values, solver_name='SCIP_EXT') + assert trail['components_values']['plaintext']['value'] != trail['components_values'][speck.get_all_components_ids()[-1]]['value'] + + fixed_values = [set_fixed_variables('plaintext','equal',range(32),[0]*31+[1])] + fixed_values.append(set_fixed_variables(speck.get_all_components_ids()[-1],'equal',range(32),[0]*31+[1])) + fixed_values.append(set_fixed_variables('plaintext','not_equal',range(32),[(speck.get_all_components_ids()[-1],list(range(32)))])) + trail = milp.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['status'] == 'UNSATISFIABLE' def test_model_constraints(): @@ -85,24 +109,23 @@ def test_solve(): milp = MilpXorDifferentialModel(speck) milp.init_model_in_sage_milp_class() milp.add_constraints_to_build_in_sage_milp_class() - differential_solution = milp.solve("xor_differential") + differential_solution = milp.solve(XOR_DIFFERENTIAL) - assert str(differential_solution['cipher']) == 'speck_p32_k64_o32_r4' - assert differential_solution['model_type'] == 'xor_differential' - assert differential_solution['components_values']['key']['weight'] == 0 - assert differential_solution['components_values']['modadd_0_1']['weight'] >= 0 - assert differential_solution['solver_name'] == 'GLPK' - assert differential_solution['total_weight'] >= 0.0 + assert str(differential_solution["cipher"]) == "speck_p32_k64_o32_r4" + assert differential_solution["model_type"] == "xor_differential" + assert differential_solution["components_values"]["key"]["weight"] == 0 + assert differential_solution["components_values"]["modadd_0_1"]["weight"] >= 0 + assert differential_solution["solver_name"] == "GLPK" + assert differential_solution["total_weight"] >= 0.0 milp = MilpXorLinearModel(speck) milp.init_model_in_sage_milp_class() milp.add_constraints_to_build_in_sage_milp_class() linear_solution = milp.solve("xor_linear") - assert str(linear_solution['cipher']) == 'speck_p32_k64_o32_r4' - assert linear_solution['model_type'] == 'xor_linear' - assert differential_solution['components_values']['key']['weight'] == 0 - assert linear_solution['components_values']['modadd_1_7_i']['weight'] >= 0 - assert linear_solution['solver_name'] == 'GLPK' - assert linear_solution['total_weight'] >= 0.0 - + assert str(linear_solution["cipher"]) == "speck_p32_k64_o32_r4" + assert linear_solution["model_type"] == XOR_LINEAR + assert differential_solution["components_values"]["key"]["weight"] == 0 + assert linear_solution["components_values"]["modadd_1_7_i"]["weight"] >= 0 + assert linear_solution["solver_name"] == "GLPK" + assert linear_solution["total_weight"] >= 0.0 diff --git a/tests/unit/cipher_modules/models/milp/milp_models/Gurobi/monomial_prediction_test.py b/tests/unit/cipher_modules/models/milp/milp_models/Gurobi/monomial_prediction_test.py new file mode 100644 index 000000000..23c0e3b5c --- /dev/null +++ b/tests/unit/cipher_modules/models/milp/milp_models/Gurobi/monomial_prediction_test.py @@ -0,0 +1,112 @@ +import pytest + +from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher +from claasp.ciphers.stream_ciphers.trivium_stream_cipher import TriviumStreamCipher +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.ciphers.permutations.gimli_permutation import GimliPermutation +from claasp.ciphers.permutations.ascon_permutation import AsconPermutation +from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher +from claasp.cipher_modules.models.milp.milp_models.Gurobi.monomial_prediction import * + +""" + +Given a number of rounds of a chosen cipher and a chosen output bit, this module produces a model that can either: +- obtain the ANF of this chosen output bit, +- find the degree of this ANF, +- or check the presence or absence of a specified monomial. + +This module can only be used if the user possesses a Gurobi license. + +""" + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_find_anf_of_specific_output_bit(): + # Return the anf of the chosen output bit + cipher = GimliPermutation(number_of_rounds=1, word_size=4) + milp = MilpMonomialPredictionModel(cipher) + R = milp.get_boolean_polynomial_ring() + poly = milp.find_anf_of_specific_output_bit(0, chosen_cipher_output="xor_0_16") + expected = R("p0 + p1*p33 + p1 + p17 + p33") + assert poly == expected + + cipher = TriviumStreamCipher(keystream_bit_len=1, number_of_initialization_clocks=13) + milp = MilpMonomialPredictionModel(cipher) + R = milp.get_boolean_polynomial_ring() + poly = milp.find_anf_of_specific_output_bit(0) + expected = R("k0 + k27 + i9 + i24") + assert poly == expected + + cipher = AsconPermutation(number_of_rounds=1) + milp = MilpMonomialPredictionModel(cipher) + R = milp.get_boolean_polynomial_ring() + poly = milp.find_anf_of_specific_output_bit(0, chosen_cipher_output="xor_0_15") + expected = R("p0 + p64*p128 + p128 + p256") + assert poly == expected + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_find_upper_bound_degree_of_specific_output_bit(): + # Return an upper bound on the degree of the anf of the chosen output bit + cipher = AESBlockCipher(number_of_rounds=2, word_size=2, state_size=2) + milp = MilpMonomialPredictionModel(cipher) + degree = milp.find_upper_bound_degree_of_specific_output_bit(0, chosen_cipher_output="mix_column_0_7") + assert degree == 2 + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_find_superpoly_of_specific_output_bit(): + cipher = SimonBlockCipher(number_of_rounds=3) + milp = MilpMonomialPredictionModel(cipher) + R = milp.get_boolean_polynomial_ring() + superpoly = milp.find_superpoly_of_specific_output_bit(cube=["p1", "p2"], output_bit_index=0) + expected = R("p3*p10*p11 + p3*p10 + p4*p10 + p5*p10 + p10*p11*p18 + p10*p11*k50 + p10*p18 + p10*p19 + p10*k33 + p10*k50 + p10*k51 + p10 + p25 + k57") + assert superpoly == expected + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_find_exact_degree_of_superpoly_of_all_output_bits(): + cipher = SimonBlockCipher(number_of_rounds=4) + milp = MilpMonomialPredictionModel(cipher) + degrees = milp.find_exact_degree_of_superpoly_of_all_output_bits(["p1", "p2"]) + expected = [-1, -1, -1, -1, 2, 3, 3, 3, 3, -1, -1, -1, -1, 3, 3, 3, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 0] + assert degrees == expected + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_find_exact_degree_of_all_output_bits(): + cipher = SimonBlockCipher(number_of_rounds=2) + milp = MilpMonomialPredictionModel(cipher) + degrees = milp.find_exact_degree_of_all_output_bits(which_var_degree="p") + expected = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + assert degrees == expected + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_check_anf_correctness(): + cipher = SpeckBlockCipher(number_of_rounds=1) + milp = MilpMonomialPredictionModel(cipher) + check = milp.check_anf_correctness(14) + assert check == True + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_find_upper_bound_degree_of_cube_monomial_of_specific_output_bit(): + cipher = SimonBlockCipher(number_of_rounds=2) + milp = MilpMonomialPredictionModel(cipher) + cube = ["p0", "p2"] + degree = milp.find_upper_bound_degree_of_cube_monomial_of_specific_output_bit(0, cube) + expected = 2 + assert degree == expected + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_find_keycoeff_of_cube_monomial_of_specific_output_bit(): + cipher = SimonBlockCipher(number_of_rounds=2) + milp = MilpMonomialPredictionModel(cipher) + cube = ["p0", "p9"] + keycoeff = milp.find_keycoeff_of_cube_monomial_of_specific_output_bit(0, cube) + R = milp.get_boolean_polynomial_ring() + expected = R("k49") + assert keycoeff == expected + +@pytest.mark.skip(reason="Requires Gurobi license") +def test_check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(): + cipher = SimonBlockCipher(number_of_rounds=2) + milp = MilpMonomialPredictionModel(cipher) + cube = ["p0", "p9"] + keycoeff = milp.find_keycoeff_of_cube_monomial_of_specific_output_bit(0, cube) + res = check_correctness_of_keycoeff_of_cube_monomial_or_superpoly(cipher, 0, cube, keycoeff) + assert res == True diff --git a/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model_test.py b/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model_test.py index 838dca2db..aa5dd0aa4 100644 --- a/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_deterministic_truncated_xor_differential_model_test.py @@ -1,7 +1,9 @@ +from claasp.cipher_modules.models.milp.milp_models.milp_bitwise_deterministic_truncated_xor_differential_model import ( + MilpBitwiseDeterministicTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.utils import get_single_key_scenario_format_for_fixed_values, set_fixed_variables from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.milp.milp_models.milp_bitwise_deterministic_truncated_xor_differential_model import \ - MilpBitwiseDeterministicTruncatedXorDifferentialModel +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, SATISFIABLE def test_build_bitwise_deterministic_truncated_xor_differential_trail_model(): @@ -12,36 +14,47 @@ def test_build_bitwise_deterministic_truncated_xor_differential_trail_model(): constraints = milp.model_constraints assert len(constraints) == 62624 - assert str(constraints[0]) == 'x_16 == x_9' - assert str(constraints[1]) == 'x_17 == x_10' - assert str(constraints[-2]) == 'x_13273 == x_13225' - assert str(constraints[-1]) == 'x_13274 == x_13226' + assert str(constraints[0]) == "x_16 == x_9" + assert str(constraints[1]) == "x_17 == x_10" + assert str(constraints[-2]) == "x_13273 == x_13225" + assert str(constraints[-1]) == "x_13274 == x_13226" + def test_find_one_bitwise_deterministic_truncated_xor_differential_trail(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) milp = MilpBitwiseDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(32), + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = milp.find_one_bitwise_deterministic_truncated_xor_differential_trail(fixed_values=[plaintext, key]) - assert trail['components_values']['intermediate_output_0_6']['value'] == '????100000000000????100000000011' - + assert trail["components_values"]["intermediate_output_0_6"]["value"] == "????100000000000????100000000011" speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=3) milp = MilpBitwiseDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(32), + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = milp.find_one_bitwise_deterministic_truncated_xor_differential_trail(fixed_values=[plaintext, key]) - assert trail['components_values']['cipher_output_2_12']['value'] == '???????????????0????????????????' + assert trail["components_values"]["cipher_output_2_12"]["value"] == "???????????????0????????????????" def test_find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) milp = MilpBitwiseDeterministicTruncatedXorDifferentialModel(speck) trail = milp.find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail( - get_single_key_scenario_format_for_fixed_values(speck)) - assert trail['status'] == 'SATISFIABLE' - assert str(trail['total_weight']) == '14.0' + get_single_key_scenario_format_for_fixed_values(speck) + ) + assert trail["status"] == SATISFIABLE + assert str(trail["total_weight"]) == "14.0" diff --git a/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model_test.py b/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model_test.py index 8d205674b..597bdfa54 100644 --- a/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_models/milp_bitwise_impossible_xor_differential_model_test.py @@ -1,152 +1,213 @@ +from claasp.cipher_modules.models.milp.milp_models.milp_bitwise_impossible_xor_differential_model import ( + MilpBitwiseImpossibleXorDifferentialModel, +) from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher -from claasp.cipher_modules.models.milp.milp_models.milp_bitwise_impossible_xor_differential_model import \ - MilpBitwiseImpossibleXorDifferentialModel from claasp.ciphers.permutations.ascon_sbox_sigma_permutation import AsconSboxSigmaPermutation +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, SATISFIABLE + + +SIMON_INCOMPATIBLE_ROUND_OUTPUT = "????????00?????0???????0????????" + -SIMON_INCOMPATIBLE_ROUND_OUTPUT = '????????00?????0???????0????????' def test_build_bitwise_impossible_xor_differential_trail_model(): simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=2) milp = MilpBitwiseImpossibleXorDifferentialModel(simon) milp.init_model_in_sage_milp_class() milp._forward_cipher = simon.get_partial_cipher(0, 1, keep_key_schedule=True) backward_cipher = milp._cipher.cipher_partial_inverse(1, 1, keep_key_schedule=False) - milp._backward_cipher = backward_cipher.add_suffix_to_components("_backward", [ - backward_cipher.get_all_components_ids()[-1]]) + milp._backward_cipher = backward_cipher.add_suffix_to_components( + "_backward", [backward_cipher.get_all_components_ids()[-1]] + ) milp.build_bitwise_impossible_xor_differential_trail_model() constraints = milp.model_constraints assert len(constraints) == 2400 - assert str(constraints[0]) == 'x_16 == x_0' - assert str(constraints[1]) == 'x_17 == x_1' - assert str(constraints[-2]) == 'x_926 == x_766' - assert str(constraints[-1]) == 'x_927 == x_767' + assert str(constraints[0]) == "x_16 == x_0" + assert str(constraints[1]) == "x_17 == x_1" + assert str(constraints[-2]) == "x_926 == x_766" + assert str(constraints[-1]) == "x_927 == x_767" + def test_find_one_bitwise_impossible_xor_differential_trail_model(): simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) milp = MilpBitwiseImpossibleXorDifferentialModel(simon) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), bit_values=[0] * 31 + [1]) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) - ciphertext = set_fixed_variables(component_id='cipher_output_10_13', constraint_type='equal', bit_positions=range(32), bit_values=[0] * 6 + [2, 0, 2] + [0] * 23) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=[0] * 31 + [1] + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext = set_fixed_variables( + component_id="cipher_output_10_13", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0] * 6 + [2, 0, 2] + [0] * 23, + ) trail = milp.find_one_bitwise_impossible_xor_differential_trail(6, fixed_values=[plaintext, key, ciphertext]) - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['intermediate_output_5_12']['value'] == '????????????????0??????1??????0?' - assert trail['components_values']['intermediate_output_5_12_backward']['value'] == SIMON_INCOMPATIBLE_ROUND_OUTPUT - + assert trail["status"] == SATISFIABLE + assert trail["components_values"]["intermediate_output_5_12"]["value"] == "????????????????0??????1??????0?" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + + def test_find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model(): simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) milp = MilpBitwiseImpossibleXorDifferentialModel(simon) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), bit_values=[0] * 31 + [1]) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) - key_backward = set_fixed_variables(component_id='key_backward', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) - ciphertext_backward = set_fixed_variables(component_id='cipher_output_10_13_backward', constraint_type='equal', bit_positions=range(32), bit_values=[0] * 6 + [2, 0, 2] + [0] * 23) - trail = milp.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model(fixed_values=[plaintext, key, key_backward, ciphertext_backward]) - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '00000000000000000000000000000001' - assert trail['components_values']['intermediate_output_5_12_backward']['value'] == SIMON_INCOMPATIBLE_ROUND_OUTPUT - assert trail['components_values']['cipher_output_10_13_backward']['value'] == '000000?0?00000000000000000000000' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=[0] * 31 + [1] + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=[0] * 64 + ) + key_backward = set_fixed_variables( + component_id="key_backward", constraint_type="equal", bit_positions=range(64), bit_values=[0] * 64 + ) + ciphertext_backward = set_fixed_variables( + component_id="cipher_output_10_13_backward", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0] * 6 + [2, 0, 2] + [0] * 23, + ) + trail = milp.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model( + fixed_values=[plaintext, key, key_backward, ciphertext_backward] + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "00000000000000000000000000000001" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + assert trail["components_values"]["cipher_output_10_13_backward"]["value"] == "000000?0?00000000000000000000000" trail = milp.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model( - fixed_values=[plaintext, key, key_backward, ciphertext_backward], include_all_components=True) - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '00000000000000000000000000000001' - assert trail['components_values']['intermediate_output_5_12_backward']['value'] == SIMON_INCOMPATIBLE_ROUND_OUTPUT - assert trail['components_values']['cipher_output_10_13_backward']['value'] == '000000?0?00000000000000000000000' + fixed_values=[plaintext, key, key_backward, ciphertext_backward], include_all_components=True + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "00000000000000000000000000000001" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + assert trail["components_values"]["cipher_output_10_13_backward"]["value"] == "000000?0?00000000000000000000000" + +# fmt: off def test_find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components(): - ascon = AsconSboxSigmaPermutation(number_of_rounds=5) - milp = MilpBitwiseImpossibleXorDifferentialModel(ascon) - milp.init_model_in_sage_milp_class() - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(320), - bit_values=[1] + [0] * 191 + [1] + [0] * 63 + [1] + [0] * 63) - P1 = set_fixed_variables(component_id='intermediate_output_0_71', constraint_type='equal', - bit_positions=range(320), - bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0]) - P2 = set_fixed_variables(component_id='intermediate_output_1_71', constraint_type='equal', - bit_positions=range(320), - bit_values=[2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, - 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 2, 0, 2, 0, 0, 2, 2, - 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, - 2, 0, 2, 0, 0, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 0, 0, 0, - 2, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, - 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, - 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 0, 0, 2, 0, 2, 2, 2, 0, 2, 0, 0, 2, - 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, - 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, - 0, 0, 2, 0, 0, 2, 0, 0]) - P3 = set_fixed_variables(component_id='intermediate_output_2_71', constraint_type='equal', - bit_positions=range(320), - bit_values=[2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, - 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, - 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, - 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, - 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, - 2, 2, 2, 2, 2, 2, 2, 2]) - P5 = set_fixed_variables(component_id='cipher_output_4_71', constraint_type='equal', bit_positions=range(320), - bit_values=[0] * 192 + [1] + [0] * 127) - trail = milp.find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components(["sbox_3_56"], - fixed_values=[ - plaintext, - P1, P2, P3, - P5]) - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['sbox_3_56']['value'] == '00000' - assert trail['components_values']['sigma_3_69_backward']['value'] == '1000101000101010101010000000001010001000000010101000001010000000' + ascon = AsconSboxSigmaPermutation(number_of_rounds=5) + milp = MilpBitwiseImpossibleXorDifferentialModel(ascon) + milp.init_model_in_sage_milp_class() + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(320), + bit_values=[1] + [0] * 191 + [1] + [0] * 63 + [1] + [0] * 63 + ) + P1 = set_fixed_variables( + component_id="intermediate_output_0_71", + constraint_type="equal", + bit_positions=range(320), + bit_values = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, + 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + ) + P2 = set_fixed_variables( + component_id="intermediate_output_1_71", + constraint_type="equal", + bit_positions=range(320), + bit_values = [ + 2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, + 0, 0, 2, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, + 2, 2, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 2, 2, 0, + 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, + 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 0, 0, 2, 0, 2, 2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, + 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, + ], + ) + P3 = set_fixed_variables( + component_id="intermediate_output_2_71", + constraint_type="equal", + bit_positions=range(320), + bit_values = [ + 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, + 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, + 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, + 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, + ], + ) + P5 = set_fixed_variables( + component_id="cipher_output_4_71", + constraint_type="equal", + bit_positions=range(320), + bit_values=[0] * 192 + [1] + [0] * 127, + ) + trail = milp.find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components( + ["sbox_3_56"], fixed_values=[plaintext, P1, P2, P3, P5] + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"]["sbox_3_56"]["value"] == "00000" + assert trail["components_values"]["sigma_3_69_backward"]["value"] == "1000101000101010101010000000001010001000000010101000001010000000" +# fmt: on def test_find_one_bitwise_impossible_xor_differential_trail_model_with_external_solver(): simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) milp = MilpBitwiseImpossibleXorDifferentialModel(simon) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=[0] * 31 + [1]) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) - ciphertext = set_fixed_variables(component_id='cipher_output_10_13', constraint_type='equal', - bit_positions=range(32), bit_values=[0] * 6 + [2, 0, 2] + [0] * 23) - trail = milp.find_one_bitwise_impossible_xor_differential_trail(6, fixed_values=[plaintext, key, ciphertext], external_solver_name='glpk_ext') - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['intermediate_output_5_12']['value'] == '????????????????0??????1??????0?' - assert trail['components_values']['intermediate_output_5_12_backward']['value'] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=(0,) * 31 + (1,) + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext = set_fixed_variables( + component_id="cipher_output_10_13", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0] * 6 + [2, 0, 2] + [0] * 23, + ) + trail = milp.find_one_bitwise_impossible_xor_differential_trail( + 6, fixed_values=[plaintext, key, ciphertext], external_solver_name="glpk_ext" + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"]["intermediate_output_5_12"]["value"] == "????????????????0??????1??????0?" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT def test_find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model_with_external_solver(): simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) milp = MilpBitwiseImpossibleXorDifferentialModel(simon) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=[0] * 31 + [1]) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=[0] * 64) - key_backward = set_fixed_variables(component_id='key_backward', constraint_type='equal', bit_positions=range(64), - bit_values=[0] * 64) - ciphertext_backward = set_fixed_variables(component_id='cipher_output_10_13_backward', constraint_type='equal', - bit_positions=range(32), bit_values=[0] * 6 + [2, 0, 2] + [0] * 23) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=(0,) * 31 + (1,) + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + key_backward = set_fixed_variables( + component_id="key_backward", constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext_backward = set_fixed_variables( + component_id="cipher_output_10_13_backward", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0] * 6 + [2, 0, 2] + [0] * 23, + ) trail = milp.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model( - fixed_values=[plaintext, key, key_backward, ciphertext_backward], external_solver_name='glpk_ext') - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '00000000000000000000000000000001' - assert trail['components_values']['intermediate_output_5_12_backward']['value'] == SIMON_INCOMPATIBLE_ROUND_OUTPUT - assert trail['components_values']['cipher_output_10_13_backward']['value'] == '000000?0?00000000000000000000000' + fixed_values=[plaintext, key, key_backward, ciphertext_backward], external_solver_name="glpk_ext" + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "00000000000000000000000000000001" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + assert trail["components_values"]["cipher_output_10_13_backward"]["value"] == "000000?0?00000000000000000000000" diff --git a/tests/unit/cipher_modules/models/milp/milp_models/milp_cipher_model_test.py b/tests/unit/cipher_modules/models/milp/milp_models/milp_cipher_model_test.py index d31f9e4d4..dc1f2c7a2 100644 --- a/tests/unit/cipher_modules/models/milp/milp_models/milp_cipher_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_models/milp_cipher_model_test.py @@ -10,7 +10,7 @@ def test_build_cipher_model(): constraints = milp.model_constraints assert len(constraints) == 9296 - assert str(constraints[0]) == 'x_16 == x_9' - assert str(constraints[1]) == 'x_17 == x_10' - assert str(constraints[9294]) == 'x_4926 == x_4878' - assert str(constraints[9295]) == 'x_4927 == x_4879' + assert str(constraints[0]) == "x_16 == x_9" + assert str(constraints[1]) == "x_17 == x_10" + assert str(constraints[9294]) == "x_4926 == x_4878" + assert str(constraints[9295]) == "x_4927 == x_4879" diff --git a/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model_test.py b/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model_test.py index a6aa8fe2d..72d52d2d7 100644 --- a/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_deterministic_truncated_xor_differential_model_test.py @@ -1,9 +1,10 @@ +from claasp.cipher_modules.models.milp.milp_models.milp_wordwise_deterministic_truncated_xor_differential_model import ( + MilpWordwiseDeterministicTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.utils import get_single_key_scenario_format_for_fixed_values, set_fixed_variables from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher from claasp.ciphers.block_ciphers.midori_block_cipher import MidoriBlockCipher - -from claasp.cipher_modules.models.milp.milp_models.milp_wordwise_deterministic_truncated_xor_differential_model import \ - MilpWordwiseDeterministicTruncatedXorDifferentialModel +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, SATISFIABLE def test_build_wordwise_deterministic_truncated_xor_differential_trail_model(): @@ -14,28 +15,38 @@ def test_build_wordwise_deterministic_truncated_xor_differential_trail_model(): constraints = milp.model_constraints assert len(constraints) == 19768 - assert str(constraints[0]) == '1 <= 1 + x_0 - x_1 + x_2 + x_3 + x_4 + x_5 + x_6 + x_7 + x_8 + x_9' - assert str(constraints[1]) == '1 <= 1 + x_1 - x_9' - assert str(constraints[-2]) == 'x_3062 == x_2886' - assert str(constraints[-1]) == 'x_3063 == x_2887' + assert str(constraints[0]) == "1 <= 1 + x_0 - x_1 + x_2 + x_3 + x_4 + x_5 + x_6 + x_7 + x_8 + x_9" + assert str(constraints[1]) == "1 <= 1 + x_1 - x_9" + assert str(constraints[-2]) == "x_3062 == x_2886" + assert str(constraints[-1]) == "x_3063 == x_2887" + def test_find_one_wordwise_deterministic_truncated_xor_differential_trail_model(): aes = AESBlockCipher(number_of_rounds=2) milp = MilpWordwiseDeterministicTruncatedXorDifferentialModel(aes) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(16), - bit_values=[0, 1, 0, 3] + [0] * 12) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - trail = milp.find_one_wordwise_deterministic_truncated_xor_differential_trail(fixed_bits=[key], - fixed_words=[plaintext]) - assert trail['status'] == 'SATISFIABLE' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(16), bit_values=[0, 1, 0, 3] + [0] * 12 + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + trail = milp.find_one_wordwise_deterministic_truncated_xor_differential_trail( + fixed_bits=[key], fixed_words=[plaintext] + ) + assert trail["status"] == SATISFIABLE + def test_find_lowest_varied_patterns_wordwise_deterministic_truncated_xor_differential_trail_model(): midori = MidoriBlockCipher(number_of_rounds=2) milp = MilpWordwiseDeterministicTruncatedXorDifferentialModel(midori) - ciphertext = set_fixed_variables(component_id='cipher_output_1_18', constraint_type='equal', bit_positions=range(16), - bit_values=[0, 2, 2, 2] + [0] * 12) + ciphertext = set_fixed_variables( + component_id="cipher_output_1_18", + constraint_type="equal", + bit_positions=range(16), + bit_values=[0, 2, 2, 2] + [0] * 12, + ) trail = milp.find_lowest_varied_patterns_wordwise_deterministic_truncated_xor_differential_trail( - get_single_key_scenario_format_for_fixed_values(midori), [ciphertext]) - assert trail['status'] == 'SATISFIABLE' - assert trail['total_weight'] == 3.0 \ No newline at end of file + get_single_key_scenario_format_for_fixed_values(midori), [ciphertext] + ) + assert trail["status"] == SATISFIABLE + assert trail["total_weight"] == 3.0 diff --git a/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model_test.py b/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model_test.py index db07d0062..81bce78d7 100644 --- a/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_models/milp_wordwise_impossible_xor_differential_model_test.py @@ -1,7 +1,9 @@ +from claasp.cipher_modules.models.milp.milp_models.milp_wordwise_impossible_xor_differential_model import ( + MilpWordwiseImpossibleXorDifferentialModel, +) from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.aes_block_cipher import AESBlockCipher -from claasp.cipher_modules.models.milp.milp_models.milp_wordwise_impossible_xor_differential_model import \ - MilpWordwiseImpossibleXorDifferentialModel +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, SATISFIABLE def test_build_wordwise_impossible_xor_differential_trail_model(): @@ -10,113 +12,157 @@ def test_build_wordwise_impossible_xor_differential_trail_model(): milp.init_model_in_sage_milp_class() milp._forward_cipher = aes.get_partial_cipher(0, 1, keep_key_schedule=True) backward_cipher = milp._cipher.cipher_partial_inverse(1, 1, keep_key_schedule=False) - milp._backward_cipher = backward_cipher.add_suffix_to_components("_backward", [backward_cipher.get_all_components_ids()[-1]]) + milp._backward_cipher = backward_cipher.add_suffix_to_components( + "_backward", [backward_cipher.get_all_components_ids()[-1]] + ) milp.build_wordwise_impossible_xor_differential_trail_model() constraints = milp.model_constraints assert len(constraints) == 24200 - assert str(constraints[0]) == '1 <= 1 + x_0 - x_1 + x_2 + x_3 + x_4 + x_5 + x_6 + x_7 + x_8 + x_9' - assert str(constraints[1]) == '1 <= 1 + x_1 - x_9' - assert str(constraints[-2]) == 'x_3238 == x_2065' - assert str(constraints[-1]) == 'x_3239 == x_2066' + assert str(constraints[0]) == "1 <= 1 + x_0 - x_1 + x_2 + x_3 + x_4 + x_5 + x_6 + x_7 + x_8 + x_9" + assert str(constraints[1]) == "1 <= 1 + x_1 - x_9" + assert str(constraints[-2]) == "x_3238 == x_2065" + assert str(constraints[-1]) == "x_3239 == x_2066" + def test_find_one_wordwise_impossible_xor_differential_trail_model(): aes = AESBlockCipher(number_of_rounds=2) milp = MilpWordwiseImpossibleXorDifferentialModel(aes) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(16), - bit_values=[1, 0, 0, 3] + [0]*12) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - ciphertext = set_fixed_variables(component_id='cipher_output_1_32', constraint_type='equal', bit_positions=range(16), - bit_values=[1] + [0]*15) - trail = milp.find_one_wordwise_impossible_xor_differential_trail(1, fixed_bits=[key], - fixed_words=[plaintext, ciphertext]) - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '1003000000000000' - assert trail['components_values']['key']['value'] == '0000000000000000' - assert trail['components_values']['cipher_output_1_32']['value'] == '1000000000000000' - assert trail['components_values']['intermediate_output_0_37']['value'] == '2222333300000000' - assert trail['components_values']['intermediate_output_0_37_backward']['value'] == '2000000000000000' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(16), + bit_values=(1, 0, 0, 3) + (0,) * 12, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + ciphertext = set_fixed_variables( + component_id="cipher_output_1_32", constraint_type="equal", bit_positions=range(16), bit_values=(1,) + (0,) * 15 + ) + trail = milp.find_one_wordwise_impossible_xor_differential_trail( + 1, fixed_bits=[key], fixed_words=[plaintext, ciphertext] + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "1003000000000000" + assert trail["components_values"][INPUT_KEY]["value"] == "0000000000000000" + assert trail["components_values"]["cipher_output_1_32"]["value"] == "1000000000000000" + assert trail["components_values"]["intermediate_output_0_37"]["value"] == "2222333300000000" + assert trail["components_values"]["intermediate_output_0_37_backward"]["value"] == "2000000000000000" def test_find_one_wordwise_impossible_xor_differential_trail_model_with_fixed_components(): aes = AESBlockCipher(number_of_rounds=2) milp = MilpWordwiseImpossibleXorDifferentialModel(aes) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(16), - bit_values=[1, 0, 0, 3] + [0]*12) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - ciphertext = set_fixed_variables(component_id='cipher_output_1_32', constraint_type='equal', bit_positions=range(16), - bit_values=[1] + [0]*15) - trail = milp.find_one_wordwise_impossible_xor_differential_trail_with_chosen_components(['mix_column_0_21'], - fixed_bits=[key], - fixed_words=[plaintext, - ciphertext]) - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '1003000000000000' - assert trail['components_values']['key']['value'] == '0000000000000000' - assert trail['components_values']['cipher_output_1_32']['value'] == '1000000000000000' - assert trail['components_values']['xor_0_36_backward']['value'] == '2000000000000000' - assert trail['components_values']['intermediate_output_0_37_backward']['value'] == '2000000000000000' - + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(16), + bit_values=(1, 0, 0, 3) + (0,) * 12, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + ciphertext = set_fixed_variables( + component_id="cipher_output_1_32", constraint_type="equal", bit_positions=range(16), bit_values=(1,) + (0,) * 15 + ) + trail = milp.find_one_wordwise_impossible_xor_differential_trail_with_chosen_components( + ["mix_column_0_21"], fixed_bits=[key], fixed_words=[plaintext, ciphertext] + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "1003000000000000" + assert trail["components_values"][INPUT_KEY]["value"] == "0000000000000000" + assert trail["components_values"]["cipher_output_1_32"]["value"] == "1000000000000000" + assert trail["components_values"]["xor_0_36_backward"]["value"] == "2000000000000000" + assert trail["components_values"]["intermediate_output_0_37_backward"]["value"] == "2000000000000000" def test_find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model(): aes = AESBlockCipher(number_of_rounds=2) milp = MilpWordwiseImpossibleXorDifferentialModel(aes) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(16), - bit_values=[1, 0, 0, 3] + [0]*12) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - key_backward = set_fixed_variables(component_id='key_backward', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - ciphertext_backward = set_fixed_variables(component_id='cipher_output_1_32_backward', constraint_type='equal', bit_positions=range(16), - bit_values=[1] + [0]*15) - trail = milp.find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model(fixed_bits=[key, key_backward], - fixed_words=[plaintext, ciphertext_backward]) - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '1003000000000000' - assert trail['components_values']['key']['value'] == '0000000000000000' - assert trail['components_values']['cipher_output_1_32_backward']['value'] == '1000000000000000' - assert trail['components_values']['intermediate_output_0_37']['value'] == '2222333300000000' - assert trail['components_values']['intermediate_output_0_37_backward']['value'] == '2000000000000000' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(16), + bit_values=(1, 0, 0, 3) + (0,) * 12, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + key_backward = set_fixed_variables( + component_id="key_backward", constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + ciphertext_backward = set_fixed_variables( + component_id="cipher_output_1_32_backward", + constraint_type="equal", + bit_positions=range(16), + bit_values=(1,) + (0,) * 15, + ) + trail = milp.find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model( + fixed_bits=[key, key_backward], fixed_words=[plaintext, ciphertext_backward] + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "1003000000000000" + assert trail["components_values"][INPUT_KEY]["value"] == "0000000000000000" + assert trail["components_values"]["cipher_output_1_32_backward"]["value"] == "1000000000000000" + assert trail["components_values"]["intermediate_output_0_37"]["value"] == "2222333300000000" + assert trail["components_values"]["intermediate_output_0_37_backward"]["value"] == "2000000000000000" def test_find_one_wordwise_impossible_xor_differential_trail_model_with_external_solver(): aes = AESBlockCipher(number_of_rounds=2) milp = MilpWordwiseImpossibleXorDifferentialModel(aes) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(16), - bit_values=[1, 0, 0, 3] + [0]*12) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - ciphertext = set_fixed_variables(component_id='cipher_output_1_32', constraint_type='equal', bit_positions=range(16), - bit_values=[1] + [0]*15) - trail = milp.find_one_wordwise_impossible_xor_differential_trail(1, fixed_bits=[key], - fixed_words=[plaintext, ciphertext], external_solver_name='glpk_ext') - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '1003000000000000' - assert trail['components_values']['key']['value'] == '0000000000000000' - assert trail['components_values']['cipher_output_1_32']['value'] == '1000000000000000' - assert trail['components_values']['intermediate_output_0_37']['value'] == '2222333300000000' - assert trail['components_values']['intermediate_output_0_37_backward']['value'] == '2000000000000000' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(16), + bit_values=(1, 0, 0, 3) + (0,) * 12, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + ciphertext = set_fixed_variables( + component_id="cipher_output_1_32", constraint_type="equal", bit_positions=range(16), bit_values=(1,) + (0,) * 15 + ) + trail = milp.find_one_wordwise_impossible_xor_differential_trail( + 1, fixed_bits=[key], fixed_words=[plaintext, ciphertext], external_solver_name="glpk_ext" + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "1003000000000000" + assert trail["components_values"][INPUT_KEY]["value"] == "0000000000000000" + assert trail["components_values"]["cipher_output_1_32"]["value"] == "1000000000000000" + assert trail["components_values"]["intermediate_output_0_37"]["value"] == "2222333300000000" + assert trail["components_values"]["intermediate_output_0_37_backward"]["value"] == "2000000000000000" + def test_find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model_with_external_solver(): aes = AESBlockCipher(number_of_rounds=2) milp = MilpWordwiseImpossibleXorDifferentialModel(aes) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(16), - bit_values=[1, 0, 0, 3] + [0]*12) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - key_backward = set_fixed_variables(component_id='key_backward', constraint_type='equal', bit_positions=range(128), - bit_values=[0] * 128) - ciphertext_backward = set_fixed_variables(component_id='cipher_output_1_32_backward', constraint_type='equal', bit_positions=range(16), - bit_values=[1] + [0]*15) - trail = milp.find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model(fixed_bits=[key, key_backward], - fixed_words=[plaintext, ciphertext_backward], external_solver_name='glpk_ext') - assert trail['status'] == 'SATISFIABLE' - assert trail['components_values']['plaintext']['value'] == '1003000000000000' - assert trail['components_values']['key']['value'] == '0000000000000000' - assert trail['components_values']['cipher_output_1_32_backward']['value'] == '1000000000000000' - assert trail['components_values']['intermediate_output_0_37']['value'] == '2222333300000000' - assert trail['components_values']['intermediate_output_0_37_backward']['value'] == '2000000000000000' + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(16), + bit_values=(1, 0, 0, 3) + (0,) * 12, + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + key_backward = set_fixed_variables( + component_id="key_backward", constraint_type="equal", bit_positions=range(128), bit_values=(0,) * 128 + ) + ciphertext_backward = set_fixed_variables( + component_id="cipher_output_1_32_backward", + constraint_type="equal", + bit_positions=range(16), + bit_values=(1,) + (0,) * 15, + ) + trail = milp.find_one_wordwise_impossible_xor_differential_trail_with_fully_automatic_model( + fixed_bits=[key, key_backward], fixed_words=[plaintext, ciphertext_backward], external_solver_name="glpk_ext" + ) + assert trail["status"] == SATISFIABLE + assert trail["components_values"][INPUT_PLAINTEXT]["value"] == "1003000000000000" + assert trail["components_values"][INPUT_KEY]["value"] == "0000000000000000" + assert trail["components_values"]["cipher_output_1_32_backward"]["value"] == "1000000000000000" + assert trail["components_values"]["intermediate_output_0_37"]["value"] == "2222333300000000" + assert trail["components_values"]["intermediate_output_0_37_backward"]["value"] == "2000000000000000" diff --git a/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_differential_model_test.py b/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_differential_model_test.py index ca4fdbcfd..2db1e29a5 100644 --- a/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_differential_model_test.py @@ -1,25 +1,26 @@ +from claasp.cipher_modules.models.milp.milp_models.milp_xor_differential_model import MilpXorDifferentialModel from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list from claasp.ciphers.block_ciphers.present_block_cipher import PresentBlockCipher from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.milp.milp_models.milp_xor_differential_model import MilpXorDifferentialModel from claasp.ciphers.block_ciphers.tea_block_cipher import TeaBlockCipher +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT def test_find_all_xor_differential_trails_with_fixed_weight(): speck = SpeckBlockCipher(block_bit_size=8, key_bit_size=16, number_of_rounds=2) milp = MilpXorDifferentialModel(speck) - trail = milp.find_all_xor_differential_trails_with_fixed_weight(1) + trails = milp.find_all_xor_differential_trails_with_fixed_weight(1) - assert len(trail) == 6 - for i in range(len(trail)): - assert str(trail[i]['cipher']) == 'speck_p8_k16_o8_r2' - assert trail[i]['total_weight'] == 1.0 - assert eval(trail[i]['components_values']['plaintext']['value']) > 0 - assert eval(trail[i]['components_values']['key']['value']) == 0 - assert eval(trail[i]['components_values']['modadd_0_1']['value']) >= 0 - assert trail[i]['components_values']['modadd_0_1']['weight'] >= 0.0 - assert eval(trail[i]['components_values']['intermediate_output_0_6']['value']) >= 0 - assert trail[i]['components_values']['intermediate_output_0_6']['weight'] == 0 + assert len(trails) == 6 + for trail in trails: + assert str(trail["cipher"]) == "speck_p8_k16_o8_r2" + assert trail["total_weight"] == 1.0 + assert int(trail["components_values"][INPUT_PLAINTEXT]["value"], base=16) > 0 + assert int(trail["components_values"][INPUT_KEY]["value"], base=16) == 0 + assert int(trail["components_values"]["modadd_0_1"]["value"], base=16) >= 0 + assert trail["components_values"]["modadd_0_1"]["weight"] >= 0.0 + assert int(trail["components_values"]["intermediate_output_0_6"]["value"], base=16) >= 0 + assert trail["components_values"]["intermediate_output_0_6"]["weight"] == 0 def test_find_all_xor_differential_trails_with_weight_at_most(): @@ -27,9 +28,8 @@ def test_find_all_xor_differential_trails_with_weight_at_most(): milp = MilpXorDifferentialModel(speck) trails = milp.find_all_xor_differential_trails_with_weight_at_most(0, 1) assert len(trails) == 7 - for i in range(len(trails)): - assert trails[i]['total_weight'] <= 1.0 - assert trails[i]['total_weight'] >= 0.0 + for trail in trails: + assert 0.0 <= trail["total_weight"] <= 1.0 def test_find_lowest_weight_xor_differential_trail(): @@ -74,22 +74,30 @@ def test_find_one_xor_differential_trail_with_fixed_weight(): tea = TeaBlockCipher(block_bit_size=16, key_bit_size=32, number_of_rounds=2) milp = MilpXorDifferentialModel(tea) - key = set_fixed_variables(component_id='key', constraint_type='equal', - bit_positions=range(32), bit_values=[0] * 32) - round_0_output = set_fixed_variables('intermediate_output_0_15', 'equal', list(range(16)), - integer_to_bit_list(0x0084, 16, 'big')) - cipher_output = set_fixed_variables('cipher_output_1_16', 'equal', list(range(16)), - integer_to_bit_list(0x404a, 16, 'big')) - trail = milp.find_one_xor_differential_trail_with_fixed_weight(15, fixed_values=[key, round_0_output, cipher_output]) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(32), bit_values=[0] * 32 + ) + round_0_output = set_fixed_variables( + "intermediate_output_0_15", "equal", list(range(16)), integer_to_bit_list(0x0084, 16, "big") + ) + cipher_output = set_fixed_variables( + "cipher_output_1_16", "equal", list(range(16)), integer_to_bit_list(0x404A, 16, "big") + ) + trail = milp.find_one_xor_differential_trail_with_fixed_weight( + 15, fixed_values=[key, round_0_output, cipher_output] + ) assert trail["total_weight"] == 15.0 # speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) milp = MilpXorDifferentialModel(speck) - round_0_output = set_fixed_variables('intermediate_output_0_6', 'equal', list(range(16)), - integer_to_bit_list(0x10001000, 16, 'big')) - cipher_output = set_fixed_variables('cipher_output_1_12', 'equal', list(range(16)), - integer_to_bit_list(0x70203020, 16, 'big')) - key = set_fixed_variables(component_id='key', constraint_type='not_equal', - bit_positions=range(64), bit_values=[0] * 64) + round_0_output = set_fixed_variables( + "intermediate_output_0_6", "equal", list(range(16)), integer_to_bit_list(0x10001000, 16, "big") + ) + cipher_output = set_fixed_variables( + "cipher_output_1_12", "equal", list(range(16)), integer_to_bit_list(0x70203020, 16, "big") + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="not_equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = milp.find_one_xor_differential_trail_with_fixed_weight(5, fixed_values=[key, round_0_output, cipher_output]) assert trail["total_weight"] == 5.0 diff --git a/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_linear_model_test.py b/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_linear_model_test.py index a37799b74..39e36694b 100644 --- a/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_linear_model_test.py +++ b/tests/unit/cipher_modules/models/milp/milp_models/milp_xor_linear_model_test.py @@ -1,8 +1,11 @@ +import pytest + +from claasp.cipher_modules.models.milp.milp_models.milp_xor_linear_model import MilpXorLinearModel +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list -from claasp.cipher_modules.models.milp.milp_models.milp_xor_linear_model import MilpXorLinearModel -import pytest +from claasp.name_mappings import INPUT_PLAINTEXT + def test_build_xor_linear_trail_model(): speck = SpeckBlockCipher(number_of_rounds=22) @@ -10,11 +13,11 @@ def test_build_xor_linear_trail_model(): milp.init_model_in_sage_milp_class() milp.build_xor_linear_trail_model() - assert str(milp.model_constraints[0]) == 'x_0 == x_1' - assert str(milp.model_constraints[1]) == 'x_2 == x_3' - assert str(milp.model_constraints[2]) == 'x_4 == x_5' - assert str(milp.model_constraints[12369]) == 'x_6400 == x_6432' - assert str(milp.model_constraints[12370]) == 'x_6401 == x_6433' + assert str(milp.model_constraints[0]) == "x_0 == x_1" + assert str(milp.model_constraints[1]) == "x_2 == x_3" + assert str(milp.model_constraints[2]) == "x_4 == x_5" + assert str(milp.model_constraints[12369]) == "x_6400 == x_6432" + assert str(milp.model_constraints[12370]) == "x_6401 == x_6433" def test_find_all_xor_linear_trails_with_fixed_weight(): @@ -23,16 +26,16 @@ def test_find_all_xor_linear_trails_with_fixed_weight(): trails = milp.find_all_xor_linear_trails_with_fixed_weight(1) assert len(trails) == 12 - for i in range(len(trails)): - assert str(trails[i]['cipher']) == 'speck_p8_k16_o8_r3' - assert trails[i]['total_weight'] == 1.0 - assert eval(trails[i]['components_values']['plaintext']['value']) > 0 - assert eval(trails[i]['components_values']['key_0_2']['value']) >= 0 - assert trails[i]['components_values']['key_0_2']['weight'] == 0 - assert trails[i]['components_values']['key_0_2']['sign'] == 1 - assert eval(trails[i]['components_values']['rot_0_0_i']['value']) >= 0 - assert trails[i]['components_values']['rot_0_0_i']['weight'] == 0 - assert trails[i]['components_values']['rot_0_0_i']['sign'] == 1 + for trail in trails: + assert str(trail["cipher"]) == "speck_p8_k16_o8_r3" + assert trail["total_weight"] == 1.0 + assert int(trail["components_values"][INPUT_PLAINTEXT]["value"], base=16) > 0 + assert int(trail["components_values"]["key_0_2"]["value"], base=16) >= 0 + assert trail["components_values"]["key_0_2"]["weight"] == 0 + assert trail["components_values"]["key_0_2"]["sign"] == 1 + assert int(trail["components_values"]["rot_0_0_i"]["value"], base=16) >= 0 + assert trail["components_values"]["rot_0_0_i"]["weight"] == 0 + assert trail["components_values"]["rot_0_0_i"]["sign"] == 1 def test_find_all_xor_linear_trails_with_weight_at_most(): @@ -41,17 +44,17 @@ def test_find_all_xor_linear_trails_with_weight_at_most(): trails = milp.find_all_xor_linear_trails_with_weight_at_most(0, 1) assert len(trails) == 13 - for i in range(len(trails)): - assert str(trails[i]['cipher']) == 'speck_p8_k16_o8_r3' - assert trails[i]['total_weight'] <= 1.0 - assert trails[i]['total_weight'] >= 0.0 - assert eval(trails[i]['components_values']['plaintext']['value']) > 0 - assert eval(trails[i]['components_values']['key_0_2']['value']) >= 0 - assert trails[i]['components_values']['key_0_2']['weight'] == 0 - assert trails[i]['components_values']['key_0_2']['sign'] == 1 - assert eval(trails[i]['components_values']['rot_0_0_i']['value']) >= 0 - assert trails[i]['components_values']['rot_0_0_i']['weight'] == 0 - assert trails[i]['components_values']['rot_0_0_i']['sign'] == 1 + for trail in trails: + assert str(trail["cipher"]) == "speck_p8_k16_o8_r3" + assert trail["total_weight"] <= 1.0 + assert trail["total_weight"] >= 0.0 + assert int(trail["components_values"][INPUT_PLAINTEXT]["value"], base=16) > 0 + assert int(trail["components_values"]["key_0_2"]["value"], base=16) >= 0 + assert trail["components_values"]["key_0_2"]["weight"] == 0 + assert trail["components_values"]["key_0_2"]["sign"] == 1 + assert int(trail["components_values"]["rot_0_0_i"]["value"], base=16) >= 0 + assert trail["components_values"]["rot_0_0_i"]["weight"] == 0 + assert trail["components_values"]["rot_0_0_i"]["sign"] == 1 def test_find_lowest_weight_xor_linear_trail(): @@ -85,8 +88,12 @@ def test_find_lowest_weight_xor_linear_trail(): def test_find_one_xor_linear_trail(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) milp = MilpXorLinearModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=integer_to_bit_list(0x03805224, 32, 'big')) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(32), + bit_values=integer_to_bit_list(0x03805224, 32, "big"), + ) trail = milp.find_one_xor_linear_trail(fixed_values=[plaintext]) assert trail["total_weight"] >= 3.0 @@ -132,6 +139,7 @@ def test_find_one_xor_linear_trail_with_fixed_weight_with_installed_external_sol milp = MilpXorLinearModel(speck) trail = milp.find_one_xor_linear_trail_with_fixed_weight(1, external_solver_name="Gurobi_ext") + def test_find_one_xor_linear_trail_with_fixed_weight_with_unsupported_external_solver(): with pytest.raises(Exception) as e_info: speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) @@ -143,18 +151,23 @@ def test_fix_variables_value_xor_linear_constraints(): simon = SimonBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) milp = MilpXorLinearModel(simon) milp.init_model_in_sage_milp_class() - fixed_variables = [{'component_id': 'plaintext', - 'constraint_type': 'equal', - 'bit_positions': [0, 1, 2, 3], - 'bit_values': [1, 0, 1, 1] - }, {'component_id': 'cipher_output_1_8', - 'constraint_type': 'not_equal', - 'bit_positions': [0, 1, 2, 3], - 'bit_values': [1, 1, 1, 0] - }] + fixed_variables = [ + { + "component_id": INPUT_PLAINTEXT, + "constraint_type": "equal", + "bit_positions": [0, 1, 2, 3], + "bit_values": [1, 0, 1, 1], + }, + { + "component_id": "cipher_output_1_8", + "constraint_type": "not_equal", + "bit_positions": [0, 1, 2, 3], + "bit_values": [1, 1, 1, 0], + }, + ] constraints = milp.fix_variables_value_xor_linear_constraints(fixed_variables) - assert str(constraints[0]) == 'x_0 == 1' - assert str(constraints[1]) == 'x_1 == 0' - assert str(constraints[7]) == 'x_10 == x_11' - assert str(constraints[8]) == '1 <= x_4 + x_6 + x_8 + x_10' + assert str(constraints[0]) == "x_0 == 1" + assert str(constraints[1]) == "x_1 == 0" + assert str(constraints[7]) == "x_10 == x_11" + assert str(constraints[8]) == "1 <= x_4 + x_6 + x_8 + x_10" diff --git a/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrix_test.py b/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrix_test.py index 757ef6773..91dc029eb 100644 --- a/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrix_test.py +++ b/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_mds_matrix_test.py @@ -1,10 +1,12 @@ -from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_mds_matrices import \ - generate_valid_points_for_truncated_mds_matrix +from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_mds_matrices import ( + generate_valid_points_for_truncated_mds_matrix, +) + + def test_generate_valid_points_for_truncated_mds_matrix(): valid_points = generate_valid_points_for_truncated_mds_matrix(dimensions=(4, 4), max_pattern_value=3) assert len(valid_points) == 256 - assert valid_points[0] == '0000000000000000' - assert valid_points[1] == '0000000110101010' - assert valid_points[-2] == '1111111011111111' - assert valid_points[-1] == '1111111111111111' - + assert valid_points[0] == "0000000000000000" + assert valid_points[1] == "0000000110101010" + assert valid_points[-2] == "1111111011111111" + assert valid_points[-1] == "1111111111111111" diff --git a/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits_test.py b/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits_test.py index c9f29e815..004d66b66 100644 --- a/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits_test.py +++ b/tests/unit/cipher_modules/models/milp/utils/generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits_test.py @@ -1,26 +1,32 @@ from claasp.cipher_modules.models.milp.utils.generate_inequalities_for_wordwise_truncated_xor_with_n_input_bits import * + + def test_generate_valid_points_input_words(): valid_points = generate_valid_points_input_words() assert len(valid_points) == 18 - assert valid_points[0] == '000000' - assert valid_points[1] == '010001' - assert valid_points[-2] == '100000' - assert valid_points[-1] == '110000' + assert valid_points[0] == "000000" + assert valid_points[1] == "010001" + assert valid_points[-2] == "100000" + assert valid_points[-1] == "110000" + def test_update_dictionary_that_contains_wordwise_truncated_input_inequalities(): update_dictionary_that_contains_wordwise_truncated_input_inequalities(3) dictio = output_dictionary_that_contains_wordwise_truncated_input_inequalities() - assert dictio[3] == ['01000', '-0--1', '1---1', '-0-1-', '1--1-', '-01--', '1-1--'] + assert dictio[3] == ["01000", "-0--1", "1---1", "-0-1-", "1--1-", "-01--", "1-1--"] + + def test_generate_valid_points_for_xor_between_n_input_words(): valid_points = generate_valid_points_for_xor_between_n_input_words() assert len(valid_points) == 324 - assert valid_points[0] == '000000000000000000' - assert valid_points[1] == '000000010001010001' - assert valid_points[-2] == '110000100000110000' - assert valid_points[-1] == '110000110000110000' + assert valid_points[0] == "000000000000000000" + assert valid_points[1] == "000000010001010001" + assert valid_points[-2] == "110000100000110000" + assert valid_points[-1] == "110000110000110000" + def test_update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs(): update_dictionary_that_contains_wordwise_truncated_xor_inequalities_between_n_inputs(3, 3) dictio = output_dictionary_that_contains_wordwise_truncated_xor_inequalities() - assert dictio[3][3][:2] == ['0-000-0---00----1---', '0-00000----0----1---'] - assert dictio[3][3][-2:] == ['-----1---------0----', '1--------------0----'] \ No newline at end of file + assert dictio[3][3][:2] == ["0-000-0---00----1---", "0-00000----0----1---"] + assert dictio[3][3][-2:] == ["-----1---------0----", "1--------------0----"] diff --git a/tests/unit/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search_test.py b/tests/unit/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search_test.py index 089de4f89..7accfa6f9 100644 --- a/tests/unit/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search_test.py +++ b/tests/unit/cipher_modules/models/milp/utils/generate_sbox_inequalities_for_trail_search_test.py @@ -7,4 +7,4 @@ def test_generate_sbox_inequalities_for_trail_search(): SBox_PRESENT = SBox([12, 5, 6, 11, 9, 0, 10, 13, 3, 14, 15, 8, 4, 7, 1, 2]) sbox_ineqs = sbox_inequalities(SBox_PRESENT) - assert str(sbox_ineqs[2][1]) == 'An inequality (0, 0, 0, 1, 1, 0, 1, 0) x - 1 >= 0' + assert str(sbox_ineqs[2][1]) == "An inequality (0, 0, 0, 1, 1, 0, 1, 0) x - 1 >= 0" diff --git a/tests/unit/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes_test.py b/tests/unit/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes_test.py index 757e459ee..d95d72444 100644 --- a/tests/unit/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes_test.py +++ b/tests/unit/cipher_modules/models/milp/utils/generate_undisturbed_bits_inequalities_for_sboxes_test.py @@ -1,6 +1,8 @@ from sage.crypto.sbox import SBox from claasp.ciphers.block_ciphers.present_block_cipher import PresentBlockCipher from claasp.cipher_modules.models.milp.utils.generate_undisturbed_bits_inequalities_for_sboxes import * + + def test_update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits(): delete_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits() dict = get_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits() @@ -9,12 +11,16 @@ def test_update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbe present = PresentBlockCipher(number_of_rounds=1) sbox_component = present.component_from(0, 1) valid_points = sbox_component.get_ddt_with_undisturbed_transitions() - undisturbed_points = [i for i in valid_points if i[1]!= (2,2,2,2)] + undisturbed_points = [i for i in valid_points if i[1] != (2, 2, 2, 2)] assert len(valid_points) == 81 - assert undisturbed_points == [((0, 0, 0, 0), (0, 0, 0, 0)), ((0, 0, 0, 1), (2, 2, 2, 1)), ((1, 0, 0, 0), (2, 2, 2, 1)), ((1, 0, 0, 1), (2, 2, 2, 0))] + assert undisturbed_points == [ + ((0, 0, 0, 0), (0, 0, 0, 0)), + ((0, 0, 0, 1), (2, 2, 2, 1)), + ((1, 0, 0, 0), (2, 2, 2, 1)), + ((1, 0, 0, 1), (2, 2, 2, 0)), + ] sbox = SBox(sbox_component.description) update_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits(sbox, valid_points) dict = get_dictionary_that_contains_inequalities_for_sboxes_with_undisturbed_bits() - assert dict[str(sbox)][0][1] == ['------11-', '----11---', '--11-----', '11-------', '--------1'] - + assert dict[str(sbox)][0][1] == ["------11-", "----11---", "--11-----", "11-------", "--------1"] diff --git a/tests/unit/cipher_modules/models/sat/cms_models/cms_cipher_model_test.py b/tests/unit/cipher_modules/models/sat/cms_models/cms_cipher_model_test.py index 09f7db403..2effa094f 100644 --- a/tests/unit/cipher_modules/models/sat/cms_models/cms_cipher_model_test.py +++ b/tests/unit/cipher_modules/models/sat/cms_models/cms_cipher_model_test.py @@ -1,8 +1,23 @@ from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.cipher_modules.models.sat.cms_models.cms_cipher_model import CmsSatCipherModel +from claasp.cipher_modules.models.sat.solvers import CRYPTOMINISAT +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY -def test_build_cipher_model(): - speck = SpeckBlockCipher(number_of_rounds=22) +def test_find_missing_bits(): + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=22) cms = CmsSatCipherModel(speck) - cms.build_cipher_model() + cipher_output_id = speck.get_all_components_ids()[-1] + plaintext_bits = integer_to_bit_list(0x6574694C, 32, "big") + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=plaintext_bits + ) + key_bits = integer_to_bit_list(0x1918111009080100, 64, "big") + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=key_bits + ) + + missing_bits = cms.find_missing_bits(fixed_values=[plaintext, key], solver_name=CRYPTOMINISAT) + + assert missing_bits["components_values"][cipher_output_id]["value"] == "0xa86842f2" diff --git a/tests/unit/cipher_modules/models/sat/cms_models/cms_deterministic_truncated_xor_differential_model_test.py b/tests/unit/cipher_modules/models/sat/cms_models/cms_deterministic_truncated_xor_differential_model_test.py index ac8d13f72..a42027202 100644 --- a/tests/unit/cipher_modules/models/sat/cms_models/cms_deterministic_truncated_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/sat/cms_models/cms_deterministic_truncated_xor_differential_model_test.py @@ -1,6 +1,7 @@ +from claasp.cipher_modules.models.sat.cms_models.cms_bitwise_deterministic_truncated_xor_differential_model import ( + CmsSatDeterministicTruncatedXorDifferentialModel, +) from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.sat.cms_models.cms_bitwise_deterministic_truncated_xor_differential_model import \ - CmsSatDeterministicTruncatedXorDifferentialModel def test_build_bitwise_deterministic_truncated_xor_differential_trail_model(): diff --git a/tests/unit/cipher_modules/models/sat/sat_model_test.py b/tests/unit/cipher_modules/models/sat/sat_model_test.py index 8d4ffcc1b..ea705452c 100644 --- a/tests/unit/cipher_modules/models/sat/sat_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_model_test.py @@ -7,6 +7,7 @@ from claasp.cipher_modules.models.sat.sat_model import SatModel from claasp.cipher_modules.models.sat.sat_models.sat_cipher_model import SatCipherModel from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel +from claasp.cipher_modules.models.sat.solvers import CRYPTOMINISAT, CRYPTOMINISAT_EXT, KISSAT_EXT, PARKISSAT_EXT def test_solve(): @@ -14,18 +15,16 @@ def test_solve(): tea = TeaBlockCipher(number_of_rounds=32) sat = SatCipherModel(tea) sat.build_cipher_model() - solution = sat.solve("cipher", solver_name="CRYPTOMINISAT_EXT") + solution = sat.solve("cipher", solver_name=CRYPTOMINISAT_EXT) assert str(solution["cipher"]) == "tea_p64_k128_o64_r32" - assert solution["solver_name"] == "CRYPTOMINISAT_EXT" assert eval(solution["components_values"]["modadd_0_3"]["value"]) >= 0 assert eval(solution["components_values"]["cipher_output_31_16"]["value"]) >= 0 # testing with sage solver simon = SimonBlockCipher(number_of_rounds=32) sat = SatCipherModel(simon) sat.build_cipher_model() - solution = sat.solve("cipher", solver_name="cryptominisat") + solution = sat.solve("cipher", solver_name=CRYPTOMINISAT) assert str(solution["cipher"]) == "simon_p32_k64_o32_r32" - assert solution["solver_name"] == "cryptominisat" assert eval(solution["components_values"]["rot_0_3"]["value"]) >= 0 assert eval(solution["components_values"]["cipher_output_31_13"]["value"]) >= 0 @@ -53,6 +52,22 @@ def test_fix_variables_value_constraints(): "-ciphertext_0 -ciphertext_1 -ciphertext_2 ciphertext_3", ] + speck = SpeckBlockCipher(number_of_rounds=3) + sat = SatXorDifferentialModel(speck) + fixed_values = [set_fixed_variables('plaintext', 'equal', range(32), [(speck.get_all_components_ids()[-1], list(range(32)))])] + trail = sat.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['components_values']['plaintext']['value'] == trail['components_values'][speck.get_all_components_ids()[-1]]['value'] + + fixed_values = [set_fixed_variables('plaintext', 'not_equal', range(32), [(speck.get_all_components_ids()[-1], list(range(32)))])] + trail = sat.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['components_values']['plaintext']['value'] != trail['components_values'][speck.get_all_components_ids()[-1]]['value'] + + fixed_values = [set_fixed_variables('plaintext', 'equal', range(32), [0]*31+[1])] + fixed_values.append(set_fixed_variables(speck.get_all_components_ids()[-1], 'equal', range(32), [0]*31+[1])) + fixed_values.append(set_fixed_variables('plaintext', 'not_equal', range(32), [(speck.get_all_components_ids()[-1], list(range(32)))])) + trail = sat.find_one_xor_differential_trail(fixed_values=fixed_values) + assert trail['status'] == 'UNSATISFIABLE' + def test_build_xor_differential_sat_model_from_dictionary(): component_model_types = [] @@ -89,7 +104,7 @@ def test_build_xor_differential_sat_model_from_dictionary(): variables, constraints = sat_model.weight_constraints(3) sat_model._variables_list.extend(variables) sat_model._model_constraints.extend(constraints) - result = sat_model._solve_with_external_sat_solver("xor_differential", "PARKISSAT_EXT", ["-c=6"]) + result = sat_model._solve_with_external_sat_solver("xor_differential", PARKISSAT_EXT, ["-c=6"]) assert result["status"] == "SATISFIABLE" @@ -136,10 +151,7 @@ def test_build_generic_sat_model_from_dictionary(): sat_model = SatCipherModel(speck) sat_model.build_generic_sat_model_from_dictionary([plaintext, key], component_model_types) - # variables, constraints = sat_model.weight_constraints(3) - # sat_model._variables_list.extend(variables) - # sat_model._model_constraints.extend(constraints) - result = sat_model._solve_with_external_sat_solver("xor_differential", "KISSAT_EXT", []) + result = sat_model._solve_with_external_sat_solver("xor_differential", KISSAT_EXT, []) assert result["status"] == "SATISFIABLE" diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model_test.py index f0dbb2529..fd0069fca 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_bitwise_deterministic_truncated_xor_differential_model_test.py @@ -1,7 +1,9 @@ +from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_deterministic_truncated_xor_differential_model import ( + SatBitwiseDeterministicTruncatedXorDifferentialModel, +) from claasp.cipher_modules.models.utils import get_single_key_scenario_format_for_fixed_values, set_fixed_variables from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_deterministic_truncated_xor_differential_model import \ - SatBitwiseDeterministicTruncatedXorDifferentialModel +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, SATISFIABLE def test_build_bitwise_deterministic_truncated_xor_differential_trail_model(): @@ -11,34 +13,46 @@ def test_build_bitwise_deterministic_truncated_xor_differential_trail_model(): constraints = sat.model_constraints assert len(constraints) == 28761 - assert str(constraints[0]) == 'rot_0_0_0_0 -plaintext_9_0' - assert str(constraints[1]) == 'plaintext_9_0 -rot_0_0_0_0' - assert str(constraints[-2]) == 'cipher_output_21_12_31_1 -xor_21_10_15_1' - assert str(constraints[-1]) == 'xor_21_10_15_1 -cipher_output_21_12_31_1' + assert str(constraints[0]) == "rot_0_0_0_0 -plaintext_9_0" + assert str(constraints[1]) == "plaintext_9_0 -rot_0_0_0_0" + assert str(constraints[-2]) == "cipher_output_21_12_31_1 -xor_21_10_15_1" + assert str(constraints[-1]) == "xor_21_10_15_1 -cipher_output_21_12_31_1" def test_find_one_bitwise_deterministic_truncated_xor_differential_trail(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) sat = SatBitwiseDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=(0,) * 64) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(32), + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = sat.find_one_bitwise_deterministic_truncated_xor_differential_trail(fixed_values=[plaintext, key]) - assert trail['components_values']['intermediate_output_0_6']['value'] == '????100000000000????100000000011' + assert trail["components_values"]["intermediate_output_0_6"]["value"] == "????100000000000????100000000011" speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=3) sat = SatBitwiseDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0)) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), bit_values=(0,) * 64) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(32), + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = sat.find_one_bitwise_deterministic_truncated_xor_differential_trail(fixed_values=[plaintext, key]) - assert trail['components_values']['cipher_output_2_12']['value'] == '???????????????0????????????????' + assert trail["components_values"]["cipher_output_2_12"]["value"] == "???????????????0????????????????" def test_find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail(): speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) sat = SatBitwiseDeterministicTruncatedXorDifferentialModel(speck) - trail = sat.find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail(get_single_key_scenario_format_for_fixed_values(speck)) - assert trail['status'] == 'SATISFIABLE' + trail = sat.find_lowest_varied_patterns_bitwise_deterministic_truncated_xor_differential_trail( + get_single_key_scenario_format_for_fixed_values(speck) + ) + assert trail["status"] == SATISFIABLE diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_bitwise_impossible_xor_differential_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_bitwise_impossible_xor_differential_model_test.py new file mode 100644 index 000000000..1e3105151 --- /dev/null +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_bitwise_impossible_xor_differential_model_test.py @@ -0,0 +1,161 @@ +from claasp.cipher_modules.models.sat.sat_models.sat_bitwise_impossible_xor_differential_model import ( + SatBitwiseImpossibleXorDifferentialModel, +) +from claasp.cipher_modules.models.utils import set_fixed_variables +from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.ciphers.permutations.ascon_sbox_sigma_permutation import AsconSboxSigmaPermutation +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT + + +SIMON_INCOMPATIBLE_ROUND_OUTPUT = "????????00?????0???????0????????" + + +def test_build_bitwise_deterministic_truncated_xor_differential_trail_model(): + speck = SpeckBlockCipher(number_of_rounds=2) + sat = SatBitwiseImpossibleXorDifferentialModel(speck) + sat._forward_cipher = speck.get_partial_cipher(0, 1, keep_key_schedule=True) + backward_cipher = sat._cipher.cipher_partial_inverse(1, 1, keep_key_schedule=False) + sat._backward_cipher = backward_cipher.add_suffix_to_components( + "_backward", [backward_cipher.get_all_components_ids()[-1]] + ) + sat.build_bitwise_impossible_xor_differential_trail_model() + constraints = sat.model_constraints + + assert len(constraints) == 2764 + assert constraints[0] == "rot_0_0_0_0 -plaintext_9_0" + assert constraints[1] == "plaintext_9_0 -rot_0_0_0_0" + assert constraints[-2] == "intermediate_output_0_6_backward_31_1 -rot_1_9_15_1" + assert constraints[-1] == "rot_1_9_15_1 -intermediate_output_0_6_backward_31_1" + + +def test_find_one_bitwise_impossible_xor_differential_trail_model(): + simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) + sat = SatBitwiseImpossibleXorDifferentialModel(simon) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=(0,) * 31 + (1,) + ) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext = set_fixed_variables( + component_id="cipher_output_10_13", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0] * 6 + [2, 0, 2] + [0] * 23, + ) + trail = sat.find_one_bitwise_impossible_xor_differential_trail(6, fixed_values=[plaintext, key, ciphertext]) + assert trail["status"] == "SATISFIABLE" + assert trail["components_values"]["intermediate_output_5_12"]["value"] == "????????????????0??????1??????0?" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + + +# fmt: off +def test_find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components(): + ascon = AsconSboxSigmaPermutation(number_of_rounds=5) + sat = SatBitwiseImpossibleXorDifferentialModel(ascon) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(320), + bit_values=(1,) + (0,) * 191 + (1,) + (0,) * 63 + (1,) + (0,) * 63 + ) + P1 = set_fixed_variables( + component_id="intermediate_output_0_71", + constraint_type="equal", + bit_positions=range(320), + bit_values=( + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, + 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ), + ) + P2 = set_fixed_variables( + component_id="intermediate_output_1_71", + constraint_type="equal", + bit_positions=range(320), + bit_values=( + 2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, + 0, 0, 2, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, + 2, 2, 0, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 2, 2, 0, + 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, + 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 0, 0, 2, 0, 2, 2, 2, 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, + 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, + ), + ) + P3 = set_fixed_variables( + component_id="intermediate_output_2_71", + constraint_type="equal", + bit_positions=range(320), + bit_values=( + 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, + 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, + 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, + 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, + ), + ) + P5 = set_fixed_variables( + component_id="cipher_output_4_71", + constraint_type="equal", + bit_positions=range(320), + bit_values=(0,) * 192 + (1,) + (0,) * 127 + ) + trail = sat.find_one_bitwise_impossible_xor_differential_trail_with_chosen_incompatible_components( + ["sbox_3_56"], fixed_values=[plaintext, P1, P2, P3, P5] + ) + assert trail["status"] == "SATISFIABLE" + assert trail["components_values"]["sbox_3_56"]["value"] == "00000" + assert trail["components_values"]["sigma_3_69_backward"]["value"] == "1000101000101010101010000000001010001000000010101000001010000000" +# fmt: on + + +def test_find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model(): + simon = SimonBlockCipher(block_bit_size=32, number_of_rounds=11) + sat = SatBitwiseImpossibleXorDifferentialModel(simon) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=(0,) * 31 + (1,) + ) + key = set_fixed_variables( + component_id="key", constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + key_backward = set_fixed_variables( + component_id="key_backward", constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) + ciphertext_backward = set_fixed_variables( + component_id="cipher_output_10_13_backward", + constraint_type="equal", + bit_positions=range(32), + bit_values=(0,) * 6 + (2, 0, 2) + (0,) * 23, + ) + trail = sat.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model( + fixed_values=[plaintext, key, key_backward, ciphertext_backward] + ) + assert trail["status"] == "SATISFIABLE" + assert trail["components_values"]["plaintext"]["value"] == "00000000000000000000000000000001" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + assert trail["components_values"]["cipher_output_10_13_backward"]["value"] == "000000?0?00000000000000000000000" + + trail = sat.find_one_bitwise_impossible_xor_differential_trail_with_fully_automatic_model( + fixed_values=[plaintext, key, key_backward, ciphertext_backward], include_all_components=True + ) + assert trail["status"] == "SATISFIABLE" + assert trail["components_values"]["plaintext"]["value"] == "00000000000000000000000000000001" + assert trail["components_values"]["intermediate_output_5_12_backward"]["value"] == SIMON_INCOMPATIBLE_ROUND_OUTPUT + assert trail["components_values"]["cipher_output_10_13_backward"]["value"] == "000000?0?00000000000000000000000" diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_cipher_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_cipher_model_test.py index 9aad94932..f698a9bcf 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_cipher_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_cipher_model_test.py @@ -1,23 +1,22 @@ -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.cipher_modules.models.sat.sat_models.sat_cipher_model import SatCipherModel +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY def test_find_missing_bits(): - speck = SpeckBlockCipher(number_of_rounds=22) - cipher_output_id = speck.get_all_components_ids()[-1] + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=22) sat = SatCipherModel(speck) - ciphertext = set_fixed_variables( - component_id=cipher_output_id, - constraint_type="equal", - bit_positions=range(32), - bit_values=integer_to_bit_list(0x1234ABCD, 32, "big"), + cipher_output_id = speck.get_all_components_ids()[-1] + plaintext_bits = integer_to_bit_list(0x6574694C, 32, "big") + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=plaintext_bits + ) + key_bits = integer_to_bit_list(0x1918111009080100, 64, "big") + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=key_bits ) - missing_bits = sat.find_missing_bits(fixed_values=[ciphertext]) + missing_bits = sat.find_missing_bits(fixed_values=[plaintext, key]) - assert str(missing_bits["cipher"]) == "speck_p32_k64_o32_r22" - assert missing_bits["model_type"] == "cipher" - assert missing_bits["solver_name"] == "CRYPTOMINISAT_EXT" - assert missing_bits["components_values"][cipher_output_id] == {"value": "0x1234abcd"} - assert missing_bits["status"] == "SATISFIABLE" + assert missing_bits["components_values"][cipher_output_id]["value"] == "0xa86842f2" diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_differential_linear_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_differential_linear_test.py index fb24db8ab..b32eaf72d 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_differential_linear_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_differential_linear_test.py @@ -1,11 +1,18 @@ import itertools +import math from claasp.cipher_modules.models.sat.sat_models.sat_differential_linear_model import SatDifferentialLinearModel -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list, \ - differential_linear_checker_for_permutation, differential_linear_checker_for_block_cipher_single_key +from claasp.cipher_modules.models.sat.solvers import CADICAL_EXT +from claasp.cipher_modules.models.utils import ( + set_fixed_variables, + integer_to_bit_list, + differential_linear_checker_for_permutation, + differential_linear_checker_for_block_cipher_single_key, +) from claasp.ciphers.block_ciphers.aradi_block_cipher_sbox import AradiBlockCipherSBox from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY, SATISFIABLE def test_differential_linear_trail_with_fixed_weight_6_rounds_speck(): @@ -25,37 +32,34 @@ def test_differential_linear_trail_with_fixed_weight_6_rounds_speck(): bottom_part_components = [component.id for component in bottom_part_components] plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=range(32), - bit_values=integer_to_bit_list(0x05020402, 32, 'big') + bit_values=integer_to_bit_list(0x05020402, 32, "big"), ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) ciphertext_difference = set_fixed_variables( - component_id='cipher_output_5_12', - constraint_type='equal', + component_id="cipher_output_5_12", + constraint_type="equal", bit_positions=range(32), - bit_values=integer_to_bit_list(0x00040004, 32, 'big') + bit_values=integer_to_bit_list(0x00040004, 32, "big"), ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } sat_heterogeneous_model = SatDifferentialLinearModel(speck, component_model_list) trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( - weight=10, fixed_values=[key, plaintext, ciphertext_difference], solver_name="CADICAL_EXT", num_unknown_vars=2 + weight=10, fixed_values=[key, plaintext, ciphertext_difference], solver_name=CADICAL_EXT, num_unknown_vars=2 ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE def test_lowest_differential_linear_trail_with_fixed_weight_6_rounds_speck(): @@ -75,43 +79,33 @@ def test_lowest_differential_linear_trail_with_fixed_weight_6_rounds_speck(): bottom_part_components = [component.id for component in bottom_part_components] plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=integer_to_bit_list(0x0, 32, 'big') + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) ciphertext_difference = set_fixed_variables( - component_id='cipher_output_5_12', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=integer_to_bit_list(0x0, 32, 'big') + component_id="cipher_output_5_12", constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } sat_heterogeneous_model = SatDifferentialLinearModel(speck, component_model_list) trail = sat_heterogeneous_model.find_lowest_weight_xor_differential_linear_trail( - fixed_values=[key, plaintext, ciphertext_difference], solver_name="CADICAL_EXT", num_unknown_vars=2 + fixed_values=[key, plaintext, ciphertext_difference], solver_name=CADICAL_EXT, num_unknown_vars=2 ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE def test_differential_linear_trail_with_fixed_weight_3_rounds_chacha(): """Test for finding a differential-linear trail with fixed weight for 3 rounds of ChaCha permutation.""" chacha = ChachaPermutation(number_of_rounds=6) - import itertools top_part_components = [] middle_part_components = [] bottom_part_components = [] @@ -129,44 +123,51 @@ def test_differential_linear_trail_with_fixed_weight_3_rounds_chacha(): bottom_part_components = [component.id for component in bottom_part_components] plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=range(512), - bit_values=integer_to_bit_list(0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000, 512, 'big') + bit_values=integer_to_bit_list( + 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000, + 512, + "big", + ), ) cipher_output_5_24 = set_fixed_variables( - component_id='cipher_output_5_24', - constraint_type='equal', + component_id="cipher_output_5_24", + constraint_type="equal", bit_positions=range(512), - bit_values=integer_to_bit_list(0x00010000000100010000000100030003000000800000008000000000000001800000000000000001000000010000000201000101010000000000010103000101, 512, 'big') + bit_values=integer_to_bit_list( + 0x00010000000100010000000100030003000000800000008000000000000001800000000000000001000000010000000201000101010000000000010103000101, + 512, + "big", + ), ) modadd_3_15 = set_fixed_variables( - component_id=f'modadd_3_15', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=[0] * 32 + component_id=f"modadd_3_15", constraint_type="not_equal", bit_positions=range(32), bit_values=[0] * 32 ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } sat_heterogeneous_model = SatDifferentialLinearModel(chacha, component_model_list) trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( - weight=5, fixed_values=[plaintext, modadd_3_15, cipher_output_5_24], solver_name="CADICAL_EXT", num_unknown_vars=511 + weight=5, + fixed_values=[plaintext, modadd_3_15, cipher_output_5_24], + solver_name=CADICAL_EXT, + num_unknown_vars=511, ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE assert trail["total_weight"] <= 5 def test_differential_linear_trail_with_fixed_weight_4_rounds_aradi(): """Test for finding a differential-linear trail with fixed weight for 4 rounds of Aradi block cipher.""" aradi = AradiBlockCipherSBox(number_of_rounds=4) - import itertools top_part_components = [] middle_part_components = [] bottom_part_components = [] @@ -183,53 +184,47 @@ def test_differential_linear_trail_with_fixed_weight_4_rounds_aradi(): bottom_part_components = [component.id for component in bottom_part_components] plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=range(128), - bit_values=integer_to_bit_list(0x00000000000080000000000000008000, 128, 'big') + bit_values=integer_to_bit_list(0x00000000000080000000000000008000, 128, "big"), ) cipher_output_3_86 = set_fixed_variables( - component_id='cipher_output_3_86', - constraint_type='equal', + component_id="cipher_output_3_86", + constraint_type="equal", bit_positions=range(128), - bit_values=integer_to_bit_list(0x90900120800000011010002000000000, 128, 'big') + bit_values=integer_to_bit_list(0x90900120800000011010002000000000, 128, "big"), ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(256), - bit_values=[0] * 256 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(256), bit_values=[0] * 256 ) sbox_4_8 = set_fixed_variables( - component_id=f'sbox_3_8', - constraint_type='not_equal', - bit_positions=range(4), - bit_values=[0] * 4 + component_id="sbox_3_8", constraint_type="not_equal", bit_positions=range(4), bit_values=[0] * 4 ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } sat_heterogeneous_model = SatDifferentialLinearModel(aradi, component_model_list) trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( - weight=10, fixed_values=[key, plaintext, sbox_4_8, cipher_output_3_86], solver_name="CADICAL_EXT", num_unknown_vars=128-1 + weight=10, + fixed_values=[key, plaintext, sbox_4_8, cipher_output_3_86], + solver_name=CADICAL_EXT, + num_unknown_vars=128 - 1, ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE assert trail["total_weight"] <= 10 def test_differential_linear_trail_with_fixed_weight_4_rounds_chacha(): """Test for finding a differential-linear trail with fixed weight for 4 rounds of ChaCha permutation.""" chacha = ChachaPermutation(number_of_rounds=8) - - import itertools - top_part_components = [] middle_part_components = [] bottom_part_components = [] @@ -247,43 +242,37 @@ def test_differential_linear_trail_with_fixed_weight_4_rounds_chacha(): bottom_part_components = [component.id for component in bottom_part_components] plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=range(512), bit_values=integer_to_bit_list( 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000088088780, 512, - 'big' - ) + "big", + ), ) modadd_4_15 = set_fixed_variables( - component_id=f'modadd_4_15', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=[0] * 32 + component_id="modadd_4_15", constraint_type="not_equal", bit_positions=range(32), bit_values=[0] * 32 ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } sat_heterogeneous_model = SatDifferentialLinearModel(chacha, component_model_list) trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( - weight=32, fixed_values=[plaintext, modadd_4_15], solver_name="CADICAL_EXT", num_unknown_vars=511 + weight=32, fixed_values=[plaintext, modadd_4_15], solver_name=CADICAL_EXT, num_unknown_vars=511 ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE assert trail["total_weight"] <= 32 def test_differential_linear_trail_with_fixed_weight_4_rounds_chacha_second_case(): """Test for finding a differential-linear trail with fixed weight for 4 rounds of ChaCha permutation.""" chacha = ChachaPermutation(number_of_rounds=8) - - import itertools - top_part_components = [] middle_part_components = [] bottom_part_components = [] @@ -301,36 +290,31 @@ def test_differential_linear_trail_with_fixed_weight_4_rounds_chacha_second_case bottom_part_components = [component.id for component in bottom_part_components] plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=range(512), bit_values=integer_to_bit_list( 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000088088780, 512, - 'big' - ) + "big", + ), ) modadd_4_15 = set_fixed_variables( - component_id=f'modadd_4_15', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=[0] * 32 + component_id=f"modadd_4_15", constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } - sat_heterogeneous_model = SatDifferentialLinearModel( - chacha, component_model_list - ) + sat_heterogeneous_model = SatDifferentialLinearModel(chacha, component_model_list) trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( - weight=32, fixed_values=[plaintext, modadd_4_15], solver_name="CADICAL_EXT", num_unknown_vars=511 + weight=32, fixed_values=[plaintext, modadd_4_15], solver_name=CADICAL_EXT, num_unknown_vars=511 ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE assert trail["total_weight"] <= 32 @@ -339,8 +323,6 @@ def test_differential_linear_trail_with_fixed_weight_8_rounds_chacha_one_case(): This test is using in the middle part the semi-deterministic model. """ chacha = ChachaPermutation(number_of_rounds=8) - import itertools - top_part_components = [] middle_part_components = [] bottom_part_components = [] @@ -362,63 +344,63 @@ def test_differential_linear_trail_with_fixed_weight_8_rounds_chacha_one_case(): initial_state_positions[508] = 1 plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=list(range(state_size)), bit_values=integer_to_bit_list( 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000, state_size, - 'big' - ) + "big", + ), ) - intermediate_output_2_24_string = '0000000000000000000000000000000000000000000000000000000000000000001000000010000000000?10000000100000010000000000000000000000000000000000010000000000000000000000?100000000000100000000000000000000000100000000000100000000000100001000100010001000100010001000100000001000000010001000000010000000000000000000000000010000000000000000000000?1000000000001000000000000000100000001000000000001000000000000000000000000000000000000000?100000001000100000001000000000000000000000000001000000000000000000000011000000000001000000' + intermediate_output_2_24_string = "0000000000000000000000000000000000000000000000000000000000000000001000000010000000000?10000000100000010000000000000000000000000000000000010000000000000000000000?100000000000100000000000000000000000100000000000100000000000100001000100010001000100010001000100000001000000010001000000010000000000000000000000000010000000000000000000000?1000000000001000000000000000100000001000000000001000000000000000000000000000000000000000?100000001000100000001000000000000000000000000001000000000000000000000011000000000001000000" intermediate_output_2_24_position_values = [] for intermediate_output_2_24_char in intermediate_output_2_24_string: - if intermediate_output_2_24_char == '?': + if intermediate_output_2_24_char == "?": intermediate_output_2_24_position_values.append(2) else: intermediate_output_2_24_position_values.append(int(intermediate_output_2_24_char)) intermediate_output_2_24 = set_fixed_variables( - component_id='intermediate_output_2_24', - constraint_type='equal', + component_id="intermediate_output_2_24", + constraint_type="equal", bit_positions=list(range(state_size)), - bit_values=intermediate_output_2_24_position_values + bit_values=intermediate_output_2_24_position_values, ) ciphertext = set_fixed_variables( - component_id='cipher_output_7_24', - constraint_type='equal', + component_id="cipher_output_7_24", + constraint_type="equal", bit_positions=list(range(state_size)), bit_values=integer_to_bit_list( 0x00000001000000000000000101010181000080800000000000000000000800800000100000000101000000010000000000000000000000010100000100000101, state_size, - 'big' - ) + "big", + ), ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } sat_heterogeneous_model = SatDifferentialLinearModel( - chacha, component_model_list, middle_part_model='sat_semi_deterministic_truncated_xor_differential_constraints' + chacha, component_model_list, middle_part_model="sat_semi_deterministic_truncated_xor_differential_constraints" ) trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( weight=60, fixed_values=[plaintext, ciphertext, intermediate_output_2_24], - solver_name="CADICAL_EXT", + solver_name=CADICAL_EXT, num_unknown_vars=8, unknown_window_size_configuration={ "max_number_of_sequences_window_size_0": 80, "max_number_of_sequences_window_size_1": 25, - "max_number_of_sequences_window_size_2": 190 + "max_number_of_sequences_window_size_2": 190, }, ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE assert trail["total_weight"] == 11 @@ -428,8 +410,6 @@ def test_differential_linear_trail_with_fixed_weight_4_rounds_chacha_golden(): """ chacha = ChachaPermutation(number_of_rounds=8) # import ipdb; ipdb.set_trace() - import itertools - top_part_components = [] middle_part_components = [] bottom_part_components = [] @@ -449,37 +429,32 @@ def test_differential_linear_trail_with_fixed_weight_4_rounds_chacha_golden(): state_size = 512 plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', + component_id=INPUT_PLAINTEXT, + constraint_type="not_equal", bit_positions=list(range(state_size)), - bit_values=[0] * state_size + bit_values=(0,) * state_size, ) modadd_3_0 = set_fixed_variables( - component_id='modadd_4_0', - constraint_type='not_equal', - bit_positions=list(range(32)), - bit_values=[0] * 32 + component_id="modadd_4_0", constraint_type="not_equal", bit_positions=list(range(32)), bit_values=(0,) * 32 ) component_model_list = { - 'middle_part_components': middle_part_components, - 'bottom_part_components': bottom_part_components + "middle_part_components": middle_part_components, + "bottom_part_components": bottom_part_components, } sat_heterogeneous_model = SatDifferentialLinearModel( - chacha, - component_model_list, - middle_part_model='sat_semi_deterministic_truncated_xor_differential_constraints' + chacha, component_model_list, middle_part_model="sat_semi_deterministic_truncated_xor_differential_constraints" ) trail = sat_heterogeneous_model.find_one_differential_linear_trail_with_fixed_weight( weight=12, fixed_values=[plaintext, modadd_3_0], - solver_name="CADICAL_EXT", + solver_name=CADICAL_EXT, num_unknown_vars=8, ) - assert trail["status"] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE assert trail["total_weight"] <= 12 @@ -489,14 +464,13 @@ def test_diff_lin_chacha(): """ input_difference = 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000 output_mask = 0x00010000000100010000000100030003000000800000008000000000000001800000000000000001000000010000000201000101010000000000010103000101 - number_of_samples = 2 ** 12 + number_of_samples = 2**12 number_of_rounds = 6 state_size = 512 chacha = ChachaPermutation(number_of_rounds=number_of_rounds) corr = differential_linear_checker_for_permutation( chacha, input_difference, output_mask, number_of_samples, state_size ) - import math abs_corr = abs(corr) assert abs(math.log(abs_corr, 2)) < 3 @@ -507,14 +481,13 @@ def test_diff_lin_chacha_8(): """ input_difference = 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000 output_mask = 0x00000001000000000000000101010181000080800000000000000000000800800000100000000101000000010000000000000000000000010100000100000101 - number_of_samples = 2 ** 10 + number_of_samples = 2**10 number_of_rounds = 8 state_size = 512 chacha = ChachaPermutation(number_of_rounds=number_of_rounds) corr = differential_linear_checker_for_permutation( chacha, input_difference, output_mask, number_of_samples, state_size ) - import math abs_corr = abs(corr) assert abs(math.log(abs_corr, 2)) < 8 @@ -523,9 +496,9 @@ def test_diff_lin_speck(): """ This test is verifying experimentally the test test_differential_linear_trail_with_fixed_weight_6_rounds_speck """ - input_difference = 0x02110a04 + input_difference = 0x02110A04 output_mask = 0x02000201 - number_of_samples = 2 ** 15 + number_of_samples = 2**15 number_of_rounds = 6 fixed_key = 0x0 speck = SpeckBlockCipher(number_of_rounds=number_of_rounds) @@ -534,7 +507,6 @@ def test_diff_lin_speck(): corr = differential_linear_checker_for_block_cipher_single_key( speck, input_difference, output_mask, number_of_samples, block_size, key_size, fixed_key, seed=42 ) - import math abs_corr = abs(corr) assert abs(math.log(abs_corr, 2)) <= 8 @@ -545,7 +517,7 @@ def test_diff_lin_aradi(): """ input_difference = 0x00000000000080000000000000008000 output_mask = 0x90900120800000011010002000000000 - number_of_samples = 2 ** 12 + number_of_samples = 2**12 number_of_rounds = 4 fixed_key = 0x90900120800000011010002000000000 speck = AradiBlockCipherSBox(number_of_rounds=number_of_rounds) @@ -554,6 +526,5 @@ def test_diff_lin_aradi(): corr = differential_linear_checker_for_block_cipher_single_key( speck, input_difference, output_mask, number_of_samples, block_size, key_size, fixed_key, seed=42 ) - import math abs_corr = abs(corr) assert abs(math.log(abs_corr, 2)) < 8 diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_test.py index c326eba32..e31274da2 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_probabilistic_xor_truncated_differential_test.py @@ -1,77 +1,24 @@ +import itertools +import pytest + +from claasp.cipher_modules.models.sat import solvers from claasp.cipher_modules.models.sat.sat_models.sat_probabilistic_xor_truncated_differential_model import ( - SatProbabilisticXorTruncatedDifferentialModel + SatProbabilisticXorTruncatedDifferentialModel, +) +from claasp.cipher_modules.models.sat.utils.utils import ( + _generate_component_model_types, + _update_component_model_types_for_truncated_components, +) +from claasp.cipher_modules.models.utils import ( + set_fixed_variables, + integer_to_bit_list, + differential_truncated_checker_single_key, + differential_truncated_checker_permutation, ) - -from claasp.cipher_modules.models.sat.utils.utils import _generate_component_model_types, \ - _update_component_model_types_for_truncated_components -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list, \ - differential_truncated_checker_single_key, differential_truncated_checker_permutation from claasp.ciphers.block_ciphers.aradi_block_cipher_sbox import AradiBlockCipherSBox from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation - -WORD_SIZE = 16 -MASK_VAL = 2 ** WORD_SIZE - 1 -ALPHA = 7 -BETA = 2 - - -def speck_rol(value, shift): - """Performs a left rotation on a 16-bit word.""" - return ((value << shift) & MASK_VAL) | (value >> (WORD_SIZE - shift)) - - -def speck_ror(value, shift): - """Performs a right rotation on a 16-bit word.""" - return (value >> shift) | ((value << (WORD_SIZE - shift)) & MASK_VAL) - - -def speck_encrypt_round(plaintext, subkey): - """Performs one round of encryption for Speck32/64.""" - left_part, right_part = plaintext - left_part = speck_ror(left_part, ALPHA) - left_part = (left_part + right_part) & MASK_VAL - left_part ^= subkey - right_part = speck_rol(right_part, BETA) - right_part ^= left_part - return left_part, right_part - - -def speck_decrypt_round(ciphertext, subkey): - """Performs one round of decryption for Speck32/64.""" - left_part, right_part = ciphertext - right_part ^= left_part - right_part = speck_ror(right_part, BETA) - left_part ^= subkey - left_part = (left_part - right_part) & MASK_VAL - left_part = speck_rol(left_part, ALPHA) - return left_part, right_part - - -def speck_encrypt(plaintext, subkeys): - """Encrypts the given plaintext using the provided subkeys.""" - left_part, right_part = plaintext - for subkey in subkeys: - left_part, right_part = speck_encrypt_round((left_part, right_part), subkey) - return left_part, right_part - - -def speck_decrypt(ciphertext, subkeys): - """Decrypts the given ciphertext using the provided key schedule.""" - left_part, right_part = ciphertext - for subkey in reversed(subkeys): - left_part, right_part = speck_decrypt_round((left_part, right_part), subkey) - return left_part, right_part - - -def speck_key_expansion(key, rounds): - """Expands a key for the specified number of rounds of encryption.""" - ks = [0] * rounds - ks[0] = key[-1] - left_words = list(reversed(key[:-1])) - for i in range(rounds - 1): - left_words[i % len(left_words)], ks[i + 1] = speck_encrypt_round((left_words[i % len(left_words)], ks[i]), i) - return ks +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY, SATISFIABLE def test_differential_truncated_in_single_key_scenario_speck3264(): @@ -83,8 +30,8 @@ def test_differential_truncated_in_single_key_scenario_speck3264(): expected probability for the resulting differential is approximately 2^-12. """ speck = SpeckBlockCipher(number_of_rounds=3) - num_samples = 2 ** 14 - input_diff = 0xfe2ecdf8 + num_samples = 2**14 + input_diff = 0xFE2ECDF8 output_diff = "????100000000000????100000000011" key_size = speck.inputs_bit_size[1] total_prob_weight = differential_truncated_checker_single_key( @@ -99,14 +46,17 @@ def test_differential_in_single_key_scenario_aradi(): which occurs with probability 2^-8. """ aradi = AradiBlockCipherSBox(number_of_rounds=4) - num_samples = 2 ** 12 + num_samples = 2**12 input_diff = 0x00080021000800210000000000000000 - output_diff = ("?0???0??0??0?0??????00??????0?0??0???0??0??0?0??????00??????0?0??0???0??0??0?0??????00??????0?0??0" - "???0??0??0?0??????00??????0?0?") + output_diff = ( + "?0???0??0??0?0??????00??????0?0??0???0??0??0?0??????00??????0?0?" + "?0???0??0??0?0??????00??????0?0??0???0??0??0?0??????00??????0?0?" + ) key_size = aradi.inputs_bit_size[1] aradi = differential_truncated_checker_single_key( - aradi, input_diff, output_diff, num_samples, aradi.output_bit_size, 0x0, key_size, seed=42) + aradi, input_diff, output_diff, num_samples, aradi.output_bit_size, 0x0, key_size, seed=42 + ) assert 9 > abs(aradi) > 2 @@ -118,55 +68,49 @@ def test_find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weig speck = SpeckBlockCipher(number_of_rounds=4) plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=[0] * 32 + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 ) intermediate_output_1_12 = set_fixed_variables( - component_id='intermediate_output_1_12', - constraint_type='equal', + component_id="intermediate_output_1_12", + constraint_type="equal", bit_positions=range(32), - bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) component_model_types = _generate_component_model_types(speck) truncated_components = [ - 'constant_2_0', - 'rot_2_1', - 'modadd_2_2', - 'xor_2_3', - 'rot_2_4', - 'xor_2_5', - 'rot_2_6', - 'modadd_2_7', - 'xor_2_8', - 'rot_2_9', - 'xor_2_10', - 'intermediate_output_2_11', - 'intermediate_output_2_12', - 'constant_3_0', - 'rot_3_1', - 'modadd_3_2', - 'xor_3_3', - 'rot_3_4', - 'xor_3_5', - 'rot_3_6', - 'modadd_3_7', - 'xor_3_8', - 'rot_3_9', - 'xor_3_10', - 'intermediate_output_3_11', - 'intermediate_output_3_12', - 'cipher_output_3_12' + "constant_2_0", + "rot_2_1", + "modadd_2_2", + "xor_2_3", + "rot_2_4", + "xor_2_5", + "rot_2_6", + "modadd_2_7", + "xor_2_8", + "rot_2_9", + "xor_2_10", + "intermediate_output_2_11", + "intermediate_output_2_12", + "constant_3_0", + "rot_3_1", + "modadd_3_2", + "xor_3_3", + "rot_3_4", + "xor_3_5", + "rot_3_6", + "modadd_3_7", + "xor_3_8", + "rot_3_9", + "xor_3_10", + "intermediate_output_3_11", + "intermediate_output_3_12", + "cipher_output_3_12", ] _update_component_model_types_for_truncated_components(component_model_types, truncated_components) @@ -174,10 +118,9 @@ def test_find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weig trail = sat_heterogeneous_model.find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weight( weight=8, fixed_values=[intermediate_output_1_12, key, plaintext], - number_of_unknowns_per_component={'cipher_output_3_12': 31}, - solver_name="CRYPTOMINISAT_EXT" + number_of_unknowns_per_component={"cipher_output_3_12": 31}, ) - assert trail['components_values']['cipher_output_3_12']['value'] == '????????00000000????????000000?1' + assert trail["components_values"]["cipher_output_3_12"]["value"] == "????????00000000????????000000?1" def test_find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weight_5_rounds(): @@ -185,37 +128,62 @@ def test_find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weig speck = SpeckBlockCipher(number_of_rounds=5) plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=[0] * 32 + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 ) intermediate_output_1_12 = set_fixed_variables( - component_id='intermediate_output_1_12', - constraint_type='equal', + component_id="intermediate_output_1_12", + constraint_type="equal", bit_positions=range(32), - bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) component_model_types = _generate_component_model_types(speck) truncated_components = [ - 'constant_2_0', 'rot_2_1', 'modadd_2_2', 'xor_2_3', 'rot_2_4', - 'xor_2_5', 'rot_2_6', 'modadd_2_7', 'xor_2_8', 'rot_2_9', 'xor_2_10', - 'intermediate_output_2_11', 'intermediate_output_2_12', - 'constant_3_0', 'rot_3_1', 'modadd_3_2', 'xor_3_3', 'rot_3_4', - 'xor_3_5', 'rot_3_6', 'modadd_3_7', 'xor_3_8', 'rot_3_9', 'xor_3_10', - 'intermediate_output_3_11', 'intermediate_output_3_12', - 'constant_4_0', 'rot_4_1', 'modadd_4_2', 'xor_4_3', 'rot_4_4', - 'xor_4_5', 'rot_4_6', 'modadd_4_7', 'xor_4_8', 'rot_4_9', 'xor_4_10', - 'intermediate_output_4_11', 'intermediate_output_4_12', 'cipher_output_4_12' + "constant_2_0", + "rot_2_1", + "modadd_2_2", + "xor_2_3", + "rot_2_4", + "xor_2_5", + "rot_2_6", + "modadd_2_7", + "xor_2_8", + "rot_2_9", + "xor_2_10", + "intermediate_output_2_11", + "intermediate_output_2_12", + "constant_3_0", + "rot_3_1", + "modadd_3_2", + "xor_3_3", + "rot_3_4", + "xor_3_5", + "rot_3_6", + "modadd_3_7", + "xor_3_8", + "rot_3_9", + "xor_3_10", + "intermediate_output_3_11", + "intermediate_output_3_12", + "constant_4_0", + "rot_4_1", + "modadd_4_2", + "xor_4_3", + "rot_4_4", + "xor_4_5", + "rot_4_6", + "modadd_4_7", + "xor_4_8", + "rot_4_9", + "xor_4_10", + "intermediate_output_4_11", + "intermediate_output_4_12", + "cipher_output_4_12", ] _update_component_model_types_for_truncated_components(component_model_types, truncated_components) @@ -224,11 +192,10 @@ def test_find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weig trail = sat_heterogeneous_model.find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weight( weight=8, fixed_values=[intermediate_output_1_12, key, plaintext], - number_of_unknowns_per_component={'cipher_output_4_12': 31}, - solver_name="CRYPTOMINISAT_EXT" + number_of_unknowns_per_component={"cipher_output_4_12": 31}, ) - assert trail['components_values']['cipher_output_4_12']['value'] == '???????????????0????????????????' + assert trail["components_values"]["cipher_output_4_12"]["value"] == "???????????????0????????????????" def test_find_lowest_xor_probabilistic_truncated_differential_trail_with_fixed_weight_5_rounds(): @@ -236,122 +203,139 @@ def test_find_lowest_xor_probabilistic_truncated_differential_trail_with_fixed_w speck = SpeckBlockCipher(number_of_rounds=5) plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=[0] * 32 + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 ) intermediate_output_1_12 = set_fixed_variables( - component_id='intermediate_output_1_12', - constraint_type='equal', + component_id="intermediate_output_1_12", + constraint_type="equal", bit_positions=range(32), - bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) component_model_types = _generate_component_model_types(speck) truncated_components = [ - 'constant_2_0', 'rot_2_1', 'modadd_2_2', 'xor_2_3', 'rot_2_4', - 'xor_2_5', 'rot_2_6', 'modadd_2_7', 'xor_2_8', 'rot_2_9', 'xor_2_10', - 'intermediate_output_2_11', 'intermediate_output_2_12', - 'constant_3_0', 'rot_3_1', 'modadd_3_2', 'xor_3_3', 'rot_3_4', - 'xor_3_5', 'rot_3_6', 'modadd_3_7', 'xor_3_8', 'rot_3_9', 'xor_3_10', - 'intermediate_output_3_11', 'intermediate_output_3_12', - 'constant_4_0', 'rot_4_1', 'modadd_4_2', 'xor_4_3', 'rot_4_4', - 'xor_4_5', 'rot_4_6', 'modadd_4_7', 'xor_4_8', 'rot_4_9', 'xor_4_10', - 'intermediate_output_4_11', 'intermediate_output_4_12', 'cipher_output_4_12' + "constant_2_0", + "rot_2_1", + "modadd_2_2", + "xor_2_3", + "rot_2_4", + "xor_2_5", + "rot_2_6", + "modadd_2_7", + "xor_2_8", + "rot_2_9", + "xor_2_10", + "intermediate_output_2_11", + "intermediate_output_2_12", + "constant_3_0", + "rot_3_1", + "modadd_3_2", + "xor_3_3", + "rot_3_4", + "xor_3_5", + "rot_3_6", + "modadd_3_7", + "xor_3_8", + "rot_3_9", + "xor_3_10", + "intermediate_output_3_11", + "intermediate_output_3_12", + "constant_4_0", + "rot_4_1", + "modadd_4_2", + "xor_4_3", + "rot_4_4", + "xor_4_5", + "rot_4_6", + "modadd_4_7", + "xor_4_8", + "rot_4_9", + "xor_4_10", + "intermediate_output_4_11", + "intermediate_output_4_12", + "cipher_output_4_12", ] _update_component_model_types_for_truncated_components(component_model_types, truncated_components) sat_heterogeneous_model = SatProbabilisticXorTruncatedDifferentialModel(speck, component_model_types) trail = sat_heterogeneous_model.find_lowest_weight_xor_probabilistic_truncated_differential_trail( - fixed_values=[intermediate_output_1_12, key, plaintext], solver_name="CRYPTOMINISAT_EXT" + fixed_values=[intermediate_output_1_12, key, plaintext], solver_name=solvers.CRYPTOMINISAT_EXT ) - assert trail['components_values']['cipher_output_4_12']['value'] == '???????????????0????????????????' + assert trail["components_values"]["cipher_output_4_12"]["value"] == "???????????????0????????????????" def test_wrong_fixed_variables_assignment(): speck = SpeckBlockCipher(number_of_rounds=5) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=[0] * 32 + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), bit_values=(0,) * 32 ) intermediate_output_1_12 = set_fixed_variables( - component_id='intermediate_output_1_12', - constraint_type='equal', + component_id="intermediate_output_1_12", + constraint_type="equal", bit_positions=range(32), - bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ) modadd_1_2 = set_fixed_variables( - component_id='modadd_1_2', - constraint_type='equal', + component_id="modadd_1_2", + constraint_type="equal", bit_positions=range(32), - bit_values=( - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ) + bit_values=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ) component_model_types = _generate_component_model_types(speck) sat_bitwise_deterministic_truncated_components = [ - 'constant_2_0', - 'rot_2_1', - 'modadd_2_2', - 'xor_2_3', - 'rot_2_4', - 'xor_2_5', - 'rot_2_6', - 'modadd_2_7', - 'xor_2_8', - 'rot_2_9', - 'xor_2_10', - 'intermediate_output_2_11', - 'intermediate_output_2_12', - 'constant_3_0', - 'rot_3_1', - 'modadd_3_2', - 'xor_3_3', - 'rot_3_4', - 'xor_3_5', - 'rot_3_6', - 'modadd_3_7', - 'xor_3_8', - 'rot_3_9', - 'xor_3_10', - 'intermediate_output_3_11', - 'intermediate_output_3_12', - 'constant_4_0', - 'rot_4_1', - 'modadd_4_2', - 'xor_4_3', - 'rot_4_4', - 'xor_4_5', - 'rot_4_6', - 'modadd_4_7', - 'xor_4_8', - 'rot_4_9', - 'xor_4_10', - 'intermediate_output_4_11', - 'intermediate_output_4_12', - 'cipher_output_4_12' + "constant_2_0", + "rot_2_1", + "modadd_2_2", + "xor_2_3", + "rot_2_4", + "xor_2_5", + "rot_2_6", + "modadd_2_7", + "xor_2_8", + "rot_2_9", + "xor_2_10", + "intermediate_output_2_11", + "intermediate_output_2_12", + "constant_3_0", + "rot_3_1", + "modadd_3_2", + "xor_3_3", + "rot_3_4", + "xor_3_5", + "rot_3_6", + "modadd_3_7", + "xor_3_8", + "rot_3_9", + "xor_3_10", + "intermediate_output_3_11", + "intermediate_output_3_12", + "constant_4_0", + "rot_4_1", + "modadd_4_2", + "xor_4_3", + "rot_4_4", + "xor_4_5", + "rot_4_6", + "modadd_4_7", + "xor_4_8", + "rot_4_9", + "xor_4_10", + "intermediate_output_4_11", + "intermediate_output_4_12", + "cipher_output_4_12", ] _update_component_model_types_for_truncated_components( component_model_types, sat_bitwise_deterministic_truncated_components @@ -359,13 +343,12 @@ def test_wrong_fixed_variables_assignment(): sat_heterogeneous_model = SatProbabilisticXorTruncatedDifferentialModel(speck, component_model_types) - import pytest with pytest.raises(ValueError) as exc_info: sat_heterogeneous_model.find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weight( 8, fixed_values=[intermediate_output_1_12, key, plaintext, modadd_1_2], - number_of_unknowns_per_component={'cipher_output_4_12': 31}, - solver_name="CRYPTOMINISAT_EXT" + number_of_unknowns_per_component={"cipher_output_4_12": 31}, + solver_name="CRYPTOMINISAT_EXT", ) assert str(exc_info.value) == "The fixed value in a regular XOR differential model cannot be 2" @@ -373,9 +356,6 @@ def test_wrong_fixed_variables_assignment(): def test_differential_linear_trail_with_fixed_weight_4_rounds_aradi(): """Test for finding a XOR regular truncated differential trail with fixed weight for 4 rounds of Aradi cipher.""" aradi = AradiBlockCipherSBox(number_of_rounds=4) - import itertools - - top_part_components = [] bottom_part_components = [] for round_number in range(2, 4): bottom_part_components.append(aradi.get_components_in_round(round_number)) @@ -383,17 +363,14 @@ def test_differential_linear_trail_with_fixed_weight_4_rounds_aradi(): bottom_part_components = [component.id for component in bottom_part_components] plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=range(128), - bit_values=integer_to_bit_list(0x00080021000800210000000000000000, 128, 'big') + bit_values=integer_to_bit_list(0x00080021000800210000000000000000, 128, "big"), ) key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(256), - bit_values=(0,) * 256 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(256), bit_values=(0,) * 256 ) component_model_types = _generate_component_model_types(aradi) @@ -403,95 +380,91 @@ def test_differential_linear_trail_with_fixed_weight_4_rounds_aradi(): trail = sat_heterogeneous_model.find_one_xor_probabilistic_truncated_differential_trail_with_fixed_weight( weight=8, fixed_values=[key, plaintext], - solver_name="CADICAL_EXT", - number_of_unknowns_per_component={'cipher_output_3_86': 127} + solver_name=solvers.CADICAL_EXT, + number_of_unknowns_per_component={"cipher_output_3_86": 127}, ) - assert trail['components_values']['cipher_output_3_86']['value'] == ('?0???0??0??0?0??????00??????0?0??0???0??0??0' - '?0??????00??????0?0??0???0??0??0?0??????00' - '??????0?0??0???0??0??0?0??????00??????0?0?') + assert trail["components_values"]["cipher_output_3_86"]["value"] == ( + "?0???0??0??0?0??????00??????0?0??0???0??0??0?0??????00??????0?0?" + "?0???0??0??0?0??????00??????0?0??0???0??0??0?0??????00??????0?0?" + ) def test_differential_linear_trail_with_fixed_weight_3_rounds_chacha(): """Test for finding a XOR regular truncated differential trail with fixed weight for 4 rounds of ChaCha cipher.""" chacha = ChachaPermutation(number_of_rounds=3) - import itertools - - top_part_components = [] bottom_part_components = [] for round_number in range(2, 3): bottom_part_components.append(chacha.get_components_in_round(round_number)) bottom_part_components = list(itertools.chain(*bottom_part_components)) bottom_part_components = [component.id for component in bottom_part_components] initial_state_positions = integer_to_bit_list( - int('00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008008000000000000000000000000', - 16), + 0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008008000000000000000000000000, 512, - 'big' + "big", ) plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=list(range(512)), - bit_values=initial_state_positions + bit_values=initial_state_positions, ) intermediate_output_0_24_state = integer_to_bit_list( - int('00000000000000000000000000000000800008000000000000000000000000008008000000000000000000000000000080080000000000000000000000000000', - 16), + 0x00000000000000000000000000000000800008000000000000000000000000008008000000000000000000000000000080080000000000000000000000000000, 512, - 'big' + "big", ) intermediate_output_1_24_state = integer_to_bit_list( - int('80000800000000000000000000000000000400040000000000000000000000008800000000000000000000000000000008080000000000000000000000000000', - 16), + 0x80000800000000000000000000000000000400040000000000000000000000008800000000000000000000000000000008080000000000000000000000000000, 512, - 'big' + "big", ) intermediate_output_0_24 = set_fixed_variables( - component_id='intermediate_output_0_24', - constraint_type='equal', + component_id="intermediate_output_0_24", + constraint_type="equal", bit_positions=list(range(512)), - bit_values=intermediate_output_0_24_state + bit_values=intermediate_output_0_24_state, ) intermediate_output_1_24 = set_fixed_variables( - component_id='intermediate_output_1_24', - constraint_type='equal', + component_id="intermediate_output_1_24", + constraint_type="equal", bit_positions=list(range(512)), - bit_values=intermediate_output_1_24_state - ) - - cipher_output_2_24_state = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, - 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, - 0, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, - 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + bit_values=intermediate_output_1_24_state, + ) + # fmt: off + cipher_output_2_24_state = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 0, + 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ] + # fmt: on cipher_output_2_24 = set_fixed_variables( - component_id='cipher_output_2_24', - constraint_type='equal', + component_id="cipher_output_2_24", + constraint_type="equal", bit_positions=list(range(512)), - bit_values=cipher_output_2_24_state + bit_values=cipher_output_2_24_state, ) component_model_types = _generate_component_model_types(chacha) _update_component_model_types_for_truncated_components( component_model_types, bottom_part_components, - truncated_model_type="sat_semi_deterministic_truncated_xor_differential_constraints" + truncated_model_type="sat_semi_deterministic_truncated_xor_differential_constraints", ) sat_heterogeneous_model = SatProbabilisticXorTruncatedDifferentialModel(chacha, component_model_types) @@ -499,7 +472,7 @@ def test_differential_linear_trail_with_fixed_weight_3_rounds_chacha(): unknown_window_size_configuration = { "max_number_of_sequences_window_size_0": 9, "max_number_of_sequences_window_size_1": 9, - "max_number_of_sequences_window_size_2": 20 + "max_number_of_sequences_window_size_2": 20, } max_number_of_unknowns_per_component = {"cipher_output_2_24": 12} @@ -508,13 +481,13 @@ def test_differential_linear_trail_with_fixed_weight_3_rounds_chacha(): weight=14, number_of_unknowns_per_component=max_number_of_unknowns_per_component, fixed_values=[plaintext, intermediate_output_0_24, intermediate_output_1_24, cipher_output_2_24], - solver_name="CADICAL_EXT", - unknown_window_size_configuration=unknown_window_size_configuration + solver_name=solvers.CADICAL_EXT, + unknown_window_size_configuration=unknown_window_size_configuration, ) - assert trail['status'] == 'SATISFIABLE' + assert trail["status"] == SATISFIABLE - input_difference = int(trail['components_values']['plaintext']['value'], 16) - output_difference = trail['components_values']['cipher_output_2_24']['value'] + input_difference = int(trail["components_values"][INPUT_PLAINTEXT]["value"], 16) + output_difference = trail["components_values"]["cipher_output_2_24"]["value"] prob = differential_truncated_checker_permutation( chacha, input_difference, output_difference, 1 << 14, 512, seed=42 ) diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_semi_deterministic_xor_differential_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_semi_deterministic_xor_differential_model_test.py index bdab8c267..5b2e83546 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_semi_deterministic_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_semi_deterministic_xor_differential_model_test.py @@ -1,106 +1,149 @@ -from claasp.cipher_modules.models.sat.sat_models.sat_semi_deterministic_truncated_xor_differential_model import \ - SatSemiDeterministicTruncatedXorDifferentialModel -from claasp.cipher_modules.models.utils import set_fixed_variables, differential_truncated_checker_permutation, \ - integer_to_bit_list +from claasp.cipher_modules.models.sat.sat_models.sat_semi_deterministic_truncated_xor_differential_model import ( + SatSemiDeterministicTruncatedXorDifferentialModel, +) +from claasp.cipher_modules.models.sat.solvers import CADICAL_EXT +from claasp.cipher_modules.models.utils import ( + set_fixed_variables, + differential_truncated_checker_permutation, + integer_to_bit_list, +) from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY def test_find_one_semi_deterministic_truncated_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=3) sat = SatSemiDeterministicTruncatedXorDifferentialModel(speck) - bit_values = [0]*32 + bit_values = [0] * 32 bit_values[10] = 1 - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=bit_values) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=bit_values + ) intermediate_output_0_6 = set_fixed_variables( - component_id='intermediate_output_0_6', constraint_type='equal', bit_positions=range(32), - bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + component_id="intermediate_output_0_6", + constraint_type="equal", + bit_positions=range(32), + bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) intermediate_output_1_12 = set_fixed_variables( - component_id='intermediate_output_1_12', constraint_type='equal', bit_positions=range(32), - bit_values=[0, 2, 2, 2, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1]) + component_id="intermediate_output_1_12", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0, 2, 2, 2, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1], + ) cipher_output_2_12 = set_fixed_variables( - component_id='cipher_output_2_12', constraint_type='equal', bit_positions=range(32), - bit_values=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) + component_id="cipher_output_2_12", + constraint_type="equal", + bit_positions=range(32), + bit_values=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], + ) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), - bit_values=(0,) * 64) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail( fixed_values=[plaintext, intermediate_output_0_6, intermediate_output_1_12, cipher_output_2_12, key] ) - assert trail['components_values']['cipher_output_2_12']['value'] == '???????????????0????????????????' + assert trail["components_values"]["cipher_output_2_12"]["value"] == "???????????????0????????????????" def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_window_size_configuration(): speck = SpeckBlockCipher(number_of_rounds=3) sat = SatSemiDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0]) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(32), + bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) intermediate_output_0_6 = set_fixed_variables( - component_id='intermediate_output_0_6', constraint_type='equal', bit_positions=range(32), - bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + component_id="intermediate_output_0_6", + constraint_type="equal", + bit_positions=range(32), + bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) intermediate_output_1_12 = set_fixed_variables( - component_id='intermediate_output_1_12', constraint_type='equal', bit_positions=range(32), - bit_values=[0, 1, 0, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1]) + component_id="intermediate_output_1_12", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0, 1, 0, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1], + ) cipher_output_2_12 = set_fixed_variables( - component_id='cipher_output_2_12', constraint_type='equal', bit_positions=range(32), - bit_values=[2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1]) + component_id="cipher_output_2_12", + constraint_type="equal", + bit_positions=range(32), + bit_values=[2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1], + ) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), - bit_values=(0,) * 64) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail( fixed_values=[plaintext, intermediate_output_0_6, intermediate_output_1_12, cipher_output_2_12, key], unknown_window_size_configuration={ "max_number_of_sequences_window_size_0": 20, "max_number_of_sequences_window_size_1": 20, - "max_number_of_sequences_window_size_2": 20 - } + "max_number_of_sequences_window_size_2": 20, + }, ) - assert trail['components_values']['cipher_output_2_12']['value'] == '????0??????????0???????????????1' + assert trail["components_values"]["cipher_output_2_12"]["value"] == "????0??????????0???????????????1" def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_window_size_configuration_unsat(): speck = SpeckBlockCipher(number_of_rounds=3) sat = SatSemiDeterministicTruncatedXorDifferentialModel(speck) - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(32), - bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0]) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, + constraint_type="equal", + bit_positions=range(32), + bit_values=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) intermediate_output_0_6 = set_fixed_variables( - component_id='intermediate_output_0_6', constraint_type='equal', bit_positions=range(32), - bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + component_id="intermediate_output_0_6", + constraint_type="equal", + bit_positions=range(32), + bit_values=[2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) intermediate_output_1_12 = set_fixed_variables( - component_id='intermediate_output_1_12', constraint_type='equal', bit_positions=range(32), - bit_values=[0, 1, 0, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1]) + component_id="intermediate_output_1_12", + constraint_type="equal", + bit_positions=range(32), + bit_values=[0, 1, 0, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 0, 1, 0, 0, 0, 0, 0, 2, 1], + ) cipher_output_2_12 = set_fixed_variables( - component_id='cipher_output_2_12', constraint_type='equal', bit_positions=range(32), - bit_values=[2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1]) + component_id="cipher_output_2_12", + constraint_type="equal", + bit_positions=range(32), + bit_values=[2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1], + ) - key = set_fixed_variables(component_id='key', constraint_type='equal', bit_positions=range(64), - bit_values=(0,) * 64) + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + ) trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail( fixed_values=[plaintext, intermediate_output_0_6, intermediate_output_1_12, cipher_output_2_12, key], unknown_window_size_configuration={ "max_number_of_sequences_window_size_0": 20, "max_number_of_sequences_window_size_1": 1, - "max_number_of_sequences_window_size_2": 20 - } + "max_number_of_sequences_window_size_2": 20, + }, ) - assert trail['status'] == 'UNSATISFIABLE' + assert trail["status"] == "UNSATISFIABLE" def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_window_size_configuration_chacha(): @@ -109,19 +152,20 @@ def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_windo state_size = 512 initial_state = [0] * state_size initial_state[389] = 1 - plaintext = set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=range(state_size), - bit_values=initial_state) + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(state_size), bit_values=initial_state + ) trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail( fixed_values=[plaintext], unknown_window_size_configuration={ "max_number_of_sequences_window_size_0": 20, "max_number_of_sequences_window_size_1": 1, - "max_number_of_sequences_window_size_2": 20 - } + "max_number_of_sequences_window_size_2": 20, + }, ) - assert trail['status'] == 'UNSATISFIABLE' + assert trail["status"] == "UNSATISFIABLE" def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_window_size_for_chacha_1_round_satisfiable(): @@ -132,35 +176,27 @@ def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_windo initial_state_positions[508] = 1 plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='equal', + component_id=INPUT_PLAINTEXT, + constraint_type="equal", bit_positions=list(range(state_size)), - bit_values=initial_state_positions + bit_values=initial_state_positions, ) - intermediate_output_0_24_int = int( - '00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000', - 2) + intermediate_output_0_24_int = 0x00000000000000000000000000000000000000000000000000000000800000000000000000000000000000000008000000000000000000000000000000080000 intermediate_output_0_24 = set_fixed_variables( - component_id='intermediate_output_0_24', - constraint_type='equal', + component_id="intermediate_output_0_24", + constraint_type="equal", bit_positions=list(range(state_size)), - bit_values=integer_to_bit_list(intermediate_output_0_24_int, state_size, 'big') + bit_values=integer_to_bit_list(intermediate_output_0_24_int, state_size, "big"), ) - cipher_output_1_24_int = list( - '000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000010000000????100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000????100000001000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000010000000') - cipher_output_1_24_int_temp = [] - for bit in cipher_output_1_24_int: - if bit == '?': - cipher_output_1_24_int_temp.append(2) - else: - cipher_output_1_24_int_temp.append(int(bit)) + cipher_output_1_24_int = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000010000000????100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000????100000001000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000010000000" + cipher_output_1_24_int_temp = list(map(int, cipher_output_1_24_int.replace("?", "2"))) cipher_output_1_24 = set_fixed_variables( - component_id='cipher_output_1_24', - constraint_type='equal', + component_id="cipher_output_1_24", + constraint_type="equal", bit_positions=list(range(state_size)), - bit_values=cipher_output_1_24_int_temp + bit_values=cipher_output_1_24_int_temp, ) trail = sat.find_one_semi_deterministic_truncated_xor_differential_trail( @@ -168,16 +204,16 @@ def test_find_one_semi_deterministic_truncated_xor_differential_trail_with_windo unknown_window_size_configuration={ "max_number_of_sequences_window_size_0": 3, "max_number_of_sequences_window_size_1": 3, - "max_number_of_sequences_window_size_2": 3 + "max_number_of_sequences_window_size_2": 3, }, number_of_unknowns_per_component={"cipher_output_1_24": 8}, - solver_name='CADICAL_EXT' + solver_name=CADICAL_EXT, ) - assert trail['status'] == 'SATISFIABLE' + assert trail["status"] == "SATISFIABLE" - input_difference = int(trail['components_values']['plaintext']['value'], 2) - output_difference = trail['components_values']['cipher_output_1_24']['value'] + input_difference = int(trail["components_values"][INPUT_PLAINTEXT]["value"], 2) + output_difference = trail["components_values"]["cipher_output_1_24"]["value"] prob = differential_truncated_checker_permutation( chacha, input_difference, output_difference, 1 << 12, state_size, seed=42 ) diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model_test.py index 845229b26..7818c5975 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_linear_model_test.py @@ -1,12 +1,16 @@ - import itertools import os import pickle -from claasp.cipher_modules.models.sat.sat_models.sat_shared_difference_paired_input_differential_linear_model import \ - SharedDifferencePairedInputDifferentialLinearModel -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list, \ - shared_difference_paired_input_differential_linear_checker_permutation +from claasp.cipher_modules.models.sat.sat_models.sat_shared_difference_paired_input_differential_linear_model import ( + SharedDifferencePairedInputDifferentialLinearModel, +) +from claasp.cipher_modules.models.sat.solvers import KISSAT_EXT +from claasp.cipher_modules.models.utils import ( + set_fixed_variables, + integer_to_bit_list, + shared_difference_paired_input_differential_linear_checker_permutation, +) from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation from claasp.components.intermediate_output_component import IntermediateOutput from claasp.components.modsub_component import MODSUB @@ -15,7 +19,7 @@ def add_prefix_id_to_inputs(chacha_permutation, prefix): new_inputs = [] for chacha_permutation_input in chacha_permutation.inputs: - new_inputs.append(f'{prefix}_{chacha_permutation_input}') + new_inputs.append(f"{prefix}_{chacha_permutation_input}") chacha_permutation.set_inputs(new_inputs, chacha_permutation.inputs_bit_size) @@ -25,17 +29,15 @@ def add_ciphertext_and_new_plaintext_to_inputs(chacha_permutation): chacha_permutation.inputs_bit_size.append(512) chacha_permutation.inputs_bit_size.append(512) modsub_ids = [] - constants_ids = [] round_object = chacha_permutation.rounds.round_at(0) for i in range(16): new_modsub_component = MODSUB( 0, round_object.get_number_of_components(), ["ciphertext_final", "fake_plaintext"], - [list(range(i * 32, (i) * 32 + 32)), - list(range(i * 32, (i) * 32 + 32))], + [list(range(i * 32, (i) * 32 + 32)), list(range(i * 32, (i) * 32 + 32))], 32, - None + None, ) round_object.add_component(new_modsub_component) @@ -44,9 +46,9 @@ def add_ciphertext_and_new_plaintext_to_inputs(chacha_permutation): 0, round_object.get_number_of_components(), modsub_ids, - [list(range(32)) for i in range(16)], + [list(range(32)) for _ in range(16)], 512, - "round_output" + "round_output", ) round_object.add_component(new_intermediate_output_component) new_intermediate_output_component.set_id(chacha_permutation.inputs[0]) @@ -57,8 +59,8 @@ def add_ciphertext_and_new_plaintext_to_inputs(chacha_permutation): def add_prefix_id_to_components(chacha_permutation, prefix): all_components = chacha_permutation.rounds.get_all_components() for component in all_components: - component.set_id(f'{prefix}_{component.id}') - new_input_id_links = [f'{prefix}_{input_id_link}' for input_id_link in component.input_id_links] + component.set_id(f"{prefix}_{component.id}") + new_input_id_links = [f"{prefix}_{input_id_link}" for input_id_link in component.input_id_links] component.set_input_id_links(new_input_id_links) return 0 @@ -95,77 +97,60 @@ def test_backward_direction_distinguisher(): bottom_part_components = [component.id for component in bottom_part_components] ciphertext_final = set_fixed_variables( - component_id='bottom_ciphertext_final', - constraint_type='equal', + component_id="bottom_ciphertext_final", + constraint_type="equal", bit_positions=range(512), - bit_values=integer_to_bit_list( - 0x0, - 512, - 'big' - ) + bit_values=integer_to_bit_list(0x0, 512, "big"), ) plaintext = set_fixed_variables( - component_id='bottom_fake_plaintext', - constraint_type='equal', + component_id="bottom_fake_plaintext", + constraint_type="equal", bit_positions=range(512), bit_values=integer_to_bit_list( 0x00000000000000000000000000000000080000400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, 512, - 'big' - ) + "big", + ), ) plaintext_constants = set_fixed_variables( - component_id='bottom_fake_plaintext', - constraint_type='equal', + component_id="bottom_fake_plaintext", + constraint_type="equal", bit_positions=range(128), - bit_values=integer_to_bit_list( - 0x0, - 128, - 'big' - ) + bit_values=integer_to_bit_list(0x0, 128, "big"), ) plaintext_nonce = set_fixed_variables( - component_id='bottom_fake_plaintext', - constraint_type='equal', + component_id="bottom_fake_plaintext", + constraint_type="equal", bit_positions=range(384, 512), - bit_values=integer_to_bit_list( - 0x0, - 128, - 'big' - ) + bit_values=integer_to_bit_list(0x0, 128, "big"), ) bottom_cipher_output_1_24 = set_fixed_variables( - component_id='bottom_cipher_output_3_24', - constraint_type='not_equal', + component_id="bottom_cipher_output_3_24", + constraint_type="not_equal", bit_positions=range(512), - bit_values=integer_to_bit_list( - 0x0, - 512, - 'big' - ) + bit_values=integer_to_bit_list(0x0, 512, "big"), ) bottom_plaintext = set_fixed_variables( - component_id='bottom_plaintext', - constraint_type='equal', + component_id="bottom_plaintext", + constraint_type="equal", bit_positions=range(512), bit_values=integer_to_bit_list( 0x00000000000000010000000000000000000000000000000100000000000000000000000000000001000000000000000000000000000000000000000000000000, 512, - 'big' - ) + "big", + ), ) - component_model_list = { - 'bottom_part_components': bottom_part_components - } + component_model_list = {"bottom_part_components": bottom_part_components} - sat_heterogeneous_model = SharedDifferencePairedInputDifferentialLinearModel(chacha_stream_cipher, - component_model_list) + sat_heterogeneous_model = SharedDifferencePairedInputDifferentialLinearModel( + chacha_stream_cipher, component_model_list + ) trail = sat_heterogeneous_model.find_one_shared_difference_paired_input_differential_linear_trail_with_fixed_weight( weight=40, fixed_values=[ @@ -174,23 +159,18 @@ def test_backward_direction_distinguisher(): bottom_cipher_output_1_24, bottom_plaintext, plaintext_constants, - plaintext_nonce + plaintext_nonce, ], - solver_name="KISSAT_EXT" + solver_name=KISSAT_EXT, ) assert trail["status"] == "SATISFIABLE" - input_difference = int(trail['components_values']['bottom_fake_plaintext']['value'], 16) - output_difference1 = int(trail['components_values']['bottom_plaintext']['value'], 16) + input_difference = int(trail["components_values"]["bottom_fake_plaintext"]["value"], 16) + output_difference1 = int(trail["components_values"]["bottom_plaintext"]["value"], 16) prob = shared_difference_paired_input_differential_linear_checker_permutation( - chacha_stream_cipher, - input_difference, - output_difference1, - 1 << 8, - 512, - 1 + chacha_stream_cipher, input_difference, output_difference1, 1 << 8, 512, 1 ) assert prob < 14 diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model_test.py index 40eb99ce1..dac1a1d26 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_shared_difference_paired_input_differential_model_test.py @@ -1,9 +1,14 @@ -from claasp.cipher_modules.models.sat.sat_models.sat_shared_difference_paired_input_differential_model import \ - SharedDifferencePairedInputDifferentialModel -from claasp.cipher_modules.models.utils import set_fixed_variables, \ - shared_difference_paired_input_differential_checker_permutation, \ - integer_to_bit_list +from claasp.cipher_modules.models.sat.sat_models.sat_shared_difference_paired_input_differential_model import ( + SharedDifferencePairedInputDifferentialModel, +) +from claasp.cipher_modules.models.sat.solvers import CADICAL_EXT +from claasp.cipher_modules.models.utils import ( + set_fixed_variables, + shared_difference_paired_input_differential_checker_permutation, + integer_to_bit_list, +) from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation +from claasp.name_mappings import INPUT_PLAINTEXT, SATISFIABLE def test_sat_shared_difference_paired_input_differential_model_on_chacha_permutation(): @@ -12,42 +17,35 @@ def test_sat_shared_difference_paired_input_differential_model_on_chacha_permuta fixed_variables = [ set_fixed_variables( - 'plaintext', - 'equal', + INPUT_PLAINTEXT, + "equal", bit_positions=list(range(512)), bit_values=integer_to_bit_list( - 0x800000008000000080000000e0000000800000008000000080000000e00000008000000080000000800000008000400000008000000080000000800040008000, + 0x800000008000000080000000E0000000800000008000000080000000E00000008000000080000000800000008000400000008000000080000000800040008000, list_length=512, - endianness='big' - ) + endianness="big", + ), ), set_fixed_variables( - 'cipher_output_0_24', - 'equal', + "cipher_output_0_24", + "equal", bit_positions=list(range(512)), bit_values=integer_to_bit_list( - 0x0000000000000000000000000000000000000800000008000000080000000e000000000000000000000000000000000080000000800000008000000080004000, + 0x0000000000000000000000000000000000000800000008000000080000000E000000000000000000000000000000000080000000800000008000000080004000, list_length=512, - endianness='big' - ) - ) + endianness="big", + ), + ), ] trail = sat_model.find_one_shared_difference_paired_input_differential_trail_with_fixed_weight( - 8, - fixed_variables, - solver_name="CADICAL_EXT" + 8, fixed_variables, solver_name=CADICAL_EXT ) - assert trail['status'] == 'SATISFIABLE' - input_difference = int(trail['components_values']['plaintext']['value'], 16) - output_difference1 = int(trail['components_values']['cipher_output_0_24']['value'], 16) - output_difference2 = int(trail['components_values']['cipher1_cipher_output_0_24']['value'], 16) + assert trail["status"] == SATISFIABLE + input_difference = int(trail["components_values"][INPUT_PLAINTEXT]["value"], 16) + output_difference1 = int(trail["components_values"]["cipher_output_0_24"]["value"], 16) + output_difference2 = int(trail["components_values"]["cipher1_cipher_output_0_24"]["value"], 16) prob = shared_difference_paired_input_differential_checker_permutation( - chacha1, - input_difference, - output_difference1 ^ output_difference2, - 1 << 13, - 512, - 16 + chacha1, input_difference, output_difference1 ^ output_difference2, 1 << 13, 512, 16 ) assert prob <= 13 diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_differential_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_differential_model_test.py index 283c25177..b76dd7b1c 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_differential_model_test.py @@ -1,17 +1,17 @@ -import math import numpy as np -from claasp.utils.utils import get_k_th_bit -from claasp.components.modadd_component import MODADD -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel +from claasp.cipher_modules.models.sat.solvers import CADICAL_EXT, KISSAT_EXT, PARKISSAT_EXT +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.components.modadd_component import MODADD +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT, SATISFIABLE, XOR_DIFFERENTIAL def count_sequences_of_ones(data, full_window_size): count = 0 for entry in data: - for key, binary_str in entry.items(): + for binary_str in entry.values(): binary_str = binary_str[2:] # Remove the '0b' prefix sequences = binary_str.split("0") for seq in sequences: @@ -33,7 +33,7 @@ def extract_bits_from_hex(hex_value, bit_positions): last_bit_position = bit_positions[-1] for bit_position in bit_positions: - bit_value = get_k_th_bit(hex_value, last_bit_position - bit_position) + bit_value = 1 & (hex_value >> (last_bit_position - bit_position)) bin_values.append(bit_value) extracted_value = binary_list_to_int(bin_values) @@ -70,9 +70,9 @@ def compute_modadd_xor(modadd_objects, component_values): def test_find_all_xor_differential_trails_with_fixed_weight(): sat = SatXorDifferentialModel(speck_5rounds) - sat.set_window_size_weight_pr_vars(1) + sat.window_size_weight_pr_vars = 1 - assert int(sat.find_all_xor_differential_trails_with_fixed_weight(9)[0]["total_weight"]) == int(9.0) + assert int(sat.find_all_xor_differential_trails_with_fixed_weight(9)[0]["total_weight"]) == 9 def test_find_all_xor_differential_trails_with_weight_at_most(): @@ -85,32 +85,30 @@ def test_find_all_xor_differential_trails_with_weight_at_most(): def test_find_lowest_weight_xor_differential_trail(): speck = speck_5rounds - sat = SatXorDifferentialModel(speck) + sat = SatXorDifferentialModel(speck, counter="parallel") trail = sat.find_lowest_weight_xor_differential_trail() - assert int(trail["total_weight"]) == int(9.0) + assert int(trail["total_weight"]) == 9 def test_find_one_xor_differential_trail(): speck = speck_5rounds sat = SatXorDifferentialModel(speck) plaintext = set_fixed_variables( - component_id="plaintext", + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), - bit_values=integer_to_bit_list(0, 32, "big"), + bit_values=(0,) * 32, ) trail = sat.find_one_xor_differential_trail(fixed_values=[plaintext]) assert str(trail["cipher"]) == "speck_p32_k64_o32_r5" - assert trail["model_type"] == "xor_differential" - assert trail["solver_name"] == "CRYPTOMINISAT_EXT" - assert trail["status"] == "SATISFIABLE" + assert trail["model_type"] == XOR_DIFFERENTIAL + assert trail["status"] == SATISFIABLE - trail = sat.find_one_xor_differential_trail(fixed_values=[plaintext], solver_name="KISSAT_EXT") + trail = sat.find_one_xor_differential_trail(fixed_values=[plaintext], solver_name=KISSAT_EXT) - assert trail["solver_name"] == "KISSAT_EXT" - assert trail["status"] == "SATISFIABLE" + assert trail["status"] == SATISFIABLE def test_find_one_xor_differential_trail_with_fixed_weight(): @@ -119,16 +117,16 @@ def test_find_one_xor_differential_trail_with_fixed_weight(): sat.set_window_size_heuristic_by_round([0, 0, 0]) result = sat.find_one_xor_differential_trail_with_fixed_weight(3) - assert int(result["total_weight"]) == int(3.0) + assert int(result["total_weight"]) == 3 def test_find_one_xor_differential_trail_with_fixed_weight_with_at_least_one_full_2_window(): speck = SpeckBlockCipher(number_of_rounds=9) sat = SatXorDifferentialModel(speck) sat.set_window_size_heuristic_by_round([2 for _ in range(9)], number_of_full_windows=1) - result = sat.find_one_xor_differential_trail_with_fixed_weight(30, solver_name="CADICAL_EXT") + result = sat.find_one_xor_differential_trail_with_fixed_weight(30, solver_name=CADICAL_EXT) - assert int(result["total_weight"]) == int(30.0) + assert int(result["total_weight"]) == 30 def test_find_one_xor_differential_trail_with_fixed_weight_and_with_exactly_three_full_2_window(): @@ -139,7 +137,7 @@ def test_find_one_xor_differential_trail_with_fixed_weight_and_with_exactly_thre sat.set_window_size_heuristic_by_round( [window_size for _ in range(9)], number_of_full_windows=number_of_full_windows ) - result = sat.find_one_xor_differential_trail_with_fixed_weight(30, solver_name="CADICAL_EXT") + result = sat.find_one_xor_differential_trail_with_fixed_weight(30, solver_name=CADICAL_EXT) speck_components = speck.get_all_components() modadd_objects = list(filter(lambda obj: isinstance(obj, MODADD), speck_components)) @@ -161,16 +159,16 @@ def test_find_one_xor_differential_trail_with_fixed_weight_and_with_exactly_one_ ) plaintext = set_fixed_variables( - component_id="plaintext", + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), - bit_values=[1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + bit_values=(1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0), ) key = set_fixed_variables( - component_id="key", constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) sat.build_xor_differential_trail_model(34, fixed_variables=[plaintext, key]) - result = sat._solve_with_external_sat_solver("xor_differential", "PARKISSAT_EXT", ["-c=6"]) + result = sat._solve_with_external_sat_solver(XOR_DIFFERENTIAL, PARKISSAT_EXT, ["-c=6"]) speck_components = speck.get_all_components() modadd_objects = list(filter(lambda obj: isinstance(obj, MODADD), speck_components)) carry_list = compute_modadd_xor(modadd_objects, result["components_values"]) @@ -185,9 +183,9 @@ def test_find_one_xor_differential_trail_with_fixed_weight_9_rounds(): sat = SatXorDifferentialModel(speck) sat.set_window_size_heuristic_by_round([2 for _ in range(9)]) - result = sat.find_one_xor_differential_trail_with_fixed_weight(30, solver_name="CADICAL_EXT") + result = sat.find_one_xor_differential_trail_with_fixed_weight(30, solver_name=CADICAL_EXT) - assert int(result["total_weight"]) == int(30.0) + assert int(result["total_weight"]) == 30 def test_find_one_xor_differential_trail_with_fixed_weight_with_at_least_one_full_window_parallel(): @@ -195,18 +193,18 @@ def test_find_one_xor_differential_trail_with_fixed_weight_with_at_least_one_ful sat = SatXorDifferentialModel(speck) sat.set_window_size_heuristic_by_round([3 for _ in range(10)], number_of_full_windows=1) plaintext = set_fixed_variables( - component_id="plaintext", + component_id=INPUT_PLAINTEXT, constraint_type="not_equal", bit_positions=range(32), - bit_values=integer_to_bit_list(0, 32, "big"), + bit_values=(0,) * 32, ) key = set_fixed_variables( - component_id="key", constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=(0,) * 64 ) sat.build_xor_differential_trail_model(34, fixed_variables=[plaintext, key]) - result = sat._solve_with_external_sat_solver("xor_differential", "PARKISSAT_EXT", ["-c=10"]) + result = sat._solve_with_external_sat_solver(XOR_DIFFERENTIAL, PARKISSAT_EXT, ["-c=10"]) - assert int(result["total_weight"]) == int(34.0) + assert int(result["total_weight"]) == 34 def test_find_one_xor_differential_trail_with_fixed_weight_and_window_heuristic_per_component(): @@ -219,7 +217,7 @@ def test_find_one_xor_differential_trail_with_fixed_weight_and_window_heuristic_ sat.set_window_size_heuristic_by_component_id(dict_of_window_heuristic_per_component) result = sat.find_one_xor_differential_trail_with_fixed_weight(3) - assert int(result["total_weight"]) == int(3.0) + assert int(result["total_weight"]) == 3 def test_build_xor_differential_trail_model_fixed_weight_and_parkissat(): @@ -227,9 +225,9 @@ def test_build_xor_differential_trail_model_fixed_weight_and_parkissat(): speck = SpeckBlockCipher(number_of_rounds=3) sat = SatXorDifferentialModel(speck) sat.build_xor_differential_trail_model(3) - result = sat._solve_with_external_sat_solver("xor_differential", "PARKISSAT_EXT", [f"-c={number_of_cores}"]) + result = sat._solve_with_external_sat_solver(XOR_DIFFERENTIAL, PARKISSAT_EXT, [f"-c={number_of_cores}"]) - assert int(result["total_weight"]) == int(3.0) + assert int(result["total_weight"]) == 3 def repeat_input_difference(input_difference_, number_of_samples_, number_of_bytes_): @@ -238,43 +236,3 @@ def repeat_input_difference(input_difference_, number_of_samples_, number_of_byt column_array = np_array.reshape(-1, 1) return np.tile(column_array, (1, number_of_samples_)) - - -def test_differential_in_related_key_scenario_speck3264(): - rng = np.random.default_rng(seed=42) - number_of_samples = 2**20 - input_difference = 0x00402000 - output_difference = 0x0 - key_difference = 0x000AFA3D0030A000 - input_difference_data = repeat_input_difference(input_difference, number_of_samples, 4) - output_difference_data = repeat_input_difference(output_difference, number_of_samples, 4) - key_difference_data = repeat_input_difference(key_difference, number_of_samples, 8) - key_data = rng.integers(low=0, high=256, size=(8, number_of_samples), dtype=np.uint8) - plaintext_data1 = rng.integers(low=0, high=256, size=(4, number_of_samples), dtype=np.uint8) - plaintext_data2 = plaintext_data1 ^ input_difference_data - ciphertext1 = speck_5rounds.evaluate_vectorized([plaintext_data1, key_data]) - ciphertext2 = speck_5rounds.evaluate_vectorized([plaintext_data2, key_data ^ key_difference_data]) - rows_all_true = np.all((ciphertext1[0] ^ ciphertext2[0] == output_difference_data.T), axis=1) - total = np.count_nonzero(rows_all_true) - total_prob_weight = math.log(total / number_of_samples, 2) - - assert 21 > abs(total_prob_weight) > 12 - - -def test_differential_in_single_key_scenario_speck3264(): - rng = np.random.default_rng(seed=3) - number_of_samples = 2**9 - input_difference = 0x02110A04 - output_difference = 0x81008102 - input_difference_data = repeat_input_difference(input_difference, number_of_samples, 4) - output_difference_data = repeat_input_difference(output_difference, number_of_samples, 4) - key_data = rng.integers(low=0, high=256, size=(8, number_of_samples), dtype=np.uint8) - plaintext_data1 = rng.integers(low=0, high=256, size=(4, number_of_samples), dtype=np.uint8) - plaintext_data2 = plaintext_data1 ^ input_difference_data - ciphertext1 = speck_4rounds.evaluate_vectorized([plaintext_data1, key_data]) - ciphertext2 = speck_4rounds.evaluate_vectorized([plaintext_data2, key_data]) - rows_all_true = np.all((ciphertext1[0] ^ ciphertext2[0] == output_difference_data.T), axis=1) - total = np.count_nonzero(rows_all_true) - total_prob_weight = math.log(total / number_of_samples, 2) - - assert 19 > abs(total_prob_weight) > 6 diff --git a/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py b/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py index 35dc305d3..acf399fa8 100644 --- a/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py +++ b/tests/unit/cipher_modules/models/sat/sat_models/sat_xor_linear_model_test.py @@ -41,7 +41,6 @@ def test_find_one_xor_linear_trail(): assert str(trail["cipher"]) == "speck_p32_k64_o32_r4" assert trail["model_type"] == "xor_linear" - assert trail["solver_name"] == "CRYPTOMINISAT_EXT" assert trail["status"] == "SATISFIABLE" diff --git a/tests/unit/cipher_modules/models/sat/utils/sat_model_utils_test.py b/tests/unit/cipher_modules/models/sat/utils/sat_model_utils_test.py index 597e7a38b..10e6cc2f9 100644 --- a/tests/unit/cipher_modules/models/sat/utils/sat_model_utils_test.py +++ b/tests/unit/cipher_modules/models/sat/utils/sat_model_utils_test.py @@ -1,16 +1,16 @@ -from claasp.cipher_modules.models.sat.utils.utils import (cnf_or, cnf_xor_seq) +from claasp.cipher_modules.models.sat.utils.utils import cnf_or, cnf_xor_seq def test_cnf_or(): - assert cnf_or('r', ['a', 'b', 'c']) == ['r -a', 'r -b', 'r -c', '-r a b c'] + assert cnf_or("r", ["a", "b", "c"]) == ["r -a", "r -b", "r -c", "-r a b c"] def test_cnf_xor_seq(): - xor_seq = cnf_xor_seq(['i_0', 'i_1', 'r_7'], ['a_7', 'b_7', 'c_7', 'd_7']) + xor_seq = cnf_xor_seq(["i_0", "i_1", "r_7"], ["a_7", "b_7", "c_7", "d_7"]) - assert xor_seq[0] == '-i_0 a_7 b_7' - assert xor_seq[1] == 'i_0 -a_7 b_7' - assert xor_seq[2] == 'i_0 a_7 -b_7' - assert xor_seq[-3] == 'r_7 -i_1 d_7' - assert xor_seq[-2] == 'r_7 i_1 -d_7' - assert xor_seq[-1] == '-r_7 -i_1 -d_7' + assert xor_seq[0] == "-i_0 a_7 b_7" + assert xor_seq[1] == "i_0 -a_7 b_7" + assert xor_seq[2] == "i_0 a_7 -b_7" + assert xor_seq[-3] == "r_7 -i_1 d_7" + assert xor_seq[-2] == "r_7 i_1 -d_7" + assert xor_seq[-1] == "-r_7 -i_1 -d_7" diff --git a/tests/unit/cipher_modules/models/smt/smt_model_test.py b/tests/unit/cipher_modules/models/smt/smt_model_test.py index 18dc550d6..ae3bf7cfa 100644 --- a/tests/unit/cipher_modules/models/smt/smt_model_test.py +++ b/tests/unit/cipher_modules/models/smt/smt_model_test.py @@ -3,20 +3,24 @@ from claasp.cipher_modules.models.smt.smt_model import SmtModel from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.name_mappings import INPUT_PLAINTEXT def test_fix_variables_value_constraints(): speck = SpeckBlockCipher(number_of_rounds=3) smt = SmtModel(speck) - fixed_variables = [set_fixed_variables('plaintext', 'equal', range(4), integer_to_bit_list(5, 4, 'big'))] - assert smt.fix_variables_value_constraints(fixed_variables) == ['(assert (not plaintext_0))', - '(assert plaintext_1)', - '(assert (not plaintext_2))', - '(assert plaintext_3)'] + fixed_variables = [set_fixed_variables(INPUT_PLAINTEXT, "equal", range(4), integer_to_bit_list(5, 4, "big"))] + assert smt.fix_variables_value_constraints(fixed_variables) == [ + "(assert (not plaintext_0))", + "(assert plaintext_1)", + "(assert (not plaintext_2))", + "(assert plaintext_3)", + ] - fixed_variables = [set_fixed_variables('plaintext', 'not_equal', range(4), integer_to_bit_list(5, 4, 'big'))] + fixed_variables = [set_fixed_variables(INPUT_PLAINTEXT, "not_equal", range(4), integer_to_bit_list(5, 4, "big"))] assert smt.fix_variables_value_constraints(fixed_variables) == [ - '(assert (or plaintext_0 (not plaintext_1) plaintext_2 (not plaintext_3)))'] + "(assert (or plaintext_0 (not plaintext_1) plaintext_2 (not plaintext_3)))" + ] def test_model_constraints(): diff --git a/tests/unit/cipher_modules/models/smt/smt_models/smt_cipher_model_test.py b/tests/unit/cipher_modules/models/smt/smt_models/smt_cipher_model_test.py index 40c4345d7..8bfbace49 100644 --- a/tests/unit/cipher_modules/models/smt/smt_models/smt_cipher_model_test.py +++ b/tests/unit/cipher_modules/models/smt/smt_models/smt_cipher_model_test.py @@ -1,21 +1,22 @@ -from claasp.cipher_modules.models.utils import integer_to_bit_list, set_fixed_variables from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.cipher_modules.models.smt.smt_models.smt_cipher_model import SmtCipherModel +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY def test_find_missing_bits(): - speck = SpeckBlockCipher(number_of_rounds=22) - cipher_output_id = speck.get_all_components_ids()[-1] + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=22) smt = SmtCipherModel(speck) - ciphertext = set_fixed_variables(component_id=cipher_output_id, - constraint_type='equal', - bit_positions=range(32), - bit_values=integer_to_bit_list(0x1234abcd, 32, 'big')) + cipher_output_id = speck.get_all_components_ids()[-1] + plaintext_bits = integer_to_bit_list(0x6574694C, 32, "big") + plaintext = set_fixed_variables( + component_id=INPUT_PLAINTEXT, constraint_type="equal", bit_positions=range(32), bit_values=plaintext_bits + ) + key_bits = integer_to_bit_list(0x1918111009080100, 64, "big") + key = set_fixed_variables( + component_id=INPUT_KEY, constraint_type="equal", bit_positions=range(64), bit_values=key_bits + ) - missing_bits = smt.find_missing_bits(fixed_values=[ciphertext]) + missing_bits = smt.find_missing_bits(fixed_values=[plaintext, key]) - assert str(missing_bits['cipher']) == 'speck_p32_k64_o32_r22' - assert missing_bits['model_type'] == 'cipher' - assert missing_bits['solver_name'] == 'Z3_EXT' - assert missing_bits['components_values'][cipher_output_id] == {'value': '1234abcd'} - assert missing_bits['status'] == 'SATISFIABLE' + assert missing_bits["components_values"][cipher_output_id]["value"] == "0xa86842f2" diff --git a/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_differential_model_test.py b/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_differential_model_test.py index 29586ebb8..bedc0f1de 100644 --- a/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_differential_model_test.py +++ b/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_differential_model_test.py @@ -1,5 +1,6 @@ -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.cipher_modules.models.smt.smt_models.smt_xor_differential_model import SmtXorDifferentialModel +from claasp.cipher_modules.models.smt.solvers import Z3_EXT +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher def test_find_all_xor_differential_trails_with_weight_at_most(): @@ -13,23 +14,23 @@ def test_find_lowest_weight_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=5) smt = SmtXorDifferentialModel(speck) trail = smt.find_lowest_weight_xor_differential_trail() - assert trail['total_weight'] == 9.0 + assert trail["total_weight"] == 9.0 def test_find_one_xor_differential_trail(): speck = SpeckBlockCipher(number_of_rounds=5) smt = SmtXorDifferentialModel(speck) solution = smt.find_one_xor_differential_trail() - assert str(solution['cipher']) == 'speck_p32_k64_o32_r5' - assert solution['solver_name'] == 'Z3_EXT' - assert eval('0x' + solution['components_values']['intermediate_output_0_6']['value']) >= 0 - assert solution['components_values']['intermediate_output_0_6']['weight'] == 0 - assert eval('0x' + solution['components_values']['cipher_output_4_12']['value']) >= 0 - assert solution['components_values']['cipher_output_4_12']['weight'] == 0 + assert str(solution["cipher"]) == "speck_p32_k64_o32_r5" + assert solution["solver_name"] == Z3_EXT + assert int(solution["components_values"]["intermediate_output_0_6"]["value"], 16) >= 0 + assert solution["components_values"]["intermediate_output_0_6"]["weight"] == 0 + assert int(solution["components_values"]["cipher_output_4_12"]["value"], 16) >= 0 + assert solution["components_values"]["cipher_output_4_12"]["weight"] == 0 def test_find_one_xor_differential_trail_with_fixed_weight(): speck = SpeckBlockCipher(number_of_rounds=3) smt = SmtXorDifferentialModel(speck) result = smt.find_one_xor_differential_trail_with_fixed_weight(3) - assert result['total_weight'] == 3.0 + assert result["total_weight"] == 3.0 diff --git a/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_linear_model_test.py b/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_linear_model_test.py index 88d0ab790..533769892 100644 --- a/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_linear_model_test.py +++ b/tests/unit/cipher_modules/models/smt/smt_models/smt_xor_linear_model_test.py @@ -1,50 +1,54 @@ -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher -from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list from claasp.cipher_modules.models.smt.smt_models.smt_xor_linear_model import SmtXorLinearModel +from claasp.cipher_modules.models.smt.solvers import Z3_EXT +from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.name_mappings import INPUT_KEY, INPUT_PLAINTEXT def test_find_all_xor_linear_trails_with_weight_at_most(): speck = SpeckBlockCipher(block_bit_size=8, key_bit_size=16, number_of_rounds=4) smt = SmtXorLinearModel(speck) - key = set_fixed_variables('key', 'not_equal', list(range(16)), [0] * 16) - trails = smt.find_all_xor_linear_trails_with_weight_at_most(0, 3, fixed_values=[key]) + key = set_fixed_variables(INPUT_KEY, "not_equal", list(range(16)), (0,) * 16) + trails = smt.find_all_xor_linear_trails_with_weight_at_most(0, 2, fixed_values=[key]) - assert len(trails) == 73 + assert len(trails) == 8 def test_find_lowest_weight_xor_linear_trail(): speck = SpeckBlockCipher(number_of_rounds=3) smt = SmtXorLinearModel(speck) trail = smt.find_lowest_weight_xor_linear_trail() - assert trail['total_weight'] == 1.0 + assert trail["total_weight"] == 1.0 def test_find_one_xor_linear_trail(): speck = SpeckBlockCipher(number_of_rounds=4) smt = SmtXorLinearModel(speck) solution = smt.find_one_xor_linear_trail() - assert str(solution['cipher']) == 'speck_p32_k64_o32_r4' - assert solution['solver_name'] == 'Z3_EXT' - assert eval('0x' + solution['components_values']['modadd_0_1_i']['value']) >= 0 - assert solution['components_values']['modadd_0_1_i']['weight'] == 0 - assert solution['components_values']['modadd_0_1_i']['sign'] == 1 - assert eval('0x' + solution['components_values']['xor_0_4_o']['value']) >= 0 - assert solution['components_values']['xor_0_4_o']['weight'] == 0 - assert solution['components_values']['xor_0_4_o']['sign'] == 1 + assert str(solution["cipher"]) == "speck_p32_k64_o32_r4" + assert solution["solver_name"] == Z3_EXT + assert int(solution["components_values"]["modadd_0_1_i"]["value"], 16) >= 0 + assert solution["components_values"]["modadd_0_1_i"]["weight"] == 0 + assert solution["components_values"]["modadd_0_1_i"]["sign"] == 1 + assert int(solution["components_values"]["xor_0_4_o"]["value"], 16) >= 0 + assert solution["components_values"]["xor_0_4_o"]["weight"] == 0 + assert solution["components_values"]["xor_0_4_o"]["sign"] == 1 def test_find_one_xor_linear_trail_with_fixed_weight(): speck = SpeckBlockCipher(number_of_rounds=3) smt = SmtXorLinearModel(speck) result = smt.find_one_xor_linear_trail_with_fixed_weight(7) - assert result['total_weight'] == 7.0 + assert result["total_weight"] == 7.0 def test_fix_variables_value_xor_linear_constraints(): speck = SpeckBlockCipher(number_of_rounds=3) smt = SmtXorLinearModel(speck) - fixed_variables = [set_fixed_variables('plaintext', 'equal', range(4), integer_to_bit_list(5, 4, 'big'))] - assert smt.fix_variables_value_xor_linear_constraints(fixed_variables) == ['(assert (not plaintext_0_o))', - '(assert plaintext_1_o)', - '(assert (not plaintext_2_o))', - '(assert plaintext_3_o)'] + fixed_variables = [set_fixed_variables(INPUT_PLAINTEXT, "equal", range(4), integer_to_bit_list(5, 4, "big"))] + assert smt.fix_variables_value_xor_linear_constraints(fixed_variables) == [ + "(assert (not plaintext_0_o))", + "(assert plaintext_1_o)", + "(assert (not plaintext_2_o))", + "(assert plaintext_3_o)", + ] diff --git a/tests/unit/cipher_modules/report_test.py b/tests/unit/cipher_modules/report_test.py index 3473f87db..0484977ad 100644 --- a/tests/unit/cipher_modules/report_test.py +++ b/tests/unit/cipher_modules/report_test.py @@ -1,65 +1,217 @@ +import copy +import pickle +from pathlib import Path + +from plotly.basedatatypes import BaseFigure + from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel from claasp.cipher_modules.models.smt.smt_models.smt_xor_differential_model import SmtXorDifferentialModel from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import MznXorDifferentialModel -from claasp.cipher_modules.models.milp.milp_models.milp_xor_differential_model import MilpXorDifferentialModel from claasp.cipher_modules.models.utils import set_fixed_variables from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.block_ciphers.simon_block_cipher import SimonBlockCipher from claasp.cipher_modules.report import Report from claasp.ciphers.block_ciphers.present_block_cipher import PresentBlockCipher from claasp.cipher_modules.statistical_tests.dieharder_statistical_tests import DieharderTests -from claasp.cipher_modules.statistical_tests.nist_statistical_tests import NISTStatisticalTests from claasp.cipher_modules.neural_network_tests import NeuralNetworkTests from claasp.cipher_modules.algebraic_tests import AlgebraicTests from claasp.cipher_modules.avalanche_tests import AvalancheTests from claasp.cipher_modules.component_analysis_tests import CipherComponentsAnalysis from claasp.cipher_modules.continuous_diffusion_analysis import ContinuousDiffusionAnalysis -from sage.all import load -# from tests.precomputed_test_results import speck_three_rounds_component_analysis, speck_three_rounds_avalanche_tests, speck_three_rounds_neural_network_tests, speck_three_rounds_dieharder_tests, present_four_rounds_find_one_xor_differential_trail -def test_save_as_image(): - speck = SpeckBlockCipher(number_of_rounds=2) - sat = SatXorDifferentialModel(speck) - plaintext = set_fixed_variables( - component_id='plaintext', - constraint_type='not_equal', - bit_positions=range(32), - bit_values=(0,) * 32) - key = set_fixed_variables( - component_id='key', - constraint_type='equal', - bit_positions=range(64), - bit_values=(0,) * 64) - trail = sat.find_lowest_weight_xor_differential_trail(fixed_values=[plaintext, key]) +CACHE_DIR = Path(__file__).resolve().parent / 'data' +CACHE_FILE = CACHE_DIR / 'report_test_cache.pkl' +# Heavy report fixtures are persisted to disk so subsequent test runs can +# reuse them instantly. Delete the pickle to force regeneration. +_CACHE = None + + +def _load_cache(): + global _CACHE + if _CACHE is None: + if CACHE_FILE.exists(): + with CACHE_FILE.open('rb') as cache_file: + _CACHE = pickle.load(cache_file) + else: + _CACHE = {} + return _CACHE + + +def _persist_cache(): + CACHE_DIR.mkdir(parents=True, exist_ok=True) + with CACHE_FILE.open('wb') as cache_file: + pickle.dump(_CACHE, cache_file) + + +def _get_cached_result(key, factory): + cache = _load_cache() + if key in cache: + return copy.deepcopy(cache[key]) + result = factory() + cache[key] = result + _persist_cache() + return copy.deepcopy(result) + + +def _cached_speck_trail(): + def _generate(): + cipher = SpeckBlockCipher(number_of_rounds=2) + sat_model = SatXorDifferentialModel(cipher) + plaintext = set_fixed_variables( + component_id='plaintext', + constraint_type='not_equal', + bit_positions=range(32), + bit_values=(0,) * 32) + key = set_fixed_variables( + component_id='key', + constraint_type='equal', + bit_positions=range(64), + bit_values=(0,) * 64) + return sat_model.find_lowest_weight_xor_differential_trail(fixed_values=[plaintext, key]) + + return _get_cached_result('speck_r2_sat_trail', _generate) + + +def _cached_speck_avalanche_results(): + def _generate(): + cipher = SpeckBlockCipher(number_of_rounds=2) + return AvalancheTests(cipher).avalanche_tests() + + return _get_cached_result('speck_r2_avalanche_tests', _generate) + + +def _cached_speck_neural_network_results(): + def _generate(): + cipher = SpeckBlockCipher(number_of_rounds=2) + return NeuralNetworkTests(cipher).neural_network_blackbox_distinguisher_tests(nb_samples=10) + + return _get_cached_result('speck_r2_neural_network_blackbox', _generate) + + +def _cached_speck_algebraic_results(): + def _generate(): + cipher = SpeckBlockCipher(number_of_rounds=2) + return AlgebraicTests(cipher).algebraic_tests(timeout_in_seconds=1) + + return _get_cached_result('speck_r2_algebraic_tests', _generate) + + +def _cached_speck_component_analysis(): + def _generate(): + cipher = SpeckBlockCipher(number_of_rounds=2) + return CipherComponentsAnalysis(cipher).component_analysis_tests() + + return _get_cached_result('speck_r2_component_analysis', _generate) + + +def _cached_speck_cda_results(): + def _generate(): + cipher = SpeckBlockCipher(number_of_rounds=2) + cda_module = ContinuousDiffusionAnalysis(cipher) + return cda_module.continuous_diffusion_tests() + + return _get_cached_result('speck_r2_continuous_diffusion', _generate) + + +def _cached_present_trail(): + def _generate(): + cipher = PresentBlockCipher(number_of_rounds=2) + sat_model = SatXorDifferentialModel(cipher) + related_key_setting = [ + set_fixed_variables(component_id='key', constraint_type='not_equal', bit_positions=list(range(80)), + bit_values=[0] * 80), + set_fixed_variables(component_id='plaintext', constraint_type='equal', bit_positions=list(range(64)), + bit_values=[0] * 64) + ] + return sat_model.find_one_xor_differential_trail_with_fixed_weight( + fixed_weight=16, + fixed_values=related_key_setting, + solver_name='KISSAT_EXT') + + return _get_cached_result('present_r2_sat_fixed_weight_16', _generate) + + +def _cached_speck_dieharder_results(): + def _generate(): + cipher = SpeckBlockCipher(number_of_rounds=2) + dieharder = DieharderTests(cipher) + return dieharder.dieharder_statistical_tests('avalanche', dieharder_test_option=100) + + return _get_cached_result('speck_r2_dieharder_avalanche', _generate) + + +def test_save_as_image(monkeypatch, tmp_path): + captured_writes = [] + + def fake_write_image(self, file, *args, **kwargs): + captured_writes.append((file, args, kwargs)) + return None + + monkeypatch.setattr(BaseFigure, 'write_image', fake_write_image) + + output_dir = str(tmp_path / 'report-cache') + _run_save_as_image(output_dir) + + assert captured_writes, 'Expected Plotly write_image to be invoked at least once' + + +def _run_save_as_image(output_dir): + trail = _cached_speck_trail() trail_report = Report(trail) - trail_report.save_as_image() + trail_report.save_as_image(output_directory=output_dir) + trail_report.clean_reports(output_dir=output_dir) - avalanche_results = AvalancheTests(speck).avalanche_tests() + avalanche_results = _cached_speck_avalanche_results() avalanche_report = Report(avalanche_results) - avalanche_report.save_as_image(test_name='avalanche_weight_vectors', fixed_input='plaintext', fixed_output='round_output', + avalanche_report.save_as_image(output_directory=output_dir, test_name='avalanche_weight_vectors', fixed_input='plaintext', fixed_output='round_output', fixed_input_difference='average') + avalanche_report.clean_reports(output_dir=output_dir) - blackbox_results = NeuralNetworkTests(speck).neural_network_blackbox_distinguisher_tests() + blackbox_results = _cached_speck_neural_network_results() blackbox_report = Report(blackbox_results) - blackbox_report.save_as_image() + blackbox_report.save_as_image(output_directory=output_dir) + blackbox_report.clean_reports(output_dir=output_dir) - algebraic_results = AlgebraicTests(speck).algebraic_tests(timeout_in_seconds=1) + algebraic_results = _cached_speck_algebraic_results() algebraic_report = Report(algebraic_results) - algebraic_report.save_as_image() + algebraic_report.save_as_image(output_directory=output_dir) + algebraic_report.clean_reports(output_dir=output_dir) - component_analysis = CipherComponentsAnalysis(speck).component_analysis_tests() + component_analysis = _cached_speck_component_analysis() report_cca = Report(component_analysis) - report_cca.save_as_image() + report_cca.save_as_image(output_directory=output_dir) + report_cca.clean_reports(output_dir=output_dir) - speck = SpeckBlockCipher(number_of_rounds=2) - cda = ContinuousDiffusionAnalysis(speck) - cda_for_repo = cda.continuous_diffusion_tests() + cda_for_repo = _cached_speck_cda_results() cda_repo = Report(cda_for_repo) - cda_repo.save_as_image() + cda_repo.save_as_image(output_directory=output_dir) + cda_repo.clean_reports(output_dir=output_dir) + + +def _run_show(): + component_analysis = _cached_speck_component_analysis() + report_cca = Report(component_analysis) + report_cca.show() + avalanche_results = _cached_speck_avalanche_results() + avalanche_report = Report(avalanche_results) + avalanche_report.show(test_name=None) + avalanche_report.show(test_name='avalanche_weight_vectors', fixed_input_difference=None) + avalanche_report.show(test_name='avalanche_weight_vectors', fixed_input_difference='average') + present_trail = _cached_present_trail() + present_trail_report = Report(present_trail) + present_trail_report.show() + dieharder_results = _cached_speck_dieharder_results() + dieharder_report = Report(dieharder_results) + dieharder_report.show() + + neural_network_results = _cached_speck_neural_network_results() + neural_network_tests_report = Report(neural_network_results) + neural_network_tests_report.show(fixed_input_difference=None) + neural_network_tests_report.show(fixed_input_difference='0xa') def test_save_as_latex_table(): simon = SimonBlockCipher(number_of_rounds=2) @@ -81,13 +233,14 @@ def test_save_as_latex_table(): avalanche_test_results = AvalancheTests(simon).avalanche_tests() avalanche_report = Report(avalanche_test_results) avalanche_report.save_as_latex_table(fixed_input='plaintext',fixed_output='round_output',fixed_test='avalanche_weight_vectors') - + avalanche_report.clean_reports() trail_report = Report(trail) trail_report.save_as_latex_table() - + trail_report.clean_reports() dieharder=DieharderTests(simon) report_sts = Report(dieharder.dieharder_statistical_tests('avalanche', dieharder_test_option=100)) report_sts.save_as_latex_table() + report_sts.clean_reports() def test_save_as_DataFrame(): speck = SpeckBlockCipher(number_of_rounds=2) @@ -107,26 +260,30 @@ def test_save_as_DataFrame(): avalanche_results = AvalancheTests(speck).avalanche_tests() avalanche_report = Report(avalanche_results) avalanche_report.save_as_DataFrame(fixed_input='plaintext',fixed_output='round_output',fixed_test='avalanche_weight_vectors') - + avalanche_report.clean_reports() trail_report = Report(trail) trail_report.save_as_DataFrame() - + trail_report.clean_reports() dieharder = DieharderTests(speck) report_sts = Report(dieharder.dieharder_statistical_tests('avalanche', dieharder_test_option=100)) report_sts.save_as_DataFrame() - + report_sts.clean_reports() def test_save_as_json(): - simon = SimonBlockCipher(number_of_rounds=2) - neural_network_blackbox_distinguisher_tests_results = NeuralNetworkTests( - simon).neural_network_blackbox_distinguisher_tests() + speck = SpeckBlockCipher(number_of_rounds=2) + + neural_network_blackbox_distinguisher_tests_results = NeuralNetworkTests(speck).neural_network_blackbox_distinguisher_tests(nb_samples=10) blackbox_report = Report(neural_network_blackbox_distinguisher_tests_results) blackbox_report.save_as_json(fixed_input='plaintext',fixed_output='round_output') + blackbox_report.clean_reports() + + simon = SimonBlockCipher(number_of_rounds=2) + dieharder = DieharderTests(simon) report_sts = Report(dieharder.dieharder_statistical_tests('avalanche', dieharder_test_option=100)) report_sts.save_as_json() - + report_sts.clean_reports() present = PresentBlockCipher(number_of_rounds=2) sat = SatXorDifferentialModel(present) related_key_setting = [ @@ -143,25 +300,25 @@ def test_save_as_json(): avalanche_results = AvalancheTests(simon).avalanche_tests() avalanche_report = Report(avalanche_results) avalanche_report.save_as_json(fixed_input='plaintext',fixed_output='round_output',fixed_test='avalanche_weight_vectors') + avalanche_report.clean_reports() +def test_show(monkeypatch, tmp_path): + captured_fig_shows = [] + component_charts = [] -def test_show(): - precomputed_results = load('tests/precomputed_results.sobj') - component_analysis = precomputed_results['speck_three_rounds_component_analysis'] - report_cca = Report(component_analysis) - report_cca.show() - avalanche_results = precomputed_results['speck_three_rounds_avalanche_test'] - avalanche_report = Report(avalanche_results) - avalanche_report.show(test_name=None) - avalanche_report.show(test_name='avalanche_weight_vectors', fixed_input_difference=None) - avalanche_report.show(test_name='avalanche_weight_vectors', fixed_input_difference='average') - trail = precomputed_results['present_four_rounds_trail_search'] - trail_report = Report(trail) - trail_report.show() - dieharder_test_results = precomputed_results['speck_three_rounds_dieharder_test'] - report_sts = Report(dieharder_test_results) - report_sts.show() - neural_network_test_results = precomputed_results['speck_three_rounds_neural_network_test'] - neural_network_tests_report = Report(neural_network_test_results) - neural_network_tests_report.show(fixed_input_difference=None) - neural_network_tests_report.show(fixed_input_difference='0xa') \ No newline at end of file + def fake_show(self, *args, **kwargs): + captured_fig_shows.append((args, kwargs)) + return None + + def fake_component_radar(self, results): + component_charts.append(results) + return None + + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(BaseFigure, 'show', fake_show) + monkeypatch.setattr(CipherComponentsAnalysis, 'print_component_analysis_as_radar_charts', fake_component_radar) + + _run_show() + + assert captured_fig_shows, 'Expected Plotly show to be invoked at least once' + assert component_charts, 'Expected component analysis radar chart to be produced' \ No newline at end of file diff --git a/tests/unit/ciphers/block_ciphers/kalyna_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/kalyna_block_cipher_test.py new file mode 100644 index 000000000..d6dc2868c --- /dev/null +++ b/tests/unit/ciphers/block_ciphers/kalyna_block_cipher_test.py @@ -0,0 +1,15 @@ +"""Kalyna tests + +Reference: https://eprint.iacr.org/2015/650.pdf +""" + +from claasp.ciphers.block_ciphers.kalyna_block_cipher import KalynaBlockCipher + + +def test_kalyna_block_cipher(): + kalyna = KalynaBlockCipher() + key = 0x0F0E0D0C0B0A09080706050403020100 + plaintext = 0x1F1E1D1C1B1A19181716151413121110 + ciphertext = 0x06ADD2B439EAC9E120AC9B777D1CBF81 + assert kalyna.evaluate([key, plaintext]) == ciphertext + assert kalyna.evaluate_vectorized([key, plaintext], evaluate_api=True) == ciphertext diff --git a/tests/unit/ciphers/block_ciphers/led_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/led_block_cipher_test.py new file mode 100644 index 000000000..1918dde8f --- /dev/null +++ b/tests/unit/ciphers/block_ciphers/led_block_cipher_test.py @@ -0,0 +1,46 @@ +"""LED tests + +Test vectors from https://eprint.iacr.org/2012/600.pdf +""" + +import pytest + +from claasp.ciphers.block_ciphers.led_block_cipher import LedBlockCipher +from claasp.name_mappings import BLOCK_CIPHER + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning:") +def test_led_block_cipher(): + led = LedBlockCipher() + + assert led.type == BLOCK_CIPHER + assert led.family_name == "led" + assert led.number_of_rounds == 8 + assert led.id == "led_p64_k64_o64_r8" + assert led.component_from(0, 0).id == "xor_0_0" + + plaintext = 0x0000000000000000 + key = 0x0000000000000000 + ciphertext = 0x39C2401003A0C798 + assert led.evaluate([plaintext, key]) == ciphertext + assert led.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext + + plaintext = 0x0123456789ABCDEF + key = 0x0123456789ABCDEF + ciphertext = 0xA003551E3893FC58 + assert led.evaluate([plaintext, key]) == ciphertext + assert led.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext + + led = LedBlockCipher(key_bit_size=128, number_of_rounds=48) + + plaintext = 0x0000000000000000 + key = 0x00000000000000000000000000000000 + ciphertext = 0x3DECB2A0850CDBA1 + assert led.evaluate([plaintext, key]) == ciphertext + assert led.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext + + plaintext = 0x0123456789ABCDEF + key = 0x0123456789ABCDEF0123456789ABCDEF + ciphertext = 0xD6B824587F014FC2 + assert led.evaluate([plaintext, key]) == ciphertext + assert led.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext diff --git a/tests/unit/ciphers/block_ciphers/mantis_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/mantis_block_cipher_test.py new file mode 100644 index 000000000..7e0c448ae --- /dev/null +++ b/tests/unit/ciphers/block_ciphers/mantis_block_cipher_test.py @@ -0,0 +1,36 @@ +from claasp.ciphers.block_ciphers.mantis_block_cipher import MantisBlockCipher + +""" +MANTIS block cipher validation +Test vectors reproduced from Beierle et al., "The SKINNY Family of Block Ciphers and its Low-Latency Variant MANTIS" +(IACR ePrint 2016/660): https://eprint.iacr.org/2016/660.pdf +""" + +def test_mantis_block_cipher(): + mantis5 = MantisBlockCipher(number_of_rounds=5) + plaintext = 0x3b5c77a4921f9718 + key = 0x92f09952c625e3e9d7a060f714c0292b + tweak = 0xba912e6f1055fed2 + ciphertext = 0xd6522035c1c0c6c1 + assert mantis5.evaluate([plaintext, key, tweak]) == ciphertext + + mantis6 = MantisBlockCipher(number_of_rounds=6) + plaintext = 0xd6522035c1c0c6c1 + key = 0x92f09952c625e3e9d7a060f714c0292b + tweak = 0xba912e6f1055fed2 + ciphertext = 0x60e43457311936fd + assert mantis6.evaluate([plaintext, key, tweak]) == ciphertext + + mantis7 = MantisBlockCipher(number_of_rounds=7) + plaintext = 0x60e43457311936fd + key = 0x92f09952c625e3e9d7a060f714c0292b + tweak = 0xba912e6f1055fed2 + ciphertext = 0x308e8a07f168f517 + assert mantis7.evaluate([plaintext, key, tweak]) == ciphertext + + mantis8 = MantisBlockCipher(number_of_rounds=8) + plaintext = 0x308e8a07f168f517 + key = 0x92f09952c625e3e9d7a060f714c0292b + tweak = 0xba912e6f1055fed2 + ciphertext = 0x971ea01a86b410bb + assert mantis8.evaluate([plaintext, key, tweak]) == ciphertext diff --git a/tests/unit/ciphers/block_ciphers/qarmav2_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/qarmav2_block_cipher_test.py index 90b496550..32c0b8ec3 100644 --- a/tests/unit/ciphers/block_ciphers/qarmav2_block_cipher_test.py +++ b/tests/unit/ciphers/block_ciphers/qarmav2_block_cipher_test.py @@ -9,7 +9,7 @@ def test_qarmav2_block_cipher(): assert qarmav2.type == 'block_cipher' assert qarmav2.family_name == 'qarmav2_block_cipher' assert qarmav2.number_of_rounds == 9 - assert qarmav2.id == 'qarmav2_block_cipher_k128_p64_i128_o64_r9' + assert qarmav2.id == 'qarmav2_block_cipher_k128_p64_t128_o64_r9' assert qarmav2.component_from(0, 0).id == 'constant_0_0' qarmav2 = QARMAv2BlockCipher(number_of_rounds = 4) diff --git a/tests/unit/ciphers/block_ciphers/qarmav2_with_mixcolumn_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/qarmav2_with_mixcolumn_block_cipher_test.py index e6a679c2c..2c7d4b842 100644 --- a/tests/unit/ciphers/block_ciphers/qarmav2_with_mixcolumn_block_cipher_test.py +++ b/tests/unit/ciphers/block_ciphers/qarmav2_with_mixcolumn_block_cipher_test.py @@ -9,7 +9,7 @@ def test_qarmav2_mixcolumn_block_cipher(): assert qarmav2.type == 'block_cipher' assert qarmav2.family_name == 'qarmav2_block_cipher' assert qarmav2.number_of_rounds == 9 - assert qarmav2.id == 'qarmav2_block_cipher_k128_p64_i128_o64_r9' + assert qarmav2.id == 'qarmav2_block_cipher_k128_p64_t128_o64_r9' assert qarmav2.component_from(0, 0).id == 'linear_layer_0_0' qarmav2 = QARMAv2MixColumnBlockCipher(number_of_rounds = 4) diff --git a/tests/unit/ciphers/block_ciphers/scarf_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/scarf_block_cipher_test.py index 6e0d5f881..15cc05baf 100644 --- a/tests/unit/ciphers/block_ciphers/scarf_block_cipher_test.py +++ b/tests/unit/ciphers/block_ciphers/scarf_block_cipher_test.py @@ -6,7 +6,7 @@ def test_scarf_block_cipher(): assert cipher.type == 'block_cipher' assert cipher.family_name == 'scarf' assert cipher.number_of_rounds == 8 - assert cipher.id == 'scarf_p10_k240_i48_o10_r8' + assert cipher.id == 'scarf_p10_k240_t48_o10_r8' assert cipher.component_from(0, 0).id == 'constant_0_0' plaintext = 0x0 diff --git a/tests/unit/ciphers/block_ciphers/skipjack_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/skipjack_block_cipher_test.py new file mode 100644 index 000000000..05f5d1557 --- /dev/null +++ b/tests/unit/ciphers/block_ciphers/skipjack_block_cipher_test.py @@ -0,0 +1,107 @@ +from claasp.ciphers.block_ciphers.skipjack_block_cipher import SkipjackBlockCipher + + +def test_skipjack_block_cipher(): + """ + Test vectors from [NIST1998] + https://csrc.nist.gov/csrc/media/projects/cryptographic-algorithm-validation-program/documents/skipjack/skipjack.pdf + + This test verifies the SKIPJACK implementation using the official NSA test vector + for 32 rounds, as well as intermediate round outputs for 8 and 16 rounds. + + The test checks both the Python evaluation and the vectorized evaluation + to ensure consistency with the expected outputs. + """ + # Test cipher properties + skipjack = SkipjackBlockCipher() + assert skipjack.type == 'block_cipher' + assert skipjack.family_name == 'skipjack' + assert skipjack.number_of_rounds == 32 + assert skipjack.id == 'skipjack_p64_k80_o64_r32' + assert skipjack.output_bit_size == 64 + + # Test component structure + components = skipjack.get_all_components() + assert len(components) == 576 + first_comp = skipjack.component_from(0, 0) + assert first_comp is not None + assert hasattr(first_comp, 'id') + + # Test official vectors Key and Plaintext + key = 0x00998877665544332211 + plaintext = 0x33221100ddccbbaa + + # Test intermediate output at 8 rounds to verify Rule A phase correctness + skipjack_8 = SkipjackBlockCipher(number_of_rounds=8) + ciphertext_8 = 0xd79b5599be50dd90 + assert skipjack_8.evaluate([plaintext, key]) == ciphertext_8 + assert skipjack_8.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext_8 + + # Test intermediate output at 16 rounds to verify Rule A and Rule B phases together + skipjack_16 = SkipjackBlockCipher(number_of_rounds=16) + ciphertext_16 = 0xd7f8899053979883 + assert skipjack_16.number_of_rounds == 16 + assert skipjack_16.id == 'skipjack_p64_k80_o64_r16' + assert skipjack_16.evaluate([plaintext, key]) == ciphertext_16 + assert skipjack_16.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext_16 + + # Test final output after full 32 rounds to verify second Rule B phase completion + skipjack_32 = SkipjackBlockCipher() + ciphertext_32 = 0x2587cae27a12d300 + assert skipjack_32.evaluate([plaintext, key]) == ciphertext_32 + assert skipjack_32.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext_32 + + +def test_skipjack_all_rounds(): + """ + Test each individual round output from 1 to 32 using intermediate values from [NIST1998] + + This test verifies the correctness of each round output to ensure + proper implementation of both Rule A and Rule B operations. + """ + key = 0x00998877665544332211 + plaintext = 0x33221100ddccbbaa + + # Expected output after each round from [NIST1998] + expected_outputs = [ + 0xb0040baf1100ddcc, # round 1 + 0xe6883b460baf1100, # round 2 + 0x3c762d753b460baf, # round 3 + 0x4c4547ee2d753b46, # round 4 + 0xb949820a47ee2d75, # round 5 + 0xf0e3dd90820a47ee, # round 6 + 0xf9b9be50dd90820a, # round 7 + 0xd79b5599be50dd90, # round 8 + 0xdd901e0b820bbe50, # round 9 + 0xbe504c52c391820b, # round 10 + 0x820b7f51f209c391, # round 11 + 0xc391f9c2fd56f209, # round 12 + 0xf20925ff3a5efd56, # round 13 + 0xfd5665dad7f83a5e, # round 14 + 0x3a5e69d99883d7f8, # round 15 + 0xd7f8899053979883, # round 16 + 0x9c00049289905397, # round 17 + 0x9fdccc5904928990, # round 18 + 0x3731beb2cc590492, # round 19 + 0x7afb7e7dbeb2cc59, # round 20 + 0x7759bb157e7dbeb2, # round 21 + 0xfb6445c0bb157e7d, # round 22 + 0x6f7f111545c0bb15, # round 23 + 0x65a7deaa111545c0, # round 24 + 0x45c0e0f9bb141115, # round 25 + 0x11153913a523bb14, # round 26 + 0xbb148ee6281da523, # round 27 + 0xa523bfe235ee281d, # round 28 + 0x281d0d841adc35ee, # round 29 + 0x35eee6f125871adc, # round 30 + 0x1adc60eed3002587, # round 31 + 0x2587cae27a12d300, # round 32 + ] + + # Test each round output + for rounds in range(1, 33): + skipjack = SkipjackBlockCipher(number_of_rounds=rounds) + result = skipjack.evaluate([plaintext, key]) + assert result == expected_outputs[rounds - 1], \ + f"Round {rounds} failed: expected {hex(expected_outputs[rounds - 1])}, got {hex(result)}" + \ No newline at end of file diff --git a/tests/unit/ciphers/block_ciphers/sm4_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/sm4_block_cipher_test.py new file mode 100644 index 000000000..e324b2afc --- /dev/null +++ b/tests/unit/ciphers/block_ciphers/sm4_block_cipher_test.py @@ -0,0 +1,24 @@ +from claasp.ciphers.block_ciphers.sm4_block_cipher import SM4 + +""" +The technical specifications along with the test vectors can be found here: http://www.gmbz.org.cn/upload/2025-01-23/1737625646289030731.pdf +and https://datatracker.ietf.org/doc/html/draft-ribose-cfrg-sm4-10. +""" + + +def test_sm4_block_cipher(): + sm4 = SM4() + key = 0x0123456789ABCDEFFEDCBA9876543210 + plaintext = 0x0123456789ABCDEFFEDCBA9876543210 + ciphertext = 0x681EDF34D206965E86B3E94F536E4246 + assert sm4.evaluate([key, plaintext]) == ciphertext + assert sm4.evaluate_vectorized([key, plaintext], evaluate_api=True) == ciphertext + + +def test_sm4_block_cipher(): + sm4 = SM4() + key = 0xFEDCBA98765432100123456789ABCDEF + plaintext = 0x000102030405060708090A0B0C0D0E0F + ciphertext = 0xF766678F13F01ADEAC1B3EA955ADB594 + assert sm4.evaluate([key, plaintext]) == ciphertext + assert sm4.evaluate_vectorized([key, plaintext], evaluate_api=True) == ciphertext diff --git a/tests/unit/ciphers/hash_functions/blake2_hash_function_test.py b/tests/unit/ciphers/hash_functions/blake2_hash_function_test.py index 26e4d14d2..99b975c90 100644 --- a/tests/unit/ciphers/hash_functions/blake2_hash_function_test.py +++ b/tests/unit/ciphers/hash_functions/blake2_hash_function_test.py @@ -6,12 +6,12 @@ def test_blake2_hash_function(): assert blake2.number_of_rounds == 12 assert blake2.type == 'hash_function' assert blake2.family_name == 'blake2' - assert blake2.id == 'blake2_i1024_i1024_o1024_r12' + assert blake2.id == 'blake2_m1024_s1024_o1024_r12' assert blake2.component_from(0, 0).id == 'modadd_0_0' blake2 = Blake2HashFunction(number_of_rounds=4) assert blake2.number_of_rounds == 4 - assert blake2.id == 'blake2_i1024_i1024_o1024_r4' + assert blake2.id == 'blake2_m1024_s1024_o1024_r4' assert blake2.component_from(3, 0).id == 'modadd_3_0' blake2 = Blake2HashFunction() diff --git a/tests/unit/ciphers/hash_functions/blake_hash_function_test.py b/tests/unit/ciphers/hash_functions/blake_hash_function_test.py index e410c75f9..f13b7ff2a 100644 --- a/tests/unit/ciphers/hash_functions/blake_hash_function_test.py +++ b/tests/unit/ciphers/hash_functions/blake_hash_function_test.py @@ -6,13 +6,13 @@ def test_blake_hash_function(): assert blake.number_of_rounds == 28 assert blake.type == 'hash_function' assert blake.family_name == 'blake' - assert blake.id == 'blake_i512_i512_o512_r28' + assert blake.id == 'blake_m512_s512_o512_r28' assert blake.component_from(0, 0).id == 'constant_0_0' blake = BlakeHashFunction(number_of_rounds=4) assert blake.number_of_rounds == 4 assert blake.type == 'hash_function' - assert blake.id == 'blake_i512_i512_o512_r4' + assert blake.id == 'blake_m512_s512_o512_r4' assert blake.component_from(3, 0).id == 'constant_3_0' blake = BlakeHashFunction() diff --git a/tests/unit/ciphers/hash_functions/md5_hash_function_test.py b/tests/unit/ciphers/hash_functions/md5_hash_function_test.py index d2e073f51..07b0d48ad 100644 --- a/tests/unit/ciphers/hash_functions/md5_hash_function_test.py +++ b/tests/unit/ciphers/hash_functions/md5_hash_function_test.py @@ -6,12 +6,12 @@ def test_md5_hash_function(): assert md5.number_of_rounds == 64 assert md5.type == 'hash_function' assert md5.family_name == 'MD5' - assert md5.id == 'MD5_i512_o64_r64' + assert md5.id == 'MD5_m512_o64_r64' assert md5.component_from(0, 0).id == 'constant_0_0' md5 = MD5HashFunction(number_of_rounds=4) assert md5.number_of_rounds == 4 - assert md5.id == 'MD5_i512_o64_r4' + assert md5.id == 'MD5_m512_o64_r4' assert md5.component_from(3, 0).id == 'constant_3_0' md5 = MD5HashFunction() diff --git a/tests/unit/ciphers/hash_functions/sha1_hash_function_test.py b/tests/unit/ciphers/hash_functions/sha1_hash_function_test.py index adf741b11..5c8f83934 100644 --- a/tests/unit/ciphers/hash_functions/sha1_hash_function_test.py +++ b/tests/unit/ciphers/hash_functions/sha1_hash_function_test.py @@ -6,12 +6,12 @@ def test_sha1_hash_function(): assert sha1.number_of_rounds == 80 assert sha1.type == 'hash_function' assert sha1.family_name == 'SHA1' - assert sha1.id == 'SHA1_i512_o160_r80' + assert sha1.id == 'SHA1_m512_o160_r80' assert sha1.component_from(0, 0).id == 'constant_0_0' sha1 = SHA1HashFunction(number_of_rounds=4) assert sha1.number_of_rounds == 4 - assert sha1.id == 'SHA1_i512_o160_r4' + assert sha1.id == 'SHA1_m512_o160_r4' assert sha1.component_from(3, 0).id == 'and_3_0' sha1 = SHA1HashFunction() diff --git a/tests/unit/ciphers/hash_functions/sha2_hash_function_test.py b/tests/unit/ciphers/hash_functions/sha2_hash_function_test.py index 517d9fc3c..d84a3cb48 100644 --- a/tests/unit/ciphers/hash_functions/sha2_hash_function_test.py +++ b/tests/unit/ciphers/hash_functions/sha2_hash_function_test.py @@ -6,12 +6,12 @@ def test_sha2_hash_function(): assert sha2.number_of_rounds == 65 assert sha2.type == 'hash_function' assert sha2.family_name == 'SHA2_family' - assert sha2.id == 'SHA2_family_i512_o256_r65' + assert sha2.id == 'SHA2_family_m512_o256_r65' assert sha2.component_from(0, 0).id == 'constant_0_0' sha2 = SHA2HashFunction(number_of_rounds=4) assert sha2.number_of_rounds == 4 - assert sha2.id == 'SHA2_family_i512_o256_r4' + assert sha2.id == 'SHA2_family_m512_o256_r4' assert sha2.component_from(3, 0).id == 'constant_3_0' sha2 = SHA2HashFunction() diff --git a/tests/unit/ciphers/hash_functions/whirlpool_hash_function_test.py b/tests/unit/ciphers/hash_functions/whirlpool_hash_function_test.py index 54f2f548b..7332229b1 100644 --- a/tests/unit/ciphers/hash_functions/whirlpool_hash_function_test.py +++ b/tests/unit/ciphers/hash_functions/whirlpool_hash_function_test.py @@ -4,12 +4,12 @@ def test_whirlpool_hash_function(): whirlpool = WhirlpoolHashFunction() assert whirlpool.type == 'hash_function' assert whirlpool.family_name == 'whirlpool_hash_function' - assert whirlpool.id == 'whirlpool_hash_function_i512_o512_r10' + assert whirlpool.id == 'whirlpool_hash_function_m512_o512_r10' assert whirlpool.component_from(0,0).id == 'constant_0_0' whirlpool = WhirlpoolHashFunction(number_of_rounds=4) assert whirlpool.number_of_rounds == 4 - assert whirlpool.id == 'whirlpool_hash_function_i512_o512_r4' + assert whirlpool.id == 'whirlpool_hash_function_m512_o512_r4' assert whirlpool.component_from(3,0).id == 'sbox_3_0' # The following test vector values have been obtained from the reference implementation of Whirlpool diff --git a/tests/unit/ciphers/permutations/grain_core_permutation_test.py b/tests/unit/ciphers/permutations/grain_core_permutation_test.py index 9e43e3163..dc9e50794 100644 --- a/tests/unit/ciphers/permutations/grain_core_permutation_test.py +++ b/tests/unit/ciphers/permutations/grain_core_permutation_test.py @@ -6,12 +6,12 @@ def test_grain_core_permutation(): assert grain_core.family_name == 'grain_core' assert grain_core.type == 'permutation' assert grain_core.number_of_rounds == 160 - assert grain_core.id == 'grain_core_i80_o80_r160' + assert grain_core.id == 'grain_core_s80_o80_r160' assert grain_core.component_from(0, 0).id == 'xor_0_0' grain_core = GrainCorePermutation(number_of_rounds=4) assert grain_core.number_of_rounds == 4 - assert grain_core.id == 'grain_core_i80_o80_r4' + assert grain_core.id == 'grain_core_s80_o80_r4' assert grain_core.component_from(3, 0).id == 'xor_3_0' grain_core = GrainCorePermutation() diff --git a/tests/unit/ciphers/permutations/spongent_pi_fsr_permutation_test.py b/tests/unit/ciphers/permutations/spongent_pi_fsr_permutation_test.py index 84d00dcd9..bc4a4f360 100644 --- a/tests/unit/ciphers/permutations/spongent_pi_fsr_permutation_test.py +++ b/tests/unit/ciphers/permutations/spongent_pi_fsr_permutation_test.py @@ -23,4 +23,4 @@ def test_spongent_pi_fsr_permutation(): spongentpi = SpongentPiFSRPermutation(state_bit_size=176, number_of_rounds=4) plaintext = 0x0123456789abcdef0123456789abcdef0123456789ab ciphertext = 0x8675478f97cafe723bf668c5e573ae9b582131499660 - assert spongentpi.evaluate([plaintext]) == ciphertext + assert spongentpi.evaluate([plaintext]) == ciphertext \ No newline at end of file diff --git a/tests/unit/ciphers/permutations/spongent_pi_precomputation_permutation_test.py b/tests/unit/ciphers/permutations/spongent_pi_precomputation_permutation_test.py index a3bd0ee34..ea4dd74c4 100644 --- a/tests/unit/ciphers/permutations/spongent_pi_precomputation_permutation_test.py +++ b/tests/unit/ciphers/permutations/spongent_pi_precomputation_permutation_test.py @@ -24,4 +24,4 @@ def test_spongent_pi_precomputation_permutation(): plaintext = 0x0123456789abcdef0123456789abcdef0123456789ab ciphertext = 0x8675478f97cafe723bf668c5e573ae9b582131499660 assert spongentpi.evaluate([plaintext]) == ciphertext - assert spongentpi.evaluate_vectorized([plaintext], evaluate_api=True) == ciphertext + assert spongentpi.evaluate_vectorized([plaintext], evaluate_api=True) == ciphertext \ No newline at end of file diff --git a/tests/unit/ciphers/stream_ciphers/a5_1_stream_cipher_test.py b/tests/unit/ciphers/stream_ciphers/a5_1_stream_cipher_test.py index 86ea1495e..be2bb350c 100644 --- a/tests/unit/ciphers/stream_ciphers/a5_1_stream_cipher_test.py +++ b/tests/unit/ciphers/stream_ciphers/a5_1_stream_cipher_test.py @@ -5,12 +5,11 @@ def test_a51(): assert a51.family_name == 'a51' assert a51.type == 'stream_cipher' assert a51.number_of_rounds == 229 - assert a51.id == 'a51_k64_i22_o228_r229' + assert a51.id == 'a51_k64_f22_o228_r229' assert a51.component_from(0, 0).id == 'constant_0_0' assert a51.component_from(1, 0).id == 'fsr_1_0' key = 0x48c4a2e691d5b3f7 frame = 0b0010110010000000000000 keystream = 0x534eaa582fe8151ab6e1855a728c093f4d68d757ed949b4cbe41b7c6b - assert a51.evaluate([key, frame]) == keystream - + assert a51.evaluate([key, frame]) == keystream \ No newline at end of file diff --git a/tests/unit/ciphers/stream_ciphers/a5_2_stream_cipher_test.py b/tests/unit/ciphers/stream_ciphers/a5_2_stream_cipher_test.py index 40f09b8b1..9f2790b0a 100644 --- a/tests/unit/ciphers/stream_ciphers/a5_2_stream_cipher_test.py +++ b/tests/unit/ciphers/stream_ciphers/a5_2_stream_cipher_test.py @@ -5,12 +5,11 @@ def test_a52(): assert a52.family_name == 'a52' assert a52.type == 'stream_cipher' assert a52.number_of_rounds == 229 - assert a52.id == 'a52_k64_i22_o228_r229' + assert a52.id == 'a52_k64_f22_o228_r229' assert a52.component_from(0, 0).id == 'constant_0_0' assert a52.component_from(1, 0).id == 'fsr_1_0' key = 0x003fffffffffffff frame = 0b1000010000000000000000 keystream = 0xf4512cac13593764460b722dadd51200350ca385a853735ee5c889944 - assert a52.evaluate([key, frame]) == keystream - + assert a52.evaluate([key, frame]) == keystream \ No newline at end of file diff --git a/tests/unit/ciphers/stream_ciphers/bivium_stream_cipher_test.py b/tests/unit/ciphers/stream_ciphers/bivium_stream_cipher_test.py index 5d37f1687..997d135d1 100644 --- a/tests/unit/ciphers/stream_ciphers/bivium_stream_cipher_test.py +++ b/tests/unit/ciphers/stream_ciphers/bivium_stream_cipher_test.py @@ -12,4 +12,4 @@ def test_bivium_stream_cipher_test_vector(): key = 0xffffffffff0000000000 iv = 0xffffffffff ks = 0xdebe55784f853606399af3f6f4b8d0a706963a91f2ba4c687baea16da074f3c3 - assert biv.evaluate([key, iv]) == ks + assert biv.evaluate([key, iv]) == ks \ No newline at end of file diff --git a/tests/unit/ciphers/stream_ciphers/bluetooth_stream_cipher_e0_test.py b/tests/unit/ciphers/stream_ciphers/bluetooth_stream_cipher_e0_test.py index be07cae95..92307445d 100644 --- a/tests/unit/ciphers/stream_ciphers/bluetooth_stream_cipher_e0_test.py +++ b/tests/unit/ciphers/stream_ciphers/bluetooth_stream_cipher_e0_test.py @@ -13,4 +13,4 @@ def test_bluetooth_stream_cipher_e0_test_vector(): key = 0xe22f92fff8c245c49d10359a02f1e555 input = int(hex(key << 4 | fsm), 16) # key.append(fsm) keystream = 0x1198636720bac54986d1ab5a494866c9 - assert e0.evaluate([input]) == keystream + assert e0.evaluate([input]) == keystream \ No newline at end of file diff --git a/tests/unit/ciphers/stream_ciphers/chacha_stream_cipher_test.py b/tests/unit/ciphers/stream_ciphers/chacha_stream_cipher_test.py index afd909737..cf2afb1e7 100644 --- a/tests/unit/ciphers/stream_ciphers/chacha_stream_cipher_test.py +++ b/tests/unit/ciphers/stream_ciphers/chacha_stream_cipher_test.py @@ -15,7 +15,6 @@ def test_chacha_stream_cipher(): assert chacha.component_from(3, 0).id == 'modadd_3_0' cipher = ChachaStreamCipher(number_of_rounds=40) - cipher.sort_cipher() plaintext = 0x61707865_3320646e_79622d32_6b206574_03020100_07060504_0b0a0908_0f0e0d0c_13121110_17161514_1b1a1918_1f1e1d1c_00000001_09000000_4a000000_00000000 key = 0x00010203_04050607_08090a0b_0c0d0e0f_10111213_14151617_18191a1b_1c1d1e1f nonce = 0x00000000_00000009_0000004a_00000000 diff --git a/tests/unit/ciphers/stream_ciphers/trivium_stream_cipher_test.py b/tests/unit/ciphers/stream_ciphers/trivium_stream_cipher_test.py index b61e7105c..9d0c4fb6d 100644 --- a/tests/unit/ciphers/stream_ciphers/trivium_stream_cipher_test.py +++ b/tests/unit/ciphers/stream_ciphers/trivium_stream_cipher_test.py @@ -6,4 +6,4 @@ def test_trivium_stream_cipher_test_vector(): key = 0x00000000000000000000 iv = 0x00000000000000000000 ks = 0xdf07fd641a9aa0d88a5e7472c4f993fe6a4cc06898e0f3b4e7159ef0854d97b3 - assert triv.evaluate([key, iv]) == ks + assert triv.evaluate([key, iv]) == ks \ No newline at end of file diff --git a/tests/unit/components/linear_layer_component_test.py b/tests/unit/components/linear_layer_component_test.py index b782c3fb7..6122ea896 100644 --- a/tests/unit/components/linear_layer_component_test.py +++ b/tests/unit/components/linear_layer_component_test.py @@ -84,14 +84,14 @@ def test_cp_deterministic_truncated_xor_differential_constraints(): assert constraints[0] == 'constraint if ((sbox_0_0[2] < 2) /\\ (sbox_0_0[3] < 2) /\\ (sbox_0_1[0] < 2) /\\ ' \ '(sbox_0_1[1] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[0] < 2) /\\ (sbox_0_2[1] < 2)' \ ' /\\ (sbox_0_3[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_5[1] < 2) /\\ ' \ - '(sbox_0_5[3]< 2)) then linear_layer_0_6[0] = (sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[0] ' \ + '(sbox_0_5[3] < 2)) then linear_layer_0_6[0] = (sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[0] ' \ '+ sbox_0_1[1] + sbox_0_1[3] + sbox_0_2[0] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_4[2] + ' \ 'sbox_0_5[1] + sbox_0_5[3]) mod 2 else linear_layer_0_6[0] = 2 endif;' assert constraints[-1] == 'constraint if ((sbox_0_0[0] < 2) /\\ (sbox_0_0[1] < 2) /\\ (sbox_0_0[2] < 2) /\\ ' \ '(sbox_0_0[3] < 2) /\\ (sbox_0_1[3] < 2) /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) ' \ '/\\ (sbox_0_3[2] < 2) /\\ (sbox_0_3[3] < 2) /\\ (sbox_0_4[1] < 2) /\\ ' \ '(sbox_0_4[2] < 2) /\\ (sbox_0_4[3] < 2) /\\ (sbox_0_5[1] < 2) /\\ (sbox_0_5[2] < 2) ' \ - '/\\ (sbox_0_5[3]< 2)) then linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] + ' \ + '/\\ (sbox_0_5[3] < 2)) then linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] + ' \ 'sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[3] + sbox_0_2[1] + sbox_0_3[1] + sbox_0_3[2] + ' \ 'sbox_0_3[3] + sbox_0_4[1] + sbox_0_4[2] + sbox_0_4[3] + sbox_0_5[1] + sbox_0_5[2] + ' \ 'sbox_0_5[3]) mod 2 else linear_layer_0_6[23] = 2 endif;' @@ -106,7 +106,7 @@ def test_cp_deterministic_truncated_xor_differential_constraints(): '(sbox_0_1[0] < 2) /\\ (sbox_0_1[1] < 2) /\\ (sbox_0_1[3] < 2)' \ ' /\\ (sbox_0_2[0] < 2) /\\ (sbox_0_2[1] < 2) /\\ ' \ '(sbox_0_3[1] < 2) /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_5[1] < 2) ' \ - '/\\ (sbox_0_5[3]< 2)) then linear_layer_0_6[0] = (sbox_0_0[2]' \ + '/\\ (sbox_0_5[3] < 2)) then linear_layer_0_6[0] = (sbox_0_0[2]' \ ' + sbox_0_0[3] + sbox_0_1[0] + sbox_0_1[1] + ' \ 'sbox_0_1[3] + sbox_0_2[0] + sbox_0_2[1] + sbox_0_3[1]' \ ' + sbox_0_4[2] + sbox_0_5[1] + sbox_0_5[3]) mod 2 else ' \ @@ -117,7 +117,7 @@ def test_cp_deterministic_truncated_xor_differential_constraints(): ' /\\ (sbox_0_2[1] < 2) /\\ (sbox_0_3[1] < 2) /\\ ' \ '(sbox_0_3[2] < 2) /\\ (sbox_0_3[3] < 2) /\\ (sbox_0_4[1] < 2)' \ ' /\\ (sbox_0_4[2] < 2) /\\ (sbox_0_4[3] < 2) /\\' \ - ' (sbox_0_5[1] < 2) /\\ (sbox_0_5[2] < 2) /\\ (sbox_0_5[3]< 2))' \ + ' (sbox_0_5[1] < 2) /\\ (sbox_0_5[2] < 2) /\\ (sbox_0_5[3] < 2))' \ ' then linear_layer_0_6[23] = (sbox_0_0[0] + sbox_0_0[1] +' \ ' sbox_0_0[2] + sbox_0_0[3] + sbox_0_1[3] + sbox_0_2[1]' \ ' + sbox_0_3[1] + sbox_0_3[2] + sbox_0_3[3] + ' \ diff --git a/tests/unit/components/mix_column_component_test.py b/tests/unit/components/mix_column_component_test.py index 9761a03f9..19d4b546a 100644 --- a/tests/unit/components/mix_column_component_test.py +++ b/tests/unit/components/mix_column_component_test.py @@ -71,11 +71,11 @@ def test_cp_deterministic_truncated_xor_differential_constraints(): assert declarations == [] assert constraints[0] == 'constraint if ((rot_0_17[1] < 2) /\\ (rot_0_18[0] < 2) /\\ (rot_0_18[1] < 2) /\\ ' \ - '(rot_0_19[0] < 2) /\\ (rot_0_20[0]< 2)) then mix_column_0_21[0] = (rot_0_17[1] + ' \ + '(rot_0_19[0] < 2) /\\ (rot_0_20[0] < 2)) then mix_column_0_21[0] = (rot_0_17[1] + ' \ 'rot_0_18[0] + rot_0_18[1] + rot_0_19[0] + rot_0_20[0]) mod 2 else ' \ 'mix_column_0_21[0] = 2 endif;' assert constraints[-1] == 'constraint if ((rot_0_17[0] < 2) /\\ (rot_0_17[7] < 2) /\\ (rot_0_18[7] < 2) /\\ ' \ - '(rot_0_19[7] < 2) /\\ (rot_0_20[0]< 2)) then mix_column_0_21[31] = (rot_0_17[0] + ' \ + '(rot_0_19[7] < 2) /\\ (rot_0_20[0] < 2)) then mix_column_0_21[31] = (rot_0_17[0] + ' \ 'rot_0_17[7] + rot_0_18[7] + rot_0_19[7] + rot_0_20[0]) mod 2 else ' \ 'mix_column_0_21[31] = 2 endif;' diff --git a/tests/compound_xor_differential_cipher_test.py b/tests/unit/compound_xor_differential_cipher_test.py similarity index 53% rename from tests/compound_xor_differential_cipher_test.py rename to tests/unit/compound_xor_differential_cipher_test.py index f25edca07..745065f6d 100644 --- a/tests/compound_xor_differential_cipher_test.py +++ b/tests/unit/compound_xor_differential_cipher_test.py @@ -1,96 +1,94 @@ -from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel -from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.cipher_modules.models.sat.sat_models.sat_cipher_model import SatCipherModel +from claasp.cipher_modules.models.sat.sat_models.sat_xor_differential_model import SatXorDifferentialModel +from claasp.cipher_modules.models.sat.solvers import CRYPTOMINISAT_EXT, KISSAT_EXT from claasp.cipher_modules.models.utils import set_fixed_variables, integer_to_bit_list -from claasp.name_mappings import CIPHER, XOR_DIFFERENTIAL +from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher +from claasp.name_mappings import CIPHER, SATISFIABLE, UNSATISFIABLE, XOR_DIFFERENTIAL -key_bit_size = 64 -block_bit_size = 32 -key_schedule_bit_size = 16 +KEY_BIT_SIZE = 64 +BLOCK_BIT_SIZE = 32 +KEY_SCHEDULE_BIT_SIZE = 16 -def get_round_key_values(dictionary, number_of_rounds, suffix=""): +def get_round_key_values(dictionary, number_of_rounds): round_key_values = [] for round_number in range(number_of_rounds): - component_id = get_intermediate_component_id_from_key_schedule(round_number, number_of_rounds, suffix) - if component_id != '': - round_key_values.append(int('0x' + dictionary[component_id]['value'], 16)) + component_id = get_intermediate_component_id_from_key_schedule(round_number, number_of_rounds) + if component_id != "": + round_key_values.append(int(dictionary[component_id]["value"], 16)) return round_key_values def get_round_data_values(dictionary, number_of_rounds, suffix=""): round_data_values = [] - for round_number in range(number_of_rounds+1): + for round_number in range(number_of_rounds + 1): component_id = get_intermediate_component_id_from_main_process(round_number, number_of_rounds, suffix) - if component_id != '': - round_data_values.append(int('0x' + dictionary[component_id]['value'], 16)) + if component_id != "": + round_data_values.append(int(dictionary[component_id]["value"], 16)) return round_data_values -def get_intermediate_component_id_from_key_schedule(round_number, number_of_rounds, suffix): - component_id = '' +def get_intermediate_component_id_from_key_schedule(round_number, number_of_rounds): + component_id = "" if round_number == 0: - component_id = f'intermediate_output_0_5' + component_id = "intermediate_output_0_5" if 0 < round_number < number_of_rounds: - component_id = f'intermediate_output_{round_number}_11' + component_id = f"intermediate_output_{round_number}_11" return component_id def get_intermediate_component_id_from_main_process(round_number, number_of_rounds, suffix): - component_id = '' + component_id = "" if round_number == 0: - component_id = f'plaintext{suffix}' + component_id = f"plaintext{suffix}" if round_number == 1: - component_id = f'intermediate_output_0_6' + component_id = "intermediate_output_0_6" if 1 < round_number < number_of_rounds: - component_id = f'intermediate_output_{round_number - 1}_12' + component_id = f"intermediate_output_{round_number - 1}_12" if round_number == number_of_rounds: - component_id = f'cipher_output_{number_of_rounds - 1}_12' + component_id = f"cipher_output_{number_of_rounds - 1}_12" return component_id def get_constraint(component_id_, bit_size, bit_values): - constraint_ = ( - set_fixed_variables( - component_id=component_id_, - constraint_type="equal", - bit_positions=list(range(bit_size)), - bit_values=bit_values) + constraint_ = set_fixed_variables( + component_id=component_id_, constraint_type="equal", bit_positions=list(range(bit_size)), bit_values=bit_values ) return constraint_ def get_constraints(list_key, list_data, key_differential, suffix=""): key_pair1_pair2 = set_fixed_variables( - component_id='key_pair1_pair2', - constraint_type='equal', - bit_positions=list(range(key_bit_size)), - bit_values=integer_to_bit_list(key_differential, key_bit_size, 'big')) + component_id="key_pair1_pair2", + constraint_type="equal", + bit_positions=list(range(KEY_BIT_SIZE)), + bit_values=integer_to_bit_list(key_differential, KEY_BIT_SIZE, "big"), + ) round_number = 0 component_ids = [] number_of_states = len(list_key) fixed_variables = [key_pair1_pair2] for num in list_key: - binary_list = integer_to_bit_list(num, key_schedule_bit_size, 'big') - component_id = get_intermediate_component_id_from_key_schedule(round_number, number_of_states, suffix) + binary_list = integer_to_bit_list(num, KEY_SCHEDULE_BIT_SIZE, "big") + component_id = get_intermediate_component_id_from_key_schedule(round_number, number_of_states) component_ids.append(component_id) - fixed_variables.append(get_constraint(component_id, key_schedule_bit_size, binary_list)) + fixed_variables.append(get_constraint(component_id, KEY_SCHEDULE_BIT_SIZE, binary_list)) round_number += 1 round_number = 0 number_of_states = len(list_data) - 1 for num in list_data: - binary_list = integer_to_bit_list(num, block_bit_size, 'big') + binary_list = integer_to_bit_list(num, BLOCK_BIT_SIZE, "big") component_id = get_intermediate_component_id_from_main_process(round_number, number_of_states, suffix) component_ids.append(component_id) - fixed_variables.append(get_constraint(component_id, block_bit_size, binary_list)) + fixed_variables.append(get_constraint(component_id, BLOCK_BIT_SIZE, binary_list)) round_number += 1 return fixed_variables, component_ids def test_satisfiable_differential_trail_related_key(): - speck = SpeckBlockCipher(number_of_rounds=14, block_bit_size=block_bit_size, key_bit_size=key_bit_size) + speck = SpeckBlockCipher(number_of_rounds=14, block_bit_size=BLOCK_BIT_SIZE, key_bit_size=KEY_BIT_SIZE) speck.convert_to_compound_xor_cipher() sat = SatCipherModel(speck) list_key = [ @@ -114,9 +112,9 @@ def test_satisfiable_differential_trail_related_key(): 0x14080008, 0x200000, 0x40004000, - 0xc0b1c0b0, - 0x667764b4, - 0x907002a1, + 0xC0B1C0B0, + 0x667764B4, + 0x907002A1, 0x8810205, 0x140800, 0x20000000, @@ -125,18 +123,18 @@ def test_satisfiable_differential_trail_related_key(): 0x0, 0x80008000, 0x1000102, - 0x8102850a, + 0x8102850A, ] - fixed_variables, component_ids = get_constraints(list_key, list_data, 0x0a80088000681000, '_pair1_pair2') + fixed_variables, _ = get_constraints(list_key, list_data, 0x0A80088000681000, "_pair1_pair2") sat.build_cipher_model(fixed_variables=fixed_variables) - assert sat.solve(CIPHER, solver_name="cryptominisat")["status"] == "SATISFIABLE" + assert sat.solve(CIPHER, solver_name=CRYPTOMINISAT_EXT)["status"] == SATISFIABLE def test_satisfiable_differential_trail_single_key(): - """ The following is an compatible trail presented in Table 5 of [SongHY16]_.""" + """The following is an compatible trail presented in Table 5 of [SongHY16]_.""" - speck = SpeckBlockCipher(number_of_rounds=10, block_bit_size=block_bit_size, key_bit_size=key_bit_size) + speck = SpeckBlockCipher(number_of_rounds=10, block_bit_size=BLOCK_BIT_SIZE, key_bit_size=KEY_BIT_SIZE) speck.convert_to_compound_xor_cipher() sat = SatCipherModel(speck) list_data = [ @@ -150,17 +148,17 @@ def test_satisfiable_differential_trail_single_key(): 0x00040000, 0x08000800, 0x08102810, - 0x0800A840 + 0x0800A840, ] - fixed_variables, component_ids = get_constraints([], list_data, 0x0, '_pair1_pair2') + fixed_variables, _ = get_constraints([], list_data, 0x0, "_pair1_pair2") sat.build_cipher_model(fixed_variables=fixed_variables) - assert sat.solve(CIPHER, solver_name="cryptominisat")["status"] == "SATISFIABLE" + assert sat.solve(CIPHER, solver_name=CRYPTOMINISAT_EXT)["status"] == SATISFIABLE def test_unsatisfiable_differential_trail_related_key(): - """ The following is an incompatible trail presented in Table 28 of [Sad2020]_.""" + """The following is an incompatible trail presented in Table 28 of [Sad2020]_.""" - speck = SpeckBlockCipher(number_of_rounds=14, block_bit_size=block_bit_size, key_bit_size=key_bit_size) + speck = SpeckBlockCipher(number_of_rounds=14, block_bit_size=BLOCK_BIT_SIZE, key_bit_size=KEY_BIT_SIZE) speck.convert_to_compound_xor_cipher() sat = SatCipherModel(speck) list_key = [ @@ -177,7 +175,7 @@ def test_unsatisfiable_differential_trail_related_key(): 0x87C0, 0x0042, 0x8140, - 0x0557 + 0x0557, ] list_data = [ @@ -195,97 +193,67 @@ def test_unsatisfiable_differential_trail_related_key(): 0x80028500, 0x80429440, 0x9000C102, - 0xC575C17E + 0xC575C17E, ] - fixed_variables, component_ids = get_constraints(list_key, list_data, 0x0001400008800025, '_pair1_pair2') + fixed_variables, _ = get_constraints(list_key, list_data, 0x0001400008800025, "_pair1_pair2") sat.build_cipher_model(fixed_variables=fixed_variables) - assert sat.solve(CIPHER, solver_name="cryptominisat")["status"] == "UNSATISFIABLE" + assert sat.solve(CIPHER, solver_name=CRYPTOMINISAT_EXT)["status"] == UNSATISFIABLE def test_satisfiable_differential_trail_single_key_generated_using_claasp(): - speck = SpeckBlockCipher(number_of_rounds=4, block_bit_size=block_bit_size, key_bit_size=key_bit_size) + speck = SpeckBlockCipher(number_of_rounds=4, block_bit_size=BLOCK_BIT_SIZE, key_bit_size=KEY_BIT_SIZE) sat = SatCipherModel(speck) sat_xor_diff_model = SatXorDifferentialModel( speck, ) - fixed_variables = [set_fixed_variables('key', 'not_equal', range(64), integer_to_bit_list(0, 64, 'little')), - set_fixed_variables('plaintext', 'not_equal', range(32), - integer_to_bit_list(0, 32, 'little'))] + fixed_variables = [ + set_fixed_variables("key", "not_equal", range(64), integer_to_bit_list(0, 64, "little")), + set_fixed_variables("plaintext", "not_equal", range(32), integer_to_bit_list(0, 32, "little")), + ] sat_xor_diff_model.build_xor_differential_trail_model(5, fixed_variables=fixed_variables) - sat_output = sat_xor_diff_model._solve_with_external_sat_solver('xor_differential', 'cryptominisat', []) + sat_output = sat_xor_diff_model._solve_with_external_sat_solver(XOR_DIFFERENTIAL, CRYPTOMINISAT_EXT, []) list_key = get_round_key_values(sat_output["components_values"], speck.number_of_rounds) list_data = get_round_data_values(sat_output["components_values"], speck.number_of_rounds) speck.convert_to_compound_xor_cipher() - fixed_variables, component_ids = get_constraints(list_key, list_data, 0x0) + fixed_variables, _ = get_constraints(list_key, list_data, 0x0) sat.build_cipher_model(fixed_variables=fixed_variables) - assert sat.solve(CIPHER, solver_name="cryptominisat")["status"] == "SATISFIABLE" + assert sat.solve(CIPHER, solver_name=CRYPTOMINISAT_EXT)["status"] == SATISFIABLE def test_build_xor_differential_model_and_checker_unsat(): - list_key = [ - 0x0, - 0x0, - 0x0, - 0x8000, - 0x8002, - 0xfff4, - 0x19bf, - 0x0e0d, - 0x3834, - 0x6090, - 0x0, - 0x0, - 0x8100, - 0x606, - 0x1e1e - ] + list_key = [0x0, 0x0, 0x0, 0x8000, 0x8002, 0xFFF4, 0x19BF, 0x0E0D, 0x3834, 0x6090, 0x0, 0x0, 0x8100, 0x606, 0x1E1E] list_data = [ 0x0, 0x0, 0x0, 0x80008000, 0x1020100, - 0xfb0aff0a, - 0xbb534778, - 0xe1fffc1e, - 0xfa7f0a04, + 0xFB0AFF0A, + 0xBB534778, + 0xE1FFFC1E, + 0xFA7F0A04, 0x28000010, 0x400000, 0x80008000, 0x2, - 0x0604060c, - 0x101e082e + 0x0604060C, + 0x101E082E, ] - fixed_variables, component_ids = get_constraints(list_key, list_data, 0x0040000000000000, '_pair1_pair2') + fixed_variables, _ = get_constraints(list_key, list_data, 0x0040000000000000, "_pair1_pair2") speck = SpeckBlockCipher(number_of_rounds=15) - sat = SatXorDifferentialModel(speck, window_size_by_round=[0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0]) + sat = SatXorDifferentialModel(speck) + sat.window_size_by_round_values = [0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0] sat.build_xor_differential_trail_and_checker_model_at_intermediate_output_level( 144, fixed_variables=fixed_variables ) - solution = sat.solve(XOR_DIFFERENTIAL, solver_name='kissat') - assert solution['status'] == 'UNSATISFIABLE' + solution = sat.solve(XOR_DIFFERENTIAL, solver_name=KISSAT_EXT) + assert solution["status"] == UNSATISFIABLE def test_build_xor_differential_model_and_checker_sat(): - list_key = [ - 0x0, - 0x40, - 0x8100, - 0x8002, - 0x8, - 0x00d8, - 0x400, - 0x1000, - 0x4001, - 0x0, - 0x0, - 0x200, - 0x0, - 0x0, - 0x4 - ] + list_key = [0x0, 0x40, 0x8100, 0x8002, 0x8, 0x00D8, 0x400, 0x1000, 0x4001, 0x0, 0x0, 0x200, 0x0, 0x0, 0x4] list_data = [ 0x28140810, 0x20400000, @@ -293,23 +261,21 @@ def test_build_xor_differential_model_and_checker_sat(): 0x2, 0x80008008, 0x81008122, - 0x8284860e, - 0x8b099333, - 0xb7d9fb17, - 0xfc771028, - 0x00a04000, + 0x8284860E, + 0x8B099333, + 0xB7D9FB17, + 0xFC771028, + 0x00A04000, 0x10000, 0x0, 0x0, 0x0, - 0x40004 - + 0x40004, ] - fixed_variables, component_ids = get_constraints(list_key, list_data, 0x8002204020000000, '_pair1_pair2') + fixed_variables, _ = get_constraints(list_key, list_data, 0x8002204020000000, "_pair1_pair2") speck = SpeckBlockCipher(number_of_rounds=15) - sat = SatXorDifferentialModel(speck, window_size_by_round=[0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0]) - sat.build_xor_differential_trail_and_checker_model_at_intermediate_output_level( - 92, fixed_variables=fixed_variables - ) - solution = sat.solve(XOR_DIFFERENTIAL, solver_name='kissat') - assert solution['status'] == 'SATISFIABLE' + sat = SatXorDifferentialModel(speck) + sat.window_size_by_round_values = [0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0] + sat.build_xor_differential_trail_and_checker_model_at_intermediate_output_level(92, fixed_variables=fixed_variables) + solution = sat.solve(XOR_DIFFERENTIAL, solver_name=KISSAT_EXT) + assert solution["status"] == SATISFIABLE diff --git a/tests/unit/utils/integer_test.py b/tests/unit/utils/integer_test.py index 8f3a6b70c..9deae213e 100644 --- a/tests/unit/utils/integer_test.py +++ b/tests/unit/utils/integer_test.py @@ -3,8 +3,8 @@ def test_generate_bitmask(): - assert bin(generate_bitmask(4)) == '0b1111' - assert hex(generate_bitmask(32)) == '0xffffffff' + assert generate_bitmask(4) == 0b1111 + assert generate_bitmask(32) == 0xffffffff def test_to_binary(): diff --git a/tests/unit/utils/utils_test.py b/tests/unit/utils/utils_test.py index 7a50197e3..4a89b5801 100644 --- a/tests/unit/utils/utils_test.py +++ b/tests/unit/utils/utils_test.py @@ -3,7 +3,6 @@ import claasp from claasp.utils.utils import point_pair -from claasp.utils.utils import get_k_th_bit from claasp.utils.utils import sgn_function from claasp.utils.utils import signed_distance from claasp.utils.utils import pprint_dictionary @@ -20,10 +19,6 @@ def test_bytes_positions_to_little_endian_for_32_bits(): assert bytes_positions_to_little_endian_for_32_bits(lst) == output_lst -def test_get_k_th_bit(): - assert get_k_th_bit(3, 0) == 1 - - def test_pprint_dictionary(): speck = SpeckBlockCipher(block_bit_size=16, key_bit_size=32, number_of_rounds=5) test = AvalancheTests(speck) @@ -43,6 +38,7 @@ def test_pprint_dictionary_to_file(): assert os.path.isfile(f"{tii_dir_path}/test_json") is True os.remove(f"{tii_dir_path}/test_json") + def test_sgn_function(): assert sgn_function(-1) == -1