Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 111 additions & 55 deletions rfdetr/util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,65 +48,121 @@ def get_array(key):
epochs = get_array('epoch')
train_loss = get_array('train_loss')
test_loss = get_array('test_loss')
test_coco_eval = [h['test_coco_eval_bbox'] for h in self.history if 'test_coco_eval_bbox' in h]
ap50_90 = np.array([safe_index(x, 0) for x in test_coco_eval if x is not None], dtype=np.float32)
ap50 = np.array([safe_index(x, 1) for x in test_coco_eval if x is not None], dtype=np.float32)
ar50_90 = np.array([safe_index(x, 8) for x in test_coco_eval if x is not None], dtype=np.float32)

ema_coco_eval = [h['ema_test_coco_eval_bbox'] for h in self.history if 'ema_test_coco_eval_bbox' in h]
ema_ap50_90 = np.array([safe_index(x, 0) for x in ema_coco_eval if x is not None], dtype=np.float32)
ema_ap50 = np.array([safe_index(x, 1) for x in ema_coco_eval if x is not None], dtype=np.float32)
ema_ar50_90 = np.array([safe_index(x, 8) for x in ema_coco_eval if x is not None], dtype=np.float32)

fig, axes = plt.subplots(2, 2, figsize=(18, 12))
test_coco_eval_bbox = [h['test_coco_eval_bbox'] for h in self.history if 'test_coco_eval_bbox' in h]
ap50_90_bbox = np.array([safe_index(x, 0) for x in test_coco_eval_bbox if x is not None], dtype=np.float32)
ap50_bbox = np.array([safe_index(x, 1) for x in test_coco_eval_bbox if x is not None], dtype=np.float32)
ar50_90_bbox = np.array([safe_index(x, 8) for x in test_coco_eval_bbox if x is not None], dtype=np.float32)

ema_coco_eval_bbox = [h['ema_test_coco_eval_bbox'] for h in self.history if 'ema_test_coco_eval_bbox' in h]
ema_ap50_90_bbox = np.array([safe_index(x, 0) for x in ema_coco_eval_bbox if x is not None], dtype=np.float32)
ema_ap50_bbox = np.array([safe_index(x, 1) for x in ema_coco_eval_bbox if x is not None], dtype=np.float32)
ema_ar50_90_bbox = np.array([safe_index(x, 8) for x in ema_coco_eval_bbox if x is not None], dtype=np.float32)

test_coco_eval_masks = [h['test_coco_eval_masks'] for h in self.history if 'test_coco_eval_masks' in h]
ap50_90_masks = np.array([safe_index(x, 0) for x in test_coco_eval_masks if x is not None], dtype=np.float32)
ap50_masks = np.array([safe_index(x, 1) for x in test_coco_eval_masks if x is not None], dtype=np.float32)
ar50_90_masks = np.array([safe_index(x, 8) for x in test_coco_eval_masks if x is not None], dtype=np.float32)

ema_coco_eval_masks = [h['ema_test_coco_eval_masks'] for h in self.history if 'ema_test_coco_eval_masks' in h]
ema_ap50_90_masks = np.array([safe_index(x, 0) for x in ema_coco_eval_masks if x is not None], dtype=np.float32)
ema_ap50_masks = np.array([safe_index(x, 1) for x in ema_coco_eval_masks if x is not None], dtype=np.float32)
ema_ar50_90_masks = np.array([safe_index(x, 8) for x in ema_coco_eval_masks if x is not None], dtype=np.float32)

if len(test_coco_eval_masks) or len(ema_coco_eval_masks):
fig, axes = plt.subplots(4, 2, figsize=(18, 12))
else:
fig, axes = plt.subplots(2, 2, figsize=(18, 12))

# Subplot (0,0): Training and Validation Loss
# Training and Validation Loss
r, c = 0, 0
if len(epochs) > 0:
if len(train_loss):
axes[0][0].plot(epochs, train_loss, label='Training Loss', marker='o', linestyle='-')
axes[r][c].plot(epochs, train_loss, label='Training Loss', marker='o', linestyle='-')
if len(test_loss):
axes[0][0].plot(epochs, test_loss, label='Validation Loss', marker='o', linestyle='--')
axes[0][0].set_title('Training and Validation Loss')
axes[0][0].set_xlabel('Epoch Number')
axes[0][0].set_ylabel('Loss Value')
axes[0][0].legend()
axes[0][0].grid(True)

# Subplot (0,1): Average Precision @0.50
if ap50.size > 0 or ema_ap50.size > 0:
if ap50.size > 0:
axes[0][1].plot(epochs[:len(ap50)], ap50, marker='o', linestyle='-', label='Base Model')
if ema_ap50.size > 0:
axes[0][1].plot(epochs[:len(ema_ap50)], ema_ap50, marker='o', linestyle='--', label='EMA Model')
axes[0][1].set_title('Average Precision @0.50')
axes[0][1].set_xlabel('Epoch Number')
axes[0][1].set_ylabel('AP50')
axes[0][1].legend()
axes[0][1].grid(True)

# Subplot (1,0): Average Precision @0.50:0.95
if ap50_90.size > 0 or ema_ap50_90.size > 0:
if ap50_90.size > 0:
axes[1][0].plot(epochs[:len(ap50_90)], ap50_90, marker='o', linestyle='-', label='Base Model')
if ema_ap50_90.size > 0:
axes[1][0].plot(epochs[:len(ema_ap50_90)], ema_ap50_90, marker='o', linestyle='--', label='EMA Model')
axes[1][0].set_title('Average Precision @0.50:0.95')
axes[1][0].set_xlabel('Epoch Number')
axes[1][0].set_ylabel('AP')
axes[1][0].legend()
axes[1][0].grid(True)

# Subplot (1,1): Average Recall @0.50:0.95
if ar50_90.size > 0 or ema_ar50_90.size > 0:
if ar50_90.size > 0:
axes[1][1].plot(epochs[:len(ar50_90)], ar50_90, marker='o', linestyle='-', label='Base Model')
if ema_ar50_90.size > 0:
axes[1][1].plot(epochs[:len(ema_ar50_90)], ema_ar50_90, marker='o', linestyle='--', label='EMA Model')
axes[1][1].set_title('Average Recall @0.50:0.95')
axes[1][1].set_xlabel('Epoch Number')
axes[1][1].set_ylabel('AR')
axes[1][1].legend()
axes[1][1].grid(True)
axes[r][c].plot(epochs, test_loss, label='Validation Loss', marker='o', linestyle='--')
axes[r][c].set_title('Training and Validation Loss')
axes[r][c].set_xlabel('Epoch Number')
axes[r][c].set_ylabel('Loss Value')
axes[r][c].legend()
axes[r][c].grid(True)

# BBox Average Precision @0.50
r, c = 0, 1
if ap50_bbox.size > 0 or ema_ap50_bbox.size > 0:
if ap50_bbox.size > 0:
axes[r][c].plot(epochs[:len(ap50_bbox)], ap50_bbox, marker='o', linestyle='-', label='Base Model')
if ema_ap50_bbox.size > 0:
axes[r][c].plot(epochs[:len(ema_ap50_bbox)], ema_ap50_bbox, marker='o', linestyle='--', label='EMA Model')
axes[r][c].set_title('Average Precision @0.50 (BBox)')
axes[r][c].set_xlabel('Epoch Number')
axes[r][c].set_ylabel('AP50')
axes[r][c].legend()
axes[r][c].grid(True)

# BBox Average Precision @0.50:0.95
r, c = 1, 0
if ap50_90_bbox.size > 0 or ema_ap50_90_bbox.size > 0:
if ap50_90_bbox.size > 0:
axes[r][c].plot(epochs[:len(ap50_90_bbox)], ap50_90_bbox, marker='o', linestyle='-', label='Base Model')
if ema_ap50_90_bbox.size > 0:
axes[r][c].plot(epochs[:len(ema_ap50_90_bbox)], ema_ap50_90_bbox, marker='o', linestyle='--', label='EMA Model')
axes[r][c].set_title('Average Precision @0.50:0.95 (BBox)')
axes[r][c].set_xlabel('Epoch Number')
axes[r][c].set_ylabel('AP')
axes[r][c].legend()
axes[r][c].grid(True)

# BBox Average Recall @0.50:0.95
r, c = 1, 1
if ar50_90_bbox.size > 0 or ema_ar50_90_bbox.size > 0:
if ar50_90_bbox.size > 0:
axes[r][c].plot(epochs[:len(ar50_90_bbox)], ar50_90_bbox, marker='o', linestyle='-', label='Base Model')
if ema_ar50_90_bbox.size > 0:
axes[r][c].plot(epochs[:len(ema_ar50_90_bbox)], ema_ar50_90_bbox, marker='o', linestyle='--', label='EMA Model')
axes[r][c].set_title('Average Recall @0.50:0.95 (BBox)')
axes[r][c].set_xlabel('Epoch Number')
axes[r][c].set_ylabel('AR')
axes[r][c].legend()
axes[r][c].grid(True)

# Masks Average Precision @0.50
r, c = 2, 0
if ap50_masks.size > 0 or ema_ap50_masks.size > 0:
if ap50_masks.size > 0:
axes[r][c].plot(epochs[:len(ap50_masks)], ap50_masks, marker='o', linestyle='-', label='Base Model')
if ema_ap50_masks.size > 0:
axes[r][c].plot(epochs[:len(ema_ap50_masks)], ema_ap50_masks, marker='o', linestyle='--', label='EMA Model')
axes[r][c].set_title('Average Precision @0.50 (Masks)')
axes[r][c].set_xlabel('Epoch Number')
axes[r][c].set_ylabel('AP50')
axes[r][c].legend()
axes[r][c].grid(True)

# Masks Average Precision @0.50:0.95
r, c = 2, 1
if ap50_90_masks.size > 0 or ema_ap50_90_masks.size > 0:
if ap50_90_masks.size > 0:
axes[r][c].plot(epochs[:len(ap50_90_masks)], ap50_90_masks, marker='o', linestyle='-', label='Base Model')
if ema_ap50_90_masks.size > 0:
axes[r][c].plot(epochs[:len(ema_ap50_90_masks)], ema_ap50_90_masks, marker='o', linestyle='--', label='EMA Model')
axes[r][c].set_title('Average Precision @0.50:0.95 (Masks)')
axes[r][c].set_xlabel('Epoch Number')
axes[r][c].set_ylabel('AP')
axes[r][c].legend()
axes[r][c].grid(True)

# Masks Average Recall @0.50:0.95
r, c = 3, 0
if ar50_90_masks.size > 0 or ema_ar50_90_masks.size > 0:
if ar50_90_masks.size > 0:
axes[r][c].plot(epochs[:len(ar50_90_masks)], ar50_90_masks, marker='o', linestyle='-', label='Base Model')
if ema_ar50_90_masks.size > 0:
axes[r][c].plot(epochs[:len(ema_ar50_90_masks)], ema_ar50_90_masks, marker='o', linestyle='--', label='EMA Model')
axes[r][c].set_title('Average Recall @0.50:0.95 (Masks)')
axes[r][c].set_xlabel('Epoch Number')
axes[r][c].set_ylabel('AR')
axes[r][c].legend()
axes[r][c].grid(True)

plt.tight_layout()
plt.savefig(f"{self.output_dir}/{PLOT_FILE_NAME}")
Expand Down