Skip to content

Commit f618135

Browse files
committed
add tmpfs and use to test fsspec in get_from_cache
1 parent 0864e14 commit f618135

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

tests/fixtures/fsspec.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import posixpath
22
from pathlib import Path
3+
from unittest.mock import patch
34

45
import fsspec
56
import pytest
@@ -40,10 +41,6 @@ def info(self, path, *args, **kwargs):
4041
out["name"] = out["name"][len(self.local_root_dir) :]
4142
return out
4243

43-
def get_file(self, rpath, lpath, *args, **kwargs):
44-
rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath))
45-
return self._fs.get_file(rpath, lpath, *args, **kwargs)
46-
4744
def cp_file(self, path1, path2, *args, **kwargs):
4845
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
4946
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
@@ -77,10 +74,27 @@ def _strip_protocol(cls, path):
7774
return path
7875

7976

77+
class TmpDirFileSystem(MockFileSystem):
78+
protocol = "tmp"
79+
tmp_dir = None
80+
81+
def __init__(self, *args, **kwargs):
82+
assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set"
83+
super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True)
84+
85+
@classmethod
86+
def _strip_protocol(cls, path):
87+
path = stringify_path(path)
88+
if path.startswith("tmp://"):
89+
path = path[6:]
90+
return path
91+
92+
8093
@pytest.fixture
8194
def mock_fsspec():
8295
original_registry = fsspec.registry.copy()
8396
fsspec.register_implementation("mock", MockFileSystem)
97+
fsspec.register_implementation("tmp", TmpDirFileSystem)
8498
yield
8599
fsspec.registry = original_registry
86100

@@ -89,3 +103,10 @@ def mock_fsspec():
89103
def mockfs(tmp_path_factory, mock_fsspec):
90104
local_fs_dir = tmp_path_factory.mktemp("mockfs")
91105
return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True)
106+
107+
108+
@pytest.fixture
109+
def tmpfs(tmp_path_factory, mock_fsspec):
110+
tmp_fs_dir = tmp_path_factory.mktemp("tmpfs")
111+
with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir):
112+
yield TmpDirFileSystem()

tests/test_file_utils.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def zstd_path(tmp_path_factory):
3636

3737

3838
@pytest.fixture
39-
def mockfs_file(mockfs):
40-
with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f:
39+
def tmpfs_file(tmpfs):
40+
with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f:
4141
f.write(FILE_CONTENT)
42-
return mockfs
42+
return FILE_PATH
4343

4444

4545
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
@@ -99,13 +99,11 @@ def test_cached_path_missing_local(tmp_path):
9999
cached_path(missing_file)
100100

101101

102-
def test_get_from_cache_fsspec(mockfs_file):
103-
with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths:
104-
mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH])
105-
output_path = get_from_cache("mock://huggingface.co")
106-
with open(output_path) as f:
107-
output_file_content = f.read()
108-
assert output_file_content == FILE_CONTENT
102+
def test_get_from_cache_fsspec(tmpfs_file):
103+
output_path = get_from_cache(f"tmp://{tmpfs_file}")
104+
with open(output_path) as f:
105+
output_file_content = f.read()
106+
assert output_file_content == FILE_CONTENT
109107

110108

111109
@patch("datasets.config.HF_DATASETS_OFFLINE", True)

0 commit comments

Comments
 (0)