File tree 1 file changed +2
-4
lines changed
1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -417,7 +417,7 @@ def _should_start_trace():
417
417
stop_trace_step = self .step + 3
418
418
self ._step = self ._step + 1
419
419
self .vlog (3 , "Start step %s" , self .step )
420
- output = self ._run_step (input_batch )
420
+ output = self ._run_step (utils . host_to_global_device_array ( input_batch ) )
421
421
self .vlog (3 , "Done step %s" , self .step )
422
422
num_steps += 1
423
423
if num_steps % 100 == 0 :
@@ -688,13 +688,11 @@ def _run_step(self, input_batch: NestedTensor) -> NestedTensor:
688
688
"""Runs a single training step.
689
689
690
690
Args:
691
- input_batch: a NestedTensor.
691
+ input_batch: a NestedTensor containing global arrays .
692
692
693
693
Returns:
694
694
A dict containing 'loss' and 'aux' outputs.
695
695
"""
696
- input_batch = utils .host_to_global_device_array (input_batch )
697
-
698
696
with jax .profiler .StepTraceAnnotation ("train" , step_num = self .step ):
699
697
# Note(Jan 2022):
700
698
# pjit currently requires all parameters to be specified as positional args.
You can’t perform that action at this time.
0 commit comments