Skip to content

Commit f0c385b

Browse files
committed
remove aux dropout option from conformer and from init_model_fn signature for fastmri, vit and criteo
1 parent 341bf89 commit f0c385b

File tree

4 files changed

+20
-17
lines changed

4 files changed

+20
-17
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,9 @@ def init_model_fn(
7373
self,
7474
rng: spec.RandomState,
7575
dropout_rate: Optional[float] = None,
76-
aux_dropout_rate: Optional[float] = None,
7776
tabulate: Optional[bool] = False,
7877
) -> spec.ModelInitState:
7978
"""Only dropout is used."""
80-
del aux_dropout_rate
8179
if self.use_resnet:
8280
model_class = models.DLRMResNet
8381
else:

algoperf/workloads/fastmri/fastmri_jax/workload.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def init_model_fn(
2222
self,
2323
rng: spec.RandomState,
2424
dropout_rate: Optional[float] = None,
25-
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
25+
) -> spec.ModelInitState:
2626
"""aux_dropout_rate is unused."""
27-
del aux_dropout_rate
2827
fake_batch = jnp.zeros((13, 320, 320))
2928
if dropout_rate is None:
3029
self._model = UNet(

algoperf/workloads/imagenet_vit/imagenet_jax/workload.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def initialized(self, key: spec.RandomState,
3333
def init_model_fn(
3434
self,
3535
rng: spec.RandomState,
36-
dropout_rate: Optional[float] = None,
37-
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
38-
del aux_dropout_rate
36+
dropout_rate: Optional[float] = None) -> spec.ModelInitState:
3937
if dropout_rate is None:
4038
self._model = models.ViT(
4139
num_classes=self._num_classes,

algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,32 @@ def init_model_fn(
6161
self,
6262
rng: spec.RandomState,
6363
dropout_rate: Optional[float] = None,
64-
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
64+
) -> spec.ModelInitState:
6565
"""Conformer model init function.
6666
67-
Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as
67+
Here we use dropout_rate as *_residual_dropout_rate, and for
6868
input_dropout_rate.
6969
"""
7070
if self.use_gelu:
7171
activation_function_name = 'gelu'
7272
else:
7373
activation_function_name = 'swish'
74-
model_config = models.ConformerConfig(
75-
attention_residual_dropout_rate=dropout_rate,
76-
feed_forward_residual_dropout_rate=dropout_rate,
77-
input_dropout_rate=aux_dropout_rate,
78-
use_specaug=self.use_specaug,
79-
attention_temperature=self.attention_temperature,
80-
use_post_layer_norm=self.use_post_layer_norm,
81-
activation_function_name=activation_function_name)
74+
if dropout_rate is None:
75+
model_config = models.ConformerConfig(
76+
attention_residual_dropout_rate=dropout_rate,
77+
feed_forward_residual_dropout_rate=dropout_rate,
78+
input_dropout_rate=dropout_rate,
79+
use_specaug=self.use_specaug,
80+
attention_temperature=self.attention_temperature,
81+
use_post_layer_norm=self.use_post_layer_norm,
82+
activation_function_name=activation_function_name)
83+
else:
84+
model_config = models.ConformerConfig(
85+
use_specaug=self.use_specaug,
86+
attention_temperature=self.attention_temperature,
87+
use_post_layer_norm=self.use_post_layer_norm,
88+
activation_function_name=activation_function_name)
89+
8290
self._model = models.Conformer(model_config)
8391
input_shape = [(320000,), (320000,)]
8492
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]

0 commit comments

Comments
 (0)