Skip to content

Commit 586ff73

Browse files
agalashovHackable Diffusion Authors
authored andcommitted
Implement attention dropout
PiperOrigin-RevId: 918682357
1 parent d23ca46 commit 586ff73

4 files changed

Lines changed: 261 additions & 12 deletions

File tree

hackable_diffusion/lib/architecture/attention.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def _dot_product_attention(
127127
rescale: Float["..."],
128128
*,
129129
mask: Bool["batch sequence_key"] | None = None,
130+
dropout_rate: float = 0.0,
131+
is_training: bool = True,
130132
) -> Float["batch sequence_query head*dim"]:
131133
"""Performs dot product attention.
132134
@@ -137,6 +139,8 @@ def _dot_product_attention(
137139
rescale: Rescale factor for the attention scores.
138140
mask: Mask tensor. Mask is True for tokens we want to keep and False for
139141
tokens we want to mask. If None, no masking is performed.
142+
dropout_rate: The dropout rate for the attention weights.
143+
is_training: Whether the model is in training mode.
140144
141145
Returns:
142146
The output tensor.
@@ -156,6 +160,11 @@ def _dot_product_attention(
156160
# Softmax and attention weights
157161
attn_weights = _stable_softmax(logits=attn_logits)
158162

163+
if dropout_rate > 0.0:
164+
attn_weights = nn.Dropout(rate=dropout_rate)(
165+
attn_weights, deterministic=not is_training
166+
)
167+
159168
# Calculate attention output
160169
attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v)
161170

@@ -194,6 +203,7 @@ class MultiHeadAttention(nn.Module):
194203
use_rope is True.
195204
zero_init_output: If True, the kernel of the final output projection layer
196205
is initialized to zeros.
206+
dropout_rate: The dropout rate for the attention weights.
197207
dtype: The data type of the computation.
198208
"""
199209

@@ -203,6 +213,7 @@ class MultiHeadAttention(nn.Module):
203213
use_rope: bool = False
204214
rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE
205215
zero_init_output: bool = False
216+
dropout_rate: float = 0.0
206217
dtype: DType = jnp.float32
207218

208219
def setup(self):
@@ -226,6 +237,7 @@ def __call__(
226237
c: Float["batch sequence2 dim2"] | None,
227238
*,
228239
mask: Bool["batch sequence1|sequence2"] | None = None,
240+
is_training: bool = True,
229241
) -> Float["batch sequence1 dim1"]:
230242
"""Computes multi-head attention.
231243
@@ -319,6 +331,8 @@ def __call__(
319331
v=v,
320332
rescale=scale,
321333
mask=mask,
334+
dropout_rate=self.dropout_rate,
335+
is_training=is_training,
322336
)
323337

324338
attn_output = nn.Dense(

hackable_diffusion/lib/architecture/attention_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,102 @@ def test_multi_head_attention_invalid_mask_shape_raises_error(
408408
):
409409
module.init(self.rng, self.x, c, mask=invalid_mask)
410410

411+
# MARK: Dropout Tests
412+
413+
def test_multi_head_attention_dropout_disabled_during_evaluation(self):
414+
"""Verifies dropout is inactive when is_training=False (evaluation mode)."""
415+
# Initialize with an aggressive dropout rate (e.g., 0.5)
416+
module = attention.MultiHeadAttention(
417+
num_heads=self.num_heads,
418+
dropout_rate=0.5,
419+
)
420+
421+
# Generate random inputs to capture exact matrix values
422+
rng1, rng2 = jax.random.split(self.rng)
423+
x_rand = jax.random.normal(
424+
rng1, (self.batch_size, self.seq_len_q, self.dim)
425+
)
426+
427+
variables = module.init(rng2, x_rand, c=None)
428+
429+
# Run twice with evaluation mode (is_training=False).
430+
# Even with a 50% dropout rate, the outputs should be completely identical.
431+
output_eval_1 = module.apply(variables, x_rand, c=None, is_training=False)
432+
output_eval_2 = module.apply(variables, x_rand, c=None, is_training=False)
433+
434+
np.testing.assert_allclose(
435+
output_eval_1,
436+
output_eval_2,
437+
atol=1e-6,
438+
)
439+
440+
def test_multi_head_attention_dropout_active_during_training(self):
441+
"""Verifies dropout alters outputs randomly when is_training=True."""
442+
module = attention.MultiHeadAttention(
443+
num_heads=self.num_heads,
444+
dropout_rate=0.5,
445+
)
446+
447+
rng1, rng2, rng_dropout1, rng_dropout2 = jax.random.split(self.rng, 4)
448+
x_rand = jax.random.normal(
449+
rng1, (self.batch_size, self.seq_len_q, self.dim)
450+
)
451+
452+
variables = module.init(rng2, x_rand, c=None)
453+
454+
# Flax requires a 'dropout' RNG stream state passed inside a dict
455+
# whenever execution hits an active nn.Dropout layer during training.
456+
output_train_1 = module.apply(
457+
variables,
458+
x_rand,
459+
c=None,
460+
is_training=True,
461+
rngs={"dropout": rng_dropout1},
462+
)
463+
output_train_2 = module.apply(
464+
variables,
465+
x_rand,
466+
c=None,
467+
is_training=True,
468+
rngs={"dropout": rng_dropout2},
469+
)
470+
471+
# Since two distinct keys were injected into the dropout stream,
472+
# different masks were dropped, meaning outputs must differ.
473+
self.assertFalse(jnp.allclose(output_train_1, output_train_2, atol=1e-5))
474+
475+
def test_multi_head_attention_dropout_scales_retained_activations(self):
476+
"""Verifies dropout scales active entries by 1 / (1 - rate) during training."""
477+
# Set a 50% rate. Active entries must double in value (multiplied by 2.0)
478+
rate = 0.5
479+
module = attention.MultiHeadAttention(
480+
num_heads=self.num_heads,
481+
dropout_rate=rate,
482+
)
483+
484+
rng1, rng2, rng_dropout = jax.random.split(self.rng, 3)
485+
x_rand = jax.random.normal(
486+
rng1, (self.batch_size, self.seq_len_q, self.dim)
487+
)
488+
489+
variables = module.init(rng2, x_rand, c=None)
490+
491+
output_eval = module.apply(variables, x_rand, c=None, is_training=False)
492+
output_train = module.apply(
493+
variables,
494+
x_rand,
495+
c=None,
496+
is_training=True,
497+
rngs={"dropout": rng_dropout},
498+
)
499+
500+
# Standard inverted dropout behavior means active values must be larger
501+
# than non-dropped values to preserve target expectation bounds.
502+
max_train_val = float(jnp.max(jnp.abs(output_train)))
503+
max_eval_val = float(jnp.max(jnp.abs(output_eval)))
504+
505+
self.assertGreater(max_train_val, max_eval_val)
506+
411507

412508
if __name__ == "__main__":
413509
absltest.main()

hackable_diffusion/lib/architecture/normalization.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class NormalizationLayer(nn.Module):
7979
dtype: The data type of the computation.
8080
use_bias: Whether to use bias in the normalization layer.
8181
use_scale: Whether to use scale in the normalization layer.
82+
use_conditional_shift: Whether to use conditional shift in the normalization
83+
layer (only applies when `conditional` is True).
8284
"""
8385

8486
normalization_method: NormalizationType
@@ -88,6 +90,7 @@ class NormalizationLayer(nn.Module):
8890
dtype: DType = jnp.float32
8991
use_bias: bool = True
9092
use_scale: bool = True
93+
use_conditional_shift: bool = True
9194

9295
def setup(self):
9396
if (
@@ -169,18 +172,28 @@ def __call__(
169172
)
170173

171174
if self.conditional:
172-
173-
scale_and_shift = nn.Dense(
174-
ch * 2,
175-
kernel_init=nn.zeros_init(),
176-
bias_init=nn.zeros_init(),
177-
dtype=self.dtype,
178-
)(c)
179-
scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each.
180-
181175
x = einops.rearrange(x, "b ... c -> b c ...") # (B, ch, ...).
182-
scale = jax_helpers.bcast_right(scale, x.ndim)
183-
shift = jax_helpers.bcast_right(shift, x.ndim)
176+
# Scale + shift adaptive conditioning.
177+
if self.use_conditional_shift:
178+
scale_and_shift = nn.Dense(
179+
ch * 2,
180+
kernel_init=nn.zeros_init(),
181+
bias_init=nn.zeros_init(),
182+
dtype=self.dtype,
183+
)(c)
184+
scale, shift = jnp.split(scale_and_shift, 2, axis=-1) # (B, ch) each.
185+
scale = jax_helpers.bcast_right(scale, x.ndim)
186+
shift = jax_helpers.bcast_right(shift, x.ndim)
187+
else:
188+
# Scale-only adaptive conditioning (no shift).
189+
scale = nn.Dense(
190+
ch,
191+
kernel_init=nn.zeros_init(),
192+
bias_init=nn.zeros_init(),
193+
dtype=self.dtype,
194+
)(c)
195+
scale = jax_helpers.bcast_right(scale, x.ndim)
196+
shift = jnp.zeros_like(scale)
184197
x = (1.0 + scale) * x + shift
185198
x = einops.rearrange(x, "b c ... -> b ... c")
186199

@@ -211,6 +224,8 @@ class NormalizationLayerFactory:
211224
dtype: The data type of the computation.
212225
use_bias: Whether to use bias in the normalization layer.
213226
use_scale: Whether to use scale in the normalization layer.
227+
use_conditional_shift: Whether to use conditional shift in the normalization
228+
layer (only applies when `conditional` is True).
214229
"""
215230

216231
def __init__(
@@ -221,13 +236,15 @@ def __init__(
221236
dtype: DType = jnp.float32,
222237
use_bias: bool = True,
223238
use_scale: bool = True,
239+
use_conditional_shift: bool = True,
224240
):
225241
self.normalization_method = normalization_method
226242
self.epsilon = epsilon
227243
self.num_groups = num_groups
228244
self.dtype = dtype
229245
self.use_bias = use_bias
230246
self.use_scale = use_scale
247+
self.use_conditional_shift = use_conditional_shift
231248

232249
def unconditional_norm(
233250
self, norm_name: str = "UnconditionalNorm"
@@ -242,12 +259,13 @@ def unconditional_norm(
242259
dtype=self.dtype,
243260
use_bias=self.use_bias,
244261
use_scale=self.use_scale,
262+
use_conditional_shift=self.use_conditional_shift,
245263
)
246264

247265
def conditional_norm(
248266
self, norm_name: str = "ConditionalNorm"
249267
) -> NormalizationLayer:
250-
"""Returns a factory for creating conditional normalization layers."""
268+
"""Returns a conditional normalization layer."""
251269
return NormalizationLayer(
252270
normalization_method=self.normalization_method,
253271
conditional=True,
@@ -257,4 +275,5 @@ def conditional_norm(
257275
dtype=self.dtype,
258276
use_bias=self.use_bias,
259277
use_scale=self.use_scale,
278+
use_conditional_shift=self.use_conditional_shift,
260279
)

hackable_diffusion/lib/architecture/normalization_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,126 @@ def test_rmsnorm_mask_equivalence(self):
468468
),
469469
)
470470

471+
def test_conditional_rmsnorm_scale_only_at_init(self):
472+
"""Tests conditional RMSNorm with scale-only (no shift) at init."""
473+
norm_layer = normalization.NormalizationLayer(
474+
normalization_method=NormalizationType.RMS_NORM,
475+
conditional=True,
476+
use_conditional_shift=False,
477+
)
478+
params = norm_layer.init(self.rng, self.x, self.c)
479+
output = norm_layer.apply(params, self.x, self.c)
480+
self.assertEqual(output.shape, self.x_shape)
481+
482+
# At init, scale=0, so output should match plain RMSNorm.
483+
x2 = jnp.mean(self.x**2, -1, keepdims=True)
484+
output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon)
485+
np.testing.assert_allclose(output, output_ref, rtol=1e-5, atol=1e-5)
486+
487+
def test_conditional_rmsnorm_scale_only_perturbed(self):
488+
"""Tests conditional RMSNorm scale-only with perturbed params."""
489+
norm_layer = normalization.NormalizationLayer(
490+
normalization_method=NormalizationType.RMS_NORM,
491+
conditional=True,
492+
use_conditional_shift=False,
493+
)
494+
params = norm_layer.init(self.rng, self.x, self.c)
495+
params_perturbed = _perturb_params(params=params, key=self.rng)
496+
output_perturbed = norm_layer.apply(params_perturbed, self.x, self.c)
497+
498+
# Compute unconditional RMSNorm for comparison.
499+
x2 = jnp.mean(self.x**2, -1, keepdims=True)
500+
output_ref = self.x * lax.rsqrt(x2 + norm_layer.epsilon)
501+
502+
self.assertEqual(output_perturbed.shape, self.x_shape)
503+
self.assertFalse(
504+
np.allclose(output_perturbed, output_ref, rtol=1e-5, atol=1e-5),
505+
msg=(
506+
"Scale-only conditional output should differ from unconditional"
507+
" output after perturbing params."
508+
),
509+
)
510+
511+
def test_conditional_scale_only_projects_to_ch(self):
512+
"""Tests that scale-only conditioning projects to ch (not ch*2)."""
513+
norm_layer = normalization.NormalizationLayer(
514+
normalization_method=NormalizationType.RMS_NORM,
515+
conditional=True,
516+
use_conditional_shift=False,
517+
)
518+
params = norm_layer.init(self.rng, self.x, self.c)
519+
# The Dense layer should project to ch (not ch * 2).
520+
dense_kernel = params["params"]["Dense_0"]["kernel"]
521+
expected_shape = (self.c_shape[-1], self.x_shape[-1]) # (cond_dim, ch)
522+
self.assertEqual(dense_kernel.shape, expected_shape)
523+
524+
def test_conditional_scale_shift_projects_to_ch_times_2(self):
525+
"""Tests that scale+shift conditioning projects to ch * 2."""
526+
norm_layer = normalization.NormalizationLayer(
527+
normalization_method=NormalizationType.RMS_NORM,
528+
conditional=True,
529+
use_conditional_shift=True,
530+
)
531+
params = norm_layer.init(self.rng, self.x, self.c)
532+
dense_kernel = params["params"]["Dense_0"]["kernel"]
533+
expected_shape = (self.c_shape[-1], self.x_shape[-1] * 2)
534+
self.assertEqual(dense_kernel.shape, expected_shape)
535+
536+
def test_conditional_rmsnorm_scale_only_padding_invariance(self):
537+
"""Tests scale-only conditional RMSNorm padding invariance."""
538+
norm_layer = normalization.NormalizationLayer(
539+
normalization_method=NormalizationType.RMS_NORM,
540+
conditional=True,
541+
use_conditional_shift=False,
542+
)
543+
c_small = jax.random.normal(self.rng, self.c_shape)
544+
params = norm_layer.init(self.rng, self.x_small, c_small)
545+
params_perturbed = _perturb_params(params=params, key=self.rng)
546+
547+
out_small = norm_layer.apply(params_perturbed, self.x_small, c_small)
548+
out_large = norm_layer.apply(params_perturbed, self.x_large, c_small)
549+
np.testing.assert_allclose(
550+
out_small[:, :, : self.unpadded_seq_len, :],
551+
out_large[:, :, : self.unpadded_seq_len, :],
552+
atol=1e-5,
553+
)
554+
555+
@parameterized.product(
556+
normalization_method=[
557+
NormalizationType.RMS_NORM,
558+
NormalizationType.LAYER_NORM,
559+
NormalizationType.GROUP_NORM,
560+
],
561+
conditional=[False, True],
562+
dtype=[jnp.float32, jnp.bfloat16],
563+
)
564+
def test_output_dtype(self, normalization_method, conditional, dtype):
565+
"""Tests that the output dtype matches the configured dtype."""
566+
num_groups = (
567+
self.num_groups
568+
if (normalization_method == NormalizationType.GROUP_NORM)
569+
else None
570+
)
571+
572+
norm_layer = normalization.NormalizationLayer(
573+
normalization_method=normalization_method,
574+
conditional=conditional,
575+
num_groups=num_groups,
576+
dtype=dtype,
577+
)
578+
579+
x = self.x.astype(dtype)
580+
if conditional:
581+
c = self.c.astype(dtype)
582+
params = norm_layer.init(self.rng, x, c)
583+
output = norm_layer.apply(params, x, c)
584+
else:
585+
params = norm_layer.init(self.rng, x)
586+
output = norm_layer.apply(params, x)
587+
588+
self.assertEqual(output.dtype, dtype)
589+
self.assertEqual(output.shape, self.x_shape)
590+
471591

472592
if __name__ == "__main__":
473593
absltest.main()

0 commit comments

Comments
 (0)