Skip to content

Commit bc8c0fd

Browse files
Fix: Prevent multiple processes from copying the same file when using… (#353)
Co-authored-by: tchaton <[email protected]>
1 parent 8382067 commit bc8c0fd

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

src/litdata/streaming/downloader.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import contextlib
1415
import os
1516
import shutil
1617
import subprocess
@@ -169,8 +170,17 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
169170
if not os.path.exists(remote_filepath):
170171
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")
171172

172-
if remote_filepath != local_filepath and not os.path.exists(local_filepath):
173-
shutil.copy(remote_filepath, local_filepath)
173+
try:
174+
with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0):
175+
if remote_filepath != local_filepath and not os.path.exists(local_filepath):
176+
# make an atomic operation to be safe
177+
temp_file_path = local_filepath + ".tmp"
178+
shutil.copy(remote_filepath, temp_file_path)
179+
os.rename(temp_file_path, local_filepath)
180+
with contextlib.suppress(Exception):
181+
os.remove(local_filepath + ".lock")
182+
except Timeout:
183+
pass
174184

175185

176186
class LocalDownloaderWithCache(LocalDownloader):

tests/streaming/test_dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from litdata.utilities.shuffle import _associate_chunks_and_intervals_to_workers
4646
from torch.utils.data import DataLoader
4747

48+
from tests.streaming.utils import filter_lock_files
49+
4850

4951
def seed_everything(random_seed):
5052
random.seed(random_seed)
@@ -861,13 +863,13 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir):
861863
for batch in dataloader:
862864
batches_epoch_1.append(batch)
863865

864-
assert len(os.listdir(cache_dir)) == 51
866+
assert len(filter_lock_files(os.listdir(cache_dir))) == 51
865867

866868
batches_epoch_2 = []
867869
for batch in dataloader:
868870
batches_epoch_2.append(batch)
869871

870-
assert len(os.listdir(cache_dir)) == 51
872+
assert len(filter_lock_files(os.listdir(cache_dir))) == 51
871873
assert not all(torch.equal(b1, b2) for b1, b2 in zip(batches_epoch_1, batches_epoch_2))
872874

873875

tests/streaming/test_downloader.py

+4
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,12 @@ def test_download_with_cache(tmpdir, monkeypatch):
8080
try:
8181
local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, [])
8282
shutil_mock = MagicMock()
83+
os_mock = MagicMock()
8384
monkeypatch.setattr(shutil, "copy", shutil_mock)
85+
monkeypatch.setattr(os, "rename", os_mock)
86+
8487
local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt"))
8588
shutil_mock.assert_called()
89+
os_mock.assert_called()
8690
finally:
8791
os.remove("a.txt")

tests/streaming/test_reader.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from litdata.streaming.resolver import Dir
1212
from litdata.utilities.env import _DistributedEnv
1313

14+
from tests.streaming.utils import filter_lock_files
15+
1416

1517
def test_reader_chunk_removal(tmpdir):
1618
cache_dir = os.path.join(tmpdir, "cache_dir")
@@ -32,19 +34,19 @@ def test_reader_chunk_removal(tmpdir):
3234
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
3335
assert cache[index] == i
3436

35-
assert len(os.listdir(cache_dir)) == 14
37+
assert len(filter_lock_files(os.listdir(cache_dir))) == 14
3638

3739
cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=2800)
3840

3941
shutil.rmtree(cache_dir)
4042
os.makedirs(cache_dir, exist_ok=True)
4143

4244
for i in range(25):
43-
assert len(os.listdir(cache_dir)) <= 3
45+
assert len(filter_lock_files(os.listdir(cache_dir))) <= 3
4446
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
4547
assert cache[index] == i
4648

47-
assert len(os.listdir(cache_dir)) in [2, 3]
49+
assert len(filter_lock_files(os.listdir(cache_dir))) in [2, 3]
4850

4951

5052
def test_get_folder_size(tmpdir):

tests/streaming/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def filter_lock_files(files):
2+
return [f for f in files if not f.endswith(".lock")]

0 commit comments

Comments
 (0)