@@ -30,7 +30,7 @@ class DLRMResNet(nn.Module):
30
30
31
31
@nn .compact
32
32
def __call__ (self , x , train , dropout_rate = None ):
33
- if not dropout_rate :
33
+ if dropout_rate is None :
34
34
dropout_rate = self .dropout_rate
35
35
36
36
bot_mlp_input , cat_features = jnp .split (x , [self .num_dense_features ], 1 )
@@ -93,7 +93,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
93
93
top_mlp_input )
94
94
x = nn .relu (x )
95
95
if dropout_rate and layer_idx == num_layers_top - 2 :
96
- x = Dropout (deterministic = not train )(x , rate = dropout_rate )
96
+ x = Dropout (dropout_rate , deterministic = not train )(x , rate = dropout_rate )
97
97
top_mlp_input += x
98
98
# In the DLRM model the last layer width is always 1. We can hardcode that
99
99
# below.
@@ -156,7 +156,7 @@ class DlrmSmall(nn.Module):
156
156
157
157
@nn .compact
158
158
def __call__ (self , x , train , dropout_rate = None ):
159
- if not dropout_rate :
159
+ if dropout_rate is None :
160
160
dropout_rate = self .dropout_rate
161
161
162
162
bot_mlp_input , cat_features = jnp .split (x , [self .num_dense_features ], 1 )
@@ -219,7 +219,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
219
219
top_mlp_input = nn .LayerNorm ()(top_mlp_input )
220
220
if (dropout_rate is not None and dropout_rate > 0.0 and
221
221
layer_idx == num_layers_top - 2 ):
222
- top_mlp_input = Dropout (deterministic = not train )(
223
- top_mlp_input , rate = dropout_rate )
222
+ top_mlp_input = Dropout (
223
+ dropout_rate , deterministic = not train )(
224
+ top_mlp_input , rate = dropout_rate )
224
225
logits = top_mlp_input
225
226
return logits
0 commit comments