From fb6291c1a9491b7334986b66f1e81bd2519c5287 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Fri, 17 Oct 2025 08:42:39 +0200 Subject: [PATCH 1/9] WIP: shuffle working, interleave_ds not yet --- src/datasets/iterable_dataset.py | 41 +++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 7d16baa7d0d..bc8ed064bae 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -190,8 +190,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples """ raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") + def shift_rngs(self, value: int) -> None: + print("[shift_rngs]:", value) + if hasattr(self, 'generator'): + print("[shift_rngs]: has generator") + new_seed = self.generator.bit_generator.state['state']['state'] + value + self.generator = np.random.default_rng(seed=new_seed) + + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + print(f"[_BaseExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous=True) -> list[int]: @@ -258,6 +267,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesItera def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" + print(f"[ExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) @@ -298,6 +308,7 @@ def __iter__(self): def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" + print(f"[ShuffledDataSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ExamplesIterable(self.generate_examples_fn, kwargs_with_shuffled_shards).shard_data_sources( @@ -362,6 +373,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "ArrowExamples def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" + print(f"[ArrowExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) @@ -435,6 +447,7 @@ def _iter_arrow(self): def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" + print(f"[ShuffledDataSourcesArrowExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources( @@ -567,6 +580,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArro ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RebatchedArrowExamplesIterable": + print(f"[RebatchedArrowExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return RebatchedArrowExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.batch_size, @@ -614,6 +628,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumns return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SelectColumnsIterable": + print(f"[SelectColumnsIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return SelectColumnsIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.column_names ) @@ -658,6 +673,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "StepExamplesI ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "StepExamplesIterable": + print(f"[StepExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return StepExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), step=self.step, @@ -823,6 +839,7 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "CyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + print(f"[CyclingMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if num_shards < self.num_shards: return CyclingMultiSourcesExamplesIterable( [ @@ -916,6 +933,7 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + print(f"[VerticallyConcatenatedMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return VerticallyConcatenatedMultiSourcesExamplesIterable( [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -1004,6 +1022,7 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + print(f"[HorizontallyConcatenatedMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return HorizontallyConcatenatedMultiSourcesExamplesIterable( [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -1085,6 +1104,7 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "RandomlyCyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + print(f"[RandomlyCyclingMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if num_shards < self.num_shards: return RandomlyCyclingMultiSourcesExamplesIterable( [ @@ -1491,6 +1511,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable": """Keep only the requested shard.""" + print(f"[MappedExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return MappedExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.function, @@ -1597,6 +1618,7 @@ def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FilteredExamplesIterable": """Keep only the requested shard.""" + print(f"[FilteredExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return FilteredExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.mask_function, @@ -1615,6 +1637,8 @@ def num_shards(self) -> int: class BufferShuffledExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator): + # import traceback as tb + # tb.print_stack() super().__init__() self.ex_iterable = ex_iterable self.buffer_size = buffer_size @@ -1654,6 +1678,7 @@ def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batc def __iter__(self): buffer_size = self.buffer_size rng = deepcopy(self.generator) + print(f"This is rng {rng.bit_generator.state['state']['state']}") indices_iterator = self._iter_random_indices(rng, buffer_size) # this is the shuffle buffer that we keep in memory mem_buffer = [] @@ -1693,10 +1718,14 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffle def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable": """Keep only the requested shard.""" + print(f"[BufferShuffledExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") + # set the seed independently for each BufferShuffledExamplesIterable based on index (which is worker_id) + new_seed = self.generator.bit_generator.state['state']['state'] + index + rng = np.random.default_rng(seed=new_seed) return BufferShuffledExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), buffer_size=self.buffer_size, - generator=self.generator, + generator=rng, ) @property @@ -1764,6 +1793,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesI def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SkipExamplesIterable": """Keep only the requested shard.""" + print(f"[SkipExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if self.split_when_sharding: return SkipExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -1818,6 +1848,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExample def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RepeatExamplesIterable": """Shard, then repeat shards.""" + print(f"[RepeatExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return RepeatExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), num_times=self.num_times, @@ -1889,6 +1920,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesI def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TakeExamplesIterable": """Keep only the requested shard.""" + print(f"[TakeExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if self.split_when_sharding: return TakeExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -1985,6 +2017,7 @@ def _init_state_dict(self) -> dict: return self._state_dict def __iter__(self): + print("[FormattedExamplesIterable.__iter__]", type(self.ex_iterable)) if not self.formatting or self.formatting.is_table: formatter = PythonFormatter( features=self._features if not self.ex_iterable.is_typed else None, @@ -1997,6 +2030,7 @@ def __iter__(self): token_per_repo_id=self.token_per_repo_id, ) if self.ex_iterable.iter_arrow: + print("[FormattedExamplesIterable.__iter__] iter_arrow!") # feature casting (inc column addition) handled within self._iter_arrow() for key, pa_table in self._iter_arrow(): batch = formatter.format_batch(pa_table) @@ -2009,6 +2043,7 @@ def __iter__(self): else None # cast in case features is None ) for key, example in self.ex_iterable: + print("[FormattedExamplesIterable] in self.ex_iterable") # don't apply feature types if already applied by ex_iterable (e.g. in case of chained with_format) if self.features and not self.ex_iterable.is_typed: example = _apply_feature_types_on_example( @@ -2044,6 +2079,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "FormattedExam def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FormattedExamplesIterable": """Keep only the requested shard.""" + print(f"[FormattedExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return FormattedExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), features=self.features, @@ -2361,6 +2397,8 @@ def _iter_pytorch(self): ex_iterable = ex_iterable.shard_data_sources( num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) + ex_iterable.shift_rngs(value=worker_info.id) + print("[DEBUG] Number of shards for this worker:", ex_iterable.num_shards, "worker_id", worker_info.id, type(ex_iterable)) self._state_dict = { "examples_iterable": ex_iterable._init_state_dict(), "epoch": self.epoch, @@ -2979,6 +3017,7 @@ def shuffle( 'text': "sam jones became a very lucky filmmaker the day wilco got dropped from their record label , proving that one man's ruin may be another's fortune ."}] ``` """ + print("[IterableDataset.shuffle]: seed", {seed}) if generator is None: generator = np.random.default_rng(seed) else: From bfc5b38b13b0f575f7075e52b13e8db6e860df55 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Fri, 17 Oct 2025 11:23:14 +0200 Subject: [PATCH 2/9] remove debug statements --- src/datasets/iterable_dataset.py | 39 +++----------------------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index bc8ed064bae..055599b13bc 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -191,16 +191,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") def shift_rngs(self, value: int) -> None: - print("[shift_rngs]:", value) - if hasattr(self, 'generator'): - print("[shift_rngs]: has generator") - new_seed = self.generator.bit_generator.state['state']['state'] + value + if hasattr(self, "generator"): + new_seed = self.generator.bit_generator.state["state"]["state"] + value self.generator = np.random.default_rng(seed=new_seed) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" - print(f"[_BaseExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous=True) -> list[int]: @@ -267,7 +263,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesItera def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" - print(f"[ExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) @@ -308,7 +303,6 @@ def __iter__(self): def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" - print(f"[ShuffledDataSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ExamplesIterable(self.generate_examples_fn, kwargs_with_shuffled_shards).shard_data_sources( @@ -373,7 +367,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "ArrowExamples def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" - print(f"[ArrowExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) @@ -447,7 +440,6 @@ def _iter_arrow(self): def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" - print(f"[ShuffledDataSourcesArrowExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources( @@ -580,7 +572,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArro ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RebatchedArrowExamplesIterable": - print(f"[RebatchedArrowExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return RebatchedArrowExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.batch_size, @@ -628,7 +619,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumns return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SelectColumnsIterable": - print(f"[SelectColumnsIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return SelectColumnsIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.column_names ) @@ -673,7 +663,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "StepExamplesI ) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "StepExamplesIterable": - print(f"[StepExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return StepExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), step=self.step, @@ -839,7 +828,6 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "CyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" - print(f"[CyclingMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if num_shards < self.num_shards: return CyclingMultiSourcesExamplesIterable( [ @@ -933,7 +921,6 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" - print(f"[VerticallyConcatenatedMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return VerticallyConcatenatedMultiSourcesExamplesIterable( [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -1022,7 +1009,6 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" - print(f"[HorizontallyConcatenatedMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return HorizontallyConcatenatedMultiSourcesExamplesIterable( [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -1104,7 +1090,6 @@ def shard_data_sources( self, num_shards: int, index: int, contiguous=True ) -> "RandomlyCyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" - print(f"[RandomlyCyclingMultiSourcesExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if num_shards < self.num_shards: return RandomlyCyclingMultiSourcesExamplesIterable( [ @@ -1511,7 +1496,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable": """Keep only the requested shard.""" - print(f"[MappedExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return MappedExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.function, @@ -1618,7 +1602,6 @@ def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FilteredExamplesIterable": """Keep only the requested shard.""" - print(f"[FilteredExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return FilteredExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.mask_function, @@ -1637,8 +1620,6 @@ def num_shards(self) -> int: class BufferShuffledExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator): - # import traceback as tb - # tb.print_stack() super().__init__() self.ex_iterable = ex_iterable self.buffer_size = buffer_size @@ -1678,7 +1659,6 @@ def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batc def __iter__(self): buffer_size = self.buffer_size rng = deepcopy(self.generator) - print(f"This is rng {rng.bit_generator.state['state']['state']}") indices_iterator = self._iter_random_indices(rng, buffer_size) # this is the shuffle buffer that we keep in memory mem_buffer = [] @@ -1718,14 +1698,10 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffle def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable": """Keep only the requested shard.""" - print(f"[BufferShuffledExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") - # set the seed independently for each BufferShuffledExamplesIterable based on index (which is worker_id) - new_seed = self.generator.bit_generator.state['state']['state'] + index - rng = np.random.default_rng(seed=new_seed) return BufferShuffledExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), buffer_size=self.buffer_size, - generator=rng, + generator=self.generator, ) @property @@ -1793,7 +1769,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesI def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SkipExamplesIterable": """Keep only the requested shard.""" - print(f"[SkipExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if self.split_when_sharding: return SkipExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -1848,7 +1823,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExample def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RepeatExamplesIterable": """Shard, then repeat shards.""" - print(f"[RepeatExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return RepeatExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), num_times=self.num_times, @@ -1920,7 +1894,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesI def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TakeExamplesIterable": """Keep only the requested shard.""" - print(f"[TakeExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") if self.split_when_sharding: return TakeExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -2017,7 +1990,6 @@ def _init_state_dict(self) -> dict: return self._state_dict def __iter__(self): - print("[FormattedExamplesIterable.__iter__]", type(self.ex_iterable)) if not self.formatting or self.formatting.is_table: formatter = PythonFormatter( features=self._features if not self.ex_iterable.is_typed else None, @@ -2030,7 +2002,6 @@ def __iter__(self): token_per_repo_id=self.token_per_repo_id, ) if self.ex_iterable.iter_arrow: - print("[FormattedExamplesIterable.__iter__] iter_arrow!") # feature casting (inc column addition) handled within self._iter_arrow() for key, pa_table in self._iter_arrow(): batch = formatter.format_batch(pa_table) @@ -2043,7 +2014,6 @@ def __iter__(self): else None # cast in case features is None ) for key, example in self.ex_iterable: - print("[FormattedExamplesIterable] in self.ex_iterable") # don't apply feature types if already applied by ex_iterable (e.g. in case of chained with_format) if self.features and not self.ex_iterable.is_typed: example = _apply_feature_types_on_example( @@ -2079,7 +2049,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "FormattedExam def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FormattedExamplesIterable": """Keep only the requested shard.""" - print(f"[FormattedExamplesIterable.shard_data_sources], num_shards: {num_shards}, index: {index}, contiguous: {contiguous}, has_generator: {hasattr(self, 'generator')}") return FormattedExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), features=self.features, @@ -2398,7 +2367,6 @@ def _iter_pytorch(self): num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) ex_iterable.shift_rngs(value=worker_info.id) - print("[DEBUG] Number of shards for this worker:", ex_iterable.num_shards, "worker_id", worker_info.id, type(ex_iterable)) self._state_dict = { "examples_iterable": ex_iterable._init_state_dict(), "epoch": self.epoch, @@ -3017,7 +2985,6 @@ def shuffle( 'text': "sam jones became a very lucky filmmaker the day wilco got dropped from their record label , proving that one man's ruin may be another's fortune ."}] ``` """ - print("[IterableDataset.shuffle]: seed", {seed}) if generator is None: generator = np.random.default_rng(seed) else: From 20835aed04a97127a98e82f7794dfa7e514b7859 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Fri, 17 Oct 2025 12:16:18 +0200 Subject: [PATCH 3/9] add test --- tests/test_iterable_dataset.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 1bca866bdf8..5a207f1aeb7 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1553,6 +1553,30 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa assert len(result) == 10 +@require_torch +@pytest.mark.filterwarnings("ignore:This DataLoader will create:UserWarning") +@pytest.mark.parametrize("num_workers", [4, 8]) +def test_iterable_dataset_shuffle_with_multiple_workers_different_rng(num_workers): + from itertools import groupby + + from torch.utils.data import DataLoader + + ex_iterable = ExamplesIterable( + generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(num_workers)], "n": 10} + ) + dataset = IterableDataset(ex_iterable).shuffle(buffer_size=100, seed=42) + dataloader = DataLoader(dataset, batch_size=None, num_workers=num_workers) + + result = list(dataloader) + assert len(result) == num_workers * 10 + + chunks = [list(group) for _, group in groupby(enumerate(result), key=lambda x: x[0] // num_workers)] + for chunk_idx, chunk in enumerate(chunks): + values = [ex["id"] for _, ex in chunk] + unique_values = set(values) + assert len(unique_values) > 1, f"Chunk {chunk_idx}: all workers produced same values {values}" + + @pytest.mark.parametrize("batch_size", [4, 5]) @pytest.mark.parametrize("drop_last_batch", [False, True]) def test_iterable_dataset_iter_batch(batch_size, drop_last_batch): From 7ad14f6ca49f7577279ac27c11e890917c4f95f7 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Fri, 17 Oct 2025 12:17:57 +0200 Subject: [PATCH 4/9] update test --- tests/test_iterable_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 5a207f1aeb7..b1048d03c98 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1554,13 +1554,12 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa @require_torch -@pytest.mark.filterwarnings("ignore:This DataLoader will create:UserWarning") -@pytest.mark.parametrize("num_workers", [4, 8]) -def test_iterable_dataset_shuffle_with_multiple_workers_different_rng(num_workers): +def test_iterable_dataset_shuffle_with_multiple_workers_different_rng(): from itertools import groupby from torch.utils.data import DataLoader + num_workers = 8 ex_iterable = ExamplesIterable( generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(num_workers)], "n": 10} ) From 6a582cb3e432b9b8ad5fc5270b6fa80d501bf95e Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Fri, 17 Oct 2025 16:00:32 +0200 Subject: [PATCH 5/9] use recursive overwriting of generator seeds --- src/datasets/iterable_dataset.py | 25 ++++++++++++++-- tests/test_iterable_dataset.py | 51 +++++++++++++++++++++++--------- 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 055599b13bc..279dee4b735 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -66,6 +66,19 @@ from .utils.typing import PathLike +def show_chain_of_iterables(ex_iterable: "_BaseExamplesIterable") -> str: + chain_of_iters = [ex_iterable.__class__.__name__] + + def recurse(obj): + new_obj = getattr(obj, "ex_iterable", None) + if new_obj is not None: + chain_of_iters.append(new_obj.__class__.__name__) + recurse(new_obj) + + recurse(ex_iterable) + print(f"[show_chain_of_iterables]: {','.join(chain_of_iters)}") + + if TYPE_CHECKING: import sqlite3 @@ -191,9 +204,14 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") def shift_rngs(self, value: int) -> None: - if hasattr(self, "generator"): - new_seed = self.generator.bit_generator.state["state"]["state"] + value - self.generator = np.random.default_rng(seed=new_seed) + def set_seed_recursively(ex_iterable): + if hasattr(ex_iterable, "generator"): + new_seed = ex_iterable.generator.bit_generator.state["state"]["state"] + value + ex_iterable.generator = np.random.default_rng(seed=new_seed) + if hasattr(ex_iterable, "ex_iterable"): + set_seed_recursively(ex_iterable.ex_iterable) + + set_seed_recursively(self) def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" @@ -2366,6 +2384,7 @@ def _iter_pytorch(self): ex_iterable = ex_iterable.shard_data_sources( num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) + show_chain_of_iterables(ex_iterable) ex_iterable.shift_rngs(value=worker_info.id) self._state_dict = { "examples_iterable": ex_iterable._init_state_dict(), diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index b1048d03c98..af3681f56f7 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1555,25 +1555,48 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa @require_torch def test_iterable_dataset_shuffle_with_multiple_workers_different_rng(): - from itertools import groupby + # GH 7567 + from torch.utils.data import DataLoader, get_worker_info - from torch.utils.data import DataLoader + def gen(shard): + worker_info = get_worker_info() + for i in range(100): + yield {"value": i, "worker_id": worker_info.id} - num_workers = 8 - ex_iterable = ExamplesIterable( - generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(num_workers)], "n": 10} - ) - dataset = IterableDataset(ex_iterable).shuffle(buffer_size=100, seed=42) - dataloader = DataLoader(dataset, batch_size=None, num_workers=num_workers) + num_workers = 20 + ds = IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers))}) + ds = ds.shuffle(buffer_size=100, seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) result = list(dataloader) - assert len(result) == num_workers * 10 + for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]: + values = [item["value"] for item in single_chunk] + # This will fail with the chance 1/100 ** 20! + assert len(set(values)) != 1, "Make sure not all values are identical" + + +@require_torch +def test_iterable_dataset_interleave_dataset_with_multiple_workers(): + # GH 7567 + from torch.utils.data import DataLoader + + def gen(shard, value): + for i in range(100): + yield {"value": value} - chunks = [list(group) for _, group in groupby(enumerate(result), key=lambda x: x[0] // num_workers)] - for chunk_idx, chunk in enumerate(chunks): - values = [ex["id"] for _, ex in chunk] - unique_values = set(values) - assert len(unique_values) > 1, f"Chunk {chunk_idx}: all workers produced same values {values}" + num_workers = 20 + ds = [ + IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i}) + for i in range(10) + ] + ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + result = list(dataloader) + for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]: + values = [item["value"] for item in single_chunk] + # This will fail with the chance 1/100 ** 20! + assert len(set(values)) != 1, "Make sure not all values are identical" @pytest.mark.parametrize("batch_size", [4, 5]) From 8cade894d6639b853ed9af976adaea620e31f8e4 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Fri, 17 Oct 2025 16:02:54 +0200 Subject: [PATCH 6/9] update test description --- tests/test_iterable_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index af3681f56f7..362253d5ea3 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1595,7 +1595,6 @@ def gen(shard, value): result = list(dataloader) for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]: values = [item["value"] for item in single_chunk] - # This will fail with the chance 1/100 ** 20! assert len(set(values)) != 1, "Make sure not all values are identical" From 711a993991f952967cdf7e79db535e884d050efa Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Fri, 17 Oct 2025 16:06:40 +0200 Subject: [PATCH 7/9] remove debugging strings --- src/datasets/iterable_dataset.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 9cc04c965e7..ffd55af1b31 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -74,19 +74,6 @@ from .utils.typing import PathLike -def show_chain_of_iterables(ex_iterable: "_BaseExamplesIterable") -> str: - chain_of_iters = [ex_iterable.__class__.__name__] - - def recurse(obj): - new_obj = getattr(obj, "ex_iterable", None) - if new_obj is not None: - chain_of_iters.append(new_obj.__class__.__name__) - recurse(new_obj) - - recurse(ex_iterable) - print(f"[show_chain_of_iterables]: {','.join(chain_of_iters)}") - - if TYPE_CHECKING: import sqlite3 @@ -2395,7 +2382,6 @@ def _iter_pytorch(self): ex_iterable = ex_iterable.shard_data_sources( num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) - show_chain_of_iterables(ex_iterable) ex_iterable.shift_rngs(value=worker_info.id) self._state_dict = { "examples_iterable": ex_iterable._init_state_dict(), From a8ffea576d425532ef496c1dbf7c47f77a14abb1 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Thu, 23 Oct 2025 07:45:49 +0200 Subject: [PATCH 8/9] return instances of baseexiterable instead of modifying inplace --- src/datasets/iterable_dataset.py | 60 +++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index ffd55af1b31..acbe50e3882 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -169,6 +169,19 @@ def _convert_to_arrow( yield new_key, pa.Table.from_pylist(cast_to_python_objects(examples, only_1d_for_numpy=True)) +def shift_ex_examples_rngs(ex_iterable: "_BaseExamplesIterable", value: int) -> "_BaseExamplesIterable": + """We need to go through the ex_iterables recursively, create a new seed and return a new iterable, then set it to the containing ex_iterable.""" + + def set_seed_recursively(ex_iterable): + if hasattr(ex_iterable, "shift_rngs"): + ex_iterable = ex_iterable.shift_rngs(value) + if hasattr(ex_iterable, "ex_iterable"): + ex_iterable.ex_iterable = set_seed_recursively(ex_iterable.ex_iterable) + return ex_iterable + + return set_seed_recursively(ex_iterable) + + class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" @@ -198,16 +211,6 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples """ raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") - def shift_rngs(self, value: int) -> None: - def set_seed_recursively(ex_iterable): - if hasattr(ex_iterable, "generator"): - new_seed = ex_iterable.generator.bit_generator.state["state"]["state"] + value - ex_iterable.generator = np.random.default_rng(seed=new_seed) - if hasattr(ex_iterable, "ex_iterable"): - set_seed_recursively(ex_iterable.ex_iterable) - - set_seed_recursively(self) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") @@ -293,6 +296,14 @@ def __init__( super().__init__(generate_examples_fn, kwargs) self.generator = deepcopy(generator) + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return ShuffledDataSourcesExamplesIterable( + self.generate_examples_fn, + self.kwargs, + np.random.default_rng(seed=new_seed), + ) + def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict @@ -400,6 +411,14 @@ def __init__( super().__init__(generate_tables_fn, kwargs) self.generator = deepcopy(generator) + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return ShuffledDataSourcesArrowExamplesIterable( + self.generate_examples_fn, + self.kwargs, + np.random.default_rng(seed=new_seed), + ) + def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict @@ -1041,6 +1060,15 @@ def __init__( self.generator = deepcopy(generator) self.probabilities = probabilities + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return RandomlyCyclingMultiSourcesExamplesIterable( + ex_iterables=self.ex_iterables, + generator=np.random.default_rng(seed=new_seed), + probabilities=self.probabilities, + stopping_strategy=self.stopping_strategy, + ) + @property def is_typed(self): return self.ex_iterables[0].is_typed @@ -1638,6 +1666,14 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat self.buffer_size = buffer_size self.generator = generator + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return BufferShuffledExamplesIterable( + ex_iterable=self.ex_iterable, + buffer_size=self.buffer_size, + generator=np.random.default_rng(seed=new_seed), + ) + @property def is_typed(self): return self.ex_iterable.is_typed @@ -2382,7 +2418,7 @@ def _iter_pytorch(self): ex_iterable = ex_iterable.shard_data_sources( num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) - ex_iterable.shift_rngs(value=worker_info.id) + ex_iterable = shift_ex_examples_rngs(ex_iterable=ex_iterable, value=worker_info.id) self._state_dict = { "examples_iterable": ex_iterable._init_state_dict(), "epoch": self.epoch, @@ -3682,7 +3718,7 @@ def to_polars( Args: batch_size (`int`, *optional*): The size (number of rows) of the batches if `batched` is `True`. - Defaults to `genomicsml.datasets.config.DEFAULT_MAX_BATCH_SIZE`. + Defaults to `datasets.config.DEFAULT_MAX_BATCH_SIZE`. batched (`bool`): Set to `True` to return a generator that yields the dataset as batches of `batch_size` rows. Defaults to `False` (returns the whole datasets once). From 28b53e245ffd08dcaa0a7a6e2e2563ca72a656b7 Mon Sep 17 00:00:00 2001 From: Tobias Pitters Date: Thu, 23 Oct 2025 08:00:53 +0200 Subject: [PATCH 9/9] add test to make sure multiple iterations over data are deterministic --- tests/test_iterable_dataset.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 362253d5ea3..583f5dab51a 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1598,6 +1598,32 @@ def gen(shard, value): assert len(set(values)) != 1, "Make sure not all values are identical" +@require_torch +def test_iterable_dataset_interleave_dataset_deterministic_across_iterations(): + # GH 7567 + from torch.utils.data import DataLoader + + def gen(shard, value): + for i in range(50): + yield {"value": value, "id": i} + + num_workers = 10 + ds = [ + IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i}) + for i in range(5) + ] + ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + # First iteration + first_result = list(dataloader) + + # Second iteration + second_result = list(dataloader) + + assert first_result == second_result, "Results should be identical across iterations when using same seed" + + @pytest.mark.parametrize("batch_size", [4, 5]) @pytest.mark.parametrize("drop_last_batch", [False, True]) def test_iterable_dataset_iter_batch(batch_size, drop_last_batch):