Skip to content

Commit 363da8a

Browse files
committed
dropout fix for criteo1tb jax
1 parent e36d294 commit 363da8a

File tree

1 file changed

+6
-5
lines changed
  • algoperf/workloads/criteo1tb/criteo1tb_jax

1 file changed

+6
-5
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/models.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DLRMResNet(nn.Module):
3030

3131
@nn.compact
3232
def __call__(self, x, train, dropout_rate=None):
33-
if not dropout_rate:
33+
if dropout_rate is None:
3434
dropout_rate = self.dropout_rate
3535

3636
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
@@ -93,7 +93,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
9393
top_mlp_input)
9494
x = nn.relu(x)
9595
if dropout_rate and layer_idx == num_layers_top - 2:
96-
x = Dropout(deterministic=not train)(x, rate=dropout_rate)
96+
x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
9797
top_mlp_input += x
9898
# In the DLRM model the last layer width is always 1. We can hardcode that
9999
# below.
@@ -156,7 +156,7 @@ class DlrmSmall(nn.Module):
156156

157157
@nn.compact
158158
def __call__(self, x, train, dropout_rate=None):
159-
if not dropout_rate:
159+
if dropout_rate is None:
160160
dropout_rate = self.dropout_rate
161161

162162
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
@@ -219,7 +219,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
219219
top_mlp_input = nn.LayerNorm()(top_mlp_input)
220220
if (dropout_rate is not None and dropout_rate > 0.0 and
221221
layer_idx == num_layers_top - 2):
222-
top_mlp_input = Dropout(deterministic=not train)(
223-
top_mlp_input, rate=dropout_rate)
222+
top_mlp_input = Dropout(
223+
dropout_rate, deterministic=not train)(
224+
top_mlp_input, rate=dropout_rate)
224225
logits = top_mlp_input
225226
return logits

0 commit comments

Comments
 (0)