Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ def _convert_to_arrow(
yield new_key, pa.Table.from_pylist(cast_to_python_objects(examples, only_1d_for_numpy=True))


def shift_ex_examples_rngs(ex_iterable: "_BaseExamplesIterable", value: int) -> "_BaseExamplesIterable":
"""We need to go through the ex_iterables recursively, create a new seed and return a new iterable, then set it to the containing ex_iterable."""

def set_seed_recursively(ex_iterable):
if hasattr(ex_iterable, "shift_rngs"):
ex_iterable = ex_iterable.shift_rngs(value)
if hasattr(ex_iterable, "ex_iterable"):
ex_iterable.ex_iterable = set_seed_recursively(ex_iterable.ex_iterable)
return ex_iterable

return set_seed_recursively(ex_iterable)


class _BaseExamplesIterable:
"""Base class for the examples iterable used by an IterableDataset"""

Expand Down Expand Up @@ -283,6 +296,14 @@ def __init__(
super().__init__(generate_examples_fn, kwargs)
self.generator = deepcopy(generator)

def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
new_seed = self.generator.bit_generator.state["state"]["state"] + value
return ShuffledDataSourcesExamplesIterable(
self.generate_examples_fn,
self.kwargs,
np.random.default_rng(seed=new_seed),
)

def _init_state_dict(self) -> dict:
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
return self._state_dict
Expand Down Expand Up @@ -390,6 +411,14 @@ def __init__(
super().__init__(generate_tables_fn, kwargs)
self.generator = deepcopy(generator)

def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
new_seed = self.generator.bit_generator.state["state"]["state"] + value
return ShuffledDataSourcesArrowExamplesIterable(
self.generate_examples_fn,
self.kwargs,
np.random.default_rng(seed=new_seed),
)

def _init_state_dict(self) -> dict:
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
return self._state_dict
Expand Down Expand Up @@ -1031,6 +1060,15 @@ def __init__(
self.generator = deepcopy(generator)
self.probabilities = probabilities

def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
new_seed = self.generator.bit_generator.state["state"]["state"] + value
return RandomlyCyclingMultiSourcesExamplesIterable(
ex_iterables=self.ex_iterables,
generator=np.random.default_rng(seed=new_seed),
probabilities=self.probabilities,
stopping_strategy=self.stopping_strategy,
)

@property
def is_typed(self):
return self.ex_iterables[0].is_typed
Expand Down Expand Up @@ -1628,6 +1666,14 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
self.buffer_size = buffer_size
self.generator = generator

def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
new_seed = self.generator.bit_generator.state["state"]["state"] + value
return BufferShuffledExamplesIterable(
ex_iterable=self.ex_iterable,
buffer_size=self.buffer_size,
generator=np.random.default_rng(seed=new_seed),
)

@property
def is_typed(self):
return self.ex_iterable.is_typed
Expand Down Expand Up @@ -2372,6 +2418,7 @@ def _iter_pytorch(self):
ex_iterable = ex_iterable.shard_data_sources(
num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
)
ex_iterable = shift_ex_examples_rngs(ex_iterable=ex_iterable, value=worker_info.id)
self._state_dict = {
"examples_iterable": ex_iterable._init_state_dict(),
"epoch": self.epoch,
Expand Down
71 changes: 71 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,77 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa
assert len(result) == 10


@require_torch
def test_iterable_dataset_shuffle_with_multiple_workers_different_rng():
# GH 7567
from torch.utils.data import DataLoader, get_worker_info

def gen(shard):
worker_info = get_worker_info()
for i in range(100):
yield {"value": i, "worker_id": worker_info.id}

num_workers = 20
ds = IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers))})
ds = ds.shuffle(buffer_size=100, seed=1234)
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)

result = list(dataloader)
for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]:
values = [item["value"] for item in single_chunk]
# This will fail with the chance 1/100 ** 20!
assert len(set(values)) != 1, "Make sure not all values are identical"


@require_torch
def test_iterable_dataset_interleave_dataset_with_multiple_workers():
# GH 7567
from torch.utils.data import DataLoader

def gen(shard, value):
for i in range(100):
yield {"value": value}

num_workers = 20
ds = [
IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i})
for i in range(10)
]
ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)

result = list(dataloader)
for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]:
values = [item["value"] for item in single_chunk]
assert len(set(values)) != 1, "Make sure not all values are identical"


@require_torch
def test_iterable_dataset_interleave_dataset_deterministic_across_iterations():
# GH 7567
from torch.utils.data import DataLoader

def gen(shard, value):
for i in range(50):
yield {"value": value, "id": i}

num_workers = 10
ds = [
IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i})
for i in range(5)
]
ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)

# First iteration
first_result = list(dataloader)

# Second iteration
second_result = list(dataloader)

assert first_result == second_result, "Results should be identical across iterations when using same seed"


@pytest.mark.parametrize("batch_size", [4, 5])
@pytest.mark.parametrize("drop_last_batch", [False, True])
def test_iterable_dataset_iter_batch(batch_size, drop_last_batch):
Expand Down
Loading