@@ -169,6 +169,19 @@ def _convert_to_arrow(
169169        yield  new_key , pa .Table .from_pylist (cast_to_python_objects (examples , only_1d_for_numpy = True ))
170170
171171
172+ def  shift_ex_examples_rngs (ex_iterable : "_BaseExamplesIterable" , value : int ) ->  "_BaseExamplesIterable" :
173+     """We need to go through the ex_iterables recursively, create a new seed and return a new iterable, then set it to the containing ex_iterable.""" 
174+ 
175+     def  set_seed_recursively (ex_iterable ):
176+         if  hasattr (ex_iterable , "shift_rngs" ):
177+             ex_iterable  =  ex_iterable .shift_rngs (value )
178+         if  hasattr (ex_iterable , "ex_iterable" ):
179+             ex_iterable .ex_iterable  =  set_seed_recursively (ex_iterable .ex_iterable )
180+         return  ex_iterable 
181+ 
182+     return  set_seed_recursively (ex_iterable )
183+ 
184+ 
172185class  _BaseExamplesIterable :
173186    """Base class for the examples iterable used by an IterableDataset""" 
174187
@@ -283,6 +296,14 @@ def __init__(
283296        super ().__init__ (generate_examples_fn , kwargs )
284297        self .generator  =  deepcopy (generator )
285298
299+     def  shift_rngs (self , value : int ) ->  "_BaseExamplesIterable" :
300+         new_seed  =  self .generator .bit_generator .state ["state" ]["state" ] +  value 
301+         return  ShuffledDataSourcesExamplesIterable (
302+             self .generate_examples_fn ,
303+             self .kwargs ,
304+             np .random .default_rng (seed = new_seed ),
305+         )
306+ 
286307    def  _init_state_dict (self ) ->  dict :
287308        self ._state_dict  =  {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self .__class__ .__name__ }
288309        return  self ._state_dict 
@@ -390,6 +411,14 @@ def __init__(
390411        super ().__init__ (generate_tables_fn , kwargs )
391412        self .generator  =  deepcopy (generator )
392413
414+     def  shift_rngs (self , value : int ) ->  "_BaseExamplesIterable" :
415+         new_seed  =  self .generator .bit_generator .state ["state" ]["state" ] +  value 
416+         return  ShuffledDataSourcesArrowExamplesIterable (
417+             self .generate_examples_fn ,
418+             self .kwargs ,
419+             np .random .default_rng (seed = new_seed ),
420+         )
421+ 
393422    def  _init_state_dict (self ) ->  dict :
394423        self ._state_dict  =  {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self .__class__ .__name__ }
395424        return  self ._state_dict 
@@ -1031,6 +1060,15 @@ def __init__(
10311060        self .generator  =  deepcopy (generator )
10321061        self .probabilities  =  probabilities 
10331062
1063+     def  shift_rngs (self , value : int ) ->  "_BaseExamplesIterable" :
1064+         new_seed  =  self .generator .bit_generator .state ["state" ]["state" ] +  value 
1065+         return  RandomlyCyclingMultiSourcesExamplesIterable (
1066+             ex_iterables = self .ex_iterables ,
1067+             generator = np .random .default_rng (seed = new_seed ),
1068+             probabilities = self .probabilities ,
1069+             stopping_strategy = self .stopping_strategy ,
1070+         )
1071+ 
10341072    @property  
10351073    def  is_typed (self ):
10361074        return  self .ex_iterables [0 ].is_typed 
@@ -1628,6 +1666,14 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
16281666        self .buffer_size  =  buffer_size 
16291667        self .generator  =  generator 
16301668
1669+     def  shift_rngs (self , value : int ) ->  "_BaseExamplesIterable" :
1670+         new_seed  =  self .generator .bit_generator .state ["state" ]["state" ] +  value 
1671+         return  BufferShuffledExamplesIterable (
1672+             ex_iterable = self .ex_iterable ,
1673+             buffer_size = self .buffer_size ,
1674+             generator = np .random .default_rng (seed = new_seed ),
1675+         )
1676+ 
16311677    @property  
16321678    def  is_typed (self ):
16331679        return  self .ex_iterable .is_typed 
@@ -2372,6 +2418,7 @@ def _iter_pytorch(self):
23722418            ex_iterable  =  ex_iterable .shard_data_sources (
23732419                num_shards = worker_info .num_workers , index = worker_info .id , contiguous = False 
23742420            )
2421+             ex_iterable  =  shift_ex_examples_rngs (ex_iterable = ex_iterable , value = worker_info .id )
23752422            self ._state_dict  =  {
23762423                "examples_iterable" : ex_iterable ._init_state_dict (),
23772424                "epoch" : self .epoch ,
0 commit comments