diff --git a/metaseq/data/iterators.py b/metaseq/data/iterators.py index 065c77339..9679a137b 100644 --- a/metaseq/data/iterators.py +++ b/metaseq/data/iterators.py @@ -358,23 +358,24 @@ def load_state_dict(self, state_dict): self._itr = self._get_iterator_for_epoch(self.epoch) self._itr.n = n - if True: - # Epilogue bug fixup - # Warning: this fix is not correct for the last ~1% of an epoch, but it only needs to be - # applied once earlier in the epoch to fix any data loaders with incorrect data. - num_workers = self._itr.num_workers - batch_size = self._itr.batch_size * self._itr.num_shards - if sum(sequences_consumed) != n * batch_size: - logger.warning( - f"{distributed_utils.get_global_rank()}: Sequences appear corrupted: " - f"{n}*{batch_size} != sum({sequences_consumed})" - ) - each, left = divmod(n, num_workers) - sequences_consumed = [ - batch_size * (each + (1 if i < left else 0)) - for i in range(num_workers) - ] - assert sum(sequences_consumed) == n * batch_size + # See https://github.com/facebookresearch/metaseq/pull/566/files for context. + # if True: + # # Epilogue bug fixup + # # Warning: this fix is not correct for the last ~1% of an epoch, but it only needs to be + # # applied once earlier in the epoch to fix any data loaders with incorrect data. + # num_workers = self._itr.num_workers + # batch_size = self._itr.batch_size * self._itr.num_shards + # if sum(sequences_consumed) != n * batch_size: + # logger.warning( + # f"{distributed_utils.get_global_rank()}: Sequences appear corrupted: " + # f"{n}*{batch_size} != sum({sequences_consumed})" + # ) + # each, left = divmod(n, num_workers) + # sequences_consumed = [ + # batch_size * (each + (1 if i < left else 0)) + # for i in range(num_workers) + # ] + # assert sum(sequences_consumed) == n * batch_size self._itr.sequences_consumed = sequences_consumed self._itr.next_worker = next_worker