diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 406dbef3131c..fb5c6794ef69 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -330,24 +330,24 @@ def _request_with_retry( def fsspec_head(url, timeout=10.0): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") - fs, _, paths = fsspec.get_fs_token_paths(url) + fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout}) if len(paths) > 1: - raise ValueError("HEAD can be called with at most one path but was called with {paths}") + raise ValueError(f"HEAD can be called with at most one path but was called with {paths}") return fs.info(paths[0], timeout=timeout) def fsspec_get(url, temp_file, timeout=10.0, desc=None): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") - fs, _, paths = fsspec.get_fs_token_paths(url) + fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout}) if len(paths) > 1: - raise ValueError("GET can be called with at most one path but was called with {paths}") + raise ValueError(f"GET can be called with at most one path but was called with {paths}") callback = fsspec.callbacks.TqdmCallback( tqdm_kwargs={ "desc": desc or "Downloading", "disable": logging.is_progress_bar_enabled(), } ) - fs.get(paths[0], temp_file, timeout=timeout, callback=callback) + fs.get_file(paths[0], temp_file.name, timeout=timeout, callback=callback) def ftp_head(url, timeout=10.0): diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py index be49dd0bdeb7..e7b653c7f5e4 100644 --- a/tests/fixtures/fsspec.py +++ b/tests/fixtures/fsspec.py @@ -40,6 +40,10 @@ def info(self, path, *args, **kwargs): out["name"] = out["name"][len(self.local_root_dir) :] return out + def get_file(self, rpath, lpath, *args, **kwargs): + rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath)) + return self._fs.get_file(rpath, lpath, *args, **kwargs) + def cp_file(self, path1, path2, *args, **kwargs): path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 09f3eeb4f7df..be0992460b8f 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -13,6 +13,7 @@ fsspec_head, ftp_get, ftp_head, + get_from_cache, http_get, http_head, ) @@ -22,16 +23,25 @@ Text data. Second line of data.""" +FILE_PATH = "file" + @pytest.fixture(scope="session") def zstd_path(tmp_path_factory): - path = tmp_path_factory.mktemp("data") / "file.zstd" + path = tmp_path_factory.mktemp("data") / FILE_PATH data = bytes(FILE_CONTENT, "utf-8") with zstd.open(path, "wb") as f: f.write(data) return path +@pytest.fixture +def mockfs_file(mockfs): + with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f: + f.write(FILE_CONTENT) + return mockfs + + @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file): input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path} @@ -89,6 +99,15 @@ def test_cached_path_missing_local(tmp_path): cached_path(missing_file) +def test_get_from_cache_fsspec(mockfs_file): + with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths: + mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH]) + output_path = get_from_cache("mock://huggingface.co") + with open(output_path) as f: + output_file_content = f.read() + assert output_file_content == FILE_CONTENT + + @patch("datasets.config.HF_DATASETS_OFFLINE", True) def test_cached_path_offline(): with pytest.raises(OfflineModeIsEnabled):