@@ -56,17 +56,17 @@ class ByolExperiment:
56
56
57
57
def __init__ (
58
58
self ,
59
- random_seed ,
60
- num_classes ,
61
- batch_size ,
62
- max_steps ,
63
- enable_double_transpose ,
64
- base_target_ema ,
65
- network_config ,
66
- optimizer_config ,
67
- lr_schedule_config ,
68
- evaluation_config ,
69
- checkpointing_config ):
59
+ random_seed : int ,
60
+ num_classes : int ,
61
+ batch_size : int ,
62
+ max_steps : int ,
63
+ enable_double_transpose : bool ,
64
+ base_target_ema : float ,
65
+ network_config : Mapping [ Text , Any ] ,
66
+ optimizer_config : Mapping [ Text , Any ] ,
67
+ lr_schedule_config : Mapping [ Text , Any ] ,
68
+ evaluation_config : Mapping [ Text , Any ] ,
69
+ checkpointing_config : Mapping [ Text , Any ] ):
70
70
"""Constructs the experiment.
71
71
72
72
Args:
@@ -115,15 +115,15 @@ def __init__(
115
115
116
116
def _forward (
117
117
self ,
118
- inputs ,
119
- projector_hidden_size ,
120
- projector_output_size ,
121
- predictor_hidden_size ,
122
- encoder_class ,
123
- encoder_config ,
124
- bn_config ,
125
- is_training ,
126
- ):
118
+ inputs : dataset . Batch ,
119
+ projector_hidden_size : int ,
120
+ projector_output_size : int ,
121
+ predictor_hidden_size : int ,
122
+ encoder_class : Text ,
123
+ encoder_config : Mapping [ Text , Any ] ,
124
+ bn_config : Mapping [ Text , Any ] ,
125
+ is_training : bool ,
126
+ ) -> Mapping [ Text , jnp . ndarray ] :
127
127
"""Forward application of byol's architecture.
128
128
129
129
Args:
@@ -163,7 +163,7 @@ def _forward(
163
163
classifier = hk .Linear (
164
164
output_size = self ._num_classes , name = 'classifier' )
165
165
166
- def apply_once_fn (images , suffix = '' ):
166
+ def apply_once_fn (images : jnp . ndarray , suffix : Text = '' ):
167
167
images = dataset .normalize_images (images )
168
168
169
169
embedding = net (images , is_training = is_training )
@@ -186,7 +186,7 @@ def apply_once_fn(images, suffix = ''):
186
186
else :
187
187
return apply_once_fn (inputs ['images' ], '' )
188
188
189
- def _optimizer (self , learning_rate ) :
189
+ def _optimizer (self , learning_rate : float ) -> optax . GradientTransformation :
190
190
"""Build optimizer from config."""
191
191
return optimizers .lars (
192
192
learning_rate ,
@@ -196,13 +196,13 @@ def _optimizer(self, learning_rate):
196
196
197
197
def loss_fn (
198
198
self ,
199
- online_params ,
200
- target_params ,
201
- online_state ,
202
- target_state ,
203
- rng ,
204
- inputs ,
205
- ):
199
+ online_params : hk . Params ,
200
+ target_params : hk . Params ,
201
+ online_state : hk . State ,
202
+ target_state : hk . Params ,
203
+ rng : jnp . ndarray ,
204
+ inputs : dataset . Batch ,
205
+ ) -> Tuple [ jnp . ndarray , Tuple [ Mapping [ Text , hk . State ], LogsDict ]] :
206
206
"""Compute BYOL's loss function.
207
207
208
208
Args:
@@ -292,11 +292,11 @@ def _should_transpose_images(self):
292
292
293
293
def _update_fn (
294
294
self ,
295
- byol_state ,
296
- global_step ,
297
- rng ,
298
- inputs ,
299
- ):
295
+ byol_state : _ByolExperimentState ,
296
+ global_step : jnp . ndarray ,
297
+ rng : jnp . ndarray ,
298
+ inputs : dataset . Batch ,
299
+ ) -> Tuple [ _ByolExperimentState , LogsDict ] :
300
300
"""Update online and target parameters.
301
301
302
302
Args:
@@ -352,9 +352,9 @@ def _update_fn(
352
352
353
353
def _make_initial_state (
354
354
self ,
355
- rng ,
356
- dummy_input ,
357
- ):
355
+ rng : jnp . ndarray ,
356
+ dummy_input : dataset . Batch ,
357
+ ) -> _ByolExperimentState :
358
358
"""BYOL's _ByolExperimentState initialization.
359
359
360
360
Args:
@@ -393,8 +393,8 @@ def _make_initial_state(
393
393
)
394
394
395
395
def step (self , * ,
396
- global_step ,
397
- rng ) :
396
+ global_step : jnp . ndarray ,
397
+ rng : jnp . ndarray ) -> Mapping [ Text , np . ndarray ] :
398
398
"""Performs a single training step."""
399
399
if self ._train_input is None :
400
400
self ._initialize_train ()
@@ -410,11 +410,11 @@ def step(self, *,
410
410
411
411
return helpers .get_first (scalars )
412
412
413
- def save_checkpoint (self , step , rng ):
413
+ def save_checkpoint (self , step : int , rng : jnp . ndarray ):
414
414
self ._checkpointer .maybe_save_checkpoint (
415
415
self ._byol_state , step = step , rng = rng , is_final = step >= self ._max_steps )
416
416
417
- def load_checkpoint (self ):
417
+ def load_checkpoint (self ) -> Union [ Tuple [ int , jnp . ndarray ], None ] :
418
418
checkpoint_data = self ._checkpointer .maybe_load_checkpoint ()
419
419
if checkpoint_data is None :
420
420
return None
@@ -444,7 +444,7 @@ def _initialize_train(self):
444
444
445
445
self ._byol_state = init_byol (rng = init_rng , dummy_input = inputs )
446
446
447
- def _build_train_input (self ):
447
+ def _build_train_input (self ) -> Generator [ dataset . Batch , None , None ] :
448
448
"""Loads the (infinitely looping) dataset iterator."""
449
449
num_devices = jax .device_count ()
450
450
global_batch_size = self ._batch_size
@@ -463,10 +463,10 @@ def _build_train_input(self):
463
463
464
464
def _eval_batch (
465
465
self ,
466
- params ,
467
- state ,
468
- batch ,
469
- ):
466
+ params : hk . Params ,
467
+ state : hk . State ,
468
+ batch : dataset . Batch ,
469
+ ) -> Mapping [ Text , jnp . ndarray ] :
470
470
"""Evaluates a batch.
471
471
472
472
Args:
@@ -494,7 +494,7 @@ def _eval_batch(
494
494
'top5_accuracy' : top5_correct ,
495
495
}
496
496
497
- def _eval_epoch (self , subset , batch_size ):
497
+ def _eval_epoch (self , subset : Text , batch_size : int ):
498
498
"""Evaluates an epoch."""
499
499
num_samples = 0.
500
500
summed_scalars = None
0 commit comments