Skip to content

Commit 67af8a7

Browse files
improve reproducibility interface
1 parent 7b4f3e8 commit 67af8a7

File tree

5 files changed

+22
-31
lines changed

5 files changed

+22
-31
lines changed

mostlyai/qa/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from packaging.version import Version
2020

2121
from mostlyai.qa.logging import init_logging
22+
from mostlyai.qa.random_state import set_random_state
2223
from mostlyai.qa.reporting import report
2324
from mostlyai.qa.reporting_from_statistics import report_from_statistics
2425

25-
__all__ = ["report", "report_from_statistics", "init_logging"]
26+
__all__ = ["report", "report_from_statistics", "init_logging", "set_random_state"]
2627
__version__ = "1.9.5"
2728

2829
warnings.filterwarnings("ignore", category=FutureWarning, module="phik")

mostlyai/qa/_common.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import os
17-
import struct
1816
from typing import Protocol
1917

2018
import pandas as pd
@@ -124,21 +122,3 @@ def determine_data_size(
124122
return len(tgt_keys)
125123
else:
126124
return len(tgt_data)
127-
128-
129-
def set_random_state(random_state: int | None = None):
130-
def get_random_int_from_os() -> int:
131-
# 32-bit, cryptographically secure random int from os
132-
return int(struct.unpack("I", os.urandom(4))[0])
133-
134-
if random_state is not None:
135-
_LOG.info(f"Global random_state set to `{random_state}`")
136-
137-
if random_state is None:
138-
random_state = get_random_int_from_os()
139-
140-
import random
141-
import numpy as np
142-
143-
random.seed(random_state)
144-
np.random.seed(random_state)

mostlyai/qa/reporting.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
TGT_COLUMN_PREFIX,
6363
REPORT_CREDITS,
6464
ProgressCallbackWrapper,
65-
set_random_state,
6665
)
6766
from mostlyai.qa._filesystem import Statistics, TemporaryWorkspace
6867

@@ -88,7 +87,6 @@ def report(
8887
max_sample_size_embeddings: int | None = None,
8988
statistics_path: str | Path | None = None,
9089
update_progress: ProgressCallback | None = None,
91-
random_state: int | None = None,
9290
) -> tuple[Path, ModelMetrics | None]:
9391
"""
9492
Generate an HTML report and metrics for assessing synthetic data quality.
@@ -123,15 +121,12 @@ def report(
123121
max_sample_size_embeddings: The maximum sample size for embedding calculations.
124122
statistics_path: The path of where to store the statistics to be used by `report_from_statistics`
125123
update_progress: The progress callback.
126-
random_state: Seed for the random number generators.
127124
128125
Returns:
129126
The path to the generated HTML report.
130127
Metrics instance with accuracy, similarity, and distances metrics.
131128
"""
132129

133-
set_random_state(random_state)
134-
135130
if syn_ctx_data is not None:
136131
if ctx_primary_key is None:
137132
raise ValueError("If syn_ctx_data is provided, then ctx_primary_key must also be provided.")

mostlyai/qa/reporting_from_statistics.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
determine_data_size,
3434
REPORT_CREDITS,
3535
ProgressCallbackWrapper,
36-
set_random_state,
3736
)
3837
from mostlyai.qa._filesystem import Statistics, TemporaryWorkspace
3938

@@ -54,7 +53,6 @@ def report_from_statistics(
5453
max_sample_size_accuracy: int | None = None,
5554
max_sample_size_coherence: int | None = None,
5655
update_progress: ProgressCallback | None = None,
57-
random_state: int | None = None,
5856
) -> Path:
5957
"""
6058
Generate an HTML report based on previously generated statistics and newly provided synthetic data samples.
@@ -72,14 +70,11 @@ def report_from_statistics(
7270
max_sample_size_accuracy: The maximum sample size for accuracy calculations.
7371
max_sample_size_coherence: The maximum sample size for coherence calculations.
7472
update_progress: The progress callback.
75-
random_state: Seed for the random number generators.
7673
7774
Returns:
7875
The path to the generated HTML report.
7976
"""
8077

81-
set_random_state(random_state)
82-
8378
with (
8479
TemporaryWorkspace() as workspace,
8580
ProgressCallbackWrapper(update_progress) as progress,

tests/end_to_end/test_report.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,23 @@ def generate_dates(start_date, end_date, num_samples):
304304
"Expected a warning about dtype mismatch for column 'dt'"
305305
)
306306
assert statistics.accuracy.overall > 0.6
307+
308+
309+
def test_reproducibility(tmp_path):
310+
statistics_path = tmp_path / "statistics"
311+
syn_tgt_data = mock_data(150)
312+
trn_tgt_data = mock_data(150)
313+
hol_tgt_data = mock_data(150)
314+
kwargs = {
315+
"syn_tgt_data": syn_tgt_data,
316+
"trn_tgt_data": trn_tgt_data,
317+
"hol_tgt_data": hol_tgt_data,
318+
"statistics_path": statistics_path,
319+
"max_sample_size_accuracy": 120,
320+
"max_sample_size_embeddings": 80,
321+
}
322+
qa.set_random_state(45)
323+
_, m1 = qa.report(**kwargs)
324+
qa.set_random_state(45)
325+
_, m2 = qa.report(**kwargs)
326+
assert m1 == m2

0 commit comments

Comments
 (0)