diff --git a/rfdetr/detr.py b/rfdetr/detr.py index 1cb03a3..1f1d66d 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -139,6 +139,9 @@ def train_from_config(self, config: TrainConfig, **kwargs): raise ValueError(f"Invalid dataset file: {config.dataset_file}") if self.model_config.num_classes != num_classes: + logger.warning( + f"Reinitializing your detection head with {num_classes} classes." + ) self.model.reinitialize_detection_head(num_classes) train_config = config.dict() diff --git a/rfdetr/main.py b/rfdetr/main.py index f52a238..c873c0e 100644 --- a/rfdetr/main.py +++ b/rfdetr/main.py @@ -100,6 +100,9 @@ def __init__(self, **kwargs): checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0] if checkpoint_num_classes != args.num_classes + 1: + logger.warning( + f"Reinitializing detection head with {checkpoint_num_classes} classes" + ) self.reinitialize_detection_head(checkpoint_num_classes) # add support to exclude_keys # e.g., when load object365 pretrain, do not load `class_embed.[weight, bias]`