Skip to content

Commit

Permalink
support cloud storage in load_dataset via fsspec
Browse files Browse the repository at this point in the history
  • Loading branch information
dwyatte committed Feb 27, 2023
1 parent a940972 commit b7f309e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
"tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"],
"torch": ["torch"],
"jax": ["jax>=0.2.8,!=0.3.2,<=0.3.25", "jaxlib>=0.1.65,<=0.3.25"],
"gcsfs": ["gcsfs"],
"s3": ["s3fs"],
"streaming": [], # for backward compatibility
"dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE,
Expand Down
29 changes: 27 additions & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse

import fsspec
import huggingface_hub
import requests
from huggingface_hub import HfFolder
Expand Down Expand Up @@ -327,6 +328,23 @@ def _request_with_retry(
return response


def fsspec_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
fsspec.filesystem(urlparse(url).scheme).info(url, timeout=timeout)
except Exception:
return False
return True


def fsspec_get(url, temp_file, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
fsspec.filesystem(urlparse(url).scheme).get(url, temp_file, timeout=timeout)
except fsspec.FSTimeoutError as e:
raise ConnectionError(e) from None


def ftp_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
Expand Down Expand Up @@ -400,6 +418,8 @@ def http_head(


def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]:
if urlparse(url).scheme not in ("http", "https"):
return None
headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token)
response = http_head(url, headers=headers, max_retries=3)
response.raise_for_status()
Expand Down Expand Up @@ -469,8 +489,11 @@ def get_from_cache(

# We don't have the file locally or we need an eTag
if not local_files_only:
if url.startswith("ftp://"):
scheme = urlparse(url).scheme
if scheme == "ftp":
connected = ftp_head(url)
elif scheme in ("s3", "gs"):
connected = fsspec_head(url)
try:
response = http_head(
url,
Expand Down Expand Up @@ -569,8 +592,10 @@ def _resumable_file_manager():
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")

# GET file object
if url.startswith("ftp://"):
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme in ("gs", "s3"):
fsspec_get(url, temp_file)
else:
http_get(
url,
Expand Down
20 changes: 19 additions & 1 deletion tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
import zstandard as zstd

from datasets.download.download_config import DownloadConfig
from datasets.utils.file_utils import OfflineModeIsEnabled, cached_path, ftp_get, ftp_head, http_get, http_head
from datasets.utils.file_utils import (
OfflineModeIsEnabled,
cached_path,
fsspec_get,
fsspec_head,
ftp_get,
ftp_head,
http_get,
http_head,
)


FILE_CONTENT = """\
Expand Down Expand Up @@ -102,3 +111,12 @@ def test_ftp_offline(tmp_path_factory):
ftp_get("ftp://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
ftp_head("ftp://huggingface.co")


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_fsspec_offline(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.html"
with pytest.raises(OfflineModeIsEnabled):
fsspec_get("s3://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
fsspec_head("s3://huggingface.co")

0 comments on commit b7f309e

Please sign in to comment.