-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathnms.cpp
30 lines (25 loc) · 928 Bytes
/
nms.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include "nms.h"
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <torch/types.h>
namespace vision {
namespace ops {
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::nms", "")
.typed<decltype(nms)>();
return op.call(dets, scores, iou_threshold);
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.set_python_module("torchvision._meta_registrations");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms_kernel_postprocess(Tensor order, Tensor iou_keep_out_mask, int num_boxes) -> Tensor"));
}
} // namespace ops
} // namespace vision