@@ -170,9 +170,11 @@ def test_variable_shapes_ada_rms_norm(self):
170170 'ffn' : {
171171 'Dense_Up' : {
172172 'kernel' : (self .d , mlp_hidden * 2 ),
173+ 'bias' : (mlp_hidden * 2 ,),
173174 },
174175 'Dense_Down' : {
175176 'kernel' : (mlp_hidden , self .d ),
177+ 'bias' : (self .d ,),
176178 },
177179 },
178180 'attn' : {
@@ -226,9 +228,11 @@ def test_variable_shapes_ada_ln_zero(self):
226228 'ffn' : {
227229 'Dense_Up' : {
228230 'kernel' : (self .d , mlp_hidden ),
231+ 'bias' : (mlp_hidden ,),
229232 },
230233 'Dense_Down' : {
231234 'kernel' : (mlp_hidden , self .d ),
235+ 'bias' : (self .d ,),
232236 },
233237 },
234238 'ConditionalNorm_MLP' : {
@@ -376,6 +380,98 @@ def test_ada_ln_zero_has_gates(self):
376380 gate_paths = [p for p in leaves_with_paths if 'Gate' in p ]
377381 self .assertNotEmpty (gate_paths )
378382
383+ # MARK: qk_norm_method tests
384+
385+ @parameterized .named_parameters (
386+ ('l2' , 'l2' ),
387+ ('rms_norm' , 'rms_norm' ),
388+ )
389+ def test_preset_qk_norm_method_output_shape (self , qk_norm_method ):
390+ """Tests that DiTBlock with each qk_norm_method produces correct shape."""
391+ x = jnp .ones ((self .batch , self .n , self .d ))
392+ cond = jnp .ones ((self .batch , self .c ))
393+ module = dit_blocks .DiTBlock (
394+ hidden_size = self .d ,
395+ num_heads = 4 ,
396+ norm_factory = normalization .NormalizationLayerFactory (
397+ normalization_method = NormalizationType .RMS_NORM ,
398+ use_conditional_shift = False ,
399+ ),
400+ use_gates = False ,
401+ attn_normalize_qk = True ,
402+ attn_qk_norm_method = qk_norm_method ,
403+ )
404+ variables = module .init (self .key , x , cond , is_training = False )
405+ output = module .apply (variables , x , cond , is_training = False )
406+ self .assertEqual (output .shape , (self .batch , self .n , self .d ))
407+
408+ def test_flux_uses_rms_norm_qk (self ):
409+ """Verifies DiTBlockFlux uses RMSNorm QK normalization."""
410+ x = jnp .ones ((self .batch , self .n , self .d ))
411+ cond = jnp .ones ((self .batch , self .c ))
412+ module = dit_blocks .DiTBlockFlux (hidden_size = self .d , num_heads = 4 )
413+ variables = module .init (self .key , x , cond , is_training = False )
414+ leaves = test_helpers .get_leaves_with_paths (variables )
415+ # Flux uses rms_norm method: should have RMSNorm_Q/K, no norm_qk_scale
416+ rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p ]
417+ self .assertNotEmpty (rms_paths )
418+ l2_paths = [p for p in leaves if 'norm_qk_scale' in p ]
419+ self .assertEmpty (l2_paths )
420+
421+ def test_sd3_uses_rms_norm_qk (self ):
422+ """Verifies DiTBlockSD3 uses RMSNorm QK normalization."""
423+ x = jnp .ones ((self .batch , self .n , self .d ))
424+ cond = jnp .ones ((self .batch , self .c ))
425+ module = dit_blocks .DiTBlockSD3 (hidden_size = self .d , num_heads = 4 )
426+ variables = module .init (self .key , x , cond , is_training = False )
427+ leaves = test_helpers .get_leaves_with_paths (variables )
428+ # SD3 uses rms_norm method: should have RMSNorm_Q/K, no norm_qk_scale
429+ rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p ]
430+ self .assertNotEmpty (rms_paths )
431+ l2_paths = [p for p in leaves if 'norm_qk_scale' in p ]
432+ self .assertEmpty (l2_paths )
433+
434+ def test_ada_ln_zero_has_no_qk_norm (self ):
435+ """Verifies DiTBlockAdaLNZero has no QK normalization params."""
436+ x = jnp .ones ((self .batch , self .n , self .d ))
437+ cond = jnp .ones ((self .batch , self .c ))
438+ module = dit_blocks .DiTBlockAdaLNZero (hidden_size = self .d , num_heads = 4 )
439+ variables = module .init (self .key , x , cond , is_training = False )
440+ leaves = test_helpers .get_leaves_with_paths (variables )
441+ norm_paths = [
442+ p
443+ for p in leaves
444+ if 'norm_qk' in p or 'RMSNorm_Q' in p or 'RMSNorm_K' in p
445+ ]
446+ self .assertEmpty (norm_paths )
447+
448+ def test_dit_block_no_attn_bias_with_rms_norm_qk (self ):
449+ """Verifies DiTBlock with use_bias=False and rms_norm QK norm."""
450+ x = jnp .ones ((self .batch , self .n , self .d ))
451+ cond = jnp .ones ((self .batch , self .c ))
452+ module = dit_blocks .DiTBlock (
453+ hidden_size = self .d ,
454+ num_heads = 4 ,
455+ norm_factory = normalization .NormalizationLayerFactory (
456+ normalization_method = NormalizationType .RMS_NORM ,
457+ use_conditional_shift = False ,
458+ ),
459+ use_gates = False ,
460+ attn_normalize_qk = True ,
461+ attn_qk_norm_method = 'rms_norm' ,
462+ attn_use_bias = False ,
463+ )
464+ variables = module .init (self .key , x , cond , is_training = False )
465+ leaves = test_helpers .get_leaves_with_paths (variables )
466+ # No bias in attention
467+ attn_bias_paths = [
468+ p for p in leaves if p .startswith ('params/attn/' ) and 'bias' in p
469+ ]
470+ self .assertEmpty (attn_bias_paths )
471+ # Has RMSNorm_Q/K
472+ rms_paths = [p for p in leaves if 'RMSNorm_Q' in p or 'RMSNorm_K' in p ]
473+ self .assertNotEmpty (rms_paths )
474+
379475
380476class PositionalEmbeddingTest (parameterized .TestCase ):
381477
0 commit comments