From 2fceb77153d70514e57d2ee1e0fa55d2016a62ee Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Sun, 26 Feb 2023 16:13:52 -0700 Subject: [PATCH] support cloud storage in load_dataset via fsspec --- setup.py | 1 + src/datasets/utils/file_utils.py | 30 ++++++++++++++++++++++++++++-- tests/test_file_utils.py | 20 +++++++++++++++++++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 6674625158eb..53759e8b4184 100644 --- a/setup.py +++ b/setup.py @@ -230,6 +230,7 @@ "tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"], "torch": ["torch"], "jax": ["jax>=0.2.8,!=0.3.2,<=0.3.25", "jaxlib>=0.1.65,<=0.3.25"], + "gcsfs": ["gcsfs"], "s3": ["s3fs"], "streaming": [], # for backward compatibility "dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE, diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 48242e8a4a1d..26de041fd8cb 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -22,6 +22,7 @@ from typing import List, Optional, Type, TypeVar, Union from urllib.parse import urljoin, urlparse +import fsspec import huggingface_hub import requests from huggingface_hub import HfFolder @@ -327,6 +328,23 @@ def _request_with_retry( return response +def fsspec_head(url, timeout=10.0): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") + try: + fsspec.filesystem(urlparse(url).scheme).info(url, timeout=timeout) + except Exception: + return False + return True + + +def fsspec_get(url, temp_file, timeout=10.0): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") + try: + fsspec.filesystem(urlparse(url).scheme).get(url, temp_file, timeout=timeout) + except fsspec.FSTimeoutError as e: + raise ConnectionError(e) from None + + def ftp_head(url, timeout=10.0): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") try: @@ -400,6 +418,8 @@ def http_head( def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]: + if urlparse(url).scheme not in ("http", "https"): + return None headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token) response = http_head(url, headers=headers, max_retries=3) response.raise_for_status() @@ -453,6 +473,7 @@ def get_from_cache( cookies = None etag = None head_error = None + scheme = None # Try a first time to file the file on the local file system without eTag (None) # if we don't ask for 'force_download' then we spare a request @@ -469,8 +490,11 @@ def get_from_cache( # We don't have the file locally or we need an eTag if not local_files_only: - if url.startswith("ftp://"): + scheme = urlparse(url).scheme + if scheme == "ftp": connected = ftp_head(url) + elif scheme in ("s3", "gs"): + connected = fsspec_head(url) try: response = http_head( url, @@ -569,8 +593,10 @@ def _resumable_file_manager(): logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}") # GET file object - if url.startswith("ftp://"): + if scheme == "ftp": ftp_get(url, temp_file) + elif scheme in ("gs", "s3"): + fsspec_get(url, temp_file) else: http_get( url, diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 04b37e5eb403..09f3eeb4f7df 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -6,7 +6,16 @@ import zstandard as zstd from datasets.download.download_config import DownloadConfig -from datasets.utils.file_utils import OfflineModeIsEnabled, cached_path, ftp_get, ftp_head, http_get, http_head +from datasets.utils.file_utils import ( + OfflineModeIsEnabled, + cached_path, + fsspec_get, + fsspec_head, + ftp_get, + ftp_head, + http_get, + http_head, +) FILE_CONTENT = """\ @@ -102,3 +111,12 @@ def test_ftp_offline(tmp_path_factory): ftp_get("ftp://huggingface.co", temp_file=filename) with pytest.raises(OfflineModeIsEnabled): ftp_head("ftp://huggingface.co") + + +@patch("datasets.config.HF_DATASETS_OFFLINE", True) +def test_fsspec_offline(tmp_path_factory): + filename = tmp_path_factory.mktemp("data") / "file.html" + with pytest.raises(OfflineModeIsEnabled): + fsspec_get("s3://huggingface.co", temp_file=filename) + with pytest.raises(OfflineModeIsEnabled): + fsspec_head("s3://huggingface.co")