Created by Prarthana Bhattacharyya.
Disclaimer: This is not an official product and is meant to be a proof-of-concept and for academic/educational use only.
This repository contains the PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime, to be presented at ICASSP-2022.
Self-supervision has shown outstanding results for natural language processing, and more recently, for image recognition. Simultaneously, vision transformers and its variants have emerged as a promising and scalable alternative to convolutions on various computer vision tasks. In this paper, we are the first to question if self-supervised vision transformers (SSL-ViTs) can be adapted to two important computer vision tasks in the low-label, high-data regime: few-shot image classification and zero-shot image retrieval. The motivation is to reduce the number of manual annotations required to train a visual embedder, and to produce generalizable, semantically meaningful and robust embeddings.
- SSL-ViT + few-shot image classification:
- Qualitative analysis for base-classes chosen by supervised CNN and SSL-ViT for few-shot distribution calibration:
- SSL-ViT + zero-shot image retrieval:
- Run DINO with ViT-small network on a single node with 4 GPUs for 100 epochs with the following command.
cd dino/python -m torch.distributed.launch --nproc_per_node=4 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir- For mini-ImageNet pretraining, we use the classes listed in:
ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_mini.txtFor tiered-ImageNet pretraining, we use the classes listed in:ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_tiered.txt - For CUB-200, Cars-196 and SOP, we use the pretrained model from:
import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')Please follow the instruction in FRN for few-shot image classification and RevisitDML for zero-shot image retrieval to download the datasets and put the corresponding datasets in ssl-vit-fewshot/data and DIML/data folder.
- The first step is to extract features for base and novel classes using the pretrained SSL-ViT.
get_dino_miniimagenet_feats.ipynbextracts SSL-ViT features for the base and novel classes.- Change the hyper-parameter
data_pathto use CUB or tiered-ImageNet. - The SSL-ViT checkpoints for the various datasets are provided below (Note: this has only been trained without labels). We also provide the extracted features which need to be stored in
ssl-vit-fewshot/dino_features_data/.
| arch | dataset | download | extracted-train | extracted-test |
|---|---|---|---|---|
| ViT-S/16 | mini-ImageNet | mini_imagenet_checkpoint.pth | train.p | test.p |
| ViT-S/16 | tiered-ImageNet | tiered_imagenet_checkpoint.pth | train.p | test.p |
| ViT-S/16 | CUB | cub_checkpoint.pth | train.p | test.p |
- For n-way-k-shot evaluation, we provide
miniimagenet_evaluate_dinoDC.ipynb.
- To train the baseline CNN models, run the scripts in
DIML/scripts/baselines. The checkpoints are saved in Training_Results folder. For example:
cd DIML/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh- To train the supervised ViT and self-supervised ViT:
cp -r ssl-vit-retrieval/architectures/* DIML/ssl-vit-retrieval/architectures/CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch vits
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch dino- To test the models, first edit the checkpoint paths in
test_diml.py, then run
CUDA_VISIBLE_DEVICES=0 ./scripts/diml/test_diml.sh cub200| dataset | Loss | SSL-ViT-download |
|---|---|---|
| CUB | Margin | cub_ssl-vit-margin.pth |
| CUB | Proxy-NCA | cub_ssl-vit-proxynca.pth |
| CUB | Multi-Similarity | cub_ssl-vit-ms.pth |
| Cars-196 | Margin | cars_ssl-vit-margin.pth |
| Cars-196 | Proxy-NCA | cars_ssl-vit-proxynca.pth |
| Cars-196 | Multi-Similarity | cars_ssl-vit-ms.pth |
The code is based on:



