diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index a579976ad..d259f7849 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -18,30 +18,30 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 - 1 -MIN_INT32 = 0 +MAX_UINT32 = 2**31 - 1 +MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % MAX_INT32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % MAX_INT32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % MAX_INT32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/tests/test_evals_time.py b/tests/test_evals_time.py new file mode 100644 index 000000000..1d175969c --- /dev/null +++ b/tests/test_evals_time.py @@ -0,0 +1,127 @@ +""" +Module for evaluating timing consistency in MNIST workload training. + +This script runs timing consistency tests for PyTorch and JAX implementations of an MNIST training workload. +It ensures that the total reported training time aligns with the sum of submission, evaluation, and logging times. +""" + +import os +import sys +import copy +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +from absl import logging +from collections import namedtuple +from algoperf import halton +from algoperf import random_utils as prng +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads import workloads +import submission_runner +import reference_algorithms.development_algorithms.mnist.mnist_pytorch.submission as submission_pytorch +import reference_algorithms.development_algorithms.mnist.mnist_jax.submission as submission_jax +import jax.random as jax_rng + +FLAGS = flags.FLAGS +FLAGS(sys.argv) + +class Hyperparameters: + """ + Defines hyperparameters for training. + """ + def __init__(self): + self.learning_rate = 0.0005 + self.one_minus_beta_1 = 0.05 + self.beta2 = 0.999 + self.weight_decay = 0.01 + self.epsilon = 1e-25 + self.label_smoothing = 0.1 + self.dropout_rate = 0.1 + +class CheckTime(parameterized.TestCase): + """ + Test class to verify timing consistency in MNIST workload training. + + Ensures that submission time, evaluation time, and logging time sum up to approximately the total wall-clock time. + """ + rng_seed = 0 + + @parameterized.named_parameters( + dict( + testcase_name='mnist_pytorch', + framework='pytorch', + init_optimizer_state=submission_pytorch.init_optimizer_state, + update_params=submission_pytorch.update_params, + data_selection=submission_pytorch.data_selection, + rng=prng.PRNGKey(rng_seed) + ), + dict( + testcase_name='mnist_jax', + framework='jax', + init_optimizer_state=submission_jax.init_optimizer_state, + update_params=submission_jax.update_params, + data_selection=submission_jax.data_selection, + rng=jax_rng.PRNGKey(rng_seed) + ) + ) + def test_train_once_time_consistency(self, framework, init_optimizer_state, update_params, data_selection, rng): + """ + Tests the consistency of timing metrics in the training process. + + Ensures that: + - The total logged time is approximately the sum of submission, evaluation, and logging times. + - The expected number of evaluations occurred within the training period. + """ + workload_metadata = copy.deepcopy(workloads.WORKLOADS["mnist"]) + workload_metadata['workload_path'] = os.path.join( + workloads.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py' + ) + workload = workloads.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={} + ) + + Hp = namedtuple("Hp", ["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon"]) + hp1 = Hp(0.1, 0.0017486387539278373, 0.06733926164, 0.9955159689799007, 0.08121616522670176, 0.02, 1e-25) + + accumulated_submission_time, metrics = submission_runner.train_once( + workload=workload, + workload_name="mnist", + global_batch_size=32, + global_eval_batch_size=256, + data_dir='~/tensorflow_datasets', # Dataset location + imagenet_v2_data_dir=None, + hyperparameters=hp1, + init_optimizer_state=init_optimizer_state, + update_params=update_params, + data_selection=data_selection, + rng=rng, + rng_seed=0, + profiler=PassThroughProfiler(), + max_global_steps=500, + prepare_for_eval=None + ) + + # Calculate total logged time + total_logged_time = ( + metrics['eval_results'][-1][1]['total_duration'] + - (accumulated_submission_time + + metrics['eval_results'][-1][1]['accumulated_logging_time'] + + metrics['eval_results'][-1][1]['accumulated_eval_time']) + ) + + # Set tolerance for floating-point precision errors + tolerance = 10 + self.assertAlmostEqual(total_logged_time, 0, delta=tolerance, + msg="Total wallclock time does not match the sum of submission, eval, and logging times.") + + # Verify expected number of evaluations + expected_evals = int(accumulated_submission_time // workload.eval_period_time_sec) + self.assertTrue(expected_evals <= len(metrics['eval_results']) + 2, + f"Number of evaluations {len(metrics['eval_results'])} exceeded the expected number {expected_evals + 2}.") + +if __name__ == '__main__': + absltest.main()