Skip to content

Commit 7af5c94

Browse files
committed
add dropout piping for conformer and deepspeech
1 parent f0c385b commit 7af5c94

File tree

8 files changed

+83
-59
lines changed

8 files changed

+83
-59
lines changed

algoperf/workloads/fastmri/fastmri_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def init_model_fn(
2222
self,
2323
rng: spec.RandomState,
2424
dropout_rate: Optional[float] = None,
25-
) -> spec.ModelInitState:
25+
) -> spec.ModelInitState:
2626
"""aux_dropout_rate is unused."""
2727
fake_batch = jnp.zeros((13, 320, 320))
2828
if dropout_rate is None:

algoperf/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,10 @@ class ConformerConfig:
3838
num_attention_heads: int = 8
3939
num_encoder_layers: int = 4
4040
attention_dropout_rate: float = 0.0
41-
# If None, defaults to 0.1.
42-
attention_residual_dropout_rate: Optional[float] = 0.1
43-
# If None, defaults to 0.0.
41+
attention_residual_dropout_rate: Optional[float] = 0.0
4442
conv_residual_dropout_rate: Optional[float] = 0.0
4543
feed_forward_dropout_rate: float = 0.0
46-
# If None, defaults to 0.1.
47-
feed_forward_residual_dropout_rate: Optional[float] = 0.1
44+
feed_forward_residual_dropout_rate: Optional[float] = 0.0
4845
convolution_kernel_size: int = 5
4946
feed_forward_expansion_factor: int = 4
5047
freq_mask_count: int = 2

algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def init_model_fn(
6161
self,
6262
rng: spec.RandomState,
6363
dropout_rate: Optional[float] = None,
64-
) -> spec.ModelInitState:
64+
) -> spec.ModelInitState:
6565
"""Conformer model init function.
6666
6767
Here we use dropout_rate as *_residual_dropout_rate, and for

algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class Subsample(nn.Module):
7575
@nn.compact
7676
def __call__(self, inputs, output_paddings, train, dropout_rate=None):
7777
config = self.config
78+
if dropout_rate is None:
79+
dropout_rate = config.dropout_rate
80+
7881
outputs = jnp.expand_dims(inputs, axis=-1)
7982

8083
outputs, output_paddings = Conv2dSubsampling(
@@ -111,7 +114,9 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None):
111114
input_dropout_rate = 0.1
112115
else:
113116
input_dropout_rate = config.input_dropout_rate
114-
outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs)
117+
outputs = Dropout(
118+
rate=input_dropout_rate, deterministic=not train, rate=dropout_rate)(
119+
outputs, rate=dropout_rate)
115120

116121
return outputs, output_paddings
117122

@@ -187,7 +192,13 @@ class FeedForwardModule(nn.Module):
187192
config: DeepspeechConfig
188193

189194
@nn.compact
190-
def __call__(self, inputs, input_paddings=None, train=False):
195+
def __call__(self,
196+
inputs,
197+
input_paddings=None,
198+
train=False,
199+
dropout_rate=None):
200+
if dropout_rate is None:
201+
dropout_rate = self.config.feed_forward_dropout_rate
191202
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
192203
config = self.config
193204

@@ -211,12 +222,8 @@ def __call__(self, inputs, input_paddings=None, train=False):
211222
inputs = nn.relu(inputs)
212223
inputs *= padding_mask
213224

214-
if config.feed_forward_dropout_rate is None:
215-
feed_forward_dropout_rate = 0.1
216-
else:
217-
feed_forward_dropout_rate = config.feed_forward_dropout_rate
218-
inputs = Dropout(rate=feed_forward_dropout_rate)(
219-
inputs, deterministic=not train)
225+
inputs = Dropout(rate=dropout_rate)(
226+
inputs, deterministic=not train, rate=dropout_rate)
220227

221228
return inputs
222229

@@ -472,8 +479,10 @@ def setup(self):
472479
)
473480

474481
@nn.compact
475-
def __call__(self, inputs, input_paddings, train):
482+
def __call__(self, inputs, input_paddings, train, dropout_rate=None):
476483
config = self.config
484+
if dropout_rate is None:
485+
dropout_rate = config.dropout_rate
477486

478487
outputs = inputs
479488
output_paddings = input_paddings
@@ -493,7 +502,7 @@ def __call__(self, inputs, input_paddings, train):
493502

494503
# Subsample input by a factor of 4 by performing strided convolutions.
495504
outputs, output_paddings = Subsample(
496-
config=config)(outputs, output_paddings, train)
505+
config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate)
497506

498507
# Run the lstm layers.
499508
for _ in range(config.num_lstm_layers):
@@ -507,9 +516,8 @@ def __call__(self, inputs, input_paddings, train):
507516
outputs = outputs + FeedForwardModule(config=self.config)(
508517
outputs, output_paddings, train)
509518
else:
510-
outputs = FeedForwardModule(config=self.config)(outputs,
511-
output_paddings,
512-
train)
519+
outputs = FeedForwardModule(config=self.config)(
520+
outputs, output_paddings, train, dropout_rate=dropout_rate)
513521

514522
# Run the decoder which in this case is a trivial projection layer.
515523
if config.enable_decoder_layer_norm:

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,31 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
1818
def init_model_fn(
1919
self,
2020
rng: spec.RandomState,
21-
dropout_rate: Optional[float] = None,
22-
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
21+
dropout_rate: Optional[float] = None) -> spec.ModelInitState:
2322
"""Deepspeech model init function.
24-
25-
Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
26-
as input_dropout_rate.
2723
"""
28-
model_config = models.DeepspeechConfig(
29-
feed_forward_dropout_rate=dropout_rate,
30-
use_specaug=self.use_specaug,
31-
input_dropout_rate=aux_dropout_rate,
32-
use_tanh=self.use_tanh,
33-
enable_residual_connections=self.enable_residual_connections,
34-
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
35-
layernorm_everywhere=self.layernorm_everywhere,
36-
freq_mask_count=self.freq_mask_count,
37-
time_mask_count=self.time_mask_count,
38-
)
24+
if dropout_rate is None:
25+
model_config = models.DeepspeechConfig(
26+
use_specaug=self.use_specaug,
27+
use_tanh=self.use_tanh,
28+
enable_residual_connections=self.enable_residual_connections,
29+
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
30+
layernorm_everywhere=self.layernorm_everywhere,
31+
freq_mask_count=self.freq_mask_count,
32+
time_mask_count=self.time_mask_count,
33+
)
34+
else:
35+
model_config = models.DeepspeechConfig(
36+
feed_forward_dropout_rate=dropout_rate,
37+
use_specaug=self.use_specaug,
38+
input_dropout_rate=dropout_rate,
39+
use_tanh=self.use_tanh,
40+
enable_residual_connections=self.enable_residual_connections,
41+
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
42+
layernorm_everywhere=self.layernorm_everywhere,
43+
freq_mask_count=self.freq_mask_count,
44+
time_mask_count=self.time_mask_count,
45+
)
3946
self._model = models.Deepspeech(model_config)
4047
input_shape = [(320000,), (320000,)]
4148
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
@@ -64,6 +71,7 @@ def model_fn(
6471
rng: spec.RandomState,
6572
update_batch_norm: bool,
6673
use_running_average_bn: Optional[bool] = None
74+
dropout_rate: Optional[bool] = None
6775
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
6876
variables = {'params': params, **model_state}
6977
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
@@ -75,7 +83,8 @@ def model_fn(
7583
input_paddings,
7684
train=True,
7785
rngs={'dropout' : rng},
78-
mutable=['batch_stats'])
86+
mutable=['batch_stats'],
87+
dropout_rate=dropout_rate)
7988
return (logits, logit_paddings), new_model_state
8089
else:
8190
logits, logit_paddings = self._model.apply(

algoperf/workloads/ogbg/ogbg_jax/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class GNN(nn.Module):
4848

4949
@nn.compact
5050
def __call__(self, graph, train, dropout_rate=None):
51-
if not dropout_rate:
51+
if dropout_rate is not None:
5252
dropout_rate = self.dropout_rate
53-
dropout = Dropout(deterministic=not train, rate=dropout_rate)
53+
dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate)
5454

5555
graph = graph._replace(
5656
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))

algoperf/workloads/ogbg/ogbg_jax/workload.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,24 @@ class OgbgWorkload(BaseOgbgWorkload):
2020
def init_model_fn(
2121
self,
2222
rng: spec.RandomState,
23-
dropout_rate: Optional[float] = None,
24-
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
23+
dropout_rate: Optional[float] = None) -> spec.ModelInitState:
2524
"""aux_dropout_rate is unused."""
26-
del aux_dropout_rate
2725
rng, params_rng, dropout_rng = jax.random.split(rng, 3)
28-
self._model = models.GNN(
29-
self._num_outputs,
30-
dropout_rate=dropout_rate,
31-
activation_fn_name=self.activation_fn_name,
32-
hidden_dims=self.hidden_dims,
33-
latent_dim=self.latent_dim,
34-
num_message_passing_steps=self.num_message_passing_steps)
26+
if dropout_rate is None:
27+
self._model = models.GNN(
28+
self._num_outputs,
29+
activation_fn_name=self.activation_fn_name,
30+
hidden_dims=self.hidden_dims,
31+
latent_dim=self.latent_dim,
32+
num_message_passing_steps=self.num_message_passing_steps)
33+
else:
34+
self._model = models.GNN(
35+
self._num_outputs,
36+
dropout_rate=dropout_rate,
37+
activation_fn_name=self.activation_fn_name,
38+
hidden_dims=self.hidden_dims,
39+
latent_dim=self.latent_dim,
40+
num_message_passing_steps=self.num_message_passing_steps)
3541
init_fn = jax.jit(functools.partial(self._model.init, train=False))
3642
fake_batch = jraph.GraphsTuple(
3743
n_node=jnp.asarray([1]),

algoperf/workloads/wmt/wmt_jax/workload.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,7 @@ def translate_and_calculate_bleu(self,
209209
def init_model_fn(
210210
self,
211211
rng: spec.RandomState,
212-
dropout_rate: Optional[float] = 0.0,
213-
aux_dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState:
214-
"""aux_dropout_rate is used as attention_dropout_rate."""
215-
212+
dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState:
216213
init_fake_batch_size = 2
217214
input_shape = (init_fake_batch_size, 256)
218215
target_shape = (init_fake_batch_size, 256)
@@ -224,13 +221,20 @@ def init_model_fn(
224221
else:
225222
raise ValueError(f'Unknown activation function {self.activation}.')
226223

224+
if dropout_rate is None:
225+
model_config = models.TransformerConfig(
226+
pre_ln=self.pre_ln,
227+
attention_temp=self.attention_temp,
228+
activation=activation,
229+
glu=self.glu)
230+
else:
227231
model_config = models.TransformerConfig(
228-
dropout_rate=dropout_rate,
229-
attention_dropout_rate=aux_dropout_rate,
230-
pre_ln=self.pre_ln,
231-
attention_temp=self.attention_temp,
232-
activation=activation,
233-
glu=self.glu)
232+
dropout_rate=dropout_rate,
233+
attention_dropout_rate=dropout_rate,
234+
pre_ln=self.pre_ln,
235+
attention_temp=self.attention_temp,
236+
activation=activation,
237+
glu=self.glu)
234238
self._train_model = models.Transformer(model_config)
235239
eval_config = replace(model_config, deterministic=True)
236240
self._eval_model = models.Transformer(eval_config)

0 commit comments

Comments
 (0)