Skip to content

Commit 72d4799

Browse files
committed
Add more data augmentation options for Unsupervised Domain Adaptation for Image Classification
1 parent a1fcfb1 commit 72d4799

File tree

13 files changed

+126
-72
lines changed

13 files changed

+126
-72
lines changed

examples/domain_adaptation/image_classification/README.md

+20-18
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Unsupervised Domain Adaptation for Image Classification
22

33
## Installation
4+
45
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
56

6-
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).
7-
You also need to install timm to use PyTorch-Image-Models.
7+
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You
8+
also need to install timm to use PyTorch-Image-Models.
89

910
```
1011
pip install timm
@@ -14,19 +15,22 @@ pip install timm
1415

1516
Following datasets can be downloaded automatically:
1617

17-
- [MNIST](http://yann.lecun.com/exdb/mnist/), [SVHN](http://ufldl.stanford.edu/housenumbers/), [USPS](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps)
18+
- [MNIST](http://yann.lecun.com/exdb/mnist/), [SVHN](http://ufldl.stanford.edu/housenumbers/)
19+
, [USPS](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps)
1820
- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)
1921
- [OfficeCaltech](https://www.cc.gatech.edu/~judy/domainadapt/)
2022
- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)
2123
- [VisDA2017](http://ai.bu.edu/visda-2017/)
2224
- [DomainNet](http://ai.bu.edu/M3SDA/)
2325

2426
You need to prepare following datasets manually if you want to use them:
27+
2528
- [ImageNet](https://www.image-net.org/)
2629
- [ImageNetR](https://github.com/hendrycks/imagenet-r)
2730
- [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch)
2831

29-
and prepare them following [Documentation for ImageNetR](/common/vision/datasets/imagenet_r.py) and [ImageNet-Sketch](/common/vision/datasets/imagenet_sketch.py).
32+
and prepare them following [Documentation for ImageNetR](/common/vision/datasets/imagenet_r.py)
33+
and [ImageNet-Sketch](/common/vision/datasets/imagenet_sketch.py).
3034

3135
## Supported Methods
3236

@@ -45,8 +49,8 @@ Supported methods include:
4549

4650
## Usage
4751

48-
The shell files give the script to reproduce the benchmark with specified hyper-parameters.
49-
For example, if you want to train DANN on Office31, use the following script
52+
The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to
53+
train DANN on Office31, use the following script
5054

5155
```shell script
5256
# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.
@@ -55,16 +59,17 @@ For example, if you want to train DANN on Office31, use the following script
5559
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W
5660
```
5761

58-
Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain, and ``--log`` specifies where to store results.
62+
Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain, and ``--log`` specifies where to store
63+
results.
5964

60-
After running the above command, it will download ``Office-31`` datasets from the Internet if it's the first time you run the code. Directory that stores datasets will be named as
65+
After running the above command, it will download ``Office-31`` datasets from the Internet if it's the first time you
66+
run the code. Directory that stores datasets will be named as
6167
``examples/domain_adaptation/image_classification/data/<dataset name>``.
6268

6369
If everything works fine, you will see results in following format::
6470

6571
Epoch: [1][ 900/1000] Time 0.60 ( 0.69) Data 0.22 ( 0.31) Loss 0.74 ( 0.85) Cls Acc 96.9 (95.1) Domain Acc 64.1 (62.6)
6672

67-
6873
You can also watch these results in the log file ``logs/dann/Office31_A2W/log.txt``.
6974

7075
After training, you can test your algorithm's performance by passing in ``--phase test``.
@@ -73,21 +78,19 @@ After training, you can test your algorithm's performance by passing in ``--phas
7378
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W --phase test
7479
```
7580

76-
7781
## Experiment and Results
7882

7983
**Notations**
84+
8085
- ``Origin`` means the accuracy reported by the original paper.
8186
- ``Avg`` is the accuracy reported by `TLlib`.
8287
- ``ERM`` refers to the model trained with data from the source domain.
8388
- ``Oracle`` refers to the model trained with data from the target domain.
8489

85-
86-
We found that the accuracies of adversarial methods (including DANN, ADDA, CDAN, MCD, BSP and MDD) are not stable
87-
even after the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017*
90+
We found that the accuracies of adversarial methods (including DANN, ADDA, CDAN, MCD, BSP and MDD) are not stable even
91+
after the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017*
8892
for three times and report their average accuracy.
8993

90-
9194
### Office-31 accuracy on ResNet-50
9295

9396
| Methods | Origin | Avg | A → W | D → W | W → D | A → D | D → A | W → A |
@@ -162,8 +165,8 @@ for three times and report their average accuracy.
162165
| MDD | 42.9 | 59.5 | 47.5 | 48.6 | 59.4 | 42.6 | 58.3 | 53.7 | 46.2 | 58.7 | 46.5 | 57.7 | 51.8 |
163166
| MCC | 37.7 | 55.7 | 42.6 | 45.4 | 59.8 | 39.9 | 54.4 | 53.1 | 37.0 | 58.1 | 46.3 | 56.2 | 48.9 |
164167

165-
166168
### DomainNet accuracy on ResNet-101 (Multi-Source)
169+
167170
| Methods | Origin | Avg | :c | :i | :p | :q | :r | :s |
168171
|-------------|--------|------|------|------|------|------|------|------|
169172
| ERM | 32.9 | 47.0 | 64.9 | 25.2 | 54.4 | 16.9 | 68.2 | 52.3 |
@@ -185,7 +188,6 @@ for three times and report their average accuracy.
185188

186189
## Visualization
187190

188-
189191
After training `DANN`, run the following command
190192

191193
```
@@ -200,15 +202,15 @@ Following are the t-SNE of representations from ResNet50 trained on source domai
200202
<img src="./fig/resnet_A2W.png" width="300"/>
201203
<img src="./fig/dann_A2W.png" width="300"/>
202204

203-
204205
## TODO
206+
205207
1. Support self-training methods
206208
2. Support translation methods
207209
3. Add results on ViT
208210
4. Add results on ImageNet
209-
5. Add more data augmentation options
210211

211212
## Citation
213+
212214
If you use these methods in your research, please consider citing.
213215

214216
```

examples/domain_adaptation/image_classification/adda.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def main(args: argparse.Namespace):
5858
cudnn.benchmark = True
5959

6060
# Data loading code
61-
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
61+
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
62+
random_horizontal_flip=not args.no_hflip,
6263
random_color_jitter=False, resize_size=args.resize_size,
6364
norm_mean=args.norm_mean, norm_std=args.norm_std)
6465
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
@@ -100,7 +101,8 @@ def main(args: argparse.Namespace):
100101
for epoch in range(args.pretrain_epochs):
101102
print("lr:", pretrain_lr_scheduler.get_lr())
102103
# pretrain for one epoch
103-
utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args,
104+
utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,
105+
pretrain_lr_scheduler, epoch, args,
104106
device)
105107
# validate to show pretrain process
106108
utils.validate(val_loader, pretrain_model, args, device)
@@ -244,6 +246,10 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
244246
parser.add_argument('--val-resizing', type=str, default='default')
245247
parser.add_argument('--resize-size', type=int, default=224,
246248
help='the image size after resizing')
249+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
250+
help='Random resize scale (default: 0.08 1.0)')
251+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
252+
help='Random resize aspect ratio (default: 0.75 1.33)')
247253
parser.add_argument('--no-hflip', action='store_true',
248254
help='no random horizontal flipping during training')
249255
parser.add_argument('--norm-mean', type=float, nargs='+',

examples/domain_adaptation/image_classification/bsp.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tllib.utils.logger import CompleteLogger
2929
from tllib.utils.analysis import collect_feature, tsne, a_distance
3030

31-
3231
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3332

3433

@@ -49,7 +48,8 @@ def main(args: argparse.Namespace):
4948
cudnn.benchmark = True
5049

5150
# Data loading code
52-
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
51+
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
52+
random_horizontal_flip=not args.no_hflip,
5353
random_color_jitter=False, resize_size=args.resize_size,
5454
norm_mean=args.norm_mean, norm_std=args.norm_std)
5555
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
@@ -126,7 +126,8 @@ def main(args: argparse.Namespace):
126126
for epoch in range(args.pretrain_epochs):
127127
print("lr:", pretrain_lr_scheduler.get_lr())
128128
# pretrain for one epoch
129-
utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args,
129+
utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,
130+
pretrain_lr_scheduler, epoch, args,
130131
device)
131132
# validate to show pretrain process
132133
utils.validate(val_loader, pretrain_model, args, device)
@@ -237,6 +238,10 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
237238
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
238239
parser.add_argument('--train-resizing', type=str, default='default')
239240
parser.add_argument('--val-resizing', type=str, default='default')
241+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
242+
help='Random resize scale (default: 0.08 1.0)')
243+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
244+
help='Random resize aspect ratio (default: 0.75 1.33)')
240245
parser.add_argument('--resize-size', type=int, default=224,
241246
help='the image size after resizing')
242247
parser.add_argument('--no-hflip', action='store_true',

examples/domain_adaptation/image_classification/cdan.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from tllib.utils.logger import CompleteLogger
2828
from tllib.utils.analysis import collect_feature, tsne, a_distance
2929

30-
3130
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3231

3332

@@ -48,7 +47,8 @@ def main(args: argparse.Namespace):
4847
cudnn.benchmark = True
4948

5049
# Data loading code
51-
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
50+
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
51+
random_horizontal_flip=not args.no_hflip,
5252
random_color_jitter=False, resize_size=args.resize_size,
5353
norm_mean=args.norm_mean, norm_std=args.norm_std)
5454
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
@@ -221,6 +221,10 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
221221
parser.add_argument('--val-resizing', type=str, default='default')
222222
parser.add_argument('--resize-size', type=int, default=224,
223223
help='the image size after resizing')
224+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
225+
help='Random resize scale (default: 0.08 1.0)')
226+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
227+
help='Random resize aspect ratio (default: 0.75 1.33)')
224228
parser.add_argument('--no-hflip', action='store_true',
225229
help='no random horizontal flipping during training')
226230
parser.add_argument('--norm-mean', type=float, nargs='+',

examples/domain_adaptation/image_classification/dan.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from tllib.utils.logger import CompleteLogger
2828
from tllib.utils.analysis import collect_feature, tsne, a_distance
2929

30-
3130
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3231

3332

@@ -48,7 +47,8 @@ def main(args: argparse.Namespace):
4847
cudnn.benchmark = True
4948

5049
# Data loading code
51-
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
50+
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
51+
random_horizontal_flip=not args.no_hflip,
5252
random_color_jitter=False, resize_size=args.resize_size,
5353
norm_mean=args.norm_mean, norm_std=args.norm_std)
5454
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
@@ -77,7 +77,7 @@ def main(args: argparse.Namespace):
7777

7878
# define optimizer and lr scheduler
7979
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
80-
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
80+
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
8181

8282
# define loss function
8383
mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy(
@@ -207,6 +207,10 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
207207
parser.add_argument('--val-resizing', type=str, default='default')
208208
parser.add_argument('--resize-size', type=int, default=224,
209209
help='the image size after resizing')
210+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
211+
help='Random resize scale (default: 0.08 1.0)')
212+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
213+
help='Random resize aspect ratio (default: 0.75 1.33)')
210214
parser.add_argument('--no-hflip', action='store_true',
211215
help='no random horizontal flipping during training')
212216
parser.add_argument('--norm-mean', type=float, nargs='+',
@@ -259,4 +263,3 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
259263
"When phase is 'analysis', only analysis the model.")
260264
args = parser.parse_args()
261265
main(args)
262-

examples/domain_adaptation/image_classification/dann.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from tllib.utils.logger import CompleteLogger
2828
from tllib.utils.analysis import collect_feature, tsne, a_distance
2929

30-
3130
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3231

3332

@@ -48,7 +47,8 @@ def main(args: argparse.Namespace):
4847
cudnn.benchmark = True
4948

5049
# Data loading code
51-
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
50+
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
51+
random_horizontal_flip=not args.no_hflip,
5252
random_color_jitter=False, resize_size=args.resize_size,
5353
norm_mean=args.norm_mean, norm_std=args.norm_std)
5454
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
@@ -79,7 +79,7 @@ def main(args: argparse.Namespace):
7979
# define optimizer and lr scheduler
8080
optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
8181
args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
82-
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
82+
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
8383

8484
# define loss function
8585
domain_adv = DomainAdversarialLoss(domain_discri).to(device)
@@ -210,6 +210,10 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
210210
parser.add_argument('--val-resizing', type=str, default='default')
211211
parser.add_argument('--resize-size', type=int, default=224,
212212
help='the image size after resizing')
213+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
214+
help='Random resize scale (default: 0.08 1.0)')
215+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
216+
help='Random resize aspect ratio (default: 0.75 1.33)')
213217
parser.add_argument('--no-hflip', action='store_true',
214218
help='no random horizontal flipping during training')
215219
parser.add_argument('--norm-mean', type=float, nargs='+',
@@ -239,7 +243,7 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
239243
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
240244
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
241245
help='momentum')
242-
parser.add_argument('--wd', '--weight-decay',default=1e-3, type=float,
246+
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
243247
metavar='W', help='weight decay (default: 1e-3)',
244248
dest='weight_decay')
245249
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
@@ -261,4 +265,3 @@ def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverData
261265
"When phase is 'analysis', only analysis the model.")
262266
args = parser.parse_args()
263267
main(args)
264-

examples/domain_adaptation/image_classification/erm.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from tllib.utils.analysis import collect_feature, tsne, a_distance
2222
from tllib.utils.data import ForeverDataIterator
2323

24-
2524
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2625

2726

@@ -42,7 +41,8 @@ def main(args):
4241
cudnn.benchmark = True
4342

4443
# Data loading code
45-
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
44+
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
45+
random_horizontal_flip=not args.no_hflip,
4646
random_color_jitter=False, resize_size=args.resize_size,
4747
norm_mean=args.norm_mean, norm_std=args.norm_std)
4848
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
@@ -136,6 +136,10 @@ def main(args):
136136
parser.add_argument('--val-resizing', type=str, default='default')
137137
parser.add_argument('--resize-size', type=int, default=224,
138138
help='the image size after resizing')
139+
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
140+
help='Random resize scale (default: 0.08 1.0)')
141+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
142+
help='Random resize aspect ratio (default: 0.75 1.33)')
139143
parser.add_argument('--no-hflip', action='store_true',
140144
help='no random horizontal flipping during training')
141145
parser.add_argument('--norm-mean', type=float, nargs='+',

0 commit comments

Comments
 (0)