@@ -9,6 +9,7 @@ def _combine_bits(val0, val1):
9
9
return val0 | val1
10
10
11
11
12
+ @triton .jit
12
13
def triton_nms_IoU_kernel (boxes , output_ptr , threshold , num_boxes , stride_i , stride_j , BLOCK_SIZE : tl .constexpr ):
13
14
"""
14
15
This nms_kernel computes the supressed mask of boxes [i, j].
@@ -76,13 +77,16 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, str
76
77
77
78
iou_keep_out_bit_mask = tl .reshape (iou_keep_out_bit_mask , (BLOCK_SIZE , (BLOCK_SIZE + 32 - 1 ) // 32 , 32 ))
78
79
iou_keep_out_combined = tl .reduce (iou_keep_out_bit_mask , axis = 2 , combine_fn = _combine_bits )
79
-
80
80
iou_keep_out_combined = iou_keep_out_combined .to (tl .int64 )
81
+
82
+ # The bits are combined along the col, thus we need to change the col block offsets
83
+ # For the row offset, it will remain the same.
84
+ combined_col_blk_offsets = col_block_pid * ((BLOCK_SIZE + 31 ) // 32 )
81
85
output_block_ptr = tl .make_block_ptr (
82
86
output_ptr ,
83
87
shape = (num_boxes , (num_boxes + 32 - 1 ) // 32 ),
84
88
strides = (stride_i , stride_j ),
85
- offsets = (row_block_start , 0 ),
89
+ offsets = (row_block_start , combined_col_blk_offsets ),
86
90
block_shape = (BLOCK_SIZE , (BLOCK_SIZE + 32 - 1 ) // 32 ),
87
91
order = (0 , 1 ),
88
92
)
0 commit comments