Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions hackable_diffusion/lib/architecture/dit_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class DiTBlock(nn.Module):
to avoid bias in the FFN.
ffn_activation: Activation function for the FFN.
attn_normalize_qk: Whether to normalize query and key in attention.
attn_qk_norm_method: Normalization method for query and key in attention.
attn_use_bias: Whether to use bias in the attention QKV and output
projections.
mlp_ratio: The ratio of the MLP hidden dimension to the hidden size.
Expand All @@ -149,10 +150,11 @@ class DiTBlock(nn.Module):
num_heads: int = INVALID_INT
head_dim: int = INVALID_INT
use_gates: bool = True
ffn_type: mlp_blocks.FFNType = 'swiglu'
ffn_use_bias: bool = False
ffn_type: mlp_blocks.FFNType = 'dense'
ffn_use_bias: bool = True
ffn_activation: str = 'gelu'
attn_normalize_qk: bool = True
attn_qk_norm_method: attention.AttnQKNormMethod = 'l2'
attn_use_bias: bool = True
mlp_ratio: float = 4.0
use_rope: bool = False
Expand Down Expand Up @@ -194,6 +196,7 @@ def setup(self):
zero_init_output=self.zero_init_output,
dtype=self.dtype,
normalize_qk=self.attn_normalize_qk,
qk_norm_method=self.attn_qk_norm_method,
use_bias=self.attn_use_bias,
dropout_rate=self.dropout_rate,
)
Expand Down Expand Up @@ -282,7 +285,11 @@ class DiTBlockFlux(DiTBlock):
use_gates: bool = dataclasses.field(init=False, default=False)
zero_init_output: bool = dataclasses.field(init=False, default=True)
attn_normalize_qk: bool = dataclasses.field(init=False, default=True)
attn_qk_norm_method: attention.AttnQKNormMethod = dataclasses.field(
init=False, default='rms_norm'
)
attn_use_bias: bool = dataclasses.field(init=False, default=False)
ffn_use_bias: bool = dataclasses.field(init=False, default=False)

def __post_init__(self):
self.norm_factory = normalization.NormalizationLayerFactory(
Expand All @@ -305,7 +312,11 @@ class DiTBlockSD3(DiTBlock):
use_gates: bool = dataclasses.field(init=False, default=True)
zero_init_output: bool = dataclasses.field(init=False, default=False)
attn_normalize_qk: bool = dataclasses.field(init=False, default=True)
attn_use_bias: bool = dataclasses.field(init=False, default=False)
attn_qk_norm_method: attention.AttnQKNormMethod = dataclasses.field(
init=False, default='rms_norm'
)
attn_use_bias: bool = dataclasses.field(init=False, default=True)
ffn_use_bias: bool = dataclasses.field(init=False, default=True)

def __post_init__(self):
self.norm_factory = normalization.NormalizationLayerFactory(
Expand All @@ -330,6 +341,7 @@ class DiTBlockAdaLNZero(DiTBlock):
zero_init_output: bool = dataclasses.field(init=False, default=False)
attn_normalize_qk: bool = dataclasses.field(init=False, default=False)
attn_use_bias: bool = dataclasses.field(init=False, default=True)
ffn_use_bias: bool = dataclasses.field(init=False, default=True)

def __post_init__(self):
self.norm_factory = normalization.NormalizationLayerFactory(
Expand Down
96 changes: 96 additions & 0 deletions hackable_diffusion/lib/architecture/dit_blocks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@ def test_variable_shapes_ada_rms_norm(self):
'ffn': {
'Dense_Up': {
'kernel': (self.d, mlp_hidden * 2),
'bias': (mlp_hidden * 2,),
},
'Dense_Down': {
'kernel': (mlp_hidden, self.d),
'bias': (self.d,),
},
},
'attn': {
Expand Down Expand Up @@ -226,9 +228,11 @@ def test_variable_shapes_ada_ln_zero(self):
'ffn': {
'Dense_Up': {
'kernel': (self.d, mlp_hidden),
'bias': (mlp_hidden,),
},
'Dense_Down': {
'kernel': (mlp_hidden, self.d),
'bias': (self.d,),
},
},
'ConditionalNorm_MLP': {
Expand Down Expand Up @@ -376,6 +380,98 @@ def test_ada_ln_zero_has_gates(self):
gate_paths = [p for p in leaves_with_paths if 'Gate' in p]
self.assertNotEmpty(gate_paths)

# MARK: qk_norm_method tests

@parameterized.named_parameters(
('l2_qk_norm_method', 'l2'),
('rms_norm_qk_norm_method', 'rms_norm'),
)
def test_preset_qk_norm_method_output_shape(self, qk_norm_method):
"""Tests that DiTBlock with each qk_norm_method produces correct shape."""
x = jnp.ones((self.batch, self.n, self.d))
cond = jnp.ones((self.batch, self.c))
module = dit_blocks.DiTBlock(
hidden_size=self.d,
num_heads=4,
norm_factory=normalization.NormalizationLayerFactory(
normalization_method=NormalizationType.RMS_NORM,
use_conditional_shift=False,
),
use_gates=False,
attn_normalize_qk=True,
attn_qk_norm_method=qk_norm_method,
)
variables = module.init(self.key, x, cond, is_training=False)
output = module.apply(variables, x, cond, is_training=False)
self.assertEqual(output.shape, (self.batch, self.n, self.d))

def test_flux_uses_rms_norm_qk(self):
"""Verifies DiTBlockFlux uses RMSNorm QK normalization."""
x = jnp.ones((self.batch, self.n, self.d))
cond = jnp.ones((self.batch, self.c))
module = dit_blocks.DiTBlockFlux(hidden_size=self.d, num_heads=4)
variables = module.init(self.key, x, cond, is_training=False)
leaves = test_helpers.get_leaves_with_paths(variables)
# Flux uses rms_norm method: should have RMSNorm_Q/K, no norm_qk_scale
rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p]
self.assertNotEmpty(rms_paths)
l2_paths = [p for p in leaves if 'norm_qk_scale' in p]
self.assertEmpty(l2_paths)

def test_sd3_uses_rms_norm_qk(self):
"""Verifies DiTBlockSD3 uses RMSNorm QK normalization."""
x = jnp.ones((self.batch, self.n, self.d))
cond = jnp.ones((self.batch, self.c))
module = dit_blocks.DiTBlockSD3(hidden_size=self.d, num_heads=4)
variables = module.init(self.key, x, cond, is_training=False)
leaves = test_helpers.get_leaves_with_paths(variables)
# SD3 uses rms_norm method: should have RMSNorm_Q/K, no norm_qk_scale
rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p]
self.assertNotEmpty(rms_paths)
l2_paths = [p for p in leaves if 'norm_qk_scale' in p]
self.assertEmpty(l2_paths)

def test_ada_ln_zero_has_no_qk_norm(self):
"""Verifies DiTBlockAdaLNZero has no QK normalization params."""
x = jnp.ones((self.batch, self.n, self.d))
cond = jnp.ones((self.batch, self.c))
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
variables = module.init(self.key, x, cond, is_training=False)
leaves = test_helpers.get_leaves_with_paths(variables)
norm_paths = [
p
for p in leaves
if 'norm_qk' in p or 'RMSNorm_Q' in p or 'RMSNorm_K' in p
]
self.assertEmpty(norm_paths)

def test_dit_block_no_attn_bias_with_rms_norm_qk(self):
"""Verifies DiTBlock with use_bias=False and rms_norm QK norm."""
x = jnp.ones((self.batch, self.n, self.d))
cond = jnp.ones((self.batch, self.c))
module = dit_blocks.DiTBlock(
hidden_size=self.d,
num_heads=4,
norm_factory=normalization.NormalizationLayerFactory(
normalization_method=NormalizationType.RMS_NORM,
use_conditional_shift=False,
),
use_gates=False,
attn_normalize_qk=True,
attn_qk_norm_method='rms_norm',
attn_use_bias=False,
)
variables = module.init(self.key, x, cond, is_training=False)
leaves = test_helpers.get_leaves_with_paths(variables)
# No bias in attention
attn_bias_paths = [
p for p in leaves if p.startswith('params/attn/') and 'bias' in p
]
self.assertEmpty(attn_bias_paths)
# Has RMSNorm_Q/K
rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p]
self.assertNotEmpty(rms_paths)


class PositionalEmbeddingTest(parameterized.TestCase):

Expand Down
2 changes: 2 additions & 0 deletions hackable_diffusion/lib/architecture/dit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@ def test_variable_shapes_with_patchify(self):
'ffn': {
'Dense_Up': {
'kernel': (self.embedding_dim, mlp_hidden),
'bias': (mlp_hidden,),
},
'Dense_Down': {
'kernel': (mlp_hidden, self.embedding_dim),
'bias': (self.embedding_dim,),
},
},
'ConditionalNorm_MLP': {
Expand Down
Loading