Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass down storage options #5673

Merged
merged 4 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ class DatasetBuilder:
`os.path.join(data_dir, "**")` as `data_files`.
For builders that require manual download, it must be the path to the local directory containing the
manually downloaded data.
storage_options (`dict`, *optional*):
Key/value pairs to be passed on to the dataset file-system backend, if any.
writer_batch_size (`int`, *optional*):
Batch size used by the ArrowWriter.
It defines the number of samples that are kept in memory before writing them
Expand Down Expand Up @@ -299,6 +301,7 @@ def __init__(
repo_id: Optional[str] = None,
data_files: Optional[Union[str, list, dict, DataFilesDict]] = None,
data_dir: Optional[str] = None,
storage_options: Optional[dict] = None,
writer_batch_size: Optional[int] = None,
name="deprecated",
**config_kwargs,
Expand All @@ -315,6 +318,7 @@ def __init__(
self.base_path = base_path
self.use_auth_token = use_auth_token
self.repo_id = repo_id
self.storage_options = storage_options
self._writer_batch_size = writer_batch_size or self.DEFAULT_WRITER_BATCH_SIZE

if data_files is not None and not isinstance(data_files, DataFilesDict):
Expand Down Expand Up @@ -778,6 +782,7 @@ def download_and_prepare(
use_etag=False,
num_proc=num_proc,
use_auth_token=use_auth_token,
storage_options=self.storage_options,
) # We don't use etag for data files to speed up the process

dl_manager = DownloadManager(
Expand Down Expand Up @@ -1251,7 +1256,7 @@ def as_streaming_dataset(

dl_manager = StreamingDownloadManager(
base_path=base_path or self.base_path,
download_config=DownloadConfig(use_auth_token=self.use_auth_token),
download_config=DownloadConfig(use_auth_token=self.use_auth_token, storage_options=self.storage_options),
dataset_name=self.name,
data_dir=self.config.data_dir,
)
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class DownloadConfig:
ignore_url_params (`bool`, defaults to `False`):
Whether to strip all query parameters and fragments from
the download URL before using it for caching the file.
storage_options (`dict`, *optional*):
Key/value pairs to be passed on to the dataset file-system backend, if any.
download_desc (`str`, *optional*):
A description to be displayed alongside with the progress bar while downloading the files.
"""
Expand All @@ -60,6 +62,7 @@ class DownloadConfig:
max_retries: int = 1
use_auth_token: Optional[Union[str, bool]] = None
ignore_url_params: bool = False
storage_options: Optional[Dict] = None
download_desc: Optional[str] = None

def copy(self) -> "DownloadConfig":
Expand Down
13 changes: 13 additions & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,7 @@ def load_dataset_builder(
download_mode: Optional[Union[DownloadMode, str]] = None,
revision: Optional[Union[str, Version]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
storage_options: Optional[Dict] = None,
**config_kwargs,
) -> DatasetBuilder:
"""Load a dataset builder from the Hugging Face Hub, or a local dataset. A dataset builder can be used to inspect general information that is required to build a dataset (cache directory, config, dataset info, etc.)
Expand Down Expand Up @@ -1469,6 +1470,10 @@ def load_dataset_builder(
use_auth_token (`str` or `bool`, *optional*):
Optional string or boolean to use as Bearer token for remote files on the Datasets Hub.
If `True`, or not specified, will get token from `"~/.huggingface"`.
storage_options (`dict`, *optional*, defaults to `None`):
**Experimental**. Key/value pairs to be passed on to the dataset file-system backend, if any.

<Added version="2.11.0"/>
**config_kwargs (additional keyword arguments):
Keyword arguments to be passed to the [`BuilderConfig`]
and used in the [`DatasetBuilder`].
Expand Down Expand Up @@ -1524,6 +1529,7 @@ def load_dataset_builder(
hash=hash,
features=features,
use_auth_token=use_auth_token,
storage_options=storage_options,
**builder_kwargs,
**config_kwargs,
)
Expand All @@ -1550,6 +1556,7 @@ def load_dataset(
task: Optional[Union[str, TaskTemplate]] = None,
streaming: bool = False,
num_proc: Optional[int] = None,
storage_options: Optional[Dict] = None,
**config_kwargs,
) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]:
"""Load a dataset from the Hugging Face Hub, or a local dataset.
Expand Down Expand Up @@ -1671,6 +1678,10 @@ def load_dataset(
Multiprocessing is disabled by default.

<Added version="2.7.0"/>
storage_options (`dict`, *optional*, defaults to `None`):
**Experimental**. Key/value pairs to be passed on to the dataset file-system backend, if any.

<Added version="2.11.0"/>
**config_kwargs (additional keyword arguments):
Keyword arguments to be passed to the `BuilderConfig`
and used in the [`DatasetBuilder`].
Expand Down Expand Up @@ -1764,6 +1775,7 @@ def load_dataset(
download_mode=download_mode,
revision=revision,
use_auth_token=use_auth_token,
storage_options=storage_options,
**config_kwargs,
)

Expand All @@ -1782,6 +1794,7 @@ def load_dataset(
verification_mode=verification_mode,
try_from_hf_gcs=try_from_hf_gcs,
num_proc=num_proc,
storage_options=storage_options,
)

# Build dataset for splits
Expand Down
14 changes: 8 additions & 6 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def cached_path(
max_retries=download_config.max_retries,
use_auth_token=download_config.use_auth_token,
ignore_url_params=download_config.ignore_url_params,
storage_options=download_config.storage_options,
download_desc=download_config.download_desc,
)
elif os.path.exists(url_or_filename):
Expand Down Expand Up @@ -328,17 +329,17 @@ def _request_with_retry(
return response


def fsspec_head(url, timeout=10.0):
def fsspec_head(url, storage_options=None):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options=storage_options)
if len(paths) > 1:
raise ValueError(f"HEAD can be called with at most one path but was called with {paths}")
return fs.info(paths[0])


def fsspec_get(url, temp_file, timeout=10.0, desc=None):
def fsspec_get(url, temp_file, storage_options=None, desc=None):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options=storage_options)
if len(paths) > 1:
raise ValueError(f"GET can be called with at most one path but was called with {paths}")
callback = fsspec.callbacks.TqdmCallback(
Expand Down Expand Up @@ -445,6 +446,7 @@ def get_from_cache(
max_retries=0,
use_auth_token=None,
ignore_url_params=False,
storage_options=None,
download_desc=None,
) -> str:
"""
Expand Down Expand Up @@ -499,7 +501,7 @@ def get_from_cache(
if scheme == "ftp":
connected = ftp_head(url)
elif scheme not in ("http", "https"):
response = fsspec_head(url)
response = fsspec_head(url, storage_options=storage_options)
# s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None
connected = True
Expand Down Expand Up @@ -604,7 +606,7 @@ def _resumable_file_manager():
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme not in ("http", "https"):
fsspec_get(url, temp_file, desc=download_desc)
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
else:
http_get(
url,
Expand Down