Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions mmrotate/core/evaluation/eval_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,20 @@ def print_map_summary(mean_ap,
if scale_ranges is not None:
assert len(scale_ranges) == num_scales

num_classes = len(results)
# num_classes = len(results)
# recalls = np.zeros((num_scales, num_classes), dtype=np.float32)

####################################################################################################################
num_classes = len(dataset) if dataset else len(results)
num_results_classes = len(results)
min_classes = min(num_classes, num_results_classes)
recalls = np.zeros((num_scales, min_classes), dtype=np.float32)
####################################################################################################################

recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
aps = np.zeros((num_scales, num_classes), dtype=np.float32)
num_gts = np.zeros((num_scales, num_classes), dtype=int)
for i, cls_result in enumerate(results):
# for i, cls_result in enumerate(results):
for i, cls_result in enumerate(results[:min_classes]):
if cls_result['recall'].size > 0:
recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
aps[:, i] = cls_result['ap']
Expand All @@ -290,7 +298,8 @@ def print_map_summary(mean_ap,
if dataset is None:
label_names = [str(i) for i in range(num_classes)]
else:
label_names = dataset
# label_names = dataset
label_names = dataset[:min_classes]

if not isinstance(mean_ap, list):
mean_ap = [mean_ap]
Expand All @@ -302,11 +311,13 @@ def print_map_summary(mean_ap,
table_data = [header]
for j in range(num_classes):
row_data = [
label_names[j], num_gts[i, j], results[j]['num_dets'],
# label_names[j], num_gts[i, j], results[j]['num_dets'],
label_names[j], num_gts[i, j], results[j].get('num_dets', 0),
f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}'
]
table_data.append(row_data)
table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}'])
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)

4 changes: 3 additions & 1 deletion mmrotate/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _crop_data(self, results, crop_size, allow_negative_crop):
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
img_shape = img.shape
results[key] = img
results['img_shape'] = img_shape
results['img_shape'] = img_shape[:2]

height, width, _ = img_shape

Expand All @@ -375,6 +375,8 @@ def _crop_data(self, results, crop_size, allow_negative_crop):
if (key == 'gt_bboxes' and not valid_inds.any()
and not allow_negative_crop):
return None

valid_inds = np.atleast_1d(valid_inds)
results[key] = bboxes[valid_inds, :]
# label fields. e.g. gt_labels and gt_labels_ignore
label_key = self.bbox2label.get(key)
Expand Down
6 changes: 5 additions & 1 deletion tools/analysis_tools/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def plot_confusion_matrix(confusion_matrix,
if save_dir is not None:
plt.savefig(
os.path.join(save_dir, 'confusion_matrix.png'), format='png')
plt.savefig(
os.path.join(save_dir, 'confusion_matrix.svg'), format='svg')
if show:
plt.show()

Expand Down Expand Up @@ -261,8 +263,10 @@ def main():
confusion_matrix,
dataset.CLASSES + ('background', ),
save_dir=args.save_dir,
show=args.show)
show=args.show,
color_theme=args.color_theme)


if __name__ == '__main__':
main()