@@ -39,9 +39,12 @@ class MlpBlock(nn.Module):
39
39
dropout_rate : float = 0.0
40
40
41
41
@nn .compact
42
- def __call__ (self , x : spec .Tensor , train : bool = True , dropout_rate = None ) -> spec .Tensor :
42
+ def __call__ (self ,
43
+ x : spec .Tensor ,
44
+ train : bool = True ,
45
+ dropout_rate = None ) -> spec .Tensor :
43
46
"""Applies Transformer MlpBlock module."""
44
- if not dropout_rate :
47
+ if dropout_rate is None :
45
48
dropout_rate = self .dropout_rate
46
49
47
50
inits = {
@@ -57,7 +60,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe
57
60
y = nn .Dense (self .mlp_dim , ** inits )(x )
58
61
x = x * y
59
62
60
- x = Dropout ()(x , train , rate = dropout_rate )
63
+ x = Dropout (dropout_rate )(x , train , rate = dropout_rate )
61
64
x = nn .Dense (d , ** inits )(x )
62
65
return x
63
66
@@ -71,9 +74,12 @@ class Encoder1DBlock(nn.Module):
71
74
dropout_rate : float = 0.0
72
75
73
76
@nn .compact
74
- def __call__ (self , x : spec .Tensor , train : bool = True , dropout_rate = dropout_rate ) -> spec .Tensor :
75
- if not dropout_rate :
76
- dropout_rate = self .dropout_rate
77
+ def __call__ (self ,
78
+ x : spec .Tensor ,
79
+ train : bool = True ,
80
+ dropout_rate = dropout_rate ) -> spec .Tensor :
81
+ if dropout_rate is None :
82
+ dropout_rate = self .dropout_rate
77
83
78
84
if not self .use_post_layer_norm :
79
85
y = nn .LayerNorm (name = 'LayerNorm_0' )(x )
@@ -83,15 +89,14 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
83
89
deterministic = train ,
84
90
name = 'MultiHeadDotProductAttention_1' )(
85
91
y )
86
- y = Dropout ()(y , train , dropout_rate = dropout_rate )
92
+ y = Dropout (dropout_rate )(y , train , dropout_rate = dropout_rate )
87
93
x = x + y
88
94
89
95
y = nn .LayerNorm (name = 'LayerNorm_2' )(x )
90
96
y = MlpBlock (
91
- mlp_dim = self .mlp_dim ,
92
- use_glu = self .use_glu ,
93
- name = 'MlpBlock_3' )(y , train , dropout_rate = dropout_rate )
94
- y = Dropout ()(y , train , rate = dropout_rate )
97
+ mlp_dim = self .mlp_dim , use_glu = self .use_glu , name = 'MlpBlock_3' )(
98
+ y , train , dropout_rate = dropout_rate )
99
+ y = Dropout (dropout_rate )(y , train , rate = dropout_rate )
95
100
x = x + y
96
101
else :
97
102
y = x
@@ -101,16 +106,18 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
101
106
deterministic = train ,
102
107
name = 'MultiHeadDotProductAttention_1' )(
103
108
y )
104
- y = Dropout ()(y , train , rate = dropout_rate )
109
+ y = Dropout (dropout_rate )(y , train , rate = dropout_rate )
105
110
x = x + y
106
111
x = nn .LayerNorm (name = 'LayerNorm_0' )(x )
107
112
108
113
y = x
109
114
y = MlpBlock (
110
115
mlp_dim = self .mlp_dim ,
111
116
use_glu = self .use_glu ,
112
- name = 'MlpBlock_3' )(y , train , dropout_rate = dropout_rate )
113
- y = Dropout ()(y , train )(rate = dropout_rate )
117
+ name = 'MlpBlock_3' ,
118
+ dropout_rate = dropout_rate )(
119
+ y , train , dropout_rate = dropout_rate )
120
+ y = Dropout (dropout_rate )(y , train )(rate = dropout_rate )
114
121
x = x + y
115
122
x = nn .LayerNorm (name = 'LayerNorm_2' )(x )
116
123
@@ -127,9 +134,12 @@ class Encoder(nn.Module):
127
134
use_post_layer_norm : bool = False
128
135
129
136
@nn .compact
130
- def __call__ (self , x : spec .Tensor , train : bool = True , dropout_rate = None ) -> spec .Tensor :
131
- if not dropout_rate :
132
- dropout_rate = self .dropout_rate
137
+ def __call__ (self ,
138
+ x : spec .Tensor ,
139
+ train : bool = True ,
140
+ dropout_rate = None ) -> spec .Tensor :
141
+ if dropout_rate is None :
142
+ dropout_rate = self .dropout_rate
133
143
134
144
# Input Encoder
135
145
for lyr in range (self .depth ):
@@ -139,7 +149,8 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe
139
149
num_heads = self .num_heads ,
140
150
use_glu = self .use_glu ,
141
151
use_post_layer_norm = self .use_post_layer_norm ,
142
- )(dropout_rate = dropout_rate )
152
+ dropout_rate = dropout_rate )(
153
+ dropout_rate = dropout_rate )
143
154
x = block (x , train )
144
155
if not self .use_post_layer_norm :
145
156
return nn .LayerNorm (name = 'encoder_layernorm' )(x )
@@ -151,9 +162,12 @@ class MAPHead(nn.Module):
151
162
"""Multihead Attention Pooling."""
152
163
mlp_dim : Optional [int ] = None # Defaults to 4x input dim
153
164
num_heads : int = 12
165
+ dropout_rate : 0.0
154
166
155
167
@nn .compact
156
- def __call__ (self , x ):
168
+ def __call__ (self , x , dropout_rate = None ):
169
+ if dropout_rate is None :
170
+ dropout_rate = self .dropout_rate
157
171
n , _ , d = x .shape
158
172
probe = self .param ('probe' ,
159
173
nn .initializers .xavier_uniform (), (1 , 1 , d ),
@@ -166,7 +180,7 @@ def __call__(self, x):
166
180
kernel_init = nn .initializers .xavier_uniform ())(probe , x )
167
181
168
182
y = nn .LayerNorm ()(x )
169
- x = x + MlpBlock (mlp_dim = self .mlp_dim )(y )
183
+ x = x + MlpBlock (mlp_dim = self .mlp_dim , dropout_rate = dropout_rate )(y )
170
184
return x [:, 0 ]
171
185
172
186
@@ -180,7 +194,7 @@ class ViT(nn.Module):
180
194
mlp_dim : Optional [int ] = None # Defaults to 4x input dim.
181
195
num_heads : int = 12
182
196
rep_size : Union [int , bool ] = True
183
- dropout_rate : Optional [float ] = 0.0 # If None, defaults to 0.0.
197
+ dropout_rate : Optional [float ] = 0.0
184
198
reinit : Optional [Sequence [str ]] = None
185
199
head_zeroinit : bool = True
186
200
use_glu : bool = False
@@ -194,8 +208,12 @@ def get_posemb(self,
194
208
return posemb_sincos_2d (* seqshape , width , dtype = dtype )
195
209
196
210
@nn .compact
197
- def __call__ (self , x : spec .Tensor , * , train : bool = False , dropout_rate = None ) -> spec .Tensor :
198
- if not dropout_rate :
211
+ def __call__ (self ,
212
+ x : spec .Tensor ,
213
+ * ,
214
+ train : bool = False ,
215
+ dropout_rate = None ) -> spec .Tensor :
216
+ if dropout_rate is None :
199
217
dropout_rate = self .dropout_rate
200
218
# Patch extraction
201
219
x = nn .Conv (
@@ -212,19 +230,24 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) ->
212
230
# Add posemb before adding extra token.
213
231
x = x + self .get_posemb ((h , w ), c , x .dtype )
214
232
215
- x = Dropout ()(x , not train , rate = dropout_rate )
233
+ x = Dropout (dropout_rate )(x , not train , rate = dropout_rate )
216
234
217
235
x = Encoder (
218
236
depth = self .depth ,
219
237
mlp_dim = self .mlp_dim ,
220
238
num_heads = self .num_heads ,
221
239
use_glu = self .use_glu ,
222
240
use_post_layer_norm = self .use_post_layer_norm ,
223
- name = 'Transformer' )(
241
+ name = 'Transformer' ,
242
+ dropout_rate = dropout_rate )(
224
243
x , train = not train , dropout_rate = dropout_rate )
225
244
226
245
if self .use_map :
227
- x = MAPHead (num_heads = self .num_heads , mlp_dim = self .mlp_dim )(x )
246
+ x = MAPHead (
247
+ num_heads = self .num_heads ,
248
+ mlp_dim = self .mlp_dim ,
249
+ dropout_rate = dropout_rate )(
250
+ x , dropout_rate = dropout_rate )
228
251
else :
229
252
x = jnp .mean (x , axis = 1 )
230
253
0 commit comments