Skip to content

Commit 7f277f6

Browse files
committed
Fix runtime issue
1 parent 607c839 commit 7f277f6

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchvision/ops/triton/nms.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def _combine_bits(val0, val1):
99
return val0 | val1
1010

1111

12+
@triton.jit
1213
def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr):
1314
"""
1415
This nms_kernel computes the supressed mask of boxes [i, j].
@@ -76,13 +77,16 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, str
7677

7778
iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32))
7879
iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits)
79-
8080
iou_keep_out_combined = iou_keep_out_combined.to(tl.int64)
81+
82+
# The bits are combined along the col, thus we need to change the col block offsets
83+
# For the row offset, it will remain the same.
84+
combined_col_blk_offsets = col_block_pid * ((BLOCK_SIZE + 31) // 32)
8185
output_block_ptr = tl.make_block_ptr(
8286
output_ptr,
8387
shape=(num_boxes, (num_boxes + 32 - 1) // 32),
8488
strides=(stride_i, stride_j),
85-
offsets=(row_block_start, 0),
89+
offsets=(row_block_start, combined_col_blk_offsets),
8690
block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32),
8791
order=(0, 1),
8892
)

0 commit comments

Comments
 (0)