From 1f666efbbe2a0b0135e5a64d087350f3be2fda50 Mon Sep 17 00:00:00 2001 From: Alexandre Galashov Date: Thu, 21 May 2026 02:32:19 -0700 Subject: [PATCH] Implement attention dropout PiperOrigin-RevId: 918928569 --- .../lib/architecture/attention.py | 14 +++ .../lib/architecture/attention_test.py | 96 +++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/hackable_diffusion/lib/architecture/attention.py b/hackable_diffusion/lib/architecture/attention.py index 9171f0f..58ccdbc 100644 --- a/hackable_diffusion/lib/architecture/attention.py +++ b/hackable_diffusion/lib/architecture/attention.py @@ -127,6 +127,8 @@ def _dot_product_attention( rescale: Float["..."], *, mask: Bool["batch sequence_key"] | None = None, + dropout_rate: float = 0.0, + is_training: bool = True, ) -> Float["batch sequence_query head*dim"]: """Performs dot product attention. @@ -137,6 +139,8 @@ def _dot_product_attention( rescale: Rescale factor for the attention scores. mask: Mask tensor. Mask is True for tokens we want to keep and False for tokens we want to mask. If None, no masking is performed. + dropout_rate: The dropout rate for the attention weights. + is_training: Whether the model is in training mode. Returns: The output tensor. @@ -156,6 +160,11 @@ def _dot_product_attention( # Softmax and attention weights attn_weights = _stable_softmax(logits=attn_logits) + if dropout_rate > 0.0: + attn_weights = nn.Dropout(rate=dropout_rate)( + attn_weights, deterministic=not is_training + ) + # Calculate attention output attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v) @@ -194,6 +203,7 @@ class MultiHeadAttention(nn.Module): use_rope is True. zero_init_output: If True, the kernel of the final output projection layer is initialized to zeros. + dropout_rate: The dropout rate for the attention weights. dtype: The data type of the computation. """ @@ -203,6 +213,7 @@ class MultiHeadAttention(nn.Module): use_rope: bool = False rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE zero_init_output: bool = False + dropout_rate: float = 0.0 dtype: DType = jnp.float32 def setup(self): @@ -226,6 +237,7 @@ def __call__( c: Float["batch sequence2 dim2"] | None, *, mask: Bool["batch sequence1|sequence2"] | None = None, + is_training: bool = True, ) -> Float["batch sequence1 dim1"]: """Computes multi-head attention. @@ -319,6 +331,8 @@ def __call__( v=v, rescale=scale, mask=mask, + dropout_rate=self.dropout_rate, + is_training=is_training, ) attn_output = nn.Dense( diff --git a/hackable_diffusion/lib/architecture/attention_test.py b/hackable_diffusion/lib/architecture/attention_test.py index a1c3298..7a31767 100644 --- a/hackable_diffusion/lib/architecture/attention_test.py +++ b/hackable_diffusion/lib/architecture/attention_test.py @@ -408,6 +408,102 @@ def test_multi_head_attention_invalid_mask_shape_raises_error( ): module.init(self.rng, self.x, c, mask=invalid_mask) + # MARK: Dropout Tests + + def test_multi_head_attention_dropout_disabled_during_evaluation(self): + """Verifies dropout is inactive when is_training=False (evaluation mode).""" + # Initialize with an aggressive dropout rate (e.g., 0.5) + module = attention.MultiHeadAttention( + num_heads=self.num_heads, + dropout_rate=0.5, + ) + + # Generate random inputs to capture exact matrix values + rng1, rng2 = jax.random.split(self.rng) + x_rand = jax.random.normal( + rng1, (self.batch_size, self.seq_len_q, self.dim) + ) + + variables = module.init(rng2, x_rand, c=None) + + # Run twice with evaluation mode (is_training=False). + # Even with a 50% dropout rate, the outputs should be completely identical. + output_eval_1 = module.apply(variables, x_rand, c=None, is_training=False) + output_eval_2 = module.apply(variables, x_rand, c=None, is_training=False) + + np.testing.assert_allclose( + output_eval_1, + output_eval_2, + atol=1e-6, + ) + + def test_multi_head_attention_dropout_active_during_training(self): + """Verifies dropout alters outputs randomly when is_training=True.""" + module = attention.MultiHeadAttention( + num_heads=self.num_heads, + dropout_rate=0.5, + ) + + rng1, rng2, rng_dropout1, rng_dropout2 = jax.random.split(self.rng, 4) + x_rand = jax.random.normal( + rng1, (self.batch_size, self.seq_len_q, self.dim) + ) + + variables = module.init(rng2, x_rand, c=None) + + # Flax requires a 'dropout' RNG stream state passed inside a dict + # whenever execution hits an active nn.Dropout layer during training. + output_train_1 = module.apply( + variables, + x_rand, + c=None, + is_training=True, + rngs={"dropout": rng_dropout1}, + ) + output_train_2 = module.apply( + variables, + x_rand, + c=None, + is_training=True, + rngs={"dropout": rng_dropout2}, + ) + + # Since two distinct keys were injected into the dropout stream, + # different masks were dropped, meaning outputs must differ. + self.assertFalse(jnp.allclose(output_train_1, output_train_2, atol=1e-5)) + + def test_multi_head_attention_dropout_scales_retained_activations(self): + """Verifies dropout scales active entries by 1 / (1 - rate) during training.""" + # Set a 50% rate. Active entries must double in value (multiplied by 2.0) + rate = 0.5 + module = attention.MultiHeadAttention( + num_heads=self.num_heads, + dropout_rate=rate, + ) + + rng1, rng2, rng_dropout = jax.random.split(self.rng, 3) + x_rand = jax.random.normal( + rng1, (self.batch_size, self.seq_len_q, self.dim) + ) + + variables = module.init(rng2, x_rand, c=None) + + output_eval = module.apply(variables, x_rand, c=None, is_training=False) + output_train = module.apply( + variables, + x_rand, + c=None, + is_training=True, + rngs={"dropout": rng_dropout}, + ) + + # Standard inverted dropout behavior means active values must be larger + # than non-dropped values to preserve target expectation bounds. + max_train_val = float(jnp.max(jnp.abs(output_train))) + max_eval_val = float(jnp.max(jnp.abs(output_eval))) + + self.assertGreater(max_train_val, max_eval_val) + if __name__ == "__main__": absltest.main()