From c8cc1ece299b5a3d5ed168fd19547718ebe07f6d Mon Sep 17 00:00:00 2001 From: Harneet Date: Fri, 14 Mar 2025 17:18:07 -0600 Subject: [PATCH 1/2] .. --- tests/test_evals_time.py | 127 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 tests/test_evals_time.py diff --git a/tests/test_evals_time.py b/tests/test_evals_time.py new file mode 100644 index 000000000..cc478ab5e --- /dev/null +++ b/tests/test_evals_time.py @@ -0,0 +1,127 @@ +import os +import sys +import copy +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +from absl import logging +from collections import namedtuple +import json +import jax +from algoperf import halton +from algoperf import random_utils as prng +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads import workloads +import submission_runner +import reference_algorithms.development_algorithms.mnist.mnist_pytorch.submission as submission_pytorch +import reference_algorithms.development_algorithms.mnist.mnist_jax.submission as submission_jax +import jax.random as jax_rng +# try: +# import jax.random as jax_rng +# except (ImportError, ModuleNotFoundError): +# logging.warning( +# 'Could not import jax.random for the submission runner, falling back to ' +# 'numpy random_utils.') +# jax_rng = None + +FLAGS = flags.FLAGS +FLAGS(sys.argv) + +class Hyperparameters: + def __init__(self): + self.learning_rate = 0.0005 + self.one_minus_beta_1 = 0.05 + self.beta2 = 0.999 + self.weight_decay = 0.01 + self.epsilon = 1e-25 + self.label_smoothing = 0.1 + self.dropout_rate = 0.1 + +class CheckTime(parameterized.TestCase): + """Tests to check if submission_time + eval_time + logging_time ~ total _wallclock_time """ + rng_seed = 0 + + @parameterized.named_parameters( + *[ dict( + testcase_name = 'mnist_pytorch', + framework = 'pytorch', + init_optimizer_state=submission_pytorch.init_optimizer_state, + update_params=submission_pytorch.update_params, + data_selection=submission_pytorch.data_selection, + rng = prng.PRNGKey(rng_seed))], + + *[ + dict( + testcase_name = 'mnist_jax', + framework = 'jax', + init_optimizer_state=submission_jax.init_optimizer_state, + update_params=submission_jax.update_params, + data_selection=submission_jax.data_selection, + #rng = jax.random.PRNGKey(rng_seed),), + rng = prng.PRNGKey(rng_seed),), + ] + ) + def test_train_once_time_consistency(self, framework, init_optimizer_state, update_params, data_selection, rng): + """Test to check the consistency of timing metrics.""" + rng_seed = 0 + #rng = jax.random.PRNGKey(rng_seed) + #rng, _ = prng.split(rng, 2) + workload_metadata = copy.deepcopy(workloads.WORKLOADS["mnist"]) + workload_metadata['workload_path'] = os.path.join( + workloads.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py') + workload = workloads.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}) + + Hp = namedtuple("Hp",["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon" ]) + hp1 = Hp(0.1,0.0017486387539278373,0.06733926164,0.9955159689799007,0.08121616522670176, 0.02, 1e-25) + # HPARAMS = { + # "dropout_rate": 0.1, + # "learning_rate": 0.0017486387539278373, + # "one_minus_beta_1": 0.06733926164, + # "beta2": 0.9955159689799007, + # "weight_decay": 0.08121616522670176, + # "warmup_factor": 0.02, + # "epsilon" : 1e-25 + # } + + + accumulated_submission_time, metrics = submission_runner.train_once( + workload = workload, + workload_name="mnist", + global_batch_size = 32, + global_eval_batch_size = 256, + data_dir = '~/tensorflow_datasets', # not sure + imagenet_v2_data_dir = None, + hyperparameters= hp1, + init_optimizer_state = init_optimizer_state, + update_params = update_params, + data_selection = data_selection, + rng = rng, + rng_seed = 0, + profiler= PassThroughProfiler(), + max_global_steps=500, + prepare_for_eval = None) + + + # Example: Check if total time roughly equals to submission_time + eval_time + logging_time + total_logged_time = (metrics['eval_results'][-1][1]['total_duration'] + - (accumulated_submission_time + + metrics['eval_results'][-1][1]['accumulated_logging_time'] + + metrics['eval_results'][-1][1]['accumulated_eval_time'])) + + # Use a tolerance for floating-point arithmetic + tolerance = 10 + self.assertAlmostEqual(total_logged_time, 0, delta=tolerance, + msg="Total wallclock time does not match the sum of submission, eval, and logging times.") + + # Check if the expected number of evaluations occurred + expected_evals = int(accumulated_submission_time // workload.eval_period_time_sec) + self.assertTrue(expected_evals <= len(metrics['eval_results']) + 2, + f"Number of evaluations {len(metrics['eval_results'])} exceeded the expected number {expected_evals + 2}.") + +if __name__ == '__main__': + absltest.main() From 34e62c13908ec3dff89875285a71d702584c92a8 Mon Sep 17 00:00:00 2001 From: Harneet Date: Fri, 14 Mar 2025 17:49:28 -0600 Subject: [PATCH 2/2] Added the test_evals_time which checks the time consistency and changed the int to uint --- algoperf/random_utils.py | 14 ++-- tests/test_evals_time.py | 136 +++++++++++++++++++-------------------- 2 files changed, 75 insertions(+), 75 deletions(-) diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index a579976ad..d259f7849 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -18,30 +18,30 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 - 1 -MIN_INT32 = 0 +MAX_UINT32 = 2**31 - 1 +MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % MAX_INT32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % MAX_INT32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % MAX_INT32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/tests/test_evals_time.py b/tests/test_evals_time.py index cc478ab5e..1d175969c 100644 --- a/tests/test_evals_time.py +++ b/tests/test_evals_time.py @@ -1,3 +1,10 @@ +""" +Module for evaluating timing consistency in MNIST workload training. + +This script runs timing consistency tests for PyTorch and JAX implementations of an MNIST training workload. +It ensures that the total reported training time aligns with the sum of submission, evaluation, and logging times. +""" + import os import sys import copy @@ -6,8 +13,6 @@ from absl.testing import parameterized from absl import logging from collections import namedtuple -import json -import jax from algoperf import halton from algoperf import random_utils as prng from algoperf.profiler import PassThroughProfiler @@ -16,18 +21,14 @@ import reference_algorithms.development_algorithms.mnist.mnist_pytorch.submission as submission_pytorch import reference_algorithms.development_algorithms.mnist.mnist_jax.submission as submission_jax import jax.random as jax_rng -# try: -# import jax.random as jax_rng -# except (ImportError, ModuleNotFoundError): -# logging.warning( -# 'Could not import jax.random for the submission runner, falling back to ' -# 'numpy random_utils.') -# jax_rng = None FLAGS = flags.FLAGS FLAGS(sys.argv) class Hyperparameters: + """ + Defines hyperparameters for training. + """ def __init__(self): self.learning_rate = 0.0005 self.one_minus_beta_1 = 0.05 @@ -38,87 +39,86 @@ def __init__(self): self.dropout_rate = 0.1 class CheckTime(parameterized.TestCase): - """Tests to check if submission_time + eval_time + logging_time ~ total _wallclock_time """ + """ + Test class to verify timing consistency in MNIST workload training. + + Ensures that submission time, evaluation time, and logging time sum up to approximately the total wall-clock time. + """ rng_seed = 0 @parameterized.named_parameters( - *[ dict( - testcase_name = 'mnist_pytorch', - framework = 'pytorch', - init_optimizer_state=submission_pytorch.init_optimizer_state, - update_params=submission_pytorch.update_params, - data_selection=submission_pytorch.data_selection, - rng = prng.PRNGKey(rng_seed))], - - *[ - dict( - testcase_name = 'mnist_jax', - framework = 'jax', + dict( + testcase_name='mnist_pytorch', + framework='pytorch', + init_optimizer_state=submission_pytorch.init_optimizer_state, + update_params=submission_pytorch.update_params, + data_selection=submission_pytorch.data_selection, + rng=prng.PRNGKey(rng_seed) + ), + dict( + testcase_name='mnist_jax', + framework='jax', init_optimizer_state=submission_jax.init_optimizer_state, update_params=submission_jax.update_params, data_selection=submission_jax.data_selection, - #rng = jax.random.PRNGKey(rng_seed),), - rng = prng.PRNGKey(rng_seed),), - ] + rng=jax_rng.PRNGKey(rng_seed) + ) ) def test_train_once_time_consistency(self, framework, init_optimizer_state, update_params, data_selection, rng): - """Test to check the consistency of timing metrics.""" - rng_seed = 0 - #rng = jax.random.PRNGKey(rng_seed) - #rng, _ = prng.split(rng, 2) + """ + Tests the consistency of timing metrics in the training process. + + Ensures that: + - The total logged time is approximately the sum of submission, evaluation, and logging times. + - The expected number of evaluations occurred within the training period. + """ workload_metadata = copy.deepcopy(workloads.WORKLOADS["mnist"]) workload_metadata['workload_path'] = os.path.join( - workloads.BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + '_' + framework, - 'workload.py') + workloads.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py' + ) workload = workloads.import_workload( workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs={}) + workload_init_kwargs={} + ) - Hp = namedtuple("Hp",["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon" ]) - hp1 = Hp(0.1,0.0017486387539278373,0.06733926164,0.9955159689799007,0.08121616522670176, 0.02, 1e-25) - # HPARAMS = { - # "dropout_rate": 0.1, - # "learning_rate": 0.0017486387539278373, - # "one_minus_beta_1": 0.06733926164, - # "beta2": 0.9955159689799007, - # "weight_decay": 0.08121616522670176, - # "warmup_factor": 0.02, - # "epsilon" : 1e-25 - # } + Hp = namedtuple("Hp", ["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon"]) + hp1 = Hp(0.1, 0.0017486387539278373, 0.06733926164, 0.9955159689799007, 0.08121616522670176, 0.02, 1e-25) - accumulated_submission_time, metrics = submission_runner.train_once( - workload = workload, + workload=workload, workload_name="mnist", - global_batch_size = 32, - global_eval_batch_size = 256, - data_dir = '~/tensorflow_datasets', # not sure - imagenet_v2_data_dir = None, - hyperparameters= hp1, - init_optimizer_state = init_optimizer_state, - update_params = update_params, - data_selection = data_selection, - rng = rng, - rng_seed = 0, - profiler= PassThroughProfiler(), + global_batch_size=32, + global_eval_batch_size=256, + data_dir='~/tensorflow_datasets', # Dataset location + imagenet_v2_data_dir=None, + hyperparameters=hp1, + init_optimizer_state=init_optimizer_state, + update_params=update_params, + data_selection=data_selection, + rng=rng, + rng_seed=0, + profiler=PassThroughProfiler(), max_global_steps=500, - prepare_for_eval = None) - - - # Example: Check if total time roughly equals to submission_time + eval_time + logging_time - total_logged_time = (metrics['eval_results'][-1][1]['total_duration'] - - (accumulated_submission_time + - metrics['eval_results'][-1][1]['accumulated_logging_time'] + - metrics['eval_results'][-1][1]['accumulated_eval_time'])) + prepare_for_eval=None + ) + + # Calculate total logged time + total_logged_time = ( + metrics['eval_results'][-1][1]['total_duration'] + - (accumulated_submission_time + + metrics['eval_results'][-1][1]['accumulated_logging_time'] + + metrics['eval_results'][-1][1]['accumulated_eval_time']) + ) - # Use a tolerance for floating-point arithmetic + # Set tolerance for floating-point precision errors tolerance = 10 - self.assertAlmostEqual(total_logged_time, 0, delta=tolerance, + self.assertAlmostEqual(total_logged_time, 0, delta=tolerance, msg="Total wallclock time does not match the sum of submission, eval, and logging times.") - # Check if the expected number of evaluations occurred + # Verify expected number of evaluations expected_evals = int(accumulated_submission_time // workload.eval_period_time_sec) self.assertTrue(expected_evals <= len(metrics['eval_results']) + 2, f"Number of evaluations {len(metrics['eval_results'])} exceeded the expected number {expected_evals + 2}.")