Skip to content

Commit

Permalink
Support cloud storage in load_dataset via fsspec (#5580)
Browse files Browse the repository at this point in the history
* support cloud storage in load_dataset via fsspec

* fsspec get uses tqdm, tries to handle additional protocols, and computes pseudo etag from head response

* Update setup.py

* add test

* Update setup.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Update tests/test_file_utils.py

Co-authored-by: Quentin Lhoest <[email protected]>

* add tmpfs and use to test fsspec in get_from_cache

* Update src/datasets/utils/file_utils.py

Co-authored-by: Quentin Lhoest <[email protected]>

* Update src/datasets/utils/file_utils.py

Co-authored-by: Quentin Lhoest <[email protected]>

* remove comment

---------

Co-authored-by: Alvaro Bartolome <[email protected]>
Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
3 people authored Mar 11, 2023
1 parent e502117 commit 3e62699
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 4 deletions.
38 changes: 36 additions & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse

import fsspec
import huggingface_hub
import requests
from huggingface_hub import HfFolder
Expand Down Expand Up @@ -327,6 +328,28 @@ def _request_with_retry(
return response


def fsspec_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
if len(paths) > 1:
raise ValueError(f"HEAD can be called with at most one path but was called with {paths}")
return fs.info(paths[0])


def fsspec_get(url, temp_file, timeout=10.0, desc=None):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
if len(paths) > 1:
raise ValueError(f"GET can be called with at most one path but was called with {paths}")
callback = fsspec.callbacks.TqdmCallback(
tqdm_kwargs={
"desc": desc or "Downloading",
"disable": not logging.is_progress_bar_enabled(),
}
)
fs.get_file(paths[0], temp_file.name, callback=callback)


def ftp_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
Expand Down Expand Up @@ -400,6 +423,8 @@ def http_head(


def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]:
if urlparse(url).scheme not in ("http", "https"):
return None
headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token)
response = http_head(url, headers=headers, max_retries=3)
response.raise_for_status()
Expand Down Expand Up @@ -453,6 +478,7 @@ def get_from_cache(
cookies = None
etag = None
head_error = None
scheme = None

# Try a first time to file the file on the local file system without eTag (None)
# if we don't ask for 'force_download' then we spare a request
Expand All @@ -469,8 +495,14 @@ def get_from_cache(

# We don't have the file locally or we need an eTag
if not local_files_only:
if url.startswith("ftp://"):
scheme = urlparse(url).scheme
if scheme == "ftp":
connected = ftp_head(url)
elif scheme not in ("http", "https"):
response = fsspec_head(url)
# s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None
connected = True
try:
response = http_head(
url,
Expand Down Expand Up @@ -569,8 +601,10 @@ def _resumable_file_manager():
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")

# GET file object
if url.startswith("ftp://"):
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme not in ("http", "https"):
fsspec_get(url, temp_file, desc=download_desc)
else:
http_get(
url,
Expand Down
25 changes: 25 additions & 0 deletions tests/fixtures/fsspec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import posixpath
from pathlib import Path
from unittest.mock import patch

import fsspec
import pytest
Expand Down Expand Up @@ -73,10 +74,27 @@ def _strip_protocol(cls, path):
return path


class TmpDirFileSystem(MockFileSystem):
protocol = "tmp"
tmp_dir = None

def __init__(self, *args, **kwargs):
assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set"
super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True)

@classmethod
def _strip_protocol(cls, path):
path = stringify_path(path)
if path.startswith("tmp://"):
path = path[6:]
return path


@pytest.fixture
def mock_fsspec():
original_registry = fsspec.registry.copy()
fsspec.register_implementation("mock", MockFileSystem)
fsspec.register_implementation("tmp", TmpDirFileSystem)
yield
fsspec.registry = original_registry

Expand All @@ -85,3 +103,10 @@ def mock_fsspec():
def mockfs(tmp_path_factory, mock_fsspec):
local_fs_dir = tmp_path_factory.mktemp("mockfs")
return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True)


@pytest.fixture
def tmpfs(tmp_path_factory, mock_fsspec):
tmp_fs_dir = tmp_path_factory.mktemp("tmpfs")
with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir):
yield TmpDirFileSystem()
39 changes: 37 additions & 2 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,42 @@
import zstandard as zstd

from datasets.download.download_config import DownloadConfig
from datasets.utils.file_utils import OfflineModeIsEnabled, cached_path, ftp_get, ftp_head, http_get, http_head
from datasets.utils.file_utils import (
OfflineModeIsEnabled,
cached_path,
fsspec_get,
fsspec_head,
ftp_get,
ftp_head,
get_from_cache,
http_get,
http_head,
)


FILE_CONTENT = """\
Text data.
Second line of data."""

FILE_PATH = "file"


@pytest.fixture(scope="session")
def zstd_path(tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "file.zstd"
path = tmp_path_factory.mktemp("data") / (FILE_PATH + ".zstd")
data = bytes(FILE_CONTENT, "utf-8")
with zstd.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture
def tmpfs_file(tmpfs):
with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f:
f.write(FILE_CONTENT)
return FILE_PATH


@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
Expand Down Expand Up @@ -80,6 +99,13 @@ def test_cached_path_missing_local(tmp_path):
cached_path(missing_file)


def test_get_from_cache_fsspec(tmpfs_file):
output_path = get_from_cache(f"tmp://{tmpfs_file}")
with open(output_path) as f:
output_file_content = f.read()
assert output_file_content == FILE_CONTENT


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_cached_path_offline():
with pytest.raises(OfflineModeIsEnabled):
Expand All @@ -102,3 +128,12 @@ def test_ftp_offline(tmp_path_factory):
ftp_get("ftp://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
ftp_head("ftp://huggingface.co")


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_fsspec_offline(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.html"
with pytest.raises(OfflineModeIsEnabled):
fsspec_get("s3://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
fsspec_head("s3://huggingface.co")

0 comments on commit 3e62699

Please sign in to comment.