Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions hackable_diffusion/lib/architecture/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def _dot_product_attention(
rescale: Float["..."],
*,
mask: Bool["batch sequence_key"] | None = None,
dropout_rate: float = 0.0,
is_training: bool = True,
) -> Float["batch sequence_query head*dim"]:
"""Performs dot product attention.

Expand All @@ -137,6 +139,8 @@ def _dot_product_attention(
rescale: Rescale factor for the attention scores.
mask: Mask tensor. Mask is True for tokens we want to keep and False for
tokens we want to mask. If None, no masking is performed.
dropout_rate: The dropout rate for the attention weights.
is_training: Whether the model is in training mode.

Returns:
The output tensor.
Expand All @@ -156,6 +160,11 @@ def _dot_product_attention(
# Softmax and attention weights
attn_weights = _stable_softmax(logits=attn_logits)

if dropout_rate > 0.0:
attn_weights = nn.Dropout(rate=dropout_rate)(
attn_weights, deterministic=not is_training
)

# Calculate attention output
attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v)

Expand Down Expand Up @@ -194,6 +203,7 @@ class MultiHeadAttention(nn.Module):
use_rope is True.
zero_init_output: If True, the kernel of the final output projection layer
is initialized to zeros.
dropout_rate: The dropout rate for the attention weights.
dtype: The data type of the computation.
"""

Expand All @@ -203,6 +213,7 @@ class MultiHeadAttention(nn.Module):
use_rope: bool = False
rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE
zero_init_output: bool = False
dropout_rate: float = 0.0
dtype: DType = jnp.float32

def setup(self):
Expand All @@ -226,6 +237,7 @@ def __call__(
c: Float["batch sequence2 dim2"] | None,
*,
mask: Bool["batch sequence1|sequence2"] | None = None,
is_training: bool = True,
) -> Float["batch sequence1 dim1"]:
"""Computes multi-head attention.

Expand Down Expand Up @@ -319,6 +331,8 @@ def __call__(
v=v,
rescale=scale,
mask=mask,
dropout_rate=self.dropout_rate,
is_training=is_training,
)

attn_output = nn.Dense(
Expand Down
96 changes: 96 additions & 0 deletions hackable_diffusion/lib/architecture/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,102 @@ def test_multi_head_attention_invalid_mask_shape_raises_error(
):
module.init(self.rng, self.x, c, mask=invalid_mask)

# MARK: Dropout Tests

def test_multi_head_attention_dropout_disabled_during_evaluation(self):
"""Verifies dropout is inactive when is_training=False (evaluation mode)."""
# Initialize with an aggressive dropout rate (e.g., 0.5)
module = attention.MultiHeadAttention(
num_heads=self.num_heads,
dropout_rate=0.5,
)

# Generate random inputs to capture exact matrix values
rng1, rng2 = jax.random.split(self.rng)
x_rand = jax.random.normal(
rng1, (self.batch_size, self.seq_len_q, self.dim)
)

variables = module.init(rng2, x_rand, c=None)

# Run twice with evaluation mode (is_training=False).
# Even with a 50% dropout rate, the outputs should be completely identical.
output_eval_1 = module.apply(variables, x_rand, c=None, is_training=False)
output_eval_2 = module.apply(variables, x_rand, c=None, is_training=False)

np.testing.assert_allclose(
output_eval_1,
output_eval_2,
atol=1e-6,
)

def test_multi_head_attention_dropout_active_during_training(self):
"""Verifies dropout alters outputs randomly when is_training=True."""
module = attention.MultiHeadAttention(
num_heads=self.num_heads,
dropout_rate=0.5,
)

rng1, rng2, rng_dropout1, rng_dropout2 = jax.random.split(self.rng, 4)
x_rand = jax.random.normal(
rng1, (self.batch_size, self.seq_len_q, self.dim)
)

variables = module.init(rng2, x_rand, c=None)

# Flax requires a 'dropout' RNG stream state passed inside a dict
# whenever execution hits an active nn.Dropout layer during training.
output_train_1 = module.apply(
variables,
x_rand,
c=None,
is_training=True,
rngs={"dropout": rng_dropout1},
)
output_train_2 = module.apply(
variables,
x_rand,
c=None,
is_training=True,
rngs={"dropout": rng_dropout2},
)

# Since two distinct keys were injected into the dropout stream,
# different masks were dropped, meaning outputs must differ.
self.assertFalse(jnp.allclose(output_train_1, output_train_2, atol=1e-5))

def test_multi_head_attention_dropout_scales_retained_activations(self):
"""Verifies dropout scales active entries by 1 / (1 - rate) during training."""
# Set a 50% rate. Active entries must double in value (multiplied by 2.0)
rate = 0.5
module = attention.MultiHeadAttention(
num_heads=self.num_heads,
dropout_rate=rate,
)

rng1, rng2, rng_dropout = jax.random.split(self.rng, 3)
x_rand = jax.random.normal(
rng1, (self.batch_size, self.seq_len_q, self.dim)
)

variables = module.init(rng2, x_rand, c=None)

output_eval = module.apply(variables, x_rand, c=None, is_training=False)
output_train = module.apply(
variables,
x_rand,
c=None,
is_training=True,
rngs={"dropout": rng_dropout},
)

# Standard inverted dropout behavior means active values must be larger
# than non-dropped values to preserve target expectation bounds.
max_train_val = float(jnp.max(jnp.abs(output_train)))
max_eval_val = float(jnp.max(jnp.abs(output_eval)))

self.assertGreater(max_train_val, max_eval_val)


if __name__ == "__main__":
absltest.main()
Loading