Open
Description
🚀 Feature
- Add a
--device
arg to determine the backend, and modify the hard code related to cuda to accelerator. - Add Ascend NPU support by introducing
torch-npu
adapter.
Motivation & Examples
Currently, PyTorch supports many accelerators besides NVIDIA GPU, e.g., XLA devices (like TPUs), XPU, MPS and Ascend NPU. Adding a --device
argument for users to specify the accelerator they would like to use is helpful. If this is acceptable for community, I would like to do this work.
Moreover, on the basis of device
arg, I would like to add support for Ascend NPU backend for detectron2.
A tiny example
The modify of _distributed_worker
func:
def _distributed_worker(
local_rank,
main_func,
world_size,
- num_gpus_per_machine,
+ num_accelerators_per_machine,
machine_rank,
dist_url,
args,
timeout=DEFAULT_TIMEOUT,
):
- has_gpu = torch.cuda.is_available()
- if has_gpu:
- assert num_gpus_per_machine <= torch.cuda.device_count()
- global_rank = machine_rank * num_gpus_per_machine + local_rank
+ device = args[0].device
+ dist_backend = "gloo"
+ if "cuda" in device:
+ if torch.cuda.is_available():
+ assert num_accelerators_per_machine <= torch.cuda.device_count()
+ dist_backend = "nccl"
+ elif "npu" in device:
+ if torch.npu.is_available():
+ assert num_accelerators_per_machine <= torch.npu.device_count()
+ dist_backend = "hccl"
+ global_rank = machine_rank * num_accelerators_per_machine + local_rank
try:
dist.init_process_group(
- backend="NCCL" if has_gpu else "GLOO",
+ backend=dist_backend,
init_method=dist_url,
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(dist_url))
raise e
# Setup the local process group.
- comm.create_local_process_group(num_gpus_per_machine)
- if has_gpu:
+ comm.create_local_process_group(num_accelerators_per_machine)
+ if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
# synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
main_func(*args)
Related Info
torch.device : https://pytorch.org/docs/stable/tensor_attributes.html#torch-device
torch Praviteuse1 (Registering new backend module to Pytorch) : https://pytorch.org/tutorials/advanced/privateuseone.html