Skip to content

Commit c00c632

Browse files
committed
Moves the utils.host_to_global_device_array call from SpmdTrainer._run_step to SpmdTrainer.run.
This makes it easier for subclasses of SpmdTrainer to override `_run_step`.
1 parent b2ccd7b commit c00c632

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

axlearn/common/trainer.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def _should_start_trace():
417417
stop_trace_step = self.step + 3
418418
self._step = self._step + 1
419419
self.vlog(3, "Start step %s", self.step)
420-
output = self._run_step(input_batch)
420+
output = self._run_step(utils.host_to_global_device_array(input_batch))
421421
self.vlog(3, "Done step %s", self.step)
422422
num_steps += 1
423423
if num_steps % 100 == 0:
@@ -688,13 +688,11 @@ def _run_step(self, input_batch: NestedTensor) -> NestedTensor:
688688
"""Runs a single training step.
689689
690690
Args:
691-
input_batch: a NestedTensor.
691+
input_batch: a NestedTensor containing global arrays.
692692
693693
Returns:
694694
A dict containing 'loss' and 'aux' outputs.
695695
"""
696-
input_batch = utils.host_to_global_device_array(input_batch)
697-
698696
with jax.profiler.StepTraceAnnotation("train", step_num=self.step):
699697
# Note(Jan 2022):
700698
# pjit currently requires all parameters to be specified as positional args.

0 commit comments

Comments
 (0)