diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index 06a4c279..77080e51 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -46,7 +46,7 @@ def _fit_continuous(self, data): A ``ColumnTransformInfo`` object. """ column_name = data.columns[0] - gm = ClusterBasedNormalizer(model_missing_values=True, max_clusters=min(len(data), 10)) + gm = ClusterBasedNormalizer(model_missing_values=True, max_clusters=min(len(data), self._max_clusters)) gm.fit(data, column_name) num_components = sum(gm.valid_component_indicator)