diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 5142679d0..6d8c49a81 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -12,17 +12,19 @@ import boto3 import numpy as np import pandas as pd +import yaml from botocore import UNSIGNED from botocore.client import Config -from botocore.exceptions import ClientError +from sdv.errors import DemoResourceNotFoundError, DemoResourceNotFoundWarning from sdv.metadata.metadata import Metadata LOGGER = logging.getLogger(__name__) -BUCKET = 'sdv-demo-datasets' -BUCKET_URL = 'https://sdv-demo-datasets.s3.amazonaws.com' +BUCKET = 'sdv-datasets-public' +BUCKET_URL = f'https://{BUCKET}.s3.amazonaws.com' SIGNATURE_VERSION = UNSIGNED METADATA_FILENAME = 'metadata.json' +FALLBACK_ENCODING = 'latin-1' def _validate_modalities(modality): @@ -39,27 +41,147 @@ def _validate_output_folder(output_folder_name): ) +def _create_s3_client(): + """Create and return an S3 client with unsigned requests.""" + return boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) + + def _get_data_from_bucket(object_key): - session = boto3.Session() - s3 = session.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) + s3 = _create_s3_client() response = s3.get_object(Bucket=BUCKET, Key=object_key) return response['Body'].read() -def _download(modality, dataset_name): - dataset_url = f'{BUCKET_URL}/{modality.upper()}/{dataset_name}.zip' - object_key = f'{modality.upper()}/{dataset_name}.zip' - LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}') - try: - file_content = _get_data_from_bucket(object_key) - except ClientError: - raise ValueError( - f"Invalid dataset name '{dataset_name}'. " - 'Make sure you have the correct modality for the dataset name or ' - "use 'get_available_demos' to get a list of demo datasets." +def _list_objects(prefix): + """List all objects under a given prefix using pagination. + + Args: + prefix (str): + The S3 prefix to list. + + Returns: + list[dict]: + A list of object summaries. + """ + client = _create_s3_client() + contents = [] + paginator = client.get_paginator('list_objects_v2') + for resp in paginator.paginate(Bucket=BUCKET, Prefix=prefix): + contents.extend(resp.get('Contents', [])) + + if not contents: + raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{BUCKET}'.") + + return contents + + +def _search_contents_keys(contents, match_fn): + """Return list of keys from ``contents`` that satisfy ``match_fn``. + + Args: + contents (list[dict]): + S3 list_objects-like contents entries. + match_fn (callable): + Function that receives a key (str) and returns True if it matches. + + Returns: + list[str]: + Keys in their original order that matched the predicate. + """ + matches = [] + for entry in contents or []: + key = entry.get('Key', '') + try: + if match_fn(key): + matches.append(key) + except Exception: + continue + + return matches + + +def _find_data_zip_key(contents, dataset_prefix): + """Find the 'data.zip' object key under dataset prefix, case-insensitive. + + Args: + contents (list[dict]): + List of objects from S3. + dataset_prefix (str): + Prefix like 'single_table/dataset/'. + + Returns: + str: + The key to the data zip if found. + """ + prefix_lower = dataset_prefix.lower() + + def is_data_zip(key): + return key.lower() == f'{prefix_lower}data.zip' + + matches = _search_contents_keys(contents, is_data_zip) + if matches: + return matches[0] + + raise DemoResourceNotFoundError("Could not find 'data.zip' for the requested dataset.") + + +def _get_first_v1_metadata_bytes(contents, dataset_prefix): + """Find and return bytes of the first V1 metadata JSON under `dataset_prefix`. + + Scans S3 listing `contents` and, for any JSON file directly under the dataset prefix, + downloads and returns its bytes if it contains METADATA_SPEC_VERSION == 'V1'. + + Returns: + bytes: + The bytes of the first V1 metadata JSON. + """ + prefix_lower = dataset_prefix.lower() + + def is_direct_json_under_prefix(key): + key_lower = key.lower() + return ( + key_lower.startswith(prefix_lower) + and key_lower.endswith('.json') + and 'metadata' in key_lower + and key_lower.count('/') == prefix_lower.count('/') ) - return io.BytesIO(file_content) + candidate_keys = _search_contents_keys(contents, is_direct_json_under_prefix) + + for key in candidate_keys: + try: + raw = _get_data_from_bucket(key) + metadict = json.loads(raw) + if isinstance(metadict, dict) and metadict.get('METADATA_SPEC_VERSION') == 'V1': + return raw + + except Exception: + continue + + raise DemoResourceNotFoundError( + 'Could not find a valid metadata JSON with METADATA_SPEC_VERSION "V1".' + ) + + +def _download(modality, dataset_name): + """Download dataset resources from a bucket. + + Returns: + tuple: + (BytesIO(zip_bytes), metadata_bytes) + """ + dataset_prefix = f'{modality}/{dataset_name}/' + LOGGER.info( + f"Downloading dataset '{dataset_name}' for modality '{modality}' from " + f'{BUCKET_URL}/{dataset_prefix}' + ) + contents = _list_objects(dataset_prefix) + + zip_key = _find_data_zip_key(contents, dataset_prefix) + zip_bytes = _get_data_from_bucket(zip_key) + metadata_bytes = _get_first_v1_metadata_bytes(contents, dataset_prefix) + + return io.BytesIO(zip_bytes), metadata_bytes def _extract_data(bytes_io, output_folder_name): @@ -67,13 +189,6 @@ def _extract_data(bytes_io, output_folder_name): if output_folder_name: os.makedirs(output_folder_name, exist_ok=True) zf.extractall(output_folder_name) - metadata_v0_filepath = os.path.join(output_folder_name, 'metadata_v0.json') - if os.path.isfile(metadata_v0_filepath): - os.remove(metadata_v0_filepath) - os.rename( - os.path.join(output_folder_name, 'metadata_v1.json'), - os.path.join(output_folder_name, METADATA_FILENAME), - ) else: in_memory_directory = {} @@ -83,20 +198,72 @@ def _extract_data(bytes_io, output_folder_name): return in_memory_directory -def _get_data(modality, output_folder_name, in_memory_directory): +def _get_data_with_output_folder(output_folder_name): + """Load CSV tables from an extracted folder on disk. + + Returns a tuple of (data_dict, skipped_files). + Non-CSV files are ignored. + """ data = {} - if output_folder_name: - for filename in os.listdir(output_folder_name): - if filename.endswith('.csv'): - table_name = Path(filename).stem - data_path = os.path.join(output_folder_name, filename) + skipped_files = [] + for root, _dirs, files in os.walk(output_folder_name): + for filename in files: + if not filename.lower().endswith('.csv'): + skipped_files.append(filename) + continue + + table_name = Path(filename).stem + data_path = os.path.join(root, filename) + try: data[table_name] = pd.read_csv(data_path) + except UnicodeDecodeError: + data[table_name] = pd.read_csv(data_path, encoding=FALLBACK_ENCODING) + except Exception as e: + rel = os.path.relpath(data_path, output_folder_name) + skipped_files.append(f'{rel}: {e}') + + return data, skipped_files + + +def _get_data_without_output_folder(in_memory_directory): + """Load CSV tables directly from in-memory zip contents. + + Returns a tuple of (data_dict, skipped_files). + Non-CSV entries are ignored. + """ + data = {} + skipped_files = [] + for filename, file_ in in_memory_directory.items(): + if not filename.lower().endswith('.csv'): + skipped_files.append(filename) + continue + + table_name = Path(filename).stem + try: + data[table_name] = pd.read_csv(io.BytesIO(file_), low_memory=False) + except UnicodeDecodeError: + data[table_name] = pd.read_csv( + io.BytesIO(file_), low_memory=False, encoding=FALLBACK_ENCODING + ) + except Exception as e: + skipped_files.append(f'{filename}: {e}') + + return data, skipped_files + +def _get_data(modality, output_folder_name, in_memory_directory): + if output_folder_name: + data, skipped_files = _get_data_with_output_folder(output_folder_name) else: - for filename, file_ in in_memory_directory.items(): - if filename.endswith('.csv'): - table_name = Path(filename).stem - data[table_name] = pd.read_csv(io.StringIO(file_.decode()), low_memory=False) + data, skipped_files = _get_data_without_output_folder(in_memory_directory) + + if skipped_files: + warnings.warn('Skipped files: ' + ', '.join(sorted(skipped_files))) + + if not data: + raise DemoResourceNotFoundError( + 'Demo data could not be downloaded because no csv files were found in data.zip' + ) if modality != 'multi_table': data = data.popitem()[1] @@ -104,20 +271,41 @@ def _get_data(modality, output_folder_name, in_memory_directory): return data -def _get_metadata(output_folder_name, in_memory_directory, dataset_name): - metadata = Metadata() - if output_folder_name: - metadata_path = os.path.join(output_folder_name, METADATA_FILENAME) - metadata = metadata.load_from_json(metadata_path, dataset_name) +def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None): + """Parse metadata bytes and optionally persist to ``output_folder_name``. - else: - metadata_path = 'metadata_v2.json' - if metadata_path not in in_memory_directory: - warnings.warn(f'Metadata for {dataset_name} is missing updated version v2.') - metadata_path = 'metadata_v1.json' + Args: + metadata_bytes (bytes): + Raw bytes of the metadata JSON file. + dataset_name (str): + The dataset name used when loading into ``Metadata``. + output_folder_name (str or None): + Optional folder path where to write ``metadata.json``. - metadict = json.loads(in_memory_directory[metadata_path]) - metadata = metadata.load_from_dict(metadict, dataset_name) + Returns: + Metadata: + Parsed metadata object. + """ + try: + metadict = json.loads(metadata_bytes) + metadata = Metadata().load_from_dict(metadict, dataset_name) + except Exception as e: + raise DemoResourceNotFoundError('Failed to parse metadata JSON for the dataset.') from e + + if output_folder_name: + try: + metadata_path = os.path.join(str(output_folder_name), METADATA_FILENAME) + with open(metadata_path, 'wb') as f: + f.write(metadata_bytes) + + except Exception: + warnings.warn( + ( + f'Error saving {METADATA_FILENAME} for dataset {dataset_name} into ' + f'{output_folder_name}.', + ), + DemoResourceNotFoundWarning, + ) return metadata @@ -129,7 +317,7 @@ def download_demo(modality, dataset_name, output_folder_name=None): modality (str): The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``. dataset_name (str): - Name of the dataset to be downloaded from the sdv-datasets S3 bucket. + Name of the dataset to be downloaded from the sdv-datasets-public S3 bucket. output_folder_name (str or None): The name of the local folder where the metadata and data should be stored. If ``None`` the data is not saved locally and is loaded as a Python object. @@ -149,14 +337,75 @@ def download_demo(modality, dataset_name, output_folder_name=None): """ _validate_modalities(modality) _validate_output_folder(output_folder_name) - bytes_io = _download(modality, dataset_name) - in_memory_directory = _extract_data(bytes_io, output_folder_name) + + data_io, metadata_bytes = _download(modality, dataset_name) + in_memory_directory = _extract_data(data_io, output_folder_name) data = _get_data(modality, output_folder_name, in_memory_directory) - metadata = _get_metadata(output_folder_name, in_memory_directory, dataset_name) + metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name) return data, metadata +def _iter_metainfo_yaml_entries(contents, modality): + """Yield (dataset_name, yaml_key) for metainfo.yaml files under a modality. + + This matches keys like '//metainfo.yaml'. + """ + modality_lower = (modality or '').lower() + + def is_metainfo_yaml(key): + parts = key.split('/') + if len(parts) != 3: + return False + if parts[0].lower() != modality_lower: + return False + if parts[-1].lower() != 'metainfo.yaml': + return False + return bool(parts[1]) + + for key in _search_contents_keys(contents, is_metainfo_yaml): + dataset_name = key.split('/')[1] + yield dataset_name, key + + +def _get_info_from_yaml_key(yaml_key): + """Load and parse YAML metadata from an S3 key.""" + raw = _get_data_from_bucket(yaml_key) + return yaml.safe_load(raw) or {} + + +def _parse_size_mb(size_mb_val, dataset_name): + """Parse the size (MB) value into a float or NaN with logging on failures.""" + try: + return float(size_mb_val) if size_mb_val is not None else np.nan + except (ValueError, TypeError): + LOGGER.info( + f'Invalid dataset-size-mb {size_mb_val} for dataset {dataset_name}; defaulting to NaN.' + ) + return np.nan + + +def _parse_num_tables(num_tables_val, dataset_name): + """Parse the num-tables value into an int or NaN with logging on failures.""" + if isinstance(num_tables_val, str): + try: + num_tables_val = float(num_tables_val) + except (ValueError, TypeError): + LOGGER.info( + f'Could not cast num_tables_val {num_tables_val} to float for ' + f'dataset {dataset_name}; defaulting to NaN.' + ) + num_tables_val = np.nan + + try: + return int(num_tables_val) if not pd.isna(num_tables_val) else np.nan + except (ValueError, TypeError): + LOGGER.info( + f'Invalid num-tables {num_tables_val} for dataset {dataset_name} when parsing as int.' + ) + return np.nan + + def get_available_demos(modality): """Get demo datasets available for a ``modality``. @@ -170,23 +419,169 @@ def get_available_demos(modality): ``dataset_name``: The name of the dataset. ``size_MB``: The unzipped folder size in MB. ``num_tables``: The number of tables in the dataset. - - Raises: - Error: - * If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``. """ _validate_modalities(modality) - client = boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) + contents = _list_objects(f'{modality}/') tables_info = defaultdict(list) - for item in client.list_objects(Bucket=BUCKET)['Contents']: - dataset_modality, dataset = item['Key'].split('/', 1) - if dataset_modality == modality.upper(): - tables_info['dataset_name'].append(dataset.replace('.zip', '')) - headers = client.head_object(Bucket=BUCKET, Key=item['Key'])['Metadata'] - size_mb = headers.get('size-mb', np.nan) - tables_info['size_MB'].append(round(float(size_mb), 2)) - tables_info['num_tables'].append(headers.get('num-tables', np.nan)) - - df = pd.DataFrame(tables_info) - df['num_tables'] = pd.to_numeric(df['num_tables']) - return df + for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality): + try: + info = _get_info_from_yaml_key(yaml_key) + + size_mb = _parse_size_mb(info.get('dataset-size-mb'), dataset_name) + num_tables = _parse_num_tables(info.get('num-tables', np.nan), dataset_name) + + tables_info['dataset_name'].append(dataset_name) + tables_info['size_MB'].append(size_mb) + tables_info['num_tables'].append(num_tables) + + except Exception: + continue + + return pd.DataFrame(tables_info) + + +def _find_text_key(contents, dataset_prefix, filename): + """Find a text file key (README.txt or SOURCE.txt). + + Performs a case-insensitive search for ``filename`` directly under ``dataset_prefix``. + + Args: + contents (list[dict]): + List of objects from S3. + dataset_prefix (str): + Prefix like 'single_table/dataset/'. + filename (str): + The filename to look for (e.g., 'README.txt'). + + Returns: + str or None: + The key if found, otherwise ``None``. + """ + expected_lower = f'{dataset_prefix}{filename}'.lower() + for entry in contents: + key = entry.get('Key') or '' + if key.lower() == expected_lower: + return key + + return None + + +def _validate_text_file_content(modality, output_filepath, filename): + """Validation for the text file content method.""" + _validate_modalities(modality) + if output_filepath is not None and not str(output_filepath).endswith('.txt'): + fname = (filename or '').lower() + file_type = 'README' if 'readme' in fname else 'source' + raise ValueError( + f'The {file_type} can only be saved as a txt file. ' + "Please provide a filepath ending in '.txt'" + ) + + +def _raise_warnings(filename, output_filepath): + """Warn about missing text resources for a dataset.""" + if (filename or '').upper() == 'README.TXT': + msg = 'No README information is available for this dataset.' + elif (filename or '').upper() == 'SOURCE.TXT': + msg = 'No source information is available for this dataset.' + else: + msg = f'No {filename} information is available for this dataset.' + + if output_filepath: + msg = f'{msg} The requested file ({output_filepath}) will not be created.' + + warnings.warn(msg, DemoResourceNotFoundWarning) + + +def _save_document(text, output_filepath, filename, dataset_name): + """Persist ``text`` to ``output_filepath`` if provided.""" + if not output_filepath: + return + + if os.path.exists(str(output_filepath)): + raise ValueError( + f"A file named '{output_filepath}' already exists. Please specify a different filepath." + ) + + try: + parent = os.path.dirname(str(output_filepath)) + if parent: + os.makedirs(parent, exist_ok=True) + with open(output_filepath, 'w', encoding='utf-8') as f: + f.write(text) + except Exception: + LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.') + + +def _get_text_file_content(modality, dataset_name, filename, output_filepath=None): + """Fetch text file content under the dataset prefix. + + Args: + modality (str): + The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``. + dataset_name (str): + The name of the dataset. + filename (str): + The filename to fetch (``'README.txt'`` or ``'SOURCE.txt'``). + output_filepath (str or None): + If provided, save the file contents at this path. + + Returns: + str or None: + The decoded text contents if the file exists, otherwise ``None``. + """ + _validate_text_file_content(modality, output_filepath, filename) + + dataset_prefix = f'{modality}/{dataset_name}/' + contents = _list_objects(dataset_prefix) + key = _find_text_key(contents, dataset_prefix, filename) + if not key: + _raise_warnings(filename, output_filepath) + return None + + try: + raw = _get_data_from_bucket(key) + except Exception: + LOGGER.info(f'Error fetching {filename} for dataset {dataset_name}.') + return None + + text = raw.decode('utf-8', errors='replace') + _save_document(text, output_filepath, filename, dataset_name) + + return text + + +def get_source(modality, dataset_name, output_filepath=None): + """Get dataset source/citation text. + + Args: + modality (str): + The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``. + dataset_name (str): + The name of the dataset to get the source information for. + output_filepath (str or None): + Optional path where to save the file. + + Returns: + str or None: + The contents of the source file if it exists; otherwise ``None``. + """ + return _get_text_file_content(modality, dataset_name, 'SOURCE.txt', output_filepath) + + +def get_readme(modality, dataset_name, output_filepath=None): + """Get dataset README text. + + Args: + modality (str): + The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``. + dataset_name (str): + The name of the dataset to get the README for. + output_filepath (str or None): + Optional path where to save the file. + + Returns: + str or None: + The contents of the README file if it exists; otherwise ``None``. + """ + return _get_text_file_content(modality, dataset_name, 'README.txt', output_filepath) diff --git a/sdv/errors.py b/sdv/errors.py index a44c21e14..bb72271e1 100644 --- a/sdv/errors.py +++ b/sdv/errors.py @@ -95,3 +95,20 @@ class RefitWarning(UserWarning): class SynthesizerProcessingError(Exception): """Error to raise when synthesizer parameters are invalid.""" + + +class DemoResourceNotFoundError(Exception): + """Raised when a demo dataset or one of its resources cannot be found. + + This error is intended for missing demo assets such as the dataset archive, + metadata, license, README, or other auxiliary files in the demo bucket. + """ + + +class DemoResourceNotFoundWarning(UserWarning): + """Warning raised when an optional demo resource is not available. + + This warning indicates that a non-critical artifact (e.g., README or SOURCE + information) is not present for a given demo dataset. The operation can + continue, but the requested information cannot be provided. + """ diff --git a/tests/integration/datasets/test_demo.py b/tests/integration/datasets/test_demo.py index 7fac34495..af2822536 100644 --- a/tests/integration/datasets/test_demo.py +++ b/tests/integration/datasets/test_demo.py @@ -1,220 +1,87 @@ import pandas as pd -from pandas.api.types import is_integer_dtype from sdv.datasets.demo import get_available_demos def test_get_available_demos_single_table(): - """Test it can get demos for single table.""" + """Test single_table demos listing equals the expected filtered list and values.""" # Run tables_info = get_available_demos('single_table') + mask = ~( + tables_info['dataset_name'].str.startswith('bad_') + | tables_info['dataset_name'].str.startswith('dataset') + ) + tables_info = tables_info[mask].reset_index(drop=True) # Assert - expected_table = pd.DataFrame({ + expected = pd.DataFrame({ 'dataset_name': [ 'adult', 'alarm', 'census', + 'census_extended', 'child', 'covtype', 'expedia_hotel_logs', + 'fake_companies', + 'fake_hotel_guests', 'insurance', 'intrusion', 'news', + 'student_placements', + 'student_placements_pii', ], 'size_MB': [ - '3.907448', - '4.520128', - '98.165608', - '3.200128', - '255.645408', - '0.200128', - '3.340128', - '162.039016', - '18.712096', + 3.91, + 4.52, + 98.17, + 4.95, + 3.20, + 255.65, + 0.20, + 0.00, + 0.03, + 3.34, + 162.04, + 18.71, + 0.03, + 0.03, + ], + 'num_tables': [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, ], - 'num_tables': ['1'] * 9, }) - expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2) - expected_table['num_tables'] = pd.to_numeric(expected_table['num_tables']) - assert is_integer_dtype(tables_info['num_tables']) - assert len(expected_table.merge(tables_info)) == len(expected_table) + pd.testing.assert_frame_equal(tables_info[['dataset_name', 'size_MB', 'num_tables']], expected) def test_get_available_demos_multi_table(): - """Test it can get demos for multi table.""" + """Test multi_table demos listing is returned with expected columns and types.""" # Run tables_info = get_available_demos('multi_table') # Assert - expected_table = pd.DataFrame({ + expected = pd.DataFrame({ 'dataset_name': [ - 'Accidents_v1', - 'Atherosclerosis_v1', - 'AustralianFootball_v1', - 'Biodegradability_v1', - 'Bupa_v1', - 'CORA_v1', - 'Carcinogenesis_v1', - 'Chess_v1', - 'Countries_v1', - 'DCG_v1', - 'Dunur_v1', - 'Elti_v1', - 'FNHK_v1', - 'Facebook_v1', - 'Hepatitis_std_v1', - 'Mesh_v1', - 'Mooney_Family_v1', - 'MuskSmall_v1', - 'NBA_v1', - 'NCAA_v1', - 'PTE_v1', - 'Pima_v1', - 'PremierLeague_v1', - 'Pyrimidine_v1', - 'SAP_v1', - 'SAT_v1', - 'SalesDB_v1', - 'Same_gen_v1', - 'Student_loan_v1', - 'Telstra_v1', - 'Toxicology_v1', - 'Triazine_v1', - 'TubePricing_v1', - 'UTube_v1', - 'UW_std_v1', - 'WebKP_v1', - 'airbnb-simplified', - 'financial_v1', - 'ftp_v1', - 'genes_v1', - 'got_families', - 'imdb_MovieLens_v1', - 'imdb_ijs_v1', - 'imdb_small_v1', - 'legalActs_v1', - 'mutagenesis_v1', - 'nations_v1', - 'restbase_v1', - 'rossmann', - 'trains_v1', - 'university_v1', - 'walmart', - 'world_v1', + 'fake_hotels', + 'fake_hotels_extended', ], 'size_MB': [ - '296.202744', - '7.916808', - '32.534832', - '0.692008', - '0.059144', - '1.987328', - '1.642592', - '0.403784', - '10.52272', - '0.321536', - '0.020224', - '0.054912', - '141.560872', - '1.481056', - '0.809472', - '0.101856', - '0.121784', - '0.646752', - '0.16632', - '29.137896', - '1.31464', - '0.160896', - '17.37664', - '0.038144', - '196.479272', - '0.500224', - '325.19768', - '0.056176', - '0.180256', - '5.503512', - '1.495496', - '0.156496', - '15.414536', - '0.135912', - '0.0576', - '1.9718', - '293.14392', - '94.718016', - '5.45568', - '0.440016', - '0.001', - '55.253264', - '259.140656', - '0.205728', - '186.132944', - '0.618088', - '0.540336', - '1.01452', - '73.328504', - '0.00644', - '0.009632', - '14.642184', - '0.295032', - ], - 'num_tables': [ - '3', - '4', - '4', - '5', - '9', - '3', - '6', - '2', - '4', - '2', - '17', - '11', - '3', - '2', - '7', - '29', - '68', - '2', - '4', - '9', - '38', - '9', - '4', - '2', - '4', - '36', - '4', - '4', - '10', - '5', - '4', - '2', - '20', - '2', - '4', - '3', - '2', - '8', - '2', - '3', - '3', - '7', - '7', - '7', - '5', - '3', - '3', - '3', - '2', - '2', - '5', - '3', - '3', + 0.05, + 0.07, ], + 'num_tables': [2, 2], }) - expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2) - expected_table['num_tables'] = pd.to_numeric(expected_table['num_tables']) - assert is_integer_dtype(tables_info['num_tables']) - assert len(expected_table.merge(tables_info, on='dataset_name')) == len(expected_table) + pd.testing.assert_frame_equal(tables_info[['dataset_name', 'size_MB', 'num_tables']], expected) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 33e38aafd..11886dd71 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -37,7 +37,7 @@ def test_hma(self): a 0.5 scale and one with 1.5 scale. """ # Setup - data, metadata = download_demo('multi_table', 'got_families') + data, metadata = download_demo('multi_table', 'fake_hotels') hmasynthesizer = HMASynthesizer(metadata) # Run @@ -46,8 +46,8 @@ def test_hma(self): increased_sample = hmasynthesizer.sample(1.5) # Assert - assert set(normal_sample) == {'characters', 'character_families', 'families'} - assert set(increased_sample) == {'characters', 'character_families', 'families'} + assert set(normal_sample) == set(data.keys()) + assert set(increased_sample) == set(data.keys()) for table_name, table in normal_sample.items(): assert set(table.columns) == set(data[table_name]) @@ -62,7 +62,7 @@ def test_hma_metadata(self): a 0.5 scale and one with 1.5 scale. """ # Setup - data, multi_metadata = download_demo('multi_table', 'got_families') + data, multi_metadata = download_demo('multi_table', 'fake_hotels') metadata = Metadata.load_from_dict(multi_metadata.to_dict()) hmasynthesizer = HMASynthesizer(metadata) @@ -72,8 +72,8 @@ def test_hma_metadata(self): increased_sample = hmasynthesizer.sample(1.5) # Assert - assert set(normal_sample) == {'characters', 'character_families', 'families'} - assert set(increased_sample) == {'characters', 'character_families', 'families'} + assert set(normal_sample) == set(data.keys()) + assert set(increased_sample) == set(data.keys()) for table_name, table in normal_sample.items(): assert set(table.columns) == set(data[table_name]) @@ -88,13 +88,13 @@ def test_hma_reset_sampling(self): """ # Setup faker = Faker() - data, metadata = download_demo('multi_table', 'got_families') + data, metadata = download_demo('multi_table', 'fake_hotels') metadata.add_column( 'ssn', - 'characters', + 'guests', sdtype='ssn', ) - data['characters']['ssn'] = [faker.lexify() for _ in range(len(data['characters']))] + data['guests']['ssn'] = [faker.lexify() for _ in range(len(data['guests']))] for table in metadata.tables.values(): table.alternate_keys = [] @@ -164,16 +164,15 @@ def test_hma_set_table_parameters(self): Validate that the ``set_table_parameters`` sets new parameters to the synthesizers. """ # Setup - _data, metadata = download_demo('multi_table', 'got_families') + _data, metadata = download_demo('multi_table', 'fake_hotels') hmasynthesizer = HMASynthesizer(metadata) # Run - hmasynthesizer.set_table_parameters('characters', {'default_distribution': 'gamma'}) - hmasynthesizer.set_table_parameters('families', {'default_distribution': 'uniform'}) - hmasynthesizer.set_table_parameters('character_families', {'default_distribution': 'norm'}) + hmasynthesizer.set_table_parameters('hotels', {'default_distribution': 'gamma'}) + hmasynthesizer.set_table_parameters('guests', {'default_distribution': 'uniform'}) # Assert - character_params = hmasynthesizer.get_table_parameters('characters') + character_params = hmasynthesizer.get_table_parameters('hotels') assert character_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' assert character_params['synthesizer_parameters'] == { 'default_distribution': 'gamma', @@ -182,7 +181,7 @@ def test_hma_set_table_parameters(self): 'locales': ['en_US'], 'numerical_distributions': {}, } - families_params = hmasynthesizer.get_table_parameters('families') + families_params = hmasynthesizer.get_table_parameters('guests') assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' assert families_params['synthesizer_parameters'] == { 'default_distribution': 'uniform', @@ -191,21 +190,8 @@ def test_hma_set_table_parameters(self): 'locales': ['en_US'], 'numerical_distributions': {}, } - char_families_params = hmasynthesizer.get_table_parameters('character_families') - assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' - assert char_families_params['synthesizer_parameters'] == { - 'default_distribution': 'norm', - 'enforce_min_max_values': True, - 'enforce_rounding': True, - 'locales': ['en_US'], - 'numerical_distributions': {}, - } - - assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma' - assert hmasynthesizer._table_synthesizers['families'].default_distribution == 'uniform' - assert ( - hmasynthesizer._table_synthesizers['character_families'].default_distribution == 'norm' - ) + assert hmasynthesizer._table_synthesizers['hotels'].default_distribution == 'gamma' + assert hmasynthesizer._table_synthesizers['guests'].default_distribution == 'uniform' def get_custom_constraint_data_and_metadata(self): """Return data and metadata for the custom constraint tests.""" @@ -531,7 +517,7 @@ def test_use_own_data_using_hma(self, tmp_path): ) # Run - load CSVs - datasets = load_csvs(data_folder) + datasets = load_csvs(data_folder / 'data') # Assert - loaded CSVs correctly assert datasets.keys() == {'guests', 'hotels'} @@ -646,14 +632,13 @@ def test_use_own_data_using_hma(self, tmp_path): def test_progress_bar_print(self, capsys): """Test that the progress bar prints correctly.""" # Setup - data, metadata = download_demo('multi_table', 'got_families') + data, metadata = download_demo('multi_table', 'fake_hotels') hmasynthesizer = HMASynthesizer(metadata) key_phrases = [ r'Preprocess Tables:', r'Learning relationships:', - r"\(1/2\) Tables 'characters' and 'character_families' \('character_id'\):", - r"\(2/2\) Tables 'families' and 'character_families' \('family_id'\):", + r"Tables 'hotels' and 'guests' \('hotel_id'\):", ] # Run @@ -670,7 +655,25 @@ def test_progress_bar_print(self, capsys): def test_warning_message_too_many_cols(self, capsys): """Test that a warning appears if there are more than 1000 expected columns""" # Setup - (_, metadata) = download_demo(modality='multi_table', dataset_name='NBA_v1') + parent_columns = {'parent_id': {'sdtype': 'id'}, 'parent_data': {'sdtype': 'categorical'}} + child_columns = {'child_id': {'sdtype': 'id'}, 'parent_id': {'sdtype': 'id'}} + for i in range(999): + child_columns[f'col_{i}'] = {'sdtype': 'categorical'} + + large_metadata = Metadata.load_from_dict({ + 'tables': { + 'parent': {'columns': parent_columns, 'primary_key': 'parent_id'}, + 'child': {'columns': child_columns, 'primary_key': 'child_id'}, + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'parent_id', + 'child_table_name': 'child', + 'child_foreign_key': 'parent_id', + } + ], + }) key_phrases = [ r'PerformanceAlert:', @@ -679,7 +682,7 @@ def test_warning_message_too_many_cols(self, capsys): ] # Run - HMASynthesizer(metadata) + HMASynthesizer(large_metadata) captured = capsys.readouterr() @@ -687,7 +690,35 @@ def test_warning_message_too_many_cols(self, capsys): for constraint in key_phrases: match = re.search(constraint, captured.out + captured.err) assert match is not None - (_, small_metadata) = download_demo(modality='multi_table', dataset_name='trains_v1') + + # Setup small metadata that shouldn't trigger warning + small_metadata = Metadata.load_from_dict({ + 'tables': { + 'parent': { + 'columns': { + 'parent_id': {'sdtype': 'id'}, + 'parent_data': {'sdtype': 'categorical'}, + }, + 'primary_key': 'parent_id', + }, + 'child': { + 'columns': { + 'child_id': {'sdtype': 'id'}, + 'parent_id': {'sdtype': 'id'}, + 'child_data': {'sdtype': 'categorical'}, + }, + 'primary_key': 'child_id', + }, + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'parent_id', + 'child_table_name': 'child', + 'child_foreign_key': 'parent_id', + } + ], + }) # Run HMASynthesizer(small_metadata) @@ -1135,31 +1166,27 @@ def test_get_learned_distributions_error_msg(self): def test__get_likelihoods(self): """Test ``_get_likelihoods`` generates likelihoods for parents.""" # Setup - data, metadata = download_demo('multi_table', 'got_families') + data, metadata = download_demo('multi_table', 'fake_hotels') hmasynthesizer = HMASynthesizer(metadata) hmasynthesizer.fit(data) sampled_data = {} - sampled_data['characters'] = hmasynthesizer._sample_rows( - hmasynthesizer._table_synthesizers['characters'], len(data['characters']) + sampled_data['hotels'] = hmasynthesizer._sample_rows( + hmasynthesizer._table_synthesizers['hotels'], len(data['hotels']) ) - hmasynthesizer._sample_children('characters', sampled_data) + hmasynthesizer._sample_children('hotels', sampled_data) # Run likelihoods = hmasynthesizer._get_likelihoods( - sampled_data['character_families'], - sampled_data['characters'].set_index('character_id'), - 'character_families', - 'character_id', + sampled_data['guests'], + sampled_data['hotels'].set_index('hotel_id'), + 'guests', + 'hotel_id', ) # Assert - not_nan_cols = [1, 3, 6] - nan_cols = [2, 4, 5, 7] - assert set(likelihoods.columns) == {1, 2, 3, 4, 5, 6, 7} - assert len(likelihoods) == len(sampled_data['character_families']) - assert not any(likelihoods[not_nan_cols].isna().any()) - assert all(likelihoods[nan_cols].isna()) + assert len(likelihoods) == len(sampled_data['guests']) + assert not likelihoods.isna().any().any() def test__extract_parameters(self): """Test it when parameters are out of bounds.""" @@ -1267,7 +1294,7 @@ def test_metadata_updated_no_warning(self, tmp_path): initialization, but is saved to a file before fitting. """ # Setup - data, multi_metadata = download_demo('multi_table', 'got_families') + data, multi_metadata = download_demo('multi_table', 'fake_hotels') metadata = Metadata.load_from_dict(multi_metadata.to_dict()) # Run 1 @@ -1305,7 +1332,7 @@ def test_metadata_updated_no_warning(self, tmp_path): # Run 3 instance = HMASynthesizer(metadata_detect) metadata_detect.update_column( - table_name='characters', column_name='age', sdtype='categorical' + table_name='guests', column_name='room_rate', sdtype='numerical' ) file_name = tmp_path / 'multitable_2.json' metadata_detect.save_to_json(file_name) @@ -1323,7 +1350,7 @@ def test_metadata_updated_warning_detect(self): not be raised again when calling ``fit``. """ # Setup - data, metadata = download_demo('multi_table', 'got_families') + data, metadata = download_demo('multi_table', 'fake_hotels') metadata_detect = Metadata.detect_from_dataframes(data) metadata_detect.relationships = metadata.relationships @@ -2243,7 +2270,7 @@ def test_fit_raises_version_error(): def test_hma_relationship_validity(): """Test the quality of the HMA synthesizer GH#1834.""" # Setup - data, metadata = download_demo('multi_table', 'Dunur_v1') + data, metadata = download_demo('multi_table', 'fake_hotels') synthesizer = HMASynthesizer(metadata) report = DiagnosticReport() @@ -2259,7 +2286,7 @@ def test_hma_relationship_validity(): def test_hma_not_fit_raises_sampling_error(): """Test that ``HMA`` will raise a ``SamplingError`` if it wasn't fit.""" # Setup - _data, metadata = download_demo('multi_table', 'Dunur_v1') + _data, metadata = download_demo('multi_table', 'fake_hotels') synthesizer = HMASynthesizer(metadata) # Run and Assert @@ -2402,7 +2429,7 @@ def test_table_name_logging(caplog): def test_disjointed_tables(): """Test to see if synthesizer works with disjointed tables.""" # Setup - real_data, metadata = download_demo('multi_table', 'Bupa_v1') + real_data, metadata = download_demo('multi_table', 'fake_hotels') # Delete Some Relationships to make it disjointed remove_some_dict = metadata.to_dict() diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 19ba309aa..6c503a269 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -55,27 +55,108 @@ def data(): return {'parent': parent, 'child': child} -def test_simplify_schema(capsys): +@pytest.fixture +def large_data(): + great_grandparent = pd.DataFrame({'ggp_id': [1, 2, 3], 'ggp_data': ['A', 'B', 'C']}) + grandparent = pd.DataFrame({ + 'gp_id': [10, 11, 12, 13], + 'ggp_id': [1, 1, 2, 3], + 'gp_data': ['X', 'Y', 'Z', 'W'], + }) + parent = pd.DataFrame({ + 'p_id': [100, 101, 102, 103, 104], + 'gp_id': [10, 10, 11, 12, 13], + 'p_data': ['Alpha', 'Beta', 'Gamma', 'Delta', 'Epsilon'], + }) + child = pd.DataFrame({ + 'c_id': [1000, 1001, 1002, 1003, 1004, 1005], + 'p_id': [100, 100, 101, 102, 103, 104], + 'c_data': ['One', 'Two', 'Three', 'Four', 'Five', 'Six'], + }) + return { + 'great_grandparent': great_grandparent, + 'grandparent': grandparent, + 'parent': parent, + 'child': child, + } + + +@pytest.fixture +def large_metadata(): + return Metadata.load_from_dict({ + 'tables': { + 'great_grandparent': { + 'columns': {'ggp_id': {'sdtype': 'id'}, 'ggp_data': {'sdtype': 'categorical'}}, + 'primary_key': 'ggp_id', + }, + 'grandparent': { + 'columns': { + 'gp_id': {'sdtype': 'id'}, + 'ggp_id': {'sdtype': 'id'}, + 'gp_data': {'sdtype': 'categorical'}, + }, + 'primary_key': 'gp_id', + }, + 'parent': { + 'columns': { + 'p_id': {'sdtype': 'id'}, + 'gp_id': {'sdtype': 'id'}, + 'p_data': {'sdtype': 'categorical'}, + }, + 'primary_key': 'p_id', + }, + 'child': { + 'columns': { + 'c_id': {'sdtype': 'id'}, + 'p_id': {'sdtype': 'id'}, + 'c_data': {'sdtype': 'categorical'}, + }, + 'primary_key': 'c_id', + }, + }, + 'relationships': [ + { + 'parent_table_name': 'great_grandparent', + 'parent_primary_key': 'ggp_id', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'ggp_id', + }, + { + 'parent_table_name': 'grandparent', + 'parent_primary_key': 'gp_id', + 'child_table_name': 'parent', + 'child_foreign_key': 'gp_id', + }, + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'p_id', + 'child_table_name': 'child', + 'child_foreign_key': 'p_id', + }, + ], + }) + + +def test_simplify_schema(capsys, large_data, large_metadata): """Test ``simplify_schema`` end to end.""" # Setup - data, metadata = download_demo('multi_table', 'AustralianFootball_v1') - num_estimated_column_before_simplification = _get_total_estimated_columns(metadata) - HMASynthesizer(metadata) + num_estimated_column_before_simplification = _get_total_estimated_columns(large_metadata) + HMASynthesizer(large_metadata) captured_before_simplification = capsys.readouterr() # Run - data_simplify, metadata_simplify = simplify_schema(data, metadata) + data_simplify, metadata_simplify = simplify_schema(large_data, large_metadata) captured_after_simplification = capsys.readouterr() # Assert expected_message_before = re.compile( r'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended\.' - r' To model this data, HMA will generate a large number of columns\. \(173818 columns\)\s+' + r' To model this data, HMA will generate a large number of columns\. \(1034 columns\)\s+' r'Table Name\s*#\s*Columns in Metadata\s*Est # Columns\s*' - r'match_stats\s*24\s*24\s*' - r'matches\s*39\s*412\s*' - r'players\s*5\s*378\s*' - r'teams\s*1\s*173004\s*' + r'great_grandparent\s*1\s*986\s*' + r'grandparent\s*1\s*41\s*' + r'parent\s*1\s*6\s*' + r'child\s*1\s*1\s*' r'We recommend simplifying your metadata schema using ' r"'sdv.utils.poc.simplify_schema'\.\s*" r'If this is not possible, please visit ' @@ -84,24 +165,24 @@ def test_simplify_schema(capsys): expected_message_after = re.compile( r'Success! The schema has been simplified\.\s+' r'Table Name\s*#\s*Columns \(Before\)\s*#\s*Columns \(After\)\s*' - r'match_stats\s*28\s*3\s*' - r'matches\s*42\s*21\s*' - r'players\s*6\s*0\s*' - r'teams\s*2\s*2' + r'child\s*3\s*0\s*' + r'grandparent\s*3\s*3\s*' + r'great_grandparent\s*2\s*2\s*' + r'parent\s*3\s*2' ) assert expected_message_before.match(captured_before_simplification.out.strip()) assert expected_message_after.match(captured_after_simplification.out.strip()) metadata_simplify.validate() metadata_simplify.validate_data(data_simplify) num_estimated_column_after_simplification = _get_total_estimated_columns(metadata_simplify) - assert num_estimated_column_before_simplification == 173818 - assert num_estimated_column_after_simplification == 517 + assert num_estimated_column_before_simplification == 1034 + assert num_estimated_column_after_simplification == 13 def test_simpliy_nothing_to_simplify(): """Test ``simplify_schema`` end to end when no simplification is required.""" # Setup - data, metadata = download_demo('multi_table', 'Biodegradability_v1') + data, metadata = download_demo('multi_table', 'fake_hotels') # Run data_simplify, metadata_simplify = simplify_schema(data, metadata) @@ -117,117 +198,101 @@ def test_simpliy_nothing_to_simplify(): def test_simplify_no_grandchild(): """Test ``simplify_schema`` end to end when there is no grandchild table.""" # Setup - data, metadata = download_demo('multi_table', 'MuskSmall_v1') - num_estimated_column_before_simplification = _get_total_estimated_columns(metadata) + parent_data = pd.DataFrame({ + 'parent_id': range(500), + 'parent_col1': np.random.choice(['A', 'B', 'C'], 500), + 'parent_col2': np.random.randn(500), + }) + child_columns = {'child_id': range(500), 'parent_id': np.random.choice(range(500), 500)} + for i in range(168): + child_columns[f'child_col_{i}'] = np.random.choice(['X', 'Y', 'Z'], 500) + child_data = pd.DataFrame(child_columns) + data = {'parent': parent_data, 'child': child_data} + parent_columns = { + 'parent_id': {'sdtype': 'id'}, + 'parent_col1': {'sdtype': 'categorical'}, + 'parent_col2': {'sdtype': 'numerical'}, + } + child_columns_meta = {'child_id': {'sdtype': 'id'}, 'parent_id': {'sdtype': 'id'}} + for i in range(168): + child_columns_meta[f'child_col_{i}'] = {'sdtype': 'categorical'} + + metadata = Metadata.load_from_dict({ + 'tables': { + 'parent': {'columns': parent_columns, 'primary_key': 'parent_id'}, + 'child': {'columns': child_columns_meta, 'primary_key': 'child_id'}, + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'parent_id', + 'child_table_name': 'child', + 'child_foreign_key': 'parent_id', + } + ], + }) # Run + num_estimated_column_before_simplification = _get_total_estimated_columns(metadata) data_simplify, metadata_simplify = simplify_schema(data, metadata) # Assert metadata_simplify.validate() metadata_simplify.validate_data(data_simplify) num_estimated_column_after_simplification = _get_total_estimated_columns(metadata_simplify) - assert num_estimated_column_before_simplification == 14527 - assert num_estimated_column_after_simplification == 982 + assert num_estimated_column_before_simplification > num_estimated_column_after_simplification -def test_simplify_schema_big_demo_datasets(): +def test_simplify_schema_big_demo_datasets(large_data, large_metadata): """Test ``simplify_schema`` end to end for demo datasets that require simplification. This test will fail if the number of estimated columns after simplification is greater than the maximum number of columns allowed for any dataset. """ - # Setup - list_datasets = [ - 'AustralianFootball_v1', - 'MuskSmall_v1', - 'Countries_v1', - 'NBA_v1', - 'NCAA_v1', - 'PremierLeague_v1', - 'financial_v1', - ] - for dataset in list_datasets: - real_data, metadata = download_demo('multi_table', dataset) - - # Run - _data_simplify, metadata_simplify = simplify_schema(real_data, metadata) - - # Assert - estimate_column_before = _get_total_estimated_columns(metadata) - estimate_column_after = _get_total_estimated_columns(metadata_simplify) - assert estimate_column_before > MAX_NUMBER_OF_COLUMNS - assert estimate_column_after <= MAX_NUMBER_OF_COLUMNS - - -@pytest.mark.parametrize( - ('dataset_name', 'main_table_1', 'main_table_2', 'num_rows_1', 'num_rows_2'), - [ - ('AustralianFootball_v1', 'matches', 'players', 1000, 1000), - ('MuskSmall_v1', 'molecule', 'conformation', 50, 150), - ('NBA_v1', 'Team', 'Actions', 10, 200), - ('NCAA_v1', 'tourney_slots', 'tourney_compact_results', 1000, 1000), - ], -) -def test_get_random_subset(dataset_name, main_table_1, main_table_2, num_rows_1, num_rows_2): - """Test ``get_random_subset`` end to end. - - The goal here is test that the function works for various schema and also by subsampling - different main tables. + # Run + _data_simplify, metadata_simplify = simplify_schema(large_data, large_metadata) - For `AustralianFootball_v1` (parent with child and grandparent): - - main table 1 = `matches` which is the child of `teams` and the parent of `match_stats`. - - main table 2 = `players` which is the parent of `matches`. + # Assert + estimate_column_before = _get_total_estimated_columns(large_metadata) + estimate_column_after = _get_total_estimated_columns(metadata_simplify) + assert estimate_column_before > MAX_NUMBER_OF_COLUMNS + assert estimate_column_after <= MAX_NUMBER_OF_COLUMNS - For `MuskSmall_v1` (1 parent - 1 child relationship): - - main table 1 = `molecule` which is the parent of `conformation`. - - main table 2 = `conformation` which is the child of `molecule`. - For `NBA_v1` (child with parents and grandparent): - - main table 1 = `Team` which is the root table. - - main table 2 = `Actions` which is the last child. It has relationships with `Game` and `Team` - and `Player`. +def test_get_random_subset(): + """Test ``get_random_subset`` end to end. - For `NCAA_v1` (child with multiple parents): - - main table 1 = `tourney_slots` which is only the child of `seasons`. - - main table 2 = `tourney_compact_results` which is the child of `teams` with two relationships - and of `seasons` with one relationship. + The goal here is test that the function works for various schema and also by subsampling + different main tables. """ # Setup - real_data, metadata = download_demo('multi_table', dataset_name) + real_data, metadata = download_demo('multi_table', 'fake_hotels') # Run - result_1 = get_random_subset(real_data, metadata, main_table_1, num_rows_1, verbose=False) - result_2 = get_random_subset(real_data, metadata, main_table_2, num_rows_2, verbose=False) + result_1 = get_random_subset(real_data, metadata, 'hotels', 10, verbose=False) + result_2 = get_random_subset(real_data, metadata, 'guests', 20, verbose=False) # Assert - assert len(result_1[main_table_1]) == num_rows_1 - assert len(result_2[main_table_2]) == num_rows_2 + assert len(result_1['hotels']) == 10 + assert len(result_2['guests']) == 20 def test_get_random_subset_disconnected_schema(): - """Test ``get_random_subset`` end to end for a disconnected schema. - - Here we break the schema so there is only parent-child relationships between - `Player`-`Action` and `Team`-`Game`. - The part that is not connected to the main table (`Player`) should be subsampled also - in a similar proportion. - """ + """Test ``get_random_subset`` end to end for a disconnected schema.""" # Setup - real_data, metadata = download_demo('multi_table', 'NBA_v1') - metadata.remove_relationship('Game', 'Actions') - metadata.remove_relationship('Team', 'Actions') + real_data, metadata = download_demo('multi_table', 'fake_hotels') + metadata.remove_relationship('hotels', 'guests') metadata.validate = Mock() metadata.validate_data = Mock() proportion_to_keep = 0.6 - num_rows_to_keep = int(len(real_data['Player']) * proportion_to_keep) + num_rows_to_keep = int(len(real_data['guests']) * proportion_to_keep) # Run - result = get_random_subset(real_data, metadata, 'Player', num_rows_to_keep, verbose=False) + result = get_random_subset(real_data, metadata, 'guests', num_rows_to_keep, verbose=False) # Assert - assert len(result['Player']) == num_rows_to_keep - assert len(result['Team']) == int(len(real_data['Team']) * proportion_to_keep) + assert len(result['guests']) == num_rows_to_keep + assert len(result['hotels']) >= int(len(real_data['hotels']) * proportion_to_keep) def test_get_random_subset_with_missing_values(metadata, data): diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index dbb2e7cf8..8dd95abde 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -1,11 +1,36 @@ +import io +import json +import logging import re -from unittest.mock import MagicMock, Mock, patch +import zipfile +from unittest.mock import Mock, patch import numpy as np import pandas as pd import pytest -from sdv.datasets.demo import _download, _get_data_from_bucket, download_demo, get_available_demos +from sdv.datasets.demo import ( + _download, + _find_data_zip_key, + _find_text_key, + _get_data_from_bucket, + _get_first_v1_metadata_bytes, + _get_metadata, + _get_text_file_content, + _iter_metainfo_yaml_entries, + download_demo, + get_available_demos, + get_readme, + get_source, +) +from sdv.errors import DemoResourceNotFoundError, DemoResourceNotFoundWarning + + +def _make_zip_with_csv(csv_name: str, df: pd.DataFrame) -> bytes: + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr(csv_name, df.to_csv(index=False)) + return buf.getvalue() def test_download_demo_invalid_modality(): @@ -27,27 +52,44 @@ def test_download_demo_folder_already_exists(tmpdir): download_demo('single_table', 'dataset_name', tmpdir) -def test_download_demo_dataset_doesnt_exist(): - """Test it crashes when ``dataset_name`` doesn't exist.""" - # Run and Assert - err_msg = re.escape( - "Invalid dataset name 'invalid_dataset'. " - 'Make sure you have the correct modality for the dataset name or ' - "use 'get_available_demos' to get a list of demo datasets." - ) - with pytest.raises(ValueError, match=err_msg): - download_demo('single_table', 'invalid_dataset') +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_single_table(mock_list, mock_get, tmpdir): + """Test it can download a single table dataset using the new structure.""" + mock_list.return_value = [ + {'Key': 'single_table/ring/data.zip'}, + {'Key': 'single_table/ring/metadata.json'}, + ] + df = pd.DataFrame({'0': [0, 0], '1': [0, 0]}) + zip_bytes = _make_zip_with_csv('ring.csv', df) + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'ring': { + 'columns': { + '0': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + '1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + } + } + }, + 'relationships': [], + }).encode() + + def side_effect(key): + if key.endswith('data.zip'): + return zip_bytes + if key.endswith('metadata.json'): + return meta_bytes + raise KeyError(key) + mock_get.side_effect = side_effect -def test_download_demo_single_table(tmpdir): - """Test it can download a single table dataset.""" # Run table, metadata = download_demo('single_table', 'ring', tmpdir / 'test_folder') # Assert expected_table = pd.DataFrame({'0': [0, 0], '1': [0, 0]}) pd.testing.assert_frame_equal(table.head(2), expected_table) - expected_metadata_dict = { 'tables': { 'ring': { @@ -63,13 +105,13 @@ def test_download_demo_single_table(tmpdir): assert metadata.to_dict() == expected_metadata_dict -@patch('boto3.Session') +@patch('sdv.datasets.demo._create_s3_client') @patch('sdv.datasets.demo.BUCKET', 'bucket') -def test__get_data_from_bucket(session_mock): +def test__get_data_from_bucket(create_client_mock): """Test the ``_get_data_from_bucket`` method.""" # Setup mock_s3_client = Mock() - session_mock.return_value.client.return_value = mock_s3_client + create_client_mock.return_value = mock_s3_client mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: b'data')} # Run @@ -77,32 +119,60 @@ def test__get_data_from_bucket(session_mock): # Assert assert result == b'data' - session_mock.assert_called_once() + create_client_mock.assert_called_once() mock_s3_client.get_object.assert_called_once_with(Bucket='bucket', Key='object_key') @patch('sdv.datasets.demo._get_data_from_bucket') -def test__download(mock_get_data_from_bucket): - """Test the ``_download`` method.""" +@patch('sdv.datasets.demo._list_objects') +def test__download(mock_list, mock_get_data_from_bucket): + """Test the ``_download`` method with new structure.""" # Setup - mock_get_data_from_bucket.return_value = b'' + mock_list.return_value = [ + {'Key': 'single_table/ring/data.zip'}, + {'Key': 'single_table/ring/metadata.json'}, + ] + mock_get_data_from_bucket.return_value = json.dumps({'METADATA_SPEC_VERSION': 'V1'}).encode() # Run - _download('single_table', 'ring') + data_io, metadata_bytes = _download('single_table', 'ring') # Assert - mock_get_data_from_bucket.assert_called_once_with('SINGLE_TABLE/ring.zip') + assert isinstance(data_io, io.BytesIO) + assert isinstance(metadata_bytes, (bytes, bytearray)) -def test_download_demo_single_table_no_output_folder(): +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_single_table_no_output_folder(mock_list, mock_get): """Test it can download a single table dataset when no output folder is passed.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/ring/data.zip'}, + {'Key': 'single_table/ring/metadata.json'}, + ] + df = pd.DataFrame({'0': [0, 0], '1': [0, 0]}) + zip_bytes = _make_zip_with_csv('ring.csv', df) + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'ring': { + 'columns': { + '0': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + '1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + } + } + }, + 'relationships': [], + }).encode() + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + # Run table, metadata = download_demo('single_table', 'ring') # Assert expected_table = pd.DataFrame({'0': [0, 0], '1': [0, 0]}) pd.testing.assert_frame_equal(table.head(2), expected_table) - expected_metadata_dict = { 'tables': { 'ring': { @@ -115,12 +185,43 @@ def test_download_demo_single_table_no_output_folder(): 'METADATA_SPEC_VERSION': 'V1', 'relationships': [], } - assert metadata.to_dict() == expected_metadata_dict -def test_download_demo_timeseries(tmpdir): - """Test it can download a timeseries dataset.""" +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_timeseries(mock_list, mock_get, tmpdir): + """Test it can download a timeseries dataset using new structure.""" + # Setup + mock_list.return_value = [ + {'Key': 'sequential/Libras/data.zip'}, + {'Key': 'sequential/Libras/metadata.json'}, + ] + df = pd.DataFrame({ + 'ml_class': [1, 1], + 'e_id': [0, 0], + 's_index': [0, 1], + 'tt_split': [1, 1], + 'dim_0': [0.67892, 0.68085], + 'dim_1': [0.27315, 0.27315], + }) + zip_bytes = _make_zip_with_csv('Libras.csv', df) + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'relationships': [], + 'tables': { + 'Libras': { + 'columns': { + 'e_id': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'dim_0': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'dim_1': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'ml_class': {'sdtype': 'categorical'}, + } + } + }, + }).encode() + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + # Run table, metadata = download_demo('sequential', 'Libras', tmpdir / 'test_folder') @@ -134,7 +235,6 @@ def test_download_demo_timeseries(tmpdir): 'dim_1': [0.27315, 0.27315], }) pd.testing.assert_frame_equal(table.head(2), expected_table) - expected_metadata_dict = { 'METADATA_SPEC_VERSION': 'V1', 'relationships': [], @@ -152,34 +252,31 @@ def test_download_demo_timeseries(tmpdir): assert metadata.to_dict() == expected_metadata_dict -def test_download_demo_multi_table(tmpdir): - """Test it can download a multi table dataset.""" - # Run - tables, metadata = download_demo('multi_table', 'got_families', tmpdir / 'test_folder') - - # Assert - expected_families = pd.DataFrame({ - 'family_id': [1, 2], - 'name': ['Stark', 'Tully'], - }) - pd.testing.assert_frame_equal(tables['families'].head(2), expected_families) - - expected_character_families = pd.DataFrame({ +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_multi_table(mock_list, mock_get, tmpdir): + """Test it can download a multi table dataset using the new structure.""" + # Setup + mock_list.return_value = [ + {'Key': 'multi_table/got_families/data.zip'}, + {'Key': 'multi_table/got_families/metadata.json'}, + ] + families = pd.DataFrame({'family_id': [1, 2], 'name': ['Stark', 'Tully']}) + character_families = pd.DataFrame({ 'character_id': [1, 1], 'family_id': [1, 4], 'generation': [8, 5], 'type': ['father', 'mother'], }) - pd.testing.assert_frame_equal(tables['character_families'].head(2), expected_character_families) - - expected_characters = pd.DataFrame({ - 'age': [20, 16], - 'character_id': [1, 2], - 'name': ['Jon', 'Arya'], - }) - pd.testing.assert_frame_equal(tables['characters'].head(2), expected_characters) + characters = pd.DataFrame({'age': [20, 16], 'character_id': [1, 2], 'name': ['Jon', 'Arya']}) - expected_metadata_dict = { + zip_buf = io.BytesIO() + with zipfile.ZipFile(zip_buf, 'w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('families.csv', families.to_csv(index=False)) + zf.writestr('character_families.csv', character_families.to_csv(index=False)) + zf.writestr('characters.csv', characters.to_csv(index=False)) + zip_bytes = zip_buf.getvalue() + meta_bytes = json.dumps({ 'tables': { 'characters': { 'columns': { @@ -220,7 +317,17 @@ def test_download_demo_multi_table(tmpdir): }, ], 'METADATA_SPEC_VERSION': 'V1', - } + }).encode() + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + + # Run + tables, metadata = download_demo('multi_table', 'got_families', tmpdir / 'test_folder') + + # Assert + pd.testing.assert_frame_equal(tables['families'].head(2), families.head(2)) + pd.testing.assert_frame_equal(tables['character_families'].head(2), character_families.head(2)) + pd.testing.assert_frame_equal(tables['characters'].head(2), characters.head(2)) + expected_metadata_dict = json.loads(meta_bytes.decode()) assert metadata.to_dict() == expected_metadata_dict @@ -232,59 +339,859 @@ def test_get_available_demos_invalid_modality(): get_available_demos('invalid_modality') -@patch('boto3.client') -def test_get_available_demos(client_mock): - """Test it gets the correct output.""" +def test__find_data_zip_key(): # Setup - contents_objects = { - 'Contents': [{'Key': 'SINGLE_TABLE/dataset1.zip'}, {'Key': 'SINGLE_TABLE/dataset2.zip'}] - } - client_mock.return_value.list_objects = Mock(return_value=contents_objects) + contents = [ + {'Key': 'single_table/fake_hotel_guests/data.ZIP'}, + {'Key': 'single_table/fake_hotel_guests/metadata.json'}, + ] + dataset_prefix = 'single_table/fake_hotel_guests/' - def metadata_func(Bucket, Key): # noqa: N803 - if Key == 'SINGLE_TABLE/dataset1.zip': - return {'Metadata': {'size-mb': 123, 'num-tables': 321}} + # Run + zip_key = _find_data_zip_key(contents, dataset_prefix) + + # Assert + assert zip_key == 'single_table/fake_hotel_guests/data.ZIP' + + +@patch('sdv.datasets.demo._get_data_from_bucket') +def test__get_first_v1_metadata_bytes(mock_get): + # Setup + v2 = json.dumps({'METADATA_SPEC_VERSION': 'V2'}).encode() + bad = b'not-json' + v1 = json.dumps({'METADATA_SPEC_VERSION': 'V1'}).encode() - return {'Metadata': {'size-mb': 456, 'num-tables': 654}} + def side_effect(key): + return { + 'single_table/dataset/k1.json': v2, + 'single_table/dataset/k2.json': bad, + 'single_table/dataset/k_metadata_k.json': v1, + }[key] - client_mock.return_value.head_object = MagicMock(side_effect=metadata_func) + mock_get.side_effect = side_effect + contents = [ + {'Key': 'single_table/dataset/k1.json'}, + {'Key': 'single_table/dataset/k2.json'}, + {'Key': 'single_table/dataset/k_metadata_k.json'}, + ] # Run - tables_info = get_available_demos('single_table') + got = _get_first_v1_metadata_bytes(contents, 'single_table/dataset/') # Assert - expected_table = pd.DataFrame({ - 'dataset_name': ['dataset1', 'dataset2'], - 'size_MB': [123.00, 456.00], - 'num_tables': [321, 654], - }) - pd.testing.assert_frame_equal(tables_info, expected_table) + assert got == v1 + + +def test__iter_metainfo_yaml_entries_filters(): + # Setup + contents = [ + {'Key': 'single_table/d1/metainfo.yaml'}, + {'Key': 'single_table/d1/METAINFO.YAML'}, + {'Key': 'single_table/d2/not.yaml'}, + {'Key': 'multi_table/d3/metainfo.yaml'}, + {'Key': 'single_table/metainfo.yaml'}, + ] + + # Run + got = list(_iter_metainfo_yaml_entries(contents, 'single_table')) + + # Assert + assert ('d1', 'single_table/d1/metainfo.yaml') in got + assert ('d1', 'single_table/d1/METAINFO.YAML') in got + assert all(name != 'd3' for name, _ in got) + assert all(key != 'single_table/metainfo.yaml' for _, key in got) + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_get_available_demos_robust_parsing(mock_list, mock_get): + # Setup + mock_list.return_value = [ + {'Key': 'single_table/d1/metainfo.yaml'}, + {'Key': 'single_table/d2/metainfo.yaml'}, + {'Key': 'single_table/bad/metainfo.yaml'}, + {'Key': 'single_table/ignore.txt'}, + ] + + def side_effect(key): + if key.endswith('d1/metainfo.yaml'): + return b'dataset-name: d1\nnum-tables: 2\ndataset-size-mb: 10.5\nsource: EXTERNAL\n' + if key.endswith('d2/metainfo.yaml'): + return b'dataset-name: d2\nnum-tables: not_a_number\ndataset-size-mb: NaN\n' + raise ValueError('invalid yaml') + + mock_get.side_effect = side_effect + + # Run + df = get_available_demos('single_table') + assert set(df['dataset_name']) == {'d1', 'd2'} + # Assert + # d1 parsed correctly + row1 = df[df['dataset_name'] == 'd1'].iloc[0] + assert row1['num_tables'] == 2 + assert row1['size_MB'] == 10.5 + # d2 falls back to NaN + row2 = df[df['dataset_name'] == 'd2'].iloc[0] + assert np.isnan(row2['num_tables']) or row2['num_tables'] is None + assert np.isnan(row2['size_MB']) or row2['size_MB'] is None + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_get_available_demos_logs_invalid_size_mb(mock_list, mock_get, caplog): + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dsize/metainfo.yaml'}, + ] + + def side_effect(key): + return b'dataset-name: dsize\nnum-tables: 2\ndataset-size-mb: invalid\n' + + mock_get.side_effect = side_effect + + # Run + caplog.set_level(logging.INFO, logger='sdv.datasets.demo') + df = get_available_demos('single_table') + + # Assert + expected = 'Invalid dataset-size-mb invalid for dataset dsize; defaulting to NaN.' + assert expected in caplog.messages + row = df[df['dataset_name'] == 'dsize'].iloc[0] + assert row['num_tables'] == 2 + assert np.isnan(row['size_MB']) or row['size_MB'] is None + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_get_available_demos_logs_num_tables_str_cast_fail_exact(mock_list, mock_get, caplog): + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dnum/metainfo.yaml'}, + ] + + def side_effect(key): + return b'dataset-name: dnum\nnum-tables: not_a_number\ndataset-size-mb: 1.1\n' + + mock_get.side_effect = side_effect + + # Run + caplog.set_level(logging.INFO, logger='sdv.datasets.demo') + df = get_available_demos('single_table') + + # Assert + expected = ( + 'Could not cast num_tables_val not_a_number to float for dataset dnum; defaulting to NaN.' + ) + assert expected in caplog.messages + row = df[df['dataset_name'] == 'dnum'].iloc[0] + assert np.isnan(row['num_tables']) or row['num_tables'] is None + assert row['size_MB'] == 1.1 + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_get_available_demos_logs_num_tables_int_parse_fail_exact(mock_list, mock_get, caplog): + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dnum/metainfo.yaml'}, + ] + + def side_effect(key): + return b'dataset-name: dnum\nnum-tables: [1, 2]\ndataset-size-mb: 1.1\n' + + mock_get.side_effect = side_effect + + # Run + caplog.set_level(logging.INFO, logger='sdv.datasets.demo') + df = get_available_demos('single_table') + + # Assert + expected = 'Invalid num-tables [1, 2] for dataset dnum when parsing as int.' + assert expected in caplog.messages + row = df[df['dataset_name'] == 'dnum'].iloc[0] + assert np.isnan(row['num_tables']) or row['num_tables'] is None + assert row['size_MB'] == 1.1 -@patch('boto3.client') -def test_get_available_demos_missing_metadata(client_mock): - """Test it can handle data with missing metadata information.""" + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_get_available_demos_ignores_yaml_dataset_name_mismatch(mock_list, mock_get): + """When YAML dataset-name mismatches folder, use folder name from S3 path.""" # Setup - contents_objects = { - 'Contents': [{'Key': 'SINGLE_TABLE/dataset1.zip'}, {'Key': 'SINGLE_TABLE/dataset2.zip'}] + mock_list.return_value = [ + {'Key': 'single_table/folder_name/metainfo.yaml'}, + ] + + # YAML uses a different name; should be ignored for dataset_name field + def side_effect(key): + return b'dataset-name: DIFFERENT\nnum-tables: 3\ndataset-size-mb: 2.5\n' + + mock_get.side_effect = side_effect + + # Run + df = get_available_demos('single_table') + + # Assert + assert set(df['dataset_name']) == {'folder_name'} + row = df[df['dataset_name'] == 'folder_name'].iloc[0] + assert row['num_tables'] == 3 + assert row['size_MB'] == 2.5 + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_success_single_table(mock_list, mock_get): + # Setup + mock_list.return_value = [ + {'Key': 'single_table/word/data.ZIP'}, + {'Key': 'single_table/word/metadata.json'}, + ] + df = pd.DataFrame({'id': [1, 2], 'name': ['a', 'b']}) + zip_bytes = _make_zip_with_csv('word.csv', df) + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'word': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'name': {'sdtype': 'categorical'}, + }, + 'primary_key': 'id', + } + }, + 'relationships': [], + }).encode() + + def side_effect(key): + if key.endswith('data.ZIP'): + return zip_bytes + if key.endswith('metadata.json'): + return meta_bytes + raise KeyError(key) + + mock_get.side_effect = side_effect + + # Run + data, metadata = download_demo('single_table', 'word') + + # Assert + assert isinstance(data, pd.DataFrame) + assert set(data.columns) == {'id', 'name'} + assert metadata.to_dict()['tables']['word']['primary_key'] == 'id' + + +@patch('sdv.datasets.demo._get_data_from_bucket', return_value=b'{}') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_missing_zip_raises(mock_list, _mock_get): + # Setup + mock_list.return_value = [ + {'Key': 'single_table/word/metadata.json'}, + ] + + # Run and Assert + with pytest.raises(DemoResourceNotFoundError, match="Could not find 'data.zip'"): + download_demo('single_table', 'word') + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_no_v1_metadata_raises(mock_list, mock_get): + # Setup + mock_list.return_value = [ + {'Key': 'single_table/word/data.zip'}, + {'Key': 'single_table/word/metadata.json'}, + ] + mock_get.side_effect = lambda key: json.dumps({'METADATA_SPEC_VERSION': 'V2'}).encode() + + # Run and Assert + with pytest.raises(DemoResourceNotFoundError, match='METADATA_SPEC_VERSION'): + download_demo('single_table', 'word') + + +@patch('builtins.open', side_effect=OSError('fail-open')) +def test__get_metadata_warns_on_save_error(_mock_open, tmp_path): + """_get_metadata should emit a warning if writing metadata.json fails.""" + # Setup + meta = { + 'METADATA_SPEC_VERSION': 'V1', + 'relationships': [], + 'tables': { + 't': { + 'columns': { + 'a': {'sdtype': 'numerical'}, + } + } + }, + } + meta_bytes = json.dumps(meta).encode() + out_dir = tmp_path / 'out' + out_dir.mkdir(parents=True, exist_ok=True) + + # Run and Assert + warn_msg = 'Error saving metadata.json' + with pytest.warns(DemoResourceNotFoundWarning, match=warn_msg): + md = _get_metadata(meta_bytes, 'dataset1', str(out_dir)) + + assert md.to_dict() == meta + + +def test__get_metadata_raises_on_invalid_json(): + """_get_metadata should raise a helpful error when JSON is invalid.""" + # Run / Assert + err = 'Failed to parse metadata JSON for the dataset.' + with pytest.raises(DemoResourceNotFoundError, match=re.escape(err)): + _get_metadata(b'not-json', 'dataset1') + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_writes_metadata_and_discovers_nested_csv(mock_list, mock_get, tmp_path): + """When output folder is set, it writes metadata.json and finds nested CSVs.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/nested/data.zip'}, + {'Key': 'single_table/nested/metadata.json'}, + ] + + df = pd.DataFrame({'a': [1, 2], 'b': ['x', 'y']}) + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('level1/level2/my_table.csv', df.to_csv(index=False)) + zip_bytes = buf.getvalue() + + meta_dict = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'my_table': { + 'columns': { + 'a': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'b': {'sdtype': 'categorical'}, + } + } + }, + 'relationships': [], } - client_mock.return_value.list_objects = Mock(return_value=contents_objects) + meta_bytes = json.dumps(meta_dict).encode() - def metadata_func(Bucket, Key): # noqa: N803 - if Key == 'SINGLE_TABLE/dataset1.zip': - return {'Metadata': {}} + def side_effect(key): + if key.endswith('data.zip'): + return zip_bytes + if key.endswith('metadata.json'): + return meta_bytes + raise KeyError(key) - return {'Metadata': {'size-mb': 456, 'num-tables': 654}} + mock_get.side_effect = side_effect - client_mock.return_value.head_object = MagicMock(side_effect=metadata_func) + out = tmp_path / 'outdir' # Run - tables_info = get_available_demos('single_table') + data, metadata = download_demo('single_table', 'nested', out) # Assert - expected_table = pd.DataFrame({ - 'dataset_name': ['dataset1', 'dataset2'], - 'size_MB': [np.nan, 456.00], - 'num_tables': [np.nan, 654], - }) - pd.testing.assert_frame_equal(tables_info, expected_table) + pd.testing.assert_frame_equal(data, df) + assert metadata.to_dict() == meta_dict + + meta_path = out / 'metadata.json' + assert meta_path.is_file() + + with open(meta_path, 'rb') as f: + on_disk = f.read() + assert on_disk == meta_bytes + + +def test__find_text_key_returns_none_when_missing(): + """Test it returns None when the key is missing.""" + # Setup + contents = [ + {'Key': 'single_table/dataset/metadata.json'}, + {'Key': 'single_table/dataset/data.zip'}, + ] + dataset_prefix = 'single_table/dataset/' + + # Run + key = _find_text_key(contents, dataset_prefix, 'README.txt') + + # Assert + assert key is None + + +def test__find_text_key_ignores_nested_paths(): + """Test it ignores files in nested folders under the dataset prefix.""" + # Setup + contents = [ + {'Key': 'single_table/dataset1/bad_folder/SOURCE.txt'}, + ] + dataset_prefix = 'single_table/dataset1/' + + # Run + key = _find_text_key(contents, dataset_prefix, 'SOURCE.txt') + + # Assert + assert key is None + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test__get_text_file_content_happy_path(mock_list, mock_get, tmpdir): + """Test it gets the text file content when it exists.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/README.txt'}, + ] + mock_get.return_value = 'Hello README'.encode() + + # Run + text = _get_text_file_content('single_table', 'dataset1', 'README.txt') + + # Assert + assert text == 'Hello README' + + +@patch('sdv.datasets.demo._list_objects') +def test__get_text_file_content_missing_key_returns_none(mock_list): + """Test it returns None when the key is missing.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/metadata.json'}, + ] + + # Run + text = _get_text_file_content('single_table', 'dataset1', 'README.txt') + + # Assert + assert text is None + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test__get_text_file_content_fetch_error_returns_none(mock_list, mock_get): + """Test it returns None when the fetch error occurs.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/SOURCE.txt'}, + ] + mock_get.side_effect = Exception('boom') + + # Run + text = _get_text_file_content('single_table', 'dataset1', 'SOURCE.txt') + + # Assert + assert text is None + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test__get_text_file_content_logs_on_fetch_error(mock_list, mock_get, caplog): + """It logs an info when fetching the key raises an error.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/SOURCE.txt'}, + ] + mock_get.side_effect = Exception('boom') + + # Run + caplog.set_level(logging.INFO, logger='sdv.datasets.demo') + text = _get_text_file_content('single_table', 'dataset1', 'SOURCE.txt') + + # Assert + assert text is None + assert 'Error fetching SOURCE.txt for dataset dataset1.' in caplog.text + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test__get_text_file_content_writes_file_when_output_filepath_given( + mock_list, mock_get, tmp_path +): + """Test it writes the file when the output filepath is given.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/README.txt'}, + ] + mock_get.return_value = 'Write me'.encode() + out = tmp_path / 'subdir' / 'readme.txt' + + # Run + text = _get_text_file_content('single_table', 'dataset1', 'README.txt', str(out)) + + # Assert + assert text == 'Write me' + with open(out, 'r', encoding='utf-8') as f: + assert f.read() == 'Write me' + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test__get_text_file_content_logs_on_save_error( + mock_list, mock_get, tmp_path, caplog, monkeypatch +): + """It logs an info when saving to disk fails.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/README.txt'}, + ] + mock_get.return_value = 'Write me'.encode() + out = tmp_path / 'subdir' / 'readme.txt' + + def _fail_open(*args, **kwargs): + raise OSError('fail-open') + + monkeypatch.setattr('builtins.open', _fail_open) + + # Run + caplog.set_level(logging.INFO, logger='sdv.datasets.demo') + text = _get_text_file_content('single_table', 'dataset1', 'README.txt', str(out)) + + # Assert + assert text == 'Write me' + assert 'Error saving README.txt for dataset dataset1.' in caplog.text + + +def test_get_readme_and_get_source_call_wrapper(monkeypatch): + """Test it calls the wrapper function when the output filepath is given.""" + # Setup + calls = [] + + def fake(modality, dataset_name, filename, output_filepath=None): + calls.append((modality, dataset_name, filename, output_filepath)) + return 'X' + + monkeypatch.setattr('sdv.datasets.demo._get_text_file_content', fake) + + # Run + readme = get_readme('single_table', 'dataset1', '/tmp/readme.txt') + source = get_source('single_table', 'dataset1', '/tmp/source.txt') + + # Assert + assert readme == 'X' and source == 'X' + assert calls[0] == ('single_table', 'dataset1', 'README.txt', '/tmp/readme.txt') + assert calls[1] == ('single_table', 'dataset1', 'SOURCE.txt', '/tmp/source.txt') + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_get_readme_raises_if_output_file_exists(mock_list, mock_get, tmp_path): + """get_readme should raise ValueError if output file already exists.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/README.txt'}, + ] + mock_get.return_value = b'Readme contents' + out = tmp_path / 'subdir' / 'readme.txt' + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text('already here', encoding='utf-8') + + # Run / Assert + err = f"A file named '{out}' already exists. Please specify a different filepath." + with pytest.raises(ValueError, match=re.escape(err)): + get_readme('single_table', 'dataset1', str(out)) + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_get_source_raises_if_output_file_exists(mock_list, mock_get, tmp_path): + """get_source should raise ValueError if output file already exists.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/SOURCE.txt'}, + ] + mock_get.return_value = b'Source contents' + out = tmp_path / 'subdir' / 'source.txt' + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text('already here', encoding='utf-8') + + # Run / Assert + err = f"A file named '{out}' already exists. Please specify a different filepath." + with pytest.raises(ValueError, match=re.escape(err)): + get_source('single_table', 'dataset1', str(out)) + + +def test_get_readme_raises_for_non_txt_output(): + """get_readme should raise ValueError if output path is not .txt.""" + err = "The README can only be saved as a txt file. Please provide a filepath ending in '.txt'" + with pytest.raises(ValueError, match=re.escape(err)): + get_readme('single_table', 'dataset1', '/tmp/readme.md') + + +def test_get_source_raises_for_non_txt_output(): + """get_source should raise ValueError if output path is not .txt.""" + err = "The source can only be saved as a txt file. Please provide a filepath ending in '.txt'" + with pytest.raises(ValueError, match=re.escape(err)): + get_source('single_table', 'dataset1', '/tmp/source.pdf') + + +@patch('sdv.datasets.demo._list_objects') +def test_get_readme_missing_emits_warning(mock_list): + """When README is missing, warn the user with DemoResourceNotFoundWarning.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/metadata.json'}, + ] + + # Run / Assert + warn_msg = 'No README information is available for this dataset.' + with pytest.warns(DemoResourceNotFoundWarning, match=warn_msg): + result = get_readme('single_table', 'dataset1') + + assert result is None + + +@patch('sdv.datasets.demo._list_objects') +def test_get_source_missing_emits_warning(mock_list): + """When SOURCE is missing, warn the user with DemoResourceNotFoundWarning.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/metadata.json'}, + ] + + # Run / Assert + warn_msg = 'No source information is available for this dataset.' + with pytest.warns(DemoResourceNotFoundWarning, match=warn_msg): + result = get_source('single_table', 'dataset1') + + assert result is None + + +@patch('sdv.datasets.demo._list_objects') +def test_get_source_missing_emits_warning_and_does_not_create_file(mock_list, tmp_path): + """When source is missing and output path provided, warn and do not create a file.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/metadata.json'}, + ] + out = tmp_path / 'subdir' / 'source.txt' + + # Run / Assert + warn_msg = re.escape( + 'No source information is available for this dataset.' + f' The requested file ({str(out)}) will not be created.' + ) + with pytest.warns(DemoResourceNotFoundWarning, match=warn_msg): + result = get_source('single_table', 'dataset1', str(out)) + + assert result is None + + +@patch('sdv.datasets.demo._list_objects') +def test_get_readmemissing_emits_warning_and_does_not_create_file(mock_list, tmp_path): + """When README is missing and output path provided, warn and do not create a file.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/dataset1/metadata.json'}, + ] + out = tmp_path / 'subdir' / 'source.txt' + + # Run / Assert + warn_msg = re.escape( + 'No README information is available for this dataset.' + f' The requested file ({str(out)}) will not be created.' + ) + with pytest.warns(DemoResourceNotFoundWarning, match=warn_msg): + result = get_readme('single_table', 'dataset1', str(out)) + + assert result is None + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_raises_when_no_csv_in_zip_single_table(mock_list, mock_get): + """It should raise a helpful error if the zip contains no CSVs (single_table).""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/word/data.zip'}, + {'Key': 'single_table/word/metadata.json'}, + ] + + # Create a zip with a non-CSV file only + zip_buf = io.BytesIO() + with zipfile.ZipFile(zip_buf, 'w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('README.txt', 'no tables here') + + zip_bytes = zip_buf.getvalue() + meta_bytes = json.dumps({'METADATA_SPEC_VERSION': 'V1'}).encode() + + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + + # Run and Assert + msg = 'Demo data could not be downloaded because no csv files were found in data.zip' + with pytest.raises(DemoResourceNotFoundError, match=re.escape(msg)): + download_demo('single_table', 'word') + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_skips_non_csv_in_memory_no_warning(mock_list, mock_get): + """In-memory path: ignore non-CSV files silently; load valid CSVs.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/mix/data.zip'}, + {'Key': 'single_table/mix/metadata.json'}, + ] + + df = pd.DataFrame({'id': [1, 2], 'name': ['a', 'b']}) + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('good.csv', df.to_csv(index=False)) + zf.writestr('note.txt', 'hello world') + zf.writestr('nested/readme.md', '# readme') + # Add a directory entry explicitly + zf.writestr('empty_dir/', '') + zip_bytes = buf.getvalue() + + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'good': { + 'columns': { + 'id': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'name': {'sdtype': 'categorical'}, + } + } + }, + 'relationships': [], + }).encode() + + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + + # Run and Assert + warn_msg = 'Skipped files: empty_dir/, nested/readme.md, note.txt' + with pytest.warns(UserWarning, match=warn_msg) as rec: + data, _ = download_demo('single_table', 'mix') + + assert len(rec) == 1 + expected = pd.DataFrame({'id': [1, 2], 'name': ['a', 'b']}) + pd.testing.assert_frame_equal(data, expected) + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_on_disk_warns_failed_csv_only(mock_list, mock_get, tmp_path, monkeypatch): + """On-disk path: warn only for failed CSVs; non-CSV are skipped silently.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/mix/data.zip'}, + {'Key': 'single_table/mix/metadata.json'}, + ] + + good = pd.DataFrame({'x': [1, 2]}) + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('good.csv', good.to_csv(index=False)) + zf.writestr('bad.csv', 'will_fail') + zf.writestr('info.txt', 'ignore me') + zip_bytes = buf.getvalue() + + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'good': { + 'columns': { + 'x': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + } + } + }, + 'relationships': [], + }).encode() + + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + + # Force read_csv to fail on bad.csv only + orig_read_csv = pd.read_csv + + def fake_read_csv(path_or_buf, *args, **kwargs): + if isinstance(path_or_buf, str) and path_or_buf.endswith('bad.csv'): + raise ValueError('bad-parse') + return orig_read_csv(path_or_buf, *args, **kwargs) + + monkeypatch.setattr('pandas.read_csv', fake_read_csv) + + out_dir = tmp_path / 'mix_out' + + # Run and Assert + warn_msg = 'Skipped files: bad.csv: bad-parse, info.txt' + with pytest.warns(UserWarning, match=warn_msg) as rec: + data, _ = download_demo('single_table', 'mix', out_dir) + + assert len(rec) == 1 + pd.testing.assert_frame_equal(data, good) + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_handles_non_utf8_in_memory(mock_list, mock_get): + """It should successfully read Latin-1 encoded CSVs from in-memory extraction.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/nonutf/data.zip'}, + {'Key': 'single_table/nonutf/metadata.json'}, + ] + + df = pd.DataFrame({'id': [1], 'name': ['café']}) + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('nonutf.csv', df.to_csv(index=False).encode('latin-1')) + zip_bytes = buf.getvalue() + + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'nonutf': { + 'columns': { + 'id': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'name': {'sdtype': 'categorical'}, + } + } + }, + 'relationships': [], + }).encode() + + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + + # Run + data, _ = download_demo('single_table', 'nonutf') + + # Assert + expected = pd.DataFrame({'id': [1], 'name': ['café']}) + pd.testing.assert_frame_equal(data, expected) + + +@patch('sdv.datasets.demo._get_data_from_bucket') +@patch('sdv.datasets.demo._list_objects') +def test_download_demo_handles_non_utf8_on_disk(mock_list, mock_get, tmp_path): + """It should successfully read Latin-1 encoded CSVs when extracted to disk.""" + # Setup + mock_list.return_value = [ + {'Key': 'single_table/nonutf/data.zip'}, + {'Key': 'single_table/nonutf/metadata.json'}, + ] + + df = pd.DataFrame({'id': [1], 'name': ['café']}) + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('nonutf.csv', df.to_csv(index=False).encode('latin-1')) + zip_bytes = buf.getvalue() + + meta_bytes = json.dumps({ + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'nonutf': { + 'columns': { + 'id': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'name': {'sdtype': 'categorical'}, + } + } + }, + 'relationships': [], + }).encode() + + mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + + out_dir = tmp_path / 'latin_out' + + # Run + data, _ = download_demo('single_table', 'nonutf', out_dir) + + # Assert + expected = pd.DataFrame({'id': [1], 'name': ['café']}) + pd.testing.assert_frame_equal(data, expected) diff --git a/tests/unit/multi_table/test_dayz.py b/tests/unit/multi_table/test_dayz.py index 0d2dba095..ac214f695 100644 --- a/tests/unit/multi_table/test_dayz.py +++ b/tests/unit/multi_table/test_dayz.py @@ -5,7 +5,6 @@ import pandas as pd import pytest -from sdv.datasets.demo import download_demo from sdv.errors import SynthesizerInputError, SynthesizerProcessingError from sdv.metadata import Metadata from sdv.multi_table.dayz import ( @@ -439,21 +438,3 @@ def test__validate_relationships_is_list_of_dicts(self, metadata): with pytest.raises(SynthesizerProcessingError, match=expected_msg): DayZSynthesizer.validate_parameters(metadata, {'relationships': ['a', 'b', 'c']}) - - def test__validate_min_cardinality_allows_zero(self): - """Test that min_cardinality=0 is allowed and does not raise.""" - # Setup - data, metadata = download_demo('multi_table', 'financial_v1') - dayz_parameters = DayZSynthesizer.create_parameters(data, metadata) - dayz_parameters['relationships'] = [ - { - 'parent_table_name': 'district', - 'parent_primary_key': 'district_id', - 'child_table_name': 'account', - 'child_foreign_key': 'district_id', - 'min_cardinality': 0, - } - ] - - # Run - DayZSynthesizer.validate_parameters(metadata, dayz_parameters) diff --git a/tests/unit/single_table/test_dayz.py b/tests/unit/single_table/test_dayz.py index 76901fe50..4a1904c3d 100644 --- a/tests/unit/single_table/test_dayz.py +++ b/tests/unit/single_table/test_dayz.py @@ -6,10 +6,8 @@ import pandas as pd import pytest -from sdv.datasets.demo import download_demo from sdv.errors import SynthesizerInputError, SynthesizerProcessingError from sdv.metadata import Metadata -from sdv.multi_table.dayz import DayZSynthesizer as MultiTableDayZSynthesizer from sdv.single_table.dayz import ( DayZSynthesizer, _detect_column_parameters, @@ -560,21 +558,6 @@ def test__validate_parameters_errors_with_multi_table_metadata(self): with pytest.raises(SynthesizerProcessingError, match=expected_error_msg): _validate_parameters(metadata, dayz_parameters) - def test__validate_parameters_errors_with_relationships(self): - """Test that single-table validation errors if relationships are provided.""" - # Setup - data, metadata = download_demo('multi_table', 'financial_v1') - dayz_parameters = MultiTableDayZSynthesizer.create_parameters(data, metadata) - del dayz_parameters['relationships'] - - # Run and Assert - expected_error_msg = re.escape( - 'Invalid metadata provided for single-table DayZSynthesizer. The metadata contains ' - 'multiple tables. Please use multi-table DayZSynthesizer instead.' - ) - with pytest.raises(SynthesizerProcessingError, match=expected_error_msg): - DayZSynthesizer.validate_parameters(metadata, dayz_parameters) - def test_create_parameters_returns_valid_defaults(self): """Test create_parameters returns valid defaults.""" # Setup