Skip to content

Commit 31f6019

Browse files
committed
modify wmt model for dropout passing
1 parent 9bba078 commit 31f6019

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class Subsample(nn.Module):
7373
config: DeepspeechConfig
7474

7575
@nn.compact
76-
def __call__(self, inputs, output_paddings, train):
76+
def __call__(self, inputs, output_paddings, train, dropout_rate=None):
7777
config = self.config
7878
outputs = jnp.expand_dims(inputs, axis=-1)
7979

algoperf/workloads/wmt/wmt_jax/models.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
236236

237237
# MLP block.
238238
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
239-
y = MlpBlock(config=cfg)(y)
239+
y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate)
240240

241241
return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
242242

@@ -324,7 +324,7 @@ def __call__(
324324

325325
# MLP block.
326326
z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
327-
z = MlpBlock(config=cfg)(z)
327+
z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate)
328328

329329
return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
330330

@@ -382,7 +382,7 @@ def __call__(
382382

383383
# Input Encoder
384384
for lyr in range(cfg.num_layers):
385-
x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask)
385+
x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate)
386386

387387
encoded = (
388388
nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x)
@@ -464,6 +464,7 @@ def __call__(
464464
encoded,
465465
decoder_mask=decoder_mask,
466466
encoder_decoder_mask=encoder_decoder_mask,
467+
dropout_rate=dropout_rate,
467468
)
468469
y = (
469470
nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y)
@@ -503,7 +504,7 @@ def setup(self):
503504
self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding)
504505
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding)
505506

506-
def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
507+
def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropout_rate=None):
507508
"""Applies Transformer encoder-branch on the inputs.
508509
509510
Args:
@@ -528,7 +529,7 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
528529
jnp.equal,
529530
dtype=cfg.dtype))
530531
return self.encoder(
531-
inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask)
532+
inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, dropout_rate=dropout_rate)
532533

533534
def decode(
534535
self,
@@ -595,7 +596,8 @@ def __call__(self,
595596
inputs_positions=None,
596597
targets_positions=None,
597598
inputs_segmentation=None,
598-
targets_segmentation=None):
599+
targets_segmentation=None,
600+
dropout_rate=None):
599601
"""Applies Transformer model on the inputs.
600602
601603
Args:
@@ -612,12 +614,14 @@ def __call__(self,
612614
encoded = self.encode(
613615
inputs,
614616
inputs_positions=inputs_positions,
615-
inputs_segmentation=inputs_segmentation)
617+
inputs_segmentation=inputs_segmentation,
618+
dropout_rate=dropout_rate)
616619

617620
return self.decode(
618621
encoded,
619622
inputs, # only used for masks
620623
targets,
621624
targets_positions=targets_positions,
622625
inputs_segmentation=inputs_segmentation,
623-
targets_segmentation=targets_segmentation)
626+
targets_segmentation=targets_segmentation,
627+
dropout_rate=dropout_rate)

0 commit comments

Comments
 (0)