Skip to content

Commit

Permalink
add tmpfs and use to test fsspec in get_from_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dwyatte committed Mar 5, 2023
1 parent 0864e14 commit f618135
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
29 changes: 25 additions & 4 deletions tests/fixtures/fsspec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import posixpath
from pathlib import Path
from unittest.mock import patch

import fsspec
import pytest
Expand Down Expand Up @@ -40,10 +41,6 @@ def info(self, path, *args, **kwargs):
out["name"] = out["name"][len(self.local_root_dir) :]
return out

def get_file(self, rpath, lpath, *args, **kwargs):
rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath))
return self._fs.get_file(rpath, lpath, *args, **kwargs)

def cp_file(self, path1, path2, *args, **kwargs):
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
Expand Down Expand Up @@ -77,10 +74,27 @@ def _strip_protocol(cls, path):
return path


class TmpDirFileSystem(MockFileSystem):
protocol = "tmp"
tmp_dir = None

def __init__(self, *args, **kwargs):
assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set"
super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True)

@classmethod
def _strip_protocol(cls, path):
path = stringify_path(path)
if path.startswith("tmp://"):
path = path[6:]
return path


@pytest.fixture
def mock_fsspec():
original_registry = fsspec.registry.copy()
fsspec.register_implementation("mock", MockFileSystem)
fsspec.register_implementation("tmp", TmpDirFileSystem)
yield
fsspec.registry = original_registry

Expand All @@ -89,3 +103,10 @@ def mock_fsspec():
def mockfs(tmp_path_factory, mock_fsspec):
local_fs_dir = tmp_path_factory.mktemp("mockfs")
return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True)


@pytest.fixture
def tmpfs(tmp_path_factory, mock_fsspec):
tmp_fs_dir = tmp_path_factory.mktemp("tmpfs")
with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir):
yield TmpDirFileSystem()
18 changes: 8 additions & 10 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def zstd_path(tmp_path_factory):


@pytest.fixture
def mockfs_file(mockfs):
with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f:
def tmpfs_file(tmpfs):
with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f:
f.write(FILE_CONTENT)
return mockfs
return FILE_PATH


@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
Expand Down Expand Up @@ -99,13 +99,11 @@ def test_cached_path_missing_local(tmp_path):
cached_path(missing_file)


def test_get_from_cache_fsspec(mockfs_file):
with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths:
mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH])
output_path = get_from_cache("mock://huggingface.co")
with open(output_path) as f:
output_file_content = f.read()
assert output_file_content == FILE_CONTENT
def test_get_from_cache_fsspec(tmpfs_file):
output_path = get_from_cache(f"tmp://{tmpfs_file}")
with open(output_path) as f:
output_file_content = f.read()
assert output_file_content == FILE_CONTENT


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
Expand Down

0 comments on commit f618135

Please sign in to comment.