From f7a42d0265404782b15cab463938c263e4234312 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Fri, 29 Dec 2023 15:48:58 +0800 Subject: [PATCH 01/24] update code --- mmdet/models/language_models/bert.py | 41 +- projects/mm_gdino_clip/__init__.py | 6 + projects/mm_gdino_clip/batch_sampler.py | 61 +++ .../mm_gdino_clip/browse_grounding_dataset.py | 221 ++++++++++ .../mm_gdino_clip/browse_grounding_raw.py | 288 ++++++++++++++ ...nding_dino_swin-t_pretrain_obj365_goldg.py | 277 +++++++++++++ projects/mm_gdino_clip/grounding_dino.py | 46 +++ projects/mm_gdino_clip/odvgrec.py | 155 ++++++++ projects/mm_gdino_clip/refcoco2rec.py | 97 +++++ projects/mm_gdino_clip/text_transformers.py | 376 ++++++++++++++++++ 10 files changed, 1564 insertions(+), 4 deletions(-) create mode 100644 projects/mm_gdino_clip/__init__.py create mode 100644 projects/mm_gdino_clip/batch_sampler.py create mode 100644 projects/mm_gdino_clip/browse_grounding_dataset.py create mode 100644 projects/mm_gdino_clip/browse_grounding_raw.py create mode 100644 projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py create mode 100644 projects/mm_gdino_clip/grounding_dino.py create mode 100644 projects/mm_gdino_clip/odvgrec.py create mode 100644 projects/mm_gdino_clip/refcoco2rec.py create mode 100644 projects/mm_gdino_clip/text_transformers.py diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py index efb0f46bad6..f0ce94f1524 100644 --- a/mmdet/models/language_models/bert.py +++ b/mmdet/models/language_models/bert.py @@ -14,6 +14,7 @@ HFBertModel = None from mmdet.registry import MODELS +from mmdet.models.utils import align_tensor def generate_masks_with_special_tokens_and_transfer_map( @@ -70,6 +71,14 @@ def generate_masks_with_special_tokens_and_transfer_map( return attention_mask, position_ids.to(torch.long) +def split_tensor(tensor, num_levels): + level_targets = [] + start = 0 + for n in num_levels: + end = start + n + level_targets.append(target[:, start:end]) + start = end + return level_targets @MODELS.register_module() class BertModel(BaseModel): @@ -134,9 +143,14 @@ def __init__(self, self.special_tokens = self.tokenizer.convert_tokens_to_ids( special_tokens_list) - def forward(self, captions: Sequence[str], **kwargs) -> dict: + def forward(self, captions: Sequence[str], task='VG', **kwargs) -> dict: """Forward function.""" device = next(self.language_backbone.parameters()).device + + if task == 'OD': + batch_len_captions = [len(item) for item in captions] + captions = [item for sublist in captions for item in sublist] + tokenized = self.tokenizer.batch_encode_plus( captions, max_length=self.max_tokens, @@ -145,12 +159,11 @@ def forward(self, captions: Sequence[str], **kwargs) -> dict: return_tensors='pt', truncation=True).to(device) input_ids = tokenized.input_ids - if self.use_sub_sentence_represent: + if self.use_sub_sentence_represent and task == 'VG': attention_mask, position_ids = \ generate_masks_with_special_tokens_and_transfer_map( tokenized, self.special_tokens) token_type_ids = tokenized['token_type_ids'] - else: attention_mask = tokenized.attention_mask position_ids = None @@ -163,10 +176,30 @@ def forward(self, captions: Sequence[str], **kwargs) -> dict: 'token_type_ids': token_type_ids } language_dict_features = self.language_backbone(tokenizer_input) - if self.use_sub_sentence_represent: + if self.use_sub_sentence_represent and task == 'VG': language_dict_features['position_ids'] = position_ids language_dict_features[ 'text_token_mask'] = tokenized.attention_mask.bool() + else: + end_token_idx = input_ids.argmin(dim=-1) - 1 + embedded = language_dict_features['embedded'] + embedded = embedded[torch.arange(embedded.shape[0]), end_token_idx] + + batch_embedded = [] + batch_mask=[] + batch_text_token_mask=[] + + embedded = split_tensor(embedded, batch_len_captions) + embedded = align_tensor(embedded) + attention_mask = split_tensor(tokenized.attention_mask.bool(), batch_len_captions) + attention_mask = align_tensor(attention_mask) + mask = split_tensor(language_dict_features['mask'], batch_len_captions) + mask = align_tensor(mask) + + language_dict_features['embedded'] = embedded + language_dict_features['hidden'] = embedded + language_dict_features['text_token_mask'] = attention_mask + language_dict_features['mask'] = mask return language_dict_features diff --git a/projects/mm_gdino_clip/__init__.py b/projects/mm_gdino_clip/__init__.py new file mode 100644 index 00000000000..cbf3a3d673a --- /dev/null +++ b/projects/mm_gdino_clip/__init__.py @@ -0,0 +1,6 @@ +from .odvgrec import ODVGRECDataset +from .text_transformers import RandomSamplingNegPosV2 +from .batch_sampler import MultiTaskAspectRatioBatchSampler +from .grounding_dino import GroundingDINOV2 + +__all__ = ['ODVGRECDataset', 'RandomSamplingNegPosV2', 'MultiTaskAspectRatioBatchSampler', 'GroundingDINOV2'] diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py new file mode 100644 index 00000000000..510d2196556 --- /dev/null +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +from torch.utils.data import BatchSampler, Sampler +from mmdet.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class MultiTaskAspectRatioBatchSampler(BatchSampler): + def __init__(self, + sampler: Sampler, + batch_size: int, + drop_last: bool = True) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + # two groups for w < h and w >= h and two task + self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] + + def __iter__(self) -> Sequence[int]: + for idx in self.sampler: + data_info = self.sampler.dataset.get_data_info(idx) + width, height = data_info['width'], data_info['height'] + bucket_id = 0 if width < height else 1 + # REC and OVD: 0 2 + # VG: 1 3 + if data_info['dataset_mode'] in ['REC', 'OVD']: + bucket_id = bucket_id * 2 + else: + bucket_id = bucket_id * 2 + 1 + bucket = self._aspect_ratio_buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the bucket + # left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ + # 1] + self._aspect_ratio_buckets[2] + self._aspect_ratio_buckets[3] + self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] + # while len(left_data) > 0: + # if len(left_data) <= self.batch_size: + # if not self.drop_last: + # yield left_data[:] + # left_data = [] + # else: + # yield left_data[:self.batch_size] + # left_data = left_data[self.batch_size:] + + def __len__(self) -> int: + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/projects/mm_gdino_clip/browse_grounding_dataset.py b/projects/mm_gdino_clip/browse_grounding_dataset.py new file mode 100644 index 00000000000..2fcf4b51cc2 --- /dev/null +++ b/projects/mm_gdino_clip/browse_grounding_dataset.py @@ -0,0 +1,221 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import numpy as np +from mmcv.image import imwrite +from mmengine.config import Config, DictAction +from mmengine.registry import init_default_scope +from mmengine.utils import ProgressBar + +from mmdet.registry import DATASETS, VISUALIZERS +from mmdet.structures.bbox import BaseBoxes + + +# configs/grounding_dino_swin-t_pretrain_obj365_goldg.py -o aa --not-show --shuffle +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--output-dir', + '-o', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument('--show-num', '-n', type=int, default=30) + parser.add_argument('--shuffle', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=0, + help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def draw_all_character(visualizer, characters, w): + start_index = 2 + y_index = 5 + for char in characters: + if isinstance(char, str): + visualizer.draw_texts( + str(char), + positions=np.array([start_index, y_index]), + colors=(0, 0, 0), + font_families='monospace') + start_index += len(char) * 8 + else: + visualizer.draw_texts( + str(char[0]), + positions=np.array([start_index, y_index]), + colors=char[1], + font_families='monospace') + start_index += len(char[0]) * 8 + + if start_index > w - 10: + start_index = 2 + y_index += 15 + + drawn_text = visualizer.get_image() + return drawn_text + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + assert args.show_num > 0 + + # register all modules in mmdet into the registries + init_default_scope(cfg.get('default_scope', 'mmdet')) + + dataset = DATASETS.build(cfg.train_dataloader.dataset) + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.metainfo + + dataset_index = list(range(len(dataset))) + if args.shuffle: + import random + random.shuffle(dataset_index) + + progress_bar = ProgressBar(len(dataset)) + for i in dataset_index[:args.show_num]: + item = dataset[i] + img = item['inputs'].permute(1, 2, 0).numpy() + data_sample = item['data_samples'].numpy() + gt_instances = data_sample.gt_instances + + gt_labels = gt_instances.labels + + base_name = osp.basename(item['data_samples'].img_path) + name, extension = osp.splitext(base_name) + + img = img[..., [2, 1, 0]] # bgr to rgb + gt_bboxes = gt_instances.get('bboxes', None) + if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes): + gt_instances.bboxes = gt_bboxes.tensor + + dataset_mode = data_sample.dataset_mode + print(base_name, dataset_mode, data_sample.text) + + out_file = osp.join(args.output_dir, dataset_mode + '_' + name + '_' + str(i) + + extension) if args.output_dir is not None else None + + if dataset_mode == 'VG': + tokens_positive = data_sample.tokens_positive + + max_label = int(max(gt_labels) if len(gt_labels) > 0 else 0) + palette = np.random.randint(0, 256, size=(max_label + 1, 3)) + bbox_palette = [tuple(c) for c in palette] + # bbox_palette = get_palette('random', max_label + 1) + colors = [bbox_palette[label] for label in gt_labels] + + visualizer.set_image(img) + + for label, bbox, color in zip(gt_labels, gt_bboxes, colors): + visualizer.draw_bboxes( + bbox, edge_colors=color, face_colors=color, alpha=0.3) + visualizer.draw_bboxes(bbox, edge_colors=color, alpha=1) + + drawn_img = visualizer.get_image() + + new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + gt_tokens_positive = [ + tokens_positive[label] for label in gt_labels + ] + split_by_character = [char for char in data_sample.text] + characters = [] + start_index = 0 + end_index = 0 + for w in split_by_character: + end_index += len(w) + is_find = False + for i, positive in enumerate(gt_tokens_positive): + for p in positive: + if start_index >= p[0] and end_index <= p[1]: + characters.append([w, colors[i]]) + is_find = True + break + if is_find: + break + if not is_find: + characters.append([w, (0, 0, 0)]) + start_index = end_index + + drawn_text = draw_all_character(visualizer, characters, + img.shape[1]) + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + elif dataset_mode == 'OD': + tokens_positive = data_sample.tokens_positive + gt_labels = gt_instances.labels + text = data_sample.text + label_names = [] + for label in gt_labels: + label_names.append(text[ + tokens_positive[label][0][0]:tokens_positive[label][0][1]]) + gt_instances.label_names = label_names + data_sample.gt_instances = gt_instances + + visualizer.add_datasample( + base_name, + img, + data_sample, + draw_pred=False, + show=False, + wait_time=0, + out_file=None) + drawn_img = visualizer.get_image() + + new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + characters = [char for char in text] + drawn_text = draw_all_character(visualizer, characters, + img.shape[1]) + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + else: + gt_labels = gt_instances.labels + text = data_sample.text + label_names = [] + for label in gt_labels: + label_names.append(text[label]) + gt_instances.label_names = label_names + data_sample.gt_instances = gt_instances + + visualizer.add_datasample( + base_name, + img, + data_sample, + draw_pred=False, + show=False, + wait_time=0, + out_file=None) + drawn_img = visualizer.get_image() + + if not args.not_show: + visualizer.show( + drawn_img, win_name=base_name, wait_time=args.show_interval) + + if out_file is not None: + imwrite(drawn_img[..., ::-1], out_file) + + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/projects/mm_gdino_clip/browse_grounding_raw.py b/projects/mm_gdino_clip/browse_grounding_raw.py new file mode 100644 index 00000000000..5961fe069b3 --- /dev/null +++ b/projects/mm_gdino_clip/browse_grounding_raw.py @@ -0,0 +1,288 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import os.path as osp + +import cv2 +import numpy as np +from mmcv.image import imfrombytes, imwrite +from mmengine.fileio import get +from mmengine.structures import InstanceData +from mmengine.utils import mkdir_or_exist + +from mmdet.structures import DetDataSample +from mmdet.visualization import DetLocalVisualizer +from mmdet.visualization.palette import _get_adaptive_scales + +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + + +# /home/PJLAB/huanghaian/dataset/coco2014/ mdetr_annotations/finetune_refcocog_train_ref.json train2014 --not-show --shuffle -o rex +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('data_root') + parser.add_argument('ann_file') + parser.add_argument('img_prefix') + parser.add_argument('--label-map-file', '-m', default=None) + parser.add_argument( + '--output-dir', + '-o', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument('--show-num', '-n', type=int, default=30) + parser.add_argument('--shuffle', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=0, + help='the interval of show (s)') + args = parser.parse_args() + return args + + +def draw_all_character(visualizer, characters, w): + start_index = 2 + y_index = 5 + for char in characters: + if isinstance(char, str): + visualizer.draw_texts( + str(char), + positions=np.array([start_index, y_index]), + colors=(0, 0, 0), + font_families='monospace') + start_index += len(char) * 8 + else: + visualizer.draw_texts( + str(char[0]), + positions=np.array([start_index, y_index]), + colors=char[1], + font_families='monospace') + start_index += len(char[0]) * 8 + + if start_index > w - 10: + start_index = 2 + y_index += 15 + + drawn_text = visualizer.get_image() + return drawn_text + + +def main(): + args = parse_args() + assert args.show_num > 0 + + local_path = osp.join(args.data_root, args.ann_file) + with open(local_path, 'r') as f: + data_list = [json.loads(line) for line in f] + + dataset_index = list(range(len(data_list))) + if args.shuffle: + import random + random.shuffle(dataset_index) + + if args.label_map_file is not None: + label_map_file = osp.join(args.data_root, args.label_map_file) + with open(label_map_file, 'r') as file: + label_map = json.load(file) + + visualizer = DetLocalVisualizer() + + for i in dataset_index[:args.show_num]: + item = data_list[i] + + img_path = osp.join(args.data_root, args.img_prefix, item['filename']) + if backend_args is not None: + img_bytes = get(img_path, backend_args) + img = imfrombytes(img_bytes, flag='color') + else: + img = cv2.imread(img_path) + img = img[..., [2, 1, 0]] # bgr to rgb + + base_name, extension = osp.splitext(item['filename']) + + out_file = osp.join(args.output_dir, base_name + '_' + str(i) + + extension) if args.output_dir is not None else None + + if args.output_dir is not None: + mkdir_or_exist(args.output_dir) + + if 'detection' in item: + anno = item['detection'] + + instances = [obj for obj in anno['instances']] + bboxes = [obj['bbox'] for obj in instances] + bbox_labels = [int(obj['label']) for obj in instances] + label_names = [label_map[str(label)] for label in bbox_labels] + + data_sample = DetDataSample() + gt_instances = InstanceData() + if len(instances) > 0 and 'score' in instances[0]: + score = [obj['score'] for obj in instances] + gt_instances['scores'] = np.array(score) + + gt_instances['bboxes'] = np.array(bboxes).reshape(-1, 4) + gt_instances['labels'] = np.array(bbox_labels) + gt_instances['label_names'] = label_names + data_sample.gt_instances = gt_instances + + visualizer.add_datasample( + osp.basename(img_path), + img, + data_sample, + draw_pred=False, + show=not args.not_show, + wait_time=args.show_interval, + out_file=out_file) + elif 'grounding' in item: + anno = item['grounding'] + text = anno['caption'] + regions = anno['regions'] + + max_label = len(regions) if len(regions) > 0 else 0 + palette = np.random.randint(0, 256, size=(max_label + 1, 3)) + bbox_palette = [tuple(c) for c in palette] + # bbox_palette = get_palette('random', max_label + 1) + colors = [bbox_palette[label] for label in range(max_label)] + + visualizer.set_image(img) + + gt_tokens_positive = [] + for i, region in enumerate(regions): + bbox = region['bbox'] + bbox = np.array(bbox).reshape(-1, 4) + tokens_positive = region['tokens_positive'] + gt_tokens_positive.append(tokens_positive) + visualizer.draw_bboxes( + bbox, + edge_colors=colors[i], + face_colors=colors[i], + alpha=0.3) + visualizer.draw_bboxes(bbox, edge_colors=colors[i], alpha=1) + + if 'score' in region: + areas = (bbox[:, 3] - bbox[:, 1]) * ( + bbox[:, 2] - bbox[:, 0]) + scales = _get_adaptive_scales(areas) + score = region['score'][0] + score = [str(s) for s in score] + font_sizes = [ + int(13 * scales[i]) for i in range(len(scales)) + ] + visualizer.draw_texts( + score, + bbox[:, :2].astype(np.int32), + colors=(255, 255, 255), + font_sizes=font_sizes, + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }] * len(bbox)) + + drawn_img = visualizer.get_image() + new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + split_by_character = [char for char in text] + characters = [] + start_index = 0 + end_index = 0 + for w in split_by_character: + end_index += len(w) + is_find = False + for i, positive in enumerate(gt_tokens_positive): + for p in positive: + if start_index >= p[0] and end_index <= p[1]: + characters.append([w, colors[i]]) + is_find = True + break + if is_find: + break + if not is_find: + characters.append([w, (0, 0, 0)]) + start_index = end_index + + drawn_text = draw_all_character(visualizer, characters, + img.shape[1]) + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + + if not args.not_show: + visualizer.show( + drawn_img, + win_name=base_name, + wait_time=args.show_interval) + + if out_file is not None: + imwrite(drawn_img[..., ::-1], out_file) + + elif 'referring' in item: + referring = item['referring']['instances'] + + max_label = len(referring) if len(referring) > 0 else 0 + palette = np.random.randint(0, 256, size=(max_label + 1, 3)) + bbox_palette = [tuple(c) for c in palette] + # bbox_palette = get_palette('random', max_label + 1) + colors = [bbox_palette[label] for label in range(max_label)] + + visualizer.set_image(img) + phrases = [] + for i, ref in enumerate(referring): + bbox = ref['bbox'] + phrases.append(ref['exp']) + bbox = np.array(bbox).reshape(-1, 4) + + visualizer.draw_bboxes( + bbox, + edge_colors=colors[i], + face_colors=colors[i], + alpha=0.3) + visualizer.draw_bboxes(bbox, edge_colors=colors[i], alpha=1) + drawn_img = visualizer.get_image() + + new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + visualizer.set_image(new_image) + + start_index = 2 + y_index = 5 + + chunk_size = max(min(img.shape[1] - 400, 70), 50) + for i, p in enumerate(phrases): + if not isinstance(p, list): + p = [p] + + for _p in p: + chunk_p = [ + _p[i:i + chunk_size] for i in range(0, len(_p), chunk_size) + ] + for cp in chunk_p: + visualizer.draw_texts( + cp, + positions=np.array([start_index, y_index]), + colors=colors[i], + font_families='monospace') + y_index += 15 + + drawn_text = visualizer.get_image() + drawn_img = np.concatenate((drawn_img, drawn_text), axis=0) + + if not args.not_show: + visualizer.show( + drawn_img, + win_name=base_name, + wait_time=args.show_interval) + + if out_file is not None: + imwrite(drawn_img[..., ::-1], out_file) + + +if __name__ == '__main__': + main() diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py new file mode 100644 index 00000000000..33e4a28cd2f --- /dev/null +++ b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py @@ -0,0 +1,277 @@ +_base_ = [ + '../../../configs/_base_/datasets/coco_detection.py', + '../../../configs/_base_/schedules/schedule_1x.py', '../../../configs/_base_/default_runtime.py' +] +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa +lang_model_name = 'bert-base-uncased' + +custom_imports = dict( + imports=['projects.mm_gdino_clip'], allow_failed_imports=False) + +model = dict( + type='GroundingDINOV2', + num_queries=900, + with_box_refine=True, + as_two_stage=True, + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=False, + ), + language_model=dict( + type='BertModel', + name=lang_model_name, + max_tokens=256, + pad_to_max=False, + use_sub_sentence_represent=True, + special_tokens_list=['[CLS]', '[SEP]', '.', '?'], + add_pooling_layer=False, + ), + backbone=dict( + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=True, + convert_weights=True, + frozen_stages=-1, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict( + type='ChannelMapper', + in_channels=[192, 384, 768], + kernel_size=1, + out_channels=256, + act_cfg=None, + bias=True, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + encoder=dict( + num_layers=6, + num_cp=6, + # visual layer config + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=256, num_levels=4, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), + # text layer config + text_layer_cfg=dict( + self_attn_cfg=dict(num_heads=4, embed_dims=256, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.0)), + # fusion layer config + fusion_layer_cfg=dict( + v_dim=256, + l_dim=256, + embed_dim=1024, + num_heads=4, + init_values=1e-4), + ), + decoder=dict( + num_layers=6, + return_intermediate=True, + layer_cfg=dict( + # query self attention layer + self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to text + cross_attn_text_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to image + cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), + post_norm_cfg=None), + positional_encoding=dict( + num_feats=128, normalize=True, offset=0.0, temperature=20), + bbox_head=dict( + type='GroundingDINOHead', + num_classes=256, + sync_cls_avg_factor=True, + contrastive_cfg=dict(max_text_len=256, log_scale='auto', bias=True), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), # 2.0 in DeformDETR + loss_bbox=dict(type='L1Loss', loss_weight=5.0)), + dn_cfg=dict( # TODO: Move to model.train_cfg ? + label_noise_scale=0.5, + box_noise_scale=1.0, # 0.4 for DN-DETR + group_cfg=dict(dynamic=True, num_groups=None, + num_dn_queries=100)), # TODO: half num_dn_queries + # training and testing settings + train_cfg=dict( + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='BinaryFocalLossCost', weight=2.0), + dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), + dict(type='IoUCost', iou_mode='giou', weight=2.0) + ])), + test_cfg=dict(max_per_img=300)) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='FixScaleResize', + scale=(400, 1340033), + keep_ratio=True, + backend='pillow'), + # dict( + # type='RandomChoice', + # transforms=[ + # [ + # dict( + # type='RandomChoiceResize', + # scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + # (608, 1333), (640, 1333), (672, 1333), (704, 1333), + # (736, 1333), (768, 1333), (800, 1333)], + # keep_ratio=True) + # ], + # [ + # dict( + # type='RandomChoiceResize', + # # The radio of all image in train dataset < 7 + # # follow the original implement + # scales=[(400, 4200), (500, 4200), (600, 4200)], + # keep_ratio=True), + # dict( + # type='RandomCrop', + # crop_type='absolute_range', + # crop_size=(384, 600), + # allow_negative_crop=True), + # dict( + # type='RandomChoiceResize', + # scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + # (608, 1333), (640, 1333), (672, 1333), (704, 1333), + # (736, 1333), (768, 1333), (800, 1333)], + # keep_ratio=True) + # ] + # ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPosV2', + tokenizer_name=lang_model_name, + num_sample_negative=85, + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] + +test_pipeline = [ + dict( + type='LoadImageFromFile', backend_args=None, + imdecode_backend='pillow'), + dict( + type='FixScaleResize', + scale=(400, 400), + keep_ratio=True, + backend='pillow'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text', 'custom_entities', + 'tokens_positive')) +] + +dataset_type = 'ODVGRECDataset' + +o365_data_root = '/home/PJLAB/huanghaian/dataset/grounding/obj365v1_200/' +obj365_od_dataset = dict( + type=dataset_type, + data_root=o365_data_root, + ann_file='o365v1_train_odvg.json', + label_map_file='o365v1_label_map.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=train_pipeline, + return_classes=True, + backend_args=None) + +rec_data_root = '/home/PJLAB/huanghaian/dataset/coco2014/' +rec_rec_dataset = dict( + type=dataset_type, + data_root=rec_data_root, + ann_file='mdetr_annotations/finetune_refcocog_train_ref.json', + data_prefix=dict(img='train2014/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=train_pipeline, + return_classes=True, + backend_args=None) + +flickr30k_vg_data_root = '/home/PJLAB/huanghaian/dataset/grounding/flickr30k_200/' +flickr30k_vg_dataset = dict( + type=dataset_type, + data_root=flickr30k_vg_data_root, + ann_file='final_flickr_separateGT_train_vg.json', + data_prefix=dict(img='flickr30k_images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=train_pipeline, + return_classes=True, + backend_args=None) + +train_dataloader = dict( + _delete_=True, + batch_size=2, + num_workers=0, + persistent_workers=False, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='MultiTaskAspectRatioBatchSampler'), + dataset=dict(type='ConcatDataset', datasets=[obj365_od_dataset, rec_rec_dataset, flickr30k_vg_dataset])) + +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, return_classes=True)) +test_dataloader = val_dataloader + +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0004, + weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'backbone': dict(lr_mult=0.1), + 'language_model': dict(lr_mult=0.1), + })) + +# learning policy +max_epochs = 30 +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000), + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[19, 26], + gamma=0.1) +] + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (16 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) + +default_hooks = dict(visualization=dict(type='GroundingVisualizationHook')) diff --git a/projects/mm_gdino_clip/grounding_dino.py b/projects/mm_gdino_clip/grounding_dino.py new file mode 100644 index 00000000000..f8b28f6afd9 --- /dev/null +++ b/projects/mm_gdino_clip/grounding_dino.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import re +import warnings +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import ConfigType +from mmdet.models.detectors import GroundingDINO + +task_map={'OD': 0, 'REC': 0, 'VG': 1} + +@MODELS.register_module() +class GroundingDINOV2(GroundingDINO): + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + tasks=[data_samples.dataset_mode for data_samples in batch_data_samples] + tasks= [task_map[task] for task in tasks] + assert len(set(tasks)) == 1, 'Only support one task in one batch, but got {}'.format(tasks) + + if tasks[0]==1: + # VG + return super().loss(batch_inputs, batch_data_samples) + else: + # OD=REC + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + + gt_labels = [ + data_samples.gt_instances.labels + for data_samples in batch_data_samples + ] + + text_dict = self.language_model(text_prompts, task='OD') + if self.text_feat_map is not None: + text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) + + + diff --git a/projects/mm_gdino_clip/odvgrec.py b/projects/mm_gdino_clip/odvgrec.py new file mode 100644 index 00000000000..27b09bbbeb6 --- /dev/null +++ b/projects/mm_gdino_clip/odvgrec.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import List, Optional + +from mmengine.fileio import get_local_path + +from mmdet.registry import DATASETS +from mmdet.datasets import BaseDetDataset + + +@DATASETS.register_module() +class ODVGRECDataset(BaseDetDataset): + """object detection and visual grounding dataset.""" + + def __init__(self, + *args, + data_root: str = '', + label_map_file: Optional[str] = None, + need_text: bool = True, + **kwargs) -> None: + self.dataset_mode = 'VG' + self.need_text = need_text + if label_map_file: + label_map_file = osp.join(data_root, label_map_file) + with open(label_map_file, 'r') as file: + self.label_map = json.load(file) + self.dataset_mode = 'OD' + super().__init__(*args, data_root=data_root, **kwargs) + assert self.return_classes is True + + def load_data_list(self) -> List[dict]: + self.image_to_exp = {} + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + with open(local_path, 'r') as f: + data_list = [json.loads(line) for line in f] + + out_data_list = [] + for data in data_list: + data_info = {} + img_path = osp.join(self.data_prefix['img'], data['filename']) + data_info['img_path'] = img_path + data_info['height'] = data['height'] + data_info['width'] = data['width'] + + if 'referring' in data: + self.dataset_mode = 'REC' + + if self.dataset_mode == 'OD': + if self.need_text: + data_info['text'] = self.label_map + anno = data.get('detection', {}) + instances = [obj for obj in anno.get('instances', [])] + bboxes = [obj['bbox'] for obj in instances] + bbox_labels = [str(obj['label']) for obj in instances] + + instances = [] + for bbox, label in zip(bboxes, bbox_labels): + instance = {} + x1, y1, x2, y2 = bbox + inter_w = max(0, min(x2, data['width']) - max(x1, 0)) + inter_h = max(0, min(y2, data['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = int(label) + instances.append(instance) + data_info['instances'] = instances + data_info['dataset_mode'] = self.dataset_mode + out_data_list.append(data_info) + elif self.dataset_mode == 'REC': + anno = data.get('referring', {}) + instances = [obj for obj in anno.get('instances', [])] + bboxes = [obj['bbox'] for obj in instances] + bbox_exp = [obj['exp'] for obj in instances] + + self.image_to_exp[img_path] = bbox_exp + + bbox_labels = list(range(len(bboxes))) + + phrases = {} + instances = [] + i = 0 + for bbox, exp, label in zip(bboxes, bbox_exp, bbox_labels): + instance = {} + x1, y1, x2, y2 = bbox + inter_w = max(0, min(x2, data['width']) - max(x1, 0)) + inter_h = max(0, min(y2, data['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = int(label) + instances.append(instance) + phrases[i] = exp + i += 1 + + data_info['instances'] = instances + data_info['dataset_mode'] = self.dataset_mode + data_info['text'] = phrases + out_data_list.append(data_info) + else: + anno = data['grounding'] + data_info['text'] = anno['caption'] + regions = anno['regions'] + + instances = [] + phrases = {} + for i, region in enumerate(regions): + bbox = region['bbox'] + phrase = region['phrase'] + tokens_positive = region['tokens_positive'] + if not isinstance(bbox[0], list): + bbox = [bbox] + for box in bbox: + instance = {} + x1, y1, x2, y2 = box + inter_w = max(0, min(x2, data['width']) - max(x1, 0)) + inter_h = max(0, min(y2, data['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + instance['ignore_flag'] = 0 + instance['bbox'] = box + instance['bbox_label'] = i + phrases[i] = { + 'phrase': phrase, + 'tokens_positive': tokens_positive + } + instances.append(instance) + data_info['instances'] = instances + data_info['phrases'] = phrases + data_info['dataset_mode'] = self.dataset_mode + out_data_list.append(data_info) + + del data_list + return out_data_list + + def prepare_data(self, idx: int): + """Pass the dataset to the pipeline during training to support mixed + data augmentation, such as Mosaic and MixUp.""" + if self.test_mode is False: + data_info = self.get_data_info(idx) + if self.dataset_mode == 'REC': + data_info['image_to_exp'] = self.image_to_exp + return self.pipeline(data_info) + else: + return super().prepare_data(idx) diff --git a/projects/mm_gdino_clip/refcoco2rec.py b/projects/mm_gdino_clip/refcoco2rec.py new file mode 100644 index 00000000000..00328df7b2b --- /dev/null +++ b/projects/mm_gdino_clip/refcoco2rec.py @@ -0,0 +1,97 @@ +import jsonlines +from pycocotools.coco import COCO +from tqdm import tqdm +import os + +ann_path = '/home/PJLAB/huanghaian/dataset/coco2014/mdetr_annotations/finetune_refcocog_train.json' + + +def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj['bbox'][2:]) for obj in anno) + + +def has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + return True + + +coco = COCO(ann_path) +ids = list(sorted(coco.imgs.keys())) +out_results = [] + +i = 0 +for img_id in tqdm(ids): + if i > 1000: + break + if isinstance(img_id, str): + ann_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=0) + else: + ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=0) + annos = coco.loadAnns(ann_ids) + + if not has_valid_annotation(annos): + continue + + img_info = coco.loadImgs(img_id)[0] + file_name = img_info['file_name'] + caption = img_info['caption'] + instance_list = [] + + for anno in annos: + box = anno['bbox'] + + x1, y1, w, h = box + inter_w = max(0, min(x1 + w, int(img_info['width'])) - max(x1, 0)) + inter_h = max(0, min(y1 + h, int(img_info['height'])) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if anno['area'] <= 0 or w < 1 or h < 1: + continue + + if anno.get('iscrowd', False): + continue + bbox_xyxy = [ + x1, y1, + min(x1 + w, int(img_info['width'])), + min(y1 + h, int(img_info['height'])) + ] + instance_list.append({ + 'bbox': bbox_xyxy, + 'exp': caption, + }) + + # 相同图片名的实例合并到一起 + if i != 0 and file_name == out_results[-1]['filename']: + pre_instance_list = out_results[-1]['referring']['instances'] + for instance in instance_list: + no_find = True + for pre_instance in pre_instance_list: + if instance['bbox'] == pre_instance['bbox'] and instance['exp'] != pre_instance['exp']: + if isinstance(pre_instance['exp'], list): + pre_instance['exp'].append(instance['exp']) + else: + pre_instance['exp'] = [pre_instance['exp'], instance['exp']] + no_find = False + break + if no_find: + pre_instance_list.append(instance) + else: + out_results.append({ + 'filename': file_name, + 'height': img_info['height'], + 'width': img_info['width'], + 'referring': { + 'instances': instance_list + } + }) + i += 1 +file_name = os.path.basename(ann_path) +out_path = os.path.join(os.path.dirname(ann_path), os.path.basename(ann_path)[:-5] + '_ref.json') +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(out_results) +print(f'save to {out_path}') diff --git a/projects/mm_gdino_clip/text_transformers.py b/projects/mm_gdino_clip/text_transformers.py new file mode 100644 index 00000000000..c63df796501 --- /dev/null +++ b/projects/mm_gdino_clip/text_transformers.py @@ -0,0 +1,376 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json + +from mmcv.transforms import BaseTransform + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import BaseBoxes + +try: + from transformers import AutoTokenizer + from transformers import BertModel as HFBertModel +except ImportError: + AutoTokenizer = None + HFBertModel = None + +import random +import re + +import numpy as np + + +def clean_name(name): + name = re.sub(r'\(.*\)', '', name) + name = re.sub(r'_', ' ', name) + name = re.sub(r' ', ' ', name) + name = name.lower() + return name + + +def check_for_positive_overflow(gt_bboxes, gt_labels, text, tokenizer, + max_tokens): + # Check if we have too many positive labels + # generate a caption by appending the positive labels + positive_label_list = np.unique(gt_labels).tolist() + # random shuffule so we can sample different annotations + # at different epochs + random.shuffle(positive_label_list) + + kept_lables = [] + length = 0 + + for index, label in enumerate(positive_label_list): + + label_text = clean_name(text[str(label)]) + '. ' + + tokenized = tokenizer.tokenize(label_text) + + length += len(tokenized) + + if length > max_tokens: + break + else: + kept_lables.append(label) + + keep_box_index = [] + keep_gt_labels = [] + for i in range(len(gt_labels)): + if gt_labels[i] in kept_lables: + keep_box_index.append(i) + keep_gt_labels.append(gt_labels[i]) + + return gt_bboxes[keep_box_index], np.array( + keep_gt_labels, dtype=np.long), length + + +def generate_senetence_given_labels(positive_label_list, negative_label_list, + text): + label_to_positions = {} + + label_list = negative_label_list + positive_label_list + + random.shuffle(label_list) + + pheso_caption = '' + + label_remap_dict = {} + for index, label in enumerate(label_list): + + start_index = len(pheso_caption) + + pheso_caption += clean_name(text[str(label)]) + + end_index = len(pheso_caption) + + if label in positive_label_list: + label_to_positions[index] = [[start_index, end_index]] + label_remap_dict[int(label)] = index + + # if index != len(label_list) - 1: + # pheso_caption += '. ' + pheso_caption += '. ' + + return label_to_positions, pheso_caption, label_remap_dict + + +@TRANSFORMS.register_module() +class RandomSamplingNegPosV2(BaseTransform): + + def __init__(self, + tokenizer_name, + num_sample_negative=85, + max_tokens=256, + full_sampling_prob=0.5, + label_map_file=None): + if AutoTokenizer is None: + raise RuntimeError( + 'transformers is not installed, please install it by: ' + 'pip install transformers.') + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.num_sample_negative = num_sample_negative + self.full_sampling_prob = full_sampling_prob + self.max_tokens = max_tokens + self.label_map = None + if label_map_file: + with open(label_map_file, 'r') as file: + self.label_map = json.load(file) + + def transform(self, results: dict) -> dict: + dataset_mode = results['dataset_mode'] + if dataset_mode == 'OD': + if np.random.rand() > 0.3: + return self.rec_aug(results) + else: + return self.od_aug(results) + elif dataset_mode == 'VG': + return self.vg_aug(results) + else: + return self.rec_aug(results) + + def rec_aug(self, results): + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, BaseBoxes): + gt_bboxes = gt_bboxes.tensor + gt_labels = results['gt_bboxes_labels'] + + if 'text' not in results: + assert self.label_map is not None + text = self.label_map + else: + text = results['text'] + + if results['dataset_mode'] == 'REC': + assert 'image_to_exp' in results + keys = list(results['image_to_exp'].keys()) + positive_label_list = np.unique(gt_labels).tolist() + + full_negative = self.num_sample_negative + + if full_negative > len(keys): + full_negative = len(keys) + + outer_prob = random.random() + + if outer_prob < self.full_sampling_prob: + # c. probability_full: add both all positive and all negatives + num_negatives = full_negative + else: + if random.random() < 1.0: + num_negatives = np.random.choice(max(1, full_negative)) + 1 + else: + num_negatives = full_negative + + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(keys): + num_negatives = len(keys) + + for i in np.random.choice( + keys, size=num_negatives, replace=False): + if i not in results['img_path']: + others_exp = results['image_to_exp'][i] + if isinstance(others_exp, list): + others_exp = random.choice(others_exp) + if isinstance(others_exp, list): + others_exp = random.choice(others_exp) + negative_label_list.add(others_exp) + + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) + random.shuffle(negative_label_list) + + label_list = positive_label_list + negative_label_list + + text = results['text'] # dict + + random.shuffle(label_list) + + label_remap_dict = {} + new_text = [] + for index, label in enumerate(label_list): + if label in positive_label_list: + label_remap_dict[int(label)] = index + _text = text[label] + if isinstance(_text, list): + _text = random.choice(_text) + new_text.append(_text) + else: + new_text.append(label) + + if len(gt_labels) > 0: + gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels) + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + results['text'] = new_text + else: + valid_negative_indexes = list(text.keys()) + + positive_label_list = np.unique(gt_labels).tolist() + full_negative = self.num_sample_negative + + if full_negative > len(valid_negative_indexes): + full_negative = len(valid_negative_indexes) + + outer_prob = random.random() + + if outer_prob < self.full_sampling_prob: + # c. probability_full: add both all positive and all negatives + num_negatives = full_negative + else: + if random.random() < 1.0: + num_negatives = np.random.choice(max(1, full_negative)) + 1 + else: + num_negatives = full_negative + + # Keep some negatives + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(valid_negative_indexes): + num_negatives = len(valid_negative_indexes) + + for i in np.random.choice( + valid_negative_indexes, size=num_negatives, replace=False): + if i not in positive_label_list: + negative_label_list.add(i) + + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) + random.shuffle(negative_label_list) + + label_list = positive_label_list + negative_label_list + random.shuffle(label_list) + + label_remap_dict = {} + for index, label in enumerate(label_list): + if label in positive_label_list: + label_remap_dict[int(label)] = index + + if len(gt_labels) > 0: + gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels) + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + results['text'] = [text[str(l)] for l in label_list] + + results['dataset_mode'] = 'REC' + return results + + def vg_aug(self, results): + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, BaseBoxes): + gt_bboxes = gt_bboxes.tensor + gt_labels = results['gt_bboxes_labels'] + text = results['text'].lower().strip() + if not text.endswith('.'): + text = text + '. ' + + phrases = results['phrases'] + # TODO: add neg + positive_label_list = np.unique(gt_labels).tolist() + label_to_positions = {} + for label in positive_label_list: + label_to_positions[label] = phrases[label]['tokens_positive'] + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + + results['text'] = text + results['tokens_positive'] = label_to_positions + return results + + def od_aug(self, results): + gt_bboxes = results['gt_bboxes'] + if isinstance(gt_bboxes, BaseBoxes): + gt_bboxes = gt_bboxes.tensor + gt_labels = results['gt_bboxes_labels'] + + if 'text' not in results: + assert self.label_map is not None + text = self.label_map + else: + text = results['text'] + + original_box_num = len(gt_labels) + # If the category name is in the format of 'a/b' (in object365), + # we randomly select one of them. + for key, value in text.items(): + if '/' in value: + text[key] = random.choice(value.split('/')).strip() + + gt_bboxes, gt_labels, positive_caption_length = \ + check_for_positive_overflow(gt_bboxes, gt_labels, + text, self.tokenizer, self.max_tokens) + + if len(gt_bboxes) < original_box_num: + print('WARNING: removed {} boxes due to positive caption overflow'. + format(original_box_num - len(gt_bboxes))) + + valid_negative_indexes = list(text.keys()) + + positive_label_list = np.unique(gt_labels).tolist() + full_negative = self.num_sample_negative + + if full_negative > len(valid_negative_indexes): + full_negative = len(valid_negative_indexes) + + outer_prob = random.random() + + if outer_prob < self.full_sampling_prob: + # c. probability_full: add both all positive and all negatives + num_negatives = full_negative + else: + if random.random() < 1.0: + num_negatives = np.random.choice(max(1, full_negative)) + 1 + else: + num_negatives = full_negative + + # Keep some negatives + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(valid_negative_indexes): + num_negatives = len(valid_negative_indexes) + + for i in np.random.choice( + valid_negative_indexes, size=num_negatives, replace=False): + if i not in positive_label_list: + negative_label_list.add(i) + + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) + random.shuffle(negative_label_list) + + negative_max_length = self.max_tokens - positive_caption_length + screened_negative_label_list = [] + + for negative_label in negative_label_list: + label_text = clean_name(text[str(negative_label)]) + '. ' + + tokenized = self.tokenizer.tokenize(label_text) + + negative_max_length -= len(tokenized) + + if negative_max_length > 0: + screened_negative_label_list.append(negative_label) + else: + break + negative_label_list = screened_negative_label_list + label_to_positions, pheso_caption, label_remap_dict = \ + generate_senetence_given_labels(positive_label_list, + negative_label_list, text) + + # label remap + if len(gt_labels) > 0: + gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels) + + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_labels'] = gt_labels + + results['text'] = pheso_caption + results['tokens_positive'] = label_to_positions + + return results From 0c5238c8349f9fffbd2fa37ea944a29e3c3bce09 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 2 Jan 2024 16:59:28 +0800 Subject: [PATCH 02/24] update code --- .../models/dense_heads/grounding_dino_head.py | 103 ++++++++++++------ mmdet/models/detectors/grounding_dino.py | 17 +-- mmdet/models/language_models/bert.py | 21 ++-- .../transformer/grounding_dino_layers.py | 2 +- .../task_modules/assigners/match_cost.py | 27 ++++- projects/mm_gdino_clip/batch_sampler.py | 16 ++- ...nding_dino_swin-t_pretrain_obj365_goldg.py | 89 +++++++-------- projects/mm_gdino_clip/grounding_dino.py | 30 +++-- projects/mm_gdino_clip/text_transformers.py | 15 +-- 9 files changed, 197 insertions(+), 123 deletions(-) diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py index 8088322546f..33912a684b6 100644 --- a/mmdet/models/dense_heads/grounding_dino_head.py +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -18,6 +18,7 @@ from ..layers import inverse_sigmoid from .atss_vlfusion_head import convert_grounding_to_cls_scores from .dino_head import DINOHead +import torch.nn.functional as F class ContrastiveEmbed(nn.Module): @@ -60,7 +61,7 @@ def __init__(self, torch.Tensor([bias_value]), requires_grad=True) def forward(self, visual_feat: Tensor, text_feat: Tensor, - text_token_mask: Tensor) -> Tensor: + text_token_mask: Tensor, need_expand=True) -> Tensor: """Forward function. Args: @@ -79,13 +80,14 @@ def forward(self, visual_feat: Tensor, text_feat: Tensor, res = res / math.sqrt(visual_feat.shape[-1]) if self.bias is not None: res = res + self.bias - res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) - - new_res = torch.full((*res.shape[:-1], self.max_text_len), - float('-inf'), - device=res.device) - new_res[..., :res.shape[-1]] = res - + if need_expand: + res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) + new_res = torch.full((*res.shape[:-1], self.max_text_len), + float('-inf'), + device=res.device) + new_res[..., :res.shape[-1]] = res + else: + new_res = res return new_res @@ -190,10 +192,16 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, # Major changes. The labels are 0-1 binary labels for each bbox # and text tokens. - labels = gt_bboxes.new_full((num_bboxes, self.max_text_len), - 0, - dtype=torch.float32) - labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + if 'positive_maps' in gt_instances: + labels = gt_bboxes.new_full((num_bboxes, self.max_text_len), + 0, + dtype=torch.float32) + labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + else: + labels = gt_bboxes.new_full((num_bboxes,), + cls_score.size(1), + dtype=torch.long) + labels[pos_inds] = gt_instances.labels[pos_assigned_gt_inds] label_weights = gt_bboxes.new_ones(num_bboxes) # bbox targets @@ -211,11 +219,12 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, neg_inds) def forward( - self, - hidden_states: Tensor, - references: List[Tensor], - memory_text: Tensor, - text_token_mask: Tensor, + self, + hidden_states: Tensor, + references: List[Tensor], + memory_text: Tensor, + text_token_mask: Tensor, + need_expand=True ) -> Tuple[Tensor]: """Forward function. @@ -257,7 +266,7 @@ def forward( hidden_state = hidden_states[layer_id] outputs_class = self.cls_branches[layer_id](hidden_state, memory_text, - text_token_mask) + text_token_mask, need_expand) tmp_reg_preds = self.reg_branches[layer_id](hidden_state) if reference.shape[-1] == 4: # When `layer` is 0 and `as_two_stage` of the detector @@ -492,7 +501,12 @@ def loss(self, hidden_states: Tensor, references: List[Tensor], batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) - outs = self(hidden_states, references, memory_text, text_token_mask) + if 'tokens_positive' in batch_data_samples[0]: + need_expand = True + else: + need_expand = False + + outs = self(hidden_states, references, memory_text, text_token_mask, need_expand) self.text_masks = text_token_mask loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, batch_gt_instances, batch_img_metas, dn_meta) @@ -539,22 +553,28 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, # ===== this change ===== # Loss is not computed for the padded regions of the text. assert (self.text_masks.dim() == 2) - text_masks = self.text_masks.new_zeros( - (self.text_masks.size(0), self.max_text_len)) - text_masks[:, :self.text_masks.size(1)] = self.text_masks + if 'positive_maps' in batch_gt_instances[0]: + text_masks = self.text_masks.new_zeros( + (self.text_masks.size(0), self.max_text_len)) + text_masks[:, :self.text_masks.size(1)] = self.text_masks + else: + text_masks = self.text_masks + num_classes = cls_scores.size(-1) + labels = F.one_hot(labels, num_classes=num_classes + 1) + labels = labels[..., :num_classes] text_mask = (text_masks > 0).unsqueeze(1) text_mask = text_mask.repeat(1, cls_scores.size(1), 1) cls_scores = torch.masked_select(cls_scores, text_mask).contiguous() labels = torch.masked_select(labels, text_mask) label_weights = label_weights[..., - None].repeat(1, 1, text_mask.size(-1)) + None].repeat(1, 1, text_mask.size(-1)) label_weights = torch.masked_select(label_weights, text_mask) # classification loss # construct weighted avg_factor to match with the official DETR repo cls_avg_factor = num_total_pos * 1.0 + \ - num_total_neg * self.bg_cls_weight + num_total_neg * self.bg_cls_weight if self.sync_cls_avg_factor: cls_avg_factor = reduce_mean( cls_scores.new_tensor([cls_avg_factor])) @@ -578,7 +598,7 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, img_h, img_w, = img_meta['img_shape'] factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).repeat( - bbox_pred.size(0), 1) + bbox_pred.size(0), 1) factors.append(factor) factors = torch.cat(factors, 0) @@ -637,15 +657,23 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, # ===== this change ===== # Loss is not computed for the padded regions of the text. assert (self.text_masks.dim() == 2) - text_masks = self.text_masks.new_zeros( - (self.text_masks.size(0), self.max_text_len)) - text_masks[:, :self.text_masks.size(1)] = self.text_masks + if 'positive_maps' in batch_gt_instances[0]: + text_masks = self.text_masks.new_zeros( + (self.text_masks.size(0), self.max_text_len)) + text_masks[:, :self.text_masks.size(1)] = self.text_masks + else: + text_masks = self.text_masks + num_classes = dn_cls_scores.size(-1) + # 临时方案,由于 _get_dn_targets_single 获取不到 dn_cls_scores + labels[labels == self.max_text_len] = num_classes + labels = F.one_hot(labels, num_classes=num_classes + 1) + labels = labels[..., :num_classes] text_mask = (text_masks > 0).unsqueeze(1) text_mask = text_mask.repeat(1, dn_cls_scores.size(1), 1) cls_scores = torch.masked_select(dn_cls_scores, text_mask).contiguous() + labels = torch.masked_select(labels, text_mask) - label_weights = label_weights[..., - None].repeat(1, 1, text_mask.size(-1)) + label_weights = label_weights[..., None].repeat(1, 1, text_mask.size(-1)) label_weights = torch.masked_select(label_weights, text_mask) # ======================= @@ -749,10 +777,17 @@ def _get_dn_targets_single(self, gt_instances: InstanceData, neg_inds = pos_inds + num_queries_each_group // 2 # label targets # this change - labels = gt_bboxes.new_full((num_denoising_queries, self.max_text_len), - 0, - dtype=torch.float32) - labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + + if 'positive_maps' in gt_instances: + labels = gt_bboxes.new_full((num_denoising_queries, self.max_text_len), + 0, + dtype=torch.float32) + labels[pos_inds] = gt_instances.positive_maps[pos_assigned_gt_inds] + else: + labels = gt_bboxes.new_full((num_denoising_queries,), + self.max_text_len, + dtype=torch.long) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] label_weights = gt_bboxes.new_ones(num_denoising_queries) # bbox targets diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py index 4ec9d14e634..f7866364db4 100644 --- a/mmdet/models/detectors/grounding_dino.py +++ b/mmdet/models/detectors/grounding_dino.py @@ -329,8 +329,8 @@ def forward_encoder(self, feat: Tensor, feat_mask: Tensor, # for text encoder memory_text=text_dict['embedded'], text_attention_mask=~text_token_mask, - position_ids=text_dict['position_ids'], - text_self_attention_masks=text_dict['masks']) + position_ids=text_dict.get('position_ids', None), + text_self_attention_masks=text_dict.get('masks', None)) encoder_outputs_dict = dict( memory=memory, memory_mask=feat_mask, @@ -353,13 +353,14 @@ def pre_decoder( output_memory, output_proposals = self.gen_encoder_output_proposals( memory, memory_mask, spatial_shapes) + if 'tokens_positive' in batch_data_samples[0]: + need_expand = True + else: + need_expand = False enc_outputs_class = self.bbox_head.cls_branches[ - self.decoder.num_layers](output_memory, memory_text, - text_token_mask) - cls_out_features = self.bbox_head.cls_branches[ - self.decoder.num_layers].max_text_len + self.decoder.num_layers](output_memory, memory_text, text_token_mask, need_expand) enc_outputs_coord_unact = self.bbox_head.reg_branches[ - self.decoder.num_layers](output_memory) + output_proposals + self.decoder.num_layers](output_memory) + output_proposals # NOTE The DINO selects top-k proposals according to scores of # multi-class classification, while DeformDETR, where the input @@ -370,7 +371,7 @@ def pre_decoder( topk_score = torch.gather( enc_outputs_class, 1, - topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) + topk_indices.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])) topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4)) diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py index f0ce94f1524..b6850a7cfee 100644 --- a/mmdet/models/language_models/bert.py +++ b/mmdet/models/language_models/bert.py @@ -71,15 +71,17 @@ def generate_masks_with_special_tokens_and_transfer_map( return attention_mask, position_ids.to(torch.long) + def split_tensor(tensor, num_levels): level_targets = [] start = 0 for n in num_levels: end = start + n - level_targets.append(target[:, start:end]) + level_targets.append(tensor[start:end]) start = end return level_targets + @MODELS.register_module() class BertModel(BaseModel): """BERT model for language embedding only encoder. @@ -147,7 +149,7 @@ def forward(self, captions: Sequence[str], task='VG', **kwargs) -> dict: """Forward function.""" device = next(self.language_backbone.parameters()).device - if task == 'OD': + if task == 'REC': batch_len_captions = [len(item) for item in captions] captions = [item for sublist in captions for item in sublist] @@ -185,21 +187,16 @@ def forward(self, captions: Sequence[str], task='VG', **kwargs) -> dict: embedded = language_dict_features['embedded'] embedded = embedded[torch.arange(embedded.shape[0]), end_token_idx] - batch_embedded = [] - batch_mask=[] - batch_text_token_mask=[] - embedded = split_tensor(embedded, batch_len_captions) embedded = align_tensor(embedded) - attention_mask = split_tensor(tokenized.attention_mask.bool(), batch_len_captions) + attention_mask = split_tensor(embedded.new_ones((len(tokenized.attention_mask))).bool(), batch_len_captions) attention_mask = align_tensor(attention_mask) - mask = split_tensor(language_dict_features['mask'], batch_len_captions) - mask = align_tensor(mask) - + # mask = split_tensor(language_dict_features['masks'], batch_len_captions) + # mask = align_tensor(mask) + del language_dict_features['masks'] + del language_dict_features['hidden'] language_dict_features['embedded'] = embedded - language_dict_features['hidden'] = embedded language_dict_features['text_token_mask'] = attention_mask - language_dict_features['mask'] = mask return language_dict_features diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py index 3c285768f36..065cae318f9 100644 --- a/mmdet/models/layers/transformer/grounding_dino_layers.py +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -240,7 +240,7 @@ def forward(self, query=memory_text, query_pos=(pos_text if pos_text is not None else None), attn_mask=~text_self_attention_masks.repeat( - text_num_heads, 1, 1), # note we use ~ for mask here + text_num_heads, 1, 1) if text_self_attention_masks is not None else None, # note we use ~ for mask here key_padding_mask=None, ) output = layer( diff --git a/mmdet/models/task_modules/assigners/match_cost.py b/mmdet/models/task_modules/assigners/match_cost.py index 5fc62f01f29..05586d110e5 100644 --- a/mmdet/models/task_modules/assigners/match_cost.py +++ b/mmdet/models/task_modules/assigners/match_cost.py @@ -334,6 +334,25 @@ def __call__(self, @TASK_UTILS.register_module() class BinaryFocalLossCost(FocalLossCost): + def _default_focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_queries, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] + return cls_cost * self.weight + def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor: """ Args: @@ -378,8 +397,12 @@ def __call__(self, text_token_mask = torch.nonzero( gt_instances.text_token_mask[0]).squeeze(-1) pred_scores = pred_instances.scores[:, text_token_mask] - gt_labels = gt_instances.positive_maps[:, text_token_mask] - return self._focal_loss_cost(pred_scores, gt_labels) + if 'positive_maps' in gt_instances: + gt_labels = gt_instances.positive_maps[:, text_token_mask] + return self._focal_loss_cost(pred_scores, gt_labels) + else: + gt_labels = gt_instances.labels + return self._default_focal_loss_cost(pred_scores, gt_labels) @TASK_UTILS.register_module() diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py index 510d2196556..b433951fdcd 100644 --- a/projects/mm_gdino_clip/batch_sampler.py +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -3,6 +3,7 @@ from torch.utils.data import BatchSampler, Sampler from mmdet.registry import DATA_SAMPLERS +import numpy as np @DATA_SAMPLERS.register_module() @@ -10,7 +11,8 @@ class MultiTaskAspectRatioBatchSampler(BatchSampler): def __init__(self, sampler: Sampler, batch_size: int, - drop_last: bool = True) -> None: + drop_last: bool = True, + od_to_rec_prob=0.7) -> None: if not isinstance(sampler, Sampler): raise TypeError('sampler should be an instance of ``Sampler``, ' f'but got {sampler}') @@ -22,15 +24,21 @@ def __init__(self, self.drop_last = drop_last # two groups for w < h and w >= h and two task self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] + self.od_to_rec_prob = od_to_rec_prob def __iter__(self) -> Sequence[int]: for idx in self.sampler: data_info = self.sampler.dataset.get_data_info(idx) width, height = data_info['width'], data_info['height'] bucket_id = 0 if width < height else 1 - # REC and OVD: 0 2 - # VG: 1 3 - if data_info['dataset_mode'] in ['REC', 'OVD']: + + if data_info['dataset_mode'] == 'OD': + if np.random.random() > 1-self.od_to_rec_prob: + data_info['dataset_mode'] = 'REC' + + # REC: 0 2 + # VG and OD: 1 3 + if data_info['dataset_mode'] == 'REC': bucket_id = bucket_id * 2 else: bucket_id = bucket_id * 2 + 1 diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py index 33e4a28cd2f..4013dea0adb 100644 --- a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py +++ b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py @@ -28,6 +28,7 @@ use_sub_sentence_represent=True, special_tokens_list=['[CLS]', '[SEP]', '.', '?'], add_pooling_layer=False, + use_checkpoint=True, # change this ), backbone=dict( type='SwinTransformer', @@ -125,42 +126,42 @@ dict(type='LoadImageFromFile', backend_args=_base_.backend_args), dict(type='LoadAnnotations', with_bbox=True), dict(type='RandomFlip', prob=0.5), - dict( - type='FixScaleResize', - scale=(400, 1340033), - keep_ratio=True, - backend='pillow'), # dict( - # type='RandomChoice', - # transforms=[ - # [ - # dict( - # type='RandomChoiceResize', - # scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), - # (608, 1333), (640, 1333), (672, 1333), (704, 1333), - # (736, 1333), (768, 1333), (800, 1333)], - # keep_ratio=True) - # ], - # [ - # dict( - # type='RandomChoiceResize', - # # The radio of all image in train dataset < 7 - # # follow the original implement - # scales=[(400, 4200), (500, 4200), (600, 4200)], - # keep_ratio=True), - # dict( - # type='RandomCrop', - # crop_type='absolute_range', - # crop_size=(384, 600), - # allow_negative_crop=True), - # dict( - # type='RandomChoiceResize', - # scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), - # (608, 1333), (640, 1333), (672, 1333), (704, 1333), - # (736, 1333), (768, 1333), (800, 1333)], - # keep_ratio=True) - # ] - # ]), + # type='FixScaleResize', + # scale=(400, 400), + # keep_ratio=True, + # backend='pillow'), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), dict( type='RandomSamplingNegPosV2', @@ -180,7 +181,7 @@ imdecode_backend='pillow'), dict( type='FixScaleResize', - scale=(400, 400), + scale=(1333, 800), keep_ratio=True, backend='pillow'), dict(type='LoadAnnotations', with_bbox=True), @@ -193,7 +194,7 @@ dataset_type = 'ODVGRECDataset' -o365_data_root = '/home/PJLAB/huanghaian/dataset/grounding/obj365v1_200/' +o365_data_root = 'obj365v1_200/' obj365_od_dataset = dict( type=dataset_type, data_root=o365_data_root, @@ -205,7 +206,7 @@ return_classes=True, backend_args=None) -rec_data_root = '/home/PJLAB/huanghaian/dataset/coco2014/' +rec_data_root = 'data/coco/' rec_rec_dataset = dict( type=dataset_type, data_root=rec_data_root, @@ -216,7 +217,7 @@ return_classes=True, backend_args=None) -flickr30k_vg_data_root = '/home/PJLAB/huanghaian/dataset/grounding/flickr30k_200/' +flickr30k_vg_data_root = 'flickr30k_200/' flickr30k_vg_dataset = dict( type=dataset_type, data_root=flickr30k_vg_data_root, @@ -229,12 +230,12 @@ train_dataloader = dict( _delete_=True, - batch_size=2, - num_workers=0, - persistent_workers=False, + batch_size=4, + num_workers=2, + persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='MultiTaskAspectRatioBatchSampler'), - dataset=dict(type='ConcatDataset', datasets=[obj365_od_dataset, rec_rec_dataset, flickr30k_vg_dataset])) + batch_sampler=dict(type='MultiTaskAspectRatioBatchSampler', od_to_rec_prob=0.7), + dataset=dict(type='ConcatDataset', datasets=[obj365_od_dataset, flickr30k_vg_dataset])) val_dataloader = dict( dataset=dict(pipeline=test_pipeline, return_classes=True)) diff --git a/projects/mm_gdino_clip/grounding_dino.py b/projects/mm_gdino_clip/grounding_dino.py index f8b28f6afd9..0b7512a15fc 100644 --- a/projects/mm_gdino_clip/grounding_dino.py +++ b/projects/mm_gdino_clip/grounding_dino.py @@ -13,34 +13,42 @@ from mmdet.utils import ConfigType from mmdet.models.detectors import GroundingDINO -task_map={'OD': 0, 'REC': 0, 'VG': 1} +task_map = {'REC': 0, 'VG': 1} + @MODELS.register_module() class GroundingDINOV2(GroundingDINO): def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Union[dict, list]: - tasks=[data_samples.dataset_mode for data_samples in batch_data_samples] - tasks= [task_map[task] for task in tasks] + tasks = [data_samples.dataset_mode for data_samples in batch_data_samples] + tasks = [task_map[task] for task in tasks] assert len(set(tasks)) == 1, 'Only support one task in one batch, but got {}'.format(tasks) - if tasks[0]==1: + if tasks[0] == 1: # VG return super().loss(batch_inputs, batch_data_samples) else: - # OD=REC + # REC text_prompts = [ data_samples.text for data_samples in batch_data_samples ] - gt_labels = [ - data_samples.gt_instances.labels - for data_samples in batch_data_samples - ] - - text_dict = self.language_model(text_prompts, task='OD') + text_dict = self.language_model(text_prompts, task='REC') if self.text_feat_map is not None: text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) + for i, data_samples in enumerate(batch_data_samples): + # for calc BinaryFocalLossCost + text_token_mask = text_dict['text_token_mask'][i] + data_samples.gt_instances.text_token_mask = \ + text_token_mask.unsqueeze(0).repeat( + len(data_samples.gt_instances), 1) + visual_features = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(visual_features, text_dict, + batch_data_samples) + losses = self.bbox_head.loss( + **head_inputs_dict, batch_data_samples=batch_data_samples) + return losses diff --git a/projects/mm_gdino_clip/text_transformers.py b/projects/mm_gdino_clip/text_transformers.py index c63df796501..b2862241bbd 100644 --- a/projects/mm_gdino_clip/text_transformers.py +++ b/projects/mm_gdino_clip/text_transformers.py @@ -119,10 +119,8 @@ def __init__(self, def transform(self, results: dict) -> dict: dataset_mode = results['dataset_mode'] if dataset_mode == 'OD': - if np.random.rand() > 0.3: - return self.rec_aug(results) - else: - return self.od_aug(results) + results['dataset_mode'] = 'VG' + return self.od_aug(results) elif dataset_mode == 'VG': return self.vg_aug(results) else: @@ -140,11 +138,11 @@ def rec_aug(self, results): else: text = results['text'] - if results['dataset_mode'] == 'REC': - assert 'image_to_exp' in results + if 'image_to_exp' in results: # REC keys = list(results['image_to_exp'].keys()) positive_label_list = np.unique(gt_labels).tolist() + # 85 有点大,会消耗比较多显存,稍微改小点 full_negative = self.num_sample_negative if full_negative > len(keys): @@ -205,7 +203,7 @@ def rec_aug(self, results): results['gt_bboxes'] = gt_bboxes results['gt_bboxes_labels'] = gt_labels results['text'] = new_text - else: + else: # OD valid_negative_indexes = list(text.keys()) positive_label_list = np.unique(gt_labels).tolist() @@ -257,6 +255,9 @@ def rec_aug(self, results): results['text'] = [text[str(l)] for l in label_list] results['dataset_mode'] = 'REC' + if 'tokens_positive' in results: + del results['tokens_positive'] + return results def vg_aug(self, results): From 482605e7c0bd90f3e385bcc916034a5e01145d7a Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 2 Jan 2024 19:01:37 +0800 Subject: [PATCH 03/24] update eval code --- .../models/dense_heads/grounding_dino_head.py | 19 +++++++----- projects/mm_gdino_clip/grounding_dino.py | 29 +++++++++++++++++++ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py index 33912a684b6..570a2f26c75 100644 --- a/mmdet/models/dense_heads/grounding_dino_head.py +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -328,12 +328,17 @@ def predict(self, batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] - batch_token_positive_maps = [ - data_samples.token_positive_map - for data_samples in batch_data_samples - ] - outs = self(hidden_states, references, memory_text, text_token_mask) + need_expand = True + batch_token_positive_maps = [] + for data_samples in batch_data_samples: + if 'token_positive_map' in data_samples: + batch_token_positive_maps.append(data_samples.token_positive_map) + else: + batch_token_positive_maps.append(None) + need_expand = False + + outs = self(hidden_states, references, memory_text, text_token_mask, need_expand=False) predictions = self.predict_by_feat( *outs, @@ -710,7 +715,7 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, img_h, img_w = img_meta['img_shape'] factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).repeat( - bbox_pred.size(0), 1) + bbox_pred.size(0), 1) factors.append(factor) factors = torch.cat(factors) @@ -732,7 +737,7 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, def _get_dn_targets_single(self, gt_instances: InstanceData, img_meta: dict, dn_meta: Dict[str, - int]) -> tuple: + int]) -> tuple: """Get targets in denoising part for one image. Args: diff --git a/projects/mm_gdino_clip/grounding_dino.py b/projects/mm_gdino_clip/grounding_dino.py index 0b7512a15fc..f42354e0514 100644 --- a/projects/mm_gdino_clip/grounding_dino.py +++ b/projects/mm_gdino_clip/grounding_dino.py @@ -52,3 +52,32 @@ def loss(self, batch_inputs: Tensor, losses = self.bbox_head.loss( **head_inputs_dict, batch_data_samples=batch_data_samples) return losses + + def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): + # only od eval for now + text_prompts = [data_samples.text for data_samples in batch_data_samples] + text_prompts = text_prompts[0] + + visual_feats = self.extract_feat(batch_inputs) + + text_dict = self.language_model([text_prompts], task='REC') + if self.text_feat_map is not None: + text_dict['embedded'] = self.text_feat_map( + text_dict['embedded']) + head_inputs_dict = self.forward_transformer( + visual_feats, text_dict, batch_data_samples) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples) + + for data_sample, pred_instances in zip( + batch_data_samples, results_list): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + label_names.append(text_prompts[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + return batch_data_samples From 449d5acd6c40486b32d3f341307d957eb5042fb8 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 3 Jan 2024 10:14:12 +0800 Subject: [PATCH 04/24] fix batch sampler --- projects/mm_gdino_clip/batch_sampler.py | 40 ++++++++++++++++++------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py index b433951fdcd..2380408f0e6 100644 --- a/projects/mm_gdino_clip/batch_sampler.py +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -25,6 +25,7 @@ def __init__(self, # two groups for w < h and w >= h and two task self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] self.od_to_rec_prob = od_to_rec_prob + assert drop_last is True def __iter__(self) -> Sequence[int]: for idx in self.sampler: @@ -33,7 +34,7 @@ def __iter__(self) -> Sequence[int]: bucket_id = 0 if width < height else 1 if data_info['dataset_mode'] == 'OD': - if np.random.random() > 1-self.od_to_rec_prob: + if np.random.random() > 1 - self.od_to_rec_prob: data_info['dataset_mode'] = 'REC' # REC: 0 2 @@ -50,17 +51,34 @@ def __iter__(self) -> Sequence[int]: del bucket[:] # yield the rest data and reset the bucket - # left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ - # 1] + self._aspect_ratio_buckets[2] + self._aspect_ratio_buckets[3] + left_rec_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[2] + left_vg_data = self._aspect_ratio_buckets[1] + self._aspect_ratio_buckets[3] self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] - # while len(left_data) > 0: - # if len(left_data) <= self.batch_size: - # if not self.drop_last: - # yield left_data[:] - # left_data = [] - # else: - # yield left_data[:self.batch_size] - # left_data = left_data[self.batch_size:] + + while len(left_rec_data) > 0: + if len(left_rec_data) > self.batch_size: + yield left_rec_data[:self.batch_size] + left_rec_data = left_rec_data[self.batch_size:] + else: + break + + while len(left_vg_data) > 0: + if len(left_vg_data) > self.batch_size: + yield left_vg_data[:self.batch_size] + left_vg_data = left_vg_data[self.batch_size:] + else: + break + + if 0 < len(left_rec_data) < self.batch_size: + left_rec_data.extend([left_rec_data[-1]] * (self.batch_size - len(left_rec_data))) + + if 0 < len(left_vg_data) < self.batch_size: + left_vg_data.extend([left_vg_data[-1]] * (self.batch_size - len(left_vg_data))) + + all_left_data = left_rec_data + left_vg_data + while len(all_left_data) > 0: + yield all_left_data[:self.batch_size] + all_left_data = all_left_data[self.batch_size:] def __len__(self) -> int: if self.drop_last: From e6d2da2e0d0e32a23cf274911d15d27bf43103e4 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 3 Jan 2024 17:49:46 +0800 Subject: [PATCH 05/24] fix bug --- mmdet/models/dense_heads/grounding_dino_head.py | 3 ++- mmdet/models/detectors/grounding_dino.py | 2 +- projects/mm_gdino_clip/batch_sampler.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py index 570a2f26c75..5b31d238866 100644 --- a/mmdet/models/dense_heads/grounding_dino_head.py +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -87,6 +87,7 @@ def forward(self, visual_feat: Tensor, text_feat: Tensor, device=res.device) new_res[..., :res.shape[-1]] = res else: + res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) new_res = res return new_res @@ -338,7 +339,7 @@ def predict(self, batch_token_positive_maps.append(None) need_expand = False - outs = self(hidden_states, references, memory_text, text_token_mask, need_expand=False) + outs = self(hidden_states, references, memory_text, text_token_mask, need_expand=need_expand) predictions = self.predict_by_feat( *outs, diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py index f7866364db4..43af4169571 100644 --- a/mmdet/models/detectors/grounding_dino.py +++ b/mmdet/models/detectors/grounding_dino.py @@ -353,7 +353,7 @@ def pre_decoder( output_memory, output_proposals = self.gen_encoder_output_proposals( memory, memory_mask, spatial_shapes) - if 'tokens_positive' in batch_data_samples[0]: + if 'tokens_positive' in batch_data_samples[0] or 'token_positive_map' in batch_data_samples[0]: need_expand = True else: need_expand = False diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py index 2380408f0e6..f8576e26501 100644 --- a/projects/mm_gdino_clip/batch_sampler.py +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -33,6 +33,7 @@ def __iter__(self) -> Sequence[int]: width, height = data_info['width'], data_info['height'] bucket_id = 0 if width < height else 1 + # BUG if data_info['dataset_mode'] == 'OD': if np.random.random() > 1 - self.od_to_rec_prob: data_info['dataset_mode'] = 'REC' From 1a247381f7a25d31e57826c8973f44ac23ae412a Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 3 Jan 2024 18:22:45 +0800 Subject: [PATCH 06/24] fix bug --- .../configs/grounding_dino_swin-t_pretrain_obj365_goldg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py index 4013dea0adb..09ef25991ca 100644 --- a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py +++ b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py @@ -28,7 +28,7 @@ use_sub_sentence_represent=True, special_tokens_list=['[CLS]', '[SEP]', '.', '?'], add_pooling_layer=False, - use_checkpoint=True, # change this + use_checkpoint=False, # change this ), backbone=dict( type='SwinTransformer', @@ -181,7 +181,7 @@ imdecode_backend='pillow'), dict( type='FixScaleResize', - scale=(1333, 800), + scale=(800, 1333), keep_ratio=True, backend='pillow'), dict(type='LoadAnnotations', with_bbox=True), From 43383d31a4d724a7f7ba3c59b85df82484b629e2 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 4 Jan 2024 09:34:15 +0800 Subject: [PATCH 07/24] fix bug --- projects/mm_gdino_clip/__init__.py | 3 +- projects/mm_gdino_clip/batch_sampler.py | 18 ++++++---- projects/mm_gdino_clip/concat_dataset.py | 34 +++++++++++++++++++ ...nding_dino_swin-t_pretrain_obj365_goldg.py | 2 +- projects/mm_gdino_clip/odvgrec.py | 7 ++++ 5 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 projects/mm_gdino_clip/concat_dataset.py diff --git a/projects/mm_gdino_clip/__init__.py b/projects/mm_gdino_clip/__init__.py index cbf3a3d673a..4a57c00811f 100644 --- a/projects/mm_gdino_clip/__init__.py +++ b/projects/mm_gdino_clip/__init__.py @@ -2,5 +2,6 @@ from .text_transformers import RandomSamplingNegPosV2 from .batch_sampler import MultiTaskAspectRatioBatchSampler from .grounding_dino import GroundingDINOV2 +from .concat_dataset import CustomConcatDataset -__all__ = ['ODVGRECDataset', 'RandomSamplingNegPosV2', 'MultiTaskAspectRatioBatchSampler', 'GroundingDINOV2'] +__all__ = ['ODVGRECDataset', 'RandomSamplingNegPosV2', 'MultiTaskAspectRatioBatchSampler', 'GroundingDINOV2', 'CustomConcatDataset'] diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py index f8576e26501..80ac8a83056 100644 --- a/projects/mm_gdino_clip/batch_sampler.py +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -28,24 +28,28 @@ def __init__(self, assert drop_last is True def __iter__(self) -> Sequence[int]: + for idx in self.sampler: - data_info = self.sampler.dataset.get_data_info(idx) - width, height = data_info['width'], data_info['height'] + wh_mode = self.sampler.dataset.get_wh_mode(idx) + dataset_mode, height, width = wh_mode bucket_id = 0 if width < height else 1 - # BUG - if data_info['dataset_mode'] == 'OD': + od_to_rec_flag = False + if dataset_mode == 'OD': if np.random.random() > 1 - self.od_to_rec_prob: - data_info['dataset_mode'] = 'REC' + dataset_mode = 'REC' + od_to_rec_flag = True + else: + od_to_rec_flag = False # REC: 0 2 # VG and OD: 1 3 - if data_info['dataset_mode'] == 'REC': + if dataset_mode == 'REC': bucket_id = bucket_id * 2 else: bucket_id = bucket_id * 2 + 1 bucket = self._aspect_ratio_buckets[bucket_id] - bucket.append(idx) + bucket.append([idx, od_to_rec_flag]) # yield a batch of indices in the same aspect ratio group if len(bucket) == self.batch_size: yield bucket[:] diff --git a/projects/mm_gdino_clip/concat_dataset.py b/projects/mm_gdino_clip/concat_dataset.py new file mode 100644 index 00000000000..459cb868824 --- /dev/null +++ b/projects/mm_gdino_clip/concat_dataset.py @@ -0,0 +1,34 @@ +from mmdet.datasets import ConcatDataset as _ConcatDataset +from mmdet.registry import DATASETS +from mmengine.logging import print_log +import logging + + +@DATASETS.register_module() +class CustomConcatDataset(_ConcatDataset): + + def __getitem__(self, idx: list): + if not self._fully_initialized: + print_log( + 'Please call `full_init` method manually to ' + 'accelerate the speed.', + logger='current', + level=logging.WARNING) + self.full_init() + + od_to_rec_flag = idx[1] + + dataset_idx, sample_idx = self._get_ori_dataset_idx(idx[0]) + + if od_to_rec_flag: + for _ in range(30): + data_info = self.datasets[dataset_idx].get_data_info(sample_idx) + assert data_info['dataset_mode'] == 'OD' + data_info['dataset_mode'] = 'REC' + data = self.datasets[dataset_idx].pipeline(data_info) + if data is None: + sample_idx = self.datasets[dataset_idx]._rand_another() + continue + return data + else: + return self.datasets[dataset_idx][sample_idx] diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py index 09ef25991ca..bec0933aea7 100644 --- a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py +++ b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py @@ -235,7 +235,7 @@ persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), batch_sampler=dict(type='MultiTaskAspectRatioBatchSampler', od_to_rec_prob=0.7), - dataset=dict(type='ConcatDataset', datasets=[obj365_od_dataset, flickr30k_vg_dataset])) + dataset=dict(type='CustomConcatDataset', datasets=[obj365_od_dataset])) val_dataloader = dict( dataset=dict(pipeline=test_pipeline, return_classes=True)) diff --git a/projects/mm_gdino_clip/odvgrec.py b/projects/mm_gdino_clip/odvgrec.py index 27b09bbbeb6..b6afec51cb8 100644 --- a/projects/mm_gdino_clip/odvgrec.py +++ b/projects/mm_gdino_clip/odvgrec.py @@ -31,6 +31,7 @@ def __init__(self, def load_data_list(self) -> List[dict]: self.image_to_exp = {} + self.wh_modes=[] with get_local_path( self.ann_file, backend_args=self.backend_args) as local_path: with open(local_path, 'r') as f: @@ -71,6 +72,7 @@ def load_data_list(self) -> List[dict]: instances.append(instance) data_info['instances'] = instances data_info['dataset_mode'] = self.dataset_mode + self.wh_modes.append([self.dataset_mode, data_info['height'], data_info['width']]) out_data_list.append(data_info) elif self.dataset_mode == 'REC': anno = data.get('referring', {}) @@ -104,6 +106,7 @@ def load_data_list(self) -> List[dict]: data_info['instances'] = instances data_info['dataset_mode'] = self.dataset_mode data_info['text'] = phrases + self.wh_modes.append([self.dataset_mode, data_info['height'], data_info['width']]) out_data_list.append(data_info) else: anno = data['grounding'] @@ -138,11 +141,15 @@ def load_data_list(self) -> List[dict]: data_info['instances'] = instances data_info['phrases'] = phrases data_info['dataset_mode'] = self.dataset_mode + self.wh_modes.append([self.dataset_mode, data_info['height'], data_info['width']]) out_data_list.append(data_info) del data_list return out_data_list + def get_wh_mode(self, idx): + return self.wh_modes[idx] + def prepare_data(self, idx: int): """Pass the dataset to the pipeline during training to support mixed data augmentation, such as Mosaic and MixUp.""" From 82da1c45dcebd577f40a7875926542183d76311d Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 4 Jan 2024 11:38:29 +0800 Subject: [PATCH 08/24] fix bug --- projects/mm_gdino_clip/concat_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/projects/mm_gdino_clip/concat_dataset.py b/projects/mm_gdino_clip/concat_dataset.py index 459cb868824..ddb27d98685 100644 --- a/projects/mm_gdino_clip/concat_dataset.py +++ b/projects/mm_gdino_clip/concat_dataset.py @@ -32,3 +32,7 @@ def __getitem__(self, idx: list): return data else: return self.datasets[dataset_idx][sample_idx] + + def get_wh_mode(self, idx): + dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) + return self.datasets[dataset_idx].get_wh_mode(sample_idx) From 63150e6c52448f0d29bb15f6b864de317ae41ee2 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 4 Jan 2024 16:09:28 +0800 Subject: [PATCH 09/24] fix bug --- .../layers/transformer/grounding_dino_layers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py index 065cae318f9..e1344793c51 100644 --- a/mmdet/models/layers/transformer/grounding_dino_layers.py +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -236,12 +236,19 @@ def forward(self, if self.text_layers: text_num_heads = self.text_layers[ layer_id].self_attn_cfg.num_heads + if text_self_attention_masks is None: + # rec + key_padding_mask = text_attention_mask + else: + # phrase grounding + key_padding_mask = None memory_text = self.text_layers[layer_id]( query=memory_text, query_pos=(pos_text if pos_text is not None else None), attn_mask=~text_self_attention_masks.repeat( - text_num_heads, 1, 1) if text_self_attention_masks is not None else None, # note we use ~ for mask here - key_padding_mask=None, + text_num_heads, 1, 1) if text_self_attention_masks is not None else None, + # note we use ~ for mask here + key_padding_mask=key_padding_mask, ) output = layer( query=output, From f6358c5b3f8f6e3f4dc44cfcad3bf451a3a885a0 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 4 Jan 2024 19:22:13 +0800 Subject: [PATCH 10/24] fix inference bug --- mmdet/models/dense_heads/grounding_dino_head.py | 10 ++++++---- projects/mm_gdino_clip/batch_sampler.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py index 5b31d238866..e00e93467e2 100644 --- a/mmdet/models/dense_heads/grounding_dino_head.py +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -442,11 +442,13 @@ def _predict_by_feat_single(self, bbox_index = indexes // num_classes bbox_pred = bbox_pred[bbox_index] else: + # TODO: REC cls_score = cls_score.sigmoid() - scores, _ = cls_score.max(-1) - scores, indexes = scores.topk(max_per_img) - bbox_pred = bbox_pred[indexes] - det_labels = scores.new_zeros(scores.shape, dtype=torch.long) + scores, indexes = cls_score.view(-1).topk(max_per_img) + num_classes = cls_score.shape[-1] + det_labels = indexes % num_classes + bbox_index = indexes // num_classes + bbox_pred = bbox_pred[bbox_index] det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py index 80ac8a83056..320f41f28c3 100644 --- a/projects/mm_gdino_clip/batch_sampler.py +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -36,7 +36,7 @@ def __iter__(self) -> Sequence[int]: od_to_rec_flag = False if dataset_mode == 'OD': - if np.random.random() > 1 - self.od_to_rec_prob: + if np.random.random() >= 1 - self.od_to_rec_prob: dataset_mode = 'REC' od_to_rec_flag = True else: From 167e73960207d4c5dd9e357df93e1e53cfe9ce1f Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 9 Jan 2024 17:32:27 +0800 Subject: [PATCH 11/24] update scripts --- mmdet/models/language_models/bert.py | 12 +- projects/mm_gdino_clip/batch_sampler.py | 7 +- .../mm_gdino_clip/browse_grounding_raw.py | 17 +- .../mm_gdino_clip/script/flickr30k2rec.py | 440 ++++++++++++++++++ projects/mm_gdino_clip/script/gqa2rec.py | 384 +++++++++++++++ .../mm_gdino_clip/{ => script}/refcoco2rec.py | 0 projects/mm_gdino_clip/script/utils/boxes.py | 85 ++++ projects/mm_gdino_clip/script/utils/dump.py | 104 +++++ projects/mm_gdino_clip/script/utils/spans.py | 235 ++++++++++ projects/mm_gdino_clip/script/utils/text.py | 135 ++++++ .../mm_gdino_clip/script/utils/unionfind.py | 31 ++ 11 files changed, 1436 insertions(+), 14 deletions(-) create mode 100644 projects/mm_gdino_clip/script/flickr30k2rec.py create mode 100644 projects/mm_gdino_clip/script/gqa2rec.py rename projects/mm_gdino_clip/{ => script}/refcoco2rec.py (100%) create mode 100644 projects/mm_gdino_clip/script/utils/boxes.py create mode 100644 projects/mm_gdino_clip/script/utils/dump.py create mode 100644 projects/mm_gdino_clip/script/utils/spans.py create mode 100644 projects/mm_gdino_clip/script/utils/text.py create mode 100644 projects/mm_gdino_clip/script/utils/unionfind.py diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py index b6850a7cfee..0fec27acfb8 100644 --- a/mmdet/models/language_models/bert.py +++ b/mmdet/models/language_models/bert.py @@ -116,11 +116,13 @@ def __init__(self, add_pooling_layer: bool = False, num_layers_of_embedded: int = 1, use_checkpoint: bool = False, + reduce_type: str = 'avg', # avg start **kwargs) -> None: super().__init__(**kwargs) self.max_tokens = max_tokens self.pad_to_max = pad_to_max + self.reduce_type = reduce_type if AutoTokenizer is None: raise RuntimeError( @@ -183,16 +185,18 @@ def forward(self, captions: Sequence[str], task='VG', **kwargs) -> dict: language_dict_features[ 'text_token_mask'] = tokenized.attention_mask.bool() else: - end_token_idx = input_ids.argmin(dim=-1) - 1 embedded = language_dict_features['embedded'] - embedded = embedded[torch.arange(embedded.shape[0]), end_token_idx] + if self.reduce_type == 'start': + end_token_idx = 0 + embedded = embedded[torch.arange(embedded.shape[0]), end_token_idx] + else: + embedded = embedded * tokenized.attention_mask[..., None].float() + embedded = embedded.sum(1) / tokenized.attention_mask.float().sum(-1)[..., None] embedded = split_tensor(embedded, batch_len_captions) embedded = align_tensor(embedded) attention_mask = split_tensor(embedded.new_ones((len(tokenized.attention_mask))).bool(), batch_len_captions) attention_mask = align_tensor(attention_mask) - # mask = split_tensor(language_dict_features['masks'], batch_len_captions) - # mask = align_tensor(mask) del language_dict_features['masks'] del language_dict_features['hidden'] language_dict_features['embedded'] = embedded diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py index 320f41f28c3..9391356a4e6 100644 --- a/projects/mm_gdino_clip/batch_sampler.py +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -36,9 +36,10 @@ def __iter__(self) -> Sequence[int]: od_to_rec_flag = False if dataset_mode == 'OD': - if np.random.random() >= 1 - self.od_to_rec_prob: - dataset_mode = 'REC' - od_to_rec_flag = True + # TODO + # if np.random.random() >= 1 - self.od_to_rec_prob: + # dataset_mode = 'REC' + od_to_rec_flag = True else: od_to_rec_flag = False diff --git a/projects/mm_gdino_clip/browse_grounding_raw.py b/projects/mm_gdino_clip/browse_grounding_raw.py index 5961fe069b3..ab57e6a5d48 100644 --- a/projects/mm_gdino_clip/browse_grounding_raw.py +++ b/projects/mm_gdino_clip/browse_grounding_raw.py @@ -237,18 +237,21 @@ def main(): phrases = [] for i, ref in enumerate(referring): bbox = ref['bbox'] - phrases.append(ref['exp']) + if isinstance(ref['exp'], list): + phrases.append(' / '.join(ref['exp'])) + else: + phrases.append(ref['exp']) bbox = np.array(bbox).reshape(-1, 4) - visualizer.draw_bboxes( - bbox, - edge_colors=colors[i], - face_colors=colors[i], - alpha=0.3) + # visualizer.draw_bboxes( + # bbox, + # edge_colors=colors[i], + # face_colors=colors[i], + # alpha=0.3) visualizer.draw_bboxes(bbox, edge_colors=colors[i], alpha=1) drawn_img = visualizer.get_image() - new_image = np.ones((100, img.shape[1], 3), dtype=np.uint8) * 255 + new_image = np.ones((len(phrases) * 20 + 100, img.shape[1], 3), dtype=np.uint8) * 255 visualizer.set_image(new_image) start_index = 2 diff --git a/projects/mm_gdino_clip/script/flickr30k2rec.py b/projects/mm_gdino_clip/script/flickr30k2rec.py new file mode 100644 index 00000000000..f69179737d0 --- /dev/null +++ b/projects/mm_gdino_clip/script/flickr30k2rec.py @@ -0,0 +1,440 @@ +import argparse +import os +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple +from xml.etree.ElementTree import parse + +import numpy as np +import torch +import xmltodict # TODO +from tqdm import tqdm +from torchvision.ops.boxes import box_area, batched_nms +import jsonlines +import copy +import os.path as osp + +""" +data/flickr30k_entities + - annotations + - Annotations + - Sentences + - flickr30k_images + - train.txt +""" + + +def parse_args(): + parser = argparse.ArgumentParser("Conversion script") + + parser.add_argument( + "--flickr_path", + default='data/flickr30k_entities', + type=str, + help="Path to the flickr dataset", + ) + parser.add_argument( + "--out_path", + default="", + type=str, + help="Path where to export the resulting dataset.", + ) + + parser.add_argument( + "--merge_ground_truth", + action="store_true", + help="Whether to follow Bryan Plummer protocol and merge ground truth. By default, all the boxes for an entity are kept separate", + ) + + return parser.parse_args() + + +def box_xywh_to_xyxy(x): + """Accepts a list of bounding boxes in coco format (xmin,ymin, width, height) + Returns the list of boxes in pascal format (xmin,ymin,xmax,ymax) + + The boxes are expected as a numpy array + """ + # result = x.copy() + result = x.clone() + result[..., 2:] += result[..., :2] + return result + + +def xyxy2xywh(box: List): + """Accepts a list of bounding boxes in pascal format (xmin,ymin,xmax,ymax) + Returns the list of boxes in coco format (xmin,ymin, width, height) + """ + xmin, ymin, xmax, ymax = box + h = ymax - ymin + w = xmax - xmin + return [xmin, ymin, w, h] + + +def get_sentence_data(filename) -> List[Dict[str, Any]]: + """ + Parses a sentence file from the Flickr30K Entities dataset + + input: + filename - full file path to the sentence file to parse + + output: + a list of dictionaries for each sentence with the following fields: + sentence - the original sentence + phrases - a list of dictionaries for each phrase with the + following fields: + phrase - the text of the annotated phrase + first_word_index - the position of the first word of + the phrase in the sentence + phrase_id - an identifier for this phrase + phrase_type - a list of the coarse categories this + phrase belongs to + + """ + with open(filename, "r") as f: + sentences = f.read().split("\n") + + annotations = [] + for sentence in sentences: + if not sentence: + continue + + first_word = [] + phrases = [] + phrase_id = [] + phrase_type = [] + words = [] + current_phrase = [] + add_to_phrase = False + for token in sentence.split(): + if add_to_phrase: + if token[-1] == "]": + add_to_phrase = False + token = token[:-1] + current_phrase.append(token) + phrases.append(" ".join(current_phrase)) + current_phrase = [] + else: + current_phrase.append(token) + + words.append(token) + else: + if token[0] == "[": + add_to_phrase = True + first_word.append(len(words)) + parts = token.split("/") + phrase_id.append(parts[1][3:]) + phrase_type.append(parts[2:]) + else: + words.append(token) + + sentence_data = {"sentence": " ".join(words), "phrases": []} + for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type): + sentence_data["phrases"].append( + {"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type} + ) + + annotations.append(sentence_data) + + return annotations + + +def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]: + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clip(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + return inter, union + + +def box_iou(boxes1: np.array, boxes2: np.array) -> np.array: + """ + Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]) + boxes2 (Tensor[M, 4]) + + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + """ + inter, union = _box_inter_union(boxes1, boxes2) + iou = inter / union + return iou + + +class UnionFind: + """Optimized union find structure""" + + def __init__(self, n): + """Initialize a union find with n components""" + self.compo = list(range(n)) + self.weight = [1] * n + self.nb_compo = n + + def get_nb_compo(self): + return self.nb_compo + + def find(self, x): + if self.compo[x] == x: + return x + self.compo[x] = self.find(self.compo[x]) + return self.compo[x] + + def unite(self, a, b): + fa = self.find(a) + fb = self.find(b) + if fa != fb: + self.nb_compo -= 1 + if self.weight[fb] > self.weight[fa]: + fa, fb = fb, fa + self.compo[fb] = fa + self.weight[fa] += self.weight[fb] + + +def get_equivalent_boxes(all_boxes, iou_threshold=0.95): + """Find clusters of highly overlapping boxes + Parameters: + - all_boxes: a list of boxes in [center_x, center_y, w, h] format + - iou_threshold: threshold at which we consider two boxes to be the same + + Returns a dict where the keys are an arbitrary id, and the values are the equivalence lists + """ + if len(all_boxes) == 0: + return {0: []} + uf = UnionFind(len(all_boxes)) + + # xy_boxes = box_xywh_to_xyxy(np.asarray(all_boxes)) + xy_boxes = box_xywh_to_xyxy(torch.as_tensor(all_boxes, dtype=torch.float)) + iou = box_iou(xy_boxes, xy_boxes) + for i, j in zip(*np.where(iou >= iou_threshold)): + uf.unite(i, j) + compo = defaultdict(list) + for i in range(len(all_boxes)): + compo[uf.find(i)].append(i) + return compo + + +def convert( + subset: str, flickr_path: Path, merge_ground_truth: bool, next_img_id: int = 1, + next_id: int = 1 +): + with open(flickr_path / f"{subset}.txt") as fd: + ids = [int(l.strip()) for l in fd] + + multibox_entity_count = 0 + + out_results = [] + total_phrase = 0 + total_bbox = 0 + + print(f"Exporting {subset}...") + for img_id in tqdm(ids): + + with open(flickr_path / "annotations" / "Annotations" / f"{img_id}.xml") as xml_file: + annotation = xmltodict.parse(xml_file.read())["annotation"] + + cur_img = { + "filename": annotation["filename"], + "height": int(annotation["size"]["height"]), + "width": int(annotation["size"]["width"]), + } + + instance_list = [] + # image = cv2.imread(output_path / "flickr30k-images" / annotation["filename"]) + # if image.shape[1] != cur_img["width"] or image.shape[0] != cur_img["height"]: + # print("before exif correction: ", cur_img) + # cur_img["width"], cur_img["height"] = image.shape[1], image.shape[0] + # print("after exif correction: ", cur_img) + + anno_file = os.path.join(flickr_path, "annotations/Annotations/%d.xml" % img_id) + + # Parse Annotation + root = parse(anno_file).getroot() + obj_elems = root.findall("./object") + target_bboxes = {} + + for elem in obj_elems: + if elem.find("bndbox") is None or len(elem.find("bndbox")) == 0: + continue + xmin = float(elem.findtext("./bndbox/xmin")) + ymin = float(elem.findtext("./bndbox/ymin")) + xmax = float(elem.findtext("./bndbox/xmax")) + ymax = float(elem.findtext("./bndbox/ymax")) + assert 0 < xmin and 0 < ymin + + h = ymax - ymin + w = xmax - xmin + + coco_box = [xmin, ymin, w, h] + + for name in elem.findall("name"): + entity_id = int(name.text) + assert 0 < entity_id + if not entity_id in target_bboxes: + target_bboxes[entity_id] = [] + else: + multibox_entity_count += 1 + # Dict from entity_id to list of all the bounding boxes + target_bboxes[entity_id].append(coco_box) + + if merge_ground_truth: + merged_bboxes = defaultdict(list) + for eid, bbox_list in target_bboxes.items(): + boxes_xyxy = box_xywh_to_xyxy(torch.as_tensor(bbox_list, dtype=torch.float)) + gt_box_merged = [ + min(boxes_xyxy[:, 0]).item(), + min(boxes_xyxy[:, 1]).item(), + max(boxes_xyxy[:, 2]).item(), + max(boxes_xyxy[:, 3]).item(), + ] + merged_bboxes[eid] = [xyxy2xywh(gt_box_merged)] # convert back to xywh for coco format + + target_bboxes = merged_bboxes + + sents = get_sentence_data(flickr_path / "annotations/Sentences" / f"{img_id}.txt") + for sent_id, sent in enumerate(sents): + + spans = {} # global phrase ID to span in sentence + phraseid2entityid = {} + entityid2phraseid = defaultdict(list) + sentence = sent["sentence"] + entity_ids = [int(p["phrase_id"]) for p in sent["phrases"]] + + for global_phrase_id, phrase in enumerate(sent["phrases"]): + phraseid2entityid[global_phrase_id] = int(phrase["phrase_id"]) + entityid2phraseid[int(phrase["phrase_id"])].append(global_phrase_id) + first_word = phrase["first_word_index"] + beg = sum([len(x) for x in sentence.split()[:first_word]]) + first_word + spans[global_phrase_id] = (beg, beg + len(phrase["phrase"])) + assert sentence[beg: beg + len(phrase["phrase"])] == phrase["phrase"] + + all_boxes_in_sent = [] + for ent_id in entity_ids: + if ent_id in target_bboxes: + for bb in target_bboxes[ent_id]: + all_boxes_in_sent.append({"ent_id": int(ent_id), "coords": bb}) + + equivalences = get_equivalent_boxes([b["coords"] for b in all_boxes_in_sent], 0.95) + + tokens_positive_eval = [] + for gpid, span in spans.items(): + if phraseid2entityid[gpid] in target_bboxes: + tokens_positive_eval.append([span]) + + for equiv in equivalences.values(): + if len(equiv) == 0: + continue + cur_entids = set([all_boxes_in_sent[bid]["ent_id"] for bid in equiv]) + token_spans = [] + for entid in cur_entids: + token_spans += [spans[gid] for gid in entityid2phraseid[entid]] + xmin, ymin, w, h = all_boxes_in_sent[equiv[-1]]["coords"] + + phrase = " ".join([sentence[sp[0]:sp[1]] for sp in token_spans]) + + cur_obj = { + "bbox": [xmin, ymin, w + xmin, h + ymin], + "exp": phrase, + } + next_id += 1 + instance_list.append(cur_obj) + + # 相同图片名的实例合并到一起 + out_instance = {} + for instance in instance_list: + if instance['exp'] in out_instance: + data = out_instance[instance['exp']] + if isinstance(data['bbox'][0], list): + # 如果 bbox 是相同的,就直接留一个就行 + is_same = False + for bbox in data['bbox']: + if bbox == instance['bbox']: + is_same = True + break + if not is_same: + data['bbox'].append(instance['bbox']) + else: + # 如果 bbox 是相同的,就直接留一个就行 + if data['bbox'] != instance['bbox']: + data['bbox'] = [data['bbox'], instance['bbox']] + else: + out_instance[instance['exp']] = copy.deepcopy(instance) + + out_instance = list(out_instance.values()) + + # 不同 phrase 但是 bbox 相同的需要合并 + new_out_instance = [] + temp_bboxes = [] + for instance in out_instance: + bbox = instance['bbox'] + if bbox not in temp_bboxes: + new_out_instance.append(copy.deepcopy(instance)) + temp_bboxes.append(bbox) + else: + index = temp_bboxes.index(bbox) + instance_ = new_out_instance[index] + if isinstance(instance_['exp'], list): + # 如果 phrase 是相同的,就直接留一个就行 + is_same = False + for exp in instance_['exp']: + if exp.lower() == instance['exp'].lower(): + is_same = True + break + if not is_same: + instance_['exp'].append(instance['exp']) + else: + # 如果去除大小写后一样,则只保留其中一个 + if instance_['exp'].lower() != instance['exp'].lower(): + instance_['exp'] = [instance_['exp'], instance['exp']] + + # 每条数据 nms + for instance in new_out_instance: + if isinstance(instance['bbox'][0], list): + bboxes = torch.as_tensor(instance['bbox'], dtype=torch.float).reshape(-1, 4) + score = torch.ones(len(bboxes), dtype=torch.float) + index = batched_nms(bboxes, score, score, iou_threshold=0.9) + if len(index) != len(score): + # print('nms vaild', cur_img['filename'], instance['exp'], bboxes, bboxes[index]) + print('nms vaild', cur_img['filename'], instance['exp']) + bboxes = bboxes[index].numpy().tolist() + instance['bbox'] = bboxes + + total_phrase += len(new_out_instance) + total_bbox_ = [len(ins['bbox']) for ins in new_out_instance] + total_bbox += sum(total_bbox_) + next_img_id += 1 + cur_img['referring'] = {} + cur_img['referring']['instances'] = new_out_instance + out_results.append(cur_img) + + print(f'total image: {len(out_results)}, total phrase: {total_phrase}, total bbox: {total_bbox}') + if merge_ground_truth: + filename = f"flickr30k_mergedGT_{subset}_rec.json" + else: + filename = f"flickr30k_separateGT_{subset}_rec.json" + + out_path = osp.join(flickr_path, filename) + + with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(out_results) + print(f'save to {out_path}') + + +def main(args): + flickr_path = Path(args.flickr_path) + convert("train", flickr_path, args.merge_ground_truth) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/projects/mm_gdino_clip/script/gqa2rec.py b/projects/mm_gdino_clip/script/gqa2rec.py new file mode 100644 index 00000000000..677f776efa7 --- /dev/null +++ b/projects/mm_gdino_clip/script/gqa2rec.py @@ -0,0 +1,384 @@ +""" +data_path : path to original GQA annotations to be downloaded from https://cs.stanford.edu/people/dorarad/gqa/download.html +img_path : path to original GQA images to be downloaded from https://cs.stanford.edu/people/dorarad/gqa/download.html +sg_path : path to original GQA scene graphs to be downloaded from https://cs.stanford.edu/people/dorarad/gqa/download.html +vg_img_data_path : path to image info for VG images to be downloaded from https://visualgenome.org/static/data/dataset/image_data.json.zip + + +data/gqa + - questions1.2 + - sceneGraphs + - image_data.json # from VG +""" + +import argparse +import json +import os +import re +from collections import defaultdict +from pathlib import Path +import sys +from tqdm import tqdm +import os.path as osp +import jsonlines +import torch +import copy +from torchvision.ops.boxes import batched_nms + + +PACKAGE_PARENT = "." +SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) +sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) +from utils.spans import consolidate_spans + + +# pip install nltk spacy + + +def parse_args(): + parser = argparse.ArgumentParser("Conversion script") + + parser.add_argument( + "--data_path", + default='data/gqa/questions1.2/', + type=str, + help="Path to the gqa dataset", + ) + parser.add_argument( + "--sg_path", + default='data/gqa/sceneGraphs/', + type=str, + help="Path to the gqa dataset scene graph", + ) + + parser.add_argument( + "--vg_img_data_path", + default='data/gqa/', + type=str, + help="Path to image meta data for VG" + ) + return parser.parse_args() + + +def convert(out_results, split, data_path, sg_path, imid2data, next_img_id=1, next_id=1): + print("Loading", data_path / f"{split}_balanced_questions.json") + with open(data_path / f"{split}_balanced_questions.json", "r") as f: + data = json.load(f) + print("Loading", sg_path / f"{split}_sceneGraphs.json") + with open(sg_path / f"{split}_sceneGraphs.json", "r") as f: + sg_data = json.load(f) + + img2ann = defaultdict(dict) + for k, v in data.items(): + img2ann[v["imageId"]][k] = v + print(len(img2ann)) + print(img2ann["2354786"]) + print(img2ann[list(img2ann.keys())[0]].keys()) + + # Add missing annotations by inspecting the semantic field + regexp = re.compile(r"([0-9]+)") + regexp2 = re.compile(r"([A-z]+)") + count = 0 + + for k, v in img2ann.items(): + for ann_id, annotations in v.items(): + expected_boxes = [] + for item in annotations["semantic"]: + if item["operation"] == "select": + if len(regexp.findall(item["argument"])) > 0: + expected_boxes.append( + (regexp2.findall(item["argument"])[0].strip(), regexp.findall(item["argument"])[0]) + ) + question_boxes = list(annotations["annotations"]["question"].values()) + + for name, box_id in expected_boxes: + if box_id not in question_boxes: + count += 1 + beg = annotations["question"].find(name) + end = beg + len(name) + annotations["annotations"]["question"][(beg, end)] = box_id + + print(len(img2ann)) + print(img2ann["2354786"]) + print(img2ann[list(img2ann.keys())[0]].keys()) + + # Add annotations for the questions where there is a box for the answer but not for the question (what/where/who questions) + for k, v in img2ann.items(): + for ann_id, ann in v.items(): + question_objects = list(ann["annotations"]["question"].values()) + answer_objects = list(ann["annotations"]["answer"].values()) + if len(set(answer_objects) - set(question_objects)) > 0: + + for box_id in answer_objects: + if box_id not in question_objects: + + if ann["question"].find("What") > -1: + beg = ann["question"].find("What") + end = beg + len("What") + elif ann["question"].find("what") > -1: + beg = ann["question"].find("what") + end = beg + len("what") + elif ann["question"].find("Who") > -1: + beg = ann["question"].find("Who") + end = beg + len("Who") + elif ann["question"].find("who") > -1: + beg = ann["question"].find("who") + end = beg + len("who") + elif ann["question"].find("Where") > -1: + beg = ann["question"].find("Where") + end = beg + len("Where") + elif ann["question"].find("where") > -1: + beg = ann["question"].find("where") + end = beg + len("where") + else: + continue + + ann["annotations"]["question"][(beg, end)] = box_id + + print(f"Dumping {split}...") + # next_img_id = 0 + # next_id = 0 + + for k, v in tqdm(img2ann.items()): + filename = f"{k}.jpg" + cur_img = { + "filename": filename, + "height": imid2data[int(k)]["height"], + "width": imid2data[int(k)]["width"], + # "id": next_img_id, + # "original_id": k, + } + instance_list = [] + + # image = read_image(data_path / "images" / filename, format="BGR") + # if image.shape[1] != cur_img["width"] or image.shape[0] != cur_img["height"]: + # print("before exif correction: ", cur_img) + # cur_img["width"], cur_img["height"] = image.shape[1], image.shape[0] + # print("after exif correction: ", cur_img) + # if filename == "860.jpg": + # print(v) + + for ann_id, annotation in v.items(): + question = annotation["question"] + answer = annotation["answer"] + full_answer = annotation["fullAnswer"] + + if len(annotation["annotations"]["question"]) > 0: + + # assert len(annotation["annotations"]["question"]) == 1 + # if len(annotation["annotations"]["question"]) > 1: + # print(annotation) + phrase_all = [] + for text_tok_id, box_anno_id in annotation["annotations"]["question"].items(): + target_bbox = sg_data[k]["objects"][box_anno_id] + x, y, h, w = target_bbox["x"], target_bbox["y"], target_bbox["h"], target_bbox["w"] + target_bbox = [x, y, w, h] + + if isinstance(text_tok_id, str): + if ":" in text_tok_id: + text_tok_id = text_tok_id.split(":") + if isinstance(text_tok_id, list) and len(text_tok_id) > 1: + beg = sum([len(x) for x in question.split()[: int(text_tok_id[0])]]) + int(text_tok_id[0]) + end = ( + sum([len(x) for x in question.split()[: int(text_tok_id[1]) - 1]]) + + int(text_tok_id[1]) + - 1 + ) + end = end + len(question.split()[int(text_tok_id[1]) - 1]) + else: + beg = sum([len(x) for x in question.split()[: int(text_tok_id)]]) + int(text_tok_id) + end = beg + len(question.split()[int(text_tok_id)]) + else: + beg, end = text_tok_id + + cleaned_span = consolidate_spans([(beg, end)], question) + + question_positive = " ".join([question[sp[0]:sp[1]] for sp in cleaned_span]) + + if question_positive.lower() in ["what", "who", "where"]: + phrase = answer + else: + phrase = question_positive + phrase_all.append(phrase) + + for text_tok_id, box_anno_id in annotation["annotations"]["question"].items(): + target_bbox = sg_data[k]["objects"][box_anno_id] + x, y, h, w = target_bbox["x"], target_bbox["y"], target_bbox["h"], target_bbox["w"] + target_bbox = [x, y, w + x, h + y] + + if isinstance(text_tok_id, str): + if ":" in text_tok_id: + text_tok_id = text_tok_id.split(":") + if isinstance(text_tok_id, list) and len(text_tok_id) > 1: + beg = sum([len(x) for x in question.split()[: int(text_tok_id[0])]]) + int(text_tok_id[0]) + end = ( + sum([len(x) for x in question.split()[: int(text_tok_id[1]) - 1]]) + + int(text_tok_id[1]) + - 1 + ) + end = end + len(question.split()[int(text_tok_id[1]) - 1]) + else: + beg = sum([len(x) for x in question.split()[: int(text_tok_id)]]) + int(text_tok_id) + end = beg + len(question.split()[int(text_tok_id)]) + else: + beg, end = text_tok_id + + cleaned_span = consolidate_spans([(beg, end)], question) + + question_positive = " ".join([question[sp[0]:sp[1]] for sp in cleaned_span]) + + phrase = question_positive + if any([phrase.lower().startswith(p) for p in ["what", "who", "where"]]): + phrase = answer + elif question_positive.lower() == "wh": + phrase = answer + elif question_positive.lower() == "ho": + phrase = answer + + if sum([1 if p in full_answer else 0 for p in phrase_all]) == 1: + if answer in full_answer and phrase in full_answer: + phrase = full_answer + # beg = full_answer.index(phrase) + # end = beg + len(phrase) + # print([[(beg, end)]], full_answer, phrase) + # cleaned_span, phrase = get_canonical_spans([[(beg, end)]], full_answer) + # print(cleaned_span, phrase) + + if phrase.lower() == "he": + if "man" in full_answer or "boy" in full_answer or "guy" in full_answer: + phrase = full_answer + else: + phrase = "man" + if phrase.lower() == "she": + if "woman" in full_answer or "lady" in full_answer or "girl" in full_answer: + phrase = full_answer + else: + phrase = "woman" + + if len(phrase) == 2 and not (phrase.lower() == "tv" or phrase.lower() == "cd"): + phrase = full_answer + + if len(phrase) == 1: + phrase = full_answer + + if phrase.lower().startswith("no, "): + phrase = phrase[4:] + if phrase.lower().startswith("yes, "): + phrase = phrase[5:] + + cur_obj = { + # "area": h * w, + # "iscrowd": 0, + # "category_id": 1, + "bbox": target_bbox, + # "image_id": next_img_id, + # "id": next_id, + # "question": question, + # "answer": answer, + # "full_answer": full_answer, + # "tokens_positive": cleaned_span, + # "question_positive": question_positive, + "exp": phrase, + } + + next_id += 1 + instance_list.append(cur_obj) + + # 相同图片名的实例合并到一起 + out_instance = {} + for instance in instance_list: + if instance['exp'] in out_instance: + data = out_instance[instance['exp']] + if isinstance(data['bbox'][0], list): + # 如果 bbox 是相同的,就直接留一个就行 + is_same = False + for bbox in data['bbox']: + if bbox == instance['bbox']: + is_same = True + break + if not is_same: + data['bbox'].append(instance['bbox']) + else: + # 如果 bbox 是相同的,就直接留一个就行 + if data['bbox'] != instance['bbox']: + data['bbox'] = [data['bbox'], instance['bbox']] + else: + out_instance[instance['exp']] = copy.deepcopy(instance) + + out_instance = list(out_instance.values()) + + # 不同 phrase 但是 bbox 相同的需要合并 + new_out_instance = [] + temp_bboxes = [] + for instance in out_instance: + bbox = instance['bbox'] + if bbox not in temp_bboxes: + new_out_instance.append(copy.deepcopy(instance)) + temp_bboxes.append(bbox) + else: + index = temp_bboxes.index(bbox) + instance_ = new_out_instance[index] + if isinstance(instance_['exp'], list): + # 如果 phrase 是相同的,就直接留一个就行 + is_same = False + for exp in instance_['exp']: + if exp.lower() == instance['exp'].lower(): + is_same = True + break + if not is_same: + instance_['exp'].append(instance['exp']) + else: + # 如果去除大小写后一样,则只保留其中一个 + if instance_['exp'].lower() != instance['exp'].lower(): + instance_['exp'] = [instance_['exp'], instance['exp']] + + # 每条数据 nms + for instance in new_out_instance: + if isinstance(instance['bbox'][0], list): + bboxes = torch.as_tensor(instance['bbox'], dtype=torch.float).reshape(-1, 4) + score = torch.ones(len(bboxes), dtype=torch.float) + index = batched_nms(bboxes, score, score, iou_threshold=0.9) + if len(index) != len(score): + # print('nms vaild', cur_img['filename'], instance['exp'], bboxes, bboxes[index]) + print('nms vaild', cur_img['filename'], instance['exp']) + bboxes = bboxes[index].numpy().tolist() + instance['bbox'] = bboxes + + next_img_id += 1 + cur_img['referring'] = {} + cur_img['referring']['instances'] = new_out_instance + out_results.append(cur_img) + + return out_results, next_img_id, next_id + +def main(args): + data_path = Path(args.data_path) + sg_path = Path(args.sg_path) + + print("Loading", f"{args.vg_img_data_path}/image_data.json") + with open(f"{args.vg_img_data_path}/image_data.json", "r") as f: + image_data = json.load(f) + imid2data = {x["image_id"]: x for x in image_data} + + out_results = [] + out_results, next_img_id, next_id = convert(out_results, "train", data_path, sg_path, imid2data) + out_results, _, _ = convert(out_results, "val", data_path, sg_path, imid2data, next_img_id, next_id) + + total_phrase = 0 + total_bbox = 0 + for result in out_results: + total_phrase += len(result['referring']['instances']) + total_bbox_ = [len(ins['bbox']) for ins in result['referring']['instances']] + total_bbox += sum(total_bbox_) + print(f'total image: {len(out_results)}, total phrase: {total_phrase}, total bbox: {total_bbox}') + + filename = f"gqa_rec.json" + out_path = osp.join(args.vg_img_data_path, filename) + + with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(out_results) + print(f'save to {out_path}') + + +if __name__ == "__main__": + main(parse_args()) diff --git a/projects/mm_gdino_clip/refcoco2rec.py b/projects/mm_gdino_clip/script/refcoco2rec.py similarity index 100% rename from projects/mm_gdino_clip/refcoco2rec.py rename to projects/mm_gdino_clip/script/refcoco2rec.py diff --git a/projects/mm_gdino_clip/script/utils/boxes.py b/projects/mm_gdino_clip/script/utils/boxes.py new file mode 100644 index 00000000000..e655602dca5 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/boxes.py @@ -0,0 +1,85 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""Utilities to manipulate and convert boxes""" +from collections import defaultdict +from typing import Any, Dict + +import torch +from torchvision.ops.boxes import box_iou + +from .unionfind import UnionFind + + +def obj_to_box(obj: Dict[str, Any]): + """Extract the bounding box of a given object as a list""" + return [obj["x"], obj["y"], obj["w"], obj["h"]] + + +def region_to_box(obj: Dict[str, Any]): + """Extract the bounding box of a given region as a list""" + return [obj["x"], obj["y"], obj["width"], obj["height"]] + + +def get_boxes_equiv(orig_boxes, iou_threshold): + """Given a set of boxes, returns a dict containing clusters of boxes that are highly overlapping. + For optimization, return None if none of the boxes are overlapping + A high overlap is characterized by the iou_threshold + Boxes are expected as [top_left_x, top_left_y, width, height] + """ + boxes = torch.as_tensor(orig_boxes, dtype=torch.float) + # Convert to (x,y,x,y) format + boxes[:, 2:] += boxes[:, :2] + ious = box_iou(boxes, boxes) + uf = UnionFind(len(boxes)) + for i in range(len(boxes)): + for j in range(i + 1, len(boxes)): + if ious[i][j] >= iou_threshold: + uf.unite(i, j) + if len(orig_boxes) == uf.nb_compo: + # We didn't found any opportunity for merging, returning as-is + # print("no merging") + return None, None + # print("merging") + compo2boxes = defaultdict(list) + compo2id = defaultdict(list) + + for i in range(len(boxes)): + compo2boxes[uf.find(i)].append(boxes[i]) + compo2id[uf.find(i)].append(i) + assert len(compo2boxes) == uf.nb_compo + return compo2boxes, compo2id + + +def xyxy_to_xywh(boxes: torch.Tensor): + """Converts a set of boxes in [top_left_x, top_left_y, bottom_right_x, bottom_right_y] format to + [top_left_x, top_left_y, width, height] format""" + assert boxes.shape[-1] == 4 + converted = boxes.clone() + converted[..., 2:] -= converted[..., :2] + return converted + + +def combine_boxes(orig_boxes, iou_threshold=0.7): + """Given a set of boxes, returns the average of all clusters of boxes that are highly overlapping. + A high overlap is characterized by the iou_threshold + Boxes are expected as [top_left_x, top_left_y, width, height] + """ + compo2boxes, _ = get_boxes_equiv(orig_boxes, iou_threshold) + if compo2boxes is None: + return orig_boxes + result_boxes = [] + for box_list in compo2boxes.values(): + result_boxes.append(xyxy_to_xywh(torch.stack(box_list, 0).mean(0)).tolist()) + return result_boxes + + +def box_iou_helper(b1, b2): + """returns the iou matrix between two sets of boxes + The boxes are expected in the format [top_left_x, top_left_y, w, h] + """ + boxes_r1 = torch.as_tensor(b1, dtype=torch.float) + # Convert to (x,y,x,y) format + boxes_r1[:, 2:] += boxes_r1[:, :2] + boxes_r2 = torch.as_tensor(b2, dtype=torch.float) + # Convert to (x,y,x,y) format + boxes_r2[:, 2:] += boxes_r2[:, :2] + return box_iou(boxes_r1, boxes_r2) diff --git a/projects/mm_gdino_clip/script/utils/dump.py b/projects/mm_gdino_clip/script/utils/dump.py new file mode 100644 index 00000000000..d4e2763a6f2 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/dump.py @@ -0,0 +1,104 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +import json +from typing import Any, List, NamedTuple, Optional, Tuple + + +class Annotation(NamedTuple): + area: float + iscrowd: int + category_id: int + bbox: List[float] + giou_friendly_bbox: List[float] + tokens_positive: List[Tuple[int, int]] + + +class Datapoint(NamedTuple): + image_id: int + dataset_name: str + tokens_negative: List[Tuple[int, int]] + original_id: int + caption: str + annotations: List[Annotation] + + +def convert2dataset_combined( + datapoint_list_coco: List[Datapoint], + datapoint_list_vg: List[Datapoint], + imgid2imginfo_coco, + imgid2imginfo_vg, + output_path, +): + """""" + print(f"Dumping combined coco and vg images related all training examples...") + next_img_id = 0 + next_id = 0 + + annotations = [] + images = [] + + for datapoint in datapoint_list_coco: + img_id = datapoint.image_id + filename = imgid2imginfo_coco[img_id]["file_name"] + cur_img = { + "file_name": filename, + "height": imgid2imginfo_coco[img_id]["height"], + "width": imgid2imginfo_coco[img_id]["width"], + "id": next_img_id, + "original_id": img_id, + "caption": datapoint.caption, + "tokens_negative": datapoint.tokens_negative, + "data_source": "coco", + "dataset_name": datapoint.dataset_name, + } + + for anns in datapoint.annotations: + cur_obj = { + "area": float(anns.area), + "iscrowd": anns.iscrowd, + "image_id": next_img_id, + "category_id": anns.category_id, + "id": next_id, + "bbox": anns.bbox, + "tokens_positive": anns.tokens_positive, + } + next_id += 1 + annotations.append(cur_obj) + + next_img_id += 1 + images.append(cur_img) + + for datapoint in datapoint_list_vg: + img_id = datapoint.image_id + filename = f"{img_id}.jpg" + cur_img = { + "file_name": filename, + "height": imgid2imginfo_vg[img_id]["height"], + "width": imgid2imginfo_vg[img_id]["width"], + "id": next_img_id, + "original_id": img_id, + "caption": datapoint.caption, + "tokens_negative": datapoint.tokens_negative, + "data_source": "vg", + "dataset_name": datapoint.dataset_name, + } + + for anns in datapoint.annotations: + cur_obj = { + "area": float(anns.area), + "iscrowd": anns.iscrowd, + "image_id": next_img_id, + "category_id": anns.category_id, + "id": next_id, + "bbox": anns.bbox, + "tokens_positive": anns.tokens_positive, + } + next_id += 1 + annotations.append(cur_obj) + + next_img_id += 1 + images.append(cur_img) + + ds = {"info": [], "licenses": [], "images": images, "annotations": annotations, "categories": []} + with open(output_path / f"final_mixed_train.json", "w") as j_file: + json.dump(ds, j_file) + return next_img_id, next_id diff --git a/projects/mm_gdino_clip/script/utils/spans.py b/projects/mm_gdino_clip/script/utils/spans.py new file mode 100644 index 00000000000..b2839ac1e85 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/spans.py @@ -0,0 +1,235 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""""" + +from typing import List, Tuple + +from .text import STOP_WORDS, nlp + + +class PreprocessError(Exception): + pass + + +def span_intersect_span(span1: Tuple[int, int], span2: Tuple[int, int]): + """Returns True if the given spans intersect""" + return (span1[0] <= span2[0] < span1[1]) or (span2[0] <= span1[0] < span2[1]) + + +def span_intersect_spanlist(span: Tuple[int, int], target_spans: List[Tuple[int, int]]): + """Returns True if the given spans intersect with any in the given list""" + for t in target_spans: + if span_intersect_span(span, t): + return True + return False + + +def spanlist_intersect_spanlist(spans: List[Tuple[int, int]], target_spans: List[Tuple[int, int]]): + """Returns True if the given spans intersect with any in the given list""" + for s in spans: + if span_intersect_spanlist(s, target_spans): + return True + return False + + +def consolidate_spans(spans: List[Tuple[int, int]], caption: str, rec=True): + """Accepts a list of spans and the the corresponding caption. + Returns a cleaned list of spans where: + - Overlapping spans are merged + - It is guaranteed that spans start and end on a word + """ + sorted_spans = sorted(spans) + cur_end = -1 + cur_beg = None + final_spans: List[Tuple[int, int]] = [] + for s in sorted_spans: + if s[0] >= cur_end: + if cur_beg is not None: + final_spans.append((cur_beg, cur_end)) + cur_beg = s[0] + cur_end = max(cur_end, s[1]) + + if cur_beg is not None: + final_spans.append((cur_beg, cur_end)) + + # Now clean the begining/end + clean_spans: List[Tuple[int, int]] = [] + for s in final_spans: + beg, end = s + end = min(end, len(caption)) + while beg < len(caption) and not caption[beg].isalnum(): + beg += 1 + while end > 0 and not caption[end - 1].isalnum(): + end -= 1 + # Try to get hyphenated words + if end < len(caption) and caption[end] == "-": + # print("trigg") + next_space = caption.find(" ", end) + if next_space == -1: + end = len(caption) + else: + end = next_space + 1 + if beg > 0 and caption[beg - 1] == "-": + prev_space = caption.rfind(" ", 0, beg) + if prev_space == -1: + beg = 0 + else: + beg = prev_space + 1 + if 0 <= beg < end <= len(caption): + clean_spans.append((beg, end)) + if rec: + return consolidate_spans(clean_spans, caption, False) + return clean_spans + + +def get_canonical_spans(orig_spans: List[List[Tuple[int, int]]], orig_caption: str, whitespace_only=False): + """This functions computes the spans after reduction of the caption to it's normalized version + For example, if the caption is "There is a man wearing sneakers" and the span is [(11,14)] ("man"), + then the normalized sentence is "man wearing sneakers" so the new span is [(0,3)] + """ + # print("orig caption", orig_caption) + # print("orig spans", [orig_caption[t[0]:t[1]] for span in orig_spans for t in span]) + new_spans = [sorted(spans) for spans in orig_spans] + caption = orig_caption.lower() + + def remove_chars(pos, amount): + for i in range(len(new_spans)): + for j in range(len(new_spans[i])): + if pos >= new_spans[i][j][1]: + continue + beg, end = new_spans[i][j] + if span_intersect_span(new_spans[i][j], (pos, pos + amount)): + # assert new_spans[i][j][0] == pos or amount == 1, "unexpected deletion from middle of span" + new_spans[i][j] = (beg, end - amount) + else: + new_spans[i][j] = (beg - amount, end - amount) + + def change_chars(old_beg, old_end, delta): + for i in range(len(new_spans)): + for j in range(len(new_spans[i])): + if old_beg >= new_spans[i][j][1]: + continue + beg, end = new_spans[i][j] + if span_intersect_span(new_spans[i][j], (old_beg, old_end)): + if not (new_spans[i][j][0] <= old_beg < old_end <= new_spans[i][j][1]): + raise PreprocessError(f"deleted spans should be contained in known span") + assert ( + new_spans[i][j][0] <= old_beg < old_end <= new_spans[i][j][1] + ), "deleted spans should be contained in known span" + new_spans[i][j] = (beg, end + delta) + else: + new_spans[i][j] = (beg + delta, end + delta) + + # Pre pass, removing double spaces and leading spaces + # Check for leading spaces + while caption[0] == " ": + remove_chars(0, 1) + caption = caption[1:] + cur_start = 0 + pos = caption.find(" ", cur_start) + while pos != -1: + amount = 1 + # print("remvoing", removed, pos) + remove_chars(pos, amount) + caption = caption.replace(" ", " ", 1) + pos = caption.find(" ", cur_start) + # print("after whitespace caption", caption) + # print("after whitespace spans", [caption[t[0]:t[1]] for span in new_spans for t in span]) + if whitespace_only: + return new_spans, caption + + # First pass, removing punctuation + for punct in [".", ",", "!", "?", ":"]: + pos = caption.find(punct) + while pos != -1: + remove_chars(pos, len(punct)) + caption = caption.replace(punct, "", 1) + pos = caption.find(punct) + # print("after punct caption", caption) + # print("after punct spans", [caption[t[0]:t[1]] for span in new_spans for t in span]) + + # parsing needs to happen before stop words removal + all_tokens = nlp(caption) + tokens = [] + + # Second pass, removing stop words + ## Remove from tokenization + for t in all_tokens: + if str(t) not in STOP_WORDS: + tokens.append(t) + ## Remove from actual sentence + for stop in STOP_WORDS: + cur_start = 0 + pos = caption.find(stop, cur_start) + while pos != -1: + # Check that we are matching a full word + if (pos == 0 or caption[pos - 1] == " ") and ( + pos + len(stop) == len(caption) or caption[pos + len(stop)] == " " + ): + removed = stop + spaces = 0 + if pos + len(stop) < len(caption) and caption[pos + len(stop)] == " ": + removed += " " + spaces += 1 + if pos > 0 and caption[pos - 1] == " ": + removed = " " + removed + spaces += 1 + if spaces == 0: + raise PreprocessError( + f"No spaces found in '{caption}', position={pos}, stopword={stop}, len={len(stop)}" + ) + assert spaces > 0 + replaced = "" if spaces == 1 else " " + amount = len(removed) - len(replaced) + # print("remvoing", removed, pos) + remove_chars(pos, amount) + caption = caption.replace(removed, replaced, 1) + # print("cur caption", caption) + # print("cur spans", [caption[t[0]:t[1]] for span in new_spans for t in span if t[0] < t[1]]) + else: + cur_start += 1 + pos = caption.find(stop, cur_start) + + # print("final caption", caption) + # print("final spans", [caption[t[0]:t[1]] for span in new_spans for t in span if t[0] < t[1]]) + + # Third pass, lemmatization + final_caption = [] + if len(tokens) != len(caption.strip().split(" ")): + raise PreprocessError( + f"''{tokens}'', len={len(tokens)}, {caption.strip().split(' ')}, len={len(caption.strip().split(' '))}" + ) + + # tokens = nlp(caption) + cur_beg = 0 + for i, w in enumerate(caption.strip().split(" ")): + if tokens[i].lemma_[0] != "-": + # print(w, "lemmatized to", tokens[i].lemma_) + final_caption.append(tokens[i].lemma_) + change_chars(cur_beg, cur_beg + len(w), len(tokens[i].lemma_) - len(w)) + else: + # print(w, "skipped lemmatized to", tokens[i].lemma_) + final_caption.append(w) + cur_beg += 1 + len(final_caption[-1]) + # print("cur_beg", cur_beg) + # print("cur spans", [caption[t[0]:t[1]] for span in new_spans for t in span if t[0] < t[1]], new_spans) + + clean_caption = " ".join(final_caption) + # Cleanup empty spans + clean_spans = [] + for spans in new_spans: + cur = [] + for s in spans: + if 0 <= s[0] < s[1]: + cur.append(s) + clean_spans.append(cur) + + # print("clean caption", clean_caption) + # print("clean spans", [clean_caption[t[0]:t[1]] for span in clean_spans for t in span]) + return clean_spans, clean_caption + + +def shift_spans(spans: List[Tuple[int, int]], offset: int) -> List[Tuple[int, int]]: + final_spans = [] + for beg, end in spans: + final_spans.append((beg + offset, end + offset)) + return final_spans diff --git a/projects/mm_gdino_clip/script/utils/text.py b/projects/mm_gdino_clip/script/utils/text.py new file mode 100644 index 00000000000..b094bf4a239 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/text.py @@ -0,0 +1,135 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""Provides various text related util function""" +import re +from typing import List, Tuple + +import nltk +import spacy + +nlp = spacy.load("en_core_web_sm") + +nltk.download("stopwords") +from nltk.corpus import stopwords + +STOP_WORDS = set(stopwords.words("english")) - set(["above", "below", "between", "further", "he", "she", "they"]) + + +def get_noun_phrase(root): + queue = [root] + all_toks = [root] + while len(queue) > 0: + curr = queue.pop() + if curr.tag_ in ["NN", "NNS", "NNP", "NNPS"]: + queue += curr.lefts + all_toks += curr.lefts + return all_toks + + +def get_root_and_nouns(text: str, lazy=True) -> Tuple[str, str, List[Tuple[int, int]], List[Tuple[int, int]]]: + """Given a sentence, returns a tuple with the following items: + -- root text:str : the text associated with the root of the sentence + -- negative_text:str: all the text that shouldn't be positively matched with a box other than the main one + -- root_span: List[Tuple[int, int]] spans covering the root expressions, returned as a list of (beg, end) character spans + -- negative_span: List[Tuple[int, int]] spans covering the negative expressions, returned as a list of (beg, end) character spans + + If lazy is False, then we try a bit harder to find the precise root of the sentence + """ + sents = nlp(text) + negative_text = [] + + if len([x for x in sents if x.tag_ in ["NN", "NNS", "NNP", "NNPS", "PRP"]]) <= 1: + if lazy or len([x for x in sents if x.tag_ in ["NN", "NNS", "NNP", "NNPS", "PRP"]]) == 0: + return text, " ", [(0, len(text))], [(0, len(text))] + + root = None + for token in sents: + if token.dep_ == "ROOT": + if token.tag_ == "UH": + continue + root = token + break + + if root is None: + return text, "", [(0, len(text))], [(0, len(text))] + + if ( + len([c for c in root.children if c.tag_ in ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"] and c.dep_ == "compound"]) + > 0 + ): + return text, "", [(0, len(text))], [(0, len(text))] + + all_toks = [] + if root.tag_ in ["NN", "NNS", "NNP", "NNPS"]: + all_toks = get_noun_phrase(root) + root_text = " ".join([x.text for x in all_toks]) + root_spans = [(x.idx, x.idx + len(x.text)) for x in all_toks] + else: + root = [x for x in root.children if x.tag_ in ["NN", "NNS", "NNP", "NNPS", "PRP"]] + if len(root) < 1: + return text, "", [(0, len(text))], [(0, len(text))] + else: + root = root[0] + all_toks = list(root.lefts) + [root] + root_text = " ".join([x.text for x in all_toks]) + root_spans = [(x.idx, x.idx + len(x.text)) for x in all_toks] + + everything_else = set() + for token in sents: + if token.tag_ in ["NN", "NNS", "NNP", "NNPS"] and token.dep_ not in ["ROOT"] and token not in all_toks: + everything_else = everything_else.union(set(get_noun_phrase(token))) + + negative_tokens = set(sents) - set(everything_else) + negative_text = " ".join([x.text for x in negative_tokens]) + negative_spans = [(x.idx, x.idx + len(x.text)) for x in negative_tokens] + + return root_text, negative_text, root_spans, negative_spans + + +def normalize_sentence(sentence): + """Returns a list of non stopwords for the sentence, obtained after cleaning ponctuation and spaces""" + + sent = sentence.lower() + sent = remove_punctuation(sentence.lower()) + sent = normalize_whitespace(sent) + tokens = nlp(sent) + return " ".join( + [ + tokens[i].lemma_ if tokens[i].lemma_[0] != "-" else w + for i, w in enumerate(sent.split(" ")) + if w not in STOP_WORDS + ] + ) + + +def remove_punctuation(text): + """ + This function removes all ponctuation. + """ + corrected = str(text) + corrected = re.sub(r"([!?,;.:-])", r"", corrected) + return corrected + + +def simplify_punctuation(text): + """ + This function simplifies doubled or more complex punctuation. The exception is '...'. + """ + corrected = str(text) + corrected = re.sub(r"([!?,;:-])\1+", r"\1", corrected) + corrected = re.sub(r"\.{2,}", r"...", corrected) + corrected = re.sub(r"\s?-\s?", r"-", corrected) + return corrected + + +def normalize_whitespace(text): + """ + This function normalizes whitespaces, removing duplicates and converting all to standard spaces + """ + corrected = str(text) + corrected = re.sub(r"//t", r"\t", corrected) + corrected = re.sub(r"\n", r" ", corrected) + corrected = re.sub(r"_", r" ", corrected) + corrected = re.sub(r"\r", r" ", corrected) + corrected = re.sub(r"\t", r" ", corrected) + corrected = re.sub(r"\s+", r" ", corrected) + return corrected.strip(" ") diff --git a/projects/mm_gdino_clip/script/utils/unionfind.py b/projects/mm_gdino_clip/script/utils/unionfind.py new file mode 100644 index 00000000000..9617a7b20b5 --- /dev/null +++ b/projects/mm_gdino_clip/script/utils/unionfind.py @@ -0,0 +1,31 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +"""Simple union find structure implementation""" + + +class UnionFind: + """Optimized union find structure""" + + def __init__(self, n): + """Initialize a union find with n components""" + self.compo = list(range(n)) + self.weight = [1] * n + self.nb_compo = n + + def get_nb_compo(self): + return self.nb_compo + + def find(self, x): + if self.compo[x] == x: + return x + self.compo[x] = self.find(self.compo[x]) + return self.compo[x] + + def unite(self, a, b): + fa = self.find(a) + fb = self.find(b) + if fa != fb: + self.nb_compo -= 1 + if self.weight[fb] > self.weight[fa]: + fa, fb = fb, fa + self.compo[fb] = fa + self.weight[fa] += self.weight[fb] From c8e61633997f9d6807292156d649c51110f4e45c Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 9 Jan 2024 18:50:12 +0800 Subject: [PATCH 12/24] update configs --- ... grounding_dino_swin-t_pretrain_obj365.py} | 0 ...in-t_pretrain_obj365_goldg_grit9m_v3det.py | 117 ++++++++++++++++++ projects/mm_gdino_clip/script/gqa2rec.py | 3 +- 3 files changed, 118 insertions(+), 2 deletions(-) rename projects/mm_gdino_clip/configs/{grounding_dino_swin-t_pretrain_obj365_goldg.py => grounding_dino_swin-t_pretrain_obj365.py} (100%) create mode 100644 projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365.py similarity index 100% rename from projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg.py rename to projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365.py diff --git a/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py new file mode 100644 index 00000000000..16ca1491389 --- /dev/null +++ b/projects/mm_gdino_clip/configs/grounding_dino_swin-t_pretrain_obj365_goldg_grit9m_v3det.py @@ -0,0 +1,117 @@ +_base_ = 'grounding_dino_swin-t_pretrain_obj365.py' + +o365v1_od_dataset = dict( + type='ODVGRECDataset', + data_root='data/objects365v1/', + ann_file='o365v1_train_odvg.json', + label_map_file='o365v1_label_map.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None, +) + +flickr30k_dataset = dict( + type='ODVGRECDataset', + data_root='data/flickr30k/', + ann_file='flickr30k_separateGT_train_rec.json', + label_map_file=None, + data_prefix=dict(img='flickr30k_images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +gqa_dataset = dict( + type='ODVGRECDataset', + data_root='data/gqa/', + ann_file='gqa_rec.json', + label_map_file=None, + data_prefix=dict(img='images/'), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +v3d_train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict( + type='RandomSamplingNegPosV2', + tokenizer_name=_base_.lang_model_name, + num_sample_negative=85, + # change this + label_map_file='data/V3Det/annotations/v3det_2023_v1_label_map.json', + max_tokens=256), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction', 'text', + 'custom_entities', 'tokens_positive', 'dataset_mode')) +] +v3det_dataset = dict( + type='ODVGDataset', + data_root='data/V3Det/', + ann_file='annotations/v3det_2023_v1_train_od.json', + label_map_file='annotations/v3det_2023_v1_label_map.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + need_text=False, # change this + pipeline=v3d_train_pipeline, + return_classes=True, + backend_args=None) + +grit_dataset = dict( + type='ODVGDataset', + data_root='grit_processed/', + ann_file='grit20m_rec.json', + label_map_file=None, + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=False), + pipeline=_base_.train_pipeline, + return_classes=True, + backend_args=None) + +train_dataloader = dict( + sampler=dict( + _delete_=True, + type='CustomSampleSizeSampler', + dataset_size=[-1, -1, -1, -1, 500000]), + dataset=dict(datasets=[ + o365v1_od_dataset, flickr30k_dataset, gqa_dataset, v3det_dataset, + grit_dataset + ])) diff --git a/projects/mm_gdino_clip/script/gqa2rec.py b/projects/mm_gdino_clip/script/gqa2rec.py index 677f776efa7..7c0e55f4422 100644 --- a/projects/mm_gdino_clip/script/gqa2rec.py +++ b/projects/mm_gdino_clip/script/gqa2rec.py @@ -33,8 +33,7 @@ # pip install nltk spacy - - +# python -m spacy download en_core_web_sm def parse_args(): parser = argparse.ArgumentParser("Conversion script") From dbf0bc4af15f6e6addddfcc58fa2923ec6a59edf Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 9 Jan 2024 19:05:37 +0800 Subject: [PATCH 13/24] fix bug --- projects/mm_gdino_clip/odvgrec.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/projects/mm_gdino_clip/odvgrec.py b/projects/mm_gdino_clip/odvgrec.py index b6afec51cb8..022724d70bf 100644 --- a/projects/mm_gdino_clip/odvgrec.py +++ b/projects/mm_gdino_clip/odvgrec.py @@ -89,17 +89,20 @@ def load_data_list(self) -> List[dict]: i = 0 for bbox, exp, label in zip(bboxes, bbox_exp, bbox_labels): instance = {} - x1, y1, x2, y2 = bbox - inter_w = max(0, min(x2, data['width']) - max(x1, 0)) - inter_h = max(0, min(y2, data['height']) - max(y1, 0)) - if inter_w * inter_h == 0: - continue - if (x2 - x1) < 1 or (y2 - y1) < 1: - continue - instance['ignore_flag'] = 0 - instance['bbox'] = bbox - instance['bbox_label'] = int(label) - instances.append(instance) + if not isinstance(bbox[0], list): + bbox = [bbox] + for b in bbox: + x1, y1, x2, y2 = b + inter_w = max(0, min(x2, data['width']) - max(x1, 0)) + inter_h = max(0, min(y2, data['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = int(label) + instances.append(instance) phrases[i] = exp i += 1 From 6db3637da6781cdb64a7d7c462b8f148aa8be28a Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 9 Jan 2024 19:10:35 +0800 Subject: [PATCH 14/24] fix bug --- mmdet/models/layers/transformer/grounding_dino_layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py index e1344793c51..50e858c63d3 100644 --- a/mmdet/models/layers/transformer/grounding_dino_layers.py +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -238,17 +238,17 @@ def forward(self, layer_id].self_attn_cfg.num_heads if text_self_attention_masks is None: # rec - key_padding_mask = text_attention_mask + l_key_padding_mask = text_attention_mask else: # phrase grounding - key_padding_mask = None + l_key_padding_mask = None memory_text = self.text_layers[layer_id]( query=memory_text, query_pos=(pos_text if pos_text is not None else None), attn_mask=~text_self_attention_masks.repeat( text_num_heads, 1, 1) if text_self_attention_masks is not None else None, # note we use ~ for mask here - key_padding_mask=key_padding_mask, + key_padding_mask=l_key_padding_mask, ) output = layer( query=output, From 54a0cfeacb7fcf20f96907fed099726c858226b6 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 9 Jan 2024 19:16:25 +0800 Subject: [PATCH 15/24] fix bug --- projects/mm_gdino_clip/odvgrec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/mm_gdino_clip/odvgrec.py b/projects/mm_gdino_clip/odvgrec.py index 022724d70bf..fce3b17cb11 100644 --- a/projects/mm_gdino_clip/odvgrec.py +++ b/projects/mm_gdino_clip/odvgrec.py @@ -100,7 +100,7 @@ def load_data_list(self) -> List[dict]: if (x2 - x1) < 1 or (y2 - y1) < 1: continue instance['ignore_flag'] = 0 - instance['bbox'] = bbox + instance['bbox'] = b instance['bbox_label'] = int(label) instances.append(instance) phrases[i] = exp From 60ef88857a980fcf3126d2a477d88c1eefb680e3 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 9 Jan 2024 19:25:50 +0800 Subject: [PATCH 16/24] fix bug --- projects/mm_gdino_clip/text_transformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/projects/mm_gdino_clip/text_transformers.py b/projects/mm_gdino_clip/text_transformers.py index b2862241bbd..2951c71bb83 100644 --- a/projects/mm_gdino_clip/text_transformers.py +++ b/projects/mm_gdino_clip/text_transformers.py @@ -168,9 +168,11 @@ def rec_aug(self, results): keys, size=num_negatives, replace=False): if i not in results['img_path']: others_exp = results['image_to_exp'][i] + if len(others_exp) == 0: + continue if isinstance(others_exp, list): others_exp = random.choice(others_exp) - if isinstance(others_exp, list): + if isinstance(others_exp, list) and len(others_exp) > 0: others_exp = random.choice(others_exp) negative_label_list.add(others_exp) From 5aae2a5ba828f87b29be6dc639dece582f120c86 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 10 Jan 2024 11:31:54 +0800 Subject: [PATCH 17/24] fix ass bug --- mmdet/models/task_modules/assigners/hungarian_assigner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mmdet/models/task_modules/assigners/hungarian_assigner.py b/mmdet/models/task_modules/assigners/hungarian_assigner.py index a6745a36cdc..64afa37e9a9 100644 --- a/mmdet/models/task_modules/assigners/hungarian_assigner.py +++ b/mmdet/models/task_modules/assigners/hungarian_assigner.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Union - +import numpy as np import torch from mmengine import ConfigDict from mmengine.structures import InstanceData @@ -128,6 +128,11 @@ def assign(self, raise ImportError('Please run "pip install scipy" ' 'to install scipy first.') + has_nan = np.isnan(cost).any() + if has_nan: + print(f' has nan {cost}, replace to 10000000.0') + cost[np.isnan(cost)] = 10000000.0 + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds = torch.from_numpy(matched_row_inds).to(device) matched_col_inds = torch.from_numpy(matched_col_inds).to(device) From 9a226a6660d6f6682082be7b0f3460c97455ce15 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 10 Jan 2024 11:34:38 +0800 Subject: [PATCH 18/24] add file --- .../mm_gdino_clip/script/merge_vg_to_rec.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 projects/mm_gdino_clip/script/merge_vg_to_rec.py diff --git a/projects/mm_gdino_clip/script/merge_vg_to_rec.py b/projects/mm_gdino_clip/script/merge_vg_to_rec.py new file mode 100644 index 00000000000..872cfdfedd4 --- /dev/null +++ b/projects/mm_gdino_clip/script/merge_vg_to_rec.py @@ -0,0 +1,69 @@ +import json +import jsonlines + +root_path = '/home/PJLAB/huanghaian/dataset/flickr30k_entities/' +rec_path = root_path + 'flickr30k_separateGT_train_rec.json' +vg_path = root_path + 'final_flickr_separateGT_train_vg.json' + +with open(rec_path, 'r') as f: + rec_data_list = [json.loads(line) for line in f] + +rec_data_list_name = [data['filename'] for data in rec_data_list] + +with open(vg_path, 'r') as f: + vg_data_list = [json.loads(line) for line in f] + +num = 0 +in_num = 0 + +for vg_data in vg_data_list: + anno = vg_data['grounding'] + regions = anno['regions'] + + if len(regions) > 1: + continue + + filename = vg_data['filename'] + caption = anno['caption'] + bbox = regions[0]['bbox'] + + index = rec_data_list_name.index(filename) + if index == -1: + continue + + if not isinstance(bbox[0], list): + bbox = [bbox] + bbox = set([sum(r) for r in bbox]) + + rec_data = rec_data_list[index] + anno = rec_data.get('referring', {}) + instances = [obj for obj in anno.get('instances', [])] + for ins in instances: + rec_bbox = ins['bbox'] + if not isinstance(rec_bbox[0], list): + rec_bbox = [rec_bbox] + rec_bbox = set([sum(r) for r in rec_bbox]) + if rec_bbox == bbox: + if isinstance(ins['exp'], list): + is_same = False + for exp in ins['exp']: + if exp.lower() == caption.lower(): + is_same = True + break + if not is_same: + in_num += 1 + ins['exp'].append(caption) + else: + if ins['exp'].lower() != caption.lower(): + in_num += 1 + ins['exp'] = [ins['exp'], caption] + break + num += 1 + +print(num) +print(in_num) + +out_path = root_path + 'flickr30k_separateGT_train_mergevg_rec.json' +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(rec_data_list) +print(f'save to {out_path}') From af680d6776529ba4753cb850b0b82b5ea2eabbd7 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 10 Jan 2024 15:43:17 +0800 Subject: [PATCH 19/24] add file --- .../mm_gdino_clip/script/grit_vg_to_rec.py | 22 ++++ ..._vg_to_rec.py => merge_flickrvg_to_rec.py} | 6 +- .../script/merge_gqavg_to_rec.py | 112 ++++++++++++++++++ 3 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 projects/mm_gdino_clip/script/grit_vg_to_rec.py rename projects/mm_gdino_clip/script/{merge_vg_to_rec.py => merge_flickrvg_to_rec.py} (94%) create mode 100644 projects/mm_gdino_clip/script/merge_gqavg_to_rec.py diff --git a/projects/mm_gdino_clip/script/grit_vg_to_rec.py b/projects/mm_gdino_clip/script/grit_vg_to_rec.py new file mode 100644 index 00000000000..f8183e3ad29 --- /dev/null +++ b/projects/mm_gdino_clip/script/grit_vg_to_rec.py @@ -0,0 +1,22 @@ +import json +import jsonlines + +root_path = '/mnt/workspace/zhaoxiangyu/code_new/grounding_mm_mine/grit_try/' +grit_path = root_path + 'grit_ref_all_after_filter.jsonl' + +with open(grit_path, 'r') as f: + rec_data_list = [json.loads(line) for line in f] + +for data in rec_data_list: + referring = data['referring'] + new_dict = {} + for ref in referring: + new_dict['exp'] = ref['phrase'] + new_dict['bbox'] = ref['bbox'] + data['referring'] = {} + data['referring']['instances'] = new_dict + +out_path = root_path + 'grit_ref_all_after_filter_rec.json' +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(rec_data_list) +print(f'save to {out_path}') diff --git a/projects/mm_gdino_clip/script/merge_vg_to_rec.py b/projects/mm_gdino_clip/script/merge_flickrvg_to_rec.py similarity index 94% rename from projects/mm_gdino_clip/script/merge_vg_to_rec.py rename to projects/mm_gdino_clip/script/merge_flickrvg_to_rec.py index 872cfdfedd4..5fd72f0f591 100644 --- a/projects/mm_gdino_clip/script/merge_vg_to_rec.py +++ b/projects/mm_gdino_clip/script/merge_flickrvg_to_rec.py @@ -20,6 +20,7 @@ anno = vg_data['grounding'] regions = anno['regions'] + # 每个 caption 只有一个 phrase if len(regions) > 1: continue @@ -43,6 +44,7 @@ if not isinstance(rec_bbox[0], list): rec_bbox = [rec_bbox] rec_bbox = set([sum(r) for r in rec_bbox]) + # 严格匹配 if rec_bbox == bbox: if isinstance(ins['exp'], list): is_same = False @@ -60,8 +62,8 @@ break num += 1 -print(num) -print(in_num) +print(num) # 17233 +print(in_num) # 17111 out_path = root_path + 'flickr30k_separateGT_train_mergevg_rec.json' with jsonlines.open(out_path, mode='w') as writer: diff --git a/projects/mm_gdino_clip/script/merge_gqavg_to_rec.py b/projects/mm_gdino_clip/script/merge_gqavg_to_rec.py new file mode 100644 index 00000000000..297ba3d3c0a --- /dev/null +++ b/projects/mm_gdino_clip/script/merge_gqavg_to_rec.py @@ -0,0 +1,112 @@ +import json +import jsonlines +import re +import tqdm + +root_path = '/home/PJLAB/huanghaian/dataset/gqa/' +rec_path = root_path + 'gqa_rec.json' +vg_path = root_path + 'final_mixed_train_no_coco_vg.json' + +with open(rec_path, 'r') as f: + rec_data_list = [json.loads(line) for line in f] + +rec_data_list_name = [data['filename'] for data in rec_data_list] + +with open(vg_path, 'r') as f: + vg_data_list = [json.loads(line) for line in f] + + +def split_sentence(sentence): + pattern = r'([?.])' # 正则表达式模式,匹配问号 "?" 或句号 "." + sentences = re.split(pattern, sentence) + sentences = [s.strip() + p for s, p in zip(sentences[0::2], sentences[1::2])] + return sentences + + +num = 0 +in_num = 0 +new_results = [] + +for vg_data in tqdm.tqdm(vg_data_list): + filename = vg_data['filename'] + anno = vg_data['grounding'] + regions = anno['regions'] + all_phrase = [r['phrase'] for r in regions] + caption = anno['caption'] + # 按照分隔符切割为多段 + caption_list = split_sentence(caption) + + for caption in caption_list: + if caption.endswith('?'): # 问句不要了 + continue + count = 0 + for i, p in enumerate(all_phrase): + # 如果这个 phrase 是列表,则抛弃 + if isinstance(p, list): + break + # 如果这个 caption 位于多个 phrase 中,则抛弃 + if p in caption: + index = i + count += 1 + if count > 1 or count == 0: + continue + num += 1 + + # 我们只需要这个 caption 中只有一个名词短语的数据 + data = regions[index] + new_results.append({'bbox': data['bbox'], 'exp': caption, 'filename': filename, 'height': vg_data['height'], + 'width': vg_data['width']}) + +print(num) # 989203 +print(len(new_results), new_results[0]) + +new_image = 0 +for new in tqdm.tqdm(new_results): + filename = new.pop('filename') + width = new.pop('width') + height = new.pop('height') + bbox = new['bbox'] + caption = new['exp'] + if not isinstance(bbox[0], list): + bbox = [bbox] + new_bbox = set([sum(r) for r in bbox]) + + if filename not in rec_data_list_name: + new_image += 1 + rec_data_list.append({'filename': filename, 'width': width, 'height': height, + 'referring': {'instances': [{'bbox': new['bbox'], 'exp': new['exp']}]}}) + rec_data_list_name = [data['filename'] for data in rec_data_list] + else: + index = rec_data_list_name.index(filename) + rec_data = rec_data_list[index] + anno = rec_data.get('referring', {}) + instances = [obj for obj in anno.get('instances', [])] + for ins in instances: + rec_bbox = ins['bbox'] + if not isinstance(rec_bbox[0], list): + rec_bbox = [rec_bbox] + rec_bbox = set([sum(r) for r in rec_bbox]) + # 非常严格的匹配策略,确保不会出现错误 + if rec_bbox == new_bbox: + if isinstance(ins['exp'], list): + is_same = False + for exp in ins['exp']: + if exp.lower() == caption.lower(): + is_same = True + break + if not is_same: + in_num += 1 + ins['exp'].append(caption) + else: + if ins['exp'].lower() != caption.lower(): + in_num += 1 + ins['exp'] = [ins['exp'], caption] + break + +print(in_num) # 47266 +print(new_image) # 12052 + +out_path = root_path + 'gqa_mergevg_rec.json' +with jsonlines.open(out_path, mode='w') as writer: + writer.write_all(rec_data_list) +print(f'save to {out_path}') From 22b3359b72678877721688afbcea3ad46ea400cf Mon Sep 17 00:00:00 2001 From: huanghaian Date: Fri, 12 Jan 2024 09:56:49 +0800 Subject: [PATCH 20/24] fix bug --- .../models/dense_heads/grounding_dino_head.py | 19 ++++++++++- projects/mm_gdino_clip/batch_sampler.py | 20 ++++++++++- projects/mm_gdino_clip/text_transformers.py | 33 +++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py index e00e93467e2..468acc17aeb 100644 --- a/mmdet/models/dense_heads/grounding_dino_head.py +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -594,6 +594,9 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, else: loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + if torch.isnan(loss_cls): + print(f'has nan of loss_cls') + loss_cls = cls_scores.sum() * 0 # Compute the average number of gt boxes across all gpus, for # normalization purposes @@ -620,10 +623,15 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, # regression IoU loss, defaultly GIoU loss loss_iou = self.loss_iou( bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) - + if torch.isnan(loss_iou): + print(f'has nan of loss_iou') + loss_iou = bboxes.sum() * 0 # regression L1 loss loss_bbox = self.loss_bbox( bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + if torch.isnan(loss_bbox): + print(f'has nan of loss_bbox') + loss_bbox = bbox_preds.sum() * 0 return loss_cls, loss_bbox, loss_iou def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, @@ -703,6 +711,9 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, labels, label_weights, avg_factor=cls_avg_factor) + if torch.isnan(loss_cls): + print(f'has nan of dn loss_cls') + loss_cls = cls_scores.sum() * 0 else: loss_cls = torch.zeros( 1, dtype=cls_scores.dtype, device=cls_scores.device) @@ -732,10 +743,16 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, # regression IoU loss, defaultly GIoU loss loss_iou = self.loss_iou( bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + if torch.isnan(loss_iou): + print(f'has nan of dn loss_iou') + loss_iou = bboxes.sum() * 0 # regression L1 loss loss_bbox = self.loss_bbox( bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + if torch.isnan(loss_bbox): + print(f'has nan of dn loss_bbox') + loss_bbox = bbox_preds.sum() * 0 return loss_cls, loss_bbox, loss_iou def _get_dn_targets_single(self, gt_instances: InstanceData, diff --git a/projects/mm_gdino_clip/batch_sampler.py b/projects/mm_gdino_clip/batch_sampler.py index 9391356a4e6..2124961cf75 100644 --- a/projects/mm_gdino_clip/batch_sampler.py +++ b/projects/mm_gdino_clip/batch_sampler.py @@ -28,7 +28,8 @@ def __init__(self, assert drop_last is True def __iter__(self) -> Sequence[int]: - + batch_count = 0 + total_count = len(self.sampler) // self.batch_size for idx in self.sampler: wh_mode = self.sampler.dataset.get_wh_mode(idx) dataset_mode, height, width = wh_mode @@ -39,6 +40,7 @@ def __iter__(self) -> Sequence[int]: # TODO # if np.random.random() >= 1 - self.od_to_rec_prob: # dataset_mode = 'REC' + dataset_mode = 'REC' od_to_rec_flag = True else: od_to_rec_flag = False @@ -54,6 +56,7 @@ def __iter__(self) -> Sequence[int]: # yield a batch of indices in the same aspect ratio group if len(bucket) == self.batch_size: yield bucket[:] + batch_count += 1 del bucket[:] # yield the rest data and reset the bucket @@ -61,17 +64,29 @@ def __iter__(self) -> Sequence[int]: left_vg_data = self._aspect_ratio_buckets[1] + self._aspect_ratio_buckets[3] self._aspect_ratio_buckets = [[] for _ in range(2 * 2)] + if batch_count >= total_count: + left_rec_data = [] + left_vg_data = [] + while len(left_rec_data) > 0: if len(left_rec_data) > self.batch_size: yield left_rec_data[:self.batch_size] + batch_count += 1 left_rec_data = left_rec_data[self.batch_size:] + if batch_count >= total_count: + left_rec_data = [] + left_vg_data = [] else: break while len(left_vg_data) > 0: if len(left_vg_data) > self.batch_size: yield left_vg_data[:self.batch_size] + batch_count += 1 left_vg_data = left_vg_data[self.batch_size:] + if batch_count >= total_count: + left_rec_data = [] + left_vg_data = [] else: break @@ -84,7 +99,10 @@ def __iter__(self) -> Sequence[int]: all_left_data = left_rec_data + left_vg_data while len(all_left_data) > 0: yield all_left_data[:self.batch_size] + batch_count += 1 all_left_data = all_left_data[self.batch_size:] + if batch_count >= total_count: + all_left_data = [] def __len__(self) -> int: if self.drop_last: diff --git a/projects/mm_gdino_clip/text_transformers.py b/projects/mm_gdino_clip/text_transformers.py index 2951c71bb83..1c6694ef13a 100644 --- a/projects/mm_gdino_clip/text_transformers.py +++ b/projects/mm_gdino_clip/text_transformers.py @@ -19,6 +19,31 @@ import numpy as np +def clean_string(phrase): + # return re.sub(r"([.,'!?\"()*#:;])", "", phrase.lower()).replace("-", " ").replace("/", " ") + + phrase = re.sub(r"([.,'!?\"()*#:;])", "", phrase.lower()).replace("-", " ").replace("/", " ") + phrase = phrase.strip("\n").strip("\r").strip().lstrip(" ").rstrip(" ") + phrase = re.sub(" +", " ", phrase) + + replacements = { + "½": "half", + "—": "-", + "™": "", + "¢": "cent", + "ç": "c", + "û": "u", + "é": "e", + "°": " degree", + "è": "e", + "…": "", + } + for k, v in replacements.items(): + phrase = phrase.replace(k, v) + + return phrase + + def clean_name(name): name = re.sub(r'\(.*\)', '', name) name = re.sub(r'_', ' ', name) @@ -204,6 +229,14 @@ def rec_aug(self, results): results['gt_bboxes'] = gt_bboxes results['gt_bboxes_labels'] = gt_labels + + new_text = [clean_string(phrase) for phrase in new_text] + + if results.get('flip', False): + new_text = [ + phrase.replace("left", "@").replace("right", "left").replace("@", "right") + for phrase in new_text + ] results['text'] = new_text else: # OD valid_negative_indexes = list(text.keys()) From 549c68eb6211b184746f8a3d1e6133bb6a08f71a Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 16 Jan 2024 13:07:41 +0800 Subject: [PATCH 21/24] support rec eval --- mmdet/models/detectors/grounding_dino.py | 3 ++- mmdet/models/language_models/bert.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py index 43af4169571..fabdb2979fa 100644 --- a/mmdet/models/detectors/grounding_dino.py +++ b/mmdet/models/detectors/grounding_dino.py @@ -353,7 +353,8 @@ def pre_decoder( output_memory, output_proposals = self.gen_encoder_output_proposals( memory, memory_mask, spatial_shapes) - if 'tokens_positive' in batch_data_samples[0] or 'token_positive_map' in batch_data_samples[0]: + if ('tokens_positive' in batch_data_samples[0] and batch_data_samples[0].tokens_positive !=-1) \ + or 'token_positive_map' in batch_data_samples[0]: need_expand = True else: need_expand = False diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py index 0fec27acfb8..ad1156fde64 100644 --- a/mmdet/models/language_models/bert.py +++ b/mmdet/models/language_models/bert.py @@ -153,7 +153,10 @@ def forward(self, captions: Sequence[str], task='VG', **kwargs) -> dict: if task == 'REC': batch_len_captions = [len(item) for item in captions] - captions = [item for sublist in captions for item in sublist] + if isinstance(captions, tuple): + captions=list(captions) + if isinstance(captions[0], (list, tuple)): + captions = [item for sublist in captions for item in sublist] tokenized = self.tokenizer.batch_encode_plus( captions, From fa121846d0f7a38d6b7d737797a5f5be5f577dde Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 16 Jan 2024 13:07:44 +0800 Subject: [PATCH 22/24] support rec eval --- mmdet/engine/hooks/visualization_hook.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mmdet/engine/hooks/visualization_hook.py b/mmdet/engine/hooks/visualization_hook.py index 3408186b6ef..90df932f9af 100644 --- a/mmdet/engine/hooks/visualization_hook.py +++ b/mmdet/engine/hooks/visualization_hook.py @@ -390,7 +390,7 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, gt_bboxes = gt_instances.get('bboxes', None) if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes): gt_instances.bboxes = gt_bboxes.tensor - print(gt_labels, tokens_positive, gt_bboxes, img_path) + # print(gt_labels, tokens_positive, gt_bboxes, img_path) pred_instances = data_sample.pred_instances pred_instances = pred_instances[ pred_instances.scores > self.score_thr] @@ -416,8 +416,8 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, self._visualizer.set_image(img) for label, bbox, color in zip(gt_labels, gt_bboxes, colors): - self._visualizer.draw_bboxes( - bbox, edge_colors=color, face_colors=color, alpha=0.3) + # self._visualizer.draw_bboxes( + # bbox, edge_colors=color, face_colors=color, alpha=0.3) self._visualizer.draw_bboxes( bbox, edge_colors=color, alpha=1) @@ -460,11 +460,11 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, for label, bbox, color in zip(pred_labels, pred_bboxes, colors): - self._visualizer.draw_bboxes( - bbox, edge_colors=color, face_colors=color, alpha=0.3) + # self._visualizer.draw_bboxes( + # bbox, edge_colors=color, face_colors=color, alpha=0.3) self._visualizer.draw_bboxes( bbox, edge_colors=color, alpha=1) - print(pred_labels, pred_bboxes, pred_scores, colors) + # print(pred_labels, pred_bboxes, pred_scores, colors) areas = (pred_bboxes[:, 3] - pred_bboxes[:, 1]) * ( pred_bboxes[:, 2] - pred_bboxes[:, 0]) scales = _get_adaptive_scales(areas) From 3593a157a7a451d30ae0fb1d568bc332f0c34bb5 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 16 Jan 2024 19:36:30 +0800 Subject: [PATCH 23/24] fix attention bug --- .../layers/transformer/grounding_dino_layers.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py index 50e858c63d3..43e4cf3a8ed 100644 --- a/mmdet/models/layers/transformer/grounding_dino_layers.py +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -238,15 +238,23 @@ def forward(self, layer_id].self_attn_cfg.num_heads if text_self_attention_masks is None: # rec - l_key_padding_mask = text_attention_mask + # l_key_padding_mask = text_attention_mask + # text_self_attention_masks1=None + + l_key_padding_mask = None + text_self_attention_masks1 = \ + torch.eye(text_attention_mask.shape[1], + device=memory_text.device).bool().unsqueeze(0).repeat( + bs, 1, 1) else: # phrase grounding l_key_padding_mask = None + text_self_attention_masks1 = text_self_attention_masks memory_text = self.text_layers[layer_id]( query=memory_text, query_pos=(pos_text if pos_text is not None else None), - attn_mask=~text_self_attention_masks.repeat( - text_num_heads, 1, 1) if text_self_attention_masks is not None else None, + attn_mask=~text_self_attention_masks1.repeat( + text_num_heads, 1, 1) if text_self_attention_masks1 is not None else None, # note we use ~ for mask here key_padding_mask=l_key_padding_mask, ) From 3dc97c34b5f6e0975281f6b1cce134cfd0a8a759 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Fri, 19 Jan 2024 17:21:51 +0800 Subject: [PATCH 24/24] fix dataset bug --- projects/mm_gdino_clip/odvgrec.py | 2 +- projects/mm_gdino_clip/text_transformers.py | 22 +++++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/projects/mm_gdino_clip/odvgrec.py b/projects/mm_gdino_clip/odvgrec.py index fce3b17cb11..d0be2b27a8a 100644 --- a/projects/mm_gdino_clip/odvgrec.py +++ b/projects/mm_gdino_clip/odvgrec.py @@ -88,10 +88,10 @@ def load_data_list(self) -> List[dict]: instances = [] i = 0 for bbox, exp, label in zip(bboxes, bbox_exp, bbox_labels): - instance = {} if not isinstance(bbox[0], list): bbox = [bbox] for b in bbox: + instance = {} x1, y1, x2, y2 = b inter_w = max(0, min(x2, data['width']) - max(x1, 0)) inter_h = max(0, min(y2, data['height']) - max(y1, 0)) diff --git a/projects/mm_gdino_clip/text_transformers.py b/projects/mm_gdino_clip/text_transformers.py index 1c6694ef13a..99d50b00cf5 100644 --- a/projects/mm_gdino_clip/text_transformers.py +++ b/projects/mm_gdino_clip/text_transformers.py @@ -204,11 +204,25 @@ def rec_aug(self, results): random.shuffle(positive_label_list) negative_label_list = list(negative_label_list) - random.shuffle(negative_label_list) - - label_list = positive_label_list + negative_label_list text = results['text'] # dict + _pos_texts = [text[p] for p in positive_label_list] + _flat_pos_texts = [] + for p in _pos_texts: + if isinstance(p, list): + p = [_p.lower() for _p in p] + _flat_pos_texts.extend(p) + else: + _flat_pos_texts.append(p.lower()) + _negative_label_list = [] + for n in negative_label_list: + if n.lower() not in _flat_pos_texts: + _negative_label_list.append(n) + negative_label_list = _negative_label_list + if len(negative_label_list) > 0: + random.shuffle(negative_label_list) + + label_list = positive_label_list + negative_label_list random.shuffle(label_list) @@ -266,7 +280,7 @@ def rec_aug(self, results): for i in np.random.choice( valid_negative_indexes, size=num_negatives, replace=False): - if i not in positive_label_list: + if int(i) not in positive_label_list: negative_label_list.add(i) random.shuffle(positive_label_list)