Skip to content

Add device argument for multi-backends access & Ascend NPU support #5285

Open
@MengqingCao

Description

@MengqingCao

🚀 Feature

  1. Add a --device arg to determine the backend, and modify the hard code related to cuda to accelerator.
  2. Add Ascend NPU support by introducing torch-npu adapter.

Motivation & Examples

Currently, PyTorch supports many accelerators besides NVIDIA GPU, e.g., XLA devices (like TPUs), XPU, MPS and Ascend NPU. Adding a --device argument for users to specify the accelerator they would like to use is helpful. If this is acceptable for community, I would like to do this work.

Moreover, on the basis of device arg, I would like to add support for Ascend NPU backend for detectron2.

A tiny example
The modify of _distributed_worker func:

def _distributed_worker(
     local_rank,
     main_func,
     world_size,
-    num_gpus_per_machine,
+    num_accelerators_per_machine,
     machine_rank,
     dist_url,
     args,
     timeout=DEFAULT_TIMEOUT,
 ):
-    has_gpu = torch.cuda.is_available()
-    if has_gpu:
-        assert num_gpus_per_machine <= torch.cuda.device_count()
-    global_rank = machine_rank * num_gpus_per_machine + local_rank
+    device = args[0].device
+    dist_backend = "gloo"
+    if "cuda" in device:
+        if torch.cuda.is_available():
+            assert num_accelerators_per_machine <= torch.cuda.device_count()
+            dist_backend = "nccl"
+    elif "npu" in device:
+        if torch.npu.is_available():
+            assert num_accelerators_per_machine <= torch.npu.device_count()
+            dist_backend = "hccl"
+    global_rank = machine_rank * num_accelerators_per_machine + local_rank
     try:
         dist.init_process_group(
-            backend="NCCL" if has_gpu else "GLOO",
+            backend=dist_backend,
             init_method=dist_url,
             init_method=dist_url,
             world_size=world_size,
             rank=global_rank,
            timeout=timeout,
        )
     except Exception as e:
        logger = logging.getLogger(__name__)
        logger.error("Process group URL: {}".format(dist_url))
        raise e
 
     # Setup the local process group.
-    comm.create_local_process_group(num_gpus_per_machine)
-    if has_gpu:
+    comm.create_local_process_group(num_accelerators_per_machine)
+    if torch.cuda.is_available():
         torch.cuda.set_device(local_rank)

    # synchronize is needed here to prevent a possible timeout after calling init_process_group
    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
    comm.synchronize()

    main_func(*args)

Related Info

torch.device : https://pytorch.org/docs/stable/tensor_attributes.html#torch-device

torch Praviteuse1 (Registering new backend module to Pytorch) : https://pytorch.org/tutorials/advanced/privateuseone.html

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementImprovements or good new features

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions