diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py index e7b653c7f5e4..8aaa181a77e4 100644 --- a/tests/fixtures/fsspec.py +++ b/tests/fixtures/fsspec.py @@ -1,5 +1,6 @@ import posixpath from pathlib import Path +from unittest.mock import patch import fsspec import pytest @@ -40,10 +41,6 @@ def info(self, path, *args, **kwargs): out["name"] = out["name"][len(self.local_root_dir) :] return out - def get_file(self, rpath, lpath, *args, **kwargs): - rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath)) - return self._fs.get_file(rpath, lpath, *args, **kwargs) - def cp_file(self, path1, path2, *args, **kwargs): path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) @@ -77,10 +74,27 @@ def _strip_protocol(cls, path): return path +class TmpDirFileSystem(MockFileSystem): + protocol = "tmp" + tmp_dir = None + + def __init__(self, *args, **kwargs): + assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set" + super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True) + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("tmp://"): + path = path[6:] + return path + + @pytest.fixture def mock_fsspec(): original_registry = fsspec.registry.copy() fsspec.register_implementation("mock", MockFileSystem) + fsspec.register_implementation("tmp", TmpDirFileSystem) yield fsspec.registry = original_registry @@ -89,3 +103,10 @@ def mock_fsspec(): def mockfs(tmp_path_factory, mock_fsspec): local_fs_dir = tmp_path_factory.mktemp("mockfs") return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True) + + +@pytest.fixture +def tmpfs(tmp_path_factory, mock_fsspec): + tmp_fs_dir = tmp_path_factory.mktemp("tmpfs") + with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir): + yield TmpDirFileSystem() diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 64b0583e13e4..7c5d720e2395 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -36,10 +36,10 @@ def zstd_path(tmp_path_factory): @pytest.fixture -def mockfs_file(mockfs): - with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f: +def tmpfs_file(tmpfs): + with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f: f.write(FILE_CONTENT) - return mockfs + return f"tmp://{FILE_PATH}" @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) @@ -99,13 +99,11 @@ def test_cached_path_missing_local(tmp_path): cached_path(missing_file) -def test_get_from_cache_fsspec(mockfs_file): - with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths: - mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH]) - output_path = get_from_cache("mock://huggingface.co") - with open(output_path) as f: - output_file_content = f.read() - assert output_file_content == FILE_CONTENT +def test_get_from_cache_fsspec(tmpfs_file): + output_path = get_from_cache(tmpfs_file) + with open(output_path) as f: + output_file_content = f.read() + assert output_file_content == FILE_CONTENT @patch("datasets.config.HF_DATASETS_OFFLINE", True)