Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for storage_options for load_dataset API #5919

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5316,7 +5316,12 @@ def path_in_repo(_index, shard):
for data_file in data_files
if data_file.startswith(f"data/{split}-") and data_file not in shards_path_in_repo
]
deleted_size = sum(xgetsize(hf_hub_url(repo_id, data_file), token=token) for data_file in data_files_to_delete)

download_config = DownloadConfig(token=token)
deleted_size = sum(
xgetsize(hf_hub_url(repo_id, data_file), download_config=download_config)
for data_file in data_files_to_delete
)

def delete_file(file):
api.delete_file(file, repo_id=repo_id, token=token, repo_type="dataset", revision=branch)
Expand Down
249 changes: 104 additions & 145 deletions src/datasets/download/streaming_download_manager.py

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyarrow as pa

from .. import config
from ..download.download_config import DownloadConfig
from ..download.streaming_download_manager import xopen, xsplitext
from ..table import array_cast
from ..utils.py_utils import no_op_if_value_is_null, string_to_dict
Expand Down Expand Up @@ -172,13 +173,15 @@ def decode_example(
if file is None:
token_per_repo_id = token_per_repo_id or {}
source_url = path.split("::")[-1]
repo_id = None
try:
repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"]
token = token_per_repo_id[repo_id]
token_per_repo_id[repo_id]
except (ValueError, KeyError):
token = None
pass
Comment on lines -177 to +181
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should revert this no ? we need to get the token from token_per_repo_id

Copy link
Contributor Author

@janineguo janineguo Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need, I can change the setting of DownloadConfig.


with xopen(path, "rb", token=token) as f:
download_config = DownloadConfig(token=None if repo_id is None else token_per_repo_id[repo_id])
with xopen(path, "rb", download_config=download_config) as f:
array, sampling_rate = sf.read(f)

else:
Expand Down
10 changes: 7 additions & 3 deletions src/datasets/features/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pyarrow as pa

from .. import config
from ..download.download_config import DownloadConfig
from ..download.streaming_download_manager import xopen
from ..table import array_cast
from ..utils.file_utils import is_local_path
Expand Down Expand Up @@ -167,10 +168,13 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "PIL.Image.Imag
source_url = path.split("::")[-1]
try:
repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"]
token = token_per_repo_id.get(repo_id)
token_per_repo_id.get(repo_id)
download_config = DownloadConfig(token=token_per_repo_id.get(repo_id))
except ValueError:
token = None
with xopen(path, "rb", token=token) as f:
use_auth_token = None
download_config = DownloadConfig(token=use_auth_token)

with xopen(path, "rb", download_config=download_config) as f:
bytes_ = BytesIO(f.read())
image = PIL.Image.open(bytes_)
else:
Expand Down
16 changes: 7 additions & 9 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def _create_importable_file(


def infer_module_for_data_files(
data_files_list: DataFilesList, token: Optional[Union[bool, str]] = None
data_files_list: DataFilesList, download_config: Optional[DownloadConfig] = None
) -> Optional[Tuple[str, str]]:
"""Infer module (and builder kwargs) from list of data files.

Expand All @@ -352,8 +352,7 @@ def infer_module_for_data_files(

Args:
data_files_list (DataFilesList): List of data files.
token (bool or str, optional): Whether to use token or token to authenticate on the Hugging Face Hub
for private remote files.
download_config (bool or str, optional): mainly use token or storage_options to support different platforms and auth types.

Returns:
tuple[str, str]: Tuple with
Expand All @@ -376,19 +375,18 @@ def sort_key(ext_count: Tuple[str, int]) -> Tuple[int, bool]:
if ext in _EXTENSION_TO_MODULE:
return _EXTENSION_TO_MODULE[ext]
elif ext == ".zip":
return infer_module_for_data_files_in_archives(data_files_list, token=token)
return infer_module_for_data_files_in_archives(data_files_list, download_config=download_config)
return None, {}


def infer_module_for_data_files_in_archives(
data_files_list: DataFilesList, token: Optional[Union[bool, str]]
data_files_list: DataFilesList, download_config: Optional[DownloadConfig]
) -> Optional[Tuple[str, str]]:
"""Infer module (and builder kwargs) from list of archive data files.

Args:
data_files_list (DataFilesList): List of data files.
token (bool or str, optional): Whether to use token or token to authenticate on the Hugging Face Hub
for private remote files.
download_config (bool or str, optional): mainly use token or storage_options to support different platforms and auth types.

Returns:
tuple[str, str]: Tuple with
Expand All @@ -405,7 +403,7 @@ def infer_module_for_data_files_in_archives(
extracted = xjoin(StreamingDownloadManager().extract(filepath), "**")
archived_files += [
f.split("::")[0]
for f in xglob(extracted, recursive=True, token=token)[
for f in xglob(extracted, recursive=True, download_config=download_config)[
: config.ARCHIVED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE
]
]
Expand Down Expand Up @@ -785,7 +783,7 @@ def get_module(self) -> DatasetModule:
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
)
split_modules = {
split: infer_module_for_data_files(data_files_list, token=self.download_config.token)
split: infer_module_for_data_files(data_files_list, download_config=self.download_config)
for split, data_files_list in data_files.items()
}
module_name, builder_kwargs = next(iter(split_modules.values()))
Expand Down
17 changes: 9 additions & 8 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import importlib
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional

from .download.download_config import DownloadConfig
from .download.streaming_download_manager import (
xbasename,
xdirname,
Expand Down Expand Up @@ -39,7 +40,7 @@
from .builder import DatasetBuilder


def extend_module_for_streaming(module_path, token: Optional[Union[str, bool]] = None):
def extend_module_for_streaming(module_path, download_config: Optional[DownloadConfig] = None):
"""Extend the module to support streaming.

We patch some functions in the module to use `fsspec` to support data streaming:
Expand All @@ -55,8 +56,7 @@ def extend_module_for_streaming(module_path, token: Optional[Union[str, bool]] =

Args:
module_path: Path to the module to be extended.
token (``str`` or :obj:`bool`, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub.
If True, or not specified, will get token from `"~/.huggingface"`.
download_config : mainly use token or storage_options to support different platforms and auth types.
"""

module = importlib.import_module(module_path)
Expand All @@ -68,7 +68,7 @@ def extend_module_for_streaming(module_path, token: Optional[Union[str, bool]] =
def wrap_auth(function):
@wraps(function)
def wrapper(*args, **kwargs):
return function(*args, token=token, **kwargs)
return function(*args, download_config=download_config, **kwargs)

wrapper._decorator_name_ = "wrap_auth"
return wrapper
Expand Down Expand Up @@ -109,14 +109,15 @@ def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"):
builder (:class:`DatasetBuilder`): Dataset builder instance.
"""
# this extends the open and os.path.join functions for data streaming
extend_module_for_streaming(builder.__module__, token=builder.token)
download_config = DownloadConfig(storage_options=builder.storage_options, token=builder.use_auth_token)
extend_module_for_streaming(builder.__module__, download_config=download_config)
# if needed, we also have to extend additional internal imports (like wmt14 -> wmt_utils)
if not builder.__module__.startswith("datasets."): # check that it's not a packaged builder like csv
for imports in get_imports(inspect.getfile(builder.__class__)):
if imports[0] == "internal":
internal_import_name = imports[1]
internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name])
extend_module_for_streaming(internal_module_name, token=builder.token)
extend_module_for_streaming(internal_module_name, download_config=download_config)

# builders can inherit from other builders that might use streaming functionality
# (for example, ImageFolder and AudioFolder inherit from FolderBuilder which implements examples generation)
Expand All @@ -129,4 +130,4 @@ def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"):
if issubclass(cls, DatasetBuilder) and cls.__module__ != DatasetBuilder.__module__
] # check it's not a standard builder from datasets.builder
for module in parent_builder_modules:
extend_module_for_streaming(module, token=builder.token)
extend_module_for_streaming(module, download_config=download_config)
47 changes: 28 additions & 19 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fsspec.registry import _registry as _fsspec_registry
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem

from datasets.download.download_config import DownloadConfig
from datasets.download.streaming_download_manager import (
StreamingDownloadManager,
_get_extraction_protocol,
Expand Down Expand Up @@ -236,8 +237,9 @@ def test_xexists(input_path, exists, tmp_path, mock_fsspec):
@pytest.mark.integration
def test_xexists_private(hf_private_dataset_repo_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "")
assert xexists(root_url + "data/text_data.txt", token=hf_token)
assert not xexists(root_url + "file_that_doesnt_exist.txt", token=hf_token)
download_config = DownloadConfig(token=hf_token)
assert xexists(root_url + "data/text_data.txt", download_config=download_config)
assert not xexists(root_url + "file_that_doesnt_exist.txt", download_config=download_config)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -320,12 +322,13 @@ def test_xlistdir(input_path, expected_paths, tmp_path, mock_fsspec):
@pytest.mark.integration
def test_xlistdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert len(xlistdir("zip://::" + root_url, token=hf_token)) == 1
assert len(xlistdir("zip://main_dir::" + root_url, token=hf_token)) == 2
download_config = DownloadConfig(token=hf_token)
assert len(xlistdir("zip://::" + root_url, download_config=download_config)) == 1
assert len(xlistdir("zip://main_dir::" + root_url, download_config=download_config)) == 2
with pytest.raises(FileNotFoundError):
xlistdir("zip://qwertyuiop::" + root_url, token=hf_token)
xlistdir("zip://qwertyuiop::" + root_url, download_config=download_config)
with pytest.raises(NotImplementedError):
xlistdir(root_url, token=hf_token)
xlistdir(root_url, download_config=download_config)


@pytest.mark.parametrize(
Expand All @@ -348,11 +351,13 @@ def test_xisdir(input_path, isdir, tmp_path, mock_fsspec):
@pytest.mark.integration
def test_xisdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert xisdir("zip://::" + root_url, token=hf_token) is True
assert xisdir("zip://main_dir::" + root_url, token=hf_token) is True
assert xisdir("zip://qwertyuiop::" + root_url, token=hf_token) is False

download_config = DownloadConfig(token=hf_token)
assert xisdir("zip://::" + root_url, download_config=download_config) is True
assert xisdir("zip://main_dir::" + root_url, download_config=download_config) is True
assert xisdir("zip://qwertyuiop::" + root_url, download_config=download_config) is False
with pytest.raises(NotImplementedError):
xisdir(root_url, token=hf_token)
xisdir(root_url, download_config=download_config)


@pytest.mark.parametrize(
Expand All @@ -374,8 +379,9 @@ def test_xisfile(input_path, isfile, tmp_path, mock_fsspec):
@pytest.mark.integration
def test_xisfile_private(hf_private_dataset_repo_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "")
assert xisfile(root_url + "data/text_data.txt", token=hf_token) is True
assert xisfile(root_url + "qwertyuiop", token=hf_token) is False
download_config = DownloadConfig(token=hf_token)
assert xisfile(root_url + "data/text_data.txt", download_config=download_config) is True
assert xisfile(root_url + "qwertyuiop", download_config=download_config) is False


@pytest.mark.parametrize(
Expand All @@ -397,9 +403,10 @@ def test_xgetsize(input_path, size, tmp_path, mock_fsspec):
@pytest.mark.integration
def test_xgetsize_private(hf_private_dataset_repo_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "")
assert xgetsize(root_url + "data/text_data.txt", token=hf_token) == 39
download_config = DownloadConfig(token=hf_token)
assert xgetsize(root_url + "data/text_data.txt", download_config=download_config) == 39
with pytest.raises(FileNotFoundError):
xgetsize(root_url + "qwertyuiop", token=hf_token)
xgetsize(root_url + "qwertyuiop", download_config=download_config)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -440,8 +447,9 @@ def test_xglob(input_path, expected_paths, tmp_path, mock_fsspec):
@pytest.mark.integration
def test_xglob_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert len(xglob("zip://**::" + root_url, token=hf_token)) == 3
assert len(xglob("zip://qwertyuiop/*::" + root_url, token=hf_token)) == 0
download_config = DownloadConfig(token=hf_token)
assert len(xglob("zip://**::" + root_url, download_config=download_config)) == 3
assert len(xglob("zip://qwertyuiop/*::" + root_url, download_config=download_config)) == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -478,9 +486,10 @@ def test_xwalk(input_path, expected_outputs, tmp_path, mock_fsspec):
@pytest.mark.integration
def test_xwalk_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert len(list(xwalk("zip://::" + root_url, token=hf_token))) == 2
assert len(list(xwalk("zip://main_dir::" + root_url, token=hf_token))) == 1
assert len(list(xwalk("zip://qwertyuiop::" + root_url, token=hf_token))) == 0
download_config = DownloadConfig(token=hf_token)
assert len(list(xwalk("zip://::" + root_url, download_config=download_config))) == 2
assert len(list(xwalk("zip://main_dir::" + root_url, download_config=download_config))) == 1
assert len(list(xwalk("zip://qwertyuiop::" + root_url, download_config=download_config))) == 0


@pytest.mark.parametrize(
Expand Down