@@ -497,8 +497,12 @@ def __next__(self):
497497 return self .micro_batches .pop (0 )
498498
499499 def refill (self ):
500- # this will raise StopIteration when empty
501- batch = next (self .iter )
500+ # reset the iterator if StopIteration arrives, and re-raise it to allow multiple epochs to run
501+ try :
502+ batch = next (self .iter )
503+ except StopIteration :
504+ self .iter = iter (self .dl )
505+ raise StopIteration
502506 micro_batches = defaultdict (dict )
503507 # XXX: replace with more efficient all-to-all?
504508
@@ -982,7 +986,8 @@ def forward(
982986 if output_reduction == "mean" :
983987 incoming_grad /= shards
984988
985- x_grad = torch .zeros_like (x )
989+ # XXX: deal with the use case of running in inference mode, where we don't need backward
990+ x_grad = torch .zeros_like (x ) if x_requires_grad else None
986991 x_shards = list (torch .chunk (x , chunks = shards , dim = 0 ))
987992 y_shards = list (torch .chunk (y , chunks = shards , dim = 0 ))
988993 if mask is not None :
@@ -1007,15 +1012,18 @@ def forward(
10071012 shard_step = x_shards [i ].shape [0 ]
10081013 shard_offset = i * x_shards [0 ].shape [0 ]
10091014
1010- x_shard .grad = x_grad .narrow (0 , shard_offset , shard_step ).view_as (x_shard )
1011-
1012- with torch .enable_grad ():
1013- args = (self , x_shard , y_shard )
1014- if mask is not None :
1015- args += (mask_shards [i ], )
1015+ args = (self , x_shard , y_shard )
1016+ if mask is not None :
1017+ args += (mask_shards [i ], )
1018+ if x_grad is not None :
1019+ x_shard .grad = x_grad .narrow (0 , shard_offset , shard_step ).view_as (x_shard )
1020+ with torch .enable_grad ():
1021+ output = fn (* args )
1022+ output_shards .append (output )
1023+ torch .autograd .backward (output , incoming_grad )
1024+ else :
10161025 output = fn (* args )
10171026 output_shards .append (output )
1018- torch .autograd .backward (output , incoming_grad )
10191027
10201028 output_unsharded = torch .cat ([l .unsqueeze (0 ) for l in output_shards ], dim = 0 )
10211029
@@ -1025,9 +1033,10 @@ def forward(
10251033 output = output_unsharded .sum ()
10261034
10271035 # unflatten
1028- x_grad = x_grad .view (bs , seqlen , * x_grad .shape [1 :])
1036+ if x_grad is not None :
1037+ x_grad = x_grad .view (bs , seqlen , * x_grad .shape [1 :])
1038+ ctx .save_for_backward (x_grad .detach ())
10291039
1030- ctx .save_for_backward (x_grad .detach ())
10311040 return output
10321041
10331042 @staticmethod
0 commit comments