diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 48242e8a4a1..c2f19b5ce7a 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -22,6 +22,7 @@ from typing import List, Optional, Type, TypeVar, Union from urllib.parse import urljoin, urlparse +import fsspec import huggingface_hub import requests from huggingface_hub import HfFolder @@ -327,6 +328,28 @@ def _request_with_retry( return response +def fsspec_head(url, timeout=10.0): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") + fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout}) + if len(paths) > 1: + raise ValueError(f"HEAD can be called with at most one path but was called with {paths}") + return fs.info(paths[0]) + + +def fsspec_get(url, temp_file, timeout=10.0, desc=None): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") + fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout}) + if len(paths) > 1: + raise ValueError(f"GET can be called with at most one path but was called with {paths}") + callback = fsspec.callbacks.TqdmCallback( + tqdm_kwargs={ + "desc": desc or "Downloading", + "disable": not logging.is_progress_bar_enabled(), + } + ) + fs.get_file(paths[0], temp_file.name, callback=callback) + + def ftp_head(url, timeout=10.0): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") try: @@ -400,6 +423,8 @@ def http_head( def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]: + if urlparse(url).scheme not in ("http", "https"): + return None headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token) response = http_head(url, headers=headers, max_retries=3) response.raise_for_status() @@ -453,6 +478,7 @@ def get_from_cache( cookies = None etag = None head_error = None + scheme = None # Try a first time to file the file on the local file system without eTag (None) # if we don't ask for 'force_download' then we spare a request @@ -469,8 +495,14 @@ def get_from_cache( # We don't have the file locally or we need an eTag if not local_files_only: - if url.startswith("ftp://"): + scheme = urlparse(url).scheme + if scheme == "ftp": connected = ftp_head(url) + elif scheme not in ("http", "https"): + response = fsspec_head(url) + # s3fs uses "ETag", gcsfs uses "etag" + etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None + connected = True try: response = http_head( url, @@ -569,8 +601,10 @@ def _resumable_file_manager(): logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}") # GET file object - if url.startswith("ftp://"): + if scheme == "ftp": ftp_get(url, temp_file) + elif scheme not in ("http", "https"): + fsspec_get(url, temp_file, desc=download_desc) else: http_get( url, diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py index be49dd0bdeb..8aaa181a77e 100644 --- a/tests/fixtures/fsspec.py +++ b/tests/fixtures/fsspec.py @@ -1,5 +1,6 @@ import posixpath from pathlib import Path +from unittest.mock import patch import fsspec import pytest @@ -73,10 +74,27 @@ def _strip_protocol(cls, path): return path +class TmpDirFileSystem(MockFileSystem): + protocol = "tmp" + tmp_dir = None + + def __init__(self, *args, **kwargs): + assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set" + super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True) + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("tmp://"): + path = path[6:] + return path + + @pytest.fixture def mock_fsspec(): original_registry = fsspec.registry.copy() fsspec.register_implementation("mock", MockFileSystem) + fsspec.register_implementation("tmp", TmpDirFileSystem) yield fsspec.registry = original_registry @@ -85,3 +103,10 @@ def mock_fsspec(): def mockfs(tmp_path_factory, mock_fsspec): local_fs_dir = tmp_path_factory.mktemp("mockfs") return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True) + + +@pytest.fixture +def tmpfs(tmp_path_factory, mock_fsspec): + tmp_fs_dir = tmp_path_factory.mktemp("tmpfs") + with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir): + yield TmpDirFileSystem() diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 04b37e5eb40..a6175c3dd17 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -6,23 +6,42 @@ import zstandard as zstd from datasets.download.download_config import DownloadConfig -from datasets.utils.file_utils import OfflineModeIsEnabled, cached_path, ftp_get, ftp_head, http_get, http_head +from datasets.utils.file_utils import ( + OfflineModeIsEnabled, + cached_path, + fsspec_get, + fsspec_head, + ftp_get, + ftp_head, + get_from_cache, + http_get, + http_head, +) FILE_CONTENT = """\ Text data. Second line of data.""" +FILE_PATH = "file" + @pytest.fixture(scope="session") def zstd_path(tmp_path_factory): - path = tmp_path_factory.mktemp("data") / "file.zstd" + path = tmp_path_factory.mktemp("data") / (FILE_PATH + ".zstd") data = bytes(FILE_CONTENT, "utf-8") with zstd.open(path, "wb") as f: f.write(data) return path +@pytest.fixture +def tmpfs_file(tmpfs): + with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f: + f.write(FILE_CONTENT) + return FILE_PATH + + @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file): input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path} @@ -80,6 +99,13 @@ def test_cached_path_missing_local(tmp_path): cached_path(missing_file) +def test_get_from_cache_fsspec(tmpfs_file): + output_path = get_from_cache(f"tmp://{tmpfs_file}") + with open(output_path) as f: + output_file_content = f.read() + assert output_file_content == FILE_CONTENT + + @patch("datasets.config.HF_DATASETS_OFFLINE", True) def test_cached_path_offline(): with pytest.raises(OfflineModeIsEnabled): @@ -102,3 +128,12 @@ def test_ftp_offline(tmp_path_factory): ftp_get("ftp://huggingface.co", temp_file=filename) with pytest.raises(OfflineModeIsEnabled): ftp_head("ftp://huggingface.co") + + +@patch("datasets.config.HF_DATASETS_OFFLINE", True) +def test_fsspec_offline(tmp_path_factory): + filename = tmp_path_factory.mktemp("data") / "file.html" + with pytest.raises(OfflineModeIsEnabled): + fsspec_get("s3://huggingface.co", temp_file=filename) + with pytest.raises(OfflineModeIsEnabled): + fsspec_head("s3://huggingface.co")