Skip to content

Commit 8457046

Browse files
altchediegolascasas
authored andcommitted
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
1 parent 22c3daf commit 8457046

33 files changed

+397
-363
lines changed

byol/README.md

+34
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,37 @@ python -m byol.main_loop \
176176
With these settings, BYOL should achieve ~92.3% top-1 accuracy (for the
177177
*online* classifier) in roughly 4 hours. Note that the above parameters were not
178178
finely tuned and may not be optimal.
179+
180+
181+
## Additional checkpoints
182+
183+
Alongside with the pretrained ResNet-50 and ResNet-200 2x, we provide the
184+
following checkpoints from our ablation study. They all correspond to a
185+
ResNet-50 1x pre-trained over 300 epochs and were randomly selected within the
186+
three seeds; file size is roughly 640MB each.
187+
188+
- [Baseline](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_baseline.pkl)
189+
190+
- Smaller batch sizes (figure 3a):
191+
- [Batch size 2048](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_2048.pkl)
192+
- [Batch size 1024](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_1024.pkl)
193+
- [Batch size 512](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_512.pkl)
194+
- [Batch size 256](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_256.pkl)
195+
- [Batch size 128](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_128.pkl)
196+
- [Batch size 64](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_batchsize_64.pkl)
197+
198+
- Ablation on transformations (figure 3b):
199+
- [Remove grayscale](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_grayscale.pkl)
200+
- [Remove color](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_no_color.pkl)
201+
- [Crop and blur only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_blur_only.pkl)
202+
- [Crop only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_only.pkl)
203+
- (from Table 18) [Crop and color only](https://storage.googleapis.com/deepmind-byol/checkpoints/ablations/res50x1_crop_and_color_only.pkl)
204+
205+
206+
## License
207+
208+
While the code is licensed under the Apache 2.0 License, the checkpoints weights
209+
are made available for non-commercial use only under the terms of the
210+
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
211+
license. You can find details at:
212+
https://creativecommons.org/licenses/by-nc/4.0/legalcode.

byol/byol_experiment.py

+47-47
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ class ByolExperiment:
5656

5757
def __init__(
5858
self,
59-
random_seed,
60-
num_classes,
61-
batch_size,
62-
max_steps,
63-
enable_double_transpose,
64-
base_target_ema,
65-
network_config,
66-
optimizer_config,
67-
lr_schedule_config,
68-
evaluation_config,
69-
checkpointing_config):
59+
random_seed: int,
60+
num_classes: int,
61+
batch_size: int,
62+
max_steps: int,
63+
enable_double_transpose: bool,
64+
base_target_ema: float,
65+
network_config: Mapping[Text, Any],
66+
optimizer_config: Mapping[Text, Any],
67+
lr_schedule_config: Mapping[Text, Any],
68+
evaluation_config: Mapping[Text, Any],
69+
checkpointing_config: Mapping[Text, Any]):
7070
"""Constructs the experiment.
7171
7272
Args:
@@ -115,15 +115,15 @@ def __init__(
115115

116116
def _forward(
117117
self,
118-
inputs,
119-
projector_hidden_size,
120-
projector_output_size,
121-
predictor_hidden_size,
122-
encoder_class,
123-
encoder_config,
124-
bn_config,
125-
is_training,
126-
):
118+
inputs: dataset.Batch,
119+
projector_hidden_size: int,
120+
projector_output_size: int,
121+
predictor_hidden_size: int,
122+
encoder_class: Text,
123+
encoder_config: Mapping[Text, Any],
124+
bn_config: Mapping[Text, Any],
125+
is_training: bool,
126+
) -> Mapping[Text, jnp.ndarray]:
127127
"""Forward application of byol's architecture.
128128
129129
Args:
@@ -163,7 +163,7 @@ def _forward(
163163
classifier = hk.Linear(
164164
output_size=self._num_classes, name='classifier')
165165

166-
def apply_once_fn(images, suffix = ''):
166+
def apply_once_fn(images: jnp.ndarray, suffix: Text = ''):
167167
images = dataset.normalize_images(images)
168168

169169
embedding = net(images, is_training=is_training)
@@ -186,7 +186,7 @@ def apply_once_fn(images, suffix = ''):
186186
else:
187187
return apply_once_fn(inputs['images'], '')
188188

189-
def _optimizer(self, learning_rate):
189+
def _optimizer(self, learning_rate: float) -> optax.GradientTransformation:
190190
"""Build optimizer from config."""
191191
return optimizers.lars(
192192
learning_rate,
@@ -196,13 +196,13 @@ def _optimizer(self, learning_rate):
196196

197197
def loss_fn(
198198
self,
199-
online_params,
200-
target_params,
201-
online_state,
202-
target_state,
203-
rng,
204-
inputs,
205-
):
199+
online_params: hk.Params,
200+
target_params: hk.Params,
201+
online_state: hk.State,
202+
target_state: hk.Params,
203+
rng: jnp.ndarray,
204+
inputs: dataset.Batch,
205+
) -> Tuple[jnp.ndarray, Tuple[Mapping[Text, hk.State], LogsDict]]:
206206
"""Compute BYOL's loss function.
207207
208208
Args:
@@ -292,11 +292,11 @@ def _should_transpose_images(self):
292292

293293
def _update_fn(
294294
self,
295-
byol_state,
296-
global_step,
297-
rng,
298-
inputs,
299-
):
295+
byol_state: _ByolExperimentState,
296+
global_step: jnp.ndarray,
297+
rng: jnp.ndarray,
298+
inputs: dataset.Batch,
299+
) -> Tuple[_ByolExperimentState, LogsDict]:
300300
"""Update online and target parameters.
301301
302302
Args:
@@ -352,9 +352,9 @@ def _update_fn(
352352

353353
def _make_initial_state(
354354
self,
355-
rng,
356-
dummy_input,
357-
):
355+
rng: jnp.ndarray,
356+
dummy_input: dataset.Batch,
357+
) -> _ByolExperimentState:
358358
"""BYOL's _ByolExperimentState initialization.
359359
360360
Args:
@@ -393,8 +393,8 @@ def _make_initial_state(
393393
)
394394

395395
def step(self, *,
396-
global_step,
397-
rng):
396+
global_step: jnp.ndarray,
397+
rng: jnp.ndarray) -> Mapping[Text, np.ndarray]:
398398
"""Performs a single training step."""
399399
if self._train_input is None:
400400
self._initialize_train()
@@ -410,11 +410,11 @@ def step(self, *,
410410

411411
return helpers.get_first(scalars)
412412

413-
def save_checkpoint(self, step, rng):
413+
def save_checkpoint(self, step: int, rng: jnp.ndarray):
414414
self._checkpointer.maybe_save_checkpoint(
415415
self._byol_state, step=step, rng=rng, is_final=step >= self._max_steps)
416416

417-
def load_checkpoint(self):
417+
def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
418418
checkpoint_data = self._checkpointer.maybe_load_checkpoint()
419419
if checkpoint_data is None:
420420
return None
@@ -444,7 +444,7 @@ def _initialize_train(self):
444444

445445
self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)
446446

447-
def _build_train_input(self):
447+
def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
448448
"""Loads the (infinitely looping) dataset iterator."""
449449
num_devices = jax.device_count()
450450
global_batch_size = self._batch_size
@@ -463,10 +463,10 @@ def _build_train_input(self):
463463

464464
def _eval_batch(
465465
self,
466-
params,
467-
state,
468-
batch,
469-
):
466+
params: hk.Params,
467+
state: hk.State,
468+
batch: dataset.Batch,
469+
) -> Mapping[Text, jnp.ndarray]:
470470
"""Evaluates a batch.
471471
472472
Args:
@@ -494,7 +494,7 @@ def _eval_batch(
494494
'top5_accuracy': top5_correct,
495495
}
496496

497-
def _eval_epoch(self, subset, batch_size):
497+
def _eval_epoch(self, subset: Text, batch_size: int):
498498
"""Evaluates an epoch."""
499499
num_samples = 0.
500500
summed_scalars = None

byol/configs/byol.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_EMA_PRESETS = {40: 0.97, 100: 0.99, 300: 0.99, 1000: 0.996}
2424

2525

26-
def get_config(num_epochs, batch_size):
26+
def get_config(num_epochs: int, batch_size: int):
2727
"""Return config object, containing all hyperparameters for training."""
2828
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
2929

byol/configs/eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from byol.utils import dataset
2020

2121

22-
def get_config(checkpoint_to_evaluate, batch_size):
22+
def get_config(checkpoint_to_evaluate: Text, batch_size: int):
2323
"""Return config object for training."""
2424
train_images_per_epoch = dataset.Split.TRAIN_AND_VALID.num_examples
2525

0 commit comments

Comments
 (0)