Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def set_table_parameters(self, table_name, table_parameters):
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, **table_parameters
)
self._table_synthesizers[table_name]._data_processor.table_name = table_name
self._table_parameters[table_name].update(deepcopy(table_parameters))

def _validate_all_tables(self, data):
Expand Down
49 changes: 37 additions & 12 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def _get_num_data_columns(metadata):
columns_per_table = {}
for table_name, table in metadata.tables.items():
key_columns = metadata._get_all_keys(table_name)
columns_per_table[table_name] = sum([
num_data_columns = sum([
1
for col_name, col_meta in table.columns.items()
if (
col_meta['sdtype'] != 'id'
or (col_name not in key_columns and col_meta.get('pii', False) is False)
)
])
num_extended_columns = 0
columns_per_table[table_name] = [num_data_columns, num_extended_columns]

return columns_per_table

Expand All @@ -85,18 +87,29 @@ def _get_num_extended_columns(
table_name, cls.DEFAULT_SYNTHESIZER_KWARGS['default_distribution']
)

num_parameters = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[distribution]

num_params_data = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[distribution]
num_params_extended = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION
]
num_rows_columns = len(metadata._get_foreign_keys(parent_table, table_name))

# no parameter columns are generated if there are no data columns
num_data_columns = columns_per_table[table_name]
if num_data_columns == 0:
# no parameter columns are generated if there are no data or extended columns
num_data_columns = columns_per_table[table_name][0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does changing the meaning of data_columns to be a subset of the former data_columns seem reasonable?

num_extended_columns = columns_per_table[table_name][1]

if (num_data_columns + num_extended_columns) == 0:
return num_rows_columns

num_parameters_columns = num_rows_columns * num_data_columns * num_parameters
num_parameters_columns = (num_rows_columns * num_data_columns * num_params_data) + (
num_rows_columns * num_extended_columns * num_params_extended
)

num_correlation_columns = num_rows_columns * (num_data_columns - 1) * num_data_columns // 2
num_correlation_columns = (
num_rows_columns
* (num_data_columns + num_extended_columns - 1)
* (num_data_columns + num_extended_columns)
// 2
)

return num_correlation_columns + num_rows_columns + num_parameters_columns

Expand All @@ -118,9 +131,11 @@ def _estimate_columns_traversal(
"""
for child_name in metadata._get_child_map()[table_name]:
if child_name not in visited:
cls._estimate_columns_traversal(metadata, child_name, columns_per_table, visited)
cls._estimate_columns_traversal(
metadata, child_name, columns_per_table, visited, distributions
)

columns_per_table[table_name] += cls._get_num_extended_columns(
columns_per_table[table_name][1] += cls._get_num_extended_columns(
metadata, child_name, table_name, columns_per_table, distributions
)

Expand Down Expand Up @@ -157,7 +172,9 @@ def _estimate_num_columns(cls, metadata, distributions=None):
metadata, table_name, columns_per_table, visited, distributions
)

return columns_per_table
return {
table_name: sum(columns_list) for table_name, columns_list in columns_per_table.items()
}

def __init__(self, metadata, locales=['en_US'], verbose=True):
BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales)
Expand All @@ -173,6 +190,11 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
BaseHierarchicalSampler.__init__(
self, self.metadata, self._table_synthesizers, self._table_sizes
)
child_tables = set()
for relationship in metadata.relationships:
child_tables.add(relationship['child_table_name'])
for child_table_name in child_tables:
self.set_table_parameters(child_table_name, {'default_distribution': 'norm'})
self._print_estimate_warning()

def set_table_parameters(self, table_name, table_parameters):
Expand Down Expand Up @@ -238,7 +260,7 @@ def _print_estimate_warning(self):
for table, est_cols in self._estimate_num_columns(self.metadata, distributions).items():
entry = []
entry.append(table)
entry.append(metadata_columns[table])
entry.append(sum(metadata_columns[table]))
total_est_cols += est_cols
entry.append(est_cols)
print_table.append(entry)
Expand Down Expand Up @@ -679,6 +701,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
parameters = self._extract_parameters(row, table_name, foreign_key)
table_meta = self._table_synthesizers[table_name].get_metadata()
synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name])
extended_columns = getattr(self, '_parent_extended_columns', {}).get(table_name, [])
if extended_columns:
self._set_extended_columns_distributions(synthesizer, table_name, extended_columns)
synthesizer._set_parameters(parameters)
try:
likelihoods[parent_id] = synthesizer._get_likelihood(table_rows)
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,9 +2610,10 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes():
})
synthesizer = HMASynthesizer(metadata)
synthesizer._finalize = Mock(return_value=data)
distributions = synthesizer._get_distributions()

# Run estimation
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)

# Run actual modeling
synthesizer.fit(data)
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def test_simplify_schema(capsys):
# Assert
expected_message_before = re.compile(
r'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended\.'
r' To model this data, HMA will generate a large number of columns\. \(173818 columns\)\s+'
r' To model this data, HMA will generate a large number of columns\. \(135934 columns\)\s+'
r'Table Name\s*#\s*Columns in Metadata\s*Est # Columns\s*'
r'match_stats\s*24\s*24\s*'
r'matches\s*39\s*412\s*'
r'players\s*5\s*378\s*'
r'teams\s*1\s*173004\s*'
r'matches\s*39\s*364\s*'
r'players\s*5\s*330\s*'
r'teams\s*1\s*135216\s*'
r'We recommend simplifying your metadata schema using '
r"'sdv.utils.poc.simplify_schema'\.\s*"
r'If this is not possible, please visit '
Expand Down
19 changes: 8 additions & 11 deletions tests/unit/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test___init__(self):
assert isinstance(instance._table_synthesizers['oseba'], GaussianCopulaSynthesizer)
assert isinstance(instance._table_synthesizers['upravna_enota'], GaussianCopulaSynthesizer)
assert instance._table_parameters == {
'nesreca': {'default_distribution': 'beta'},
'oseba': {'default_distribution': 'beta'},
'nesreca': {'default_distribution': 'norm'},
'oseba': {'default_distribution': 'norm'},
'upravna_enota': {'default_distribution': 'beta'},
}
instance.metadata.validate.assert_called_once_with()
Expand Down Expand Up @@ -70,8 +70,6 @@ def test__get_extension(self):

# Assert
expected = pd.DataFrame({
'__nesreca__upravna_enota__univariates__id_nesreca__a': [1.0, 1.0, 1.0, 1.0],
'__nesreca__upravna_enota__univariates__id_nesreca__b': [1.0, 1.0, 1.0, 1.0],
'__nesreca__upravna_enota__univariates__id_nesreca__loc': [0.0, 1.0, 2.0, 3.0],
'__nesreca__upravna_enota__univariates__id_nesreca__scale': [np.nan] * 4,
'__nesreca__upravna_enota__num_rows': [1.0, 1.0, 1.0, 1.0],
Expand Down Expand Up @@ -187,12 +185,8 @@ def test__augment_table(self):
'nesreca_val': [0, 1, 2, 3],
'value': [0, 1, 2, 3],
'__oseba__id_nesreca__correlation__0__0': [0.0] * 4,
'__oseba__id_nesreca__univariates__oseba_val__a': [1.0] * 4,
'__oseba__id_nesreca__univariates__oseba_val__b': [1.0] * 4,
'__oseba__id_nesreca__univariates__oseba_val__loc': [0.0, 1.0, 2.0, 3.0],
'__oseba__id_nesreca__univariates__oseba_val__scale': [1e-6] * 4,
'__oseba__id_nesreca__univariates__oseba_value__a': [1.0] * 4,
'__oseba__id_nesreca__univariates__oseba_value__b': [1.0] * 4,
'__oseba__id_nesreca__univariates__oseba_value__loc': [0.0, 1.0, 2.0, 3.0],
'__oseba__id_nesreca__univariates__oseba_value__scale': [1e-6] * 4,
'__oseba__id_nesreca__num_rows': [1.0] * 4,
Expand Down Expand Up @@ -877,9 +871,10 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self):
})
synthesizer = HMASynthesizer(metadata)
synthesizer._finalize = Mock(return_value=data)
distributions = synthesizer._get_distributions()

# Run estimation
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)

# Run actual modeling
synthesizer.fit(data)
Expand Down Expand Up @@ -1152,9 +1147,10 @@ def test__estimate_num_columns_to_be_modeled(self):
})
synthesizer = HMASynthesizer(metadata)
synthesizer._finalize = Mock(return_value=data)
distributions = synthesizer._get_distributions()

# Run estimation
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)

# Run actual modeling
synthesizer.fit(data)
Expand Down Expand Up @@ -1264,9 +1260,10 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self):
})
synthesizer = HMASynthesizer(metadata)
synthesizer._finalize = Mock(return_value=data)
distributions = synthesizer._get_distributions()

# Run estimation
estimated_num_columns = synthesizer._estimate_num_columns(metadata)
estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions)

# Run actual modeling
synthesizer.fit(data)
Expand Down