Skip to content

Commit 65da062

Browse files
committed
test fix
1 parent b185681 commit 65da062

File tree

3 files changed

+198
-165
lines changed

3 files changed

+198
-165
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ logical_axis_rules: [
162162
['norm', 'tensor'],
163163
['conv_batch', ['data','fsdp']],
164164
['out_channels', 'tensor'],
165+
['conv_in', 'fsdp'],
165166
['conv_out', 'fsdp'],
166167
]
167168
data_sharding: [['data', 'fsdp', 'tensor']]

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 78 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
26-
26+
from flax.linen import partitioning as nn_partitioning
2727
from .. import pyconfig
2828
from ..max_utils import (create_device_mesh, get_flash_block_sizes)
2929
from ..models.wan.transformers.transformer_wan import (
@@ -53,6 +53,18 @@ class WanTransformerTest(unittest.TestCase):
5353

5454
def setUp(self):
5555
WanTransformerTest.dummy_data = {}
56+
pyconfig.initialize(
57+
[
58+
None,
59+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
60+
],
61+
unittest=True,
62+
)
63+
config = pyconfig.config
64+
self.config = config
65+
devices_array = create_device_mesh(config)
66+
self.mesh = Mesh(devices_array, config.mesh_axes)
67+
5668

5769
def test_rotary_pos_embed(self):
5870
batch_size = 1
@@ -70,28 +82,31 @@ def test_nnx_pixart_alpha_text_projection(self):
7082
key = jax.random.key(0)
7183
rngs = nnx.Rngs(key)
7284
dummy_caption = jnp.ones((1, 512, 4096))
73-
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
74-
dummy_output = layer(dummy_caption)
75-
dummy_output.shape == (1, 512, 5120)
85+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
86+
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
87+
dummy_output = layer(dummy_caption)
88+
dummy_output.shape == (1, 512, 5120)
7689

7790
def test_nnx_timestep_embedding(self):
7891
key = jax.random.key(0)
7992
rngs = nnx.Rngs(key)
8093

8194
dummy_sample = jnp.ones((1, 256))
82-
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
83-
dummy_output = layer(dummy_sample)
84-
assert dummy_output.shape == (1, 5120)
95+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
96+
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
97+
dummy_output = layer(dummy_sample)
98+
assert dummy_output.shape == (1, 5120)
8599

86100
def test_fp32_layer_norm(self):
87101
key = jax.random.key(0)
88102
rngs = nnx.Rngs(key)
89103
batch_size = 1
90104
dummy_hidden_states = jnp.ones((batch_size, 75600, 5120))
91105
# expected same output shape with same dtype
92-
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
93-
dummy_output = layer(dummy_hidden_states)
94-
assert dummy_output.shape == dummy_hidden_states.shape
106+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
107+
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
108+
dummy_output = layer(dummy_hidden_states)
109+
assert dummy_output.shape == dummy_hidden_states.shape
95110

96111
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
97112
def test_wan_time_text_embedding(self):
@@ -102,20 +117,21 @@ def test_wan_time_text_embedding(self):
102117
time_freq_dim = 256
103118
time_proj_dim = 30720
104119
text_embed_dim = 4096
105-
layer = WanTimeTextImageEmbedding(
106-
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
107-
)
120+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
121+
layer = WanTimeTextImageEmbedding(
122+
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
123+
)
108124

109-
dummy_timestep = jnp.ones(batch_size)
125+
dummy_timestep = jnp.ones(batch_size)
110126

111-
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
112-
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
113-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
114-
dummy_timestep, dummy_encoder_hidden_states
115-
)
116-
assert temb.shape == (batch_size, dim)
117-
assert timestep_proj.shape == (batch_size, time_proj_dim)
118-
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
127+
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
128+
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
129+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
130+
dummy_timestep, dummy_encoder_hidden_states
131+
)
132+
assert temb.shape == (batch_size, dim)
133+
assert timestep_proj.shape == (batch_size, time_proj_dim)
134+
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
119135

120136
def test_wan_block(self):
121137
key = jax.random.key(0)
@@ -163,20 +179,19 @@ def test_wan_block(self):
163179
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
164180

165181
dummy_temb = jnp.ones((batch_size, 6, dim))
166-
167-
wan_block = WanTransformerBlock(
168-
rngs=rngs,
169-
dim=dim,
170-
ffn_dim=ffn_dim,
171-
num_heads=num_heads,
172-
qk_norm=qk_norm,
173-
cross_attn_norm=cross_attn_norm,
174-
eps=eps,
175-
attention="flash",
176-
mesh=mesh,
177-
flash_block_sizes=flash_block_sizes,
178-
)
179-
with mesh:
182+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
183+
wan_block = WanTransformerBlock(
184+
rngs=rngs,
185+
dim=dim,
186+
ffn_dim=ffn_dim,
187+
num_heads=num_heads,
188+
qk_norm=qk_norm,
189+
cross_attn_norm=cross_attn_norm,
190+
eps=eps,
191+
attention="flash",
192+
mesh=mesh,
193+
flash_block_sizes=flash_block_sizes,
194+
)
180195
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
181196
assert dummy_output.shape == dummy_hidden_states.shape
182197

@@ -209,40 +224,39 @@ def test_wan_attention(self):
209224
mesh = Mesh(devices_array, config.mesh_axes)
210225
batch_size = 1
211226
query_dim = 5120
212-
attention = FlaxWanAttention(
213-
rngs=rngs,
214-
query_dim=query_dim,
215-
heads=40,
216-
dim_head=128,
217-
attention_kernel="flash",
218-
mesh=mesh,
219-
flash_block_sizes=flash_block_sizes,
220-
)
221-
222-
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
223-
224-
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
225-
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
226-
with mesh:
227-
dummy_output = attention(
228-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
229-
)
230-
assert dummy_output.shape == dummy_hidden_states_shape
231-
232-
# dot product
233-
try:
227+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
234228
attention = FlaxWanAttention(
235229
rngs=rngs,
236230
query_dim=query_dim,
237231
heads=40,
238232
dim_head=128,
239-
attention_kernel="dot_product",
240-
split_head_dim=True,
233+
attention_kernel="flash",
241234
mesh=mesh,
242235
flash_block_sizes=flash_block_sizes,
243236
)
244-
except NotImplementedError:
245-
pass
237+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
238+
239+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
240+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
241+
dummy_output = attention(
242+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
243+
)
244+
assert dummy_output.shape == dummy_hidden_states_shape
245+
246+
# dot product
247+
try:
248+
attention = FlaxWanAttention(
249+
rngs=rngs,
250+
query_dim=query_dim,
251+
heads=40,
252+
dim_head=128,
253+
attention_kernel="dot_product",
254+
split_head_dim=True,
255+
mesh=mesh,
256+
flash_block_sizes=flash_block_sizes,
257+
)
258+
except NotImplementedError:
259+
pass
246260

247261
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
248262
def test_wan_model(self):
@@ -272,7 +286,8 @@ def test_wan_model(self):
272286
mesh = Mesh(devices_array, config.mesh_axes)
273287
batch_size = 1
274288
num_layers = 1
275-
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
289+
with nn_partitioning.axis_rules(config.logical_axis_rules):
290+
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
276291

277292
dummy_timestep = jnp.ones((batch_size))
278293
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))

0 commit comments

Comments
 (0)