Skip to content

Commit 3e62699

Browse files
dwyattealvarobarttlhoestq
authored
Support cloud storage in load_dataset via fsspec (#5580)
* support cloud storage in load_dataset via fsspec * fsspec get uses tqdm, tries to handle additional protocols, and computes pseudo etag from head response * Update setup.py * add test * Update setup.py Co-authored-by: Alvaro Bartolome <[email protected]> * Update tests/test_file_utils.py Co-authored-by: Quentin Lhoest <[email protected]> * add tmpfs and use to test fsspec in get_from_cache * Update src/datasets/utils/file_utils.py Co-authored-by: Quentin Lhoest <[email protected]> * Update src/datasets/utils/file_utils.py Co-authored-by: Quentin Lhoest <[email protected]> * remove comment --------- Co-authored-by: Alvaro Bartolome <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent e502117 commit 3e62699

File tree

3 files changed

+98
-4
lines changed

3 files changed

+98
-4
lines changed

src/datasets/utils/file_utils.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import List, Optional, Type, TypeVar, Union
2323
from urllib.parse import urljoin, urlparse
2424

25+
import fsspec
2526
import huggingface_hub
2627
import requests
2728
from huggingface_hub import HfFolder
@@ -327,6 +328,28 @@ def _request_with_retry(
327328
return response
328329

329330

331+
def fsspec_head(url, timeout=10.0):
332+
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
333+
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
334+
if len(paths) > 1:
335+
raise ValueError(f"HEAD can be called with at most one path but was called with {paths}")
336+
return fs.info(paths[0])
337+
338+
339+
def fsspec_get(url, temp_file, timeout=10.0, desc=None):
340+
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
341+
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
342+
if len(paths) > 1:
343+
raise ValueError(f"GET can be called with at most one path but was called with {paths}")
344+
callback = fsspec.callbacks.TqdmCallback(
345+
tqdm_kwargs={
346+
"desc": desc or "Downloading",
347+
"disable": not logging.is_progress_bar_enabled(),
348+
}
349+
)
350+
fs.get_file(paths[0], temp_file.name, callback=callback)
351+
352+
330353
def ftp_head(url, timeout=10.0):
331354
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
332355
try:
@@ -400,6 +423,8 @@ def http_head(
400423

401424

402425
def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]:
426+
if urlparse(url).scheme not in ("http", "https"):
427+
return None
403428
headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token)
404429
response = http_head(url, headers=headers, max_retries=3)
405430
response.raise_for_status()
@@ -453,6 +478,7 @@ def get_from_cache(
453478
cookies = None
454479
etag = None
455480
head_error = None
481+
scheme = None
456482

457483
# Try a first time to file the file on the local file system without eTag (None)
458484
# if we don't ask for 'force_download' then we spare a request
@@ -469,8 +495,14 @@ def get_from_cache(
469495

470496
# We don't have the file locally or we need an eTag
471497
if not local_files_only:
472-
if url.startswith("ftp://"):
498+
scheme = urlparse(url).scheme
499+
if scheme == "ftp":
473500
connected = ftp_head(url)
501+
elif scheme not in ("http", "https"):
502+
response = fsspec_head(url)
503+
# s3fs uses "ETag", gcsfs uses "etag"
504+
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None
505+
connected = True
474506
try:
475507
response = http_head(
476508
url,
@@ -569,8 +601,10 @@ def _resumable_file_manager():
569601
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
570602

571603
# GET file object
572-
if url.startswith("ftp://"):
604+
if scheme == "ftp":
573605
ftp_get(url, temp_file)
606+
elif scheme not in ("http", "https"):
607+
fsspec_get(url, temp_file, desc=download_desc)
574608
else:
575609
http_get(
576610
url,

tests/fixtures/fsspec.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import posixpath
22
from pathlib import Path
3+
from unittest.mock import patch
34

45
import fsspec
56
import pytest
@@ -73,10 +74,27 @@ def _strip_protocol(cls, path):
7374
return path
7475

7576

77+
class TmpDirFileSystem(MockFileSystem):
78+
protocol = "tmp"
79+
tmp_dir = None
80+
81+
def __init__(self, *args, **kwargs):
82+
assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set"
83+
super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True)
84+
85+
@classmethod
86+
def _strip_protocol(cls, path):
87+
path = stringify_path(path)
88+
if path.startswith("tmp://"):
89+
path = path[6:]
90+
return path
91+
92+
7693
@pytest.fixture
7794
def mock_fsspec():
7895
original_registry = fsspec.registry.copy()
7996
fsspec.register_implementation("mock", MockFileSystem)
97+
fsspec.register_implementation("tmp", TmpDirFileSystem)
8098
yield
8199
fsspec.registry = original_registry
82100

@@ -85,3 +103,10 @@ def mock_fsspec():
85103
def mockfs(tmp_path_factory, mock_fsspec):
86104
local_fs_dir = tmp_path_factory.mktemp("mockfs")
87105
return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True)
106+
107+
108+
@pytest.fixture
109+
def tmpfs(tmp_path_factory, mock_fsspec):
110+
tmp_fs_dir = tmp_path_factory.mktemp("tmpfs")
111+
with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir):
112+
yield TmpDirFileSystem()

tests/test_file_utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,42 @@
66
import zstandard as zstd
77

88
from datasets.download.download_config import DownloadConfig
9-
from datasets.utils.file_utils import OfflineModeIsEnabled, cached_path, ftp_get, ftp_head, http_get, http_head
9+
from datasets.utils.file_utils import (
10+
OfflineModeIsEnabled,
11+
cached_path,
12+
fsspec_get,
13+
fsspec_head,
14+
ftp_get,
15+
ftp_head,
16+
get_from_cache,
17+
http_get,
18+
http_head,
19+
)
1020

1121

1222
FILE_CONTENT = """\
1323
Text data.
1424
Second line of data."""
1525

26+
FILE_PATH = "file"
27+
1628

1729
@pytest.fixture(scope="session")
1830
def zstd_path(tmp_path_factory):
19-
path = tmp_path_factory.mktemp("data") / "file.zstd"
31+
path = tmp_path_factory.mktemp("data") / (FILE_PATH + ".zstd")
2032
data = bytes(FILE_CONTENT, "utf-8")
2133
with zstd.open(path, "wb") as f:
2234
f.write(data)
2335
return path
2436

2537

38+
@pytest.fixture
39+
def tmpfs_file(tmpfs):
40+
with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f:
41+
f.write(FILE_CONTENT)
42+
return FILE_PATH
43+
44+
2645
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
2746
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
2847
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
@@ -80,6 +99,13 @@ def test_cached_path_missing_local(tmp_path):
8099
cached_path(missing_file)
81100

82101

102+
def test_get_from_cache_fsspec(tmpfs_file):
103+
output_path = get_from_cache(f"tmp://{tmpfs_file}")
104+
with open(output_path) as f:
105+
output_file_content = f.read()
106+
assert output_file_content == FILE_CONTENT
107+
108+
83109
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
84110
def test_cached_path_offline():
85111
with pytest.raises(OfflineModeIsEnabled):
@@ -102,3 +128,12 @@ def test_ftp_offline(tmp_path_factory):
102128
ftp_get("ftp://huggingface.co", temp_file=filename)
103129
with pytest.raises(OfflineModeIsEnabled):
104130
ftp_head("ftp://huggingface.co")
131+
132+
133+
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
134+
def test_fsspec_offline(tmp_path_factory):
135+
filename = tmp_path_factory.mktemp("data") / "file.html"
136+
with pytest.raises(OfflineModeIsEnabled):
137+
fsspec_get("s3://huggingface.co", temp_file=filename)
138+
with pytest.raises(OfflineModeIsEnabled):
139+
fsspec_head("s3://huggingface.co")

0 commit comments

Comments
 (0)