Skip to content

Commit 7d22496

Browse files
agalashovHackable Diffusion Authors
authored andcommitted
Correct that DiT blocks to match their corresponding implementations
PiperOrigin-RevId: 919043097
1 parent efcd19e commit 7d22496

4 files changed

Lines changed: 242 additions & 21 deletions

File tree

hackable_diffusion/lib/architecture/attention.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Attention layers and utils."""
1616

17-
from typing import Callable
17+
from typing import Callable, Literal
1818
import warnings
1919

2020
import flax.linen as nn
@@ -37,6 +37,8 @@
3737
RoPEPositionType = arch_typing.RoPEPositionType
3838
INVALID_INT = arch_typing.INVALID_INT
3939

40+
AttnQKNormMethod = Literal["l2", "rms_norm"]
41+
4042
################################################################################
4143
# MARK: Constants
4244
################################################################################
@@ -211,6 +213,7 @@ class MultiHeadAttention(nn.Module):
211213
num_heads: int = INVALID_INT
212214
head_dim: int = INVALID_INT
213215
normalize_qk: bool = False
216+
qk_norm_method: AttnQKNormMethod = "l2"
214217
use_rope: bool = False
215218
rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE
216219
use_bias: bool = True
@@ -303,6 +306,32 @@ def __call__(
303306
v = v.reshape(b, seq_len_kv, num_heads, head_d).transpose(0, 2, 1, 3)
304307
# shape is [batch, num_heads, sequence_length, head_dim]
305308

309+
if self.normalize_qk:
310+
if self.qk_norm_method == "rms_norm":
311+
q = nn.RMSNorm(name="RMSNorm_Q")(q)
312+
k = nn.RMSNorm(name="RMSNorm_K")(k)
313+
scale = 1.0 / jnp.sqrt(jnp.float32(head_d))
314+
# QK L2 normalization: https://arxiv.org/abs/2010.04245
315+
elif self.qk_norm_method == "l2":
316+
scale = self.param(
317+
"norm_qk_scale",
318+
nn.initializers.constant(
319+
jnp.log2(seq_len_kv**2 - seq_len_kv + SAFETY_EPSILON)
320+
),
321+
(1, 1, 1, 1),
322+
)
323+
324+
norm_q = jnp.linalg.norm(q, ord=2, axis=-1, keepdims=True)
325+
norm_k = jnp.linalg.norm(k, ord=2, axis=-1, keepdims=True)
326+
q = q / (norm_q + SAFETY_EPSILON)
327+
k = k / (norm_k + SAFETY_EPSILON)
328+
else:
329+
raise ValueError(
330+
f"Unsupported QK normalization method: {self.qk_norm_method}."
331+
)
332+
else:
333+
scale = 1.0 / jnp.sqrt(jnp.float32(head_d))
334+
306335
# RoPE: https://arxiv.org/abs/2104.09864
307336
if self.use_rope:
308337
q = sequence_embedders.RoPESequenceEmbedding(
@@ -313,23 +342,6 @@ def __call__(
313342
)(k)
314343
# shape is [batch, num_heads, sequence_length, head_dim]
315344

316-
# QK normalization: https://arxiv.org/abs/2010.04245.
317-
if self.normalize_qk:
318-
scale = self.param(
319-
"norm_qk_scale",
320-
nn.initializers.constant(
321-
jnp.log2(seq_len_kv**2 - seq_len_kv + SAFETY_EPSILON)
322-
),
323-
(1, 1, 1, 1),
324-
)
325-
326-
norm_q = jnp.linalg.norm(q, ord=2, axis=-1, keepdims=True)
327-
norm_k = jnp.linalg.norm(k, ord=2, axis=-1, keepdims=True)
328-
q = q / (norm_q + SAFETY_EPSILON)
329-
k = k / (norm_k + SAFETY_EPSILON)
330-
else:
331-
scale = 1.0 / jnp.sqrt(head_d)
332-
333345
attn_output = _dot_product_attention(
334346
q=q,
335347
k=k,

hackable_diffusion/lib/architecture/attention_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,107 @@ def test_multi_head_attention_no_bias_param_shapes(self):
555555
}
556556
self.assertDictEqual(expected, variables_shapes)
557557

558+
# MARK: qk_norm_method tests
559+
560+
@parameterized.named_parameters(
561+
("l2", "l2"),
562+
("rms_norm", "rms_norm"),
563+
)
564+
def test_qk_norm_method_output_shape(self, qk_norm_method):
565+
"""Verifies output shape is correct for each qk_norm_method."""
566+
module = attention.MultiHeadAttention(
567+
num_heads=self.num_heads,
568+
normalize_qk=True,
569+
qk_norm_method=qk_norm_method,
570+
)
571+
variables = module.init(self.rng, self.x, c=None)
572+
output = module.apply(variables, self.x, c=None, is_training=False)
573+
self.assertEqual(output.shape, self.x.shape)
574+
575+
def test_qk_norm_l2_param_shapes(self):
576+
"""Verifies L2 QK normalization creates a norm_qk_scale parameter."""
577+
module = attention.MultiHeadAttention(
578+
num_heads=self.num_heads,
579+
normalize_qk=True,
580+
qk_norm_method="l2",
581+
)
582+
variables = module.init(self.rng, self.x, c=None)
583+
leaves = test_helpers.get_leaves_with_paths(variables)
584+
# L2 method should have a norm_qk_scale param
585+
self.assertIn("params/norm_qk_scale", leaves)
586+
self.assertEqual(leaves["params/norm_qk_scale"].shape, (1, 1, 1, 1))
587+
# Should NOT have RMSNorm_Q/K
588+
rms_paths = [p for p in leaves if "RMSNorm" in p]
589+
self.assertEmpty(rms_paths)
590+
591+
def test_qk_norm_rms_norm_param_shapes(self):
592+
"""Verifies RMSNorm QK normalization creates RMSNorm_Q/K scale params."""
593+
module = attention.MultiHeadAttention(
594+
num_heads=self.num_heads,
595+
normalize_qk=True,
596+
qk_norm_method="rms_norm",
597+
)
598+
variables = module.init(self.rng, self.x, c=None)
599+
leaves = test_helpers.get_leaves_with_paths(variables)
600+
# RMSNorm method should have RMSNorm_Q/scale and RMSNorm_K/scale
601+
self.assertIn("params/RMSNorm_Q/scale", leaves)
602+
self.assertIn("params/RMSNorm_K/scale", leaves)
603+
self.assertEqual(leaves["params/RMSNorm_Q/scale"].shape, (self.head_dim,))
604+
self.assertEqual(leaves["params/RMSNorm_K/scale"].shape, (self.head_dim,))
605+
# Should NOT have norm_qk_scale
606+
self.assertNotIn("params/norm_qk_scale", leaves)
607+
608+
def test_qk_norm_rms_norm_with_rope(self):
609+
"""Verifies RMSNorm QK norm works with RoPE (norm before RoPE)."""
610+
module = attention.MultiHeadAttention(
611+
num_heads=self.num_heads,
612+
normalize_qk=True,
613+
qk_norm_method="rms_norm",
614+
use_rope=True,
615+
rope_position_type=RoPEPositionType.SQUARE,
616+
)
617+
x = jnp.ones((self.batch_size, self.seq_len_kv, self.dim))
618+
variables = module.init(self.rng, x, c=None)
619+
output = module.apply(variables, x, c=None, is_training=False)
620+
self.assertEqual(output.shape, x.shape)
621+
622+
def test_qk_norm_l2_with_rope(self):
623+
"""Verifies L2 QK norm works with RoPE (norm before RoPE)."""
624+
module = attention.MultiHeadAttention(
625+
num_heads=self.num_heads,
626+
normalize_qk=True,
627+
qk_norm_method="l2",
628+
use_rope=True,
629+
rope_position_type=RoPEPositionType.SQUARE,
630+
)
631+
x = jnp.ones((self.batch_size, self.seq_len_kv, self.dim))
632+
variables = module.init(self.rng, x, c=None)
633+
output = module.apply(variables, x, c=None, is_training=False)
634+
self.assertEqual(output.shape, x.shape)
635+
636+
def test_qk_norm_disabled_has_no_norm_params(self):
637+
"""Verifies that normalize_qk=False creates no norm params."""
638+
module = attention.MultiHeadAttention(
639+
num_heads=self.num_heads,
640+
normalize_qk=False,
641+
)
642+
variables = module.init(self.rng, self.x, c=None)
643+
leaves = test_helpers.get_leaves_with_paths(variables)
644+
norm_paths = [p for p in leaves if "norm_qk" in p or "RMSNorm" in p]
645+
self.assertEmpty(norm_paths)
646+
647+
def test_qk_norm_invalid_method_raises_error(self):
648+
"""Verifies that an invalid qk_norm_method raises ValueError."""
649+
module = attention.MultiHeadAttention(
650+
num_heads=self.num_heads,
651+
normalize_qk=True,
652+
qk_norm_method="invalid_method", # pytype: disable=wrong-arg-types
653+
)
654+
with self.assertRaisesRegex(
655+
ValueError, "Unsupported QK normalization method"
656+
):
657+
module.init(self.rng, self.x, c=None)
658+
558659

559660
if __name__ == "__main__":
560661
absltest.main()

hackable_diffusion/lib/architecture/dit_blocks.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class DiTBlock(nn.Module):
133133
to avoid bias in the FFN.
134134
ffn_activation: Activation function for the FFN.
135135
attn_normalize_qk: Whether to normalize query and key in attention.
136+
attn_qk_norm_method: Normalization method for query and key in attention.
136137
attn_use_bias: Whether to use bias in the attention QKV and output
137138
projections.
138139
mlp_ratio: The ratio of the MLP hidden dimension to the hidden size.
@@ -149,10 +150,11 @@ class DiTBlock(nn.Module):
149150
num_heads: int = INVALID_INT
150151
head_dim: int = INVALID_INT
151152
use_gates: bool = True
152-
ffn_type: mlp_blocks.FFNType = 'swiglu'
153-
ffn_use_bias: bool = False
153+
ffn_type: mlp_blocks.FFNType = 'dense'
154+
ffn_use_bias: bool = True
154155
ffn_activation: str = 'gelu'
155156
attn_normalize_qk: bool = True
157+
attn_qk_norm_method: attention.AttnQKNormMethod = 'l2'
156158
attn_use_bias: bool = True
157159
mlp_ratio: float = 4.0
158160
use_rope: bool = False
@@ -194,6 +196,7 @@ def setup(self):
194196
zero_init_output=self.zero_init_output,
195197
dtype=self.dtype,
196198
normalize_qk=self.attn_normalize_qk,
199+
qk_norm_method=self.attn_qk_norm_method,
197200
use_bias=self.attn_use_bias,
198201
dropout_rate=self.dropout_rate,
199202
)
@@ -282,7 +285,11 @@ class DiTBlockFlux(DiTBlock):
282285
use_gates: bool = dataclasses.field(init=False, default=False)
283286
zero_init_output: bool = dataclasses.field(init=False, default=True)
284287
attn_normalize_qk: bool = dataclasses.field(init=False, default=True)
288+
attn_qk_norm_method: attention.AttnQKNormMethod = dataclasses.field(
289+
init=False, default='rms_norm'
290+
)
285291
attn_use_bias: bool = dataclasses.field(init=False, default=False)
292+
ffn_use_bias: bool = dataclasses.field(init=False, default=False)
286293

287294
def __post_init__(self):
288295
self.norm_factory = normalization.NormalizationLayerFactory(
@@ -305,7 +312,11 @@ class DiTBlockSD3(DiTBlock):
305312
use_gates: bool = dataclasses.field(init=False, default=True)
306313
zero_init_output: bool = dataclasses.field(init=False, default=False)
307314
attn_normalize_qk: bool = dataclasses.field(init=False, default=True)
308-
attn_use_bias: bool = dataclasses.field(init=False, default=False)
315+
attn_qk_norm_method: attention.AttnQKNormMethod = dataclasses.field(
316+
init=False, default='rms_norm'
317+
)
318+
attn_use_bias: bool = dataclasses.field(init=False, default=True)
319+
ffn_use_bias: bool = dataclasses.field(init=False, default=True)
309320

310321
def __post_init__(self):
311322
self.norm_factory = normalization.NormalizationLayerFactory(
@@ -330,6 +341,7 @@ class DiTBlockAdaLNZero(DiTBlock):
330341
zero_init_output: bool = dataclasses.field(init=False, default=False)
331342
attn_normalize_qk: bool = dataclasses.field(init=False, default=False)
332343
attn_use_bias: bool = dataclasses.field(init=False, default=True)
344+
ffn_use_bias: bool = dataclasses.field(init=False, default=True)
333345

334346
def __post_init__(self):
335347
self.norm_factory = normalization.NormalizationLayerFactory(

hackable_diffusion/lib/architecture/dit_blocks_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,11 @@ def test_variable_shapes_ada_rms_norm(self):
170170
'ffn': {
171171
'Dense_Up': {
172172
'kernel': (self.d, mlp_hidden * 2),
173+
'bias': (mlp_hidden * 2,),
173174
},
174175
'Dense_Down': {
175176
'kernel': (mlp_hidden, self.d),
177+
'bias': (self.d,),
176178
},
177179
},
178180
'attn': {
@@ -226,9 +228,11 @@ def test_variable_shapes_ada_ln_zero(self):
226228
'ffn': {
227229
'Dense_Up': {
228230
'kernel': (self.d, mlp_hidden),
231+
'bias': (mlp_hidden,),
229232
},
230233
'Dense_Down': {
231234
'kernel': (mlp_hidden, self.d),
235+
'bias': (self.d,),
232236
},
233237
},
234238
'ConditionalNorm_MLP': {
@@ -376,6 +380,98 @@ def test_ada_ln_zero_has_gates(self):
376380
gate_paths = [p for p in leaves_with_paths if 'Gate' in p]
377381
self.assertNotEmpty(gate_paths)
378382

383+
# MARK: qk_norm_method tests
384+
385+
@parameterized.named_parameters(
386+
('l2', 'l2'),
387+
('rms_norm', 'rms_norm'),
388+
)
389+
def test_preset_qk_norm_method_output_shape(self, qk_norm_method):
390+
"""Tests that DiTBlock with each qk_norm_method produces correct shape."""
391+
x = jnp.ones((self.batch, self.n, self.d))
392+
cond = jnp.ones((self.batch, self.c))
393+
module = dit_blocks.DiTBlock(
394+
hidden_size=self.d,
395+
num_heads=4,
396+
norm_factory=normalization.NormalizationLayerFactory(
397+
normalization_method=NormalizationType.RMS_NORM,
398+
use_conditional_shift=False,
399+
),
400+
use_gates=False,
401+
attn_normalize_qk=True,
402+
attn_qk_norm_method=qk_norm_method,
403+
)
404+
variables = module.init(self.key, x, cond, is_training=False)
405+
output = module.apply(variables, x, cond, is_training=False)
406+
self.assertEqual(output.shape, (self.batch, self.n, self.d))
407+
408+
def test_flux_uses_rms_norm_qk(self):
409+
"""Verifies DiTBlockFlux uses RMSNorm QK normalization."""
410+
x = jnp.ones((self.batch, self.n, self.d))
411+
cond = jnp.ones((self.batch, self.c))
412+
module = dit_blocks.DiTBlockFlux(hidden_size=self.d, num_heads=4)
413+
variables = module.init(self.key, x, cond, is_training=False)
414+
leaves = test_helpers.get_leaves_with_paths(variables)
415+
# Flux uses rms_norm method: should have RMSNorm_Q/K, no norm_qk_scale
416+
rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p]
417+
self.assertNotEmpty(rms_paths)
418+
l2_paths = [p for p in leaves if 'norm_qk_scale' in p]
419+
self.assertEmpty(l2_paths)
420+
421+
def test_sd3_uses_rms_norm_qk(self):
422+
"""Verifies DiTBlockSD3 uses RMSNorm QK normalization."""
423+
x = jnp.ones((self.batch, self.n, self.d))
424+
cond = jnp.ones((self.batch, self.c))
425+
module = dit_blocks.DiTBlockSD3(hidden_size=self.d, num_heads=4)
426+
variables = module.init(self.key, x, cond, is_training=False)
427+
leaves = test_helpers.get_leaves_with_paths(variables)
428+
# SD3 uses rms_norm method: should have RMSNorm_Q/K, no norm_qk_scale
429+
rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p]
430+
self.assertNotEmpty(rms_paths)
431+
l2_paths = [p for p in leaves if 'norm_qk_scale' in p]
432+
self.assertEmpty(l2_paths)
433+
434+
def test_ada_ln_zero_has_no_qk_norm(self):
435+
"""Verifies DiTBlockAdaLNZero has no QK normalization params."""
436+
x = jnp.ones((self.batch, self.n, self.d))
437+
cond = jnp.ones((self.batch, self.c))
438+
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
439+
variables = module.init(self.key, x, cond, is_training=False)
440+
leaves = test_helpers.get_leaves_with_paths(variables)
441+
norm_paths = [
442+
p
443+
for p in leaves
444+
if 'norm_qk' in p or 'RMSNorm_Q' in p or 'RMSNorm_K' in p
445+
]
446+
self.assertEmpty(norm_paths)
447+
448+
def test_dit_block_no_attn_bias_with_rms_norm_qk(self):
449+
"""Verifies DiTBlock with use_bias=False and rms_norm QK norm."""
450+
x = jnp.ones((self.batch, self.n, self.d))
451+
cond = jnp.ones((self.batch, self.c))
452+
module = dit_blocks.DiTBlock(
453+
hidden_size=self.d,
454+
num_heads=4,
455+
norm_factory=normalization.NormalizationLayerFactory(
456+
normalization_method=NormalizationType.RMS_NORM,
457+
use_conditional_shift=False,
458+
),
459+
use_gates=False,
460+
attn_normalize_qk=True,
461+
attn_qk_norm_method='rms_norm',
462+
attn_use_bias=False,
463+
)
464+
variables = module.init(self.key, x, cond, is_training=False)
465+
leaves = test_helpers.get_leaves_with_paths(variables)
466+
# No bias in attention
467+
attn_bias_paths = [
468+
p for p in leaves if p.startswith('params/attn/') and 'bias' in p
469+
]
470+
self.assertEmpty(attn_bias_paths)
471+
# Has RMSNorm_Q/K
472+
rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p]
473+
self.assertNotEmpty(rms_paths)
474+
379475

380476
class PositionalEmbeddingTest(parameterized.TestCase):
381477

0 commit comments

Comments
 (0)