Skip to content

Commit a7600ac

Browse files
authored
Fix random seed on shuffle and interleave_datasets (#7823)
* WIP: shuffle working, interleave_ds not yet * remove debug statements * add test * update test * use recursive overwriting of generator seeds * update test description * remove debugging strings * return instances of baseexiterable instead of modifying inplace * add test to make sure multiple iterations over data are deterministic
1 parent 5138876 commit a7600ac

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

src/datasets/iterable_dataset.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,19 @@ def _convert_to_arrow(
169169
yield new_key, pa.Table.from_pylist(cast_to_python_objects(examples, only_1d_for_numpy=True))
170170

171171

172+
def shift_ex_examples_rngs(ex_iterable: "_BaseExamplesIterable", value: int) -> "_BaseExamplesIterable":
173+
"""We need to go through the ex_iterables recursively, create a new seed and return a new iterable, then set it to the containing ex_iterable."""
174+
175+
def set_seed_recursively(ex_iterable):
176+
if hasattr(ex_iterable, "shift_rngs"):
177+
ex_iterable = ex_iterable.shift_rngs(value)
178+
if hasattr(ex_iterable, "ex_iterable"):
179+
ex_iterable.ex_iterable = set_seed_recursively(ex_iterable.ex_iterable)
180+
return ex_iterable
181+
182+
return set_seed_recursively(ex_iterable)
183+
184+
172185
class _BaseExamplesIterable:
173186
"""Base class for the examples iterable used by an IterableDataset"""
174187

@@ -283,6 +296,14 @@ def __init__(
283296
super().__init__(generate_examples_fn, kwargs)
284297
self.generator = deepcopy(generator)
285298

299+
def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
300+
new_seed = self.generator.bit_generator.state["state"]["state"] + value
301+
return ShuffledDataSourcesExamplesIterable(
302+
self.generate_examples_fn,
303+
self.kwargs,
304+
np.random.default_rng(seed=new_seed),
305+
)
306+
286307
def _init_state_dict(self) -> dict:
287308
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
288309
return self._state_dict
@@ -390,6 +411,14 @@ def __init__(
390411
super().__init__(generate_tables_fn, kwargs)
391412
self.generator = deepcopy(generator)
392413

414+
def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
415+
new_seed = self.generator.bit_generator.state["state"]["state"] + value
416+
return ShuffledDataSourcesArrowExamplesIterable(
417+
self.generate_examples_fn,
418+
self.kwargs,
419+
np.random.default_rng(seed=new_seed),
420+
)
421+
393422
def _init_state_dict(self) -> dict:
394423
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
395424
return self._state_dict
@@ -1031,6 +1060,15 @@ def __init__(
10311060
self.generator = deepcopy(generator)
10321061
self.probabilities = probabilities
10331062

1063+
def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
1064+
new_seed = self.generator.bit_generator.state["state"]["state"] + value
1065+
return RandomlyCyclingMultiSourcesExamplesIterable(
1066+
ex_iterables=self.ex_iterables,
1067+
generator=np.random.default_rng(seed=new_seed),
1068+
probabilities=self.probabilities,
1069+
stopping_strategy=self.stopping_strategy,
1070+
)
1071+
10341072
@property
10351073
def is_typed(self):
10361074
return self.ex_iterables[0].is_typed
@@ -1628,6 +1666,14 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
16281666
self.buffer_size = buffer_size
16291667
self.generator = generator
16301668

1669+
def shift_rngs(self, value: int) -> "_BaseExamplesIterable":
1670+
new_seed = self.generator.bit_generator.state["state"]["state"] + value
1671+
return BufferShuffledExamplesIterable(
1672+
ex_iterable=self.ex_iterable,
1673+
buffer_size=self.buffer_size,
1674+
generator=np.random.default_rng(seed=new_seed),
1675+
)
1676+
16311677
@property
16321678
def is_typed(self):
16331679
return self.ex_iterable.is_typed
@@ -2372,6 +2418,7 @@ def _iter_pytorch(self):
23722418
ex_iterable = ex_iterable.shard_data_sources(
23732419
num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
23742420
)
2421+
ex_iterable = shift_ex_examples_rngs(ex_iterable=ex_iterable, value=worker_info.id)
23752422
self._state_dict = {
23762423
"examples_iterable": ex_iterable._init_state_dict(),
23772424
"epoch": self.epoch,

tests/test_iterable_dataset.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,77 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa
15531553
assert len(result) == 10
15541554

15551555

1556+
@require_torch
1557+
def test_iterable_dataset_shuffle_with_multiple_workers_different_rng():
1558+
# GH 7567
1559+
from torch.utils.data import DataLoader, get_worker_info
1560+
1561+
def gen(shard):
1562+
worker_info = get_worker_info()
1563+
for i in range(100):
1564+
yield {"value": i, "worker_id": worker_info.id}
1565+
1566+
num_workers = 20
1567+
ds = IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers))})
1568+
ds = ds.shuffle(buffer_size=100, seed=1234)
1569+
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)
1570+
1571+
result = list(dataloader)
1572+
for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]:
1573+
values = [item["value"] for item in single_chunk]
1574+
# This will fail with the chance 1/100 ** 20!
1575+
assert len(set(values)) != 1, "Make sure not all values are identical"
1576+
1577+
1578+
@require_torch
1579+
def test_iterable_dataset_interleave_dataset_with_multiple_workers():
1580+
# GH 7567
1581+
from torch.utils.data import DataLoader
1582+
1583+
def gen(shard, value):
1584+
for i in range(100):
1585+
yield {"value": value}
1586+
1587+
num_workers = 20
1588+
ds = [
1589+
IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i})
1590+
for i in range(10)
1591+
]
1592+
ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
1593+
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)
1594+
1595+
result = list(dataloader)
1596+
for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]:
1597+
values = [item["value"] for item in single_chunk]
1598+
assert len(set(values)) != 1, "Make sure not all values are identical"
1599+
1600+
1601+
@require_torch
1602+
def test_iterable_dataset_interleave_dataset_deterministic_across_iterations():
1603+
# GH 7567
1604+
from torch.utils.data import DataLoader
1605+
1606+
def gen(shard, value):
1607+
for i in range(50):
1608+
yield {"value": value, "id": i}
1609+
1610+
num_workers = 10
1611+
ds = [
1612+
IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i})
1613+
for i in range(5)
1614+
]
1615+
ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
1616+
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)
1617+
1618+
# First iteration
1619+
first_result = list(dataloader)
1620+
1621+
# Second iteration
1622+
second_result = list(dataloader)
1623+
1624+
assert first_result == second_result, "Results should be identical across iterations when using same seed"
1625+
1626+
15561627
@pytest.mark.parametrize("batch_size", [4, 5])
15571628
@pytest.mark.parametrize("drop_last_batch", [False, True])
15581629
def test_iterable_dataset_iter_batch(batch_size, drop_last_batch):

0 commit comments

Comments
 (0)