Skip to content

Commit 78fee6b

Browse files
committed
Split methods into helpers
1 parent 2d40658 commit 78fee6b

File tree

3 files changed

+91
-101
lines changed

3 files changed

+91
-101
lines changed

sdv/datasets/demo.py

Lines changed: 91 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,44 @@ def is_metainfo_yaml(key):
368368
yield dataset_name, key
369369

370370

371+
def _get_info_from_yaml_key(yaml_key):
372+
"""Load and parse YAML metadata from an S3 key."""
373+
raw = _get_data_from_bucket(yaml_key)
374+
return yaml.safe_load(raw) or {}
375+
376+
377+
def _parse_size_mb(size_mb_val, dataset_name):
378+
"""Parse the size (MB) value into a float or NaN with logging on failures."""
379+
try:
380+
return float(size_mb_val) if size_mb_val is not None else np.nan
381+
except (ValueError, TypeError):
382+
LOGGER.info(
383+
f'Invalid dataset-size-mb {size_mb_val} for dataset {dataset_name}; defaulting to NaN.'
384+
)
385+
return np.nan
386+
387+
388+
def _parse_num_tables(num_tables_val, dataset_name):
389+
"""Parse the num-tables value into an int or NaN with logging on failures."""
390+
if isinstance(num_tables_val, str):
391+
try:
392+
num_tables_val = float(num_tables_val)
393+
except (ValueError, TypeError):
394+
LOGGER.info(
395+
f'Could not cast num_tables_val {num_tables_val} to float for '
396+
f'dataset {dataset_name}; defaulting to NaN.'
397+
)
398+
num_tables_val = np.nan
399+
400+
try:
401+
return int(num_tables_val) if not pd.isna(num_tables_val) else np.nan
402+
except (ValueError, TypeError):
403+
LOGGER.info(
404+
f'Invalid num-tables {num_tables_val} for dataset {dataset_name} when parsing as int.'
405+
)
406+
return np.nan
407+
408+
371409
def get_available_demos(modality):
372410
"""Get demo datasets available for a ``modality``.
373411
@@ -387,38 +425,10 @@ def get_available_demos(modality):
387425
tables_info = defaultdict(list)
388426
for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality):
389427
try:
390-
raw = _get_data_from_bucket(yaml_key)
391-
info = yaml.safe_load(raw) or {}
428+
info = _get_info_from_yaml_key(yaml_key)
392429

393-
size_mb_val = info.get('dataset-size-mb')
394-
try:
395-
size_mb = float(size_mb_val) if size_mb_val is not None else np.nan
396-
except (ValueError, TypeError):
397-
LOGGER.info(
398-
f'Invalid dataset-size-mb {size_mb_val} for dataset '
399-
f'{dataset_name}; defaulting to NaN.'
400-
)
401-
size_mb = np.nan
402-
403-
num_tables_val = info.get('num-tables', np.nan)
404-
if isinstance(num_tables_val, str):
405-
try:
406-
num_tables_val = float(num_tables_val)
407-
except (ValueError, TypeError):
408-
LOGGER.info(
409-
f'Could not cast num_tables_val {num_tables_val} to float for '
410-
f'dataset {dataset_name}; defaulting to NaN.'
411-
)
412-
num_tables_val = np.nan
413-
414-
try:
415-
num_tables = int(num_tables_val) if not pd.isna(num_tables_val) else np.nan
416-
except (ValueError, TypeError):
417-
LOGGER.info(
418-
f'Invalid num-tables {num_tables_val} for '
419-
f'dataset {dataset_name} when parsing as int.'
420-
)
421-
num_tables = np.nan
430+
size_mb = _parse_size_mb(info.get('dataset-size-mb'), dataset_name)
431+
num_tables = _parse_num_tables(info.get('num-tables', np.nan), dataset_name)
422432

423433
tables_info['dataset_name'].append(dataset_name)
424434
tables_info['size_MB'].append(size_mb)
@@ -456,6 +466,53 @@ def _find_text_key(contents, dataset_prefix, filename):
456466
return None
457467

458468

469+
def _validate_text_file_content(modality, output_filepath, filename):
470+
"""Validation for the text file content method."""
471+
_validate_modalities(modality)
472+
if output_filepath is not None and not str(output_filepath).endswith('.txt'):
473+
fname = (filename or '').lower()
474+
file_type = 'README' if 'readme' in fname else 'source'
475+
raise ValueError(
476+
f'The {file_type} can only be saved as a txt file. '
477+
"Please provide a filepath ending in '.txt'"
478+
)
479+
480+
481+
def _raise_warnings(filename, output_filepath):
482+
"""Warn about missing text resources for a dataset."""
483+
if (filename or '').upper() == 'README.TXT':
484+
msg = 'No README information is available for this dataset.'
485+
elif (filename or '').upper() == 'SOURCE.TXT':
486+
msg = 'No source information is available for this dataset.'
487+
else:
488+
msg = f'No {filename} information is available for this dataset.'
489+
490+
if output_filepath:
491+
msg = f'{msg} The requested file ({output_filepath}) will not be created.'
492+
493+
warnings.warn(msg, DemoResourceNotFoundWarning)
494+
495+
496+
def _save_document(text, output_filepath, filename, dataset_name):
497+
"""Persist ``text`` to ``output_filepath`` if provided."""
498+
if not output_filepath:
499+
return
500+
501+
if os.path.exists(str(output_filepath)):
502+
raise ValueError(
503+
f"A file named '{output_filepath}' already exists. Please specify a different filepath."
504+
)
505+
506+
try:
507+
parent = os.path.dirname(str(output_filepath))
508+
if parent:
509+
os.makedirs(parent, exist_ok=True)
510+
with open(output_filepath, 'w', encoding='utf-8') as f:
511+
f.write(text)
512+
except Exception:
513+
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
514+
515+
459516
def _get_text_file_content(modality, dataset_name, filename, output_filepath=None):
460517
"""Fetch text file content under the dataset prefix.
461518
@@ -473,29 +530,13 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
473530
str or None:
474531
The decoded text contents if the file exists, otherwise ``None``.
475532
"""
476-
_validate_modalities(modality)
477-
if output_filepath is not None and not str(output_filepath).endswith('.txt'):
478-
fname = (filename or '').lower()
479-
file_type = 'README' if 'readme' in fname else 'source'
480-
raise ValueError(
481-
f'The {file_type} can only be saved as a txt file. '
482-
"Please provide a filepath ending in '.txt'"
483-
)
533+
_validate_text_file_content(modality, output_filepath, filename)
484534

485535
dataset_prefix = f'{modality}/{dataset_name}/'
486536
contents = _list_objects(dataset_prefix)
487-
488537
key = _find_text_key(contents, dataset_prefix, filename)
489538
if not key:
490-
if file_type in ('README', 'SOURCE'):
491-
msg = f'No {file_type} information is available for this dataset.
492-
else:
493-
msg = f'No {filename} information is available for this dataset.'
494-
495-
if output_filepath:
496-
msg = f'{msg} The requested file ({output_filepath}) will not be created.'
497-
498-
warnings.warn(msg, DemoResourceNotFoundWarning)
539+
_raise_warnings(filename, output_filepath)
499540
return None
500541

501542
try:
@@ -505,22 +546,7 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
505546
return None
506547

507548
text = raw.decode('utf-8', errors='replace')
508-
if output_filepath:
509-
if os.path.exists(str(output_filepath)):
510-
raise ValueError(
511-
f"A file named '{output_filepath}' already exists. "
512-
'Please specify a different filepath.'
513-
)
514-
try:
515-
parent = os.path.dirname(str(output_filepath))
516-
if parent:
517-
os.makedirs(parent, exist_ok=True)
518-
with open(output_filepath, 'w', encoding='utf-8') as f:
519-
f.write(text)
520-
521-
except Exception:
522-
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
523-
pass
549+
_save_document(text, output_filepath, filename, dataset_name)
524550

525551
return text
526552

tests/unit/multi_table/test_dayz.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pandas as pd
55
import pytest
66

7-
from sdv.datasets.demo import download_demo
87
from sdv.errors import SynthesizerInputError, SynthesizerProcessingError
98
from sdv.metadata import Metadata
109
from sdv.multi_table.dayz import (
@@ -333,21 +332,3 @@ def test__validate_relationships_is_list_of_dicts(self, metadata):
333332

334333
with pytest.raises(SynthesizerProcessingError, match=expected_msg):
335334
DayZSynthesizer.validate_parameters(metadata, {'relationships': ['a', 'b', 'c']})
336-
337-
def test__validate_min_cardinality_allows_zero(self):
338-
"""Test that min_cardinality=0 is allowed and does not raise."""
339-
# Setup
340-
data, metadata = download_demo('multi_table', 'financial_v1')
341-
dayz_parameters = DayZSynthesizer.create_parameters(data, metadata)
342-
dayz_parameters['relationships'] = [
343-
{
344-
'parent_table_name': 'district',
345-
'parent_primary_key': 'district_id',
346-
'child_table_name': 'account',
347-
'child_foreign_key': 'district_id',
348-
'min_cardinality': 0,
349-
}
350-
]
351-
352-
# Run
353-
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)

tests/unit/single_table/test_dayz.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
import pandas as pd
66
import pytest
77

8-
from sdv.datasets.demo import download_demo
98
from sdv.errors import SynthesizerInputError, SynthesizerProcessingError
109
from sdv.metadata import Metadata
11-
from sdv.multi_table.dayz import DayZSynthesizer as MultiTableDayZSynthesizer
1210
from sdv.single_table.dayz import (
1311
DayZSynthesizer,
1412
_validate_column_parameters,
@@ -419,21 +417,6 @@ def test__validate_parameters_errors_with_multi_table_metadata(self):
419417
with pytest.raises(SynthesizerProcessingError, match=expected_error_msg):
420418
_validate_parameters(metadata, dayz_parameters)
421419

422-
def test__validate_parameters_errors_with_relationships(self):
423-
"""Test that single-table validation errors if relationships are provided."""
424-
# Setup
425-
data, metadata = download_demo('multi_table', 'financial_v1')
426-
dayz_parameters = MultiTableDayZSynthesizer.create_parameters(data, metadata)
427-
del dayz_parameters['relationships']
428-
429-
# Run and Assert
430-
expected_error_msg = re.escape(
431-
'Invalid metadata provided for single-table DayZSynthesizer. The metadata contains '
432-
'multiple tables. Please use multi-table DayZSynthesizer instead.'
433-
)
434-
with pytest.raises(SynthesizerProcessingError, match=expected_error_msg):
435-
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)
436-
437420
def test_create_parameters_returns_valid_defaults(self):
438421
"""Test create_parameters returns valid defaults."""
439422
# Setup

0 commit comments

Comments
 (0)