-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathnms.py
98 lines (80 loc) · 4.98 KB
/
nms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import triton
import triton.language as tl
@triton.jit
def _combine_bits(val0, val1):
tl.static_assert(val0.dtype == tl.int32, "input must be int32")
tl.static_assert(val1.dtype == tl.int32, "input must be int32")
return val0 | val1
@triton.jit
def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr):
"""
This nms_kernel computes the supressed mask of boxes [i, j].
mask[i, j]==1 means if we choose box i, the box j will be supressed.
The output is a mask of size [num_boxes, num_boxes//32], where each item is int32.
Args:
boxes (tl.tensor): A tensor containing the bounding boxes with shape (num_boxes, 4).
output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored.
threshold (float): The IoU threshold for suppressing boxes.
num_boxes (int): The total number of boxes.
stride_i (int): The stride of the output tensor along the first dimension.
stride_j (int): The stride of the output tensor along the second dimension.
BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel.
Returns:
Tensor (int32): Tensor with size [num_boxes, num_boxes//32]. It indicates that if `box i` is
choosen, whether box `j` could be choosen. The value `1` means it cannot be choosen.
"""
# The Triton kernel is a 2D block kernel. The block size is BLOCK_SIZE x BLOCK_SIZE.
# Each kernel will compute the IoU of boxes[row: row + BLOCK_SIZE, col: col + BLOCK_SIZE]
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)
row_block_start = row_block_pid * BLOCK_SIZE
col_block_start = col_block_pid * BLOCK_SIZE
row_block_offsets = row_block_start + tl.arange(0, BLOCK_SIZE)
col_block_offsets = col_block_start + tl.arange(0, BLOCK_SIZE)
row_block_mask = row_block_offsets < num_boxes
col_block_mask = col_block_offsets < num_boxes
# Since Triton does not support tensor slicing yet, we need to load point elements individiually
# Every row_block is loaded as a 1 dim tensor of size [BLOCK_SIZE]
# We then expand 1 dim for row. So that the row block dim would be [BLOCK_SIZE, 1]
row_block_x1 = tl.load(boxes + row_block_offsets * 4 + 0, mask=row_block_mask)[:, None]
row_block_y1 = tl.load(boxes + row_block_offsets * 4 + 1, mask=row_block_mask)[:, None]
row_block_x2 = tl.load(boxes + row_block_offsets * 4 + 2, mask=row_block_mask)[:, None]
row_block_y2 = tl.load(boxes + row_block_offsets * 4 + 3, mask=row_block_mask)[:, None]
# Expand 1 dim for col. So that the col block dim would be [1, BLOCK_SIZE]
col_block_x1 = tl.load(boxes + col_block_offsets * 4 + 0, mask=col_block_mask)[None, :]
col_block_y1 = tl.load(boxes + col_block_offsets * 4 + 1, mask=col_block_mask)[None, :]
col_block_x2 = tl.load(boxes + col_block_offsets * 4 + 2, mask=col_block_mask)[None, :]
col_block_y2 = tl.load(boxes + col_block_offsets * 4 + 3, mask=col_block_mask)[None, :]
# Together, the minimum / maximum will broadcast and form into a [BLOCK_SIZE, BLOCK_SIZE] matrix
left = tl.maximum(row_block_x1, col_block_x1)
right = tl.minimum(row_block_x2, col_block_x2)
top = tl.maximum(row_block_y1, col_block_y1)
bottom = tl.minimum(row_block_y2, col_block_y2)
width = tl.maximum(right - left, 0)
height = tl.maximum(bottom - top, 0)
intersection = width * height
area_a = (row_block_x2 - row_block_x1) * (row_block_y2 - row_block_y1)
area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1)
union = area_a + area_b - intersection
iou_keep_out_bit_mask = ((intersection / union) > threshold).to(tl.int32)
shift_offsets = tl.arange(0, BLOCK_SIZE) % 32
shift_offsets = tl.flip(shift_offsets, 0)[None, :]
shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE])
iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets
# The process of combine bits. Note that the Triton seems having problem when the dtype is int64.
# Thus choosing 32 bits as the mask. And convert it to int64 at the end to avoid further potential overflow.
iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32))
iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits)
iou_keep_out_combined = iou_keep_out_combined.to(tl.int64)
# The bits are combined along the col, thus we need to change the col block offsets
# For the row offset, it will remain the same.
combined_col_blk_offsets = col_block_pid * ((BLOCK_SIZE + 31) // 32)
output_block_ptr = tl.make_block_ptr(
output_ptr,
shape=(num_boxes, (num_boxes + 32 - 1) // 32),
strides=(stride_i, stride_j),
offsets=(row_block_start, combined_col_blk_offsets),
block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32),
order=(0, 1),
)
tl.store(output_block_ptr, iou_keep_out_combined, boundary_check=(0, 1))