Skip to content

Commit 1439da4

Browse files
authored
Fixes incorrect strategy init with HPUAccelerator (#19615)
1 parent 97a95ed commit 1439da4

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -408,21 +408,25 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
408408
return LightningEnvironment()
409409

410410
def _choose_strategy(self) -> Union[Strategy, str]:
411-
if self._accelerator_flag == "hpu":
412-
if not _habana_available_and_importable():
413-
raise ImportError(
414-
"You have asked for HPU but you miss install related integration."
415-
" Please run `pip install lightning-habana` or see for further instructions"
416-
" in https://github.com/Lightning-AI/lightning-Habana/."
417-
)
418-
if self._parallel_devices and len(self._parallel_devices) > 1:
419-
from lightning_habana import HPUParallelStrategy
411+
if _habana_available_and_importable():
412+
from lightning_habana import HPUAccelerator
420413

421-
return HPUParallelStrategy.strategy_name
414+
if self._accelerator_flag == "hpu" or isinstance(self._accelerator_flag, HPUAccelerator):
415+
if self._parallel_devices and len(self._parallel_devices) > 1:
416+
from lightning_habana import HPUParallelStrategy
422417

423-
from lightning_habana import SingleHPUStrategy
418+
return HPUParallelStrategy.strategy_name
419+
420+
from lightning_habana import SingleHPUStrategy
421+
422+
return SingleHPUStrategy(device=torch.device("hpu"))
423+
if self._accelerator_flag == "hpu" and not _habana_available_and_importable():
424+
raise ImportError(
425+
"You asked to run with HPU but you are missing a required dependency."
426+
" Please run `pip install lightning-habana` or seek further instructions"
427+
" in https://github.com/Lightning-AI/lightning-Habana/."
428+
)
424429

425-
return SingleHPUStrategy(device=torch.device("hpu"))
426430
if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator):
427431
if self._parallel_devices and len(self._parallel_devices) > 1:
428432
return XLAStrategy.strategy_name

0 commit comments

Comments
 (0)