Skip to content

Fix: Prevent multiple processes from copying the same file when using… #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 3, 2024
14 changes: 12 additions & 2 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import os
import shutil
import subprocess
Expand Down Expand Up @@ -169,8 +170,17 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")

if remote_filepath != local_filepath and not os.path.exists(local_filepath):
shutil.copy(remote_filepath, local_filepath)
try:
with FileLock(local_filepath + ".lock", timeout=3 if remote_filepath.endswith(_INDEX_FILENAME) else 0):
if remote_filepath != local_filepath and not os.path.exists(local_filepath):
# make an atomic operation to be safe
temp_file_path = local_filepath + ".tmp"
shutil.copy(remote_filepath, temp_file_path)
os.rename(temp_file_path, local_filepath)
with contextlib.suppress(Exception):
os.remove(local_filepath + ".lock")
except Timeout:
pass


class LocalDownloaderWithCache(LocalDownloader):
Expand Down
6 changes: 4 additions & 2 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from litdata.utilities.shuffle import _associate_chunks_and_intervals_to_workers
from torch.utils.data import DataLoader

from tests.streaming.utils import filter_lock_files


def seed_everything(random_seed):
random.seed(random_seed)
Expand Down Expand Up @@ -861,13 +863,13 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir):
for batch in dataloader:
batches_epoch_1.append(batch)

assert len(os.listdir(cache_dir)) == 51
assert len(filter_lock_files(os.listdir(cache_dir))) == 51

batches_epoch_2 = []
for batch in dataloader:
batches_epoch_2.append(batch)

assert len(os.listdir(cache_dir)) == 51
assert len(filter_lock_files(os.listdir(cache_dir))) == 51
assert not all(torch.equal(b1, b2) for b1, b2 in zip(batches_epoch_1, batches_epoch_2))


Expand Down
4 changes: 4 additions & 0 deletions tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def test_download_with_cache(tmpdir, monkeypatch):
try:
local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, [])
shutil_mock = MagicMock()
os_mock = MagicMock()
monkeypatch.setattr(shutil, "copy", shutil_mock)
monkeypatch.setattr(os, "rename", os_mock)

local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt"))
shutil_mock.assert_called()
os_mock.assert_called()
finally:
os.remove("a.txt")
8 changes: 5 additions & 3 deletions tests/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from litdata.streaming.resolver import Dir
from litdata.utilities.env import _DistributedEnv

from tests.streaming.utils import filter_lock_files


def test_reader_chunk_removal(tmpdir):
cache_dir = os.path.join(tmpdir, "cache_dir")
Expand All @@ -32,19 +34,19 @@ def test_reader_chunk_removal(tmpdir):
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert len(os.listdir(cache_dir)) == 14
assert len(filter_lock_files(os.listdir(cache_dir))) == 14

cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=2800)

shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok=True)

for i in range(25):
assert len(os.listdir(cache_dir)) <= 3
assert len(filter_lock_files(os.listdir(cache_dir))) <= 3
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert len(os.listdir(cache_dir)) in [2, 3]
assert len(filter_lock_files(os.listdir(cache_dir))) in [2, 3]


def test_get_folder_size(tmpdir):
Expand Down
2 changes: 2 additions & 0 deletions tests/streaming/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def filter_lock_files(files):
return [f for f in files if not f.endswith(".lock")]
Loading