Skip to content

Commit ac60d57

Browse files
agalashovHackable Diffusion Authors
authored andcommitted
Remove SwiGLU but introduce Feedforward with SwiGLU
PiperOrigin-RevId: 918687250
1 parent d23ca46 commit ac60d57

6 files changed

Lines changed: 573 additions & 116 deletions

File tree

hackable_diffusion/lib/architecture/attention.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def _dot_product_attention(
127127
rescale: Float["..."],
128128
*,
129129
mask: Bool["batch sequence_key"] | None = None,
130+
dropout_rate: float = 0.0,
131+
is_training: bool = True,
130132
) -> Float["batch sequence_query head*dim"]:
131133
"""Performs dot product attention.
132134
@@ -137,6 +139,8 @@ def _dot_product_attention(
137139
rescale: Rescale factor for the attention scores.
138140
mask: Mask tensor. Mask is True for tokens we want to keep and False for
139141
tokens we want to mask. If None, no masking is performed.
142+
dropout_rate: The dropout rate for the attention weights.
143+
is_training: Whether the model is in training mode.
140144
141145
Returns:
142146
The output tensor.
@@ -156,6 +160,11 @@ def _dot_product_attention(
156160
# Softmax and attention weights
157161
attn_weights = _stable_softmax(logits=attn_logits)
158162

163+
if dropout_rate > 0.0:
164+
attn_weights = nn.Dropout(rate=dropout_rate)(
165+
attn_weights, deterministic=not is_training
166+
)
167+
159168
# Calculate attention output
160169
attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v)
161170

@@ -194,6 +203,7 @@ class MultiHeadAttention(nn.Module):
194203
use_rope is True.
195204
zero_init_output: If True, the kernel of the final output projection layer
196205
is initialized to zeros.
206+
dropout_rate: The dropout rate for the attention weights.
197207
dtype: The data type of the computation.
198208
"""
199209

@@ -203,6 +213,7 @@ class MultiHeadAttention(nn.Module):
203213
use_rope: bool = False
204214
rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE
205215
zero_init_output: bool = False
216+
dropout_rate: float = 0.0
206217
dtype: DType = jnp.float32
207218

208219
def setup(self):
@@ -226,6 +237,7 @@ def __call__(
226237
c: Float["batch sequence2 dim2"] | None,
227238
*,
228239
mask: Bool["batch sequence1|sequence2"] | None = None,
240+
is_training: bool = True,
229241
) -> Float["batch sequence1 dim1"]:
230242
"""Computes multi-head attention.
231243
@@ -319,6 +331,8 @@ def __call__(
319331
v=v,
320332
rescale=scale,
321333
mask=mask,
334+
dropout_rate=self.dropout_rate,
335+
is_training=is_training,
322336
)
323337

324338
attn_output = nn.Dense(

hackable_diffusion/lib/architecture/attention_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,102 @@ def test_multi_head_attention_invalid_mask_shape_raises_error(
408408
):
409409
module.init(self.rng, self.x, c, mask=invalid_mask)
410410

411+
# MARK: Dropout Tests
412+
413+
def test_multi_head_attention_dropout_disabled_during_evaluation(self):
414+
"""Verifies dropout is inactive when is_training=False (evaluation mode)."""
415+
# Initialize with an aggressive dropout rate (e.g., 0.5)
416+
module = attention.MultiHeadAttention(
417+
num_heads=self.num_heads,
418+
dropout_rate=0.5,
419+
)
420+
421+
# Generate random inputs to capture exact matrix values
422+
rng1, rng2 = jax.random.split(self.rng)
423+
x_rand = jax.random.normal(
424+
rng1, (self.batch_size, self.seq_len_q, self.dim)
425+
)
426+
427+
variables = module.init(rng2, x_rand, c=None)
428+
429+
# Run twice with evaluation mode (is_training=False).
430+
# Even with a 50% dropout rate, the outputs should be completely identical.
431+
output_eval_1 = module.apply(variables, x_rand, c=None, is_training=False)
432+
output_eval_2 = module.apply(variables, x_rand, c=None, is_training=False)
433+
434+
np.testing.assert_allclose(
435+
output_eval_1,
436+
output_eval_2,
437+
atol=1e-6,
438+
)
439+
440+
def test_multi_head_attention_dropout_active_during_training(self):
441+
"""Verifies dropout alters outputs randomly when is_training=True."""
442+
module = attention.MultiHeadAttention(
443+
num_heads=self.num_heads,
444+
dropout_rate=0.5,
445+
)
446+
447+
rng1, rng2, rng_dropout1, rng_dropout2 = jax.random.split(self.rng, 4)
448+
x_rand = jax.random.normal(
449+
rng1, (self.batch_size, self.seq_len_q, self.dim)
450+
)
451+
452+
variables = module.init(rng2, x_rand, c=None)
453+
454+
# Flax requires a 'dropout' RNG stream state passed inside a dict
455+
# whenever execution hits an active nn.Dropout layer during training.
456+
output_train_1 = module.apply(
457+
variables,
458+
x_rand,
459+
c=None,
460+
is_training=True,
461+
rngs={"dropout": rng_dropout1},
462+
)
463+
output_train_2 = module.apply(
464+
variables,
465+
x_rand,
466+
c=None,
467+
is_training=True,
468+
rngs={"dropout": rng_dropout2},
469+
)
470+
471+
# Since two distinct keys were injected into the dropout stream,
472+
# different masks were dropped, meaning outputs must differ.
473+
self.assertFalse(jnp.allclose(output_train_1, output_train_2, atol=1e-5))
474+
475+
def test_multi_head_attention_dropout_scales_retained_activations(self):
476+
"""Verifies dropout scales active entries by 1 / (1 - rate) during training."""
477+
# Set a 50% rate. Active entries must double in value (multiplied by 2.0)
478+
rate = 0.5
479+
module = attention.MultiHeadAttention(
480+
num_heads=self.num_heads,
481+
dropout_rate=rate,
482+
)
483+
484+
rng1, rng2, rng_dropout = jax.random.split(self.rng, 3)
485+
x_rand = jax.random.normal(
486+
rng1, (self.batch_size, self.seq_len_q, self.dim)
487+
)
488+
489+
variables = module.init(rng2, x_rand, c=None)
490+
491+
output_eval = module.apply(variables, x_rand, c=None, is_training=False)
492+
output_train = module.apply(
493+
variables,
494+
x_rand,
495+
c=None,
496+
is_training=True,
497+
rngs={"dropout": rng_dropout},
498+
)
499+
500+
# Standard inverted dropout behavior means active values must be larger
501+
# than non-dropped values to preserve target expectation bounds.
502+
max_train_val = float(jnp.max(jnp.abs(output_train)))
503+
max_eval_val = float(jnp.max(jnp.abs(output_eval)))
504+
505+
self.assertGreater(max_train_val, max_eval_val)
506+
411507

412508
if __name__ == "__main__":
413509
absltest.main()

hackable_diffusion/lib/architecture/mlp_blocks.py

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -88,67 +88,132 @@ def __call__(
8888

8989

9090
################################################################################
91-
# MARK: SwiGLU
91+
# MARK: LinearSwiGLU
9292
################################################################################
9393

9494

95-
class SwiGLU(nn.Module):
96-
"""SwiGLU feed-forward network.
95+
class LinearSwiGLU(nn.Module):
96+
"""A Dense layer variant that outputs SwiGLU gating directly.
9797
9898
A gated feed-forward network using SiLU (Swish) activation for the gate,
9999
following "GLU Variants Improve Transformer" (Shazeer, 2020):
100100
https://arxiv.org/abs/2002.05202
101101
102-
The forward pass is:
103-
104-
gate_and_val = x @ W_up # (*, hidden_size) -> (*, ff_size * 2)
105-
val, gate = split(gate_and_val) # (*, ff_size) each
106-
x = val * SiLU(gate) # (*, ff_size)
107-
x = dropout(x)
108-
x = x @ W_down # (*, ff_size) -> (*, hidden_size)
109-
110-
Attributes:
111-
hidden_size: Output dimension (residual stream width).
112-
ff_size: Intermediate dimension (before gating).
113-
zero_init_output: If True, the down-projection kernel is initialized to
114-
zeros so the block starts as identity.
115-
dropout_rate: Dropout rate applied after gating.
116-
dtype: Data type for computation.
102+
Projects the input dimension to features * 2, chunks the result across the
103+
last dimension, and gates the activation channel with SiLU.
117104
"""
118105

119-
hidden_size: int
120-
ff_size: int
121-
zero_init_output: bool = False
122-
dropout_rate: float = 0.0
106+
features: int
107+
use_bias: bool = False
123108
dtype: DType = jnp.float32
124109

125110
@nn.compact
126111
@kt.typechecked
127-
def __call__(
128-
self, x: Float['batch *other_dims hidden_size'], *, is_training: bool
129-
) -> Float['batch *other_dims hidden_size']:
130-
# Up-projection: (*, hidden_size) -> (*, ff_size * 2).
112+
def __call__(self, x: Float["*batch d_in"]) -> Float["*batch features"]:
113+
# Project to double feature width
131114
gate_and_val = nn.Dense(
132-
features=self.ff_size * 2,
133-
use_bias=False,
115+
features=self.features * 2,
116+
use_bias=self.use_bias,
134117
dtype=self.dtype,
135-
name='Dense_Up',
118+
name="Dense_Gate_Val",
136119
)(x)
137-
# Split into value and gate, apply SiLU gating.
120+
121+
# Split and apply SiLU gating (mirrors torch.chunk(2, dim=-1))
138122
val, gate = jnp.split(gate_and_val, 2, axis=-1)
139-
x = val * nn.silu(gate)
140-
x = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x)
141-
# Down-projection: (*, ff_size) -> (*, hidden_size).
123+
return val * nn.silu(gate)
124+
125+
126+
################################################################################
127+
# MARK: FeedForward Unified Block
128+
################################################################################
129+
130+
131+
class FeedForward(nn.Module):
132+
"""A unified FeedForward block selecting between SwiGLU or traditional layers.
133+
134+
Attributes:
135+
output_size: Output dimension (residual stream width).
136+
hidden_size: Intermediate bottleneck network dimension.
137+
ffn_type: Layout type toggle. - 'swiglu' uses a gated SwiGLU projection
138+
layer. - 'standard' uses a classic dense projection followed by an
139+
activation.
140+
activation: Name of the activation function to use when
141+
`ffn_type='standard'` (e.g., 'gelu', 'silu', 'relu'). This parameter is
142+
explicitly ignored when `ffn_type='swiglu'` because the SwiGLU path uses
143+
its own mathematical gating mechanism (SiLU/Swish).
144+
zero_init_output: If True, the terminal linear projections are zeroed out
145+
ensuring the block satisfies identity-at-init behavior.
146+
dropout_rate: Activation state dropout regularization coefficient.
147+
dtype: Numerical precision layout representation format.
148+
"""
149+
150+
output_size: int
151+
hidden_size: int
152+
ffn_type: str = "standard"
153+
activation: str = "gelu"
154+
zero_init_output: bool = False
155+
dropout_rate: float = 0.0
156+
dtype: DType = jnp.float32
157+
158+
def setup(self):
159+
if self.ffn_type not in ("standard", "swiglu"):
160+
raise ValueError(
161+
f"Unknown ffn_type: {self.ffn_type}. Must be 'standard' or 'swiglu'."
162+
)
163+
# Regularization Dropout Layer
164+
self.dropout = nn.Dropout(rate=self.dropout_rate)
165+
166+
# Down Projection Layer Config
142167
down_kernel_init = (
143168
nn.initializers.zeros_init()
144169
if self.zero_init_output
145170
else nn.initializers.lecun_normal()
146171
)
147-
x = nn.Dense(
148-
features=self.hidden_size,
149-
use_bias=False,
150-
dtype=self.dtype,
172+
# Standard SwiGLU down-projections generally omit biases
173+
self.use_down_bias = False if self.ffn_type == "swiglu" else True
174+
175+
self.down_proj = nn.Dense(
176+
features=self.output_size,
177+
use_bias=self.use_down_bias,
151178
kernel_init=down_kernel_init,
152-
name='Dense_Down',
153-
)(x)
179+
dtype=self.dtype,
180+
name="Dense_Down",
181+
)
182+
183+
@nn.compact
184+
@kt.typechecked
185+
def __call__(
186+
self, x: Float["batch *other_dims output_size"], *, is_training: bool
187+
) -> Float["batch *other_dims output_size"]:
188+
# Up-projection step
189+
if self.ffn_type == "swiglu":
190+
# Project to double feature width
191+
gate_and_val = nn.Dense(
192+
features=self.hidden_size * 2,
193+
use_bias=False,
194+
dtype=self.dtype,
195+
name="Dense_Up",
196+
)(x)
197+
# Split and apply SiLU gating
198+
val, gate = jnp.split(gate_and_val, 2, axis=-1)
199+
x = val * nn.silu(gate)
200+
elif self.ffn_type == "standard":
201+
x = nn.Dense(
202+
features=self.hidden_size,
203+
use_bias=True,
204+
dtype=self.dtype,
205+
name="Dense_Up",
206+
)(x)
207+
# Apply the configured activation function
208+
activation_fn = getattr(nn, self.activation)
209+
x = activation_fn(x)
210+
else:
211+
raise ValueError(f"Unknown ffn_type mapping strategy: {self.ffn_type!r}")
212+
213+
# Middle regularization step
214+
x = self.dropout(x, deterministic=not is_training)
215+
216+
# Final down-projection step
217+
x = self.down_proj(x)
218+
154219
return x

0 commit comments

Comments
 (0)