diff --git a/rfdetr/detr.py b/rfdetr/detr.py index 021bf69..311c4a5 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -126,8 +126,8 @@ def train_from_config(self, config: TrainConfig, **kwargs): os.path.join(config.dataset_dir, "train", "_annotations.coco.json"), "r" ) as f: anns = json.load(f) - num_classes = len(anns["categories"]) class_names = [c["name"] for c in anns["categories"] if c["supercategory"] != "none"] + num_classes = max(c["id"] for c in anns["categories"] if c["supercategory"] != "none") self.model.class_names = class_names if self.model_config.num_classes != num_classes: @@ -135,7 +135,7 @@ def train_from_config(self, config: TrainConfig, **kwargs): f"num_classes mismatch: model has {self.model_config.num_classes} classes, but your dataset has {num_classes} classes\n" f"reinitializing your detection head with {num_classes} classes." ) - self.model.reinitialize_detection_head(num_classes) + self.model.reinitialize_detection_head(num_classes+1) train_config = config.dict() model_config = self.model_config.dict()