Skip to content

Commit 5909da5

Browse files
Sven Gowalderpson
Sven Gowal
authored andcommitted
Added jaxline pipeline to train adversarially robust models.
PiperOrigin-RevId: 383399487
1 parent d8df415 commit 5909da5

36 files changed

+5229
-79
lines changed

adversarial_robustness/README.md

+70-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ We have released our top-performing models in two formats compatible with
1313
[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/).
1414
This repository also contains our model definitions.
1515

16-
## Running the example code
16+
## Running the code
1717

1818
### Downloading a model
1919

@@ -47,18 +47,80 @@ The following table contains the models from **Rebuffi et al., 2021**.
4747
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2717; | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt)
4848
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | &#x2717; | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt)
4949

50-
### Using the model
50+
### Installing
5151

52-
Once downloaded, a model can be evaluated (clean accuracy) by running the
53-
`eval.py` script in either the `jax` or `pytorch` folders. E.g.:
52+
The following has been tested using Python 3.9.2.
53+
Using `run.sh` will create and activate a virtualenv, install all necessary
54+
dependencies and run a test program to ensure that you can import all the
55+
modules.
56+
57+
```
58+
# Run from the parent directory.
59+
sh adversarial_robustness/run.sh
60+
```
61+
62+
To run the provided code, use this virtualenv:
63+
64+
```
65+
source /tmp/adversarial_robustness_venv/bin/activate
66+
```
67+
68+
You may want to edit `requirements.txt` before running `run.sh` if GPU support
69+
is needed (e.g., use `jaxline==0.1.67+cuda111`). See JAX's installation
70+
[instructions](https://github.com/google/jax#installation) for more details.
71+
72+
### Using pre-trained models
73+
74+
Once downloaded, a model can be evaluated by running the `eval.py` script in
75+
either the `jax` or `pytorch` folders. E.g.:
5476

5577
```
5678
cd jax
5779
python3 eval.py \
5880
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
5981
```
6082

61-
## Generated datasets
83+
These models are also directly available within
84+
[RobustBench](https://github.com/RobustBench/robustbench#model-zoo-quick-tour)'s
85+
model zoo.
86+
87+
### Training your own model
88+
89+
We also provide a training pipeline that reproduces results from both
90+
publications. This pipeline uses [Jaxline](https://github.com/deepmind/jaxline)
91+
and is written using [JAX](https://github.com/google/jax) and
92+
[Haiku](https://github.com/deepmind/dm-haiku). To train a model, modify the
93+
configuration in the `get_config()` function of `jax/experiment.py` and issue
94+
the following command from within the virtualenv created above:
95+
96+
```
97+
cd jax
98+
python3 train.py --config=experiment.py
99+
```
100+
101+
The training pipeline can run with multiple worker machines and multiple devices
102+
(either GPU or TPU). See [Jaxline](https://github.com/deepmind/jaxline) for more
103+
details.
104+
105+
We do not provide a PyTorch implementation of our training pipeline. However,
106+
you may find one on GitHub, e.g.,
107+
[adversarial_robustness_pytorch](https://github.com/imrahulr/adversarial_robustness_pytorch)
108+
(by Rahul Rade).
109+
110+
## Datasets
111+
112+
### Extracted dataset
113+
114+
Gowal et al. (2020) use samples extracted from
115+
[TinyImages-80M](https://groups.csail.mit.edu/vision/TinyImages/).
116+
Unfortunately, since then, the official TinyImages-80M dataset has been
117+
withdrawn (due to the presence of offensive images). As such, we cannot provide
118+
a download link to our extrated data until we have manually verified that all
119+
extracted images are not offensive. If you want to reproduce our setup, consider
120+
the generated datasets below. We are also happy to help, so feel free to reach
121+
out to Sven Gowal directly.
122+
123+
### Generated datasets
62124

63125
Rebuffi et al. (2021) use samples generated by a Denoising Diffusion
64126
Probabilistic Model [(DDPM; Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
@@ -82,8 +144,8 @@ labels = npzfile['label']
82144

83145
## Citing this work
84146

85-
If you use this code, data or these models in your work, please cite the
86-
relevant accompanying paper:
147+
If you use this code (or any derived code), data or these models in your work,
148+
please cite the relevant accompanying paper:
87149

88150
```
89151
@article{gowal2020uncovering,
@@ -95,7 +157,7 @@ relevant accompanying paper:
95157
}
96158
```
97159

98-
or
160+
and/or
99161

100162
```
101163
@article{rebuffi2021fixing,

0 commit comments

Comments
 (0)