Skip to content

Commit e36d294

Browse files
committed
reformatting and dropout fixes to fastmri and vit
1 parent 31f6019 commit e36d294

File tree

7 files changed

+362
-310
lines changed

7 files changed

+362
-310
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from algoperf.jax_utils import Dropout
99

10+
1011
class DLRMResNet(nn.Module):
1112
"""Define a DLRMResNet model.
1213
@@ -30,7 +31,7 @@ class DLRMResNet(nn.Module):
3031
@nn.compact
3132
def __call__(self, x, train, dropout_rate=None):
3233
if not dropout_rate:
33-
dropout_rate=self.dropout_rate
34+
dropout_rate = self.dropout_rate
3435

3536
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
3637
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -157,7 +158,7 @@ class DlrmSmall(nn.Module):
157158
def __call__(self, x, train, dropout_rate=None):
158159
if not dropout_rate:
159160
dropout_rate = self.dropout_rate
160-
161+
161162
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
162163
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
163164

@@ -219,6 +220,6 @@ def scaled_init(key, shape, dtype=jnp.float_):
219220
if (dropout_rate is not None and dropout_rate > 0.0 and
220221
layer_idx == num_layers_top - 2):
221222
top_mlp_input = Dropout(deterministic=not train)(
222-
top_mlp_input, rate=dropout_rate)
223+
top_mlp_input, rate=dropout_rate)
223224
logits = top_mlp_input
224225
return logits

algoperf/workloads/fastmri/fastmri_jax/models.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from algoperf.jax_utils import Dropout
2323

24+
2425
def _instance_norm2d(x, axes, epsilon=1e-5):
2526
# promote x to at least float32, this avoids half precision computation
2627
# but preserves double or complex floating points
@@ -57,13 +58,13 @@ class UNet(nn.Module):
5758
num_channels: int = 32
5859
num_pool_layers: int = 4
5960
out_channels = 1
60-
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
61+
dropout_rate: float = 0.0
6162
use_tanh: bool = False
6263
use_layer_norm: bool = False
6364

6465
@nn.compact
6566
def __call__(self, x, train=True, dropout_rate=None):
66-
if not dropout_rate:
67+
if dropout_rate is None:
6768
dropout_rate = self.dropout_rate
6869

6970
# pylint: disable=invalid-name
@@ -138,7 +139,7 @@ class ConvBlock(nn.Module):
138139
dropout_rate: Dropout probability.
139140
"""
140141
out_channels: int
141-
dropout_rate: float
142+
dropout_rate: float = 0.0
142143
use_tanh: bool
143144
use_layer_norm: bool
144145

@@ -152,8 +153,8 @@ def __call__(self, x, train=True, dropout_rate=None):
152153
Returns:
153154
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
154155
"""
155-
if not dropout_rate:
156-
dropout_rate=self.dropout_rate
156+
if dropout_rate is None:
157+
dropout_rate = self.dropout_rate
157158
x = nn.Conv(
158159
features=self.out_channels,
159160
kernel_size=(3, 3),
@@ -174,8 +175,9 @@ def __call__(self, x, train=True, dropout_rate=None):
174175
x = activation_fn(x)
175176
# Ref code uses dropout2d which applies the same mask for the entire channel
176177
# Replicated by using broadcast dims to have the same filter on HW
177-
x = Dropout(broadcast_dims=(1, 2), deterministic=not train)(
178-
x, rate=dropout_rate )
178+
x = Dropout(
179+
dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
180+
x, rate=dropout_rate)
179181
x = nn.Conv(
180182
features=self.out_channels,
181183
kernel_size=(3, 3),
@@ -188,7 +190,7 @@ def __call__(self, x, train=True, dropout_rate=None):
188190
x = _instance_norm2d(x, (1, 2))
189191
x = activation_fn(x)
190192
x = Dropout(
191-
broadcast_dims=(1, 2), deterministic=not train)(
193+
dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
192194
x, rate=dropout_rate)
193195
return x
194196

algoperf/workloads/fastmri/fastmri_jax/workload.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,21 @@ def init_model_fn(
2626
"""aux_dropout_rate is unused."""
2727
del aux_dropout_rate
2828
fake_batch = jnp.zeros((13, 320, 320))
29-
self._model = UNet(
30-
num_pool_layers=self.num_pool_layers,
31-
num_channels=self.num_channels,
32-
use_tanh=self.use_tanh,
33-
use_layer_norm=self.use_layer_norm,
34-
dropout_rate=dropout_rate)
29+
if dropout_rate is None:
30+
self._model = UNet(
31+
num_pool_layers=self.num_pool_layers,
32+
num_channels=self.num_channels,
33+
use_tanh=self.use_tanh,
34+
use_layer_norm=self.use_layer_norm,
35+
)
36+
else:
37+
self._model = UNet(
38+
num_pool_layers=self.num_pool_layers,
39+
num_channels=self.num_channels,
40+
use_tanh=self.use_tanh,
41+
use_layer_norm=self.use_layer_norm,
42+
dropout_rate=dropout_rate)
43+
3544
params_rng, dropout_rng = jax.random.split(rng)
3645
variables = jax.jit(
3746
self._model.init)({'params': params_rng, 'dropout': dropout_rng},

algoperf/workloads/imagenet_vit/imagenet_jax/models.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ class MlpBlock(nn.Module):
3939
dropout_rate: float = 0.0
4040

4141
@nn.compact
42-
def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor:
42+
def __call__(self,
43+
x: spec.Tensor,
44+
train: bool = True,
45+
dropout_rate=None) -> spec.Tensor:
4346
"""Applies Transformer MlpBlock module."""
44-
if not dropout_rate:
47+
if dropout_rate is None:
4548
dropout_rate = self.dropout_rate
4649

4750
inits = {
@@ -57,7 +60,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe
5760
y = nn.Dense(self.mlp_dim, **inits)(x)
5861
x = x * y
5962

60-
x = Dropout()(x, train, rate=dropout_rate)
63+
x = Dropout(dropout_rate)(x, train, rate=dropout_rate)
6164
x = nn.Dense(d, **inits)(x)
6265
return x
6366

@@ -71,9 +74,12 @@ class Encoder1DBlock(nn.Module):
7174
dropout_rate: float = 0.0
7275

7376
@nn.compact
74-
def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor:
75-
if not dropout_rate:
76-
dropout_rate=self.dropout_rate
77+
def __call__(self,
78+
x: spec.Tensor,
79+
train: bool = True,
80+
dropout_rate=dropout_rate) -> spec.Tensor:
81+
if dropout_rate is None:
82+
dropout_rate = self.dropout_rate
7783

7884
if not self.use_post_layer_norm:
7985
y = nn.LayerNorm(name='LayerNorm_0')(x)
@@ -83,15 +89,14 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
8389
deterministic=train,
8490
name='MultiHeadDotProductAttention_1')(
8591
y)
86-
y = Dropout()(y, train, dropout_rate=dropout_rate)
92+
y = Dropout(dropout_rate)(y, train, dropout_rate=dropout_rate)
8793
x = x + y
8894

8995
y = nn.LayerNorm(name='LayerNorm_2')(x)
9096
y = MlpBlock(
91-
mlp_dim=self.mlp_dim,
92-
use_glu=self.use_glu,
93-
name='MlpBlock_3')(y, train, dropout_rate=dropout_rate)
94-
y = Dropout()(y, train, rate=dropout_rate)
97+
mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3')(
98+
y, train, dropout_rate=dropout_rate)
99+
y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
95100
x = x + y
96101
else:
97102
y = x
@@ -101,16 +106,18 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
101106
deterministic=train,
102107
name='MultiHeadDotProductAttention_1')(
103108
y)
104-
y = Dropout()(y, train, rate=dropout_rate)
109+
y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
105110
x = x + y
106111
x = nn.LayerNorm(name='LayerNorm_0')(x)
107112

108113
y = x
109114
y = MlpBlock(
110115
mlp_dim=self.mlp_dim,
111116
use_glu=self.use_glu,
112-
name='MlpBlock_3')(y, train, dropout_rate=dropout_rate)
113-
y = Dropout()(y, train)(rate=dropout_rate)
117+
name='MlpBlock_3',
118+
dropout_rate=dropout_rate)(
119+
y, train, dropout_rate=dropout_rate)
120+
y = Dropout(dropout_rate)(y, train)(rate=dropout_rate)
114121
x = x + y
115122
x = nn.LayerNorm(name='LayerNorm_2')(x)
116123

@@ -127,9 +134,12 @@ class Encoder(nn.Module):
127134
use_post_layer_norm: bool = False
128135

129136
@nn.compact
130-
def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor:
131-
if not dropout_rate:
132-
dropout_rate=self.dropout_rate
137+
def __call__(self,
138+
x: spec.Tensor,
139+
train: bool = True,
140+
dropout_rate=None) -> spec.Tensor:
141+
if dropout_rate is None:
142+
dropout_rate = self.dropout_rate
133143

134144
# Input Encoder
135145
for lyr in range(self.depth):
@@ -139,7 +149,8 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe
139149
num_heads=self.num_heads,
140150
use_glu=self.use_glu,
141151
use_post_layer_norm=self.use_post_layer_norm,
142-
)(dropout_rate=dropout_rate)
152+
dropout_rate=dropout_rate)(
153+
dropout_rate=dropout_rate)
143154
x = block(x, train)
144155
if not self.use_post_layer_norm:
145156
return nn.LayerNorm(name='encoder_layernorm')(x)
@@ -151,9 +162,12 @@ class MAPHead(nn.Module):
151162
"""Multihead Attention Pooling."""
152163
mlp_dim: Optional[int] = None # Defaults to 4x input dim
153164
num_heads: int = 12
165+
dropout_rate: 0.0
154166

155167
@nn.compact
156-
def __call__(self, x):
168+
def __call__(self, x, dropout_rate=None):
169+
if dropout_rate is None:
170+
dropout_rate = self.dropout_rate
157171
n, _, d = x.shape
158172
probe = self.param('probe',
159173
nn.initializers.xavier_uniform(), (1, 1, d),
@@ -166,7 +180,7 @@ def __call__(self, x):
166180
kernel_init=nn.initializers.xavier_uniform())(probe, x)
167181

168182
y = nn.LayerNorm()(x)
169-
x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
183+
x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y)
170184
return x[:, 0]
171185

172186

@@ -180,7 +194,7 @@ class ViT(nn.Module):
180194
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
181195
num_heads: int = 12
182196
rep_size: Union[int, bool] = True
183-
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
197+
dropout_rate: Optional[float] = 0.0
184198
reinit: Optional[Sequence[str]] = None
185199
head_zeroinit: bool = True
186200
use_glu: bool = False
@@ -194,8 +208,12 @@ def get_posemb(self,
194208
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
195209

196210
@nn.compact
197-
def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> spec.Tensor:
198-
if not dropout_rate:
211+
def __call__(self,
212+
x: spec.Tensor,
213+
*,
214+
train: bool = False,
215+
dropout_rate=None) -> spec.Tensor:
216+
if dropout_rate is None:
199217
dropout_rate = self.dropout_rate
200218
# Patch extraction
201219
x = nn.Conv(
@@ -212,19 +230,24 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) ->
212230
# Add posemb before adding extra token.
213231
x = x + self.get_posemb((h, w), c, x.dtype)
214232

215-
x = Dropout()(x, not train, rate=dropout_rate)
233+
x = Dropout(dropout_rate)(x, not train, rate=dropout_rate)
216234

217235
x = Encoder(
218236
depth=self.depth,
219237
mlp_dim=self.mlp_dim,
220238
num_heads=self.num_heads,
221239
use_glu=self.use_glu,
222240
use_post_layer_norm=self.use_post_layer_norm,
223-
name='Transformer')(
241+
name='Transformer',
242+
dropout_rate=dropout_rate)(
224243
x, train=not train, dropout_rate=dropout_rate)
225244

226245
if self.use_map:
227-
x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
246+
x = MAPHead(
247+
num_heads=self.num_heads,
248+
mlp_dim=self.mlp_dim,
249+
dropout_rate=dropout_rate)(
250+
x, dropout_rate=dropout_rate)
228251
else:
229252
x = jnp.mean(x, axis=1)
230253

algoperf/workloads/imagenet_vit/imagenet_jax/workload.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,21 @@ def init_model_fn(
3636
dropout_rate: Optional[float] = None,
3737
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
3838
del aux_dropout_rate
39-
self._model = models.ViT(
40-
dropout_rate=dropout_rate,
41-
num_classes=self._num_classes,
42-
use_glu=self.use_glu,
43-
use_post_layer_norm=self.use_post_layer_norm,
44-
use_map=self.use_map,
45-
**decode_variant('S/16'))
39+
if dropout_rate is None:
40+
self._model = models.ViT(
41+
num_classes=self._num_classes,
42+
use_glu=self.use_glu,
43+
use_post_layer_norm=self.use_post_layer_norm,
44+
use_map=self.use_map,
45+
**decode_variant('S/16'))
46+
else:
47+
self._model = models.ViT(
48+
dropout_rate=dropout_rate,
49+
num_classes=self._num_classes,
50+
use_glu=self.use_glu,
51+
use_post_layer_norm=self.use_post_layer_norm,
52+
use_map=self.use_map,
53+
**decode_variant('S/16'))
4654
params, model_state = self.initialized(rng, self._model)
4755
self._param_shapes = param_utils.jax_param_shapes(params)
4856
self._param_types = param_utils.jax_param_types(self._param_shapes)

algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,7 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None):
111111
input_dropout_rate = 0.1
112112
else:
113113
input_dropout_rate = config.input_dropout_rate
114-
outputs = Dropout(
115-
rate=input_dropout_rate, deterministic=not train)(
116-
outputs)
114+
outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs)
117115

118116
return outputs, output_paddings
119117

0 commit comments

Comments
 (0)