diff --git a/mmrotate/core/__init__.py b/mmrotate/core/__init__.py index c8164cf94..a5e4b45ad 100644 --- a/mmrotate/core/__init__.py +++ b/mmrotate/core/__init__.py @@ -2,6 +2,7 @@ from .anchor import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403 +from .mask import * # noqa: F401, F403 from .patch import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403 from .visualization import * # noqa: F401, F403 diff --git a/mmrotate/core/mask/__init__.py b/mmrotate/core/mask/__init__.py new file mode 100644 index 000000000..412f731fd --- /dev/null +++ b/mmrotate/core/mask/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .structures import RBitmapMasks + +__all__ = ['RBitmapMasks'] diff --git a/mmrotate/core/mask/structures.py b/mmrotate/core/mask/structures.py new file mode 100644 index 000000000..c350e882c --- /dev/null +++ b/mmrotate/core/mask/structures.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np +from mmdet.core.mask import BitmapMasks + + +class RBitmapMasks(BitmapMasks): + """This class represents masks in the form of bitmaps. Compared to the + original class, this class supports getting the minimum area rectangles + from masks. + + Args: + masks (ndarray): ndarray of masks in shape (N, H, W), where N is + the number of objects. + height (int): height of masks + width (int): width of masks + """ + + def get_rbboxes(self): + num_masks = len(self) + rboxes = np.zeros((num_masks, 5), dtype=np.float32) + x_any = self.masks.any(axis=1) + y_any = self.masks.any(axis=2) + for idx in range(num_masks): + x = np.where(x_any[idx, :])[0] + y = np.where(y_any[idx, :])[0] + if len(x) > 0 and len(y) > 0: + contours = cv2.findContours(self.masks[idx], cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE)[0][0] + (cx, cy), (w, h), a = cv2.minAreaRect(contours) + rboxes[idx, :] = np.array( + [cx, cy, w, h, np.radians(a)], dtype=np.float32) + return rboxes diff --git a/mmrotate/datasets/pipelines/__init__.py b/mmrotate/datasets/pipelines/__init__.py index 129bc1983..35190317e 100644 --- a/mmrotate/datasets/pipelines/__init__.py +++ b/mmrotate/datasets/pipelines/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .loading import LoadPatchFromImage -from .transforms import PolyRandomRotate, RMosaic, RRandomFlip, RResize +from .transforms import (PolyRandomRotate, RCopyPaste, RMosaic, RRandomFlip, + RResize) __all__ = [ 'LoadPatchFromImage', 'RResize', 'RRandomFlip', 'PolyRandomRotate', - 'RMosaic' + 'RMosaic', 'RCopyPaste' ] diff --git a/mmrotate/datasets/pipelines/transforms.py b/mmrotate/datasets/pipelines/transforms.py index 514164a85..60096109f 100644 --- a/mmrotate/datasets/pipelines/transforms.py +++ b/mmrotate/datasets/pipelines/transforms.py @@ -6,11 +6,12 @@ import numpy as np import torch from mmcv.ops import box_iou_rotated -from mmdet.datasets.pipelines.transforms import (Mosaic, RandomCrop, +from mmdet.datasets.pipelines.transforms import (CopyPaste, Mosaic, RandomCrop, RandomFlip, Resize) from numpy import random from mmrotate.core import norm_angle, obb2poly_np, poly2obb_np +from mmrotate.core.mask import RBitmapMasks from ..builder import ROTATED_PIPELINES @@ -554,3 +555,141 @@ def _filter_box_candidates(self, bboxes, labels, w, h): (bbox_h > self.min_bbox_size) valid_inds = np.nonzero(valid_inds)[0] return bboxes[valid_inds], labels[valid_inds] + + +@ROTATED_PIPELINES.register_module() +class RCopyPaste(CopyPaste): + """Simple Copy-Paste is a Strong Data Augmentation Method for Instance + Segmentation The simple copy-paste transform steps are as follows: + + 1. The destination image is already resized with aspect ratio kept, + cropped and padded. + 2. Randomly select a source image, which is also already resized + with aspect ratio kept, cropped and padded in a similar way + as the destination image. + 3. Randomly select some objects from the source image. + 4. Paste these source objects to the destination image directly, + due to the source and destination image have the same size. + 5. Update object masks of the destination image, for some origin objects + may be occluded. + 6. Generate bboxes from the updated destination masks and + filter some objects which are totally occluded, and adjust bboxes + which are partly occluded. + 7. Append selected source bboxes, masks, and labels. + Args: + max_num_pasted (int): The maximum number of pasted objects. + Default: 100. + rbbox_occluded_iou_thr (int): The threshold of occluded rbboxes. + Default: 0.3. + mask_occluded_thr (int): The threshold of occluded mask. + Default: 300. + selected (bool): Whether select objects or not. If select is False, + all objects of the source image will be pasted to the + destination image. + Default: True. + version (str, optional): Angle representations. Defaults to `oc`. + """ + + def __init__( + self, + max_num_pasted=100, + rbbox_occluded_iou_thr=0.3, + mask_occluded_thr=300, + selected=True, + version='le90', + ): + self.max_num_pasted = max_num_pasted + self.rbbox_occluded_iou_thr = rbbox_occluded_iou_thr + self.mask_occluded_thr = mask_occluded_thr + self.selected = selected + self.paste_by_box = False + self.version = version + + def gen_masks_from_bboxes(self, bboxes, img_shape): + """Generate gt_masks based on gt_bboxes. + + Args: + bboxes (list): The bboxes's list. + img_shape (tuple): The shape of image. + Returns: + RBitmapMasks + """ + self.paste_by_box = True + img_h, img_w = img_shape[:2] + gt_masks = np.zeros((len(bboxes), img_h, img_w), dtype=np.uint8) + bboxes = np.concatenate( + [bboxes, np.zeros((bboxes.shape[0], 1))], axis=-1) + polys = obb2poly_np(bboxes, + self.version)[:, :-1].reshape(-1, 4, + 2).astype(np.int0) + + for i, poly in enumerate(polys): + cv2.drawContours(gt_masks[i], [poly], 0, 1, -1) + return RBitmapMasks(gt_masks, img_h, img_w) + + def _copy_paste(self, dst_results, src_results): + """CopyPaste transform function. + + Args: + dst_results (dict): Result dict of the destination image. + src_results (dict): Result dict of the source image. + Returns: + dict: Updated result dict. + """ + + dst_img = dst_results['img'] + dst_bboxes = dst_results['gt_bboxes'] + dst_labels = dst_results['gt_labels'] + dst_masks = dst_results['gt_masks'] + + src_img = src_results['img'] + src_bboxes = src_results['gt_bboxes'] + src_labels = src_results['gt_labels'] + src_masks = src_results['gt_masks'] + + if len(src_bboxes) == 0: + if self.paste_by_box: + dst_results.pop('gt_masks') + return dst_results + + # update masks and generate bboxes from updated masks + composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0) + updated_dst_masks = self.get_updated_masks(dst_masks, composed_mask) + updated_dst_bboxes = updated_dst_masks.get_rbboxes() + assert len(updated_dst_bboxes) == len(updated_dst_masks) + + # filter totally occluded objects + bboxes_inds = box_iou_rotated( + torch.tensor(dst_bboxes), torch.tensor(updated_dst_bboxes)).numpy( + ).max(-1) <= self.rbbox_occluded_iou_thr + masks_inds = updated_dst_masks.masks.sum( + axis=(1, 2)) > self.mask_occluded_thr + valid_inds = bboxes_inds | masks_inds + + # Paste source objects to destination image directly + img = dst_img * (1 - composed_mask[..., np.newaxis] + ) + src_img * composed_mask[..., np.newaxis] + bboxes = np.concatenate([updated_dst_bboxes[valid_inds], src_bboxes]) + labels = np.concatenate([dst_labels[valid_inds], src_labels]) + masks = np.concatenate( + [updated_dst_masks.masks[valid_inds], src_masks.masks]) + + dst_results['img'] = img + dst_results['gt_bboxes'] = bboxes + dst_results['gt_labels'] = labels + if self.paste_by_box: + dst_results.pop('gt_masks') + else: + dst_results['gt_masks'] = RBitmapMasks(masks, masks.shape[1], + masks.shape[2]) + + return dst_results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'max_num_pasted={self.max_num_pasted}, ' + repr_str += f'rbox_occluded_iou_thr={self.rbox_occluded_iou_thr}, ' + repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, ' + repr_str += f'selected={self.selected}, ' + repr_str += f'version={self.version}, ' + return repr_str