Skip to content

Commit 83d5c16

Browse files
authored
[Flux] Fix file saving race condition (#1217)
We see OSError when saving generated images. It is because the execution of following two lines (from `sampling.py::save_image()`) are interleaved between ranks: ``` if not os.path.exists(output_dir): os.makedirs(output_dir) ``` Changed to use os.makedirs(output_dir, exist_ok=True) instead of os.path.exists check Thanks @tianyu-l for identifying this issue
1 parent c1e796b commit 83d5c16

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchtitan/experiments/flux/sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
preprocess_data,
2828
unpack_latents,
2929
)
30+
from torchtitan.tools.logging import logger
3031

3132

3233
# ----------------------------------------
@@ -218,11 +219,10 @@ def save_image(
218219
add_sampling_metadata: bool,
219220
prompt: str,
220221
):
221-
output_dir = os.path.join(output_dir, f"rank_{torch.distributed.get_rank()}")
222-
print(f"Saving {output_dir}/{name}")
223-
if not os.path.exists(output_dir):
224-
os.makedirs(output_dir)
222+
logger.info(f"Saving image to {output_dir}/{name}")
223+
os.makedirs(output_dir, exist_ok=True)
225224
output_name = os.path.join(output_dir, name)
225+
226226
# bring into PIL format and save
227227
x = x.clamp(-1, 1)
228228
x = rearrange(x[0], "c h w -> h w c")

0 commit comments

Comments
 (0)