Skip to content

Commit 70476a8

Browse files
committed
Fix ModelCheckpoint.file_exists OOM in DDP
1 parent 8f702b3 commit 70476a8

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -999,8 +999,10 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
999999
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
10001000
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
10011001
state to diverge between ranks."""
1002-
exists = self._fs.exists(filepath)
1003-
return trainer.strategy.broadcast(exists)
1002+
# In distributed setups, only global rank 0 touches the filesystem
1003+
local_decision = self._fs.exists(filepath) if trainer.is_global_zero else False
1004+
# Reduce the decision across ranks using an "any"-style reduction to decide if the file exists anywhere
1005+
return trainer.strategy.reduce_boolean_decision(local_decision, all=False)
10041006

10051007
def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
10061008
"""Checks if the previous checkpoint should be deleted.

tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,28 @@ def on_train_epoch_end(self):
121121
trainer.fit(model)
122122
if os.getenv("LOCAL_RANK") == "0":
123123
assert save_mock.call_count == expected
124+
125+
126+
@RunIf(min_cuda_gpus=2, standalone=True)
127+
def test_model_checkpoint_ddp_monitor_none(tmp_path):
128+
"""Ensure that ModelCheckpoint with monitor=None works correctly under DDP and exercises the file_exists path."""
129+
130+
model = BoringModel()
131+
checkpoint = callbacks.ModelCheckpoint(dirpath=tmp_path, monitor=None, save_top_k=1)
132+
133+
trainer = Trainer(
134+
default_root_dir=tmp_path,
135+
callbacks=[checkpoint],
136+
enable_progress_bar=False,
137+
enable_model_summary=False,
138+
max_epochs=1,
139+
strategy="ddp",
140+
accelerator="gpu",
141+
devices=2,
142+
limit_train_batches=2,
143+
limit_val_batches=0,
144+
)
145+
146+
trainer.fit(model)
147+
if os.getenv("LOCAL_RANK") == "0":
148+
assert checkpoint.best_model_path

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,3 +2180,36 @@ def on_validation_epoch_end(self):
21802180
assert len(checkpoint_files) == expected_files, (
21812181
f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}"
21822182
)
2183+
2184+
2185+
def test_model_checkpoint_file_exists_distributed_branch(tmp_path):
2186+
"""Ensure the distributed branch of ModelCheckpoint.file_exists uses reduce_boolean_decision."""
2187+
2188+
checkpoint = ModelCheckpoint(dirpath=tmp_path)
2189+
calls = []
2190+
2191+
class DummyStrategy:
2192+
def reduce_boolean_decision(self, decision, all=True):
2193+
calls.append((decision, all))
2194+
return decision
2195+
2196+
class DummyTrainer:
2197+
def __init__(self, is_global_zero: bool):
2198+
self.world_size = 2
2199+
self.is_global_zero = is_global_zero
2200+
self.strategy = DummyStrategy()
2201+
2202+
# global rank 0: filesystem is touched and decision=True is reduced with all=False
2203+
checkpoint._fs.exists = Mock(return_value=True)
2204+
trainer = DummyTrainer(is_global_zero=True)
2205+
assert checkpoint.file_exists("ignored", trainer)
2206+
checkpoint._fs.exists.assert_called_once_with("ignored")
2207+
assert calls == [(True, False)]
2208+
2209+
# non-global ranks: filesystem is not touched and local decision is False
2210+
calls.clear()
2211+
checkpoint._fs.exists = Mock(return_value=True)
2212+
trainer = DummyTrainer(is_global_zero=False)
2213+
assert not checkpoint.file_exists("ignored", trainer)
2214+
checkpoint._fs.exists.assert_not_called()
2215+
assert calls == [(False, False)]

0 commit comments

Comments
 (0)