diff --git a/test/common_utils.py b/test/common_utils.py index 99c7931587d..a0dac03ab0f 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -29,6 +29,7 @@ IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" +XPU_NOT_AVAILABLE_MSG = "XPU device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -141,6 +142,12 @@ def needs_mps(test_func): return pytest.mark.needs_mps(test_func) +def needs_xpu(test_func): + import pytest # noqa + + return pytest.mark.needs_xpu(test_func) + + def _create_data(height=3, width=3, channels=3, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) diff --git a/test/conftest.py b/test/conftest.py index a9768598ded..984cba981b9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ IN_RE_WORKER, MPS_NOT_AVAILABLE_MSG, OSS_CI_GPU_NO_CUDA_MSG, + XPU_NOT_AVAILABLE_MSG, ) @@ -18,6 +19,7 @@ def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems) config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device") + config.addinivalue_line("markers", "needs_xpu: mark for tests that rely on a XPU device") config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") config.addinivalue_line("markers", "opcheck_only_one: only opcheck one parametrization") @@ -43,12 +45,18 @@ def pytest_collection_modifyitems(items): # and the ones with device == 'cpu' won't have the mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None needs_mps = item.get_closest_marker("needs_mps") is not None + needs_xpu = item.get_closest_marker("needs_xpu") is not None if needs_cuda and not torch.cuda.is_available(): # In general, we skip cuda tests on machines without a GPU # There are special cases though, see below item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) + if needs_xpu and not torch.xpu.is_available(): + # In general, we skip xpu tests on machines without a GPU + # There are special cases though, see below + item.add_marker(pytest.mark.skip(reason=XPU_NOT_AVAILABLE_MSG)) + if needs_mps and not torch.backends.mps.is_available(): item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG)) diff --git a/test/test_ops.py b/test/test_ops.py index 1ba7a2c9efa..4519ed967a6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -831,6 +831,7 @@ def test_qnms(self, iou, scale, zero_point): ( pytest.param("cuda", marks=pytest.mark.needs_cuda), pytest.param("mps", marks=pytest.mark.needs_mps), + pytest.param("xpu", marks=pytest.mark.needs_xpu), ), ) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 50479066cbd..6bb03a355d4 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -107,10 +107,60 @@ at::Tensor nms_kernel( return result; } + +/** + * @brief Post-processes the results of the Non-Maximum Suppression (NMS) algorithm. + * + * This function iterates over the boxes and determines which ones to keep based on the IOU (Intersection Over Union) keep-out mask. + * It uses a 32-bitmask to efficiently track and suppress overlapping boxes. + * + * @param order A tensor containing the order of the boxes. + * @param iou_keep_out_mask A tensor containing the IOU keep-out mask. This mask has the shape (N, N//32), where N is the number of boxes. + * The datatype MUST be int32. + * @param num_boxes The total number of boxes. + * @return A tensor containing the indices of the boxes to keep. + */ + +at::Tensor nms_kernel_postprocess( + const at::Tensor& order, + const at::Tensor& iou_keep_out_mask, + const int64_t num_boxes) { + // Calculate the number of 32-bit blocks needed to cover all boxes + const int col_blocks = (num_boxes + 32 - 1) / 32; + std::vector remove_box(col_blocks); + std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks); + + + at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU)); + int64_t * keep_data_ptr = keep.data_ptr(); + + unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr(); + int num_to_keep = 0; + // Note that the iou_keep_out_mask has the shape of (N, N//32) + // The following function iterate over each box to check if it should be kept + for (int64_t i = 0; i < num_boxes; i++) { + int nblock = i / 32; + // This is equivalent to module: 31 - i % 32 + int inblock = (31 - i) & (32 -1); + + if (!(remove_box[nblock] & (1UL << inblock))){ + keep_data_ptr[num_to_keep++]=i; + unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks; + for (int j = nblock; j < col_blocks; j++){ + remove_box[j] |= p[j]; + } + } + } + return order.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)}); +} + + + } // namespace TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms_kernel_postprocess"), TORCH_FN(nms_kernel_postprocess)); } } // namespace ops diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp index 5ecf8812f1b..f1eb9c0ee0f 100644 --- a/torchvision/csrc/ops/nms.cpp +++ b/torchvision/csrc/ops/nms.cpp @@ -22,6 +22,8 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.set_python_module("torchvision._meta_registrations"); m.def(TORCH_SELECTIVE_SCHEMA( "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms_kernel_postprocess(Tensor order, Tensor iou_keep_out_mask, int num_boxes) -> Tensor")); } } // namespace ops diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 96631278d48..5d8356a5cae 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -138,6 +138,21 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: return keep +def _nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor: + """ + Post-processes the results of the non-maximum suppression (NMS) kernel. + Args: + order (Tensor): A tensor containing the order of the boxes. + iou_keep_out_mask (Tensor): A tensor containing the mask of boxes to keep based on IoU. + The datatype is int32. + num_boxes (int): The number of boxes. + Returns: + Tensor: A tensor containing the post-processed results of the NMS kernel. + """ + + return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) + + def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor: """ Clip boxes so that they lie inside an image of size ``size``. diff --git a/torchvision/ops/triton/__init__.py b/torchvision/ops/triton/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torchvision/ops/triton/nms.py b/torchvision/ops/triton/nms.py new file mode 100644 index 00000000000..545c850e6d0 --- /dev/null +++ b/torchvision/ops/triton/nms.py @@ -0,0 +1,98 @@ +import triton +import triton.language as tl + + +@triton.jit +def _combine_bits(val0, val1): + tl.static_assert(val0.dtype == tl.int32, "input must be int32") + tl.static_assert(val1.dtype == tl.int32, "input must be int32") + return val0 | val1 + + +@triton.jit +def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr): + """ + This nms_kernel computes the supressed mask of boxes [i, j]. + mask[i, j]==1 means if we choose box i, the box j will be supressed. + The output is a mask of size [num_boxes, num_boxes//32], where each item is int32. + + Args: + boxes (tl.tensor): A tensor containing the bounding boxes with shape (num_boxes, 4). + output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored. + threshold (float): The IoU threshold for suppressing boxes. + num_boxes (int): The total number of boxes. + stride_i (int): The stride of the output tensor along the first dimension. + stride_j (int): The stride of the output tensor along the second dimension. + BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel. + Returns: + Tensor (int32): Tensor with size [num_boxes, num_boxes//32]. It indicates that if `box i` is + choosen, whether box `j` could be choosen. The value `1` means it cannot be choosen. + """ + + # The Triton kernel is a 2D block kernel. The block size is BLOCK_SIZE x BLOCK_SIZE. + # Each kernel will compute the IoU of boxes[row: row + BLOCK_SIZE, col: col + BLOCK_SIZE] + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + + row_block_start = row_block_pid * BLOCK_SIZE + col_block_start = col_block_pid * BLOCK_SIZE + + row_block_offsets = row_block_start + tl.arange(0, BLOCK_SIZE) + col_block_offsets = col_block_start + tl.arange(0, BLOCK_SIZE) + + row_block_mask = row_block_offsets < num_boxes + col_block_mask = col_block_offsets < num_boxes + + # Since Triton does not support tensor slicing yet, we need to load point elements individiually + # Every row_block is loaded as a 1 dim tensor of size [BLOCK_SIZE] + # We then expand 1 dim for row. So that the row block dim would be [BLOCK_SIZE, 1] + row_block_x1 = tl.load(boxes + row_block_offsets * 4 + 0, mask=row_block_mask)[:, None] + row_block_y1 = tl.load(boxes + row_block_offsets * 4 + 1, mask=row_block_mask)[:, None] + row_block_x2 = tl.load(boxes + row_block_offsets * 4 + 2, mask=row_block_mask)[:, None] + row_block_y2 = tl.load(boxes + row_block_offsets * 4 + 3, mask=row_block_mask)[:, None] + + # Expand 1 dim for col. So that the col block dim would be [1, BLOCK_SIZE] + col_block_x1 = tl.load(boxes + col_block_offsets * 4 + 0, mask=col_block_mask)[None, :] + col_block_y1 = tl.load(boxes + col_block_offsets * 4 + 1, mask=col_block_mask)[None, :] + col_block_x2 = tl.load(boxes + col_block_offsets * 4 + 2, mask=col_block_mask)[None, :] + col_block_y2 = tl.load(boxes + col_block_offsets * 4 + 3, mask=col_block_mask)[None, :] + + # Together, the minimum / maximum will broadcast and form into a [BLOCK_SIZE, BLOCK_SIZE] matrix + left = tl.maximum(row_block_x1, col_block_x1) + right = tl.minimum(row_block_x2, col_block_x2) + top = tl.maximum(row_block_y1, col_block_y1) + bottom = tl.minimum(row_block_y2, col_block_y2) + + width = tl.maximum(right - left, 0) + height = tl.maximum(bottom - top, 0) + + intersection = width * height + area_a = (row_block_x2 - row_block_x1) * (row_block_y2 - row_block_y1) + area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1) + union = area_a + area_b - intersection + + iou_keep_out_bit_mask = ((intersection / union) > threshold).to(tl.int32) + + shift_offsets = tl.arange(0, BLOCK_SIZE) % 32 + shift_offsets = tl.flip(shift_offsets, 0)[None, :] + shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE]) + iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets + + # The process of combine bits. Note that the Triton seems having problem when the dtype is int64. + # Thus choosing 32 bits as the mask. And convert it to int64 at the end to avoid further potential overflow. + iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32)) + iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits) + iou_keep_out_combined = iou_keep_out_combined.to(tl.int64) + + # The bits are combined along the col, thus we need to change the col block offsets + # For the row offset, it will remain the same. + combined_col_blk_offsets = col_block_pid * ((BLOCK_SIZE + 31) // 32) + output_block_ptr = tl.make_block_ptr( + output_ptr, + shape=(num_boxes, (num_boxes + 32 - 1) // 32), + strides=(stride_i, stride_j), + offsets=(row_block_start, combined_col_blk_offsets), + block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32), + order=(0, 1), + ) + tl.store(output_block_ptr, iou_keep_out_combined, boundary_check=(0, 1)) diff --git a/torchvision/ops/xpu/__init__.py b/torchvision/ops/xpu/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torchvision/ops/xpu/nms.py b/torchvision/ops/xpu/nms.py new file mode 100644 index 00000000000..7b43e0f8b59 --- /dev/null +++ b/torchvision/ops/xpu/nms.py @@ -0,0 +1,62 @@ +import torch +import triton +from torchvision.ops.boxes import _nms_kernel_postprocess + +from torchvision.ops.triton.nms import triton_nms_IoU_kernel + + +@torch.library.register_kernel("torchvision::nms", "xpu") +def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) -> torch.Tensor: + """ + Performs non-maximum suppression (NMS) on the boxes according + to their intersection-over-union (IoU). + + NMS iteratively removes lower scoring boxes which have an + IoU greater than ``iou_threshold`` with another (higher scoring) + box. + + If multiple boxes have the exact same score and satisfy the IoU + criterion with respect to a reference box, the selected box is + not guaranteed to be the same between CPU and GPU. This is similar + to the behavior of argsort in PyTorch when repeated values are present. + + Args: + boxes (Tensor[N, 4])): boxes to perform NMS on. They + are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and + ``0 <= y1 < y2``. + scores (Tensor[N]): scores for each one of the boxes + iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold + + Returns: + Tensor: int64 tensor with the indices of the elements that have been kept + by NMS, sorted in decreasing order of scores + """ + num_boxes = boxes.shape[0] + + # Triton does not support argsort yet, thus it needs to fallback to ATen Calls + order = torch.argsort(scores, descending=True) + boxes = boxes[order] + iou_keep_out_mask = torch.zeros(num_boxes, (num_boxes + 32 - 1) // 32, dtype=torch.int64, device=boxes.device) + + grid = lambda meta: ( # noqa: E731 + triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), + triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), + ) + + # This triton kernel will calcualte the IoU matrix for all the input boxes (iou_keep_out_mask). + # The iou_keep_out_mask is defined as a 32-bit long bitmask matrix. So the matrix shape is [N, N//32]. + # Each item [i, j] will be interpreted as whether we should keep box j when we choose box i. + triton_nms_IoU_kernel[grid]( + boxes, + iou_keep_out_mask, + threshold, + num_boxes, + iou_keep_out_mask.stride(0), + iou_keep_out_mask.stride(1), + BLOCK_SIZE=64, + num_warps=4, + ) + + # The postprocess will calculate the final indices of the boxes that should be kept. + # It is a serialized process, and we choose to run it on CPU for more generalization. + return _nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device)