Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to manage single class datasets #48

Open
picjul opened this issue Mar 25, 2025 · 6 comments
Open

How to manage single class datasets #48

picjul opened this issue Mar 25, 2025 · 6 comments

Comments

@picjul
Copy link

picjul commented Mar 25, 2025

Hi!

I am trying to perform a fine-tuning of the network on a single-class dataset.

This is the log:

UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Loading pretrain weights
WARNING:rfdetr.detr:num_classes mismatch: model has 90 classes, but your dataset has 1 classes
reinitializing your detection head with 1 classes.
Not using distributed mode
git:
  sha: N/A, status: clean, branch: N/A

Namespace(num_classes=1, grad_accum_steps=1, amp=True, lr=0.0001, lr_encoder=0.00015, batch_size=4, weight_decay=0.0001, epochs=10, lr_drop=100, clip_max_norm=0.1, lr_vit_layer_decay=0.8, lr_component_decay=0.7, do_benchmark=False, dropout=0, drop_path=0.0, drop_mode='standard', drop_schedule='constant', cutoff_epoch=0, pretrained_encoder=None, pretrain_weights='rf-detr-base.pth', pretrain_exclude_keys=None, pretrain_keys_modify_to_load=None, pretrained_distiller=None, encoder='dinov2_windowed_small', vit_encoder_num_layers=12, window_block_indexes=None, position_embedding='sine', out_feature_indexes=[2, 5, 8, 11], freeze_encoder=False, layer_norm=True, rms_norm=False, backbone_lora=False, force_no_pretrain=False, dec_layers=3, dim_feedforward=2048, hidden_dim=256, sa_nheads=8, ca_nheads=16, num_queries=300, group_detr=13, two_stage=True, projector_scale=['P4'], lite_refpoint_refine=True, num_select=300, dec_n_points=2, decoder_norm='LN', bbox_reparam=True, freeze_batch_norm=False, set_cost_class=2, set_cost_bbox=5, set_cost_giou=2, cls_loss_coef=1.0, bbox_loss_coef=5, giou_loss_coef=2, focal_alpha=0.25, aux_loss=True, sum_group_losses=False, use_varifocal_loss=False, use_position_supervised_loss=False, ia_bce_loss=True, dataset_file='roboflow', coco_path=None, dataset_dir='/content/COCO_Ormelle_COLAB', square_resize_div_64=True, output_dir='output', dont_save_weights=False, checkpoint_interval=10, seed=42, resume='', start_epoch=0, eval=False, use_ema=True, ema_decay=0.993, ema_tau=100, num_workers=2, device='cpu', world_size=1, dist_url='env://', sync_bn=True, fp16_eval=False, encoder_only=False, backbone_only=False, resolution=560, use_cls_token=False, multi_scale=True, expanded_scales=True, warmup_epochs=0, lr_scheduler='step', lr_min_factor=0.0, distributed=False)
number of params: 31850710
[392, 448, 504, 560, 616, 672, 728, 784]
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
[392, 448, 504, 560, 616, 672, 728, 784]
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Get benchmark
Start training
Grad accum steps:  1
Total batch size:  4
LENGTH OF DATA LOADER: 18
FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:3637.)

And this is the exception thrown:

[/usr/local/lib/python3.11/dist-packages/rfdetr/models/matcher.py](https://localhost:8080/#) in forward(self, outputs, targets, group_detr)
     85         neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
     86         pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
---> 87         cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
     88 
     89         # Compute the L1 cost between boxes

IndexError: index 1 is out of bounds for dimension 0 with size 1

How should this type of dataset be handled?

System : Colab on CPU (On CUDA I think there's the same issue CUDA error: device-side assert triggered)
Steps: the ones listed in how-to-finetune-rf-detr-on-detection-dataset.ipynb

Thank you!

@picjul
Copy link
Author

picjul commented Mar 26, 2025

I think I have solved the problem. Using a sample dataset on Roboflow, I noticed how a “background” class, called "Workers", is created by default, like in the following snippet:

"categories": [
        {
            "id": 0,
            "name": "Workers",
            "supercategory": "none"
        },
        {
            "id": 1,
            "name": "custom",
            "supercategory": "Workers"
        }
]

My categories elements consisted only in one class:

"categories": [
        {
            "supercategory": "none",
            "id": 1,
            "name": "custom"
        }
]

By adjusting the JSON files, the training starts normally.

This should probably be specified in the README of the project, to avoid problems with ‘legacy’ datasets not exported directly from Roboflow.

Bye! 🚀

@picjul picjul closed this as completed Mar 26, 2025
@isaacrob-roboflow
Copy link
Collaborator

@SkalskiP do we want to natively support this format? alternatively we could build some internal check? otherwise we should update the docs to show that this is happening

@picjul
Copy link
Author

picjul commented Mar 26, 2025

IMHO it is not necessary to handle it as a special case. I have tried to train several networks with single-class COCO datasets and always found the same anomalies.
Probably the codebase is not 'optimised' to handle these cases. It would be sufficient to update the documentation

@SkalskiP
Copy link
Collaborator

Hi @picjul Thanks for your interest in RF-DETR and for diving deep and helping to understand the source of the problem.

I'm concerned because the one-class fine-tuning error is one of the most frequently reported so far and I think we need to try to solve it somehow. The question is how could we automatically verify this? @picjul, is it a correct assumption that class IDs always start from 1 and we always discard the class with ID 0? If so, the fix seems simple.

@picjul
Copy link
Author

picjul commented Mar 27, 2025

The original annotations of the COCO dataset begin with index 1, so this is probably the correct approach. I don't recall finding any contraindications to starting with 0. I think there were some sources (forgotten them, Sorry 😅) that mentioned class 0 as the background class.
By default, from what I could see, if a class without instances in the dataset is defined, it does not contribute to the calculation of the training/validation metrics.

It would be better to handle the problem in terms of indexes. Based on what is the size of the array defined that gives indexes out of bound?
(Forgive me, but I cannot spare the time to explore the code in depth 😭)

Maybe using a single class with id = 0 can do the job? 🤔

[Writing from mobile phone, pardon errors and typos]

@BartvanMarrewijk
Copy link

BartvanMarrewijk commented Mar 31, 2025

For people that need a quick fix, just add class - 1 , line 82 in ConvertCoco:

        classes = torch.tensor(classes, dtype=torch.int64)  - 1

Then it will train, but evaluation during training does not work nicely; to solve that issue:
line 99 in coco_eval.py:

            labels = (prediction["labels"]+1).tolist()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants