@@ -236,7 +236,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
236
236
237
237
# MLP block.
238
238
y = nn .LayerNorm (dtype = cfg .dtype )(x ) if pre_ln else x
239
- y = MlpBlock (config = cfg )(y )
239
+ y = MlpBlock (config = cfg )(y , dropout_rate = dropout_rate )
240
240
241
241
return x + y if pre_ln else nn .LayerNorm (dtype = cfg .dtype )(x + y )
242
242
@@ -324,7 +324,7 @@ def __call__(
324
324
325
325
# MLP block.
326
326
z = nn .LayerNorm (dtype = cfg .dtype )(y ) if pre_ln else y
327
- z = MlpBlock (config = cfg )(z )
327
+ z = MlpBlock (config = cfg )(z , dropout_rate = dropout_rate )
328
328
329
329
return y + z if pre_ln else nn .LayerNorm (dtype = cfg .dtype )(y + z )
330
330
@@ -382,7 +382,7 @@ def __call__(
382
382
383
383
# Input Encoder
384
384
for lyr in range (cfg .num_layers ):
385
- x = Encoder1DBlock (config = cfg , name = f"encoderblock_{ lyr } " )(x , encoder_mask )
385
+ x = Encoder1DBlock (config = cfg , name = f"encoderblock_{ lyr } " )(x , encoder_mask , dropout_rate )
386
386
387
387
encoded = (
388
388
nn .LayerNorm (dtype = cfg .dtype , name = "encoder_layernorm" )(x )
@@ -464,6 +464,7 @@ def __call__(
464
464
encoded ,
465
465
decoder_mask = decoder_mask ,
466
466
encoder_decoder_mask = encoder_decoder_mask ,
467
+ dropout_rate = dropout_rate ,
467
468
)
468
469
y = (
469
470
nn .LayerNorm (dtype = cfg .dtype , name = "encoderdecoder_layernorm" )(y )
@@ -503,7 +504,7 @@ def setup(self):
503
504
self .encoder = Encoder (config = cfg , shared_embedding = self .shared_embedding )
504
505
self .decoder = Decoder (config = cfg , shared_embedding = self .shared_embedding )
505
506
506
- def encode (self , inputs , inputs_positions = None , inputs_segmentation = None ):
507
+ def encode (self , inputs , inputs_positions = None , inputs_segmentation = None , dropout_rate = None ):
507
508
"""Applies Transformer encoder-branch on the inputs.
508
509
509
510
Args:
@@ -528,7 +529,7 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
528
529
jnp .equal ,
529
530
dtype = cfg .dtype ))
530
531
return self .encoder (
531
- inputs , inputs_positions = inputs_positions , encoder_mask = encoder_mask )
532
+ inputs , inputs_positions = inputs_positions , encoder_mask = encoder_mask , dropout_rate = dropout_rate )
532
533
533
534
def decode (
534
535
self ,
@@ -595,7 +596,8 @@ def __call__(self,
595
596
inputs_positions = None ,
596
597
targets_positions = None ,
597
598
inputs_segmentation = None ,
598
- targets_segmentation = None ):
599
+ targets_segmentation = None ,
600
+ dropout_rate = None ):
599
601
"""Applies Transformer model on the inputs.
600
602
601
603
Args:
@@ -612,12 +614,14 @@ def __call__(self,
612
614
encoded = self .encode (
613
615
inputs ,
614
616
inputs_positions = inputs_positions ,
615
- inputs_segmentation = inputs_segmentation )
617
+ inputs_segmentation = inputs_segmentation ,
618
+ dropout_rate = dropout_rate )
616
619
617
620
return self .decode (
618
621
encoded ,
619
622
inputs , # only used for masks
620
623
targets ,
621
624
targets_positions = targets_positions ,
622
625
inputs_segmentation = inputs_segmentation ,
623
- targets_segmentation = targets_segmentation )
626
+ targets_segmentation = targets_segmentation ,
627
+ dropout_rate = dropout_rate )
0 commit comments