@@ -408,21 +408,25 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
408
408
return LightningEnvironment ()
409
409
410
410
def _choose_strategy (self ) -> Union [Strategy , str ]:
411
- if self ._accelerator_flag == "hpu" :
412
- if not _habana_available_and_importable ():
413
- raise ImportError (
414
- "You have asked for HPU but you miss install related integration."
415
- " Please run `pip install lightning-habana` or see for further instructions"
416
- " in https://github.com/Lightning-AI/lightning-Habana/."
417
- )
418
- if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
419
- from lightning_habana import HPUParallelStrategy
411
+ if _habana_available_and_importable ():
412
+ from lightning_habana import HPUAccelerator
420
413
421
- return HPUParallelStrategy .strategy_name
414
+ if self ._accelerator_flag == "hpu" or isinstance (self ._accelerator_flag , HPUAccelerator ):
415
+ if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
416
+ from lightning_habana import HPUParallelStrategy
422
417
423
- from lightning_habana import SingleHPUStrategy
418
+ return HPUParallelStrategy .strategy_name
419
+
420
+ from lightning_habana import SingleHPUStrategy
421
+
422
+ return SingleHPUStrategy (device = torch .device ("hpu" ))
423
+ if self ._accelerator_flag == "hpu" and not _habana_available_and_importable ():
424
+ raise ImportError (
425
+ "You asked to run with HPU but you are missing a required dependency."
426
+ " Please run `pip install lightning-habana` or seek further instructions"
427
+ " in https://github.com/Lightning-AI/lightning-Habana/."
428
+ )
424
429
425
- return SingleHPUStrategy (device = torch .device ("hpu" ))
426
430
if self ._accelerator_flag == "tpu" or isinstance (self ._accelerator_flag , XLAAccelerator ):
427
431
if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
428
432
return XLAStrategy .strategy_name
0 commit comments