From d9f327b5d98390188fe853ca26adf4de710a2595 Mon Sep 17 00:00:00 2001 From: vansin Date: Fri, 11 Nov 2022 20:17:49 +0800 Subject: [PATCH 1/2] [WIP] init support yolov5 ncnn --- .../deploy/models/dense_heads/yolov5_head.py | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/mmyolo/deploy/models/dense_heads/yolov5_head.py b/mmyolo/deploy/models/dense_heads/yolov5_head.py index cf61fb3ca..6da18cadd 100644 --- a/mmyolo/deploy/models/dense_heads/yolov5_head.py +++ b/mmyolo/deploy/models/dense_heads/yolov5_head.py @@ -3,10 +3,12 @@ from functools import partial from typing import List, Optional, Tuple +import numpy as np import torch from mmdeploy.codebase.mmdet import get_post_processing_params from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import Backend from mmengine.config import ConfigDict from mmengine.structures import InstanceData from torch import Tensor @@ -146,3 +148,107 @@ def yolov5_head__predict_by_feat(ctx, return nms_func(bboxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, pre_top_k, keep_top_k) + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmyolo.models.dense_heads.yolov5_head.' + 'YOLOv5Head.predict_by_feat', + backend=Backend.NCNN.value) +def yolov5_head__predict_by_feat__ncnn(ctx, + self, + pred_maps, + with_nms=True, + cfg=None, + **kwargs): + """Rewrite `predict_by_feat` of YOLOV3Head for ncnn backend. + + 1. Shape node and batch inference is not supported by ncnn. This function + transform dynamic shape to constant shape and remove batch inference. + 2. Batch dimension is not supported by ncnn, but supported by pytorch. + The negative value of axis in torch.cat is rewritten as corresponding + positive value to avoid axis shift. + 3. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by + ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for + correct `BinaryOps` calculation by ncnn. + + + Args: + ctx (ContextCaller): The context with additional information. + + self: Represent the instance of the original class. + pred_maps (list[Tensor]): Raw predictions for a batch of images. + with_nms (bool): If True, do nms before return boxes. + Default: True. + cfg (mmengine.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. Default: None. + + Returns: + Tensor: Detection_output of shape [num_boxes, 6], + each row is [label, score, x1, y1, x2, y2]. Note that + fore-ground class label in Yolov3DetectionOutput starts + from `1`. x1, y1, x2, y2 are normalized in range(0,1). + """ + num_levels = len(pred_maps) + cfg = self.test_cfg if cfg is None else cfg + post_params = get_post_processing_params(ctx.cfg) + + confidence_threshold = cfg.get('conf_thr', + post_params.confidence_threshold) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + anchor_biases = np.array( + self.prior_generator.base_sizes).reshape(-1).tolist() + num_box = len(self.prior_generator.base_sizes[0]) + bias_masks = list(range(num_levels * num_box)) + + def _create_yolov5_detection_output(): + """Help create Yolov3DetectionOutput op in ONNX.""" + + class Yolov5DetectionOutputOp(torch.autograd.Function): + """Create Yolov3DetectionOutput op. + + Args: + *inputs (Tensor): Multiple predicted feature maps. + num_class (int): Number of classes. + num_box (int): Number of box per grid. + confidence_threshold (float): Threshold of object + score. + nms_threshold (float): IoU threshold for NMS. + biases (List[float]: Base sizes to compute anchors + for each FPN. + mask (List[float]): Used to select base sizes in + biases. + anchors_scale (List[float]): Down-sampling scales of + each FPN layer, e.g.: [32, 16]. + """ + + @staticmethod + def forward(ctx, *args): + # create dummpy output of shape [num_boxes, 6], + # each row is [label, score, x1, y1, x2, y2] + output = torch.rand(100, 6) + return output + + @staticmethod + def symbolic(g, *args): + anchors_scale = args[-1] + inputs = args[:len(anchors_scale)] + assert len(args) == (len(anchors_scale) + 7) + return g.op( + 'mmdeploy::Yolov3DetectionOutput', + *inputs, + num_class_i=args[-7], + num_box_i=args[-6], + confidence_threshold_f=args[-5], + nms_threshold_f=args[-4], + biases_f=args[-3], + mask_f=args[-2], + anchors_scale_f=anchors_scale, + outputs=1) + + return Yolov5DetectionOutputOp.apply(*pred_maps, self.num_classes, + num_box, confidence_threshold, + iou_threshold, anchor_biases, + bias_masks, self.featmap_strides) + + output = _create_yolov5_detection_output() + return output From 53afcd1f18b2bfe2c752b26f81835256936c52b9 Mon Sep 17 00:00:00 2001 From: vansin Date: Fri, 11 Nov 2022 22:22:38 +0800 Subject: [PATCH 2/2] [Feat] add ncnn deploy config --- configs/deploy/single-stage_ncnn_dynamic.py | 4 ++++ configs/deploy/single-stage_ncnn_static-300x300.py | 4 ++++ configs/deploy/single-stage_ncnn_static-416x416.py | 4 ++++ configs/deploy/single-stage_ncnn_static-800x1344.py | 4 ++++ 4 files changed, 16 insertions(+) create mode 100644 configs/deploy/single-stage_ncnn_dynamic.py create mode 100644 configs/deploy/single-stage_ncnn_static-300x300.py create mode 100644 configs/deploy/single-stage_ncnn_static-416x416.py create mode 100644 configs/deploy/single-stage_ncnn_static-800x1344.py diff --git a/configs/deploy/single-stage_ncnn_dynamic.py b/configs/deploy/single-stage_ncnn_dynamic.py new file mode 100644 index 000000000..bd917e0be --- /dev/null +++ b/configs/deploy/single-stage_ncnn_dynamic.py @@ -0,0 +1,4 @@ +_base_ = '../_base_/base_dynamic.py' +backend_config = dict(type='ncnn', precision='FP32', use_vulkan=False) +codebase_config = dict(model_type='ncnn_end2end') +onnx_config = dict(output_names=['detection_output'], input_shape=None) diff --git a/configs/deploy/single-stage_ncnn_static-300x300.py b/configs/deploy/single-stage_ncnn_static-300x300.py new file mode 100644 index 000000000..7dd5eaef8 --- /dev/null +++ b/configs/deploy/single-stage_ncnn_static-300x300.py @@ -0,0 +1,4 @@ +_base_ = '../_base_/base_static.py' +backend_config = dict(type='ncnn', precision='FP32', use_vulkan=False) +codebase_config = dict(model_type='ncnn_end2end') +onnx_config = dict(output_names=['detection_output'], input_shape=[300, 300]) diff --git a/configs/deploy/single-stage_ncnn_static-416x416.py b/configs/deploy/single-stage_ncnn_static-416x416.py new file mode 100644 index 000000000..dd09d0480 --- /dev/null +++ b/configs/deploy/single-stage_ncnn_static-416x416.py @@ -0,0 +1,4 @@ +_base_ = '../_base_/base_static.py' +backend_config = dict(type='ncnn', precision='FP32', use_vulkan=False) +codebase_config = dict(model_type='ncnn_end2end') +onnx_config = dict(output_names=['detection_output'], input_shape=[416, 416]) diff --git a/configs/deploy/single-stage_ncnn_static-800x1344.py b/configs/deploy/single-stage_ncnn_static-800x1344.py new file mode 100644 index 000000000..4b4973c94 --- /dev/null +++ b/configs/deploy/single-stage_ncnn_static-800x1344.py @@ -0,0 +1,4 @@ +_base_ = '../_base_/base_static.py' + +codebase_config = dict(model_type='ncnn_end2end') +onnx_config = dict(output_names=['detection_output'], input_shape=[1344, 800])