From 1c48f7fc5340caa0dfae6d168a5bc2387d426f42 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 31 Jan 2025 16:19:11 +0100 Subject: [PATCH 1/9] Add Python tests --- .../pyarrow/tests/test_dataset_encryption.py | 132 ++++++++++++++++-- 1 file changed, 118 insertions(+), 14 deletions(-) diff --git a/python/pyarrow/tests/test_dataset_encryption.py b/python/pyarrow/tests/test_dataset_encryption.py index eb79121b1cdbe..284bdc0883547 100644 --- a/python/pyarrow/tests/test_dataset_encryption.py +++ b/python/pyarrow/tests/test_dataset_encryption.py @@ -66,11 +66,11 @@ def create_sample_table(): ) -def create_encryption_config(): +def create_encryption_config(footer_key, column_keys): return pe.EncryptionConfiguration( - footer_key=FOOTER_KEY_NAME, + footer_key=footer_key, plaintext_footer=False, - column_keys={COL_KEY_NAME: ["n_legs", "animal"]}, + column_keys=column_keys, encryption_algorithm="AES_GCM_V1", # requires timedelta or an assertion is raised cache_lifetime=timedelta(minutes=5.0), @@ -82,11 +82,11 @@ def create_decryption_config(): return pe.DecryptionConfiguration(cache_lifetime=300) -def create_kms_connection_config(): +def create_kms_connection_config(keys): return pe.KmsConnectionConfig( custom_kms_conf={ - FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), - COL_KEY_NAME: COL_KEY.decode("UTF-8"), + key_name: key.decode("UTF-8") if isinstance(key, bytes) else key + for key_name, key in keys.items() } ) @@ -95,15 +95,15 @@ def kms_factory(kms_connection_configuration): return InMemoryKmsClient(kms_connection_configuration) -@pytest.mark.skipif( - encryption_unavailable, reason="Parquet Encryption is not currently enabled" -) -def test_dataset_encryption_decryption(): - table = create_sample_table() - - encryption_config = create_encryption_config() +def do_test_dataset_encryption_decryption( + table, + footer_key=FOOTER_KEY_NAME, + column_keys={COL_KEY_NAME: ["n_legs", "animal"]}, + keys={FOOTER_KEY_NAME: FOOTER_KEY, COL_KEY_NAME: COL_KEY} +): + encryption_config = create_encryption_config(footer_key, column_keys) decryption_config = create_decryption_config() - kms_connection_config = create_kms_connection_config() + kms_connection_config = create_kms_connection_config(keys) crypto_factory = pe.CryptoFactory(kms_factory) parquet_encryption_cfg = ds.ParquetEncryptionConfig( @@ -155,6 +155,110 @@ def test_dataset_encryption_decryption(): assert table.equals(dataset.to_table()) +@pytest.mark.skipif( + encryption_unavailable, reason="Parquet Encryption is not currently enabled" +) +def test_dataset_encryption_decryption(): + do_test_dataset_encryption_decryption(create_sample_table()) + + +@pytest.mark.skipif( + encryption_unavailable, reason="Parquet Encryption is not currently enabled" +) +@pytest.mark.parametrize("column_name", ["list", "list.list.element"]) +def test_list_encryption_decryption(column_name): + list_data = pa.array( + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1], [-2], [-3]], + type=pa.list_(pa.int32()), + ) + table = create_sample_table().append_column("list", list_data) + + column_keys = {COL_KEY_NAME: ["animal", column_name]} + do_test_dataset_encryption_decryption(table, column_keys=column_keys) + + +@pytest.mark.skipif( + encryption_unavailable, + reason="Parquet Encryption is not currently enabled" +) +@pytest.mark.parametrize( + "column_name", + ["map", "map.key", "map.value", "map.key_value.key", "map.key_value.value"] +) +def test_map_encryption_decryption(column_name): + map_type = pa.map_(pa.string(), pa.int32()) + map_data = pa.array( + [ + [("k1", 1), ("k2", 2)], [("k1", 3), ("k3", 4)], [("k2", 5), ("k3", 6)], + [("k4", 7)], [], [] + ], + type=map_type + ) + table = create_sample_table().append_column("map", map_data) + + column_keys = {COL_KEY_NAME: ["animal", column_name]} + do_test_dataset_encryption_decryption(table, column_keys=column_keys) + + +@pytest.mark.skipif( + encryption_unavailable, reason="Parquet Encryption is not currently enabled" +) +@pytest.mark.parametrize("column_name", ["struct", "struct.f1", "struct.f2"]) +def test_struct_encryption_decryption(column_name): + struct_fields = [("f1", pa.int32()), ("f2", pa.string())] + struct_type = pa.struct(struct_fields) + struct_data = pa.array( + [(1, "one"), (2, "two"), (3, "three"), (4, "four"), (5, "five"), (6, "six")], + type=struct_type + ) + table = create_sample_table().append_column("struct", struct_data) + + column_keys = {COL_KEY_NAME: ["animal", column_name]} + do_test_dataset_encryption_decryption(table, column_keys=column_keys) + + +@pytest.mark.skipif( + encryption_unavailable, + reason="Parquet Encryption is not currently enabled" +) +@pytest.mark.parametrize( + "column_name", + [ + "col", + "col.list.element", + "col.list.element.key_value.key", + "col.list.element.key_value.value", + "col.list.element.key_value.value.f1", + "col.list.element.key_value.value.f2" + ] +) +def test_deep_nested_encryption_decryption(column_name): + struct_fields = [("f1", pa.int32()), ("f2", pa.string())] + struct_type = pa.struct(struct_fields) + struct1 = (1, "one") + struct2 = (2, "two") + struct3 = (3, "three") + struct4 = (4, "four") + struct5 = (5, "five") + struct6 = (6, "six") + + map_type = pa.map_(pa.int32(), struct_type) + map1 = {1: struct1, 2: struct2} + map2 = {3: struct3} + map3 = {4: struct4} + map4 = {5: struct5, 6: struct6} + + list_type = pa.list_(map_type) + list1 = [map1, map2] + list2 = [map3] + list3 = [map4] + list_data = [pa.array([list1, list2, None, list3, None, None], type=list_type)] + table = create_sample_table().append_column("col", list_data) + + column_keys = {COL_KEY_NAME: ["animal", column_name]} + do_test_dataset_encryption_decryption(table, column_keys=column_keys) + + @pytest.mark.skipif( not encryption_unavailable, reason="Parquet Encryption is currently enabled" ) From cb0f1e6bae2d115fc1c5fd57a4871cfafa16ce82 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 31 Jan 2025 21:16:27 +0100 Subject: [PATCH 2/9] Add C++ tests --- .../dataset/file_parquet_encryption_test.cc | 198 +++++++++++++++++- 1 file changed, 194 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc index 0287d593d12d3..39f45b7460bef 100644 --- a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include +#include #include #include "gtest/gtest.h" @@ -43,7 +48,7 @@ constexpr std::string_view kFooterKeyMasterKeyId = "footer_key"; constexpr std::string_view kFooterKeyName = "footer_key"; constexpr std::string_view kColumnMasterKey = "1234567890123450"; constexpr std::string_view kColumnMasterKeyId = "col_key"; -constexpr std::string_view kColumnKeyMapping = "col_key: a"; +constexpr std::string_view kColumnName = "a"; constexpr std::string_view kBaseDir = ""; using arrow::internal::checked_pointer_cast; @@ -90,7 +95,9 @@ class DatasetEncryptionTestBase : public ::testing::Test { auto encryption_config = std::make_shared( std::string(kFooterKeyName)); - encryption_config->column_keys = kColumnKeyMapping; + std::stringstream column_key; + column_key << kColumnMasterKeyId << ": " << ColumnKey(); + encryption_config->column_keys = column_key.str(); auto parquet_encryption_config = std::make_shared(); // Directly assign shared_ptr objects to ParquetEncryptionConfig members parquet_encryption_config->crypto_factory = crypto_factory_; @@ -118,6 +125,7 @@ class DatasetEncryptionTestBase : public ::testing::Test { } virtual void PrepareTableAndPartitioning() = 0; + virtual std::string_view ColumnKey() { return kColumnName; } void TestScanDataset() { // Create decryption properties. @@ -179,8 +187,9 @@ class DatasetEncryptionTest : public DatasetEncryptionTestBase { // The dataset is partitioned using a Hive partitioning scheme. void PrepareTableAndPartitioning() override { // Prepare table data. - auto table_schema = schema({field("a", int64()), field("c", int64()), - field("e", int64()), field("part", utf8())}); + auto table_schema = + schema({field(std::string(kColumnName), int64()), field("c", int64()), + field("e", int64()), field("part", utf8())}); table_ = TableFromJSON(table_schema, {R"([ [ 0, 9, 1, "a" ], [ 1, 8, 2, "a" ], @@ -240,6 +249,187 @@ TEST_F(DatasetEncryptionTest, ReadSingleFile) { ASSERT_EQ(checked_pointer_cast(table->column(2)->chunk(0))->GetView(0), 1); } +class NestedFieldsEncryptionTest : public DatasetEncryptionTestBase, + public ::testing::WithParamInterface { + public: + NestedFieldsEncryptionTest() : rand_gen(0) {} + + // The dataset is partitioned using a Hive partitioning scheme. + void PrepareTableAndPartitioning() override { + // Prepare table and partitioning. + auto table_schema = schema({field("a", std::move(column_type_))}); + table_ = arrow::Table::Make(table_schema, {column_data_}); + partitioning_ = std::make_shared(arrow::schema({})); + } + + std::string_view ColumnKey() override { return GetParam(); } + + protected: + std::shared_ptr column_type_; + std::shared_ptr column_data_; + arrow::random::RandomArrayGenerator rand_gen; +}; + +class ListFieldEncryptionTest : public NestedFieldsEncryptionTest { + public: + ListFieldEncryptionTest() { + arrow::MemoryPool* pool = arrow::default_memory_pool(); + auto value_builder = std::make_shared(pool); + arrow::ListBuilder list_builder = arrow::ListBuilder(pool, value_builder); + ARROW_CHECK_OK(list_builder.Append()); + ARROW_CHECK_OK(value_builder->Append(1)); + ARROW_CHECK_OK(value_builder->Append(2)); + ARROW_CHECK_OK(value_builder->Append(3)); + ARROW_CHECK_OK(list_builder.Append()); + ARROW_CHECK_OK(value_builder->Append(4)); + ARROW_CHECK_OK(value_builder->Append(5)); + ARROW_CHECK_OK(list_builder.Append()); + ARROW_CHECK_OK(value_builder->Append(6)); + + std::shared_ptr list_array; + arrow::Status status = list_builder.Finish(&list_array); + + column_type_ = list(int32()); + column_data_ = list_array; + } +}; + +class MapFieldEncryptionTest : public NestedFieldsEncryptionTest { + public: + MapFieldEncryptionTest() : NestedFieldsEncryptionTest() { + arrow::MemoryPool* pool = arrow::default_memory_pool(); + auto map_type = map(utf8(), int32()); + auto key_builder = std::make_shared(pool); + auto item_builder = std::make_shared(pool); + auto map_builder = + std::make_shared(pool, key_builder, item_builder, map_type); + ARROW_CHECK_OK(map_builder->Append()); + ARROW_CHECK_OK(key_builder->Append("one")); + ARROW_CHECK_OK(item_builder->Append(1)); + ARROW_CHECK_OK(map_builder->Append()); + ARROW_CHECK_OK(key_builder->Append("two")); + ARROW_CHECK_OK(item_builder->Append(2)); + ARROW_CHECK_OK(map_builder->Append()); + ARROW_CHECK_OK(key_builder->Append("three")); + ARROW_CHECK_OK(item_builder->Append(3)); + + std::shared_ptr map_array; + ARROW_CHECK_OK(map_builder->Finish(&map_array)); + + column_type_ = map_type; + column_data_ = map_array; + } +}; + +class StructFieldEncryptionTest : public NestedFieldsEncryptionTest { + public: + StructFieldEncryptionTest() : NestedFieldsEncryptionTest() { + arrow::MemoryPool* pool = arrow::default_memory_pool(); + auto struct_type = struct_({field("f1", int32()), field("f2", utf8())}); + auto f1_builder = std::make_shared(pool); + auto f2_builder = std::make_shared(pool); + std::vector> value_builders = {f1_builder, f2_builder}; + auto struct_builder = std::make_shared(std::move(struct_type), + pool, value_builders); + ARROW_CHECK_OK(struct_builder->Append()); + ARROW_CHECK_OK(f1_builder->Append(1)); + ARROW_CHECK_OK(f2_builder->Append("one")); + ARROW_CHECK_OK(struct_builder->Append()); + ARROW_CHECK_OK(f1_builder->Append(2)); + ARROW_CHECK_OK(f2_builder->Append("two")); + ARROW_CHECK_OK(struct_builder->Append()); + ARROW_CHECK_OK(f1_builder->Append(3)); + ARROW_CHECK_OK(f2_builder->Append("three")); + + std::shared_ptr struct_array; + ARROW_CHECK_OK(struct_builder->Finish(&struct_array)); + + column_type_ = struct_type; + column_data_ = struct_array; + } +}; + +class DeepNestedFieldEncryptionTest : public NestedFieldsEncryptionTest { + public: + DeepNestedFieldEncryptionTest() : NestedFieldsEncryptionTest() { + arrow::MemoryPool* pool = arrow::default_memory_pool(); + + auto struct_type = struct_({field("f1", int32()), field("f2", utf8())}); + auto f1_builder = std::make_shared(pool); + auto f2_builder = std::make_shared(pool); + std::vector> value_builders = {f1_builder, f2_builder}; + auto struct_builder = std::make_shared(std::move(struct_type), + pool, value_builders); + + auto map_type = map(int32(), struct_type); + auto key_builder = std::make_shared(pool); + auto item_builder = struct_builder; + auto map_builder = + std::make_shared(pool, key_builder, item_builder, map_type); + + auto list_type = list(map_type); + auto value_builder = map_builder; + arrow::ListBuilder list_builder = arrow::ListBuilder(pool, value_builder); + + ARROW_CHECK_OK(list_builder.Append()); + ARROW_CHECK_OK(value_builder->Append()); + + ARROW_CHECK_OK(key_builder->Append(1)); + ARROW_CHECK_OK(item_builder->Append()); + ARROW_CHECK_OK(f1_builder->Append(1)); + ARROW_CHECK_OK(f2_builder->Append("one")); + + ARROW_CHECK_OK(key_builder->Append(1)); + ARROW_CHECK_OK(item_builder->Append()); + ARROW_CHECK_OK(f1_builder->Append(2)); + ARROW_CHECK_OK(f2_builder->Append("two")); + + ARROW_CHECK_OK(value_builder->Append()); + + ARROW_CHECK_OK(key_builder->Append(3)); + ARROW_CHECK_OK(item_builder->Append()); + ARROW_CHECK_OK(f1_builder->Append(3)); + ARROW_CHECK_OK(f2_builder->Append("three")); + + ARROW_CHECK_OK(list_builder.Append()); + ARROW_CHECK_OK(value_builder->Append()); + + ARROW_CHECK_OK(key_builder->Append(4)); + ARROW_CHECK_OK(item_builder->Append()); + ARROW_CHECK_OK(f1_builder->Append(4)); + ARROW_CHECK_OK(f2_builder->Append("four")); + + std::shared_ptr list_array; + arrow::Status status = list_builder.Finish(&list_array); + + column_type_ = list_type; + column_data_ = list_array; + } +}; + +// Test writing and reading encrypted nested fields +INSTANTIATE_TEST_SUITE_P(List, ListFieldEncryptionTest, + ::testing::Values("a", "a.list.element")); +INSTANTIATE_TEST_SUITE_P(Map, MapFieldEncryptionTest, + ::testing::Values("a", "a.key", "a.value", "a.key_value.key", + "a.key_value.value")); +INSTANTIATE_TEST_SUITE_P(Struct, StructFieldEncryptionTest, + ::testing::Values("a", "a.f1", "a.f2")); +INSTANTIATE_TEST_SUITE_P(DeepNested, DeepNestedFieldEncryptionTest, + ::testing::Values("a", "a.list.element", + "a.list.element.key_value.key", + "a.list.element.key_value.value", + "a.list.element.key_value.value.f1", + "a.list.element.key_value.value.f2")); + +TEST_P(ListFieldEncryptionTest, ColumnKeys) { TestScanDataset(); } + +TEST_P(MapFieldEncryptionTest, ColumnKeys) { TestScanDataset(); } + +TEST_P(StructFieldEncryptionTest, ColumnKeys) { TestScanDataset(); } + +TEST_P(DeepNestedFieldEncryptionTest, ColumnKeys) { TestScanDataset(); } + // GH-39444: This test covers the case where parquet dataset scanner crashes when // processing encrypted datasets over 2^15 rows in multi-threaded mode. class LargeRowEncryptionTest : public DatasetEncryptionTestBase { From 88b6e096fc0c8e87d36244368027744b3a972758 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Sat, 1 Feb 2025 21:56:17 +0100 Subject: [PATCH 3/9] More elaborate Python column decryption tests --- .../pyarrow/tests/test_dataset_encryption.py | 69 ++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/tests/test_dataset_encryption.py b/python/pyarrow/tests/test_dataset_encryption.py index 284bdc0883547..f81c715236291 100644 --- a/python/pyarrow/tests/test_dataset_encryption.py +++ b/python/pyarrow/tests/test_dataset_encryption.py @@ -128,12 +128,77 @@ def do_test_dataset_encryption_decryption( filesystem=mockfs, ) - # read without decryption config -> should error is dataset was properly encrypted + # read without decryption config -> should error if dataset was properly encrypted pformat = pa.dataset.ParquetFileFormat() with pytest.raises(IOError, match=r"no decryption"): ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - # set decryption config for parquet fragment scan options + # helper method for following tests + def create_format_with_keys(keys): + kms_connection_config = create_kms_connection_config(keys) + parquet_decryption_cfg = ds.ParquetDecryptionConfig( + crypto_factory, kms_connection_config, decryption_config + ) + pq_scan_opts = ds.ParquetFragmentScanOptions( + decryption_config=parquet_decryption_cfg + ) + return pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) + + def assert_read_table_with_keys_success(keys, column_names): + pformat = create_format_with_keys(keys) + dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + assert table.select(column_names).equals(dataset.to_table(columns=column_names)) + + def assert_read_table_with_keys_failure(keys, column_names): + pformat = create_format_with_keys(keys) + # creating the dataset works + dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + with pytest.raises(KeyError, match=r"col_key"): + # reading those columns fails + _ = dataset.to_table(column_names) + + # some notable column names and keys + all_column_names = table.column_names + encrypted_column_names = [column_name + for key_name, column_names in column_keys.items() + for column_name in column_names] + plaintext_column_names = [column_name + for column_name in all_column_names + if column_name not in encrypted_column_names] + assert len(encrypted_column_names) > 0 + assert len(plaintext_column_names) > 0 + footer_key_only = {FOOTER_KEY_NAME: FOOTER_KEY} + column_keys_only = {key_name: key for key_name, key in keys.items() if key_name != FOOTER_KEY_NAME} + + # read with footer key only + assert_read_table_with_keys_success(footer_key_only, plaintext_column_names) + #assert_read_table_with_keys_failure(footer_key_only, encrypted_column_names) + #assert_read_table_with_keys_failure(footer_key_only, all_column_names) + + # read with all but footer key + if len(keys) > 1: + assert_read_table_with_keys_success(column_keys_only, plaintext_column_names) # TODO: this is wrong! + assert_read_table_with_keys_failure(column_keys_only, encrypted_column_names) + assert_read_table_with_keys_failure(column_keys_only, all_column_names) + + # with footer key and one column key, all plaintext and + # those encrypted columns that use that key, can be read + if len(keys) > 2: + for column_key_name, encrypted_column_names in column_keys.items(): + for encrypted_column_name in encrypted_column_names: + footer_key_and_one_column_key = {key_name: key for key_name, key in keys.items() + if key_name in [FOOTER_KEY_NAME, column_key_name]} + assert_read_table_with_keys_success(footer_key_and_one_column_key, plaintext_column_names) + assert_read_table_with_keys_success(footer_key_and_one_column_key, plaintext_column_names + [encrypted_column_name]) + assert_read_table_with_keys_failure(footer_key_and_one_column_key, encrypted_column_names) + assert_read_table_with_keys_failure(footer_key_and_one_column_key, all_column_names) + + # with all column keys, all columns can be read + assert_read_table_with_keys_success(keys, plaintext_column_names) + assert_read_table_with_keys_failure(keys, encrypted_column_names) # TODO: this is wrong! + assert_read_table_with_keys_failure(keys, all_column_names) + + # no matter how many keys are configured, test that whole table can be read pq_scan_opts = ds.ParquetFragmentScanOptions( decryption_config=parquet_decryption_cfg ) From 9e23454830c1cbfc30b7134a4c091ab48ab62201 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Sun, 2 Feb 2025 19:49:00 +0100 Subject: [PATCH 4/9] Rework elaborate Python tests --- python/pyarrow/tests/parquet/encryption.py | 2 + .../pyarrow/tests/test_dataset_encryption.py | 202 ++++++++++-------- 2 files changed, 111 insertions(+), 93 deletions(-) diff --git a/python/pyarrow/tests/parquet/encryption.py b/python/pyarrow/tests/parquet/encryption.py index d07f8ae273520..87476cdbb461f 100644 --- a/python/pyarrow/tests/parquet/encryption.py +++ b/python/pyarrow/tests/parquet/encryption.py @@ -41,6 +41,8 @@ def wrap_key(self, key_bytes, master_key_identifier): def unwrap_key(self, wrapped_key, master_key_identifier): """Not a secure cipher - just extract the key from the wrapped key""" + if master_key_identifier not in self.master_keys_map: + raise ValueError("Unknown master key", master_key_identifier) expected_master_key = self.master_keys_map[master_key_identifier] decoded_wrapped_key = base64.b64decode(wrapped_key) master_key_bytes = decoded_wrapped_key[:16] diff --git a/python/pyarrow/tests/test_dataset_encryption.py b/python/pyarrow/tests/test_dataset_encryption.py index f81c715236291..40bbc190a31b6 100644 --- a/python/pyarrow/tests/test_dataset_encryption.py +++ b/python/pyarrow/tests/test_dataset_encryption.py @@ -47,7 +47,11 @@ FOOTER_KEY_NAME = "footer_key" COL_KEY = b"1234567890123450" COL_KEY_NAME = "col_key" - +KEYS = {FOOTER_KEY_NAME: FOOTER_KEY, COL_KEY_NAME: COL_KEY} +EXTRA_COL_KEY = b"2345678901234501" +EXTRA_COL_KEY_NAME = "col2_key" +COLUMNS = ["year", "n_legs", "animal"] +COLUMN_KEYS = {COL_KEY_NAME: ["n_legs", "animal"]} def create_sample_table(): return pa.table( @@ -95,22 +99,83 @@ def kms_factory(kms_connection_configuration): return InMemoryKmsClient(kms_connection_configuration) -def do_test_dataset_encryption_decryption( +def do_test_dataset_encryption_decryption(table, extra_column_path=None): + # use extra column key for column extra_column_name if given + if extra_column_path: + keys = dict(**KEYS, **{EXTRA_COL_KEY_NAME: EXTRA_COL_KEY}) + column_keys = dict(**COLUMN_KEYS, **{EXTRA_COL_KEY_NAME: [extra_column_path]}) + extra_column_name = extra_column_path.split(".")[0] + else: + keys = KEYS + column_keys = COLUMN_KEYS + extra_column_name = None + + # some notable column names and keys + all_column_names = table.column_names + encrypted_column_names = [column_name.split(".")[0] + for key_name, column_names in column_keys.items() + for column_name in column_names] + plaintext_column_names = [column_name + for column_name in all_column_names + if column_name not in encrypted_column_names and + (extra_column_path is None or not extra_column_path.startswith(f"{column_name}."))] + assert len(encrypted_column_names) > 0 + assert len(plaintext_column_names) > 0 + footer_key_only = {FOOTER_KEY_NAME: FOOTER_KEY} + column_keys_only = {key_name: key for key_name, key in keys.items() if key_name != FOOTER_KEY_NAME} + + # read with footer key only + assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, plaintext_column_names, True) + assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, encrypted_column_names, False) + assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, all_column_names, False) + + # read with all but footer key + assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, plaintext_column_names, False, False) + assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, encrypted_column_names, False, False) + assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, all_column_names, False, False) + + # with footer key and one column key, all plaintext and + # those encrypted columns that use that key, can be read + if len(column_keys) > 1: + for column_key_name, column_key_column_names in column_keys.items(): + for encrypted_column_name in column_key_column_names: + encrypted_column_name = encrypted_column_name.split(".")[0] + footer_key_and_one_column_key = {key_name: key for key_name, key in keys.items() + if key_name in [FOOTER_KEY_NAME, column_key_name]} + assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, plaintext_column_names, + True) + assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, plaintext_column_names + [encrypted_column_name], + encrypted_column_name != extra_column_name) + assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, encrypted_column_names, + False) + assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, all_column_names, False) + + # with all column keys, all columns can be read + assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, plaintext_column_names, True) + assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, encrypted_column_names, True) + assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, all_column_names, True) + + +def assert_test_dataset_encryption_decryption( table, - footer_key=FOOTER_KEY_NAME, - column_keys={COL_KEY_NAME: ["n_legs", "animal"]}, - keys={FOOTER_KEY_NAME: FOOTER_KEY, COL_KEY_NAME: COL_KEY} + column_keys, + write_keys, + read_keys, + read_columns, + to_table_success, + dataset_success = True, ): - encryption_config = create_encryption_config(footer_key, column_keys) + encryption_config = create_encryption_config(FOOTER_KEY_NAME, column_keys) decryption_config = create_decryption_config() - kms_connection_config = create_kms_connection_config(keys) + encrypt_kms_connection_config = create_kms_connection_config(write_keys) + decrypt_kms_connection_config = create_kms_connection_config(read_keys) crypto_factory = pe.CryptoFactory(kms_factory) parquet_encryption_cfg = ds.ParquetEncryptionConfig( - crypto_factory, kms_connection_config, encryption_config + crypto_factory, encrypt_kms_connection_config, encryption_config ) parquet_decryption_cfg = ds.ParquetDecryptionConfig( - crypto_factory, kms_connection_config, decryption_config + crypto_factory, decrypt_kms_connection_config, decryption_config ) # create write_options with dataset encryption config @@ -133,91 +198,40 @@ def do_test_dataset_encryption_decryption( with pytest.raises(IOError, match=r"no decryption"): ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - # helper method for following tests - def create_format_with_keys(keys): - kms_connection_config = create_kms_connection_config(keys) - parquet_decryption_cfg = ds.ParquetDecryptionConfig( - crypto_factory, kms_connection_config, decryption_config - ) - pq_scan_opts = ds.ParquetFragmentScanOptions( - decryption_config=parquet_decryption_cfg - ) - return pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) - - def assert_read_table_with_keys_success(keys, column_names): - pformat = create_format_with_keys(keys) - dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - assert table.select(column_names).equals(dataset.to_table(columns=column_names)) - - def assert_read_table_with_keys_failure(keys, column_names): - pformat = create_format_with_keys(keys) - # creating the dataset works - dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - with pytest.raises(KeyError, match=r"col_key"): - # reading those columns fails - _ = dataset.to_table(column_names) - - # some notable column names and keys - all_column_names = table.column_names - encrypted_column_names = [column_name - for key_name, column_names in column_keys.items() - for column_name in column_names] - plaintext_column_names = [column_name - for column_name in all_column_names - if column_name not in encrypted_column_names] - assert len(encrypted_column_names) > 0 - assert len(plaintext_column_names) > 0 - footer_key_only = {FOOTER_KEY_NAME: FOOTER_KEY} - column_keys_only = {key_name: key for key_name, key in keys.items() if key_name != FOOTER_KEY_NAME} - - # read with footer key only - assert_read_table_with_keys_success(footer_key_only, plaintext_column_names) - #assert_read_table_with_keys_failure(footer_key_only, encrypted_column_names) - #assert_read_table_with_keys_failure(footer_key_only, all_column_names) - - # read with all but footer key - if len(keys) > 1: - assert_read_table_with_keys_success(column_keys_only, plaintext_column_names) # TODO: this is wrong! - assert_read_table_with_keys_failure(column_keys_only, encrypted_column_names) - assert_read_table_with_keys_failure(column_keys_only, all_column_names) - - # with footer key and one column key, all plaintext and - # those encrypted columns that use that key, can be read - if len(keys) > 2: - for column_key_name, encrypted_column_names in column_keys.items(): - for encrypted_column_name in encrypted_column_names: - footer_key_and_one_column_key = {key_name: key for key_name, key in keys.items() - if key_name in [FOOTER_KEY_NAME, column_key_name]} - assert_read_table_with_keys_success(footer_key_and_one_column_key, plaintext_column_names) - assert_read_table_with_keys_success(footer_key_and_one_column_key, plaintext_column_names + [encrypted_column_name]) - assert_read_table_with_keys_failure(footer_key_and_one_column_key, encrypted_column_names) - assert_read_table_with_keys_failure(footer_key_and_one_column_key, all_column_names) - - # with all column keys, all columns can be read - assert_read_table_with_keys_success(keys, plaintext_column_names) - assert_read_table_with_keys_failure(keys, encrypted_column_names) # TODO: this is wrong! - assert_read_table_with_keys_failure(keys, all_column_names) - - # no matter how many keys are configured, test that whole table can be read + # set decryption config for parquet fragment scan options pq_scan_opts = ds.ParquetFragmentScanOptions( decryption_config=parquet_decryption_cfg ) pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) - dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - - assert table.equals(dataset.to_table()) + if dataset_success: + dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + if to_table_success: + assert table.select(read_columns).equals(dataset.to_table(read_columns)) + else: + with pytest.raises(ValueError, match="Unknown master key"): + assert table.select(read_columns).equals(dataset.to_table(read_columns)) + else: + with pytest.raises(ValueError, match="Unknown master key"): + _ = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) # set decryption properties for parquet fragment scan options decryption_properties = crypto_factory.file_decryption_properties( - kms_connection_config, decryption_config) + decrypt_kms_connection_config, decryption_config) pq_scan_opts = ds.ParquetFragmentScanOptions( decryption_properties=decryption_properties ) pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) - dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - - assert table.equals(dataset.to_table()) + if dataset_success: + dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + if to_table_success: + assert table.select(read_columns).equals(dataset.to_table(read_columns)) + else: + with pytest.raises(ValueError, match="Unknown master key"): + assert table.select(read_columns).equals(dataset.to_table(read_columns)) + else: + with pytest.raises(ValueError, match="Unknown master key"): + _ = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) @pytest.mark.skipif( @@ -238,8 +252,7 @@ def test_list_encryption_decryption(column_name): ) table = create_sample_table().append_column("list", list_data) - column_keys = {COL_KEY_NAME: ["animal", column_name]} - do_test_dataset_encryption_decryption(table, column_keys=column_keys) + do_test_dataset_encryption_decryption(table, column_name) @pytest.mark.skipif( @@ -247,8 +260,7 @@ def test_list_encryption_decryption(column_name): reason="Parquet Encryption is not currently enabled" ) @pytest.mark.parametrize( - "column_name", - ["map", "map.key", "map.value", "map.key_value.key", "map.key_value.value"] + "column_name", ["map", "map.key", "map.value", "map.key_value.key", "map.key_value.value"] ) def test_map_encryption_decryption(column_name): map_type = pa.map_(pa.string(), pa.int32()) @@ -261,14 +273,15 @@ def test_map_encryption_decryption(column_name): ) table = create_sample_table().append_column("map", map_data) - column_keys = {COL_KEY_NAME: ["animal", column_name]} - do_test_dataset_encryption_decryption(table, column_keys=column_keys) + do_test_dataset_encryption_decryption(table, column_name) @pytest.mark.skipif( encryption_unavailable, reason="Parquet Encryption is not currently enabled" ) -@pytest.mark.parametrize("column_name", ["struct", "struct.f1", "struct.f2"]) +@pytest.mark.parametrize( + "column_name", [ "struct", "struct.f1", "struct.f2"] +) def test_struct_encryption_decryption(column_name): struct_fields = [("f1", pa.int32()), ("f2", pa.string())] struct_type = pa.struct(struct_fields) @@ -278,8 +291,7 @@ def test_struct_encryption_decryption(column_name): ) table = create_sample_table().append_column("struct", struct_data) - column_keys = {COL_KEY_NAME: ["animal", column_name]} - do_test_dataset_encryption_decryption(table, column_keys=column_keys) + do_test_dataset_encryption_decryption(table, column_name) @pytest.mark.skipif( @@ -290,6 +302,11 @@ def test_struct_encryption_decryption(column_name): "column_name", [ "col", + "col.element", + "col.element.key", + "col.element.value", + "col.element.value.f1", + "col.element.value.f2", "col.list.element", "col.list.element.key_value.key", "col.list.element.key_value.value", @@ -320,8 +337,7 @@ def test_deep_nested_encryption_decryption(column_name): list_data = [pa.array([list1, list2, None, list3, None, None], type=list_type)] table = create_sample_table().append_column("col", list_data) - column_keys = {COL_KEY_NAME: ["animal", column_name]} - do_test_dataset_encryption_decryption(table, column_keys=column_keys) + do_test_dataset_encryption_decryption(table, column_name) @pytest.mark.skipif( From 2f1b4c5736df8a5dcf5d337e8a749324843ef101 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Sun, 2 Feb 2025 19:53:10 +0100 Subject: [PATCH 5/9] Simplify elaborate Python tests --- .../pyarrow/tests/test_dataset_encryption.py | 221 ++++++++++-------- 1 file changed, 120 insertions(+), 101 deletions(-) diff --git a/python/pyarrow/tests/test_dataset_encryption.py b/python/pyarrow/tests/test_dataset_encryption.py index 40bbc190a31b6..b310aa4021e6b 100644 --- a/python/pyarrow/tests/test_dataset_encryption.py +++ b/python/pyarrow/tests/test_dataset_encryption.py @@ -16,6 +16,7 @@ # under the License. import base64 +from contextlib import contextmanager from datetime import timedelta import random import pyarrow.fs as fs @@ -53,6 +54,7 @@ COLUMNS = ["year", "n_legs", "animal"] COLUMN_KEYS = {COL_KEY_NAME: ["n_legs", "animal"]} + def create_sample_table(): return pa.table( { @@ -99,8 +101,17 @@ def kms_factory(kms_connection_configuration): return InMemoryKmsClient(kms_connection_configuration) +@contextmanager +def cond_raises(success, error_type, match): + if success: + yield + else: + with pytest.raises(error_type, match=match): + yield + + def do_test_dataset_encryption_decryption(table, extra_column_path=None): - # use extra column key for column extra_column_name if given + # use extra column key for column extra_column_path, if given if extra_column_path: keys = dict(**KEYS, **{EXTRA_COL_KEY_NAME: EXTRA_COL_KEY}) column_keys = dict(**COLUMN_KEYS, **{EXTRA_COL_KEY_NAME: [extra_column_path]}) @@ -110,6 +121,77 @@ def do_test_dataset_encryption_decryption(table, extra_column_path=None): column_keys = COLUMN_KEYS extra_column_name = None + # define the actual test + def assert_decrypts( + read_keys, + read_columns, + to_table_success, + dataset_success=True, + ): + # use all keys for writing + write_keys = keys + encryption_config = create_encryption_config(FOOTER_KEY_NAME, column_keys) + decryption_config = create_decryption_config() + encrypt_kms_connection_config = create_kms_connection_config(write_keys) + decrypt_kms_connection_config = create_kms_connection_config(read_keys) + + crypto_factory = pe.CryptoFactory(kms_factory) + parquet_encryption_cfg = ds.ParquetEncryptionConfig( + crypto_factory, encrypt_kms_connection_config, encryption_config + ) + parquet_decryption_cfg = ds.ParquetDecryptionConfig( + crypto_factory, decrypt_kms_connection_config, decryption_config + ) + + # create write_options with dataset encryption config + pformat = pa.dataset.ParquetFileFormat() + write_options = pformat.make_write_options( + encryption_config=parquet_encryption_cfg + ) + + mockfs = fs._MockFileSystem() + mockfs.create_dir("/") + + ds.write_dataset( + data=table, + base_dir="sample_dataset", + format=pformat, + file_options=write_options, + filesystem=mockfs, + ) + + # read without decryption config -> errors if dataset was properly encrypted + pformat = pa.dataset.ParquetFileFormat() + with pytest.raises(IOError, match=r"no decryption"): + ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + + # set decryption config for parquet fragment scan options + pq_scan_opts = ds.ParquetFragmentScanOptions( + decryption_config=parquet_decryption_cfg + ) + pformat = pa.dataset.ParquetFileFormat( + default_fragment_scan_options=pq_scan_opts + ) + with cond_raises(dataset_success, ValueError, match="Unknown master key"): + dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + with cond_raises(to_table_success, ValueError, match="Unknown master key"): + assert table.select(read_columns).equals(dataset.to_table(read_columns)) + + # set decryption properties for parquet fragment scan options + decryption_properties = crypto_factory.file_decryption_properties( + decrypt_kms_connection_config, decryption_config) + pq_scan_opts = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties + ) + + pformat = pa.dataset.ParquetFileFormat( + default_fragment_scan_options=pq_scan_opts + ) + with cond_raises(dataset_success, ValueError, match="Unknown master key"): + dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + with cond_raises(to_table_success, ValueError, match="Unknown master key"): + assert table.select(read_columns).equals(dataset.to_table(read_columns)) + # some notable column names and keys all_column_names = table.column_names encrypted_column_names = [column_name.split(".")[0] @@ -118,120 +200,55 @@ def do_test_dataset_encryption_decryption(table, extra_column_path=None): plaintext_column_names = [column_name for column_name in all_column_names if column_name not in encrypted_column_names and - (extra_column_path is None or not extra_column_path.startswith(f"{column_name}."))] + (extra_column_path is None or + not extra_column_path.startswith(f"{column_name}."))] assert len(encrypted_column_names) > 0 assert len(plaintext_column_names) > 0 footer_key_only = {FOOTER_KEY_NAME: FOOTER_KEY} - column_keys_only = {key_name: key for key_name, key in keys.items() if key_name != FOOTER_KEY_NAME} + column_keys_only = {key_name: key + for key_name, key in keys.items() + if key_name != FOOTER_KEY_NAME} + + # the test scenarios - # read with footer key only - assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, plaintext_column_names, True) - assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, encrypted_column_names, False) - assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_only, all_column_names, False) + # read with footer key only, can only read plaintext columns + assert_decrypts(footer_key_only, plaintext_column_names, True) + assert_decrypts(footer_key_only, encrypted_column_names, False) + assert_decrypts(footer_key_only, all_column_names, False) - # read with all but footer key - assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, plaintext_column_names, False, False) - assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, encrypted_column_names, False, False) - assert_test_dataset_encryption_decryption(table, column_keys, keys, column_keys_only, all_column_names, False, False) + # read with all but footer key, cannot read any columns + assert_decrypts(column_keys_only, plaintext_column_names, False, False) + assert_decrypts(column_keys_only, encrypted_column_names, False, False) + assert_decrypts(column_keys_only, all_column_names, False, False) # with footer key and one column key, all plaintext and # those encrypted columns that use that key, can be read if len(column_keys) > 1: for column_key_name, column_key_column_names in column_keys.items(): for encrypted_column_name in column_key_column_names: + # if one nested field of a column is encrypted, + # the entire column is considered encrypted encrypted_column_name = encrypted_column_name.split(".")[0] - footer_key_and_one_column_key = {key_name: key for key_name, key in keys.items() - if key_name in [FOOTER_KEY_NAME, column_key_name]} - assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, plaintext_column_names, - True) - assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, plaintext_column_names + [encrypted_column_name], - encrypted_column_name != extra_column_name) - assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, encrypted_column_names, - False) - assert_test_dataset_encryption_decryption(table, column_keys, keys, footer_key_and_one_column_key, all_column_names, False) - - # with all column keys, all columns can be read - assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, plaintext_column_names, True) - assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, encrypted_column_names, True) - assert_test_dataset_encryption_decryption(table, column_keys, keys, keys, all_column_names, True) - - -def assert_test_dataset_encryption_decryption( - table, - column_keys, - write_keys, - read_keys, - read_columns, - to_table_success, - dataset_success = True, -): - encryption_config = create_encryption_config(FOOTER_KEY_NAME, column_keys) - decryption_config = create_decryption_config() - encrypt_kms_connection_config = create_kms_connection_config(write_keys) - decrypt_kms_connection_config = create_kms_connection_config(read_keys) - - crypto_factory = pe.CryptoFactory(kms_factory) - parquet_encryption_cfg = ds.ParquetEncryptionConfig( - crypto_factory, encrypt_kms_connection_config, encryption_config - ) - parquet_decryption_cfg = ds.ParquetDecryptionConfig( - crypto_factory, decrypt_kms_connection_config, decryption_config - ) - - # create write_options with dataset encryption config - pformat = pa.dataset.ParquetFileFormat() - write_options = pformat.make_write_options(encryption_config=parquet_encryption_cfg) - mockfs = fs._MockFileSystem() - mockfs.create_dir("/") + # decrypt with footer key and one column key + read_keys = {key_name: key + for key_name, key in keys.items() + if key_name in [FOOTER_KEY_NAME, column_key_name]} - ds.write_dataset( - data=table, - base_dir="sample_dataset", - format=pformat, - file_options=write_options, - filesystem=mockfs, - ) + # that one encrypted column can only be read + # if it is not a column path / nested field + plaintext_and_one_success = encrypted_column_name != extra_column_name + plaintext_and_one = plaintext_column_names + [encrypted_column_name] - # read without decryption config -> should error if dataset was properly encrypted - pformat = pa.dataset.ParquetFileFormat() - with pytest.raises(IOError, match=r"no decryption"): - ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - - # set decryption config for parquet fragment scan options - pq_scan_opts = ds.ParquetFragmentScanOptions( - decryption_config=parquet_decryption_cfg - ) - pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) - if dataset_success: - dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - if to_table_success: - assert table.select(read_columns).equals(dataset.to_table(read_columns)) - else: - with pytest.raises(ValueError, match="Unknown master key"): - assert table.select(read_columns).equals(dataset.to_table(read_columns)) - else: - with pytest.raises(ValueError, match="Unknown master key"): - _ = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - - # set decryption properties for parquet fragment scan options - decryption_properties = crypto_factory.file_decryption_properties( - decrypt_kms_connection_config, decryption_config) - pq_scan_opts = ds.ParquetFragmentScanOptions( - decryption_properties=decryption_properties - ) + assert_decrypts(read_keys, plaintext_column_names, True) + assert_decrypts(read_keys, plaintext_and_one, plaintext_and_one_success) + assert_decrypts(read_keys, encrypted_column_names, False) + assert_decrypts(read_keys, all_column_names, False) - pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts) - if dataset_success: - dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) - if to_table_success: - assert table.select(read_columns).equals(dataset.to_table(read_columns)) - else: - with pytest.raises(ValueError, match="Unknown master key"): - assert table.select(read_columns).equals(dataset.to_table(read_columns)) - else: - with pytest.raises(ValueError, match="Unknown master key"): - _ = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs) + # with all column keys, all columns can be read + assert_decrypts(keys, plaintext_column_names, True) + assert_decrypts(keys, encrypted_column_names, True) + assert_decrypts(keys, all_column_names, True) @pytest.mark.skipif( @@ -260,7 +277,9 @@ def test_list_encryption_decryption(column_name): reason="Parquet Encryption is not currently enabled" ) @pytest.mark.parametrize( - "column_name", ["map", "map.key", "map.value", "map.key_value.key", "map.key_value.value"] + "column_name", [ + "map", "map.key", "map.value", "map.key_value.key", "map.key_value.value" + ] ) def test_map_encryption_decryption(column_name): map_type = pa.map_(pa.string(), pa.int32()) @@ -280,7 +299,7 @@ def test_map_encryption_decryption(column_name): encryption_unavailable, reason="Parquet Encryption is not currently enabled" ) @pytest.mark.parametrize( - "column_name", [ "struct", "struct.f1", "struct.f2"] + "column_name", ["struct", "struct.f1", "struct.f2"] ) def test_struct_encryption_decryption(column_name): struct_fields = [("f1", pa.int32()), ("f2", pa.string())] From 04409cca06d0ce4fbd2228416b121f30ef7efaf5 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 6 Feb 2025 06:57:32 +0100 Subject: [PATCH 6/9] Provide Node's schema path --- cpp/src/parquet/schema.cc | 21 +++++++++++++++++++++ cpp/src/parquet/schema.h | 5 +++++ cpp/src/parquet/schema_test.cc | 27 +++++++++++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/cpp/src/parquet/schema.cc b/cpp/src/parquet/schema.cc index 47fa72d829658..c740780dccd88 100644 --- a/cpp/src/parquet/schema.cc +++ b/cpp/src/parquet/schema.cc @@ -67,11 +67,24 @@ std::shared_ptr ColumnPath::FromDotString(const std::string& dotstri } std::shared_ptr ColumnPath::FromNode(const Node& node) { + return FromNode(node, false); +} + +std::shared_ptr ColumnPath::FromNode(const Node& node, bool schema_path) { // Build the path in reverse order as we traverse the nodes to the top std::vector rpath_; const Node* cursor = &node; // The schema node is not part of the ColumnPath while (cursor->parent()) { + if (schema_path && ( + // nested fields in arrow schema do not know these intermediate nodes + cursor->parent()->converted_type() == ConvertedType::MAP || + cursor->parent()->converted_type() == ConvertedType::LIST + ) + ) { + cursor = cursor->parent(); + continue; + } rpath_.push_back(cursor->name()); cursor = cursor->parent(); } @@ -113,6 +126,10 @@ const std::shared_ptr Node::path() const { return ColumnPath::FromNode(*this); } +const std::shared_ptr Node::schema_path() const { + return ColumnPath::FromNode(*this, true); +} + bool Node::EqualsInternal(const Node* other) const { return type_ == other->type_ && name_ == other->name_ && repetition_ == other->repetition_ && converted_type_ == other->converted_type_ && @@ -960,4 +977,8 @@ const std::shared_ptr ColumnDescriptor::path() const { return primitive_node_->path(); } +const std::shared_ptr ColumnDescriptor::schema_path() const { + return primitive_node_->schema_path(); +} + } // namespace parquet diff --git a/cpp/src/parquet/schema.h b/cpp/src/parquet/schema.h index 1addc73bd367d..08881728e1f6f 100644 --- a/cpp/src/parquet/schema.h +++ b/cpp/src/parquet/schema.h @@ -84,6 +84,7 @@ class PARQUET_EXPORT ColumnPath { static std::shared_ptr FromDotString(const std::string& dotstring); static std::shared_ptr FromNode(const Node& node); + static std::shared_ptr FromNode(const Node& node, bool filter_converted_types); std::shared_ptr extend(const std::string& node_name) const; std::string ToDotString() const; @@ -132,6 +133,8 @@ class PARQUET_EXPORT Node { const std::shared_ptr path() const; + const std::shared_ptr schema_path() const; + virtual void ToParquet(void* element) const = 0; // Node::Visitor abstract class for walking schemas with the visitor pattern @@ -386,6 +389,8 @@ class PARQUET_EXPORT ColumnDescriptor { const std::shared_ptr path() const; + const std::shared_ptr schema_path() const; + const schema::NodePtr& schema_node() const { return node_; } std::string ToString() const; diff --git a/cpp/src/parquet/schema_test.cc b/cpp/src/parquet/schema_test.cc index 2532a8656e69f..fa6167b7132c0 100644 --- a/cpp/src/parquet/schema_test.cc +++ b/cpp/src/parquet/schema_test.cc @@ -110,6 +110,33 @@ TEST(TestColumnPath, TestAttrs) { ASSERT_EQ(extended->ToDotString(), "toplevel.leaf.anotherlevel"); } +TEST(TestColumnPath, FromNode) { + auto key = PrimitiveNode::Make("key", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32); + auto key_value = GroupNode::Make("key_value", Repetition::REQUIRED, {key}, ConvertedType::NONE); + auto map = GroupNode::Make("a", Repetition::REQUIRED, {key_value}, ConvertedType::MAP); + + auto element = PrimitiveNode::Make("element", Repetition::REPEATED, Type::INT32, ConvertedType::INT_32); + auto inner_list = GroupNode::Make("list", Repetition::REQUIRED, {element}, ConvertedType::NONE); + auto list = GroupNode::Make("b", Repetition::REQUIRED, {inner_list}, ConvertedType::LIST); + + auto f1 = PrimitiveNode::Make("f1", Repetition::OPTIONAL, Type::INT32, ConvertedType::INT_32); + auto f2 = PrimitiveNode::Make("f2", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8); + auto struct_ = GroupNode::Make("c", Repetition::REQUIRED, {f1, f2}, ConvertedType::NONE); + + auto schema = GroupNode::Make("schema", Repetition::REQUIRED, {map, list, struct_}, ConvertedType::NONE); + + ASSERT_EQ(ColumnPath::FromNode(*key)->ToDotString(), "a.key_value.key"); + ASSERT_EQ(ColumnPath::FromNode(*key, true)->ToDotString(), "a.key"); + + ASSERT_EQ(ColumnPath::FromNode(*element)->ToDotString(), "b.list.element"); + ASSERT_EQ(ColumnPath::FromNode(*element, true)->ToDotString(), "b.element"); + + ASSERT_EQ(ColumnPath::FromNode(*f1)->ToDotString(), "c.f1"); + ASSERT_EQ(ColumnPath::FromNode(*f1, true)->ToDotString(), "c.f1"); + ASSERT_EQ(ColumnPath::FromNode(*f2)->ToDotString(), "c.f2"); + ASSERT_EQ(ColumnPath::FromNode(*f2, true)->ToDotString(), "c.f2"); +} + // ---------------------------------------------------------------------- // Primitive node From b7cdf8496d7d2b477ecbba4a6f809a6ab255467b Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 6 Feb 2025 15:59:50 +0100 Subject: [PATCH 7/9] Move encrypted column check into FileEncryptionProperties::encrypt_schema --- cpp/src/parquet/encryption/encryption.cc | 23 +++++++++++++++++++++++ cpp/src/parquet/encryption/encryption.h | 10 ++++++++++ cpp/src/parquet/file_writer.cc | 23 +++-------------------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/cpp/src/parquet/encryption/encryption.cc b/cpp/src/parquet/encryption/encryption.cc index 731120d9a6396..64d105d874b43 100644 --- a/cpp/src/parquet/encryption/encryption.cc +++ b/cpp/src/parquet/encryption/encryption.cc @@ -217,6 +217,29 @@ FileEncryptionProperties::Builder* FileEncryptionProperties::Builder::encrypted_ return this; } +void FileEncryptionProperties::encrypt_schema(const SchemaDescriptor& schema) { + // Check that all columns in columnEncryptionProperties exist in the schema. + auto encrypted_columns = encrypted_columns_; + // if columnEncryptionProperties is empty, every column in file schema will be + // encrypted with footer key. + if (encrypted_columns.size() != 0) { + std::vector column_path_vec; + // First, save all column paths in schema. + for (int i = 0; i < schema.num_columns(); i++) { + column_path_vec.push_back(schema.Column(i)->path()->ToDotString()); + } + // Check if column exists in schema. + for (const auto& elem : encrypted_columns) { + auto it = std::find(column_path_vec.begin(), column_path_vec.end(), elem.first); + if (it == column_path_vec.end()) { + std::stringstream ss; + ss << "Encrypted column " + elem.first + " not in file schema"; + throw ParquetException(ss.str()); + } + } + } +} + void FileEncryptionProperties::WipeOutEncryptionKeys() { footer_key_.clear(); for (const auto& element : encrypted_columns_) { diff --git a/cpp/src/parquet/encryption/encryption.h b/cpp/src/parquet/encryption/encryption.h index 1ddef9e8236db..99d8aca7e6811 100644 --- a/cpp/src/parquet/encryption/encryption.h +++ b/cpp/src/parquet/encryption/encryption.h @@ -498,6 +498,16 @@ class PARQUET_EXPORT FileEncryptionProperties { return encrypted_columns_; } + /// All columns in encrypted_columns must refer to columns in the given schema. + /// They can also refer to parent fields if schema contains nested fields. Then + /// all those nested fields of a matching parent field are encrypted by the same key. + /// This modifies encrypted_columns to reflect this. + /// + /// Columns in encrypted_columns can refer to the parquet column paths as well as the + /// schema paths of columns. Those are usually identical, except for nested fields of + /// lists and maps. + void encrypt_schema(const SchemaDescriptor& schema); + private: EncryptionAlgorithm algorithm_; std::string footer_key_; diff --git a/cpp/src/parquet/file_writer.cc b/cpp/src/parquet/file_writer.cc index f80a095a13587..16b582585d9ee 100644 --- a/cpp/src/parquet/file_writer.cc +++ b/cpp/src/parquet/file_writer.cc @@ -481,26 +481,9 @@ class FileSerializer : public ParquetFileWriter::Contents { // Unencrypted parquet files always start with PAR1 PARQUET_THROW_NOT_OK(sink_->Write(kParquetMagic, 4)); } else { - // Check that all columns in columnEncryptionProperties exist in the schema. - auto encrypted_columns = file_encryption_properties->encrypted_columns(); - // if columnEncryptionProperties is empty, every column in file schema will be - // encrypted with footer key. - if (encrypted_columns.size() != 0) { - std::vector column_path_vec; - // First, save all column paths in schema. - for (int i = 0; i < num_columns(); i++) { - column_path_vec.push_back(schema_.Column(i)->path()->ToDotString()); - } - // Check if column exists in schema. - for (const auto& elem : encrypted_columns) { - auto it = std::find(column_path_vec.begin(), column_path_vec.end(), elem.first); - if (it == column_path_vec.end()) { - std::stringstream ss; - ss << "Encrypted column " + elem.first + " not in file schema"; - throw ParquetException(ss.str()); - } - } - } + // make the file encryption encrypt this schema + // this modifies file_encryption_properties->encrypted_columns() + file_encryption_properties->encrypt_schema(schema_); file_encryptor_ = std::make_unique( file_encryption_properties, properties_->memory_pool()); From 2b1f73234aa1eac2b4afae921316fbf2166fc950 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 6 Feb 2025 16:00:28 +0100 Subject: [PATCH 8/9] Improve resolving encrypted columns --- cpp/src/parquet/encryption/encryption.cc | 57 ++++++++++++++++++++---- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/cpp/src/parquet/encryption/encryption.cc b/cpp/src/parquet/encryption/encryption.cc index 64d105d874b43..acf675156fbff 100644 --- a/cpp/src/parquet/encryption/encryption.cc +++ b/cpp/src/parquet/encryption/encryption.cc @@ -219,21 +219,62 @@ FileEncryptionProperties::Builder* FileEncryptionProperties::Builder::encrypted_ void FileEncryptionProperties::encrypt_schema(const SchemaDescriptor& schema) { // Check that all columns in columnEncryptionProperties exist in the schema. - auto encrypted_columns = encrypted_columns_; + // Copy the encrypted_columns map as we are going to modify it while iterating it + auto encrypted_columns = ColumnPathToEncryptionPropertiesMap(encrypted_columns_); // if columnEncryptionProperties is empty, every column in file schema will be // encrypted with footer key. if (encrypted_columns.size() != 0) { - std::vector column_path_vec; - // First, save all column paths in schema. + std::vector> column_path_vec; + // First, memorize all column or schema paths of the schema as dot-strings. for (int i = 0; i < schema.num_columns(); i++) { - column_path_vec.push_back(schema.Column(i)->path()->ToDotString()); + auto column = schema.Column(i); + auto column_path = column->path()->ToDotString(); + auto schema_path = column->schema_path()->ToDotString(); + column_path_vec.emplace_back(column_path, column_path); + if (schema_path != column_path) { + column_path_vec.emplace_back(schema_path, column_path); + } } - // Check if column exists in schema. + // Sort them alphabetically, so that we can use binary-search and look up parent columns. + std::sort(column_path_vec.begin(), column_path_vec.end()); + + // Check if encrypted column exists in schema, or if it is a parent field of a column. for (const auto& elem : encrypted_columns) { - auto it = std::find(column_path_vec.begin(), column_path_vec.end(), elem.first); - if (it == column_path_vec.end()) { + auto& encrypted_column = elem.first; + auto encrypted_column_len = encrypted_column.size(); + + // first we look up encrypted_columns as + // find first column that equals encrypted_column or starts with encrypted_column + auto it = std::lower_bound( + column_path_vec.begin(), column_path_vec.end(), encrypted_column, + [&](const std::pair& item, const std::string& term) { + return item.first < term; + }); + bool matches = false; + + // encrypted_column encrypts column 'it' when 'it' is either equal to encrypted_column, + // or 'it' starts with encrypted_column followed by a '.' + while (it != column_path_vec.end() && (it->first == encrypted_column || + (it->first.size() > encrypted_column_len && it->first.substr(0, encrypted_column_len) == encrypted_column && it->first.at(encrypted_column_len) == '.') + )) { + // count columns encrypted by encrypted_column + matches = true; + + // add column 'it' to file_encryption_properties.encrypted_columns + // when encrypted_column is a parent column + if (it->second != encrypted_column) { + encrypted_columns_.erase(encrypted_column); + encrypted_columns_.emplace(it->second, elem.second); + } + + // move to next match + ++it; + } + + // check encrypted_column matches any existing column + if (!matches) { std::stringstream ss; - ss << "Encrypted column " + elem.first + " not in file schema"; + ss << "Encrypted column " + encrypted_column + " not in file schema"; throw ParquetException(ss.str()); } } From ba57bdc5944da64bc486536298d8326afcc26a42 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Thu, 6 Feb 2025 14:51:45 +0100 Subject: [PATCH 9/9] Test encrypted column check --- .../encryption/write_configurations_test.cc | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/cpp/src/parquet/encryption/write_configurations_test.cc b/cpp/src/parquet/encryption/write_configurations_test.cc index f27da82694874..3b6f79d254cb0 100644 --- a/cpp/src/parquet/encryption/write_configurations_test.cc +++ b/cpp/src/parquet/encryption/write_configurations_test.cc @@ -223,6 +223,119 @@ TEST_F(TestEncryptionConfiguration, EncryptTwoColumnsAndFooterUseAES_GCM_CTR) { "tmp_encrypt_columns_and_footer_ctr.parquet.encrypted")); } +TEST(TestFileEncryptionProperties, EncryptSchema) { + std::string kFooterEncryptionKey_ = std::string(kFooterEncryptionKey); + std::string kColumnEncryptionKey_ = std::string(kColumnEncryptionKey1); + + std::map> + encryption_cols; + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_21( + "a_map"); + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_22( + "a_list"); + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_23( + "a_struct"); + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_24( + "b_map.key"); + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_25( + "b_map.key_value.value"); + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_26( + "b_list.list.element"); + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_27( + "b_struct.f1"); + parquet::ColumnEncryptionProperties::Builder encryption_col_builder_28( + "c_list.element"); + + encryption_col_builder_21.key(kColumnEncryptionKey_)->key_id("kc1"); + encryption_col_builder_22.key(kColumnEncryptionKey_)->key_id("kc1"); + encryption_col_builder_23.key(kColumnEncryptionKey_)->key_id("kc1"); + encryption_col_builder_24.key(kColumnEncryptionKey_)->key_id("kc1"); + encryption_col_builder_25.key(kColumnEncryptionKey_)->key_id("kc1"); + encryption_col_builder_26.key(kColumnEncryptionKey_)->key_id("kc1"); + encryption_col_builder_27.key(kColumnEncryptionKey_)->key_id("kc1"); + encryption_col_builder_28.key(kColumnEncryptionKey_)->key_id("kc1"); + + encryption_cols["a_map"] = encryption_col_builder_21.build(); + encryption_cols["a_list"] = encryption_col_builder_22.build(); + encryption_cols["a_struct"] = encryption_col_builder_23.build(); + encryption_cols["b_map.key"] = encryption_col_builder_24.build(); + encryption_cols["b_map.key_value.value"] = encryption_col_builder_25.build(); + encryption_cols["b_list.list.element"] = encryption_col_builder_26.build(); + encryption_cols["b_struct.f1"] = encryption_col_builder_27.build(); + encryption_cols["c_list.element"] = encryption_col_builder_28.build(); + + parquet::FileEncryptionProperties::Builder file_encryption_builder(kFooterEncryptionKey_); + file_encryption_builder.encrypted_columns(encryption_cols); + auto encryption_configurations = file_encryption_builder.build(); + + auto a_key = parquet::schema::PrimitiveNode::Make("key", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32); + auto a_value = parquet::schema::PrimitiveNode::Make("value", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8); + auto a_key_value = parquet::schema::GroupNode::Make("key_value", Repetition::REPEATED, {a_key, a_value}, ConvertedType::NONE); + auto a_map = parquet::schema::GroupNode::Make("a_map", Repetition::OPTIONAL, {a_key_value}, ConvertedType::MAP); + + auto a_list_elem = parquet::schema::PrimitiveNode::Make("element", Repetition::OPTIONAL, Type::INT32, ConvertedType::INT_32); + auto a_list_list = parquet::schema::GroupNode::Make("list", Repetition::REPEATED, {a_list_elem}, ConvertedType::NONE); + auto a_list = parquet::schema::GroupNode::Make("a_list", Repetition::OPTIONAL, {a_list_list}, ConvertedType::LIST); + + auto a_struct_f1 = parquet::schema::PrimitiveNode::Make("f1", Repetition::OPTIONAL, Type::INT32, ConvertedType::INT_32); + auto a_struct_f2 = parquet::schema::PrimitiveNode::Make("f2", Repetition::OPTIONAL, Type::INT64, ConvertedType::INT_64); + auto a_struct = parquet::schema::GroupNode::Make("a_struct", Repetition::OPTIONAL, {a_struct_f1, a_struct_f2}, ConvertedType::NONE); + + auto b_key = parquet::schema::PrimitiveNode::Make("key", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32); + auto b_value = parquet::schema::PrimitiveNode::Make("value", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8); + auto b_key_value = parquet::schema::GroupNode::Make("key_value", Repetition::REPEATED, {b_key, b_value}, ConvertedType::NONE); + auto b_map = parquet::schema::GroupNode::Make("b_map", Repetition::OPTIONAL, {b_key_value}, ConvertedType::MAP); + + auto b_list_elem = parquet::schema::PrimitiveNode::Make("element", Repetition::OPTIONAL, Type::INT32, ConvertedType::INT_32); + auto b_list_list = parquet::schema::GroupNode::Make("list", Repetition::REPEATED, {b_list_elem}, ConvertedType::NONE); + auto b_list = parquet::schema::GroupNode::Make("b_list", Repetition::OPTIONAL, {b_list_list}, ConvertedType::LIST); + + auto b_struct_f1 = parquet::schema::PrimitiveNode::Make("f1", Repetition::OPTIONAL, Type::INT32, ConvertedType::INT_32); + auto b_struct_f2 = parquet::schema::PrimitiveNode::Make("f2", Repetition::OPTIONAL, Type::INT64, ConvertedType::INT_64); + auto b_struct = parquet::schema::GroupNode::Make("b_struct", Repetition::OPTIONAL, {b_struct_f1, b_struct_f2}, ConvertedType::NONE); + + auto c_list_elem = parquet::schema::PrimitiveNode::Make("element", Repetition::OPTIONAL, Type::INT32, ConvertedType::INT_32); + auto c_list_list = parquet::schema::GroupNode::Make("list", Repetition::REPEATED, {c_list_elem}, ConvertedType::NONE); + auto c_list = parquet::schema::GroupNode::Make("c_list", Repetition::OPTIONAL, {c_list_list}, ConvertedType::LIST); + + auto a_structs_f1 = parquet::schema::PrimitiveNode::Make("f1", Repetition::OPTIONAL, Type::INT32, ConvertedType::INT_32); + auto a_structs_f2 = parquet::schema::PrimitiveNode::Make("f2", Repetition::OPTIONAL, Type::INT64, ConvertedType::INT_64); + auto a_structs = parquet::schema::GroupNode::Make("a_structs", Repetition::OPTIONAL, {a_structs_f1, a_structs_f2}, ConvertedType::NONE); + + auto schema = parquet::schema::GroupNode::Make("schema", Repetition::REQUIRED, {a_map, a_list, a_struct, b_map, b_list, b_struct, c_list, a_structs}); + + SchemaDescriptor descr; + descr.Init(schema); + + // original configuration as set above + auto cols = encryption_configurations->encrypted_columns(); + ASSERT_EQ(cols.at("a_map")->column_path(), "a_map"); + ASSERT_EQ(cols.at("a_list")->column_path(), "a_list"); + ASSERT_EQ(cols.at("a_struct")->column_path(), "a_struct"); + ASSERT_EQ(cols.at("b_map.key")->column_path(), "b_map.key"); + ASSERT_EQ(cols.at("b_map.key_value.value")->column_path(), "b_map.key_value.value"); + ASSERT_EQ(cols.at("b_list.list.element")->column_path(), "b_list.list.element"); + ASSERT_EQ(cols.at("b_struct.f1")->column_path(), "b_struct.f1"); + ASSERT_EQ(cols.at("c_list.element")->column_path(), "c_list.element"); + ASSERT_EQ(cols.size(), 8); + + encryption_configurations->encrypt_schema(descr); + + // the updated configuration where parent fields have been replaced with all their leaf fields + cols = encryption_configurations->encrypted_columns(); + ASSERT_EQ(cols.at("a_map.key_value.key")->column_path(), "a_map"); + ASSERT_EQ(cols.at("a_map.key_value.value")->column_path(), "a_map"); + ASSERT_EQ(cols.at("a_list.list.element")->column_path(), "a_list"); + ASSERT_EQ(cols.at("a_struct.f1")->column_path(), "a_struct"); + ASSERT_EQ(cols.at("a_struct.f2")->column_path(), "a_struct"); + ASSERT_EQ(cols.at("b_map.key_value.key")->column_path(), "b_map.key"); + ASSERT_EQ(cols.at("b_map.key_value.value")->column_path(), "b_map.key_value.value"); + ASSERT_EQ(cols.at("b_list.list.element")->column_path(), "b_list.list.element"); + ASSERT_EQ(cols.at("b_struct.f1")->column_path(), "b_struct.f1"); + ASSERT_EQ(cols.at("c_list.list.element")->column_path(), "c_list.element"); + ASSERT_EQ(cols.size(), 10); +} + // Set temp_dir before running the write/read tests. The encrypted files will // be written/read from this directory. void TestEncryptionConfiguration::SetUpTestCase() {