diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 5c6d8148ac1..e633b01392a 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2789,7 +2789,7 @@ bind_python: False - name: "nms" - signature: "Tensor (Tensor x, Float iou_threshold, Int32 keep_n=-1) => Nms" + signature: "Tensor (Tensor x, Tensor scores=None, Tensor input_indices=None, Float iou_threshold, Int32 keep_n=-1) => Nms" bind_python: True - name: "roi_align" diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index aef7ef62a3b..4a960cba34d 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -588,7 +588,25 @@ class ArgWhereFunctor { const Symbol& dtype) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype"); attrs.SetAllAttrs(dtype->data_type()); - return OpInterpUtil::Dispatch(*op_, {x}, attrs); + + auto device_type = DeviceType::kCPU; + if (x->is_global()) { + device_type = JUST(x->parallel_desc())->device_type(); + } else { + device_type = JUST(x->device())->enum_type(); + } + + if (device_type == DeviceType::kNPU) { + // NOTE: use cpu argwhere when device="npu" + auto cpu_tensor = JUST(one::functional::To(x, "cpu")); + auto result = JUST(OpInterpUtil::Dispatch(*op_, {cpu_tensor}, attrs)); + for (int i = 0; i < result->size(); ++i) { + (*result)[i] = JUST(one::functional::To((*result)[i], "npu")); + } + return result; + } else { + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } } private: diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 51f42367c07..3861c417a31 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -4014,17 +4014,38 @@ class PariticalFCSampleDisableBoxing { class NmsFunctor { public: - NmsFunctor() { op_ = CHECK_JUST(one::OpBuilder("nms").Input("in").Output("out").Build()); } + NmsFunctor() { + op_ = CHECK_JUST(one::OpBuilder("nms").Input("in").Output("out").Build()); + fused_op_ = CHECK_JUST(one::OpBuilder("nms") + .Input("in") + .Input("scores") + .Input("input_indices") + .Output("out") + .Build()); + } - Maybe operator()(const std::shared_ptr& x, const float& iou_threshold, + Maybe operator()(const std::shared_ptr& x, + const Optional& scores, + const Optional& input_indices, const float& iou_threshold, const int32_t& keep_n) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("iou_threshold", "keep_n"); attrs.SetAllAttrs(iou_threshold, keep_n); - return OpInterpUtil::Dispatch(*op_, {x}, attrs); + DeviceType device_type = JUST(x->device())->enum_type(); + if (device_type == DeviceType::kNPU) { + if (scores) { + return OpInterpUtil::Dispatch(*fused_op_, {x, JUST(scores), JUST(input_indices)}, + attrs); + } else { + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + } else { + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } } private: std::shared_ptr op_; + std::shared_ptr fused_op_; }; class RoiAlignFunctor { diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 7532cfe441e..190536c9ce5 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -1886,7 +1886,9 @@ def OneFlow_InTopKOp : OneFlow_BaseOp<"in_top_k", [NoMemoryEffect, NoGrad, Decla def OneFlow_NmsOp : OneFlow_BaseOp<"nms", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins - OneFlow_Tensor:$in + OneFlow_Tensor:$in, + Optional:$scores, + Optional:$input_indices ); let output = (outs OneFlow_Tensor:$out diff --git a/oneflow/user/ops/nms_op.cpp b/oneflow/user/ops/nms_op.cpp index e49236c7e8c..56e3883254d 100644 --- a/oneflow/user/ops/nms_op.cpp +++ b/oneflow/user/ops/nms_op.cpp @@ -26,7 +26,11 @@ Maybe InferNmsTensorDesc(user_op::InferContext* ctx) { } Maybe InferNmsDataType(user_op::InferContext* ctx) { - ctx->SetOutputDType("out", 0, DataType::kInt8); + if (ctx->parallel_desc().device_type() == DeviceType::kNPU) { + ctx->SetOutputDType("out", 0, DataType::kInt32); + } else { + ctx->SetOutputDType("out", 0, DataType::kInt8); + } return Maybe::Ok(); } diff --git a/python/oneflow/nn/modules/nms.py b/python/oneflow/nn/modules/nms.py index 7fdb64f0087..f6059d38161 100644 --- a/python/oneflow/nn/modules/nms.py +++ b/python/oneflow/nn/modules/nms.py @@ -20,7 +20,13 @@ def nms_op(boxes, scores, iou_threshold: float): score_inds = flow.argsort(scores, dim=0, descending=True) - boxes = flow._C.gather(boxes, score_inds, axis=0) - keep = flow._C.nms(boxes, iou_threshold) + if boxes.device == flow.device("npu"): + sorted_scores = flow.gather(scores, dim=0, index=score_inds) + keep = flow._C.nms( + boxes, sorted_scores, score_inds.to(flow.int32), iou_threshold=iou_threshold + ) + else: + boxes = flow._C.gather(boxes, score_inds, axis=0) + keep = flow._C.nms(boxes, iou_threshold=iou_threshold) index = flow.squeeze(flow.argwhere(keep), dim=[1]) return flow._C.gather(score_inds, index, axis=0)