2323from absl .testing import absltest
2424from flax import nnx
2525from jax .sharding import Mesh
26-
26+ from flax . linen import partitioning as nn_partitioning
2727from .. import pyconfig
2828from ..max_utils import (create_device_mesh , get_flash_block_sizes )
2929from ..models .wan .transformers .transformer_wan import (
@@ -53,6 +53,18 @@ class WanTransformerTest(unittest.TestCase):
5353
5454 def setUp (self ):
5555 WanTransformerTest .dummy_data = {}
56+ pyconfig .initialize (
57+ [
58+ None ,
59+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
60+ ],
61+ unittest = True ,
62+ )
63+ config = pyconfig .config
64+ self .config = config
65+ devices_array = create_device_mesh (config )
66+ self .mesh = Mesh (devices_array , config .mesh_axes )
67+
5668
5769 def test_rotary_pos_embed (self ):
5870 batch_size = 1
@@ -70,28 +82,31 @@ def test_nnx_pixart_alpha_text_projection(self):
7082 key = jax .random .key (0 )
7183 rngs = nnx .Rngs (key )
7284 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
73- layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
74- dummy_output = layer (dummy_caption )
75- dummy_output .shape == (1 , 512 , 5120 )
85+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
86+ layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
87+ dummy_output = layer (dummy_caption )
88+ dummy_output .shape == (1 , 512 , 5120 )
7689
7790 def test_nnx_timestep_embedding (self ):
7891 key = jax .random .key (0 )
7992 rngs = nnx .Rngs (key )
8093
8194 dummy_sample = jnp .ones ((1 , 256 ))
82- layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
83- dummy_output = layer (dummy_sample )
84- assert dummy_output .shape == (1 , 5120 )
95+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
96+ layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
97+ dummy_output = layer (dummy_sample )
98+ assert dummy_output .shape == (1 , 5120 )
8599
86100 def test_fp32_layer_norm (self ):
87101 key = jax .random .key (0 )
88102 rngs = nnx .Rngs (key )
89103 batch_size = 1
90104 dummy_hidden_states = jnp .ones ((batch_size , 75600 , 5120 ))
91105 # expected same output shape with same dtype
92- layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
93- dummy_output = layer (dummy_hidden_states )
94- assert dummy_output .shape == dummy_hidden_states .shape
106+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
107+ layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
108+ dummy_output = layer (dummy_hidden_states )
109+ assert dummy_output .shape == dummy_hidden_states .shape
95110
96111 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
97112 def test_wan_time_text_embedding (self ):
@@ -102,20 +117,21 @@ def test_wan_time_text_embedding(self):
102117 time_freq_dim = 256
103118 time_proj_dim = 30720
104119 text_embed_dim = 4096
105- layer = WanTimeTextImageEmbedding (
106- rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
107- )
120+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
121+ layer = WanTimeTextImageEmbedding (
122+ rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
123+ )
108124
109- dummy_timestep = jnp .ones (batch_size )
125+ dummy_timestep = jnp .ones (batch_size )
110126
111- encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
112- dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
113- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
114- dummy_timestep , dummy_encoder_hidden_states
115- )
116- assert temb .shape == (batch_size , dim )
117- assert timestep_proj .shape == (batch_size , time_proj_dim )
118- assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
127+ encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
128+ dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
129+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
130+ dummy_timestep , dummy_encoder_hidden_states
131+ )
132+ assert temb .shape == (batch_size , dim )
133+ assert timestep_proj .shape == (batch_size , time_proj_dim )
134+ assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
119135
120136 def test_wan_block (self ):
121137 key = jax .random .key (0 )
@@ -163,20 +179,19 @@ def test_wan_block(self):
163179 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , dim ))
164180
165181 dummy_temb = jnp .ones ((batch_size , 6 , dim ))
166-
167- wan_block = WanTransformerBlock (
168- rngs = rngs ,
169- dim = dim ,
170- ffn_dim = ffn_dim ,
171- num_heads = num_heads ,
172- qk_norm = qk_norm ,
173- cross_attn_norm = cross_attn_norm ,
174- eps = eps ,
175- attention = "flash" ,
176- mesh = mesh ,
177- flash_block_sizes = flash_block_sizes ,
178- )
179- with mesh :
182+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
183+ wan_block = WanTransformerBlock (
184+ rngs = rngs ,
185+ dim = dim ,
186+ ffn_dim = ffn_dim ,
187+ num_heads = num_heads ,
188+ qk_norm = qk_norm ,
189+ cross_attn_norm = cross_attn_norm ,
190+ eps = eps ,
191+ attention = "flash" ,
192+ mesh = mesh ,
193+ flash_block_sizes = flash_block_sizes ,
194+ )
180195 dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
181196 assert dummy_output .shape == dummy_hidden_states .shape
182197
@@ -209,40 +224,39 @@ def test_wan_attention(self):
209224 mesh = Mesh (devices_array , config .mesh_axes )
210225 batch_size = 1
211226 query_dim = 5120
212- attention = FlaxWanAttention (
213- rngs = rngs ,
214- query_dim = query_dim ,
215- heads = 40 ,
216- dim_head = 128 ,
217- attention_kernel = "flash" ,
218- mesh = mesh ,
219- flash_block_sizes = flash_block_sizes ,
220- )
221-
222- dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
223-
224- dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
225- dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
226- with mesh :
227- dummy_output = attention (
228- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
229- )
230- assert dummy_output .shape == dummy_hidden_states_shape
231-
232- # dot product
233- try :
227+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
234228 attention = FlaxWanAttention (
235229 rngs = rngs ,
236230 query_dim = query_dim ,
237231 heads = 40 ,
238232 dim_head = 128 ,
239- attention_kernel = "dot_product" ,
240- split_head_dim = True ,
233+ attention_kernel = "flash" ,
241234 mesh = mesh ,
242235 flash_block_sizes = flash_block_sizes ,
243236 )
244- except NotImplementedError :
245- pass
237+ dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
238+
239+ dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
240+ dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
241+ dummy_output = attention (
242+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
243+ )
244+ assert dummy_output .shape == dummy_hidden_states_shape
245+
246+ # dot product
247+ try :
248+ attention = FlaxWanAttention (
249+ rngs = rngs ,
250+ query_dim = query_dim ,
251+ heads = 40 ,
252+ dim_head = 128 ,
253+ attention_kernel = "dot_product" ,
254+ split_head_dim = True ,
255+ mesh = mesh ,
256+ flash_block_sizes = flash_block_sizes ,
257+ )
258+ except NotImplementedError :
259+ pass
246260
247261 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
248262 def test_wan_model (self ):
@@ -272,7 +286,8 @@ def test_wan_model(self):
272286 mesh = Mesh (devices_array , config .mesh_axes )
273287 batch_size = 1
274288 num_layers = 1
275- wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
289+ with nn_partitioning .axis_rules (config .logical_axis_rules ):
290+ wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
276291
277292 dummy_timestep = jnp .ones ((batch_size ))
278293 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , 4096 ))
0 commit comments