diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index a727f5fab3d..809f8bd3966 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5316,7 +5316,12 @@ def path_in_repo(_index, shard): for data_file in data_files if data_file.startswith(f"data/{split}-") and data_file not in shards_path_in_repo ] - deleted_size = sum(xgetsize(hf_hub_url(repo_id, data_file), token=token) for data_file in data_files_to_delete) + + download_config = DownloadConfig(token=token) + deleted_size = sum( + xgetsize(hf_hub_url(repo_id, data_file), download_config=download_config) + for data_file in data_files_to_delete + ) def delete_file(file): api.delete_file(file, repo_id=repo_id, token=token, repo_type="dataset", revision=branch) diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index 67b28e008a0..40c9b958ccf 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -77,6 +77,8 @@ for magic_number in chain(MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL, MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL) ) +SUPPORTED_REMOTE_SERVER_TYPE = ["http", "https", "s3"] + class NonStreamableDatasetError(Exception): pass @@ -140,13 +142,12 @@ def xdirname(a): return "::".join([a] + b) -def xexists(urlpath: str, token: Optional[Union[str, bool]] = None): +def xexists(urlpath: str, download_config: Optional[DownloadConfig] = None): """Extend `os.path.exists` function to support both local and remote files. Args: urlpath (`str`): URL path. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `bool` @@ -156,16 +157,7 @@ def xexists(urlpath: str, token: Optional[Union[str, bool]] = None): if is_local_path(main_hop): return os.path.exists(main_hop) else: - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, token=token) - storage_options = http_kwargs - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, http_kwargs = _prepare_http_url_kwargs(url, token=token) - storage_options = {"https": http_kwargs} - urlpath = "::".join([main_hop, url, *rest_hops[1:]]) - else: - storage_options = None + urlpath, storage_options = _prepare_server_config(urlpath, download_config=download_config) fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options) return fs.exists(main_hop) @@ -250,13 +242,12 @@ def xsplitext(a): return "::".join([a] + b), ext -def xisfile(path, token: Optional[Union[str, bool]] = None) -> bool: +def xisfile(path, download_config: Optional[DownloadConfig] = None) -> bool: """Extend `os.path.isfile` function to support remote files. Args: path (`str`): URL path. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `bool` @@ -265,27 +256,17 @@ def xisfile(path, token: Optional[Union[str, bool]] = None) -> bool: if is_local_path(main_hop): return os.path.isfile(path) else: - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, token=token) - storage_options = http_kwargs - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, http_kwargs = _prepare_http_url_kwargs(url, token=token) - storage_options = {"https": http_kwargs} - path = "::".join([main_hop, url, *rest_hops[1:]]) - else: - storage_options = None + path, storage_options = _prepare_server_config(path, download_config=download_config) fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) return fs.isfile(main_hop) -def xgetsize(path, token: Optional[Union[str, bool]] = None) -> int: +def xgetsize(path, download_config: Optional[DownloadConfig] = None) -> int: """Extend `os.path.getsize` function to support remote files. Args: path (`str`): URL path. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `int`: optional @@ -294,32 +275,22 @@ def xgetsize(path, token: Optional[Union[str, bool]] = None) -> int: if is_local_path(main_hop): return os.path.getsize(path) else: - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, token=token) - storage_options = http_kwargs - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, http_kwargs = _prepare_http_url_kwargs(url, token=token) - storage_options = {"https": http_kwargs} - path = "::".join([main_hop, url, *rest_hops[1:]]) - else: - storage_options = None + path, storage_options = _prepare_server_config(path, download_config=download_config) fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) size = fs.size(main_hop) if size is None: # use xopen instead of fs.open to make data fetching more robust - with xopen(path, token=token) as f: + with xopen(path, download_config=download_config) as f: size = len(f.read()) return size -def xisdir(path, token: Optional[Union[str, bool]] = None) -> bool: +def xisdir(path, download_config: Optional[DownloadConfig] = None) -> bool: """Extend `os.path.isdir` function to support remote files. Args: path (`str`): URL path. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `bool` @@ -328,15 +299,7 @@ def xisdir(path, token: Optional[Union[str, bool]] = None) -> bool: if is_local_path(main_hop): return os.path.isdir(path) else: - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - raise NotImplementedError("os.path.isdir is not extended to support URLs in streaming mode") - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, http_kwargs = _prepare_http_url_kwargs(url, token=token) - storage_options = {"https": http_kwargs} - path = "::".join([main_hop, url, *rest_hops[1:]]) - else: - storage_options = None + path, storage_options = _prepare_server_config(path, download_config=download_config, implemented=False) fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) inner_path = main_hop.split("://")[1] if not inner_path.strip("/"): @@ -412,7 +375,7 @@ def _get_extraction_protocol_with_magic_number(f) -> Optional[str]: raise NotImplementedError(f"Compression protocol '{compression}' not implemented.") -def _get_extraction_protocol(urlpath: str, token: Optional[Union[str, bool]] = None) -> Optional[str]: +def _get_extraction_protocol(urlpath: str, download_config: Optional[DownloadConfig] = None) -> Optional[str]: # get inner file: zip://train-00000.json.gz::https://foo.bar/data.zip -> zip://train-00000.json.gz urlpath = str(urlpath) path = urlpath.split("::")[0] @@ -427,7 +390,7 @@ def _get_extraction_protocol(urlpath: str, token: Optional[Union[str, bool]] = N return COMPRESSION_EXTENSION_TO_PROTOCOL[extension] if is_remote_url(urlpath): # get headers and cookies for authentication on the HF Hub and for Google Drive - urlpath, kwargs = _prepare_http_url_kwargs(urlpath, token=token) + urlpath, kwargs = _prepare_http_url_kwargs(urlpath, download_config=download_config) else: urlpath, kwargs = urlpath, {} try: @@ -442,14 +405,43 @@ def _get_extraction_protocol(urlpath: str, token: Optional[Union[str, bool]] = N raise -def _prepare_http_url_kwargs(url: str, token: Optional[Union[str, bool]] = None) -> Tuple[str, dict]: +def _validate_servers(urlpath: str): + server = urlpath.split("://")[0] + return server in SUPPORTED_REMOTE_SERVER_TYPE + + +def _prepare_server_config( + path: str, download_config: Optional[DownloadConfig] = None, implemented: bool = True +) -> Tuple[str, dict]: + main_hop, *rest_hops = str(path).split("::") + if not rest_hops and _validate_servers(main_hop): + if not implemented: + raise NotImplementedError("Currently not extended to support URLs in streaming mode") + main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, download_config=download_config) + storage_options = http_kwargs + elif rest_hops and _validate_servers(rest_hops[0]): + url = rest_hops[0] + url, http_kwargs = _prepare_http_url_kwargs(url, download_config=download_config) + storage_options = {"https": http_kwargs} + path = "::".join([main_hop, url, *rest_hops[1:]]) + else: + storage_options = None + return path, storage_options + + +def _prepare_http_url_kwargs(url: str, download_config: Optional[DownloadConfig] = None) -> Tuple[str, dict]: """ Prepare the URL and the kwargs that must be passed to the HttpFileSystem or to requests.get/head In particular it resolves google drive URLs and it adds the authentication headers for the Hugging Face Hub. + + it also needs to resolve the S3 file system specifically due to the S3 file system stores it parameters in + storage_options field. """ kwargs = { - "headers": get_authentication_headers_for_url(url, token=token), + "headers": get_authentication_headers_for_url( + url, token=None if download_config is None else download_config.token + ), "client_kwargs": {"trust_env": True}, # Enable reading proxy env variables. } if "drive.google.com" in url: @@ -466,10 +458,13 @@ def _prepare_http_url_kwargs(url: str, token: Optional[Union[str, bool]] = None) if url.startswith("https://raw.githubusercontent.com/"): # Workaround for served data with gzip content-encoding: https://github.com/fsspec/filesystem_spec/issues/389 kwargs["block_size"] = 0 + # Fix S3 file system + if url.startswith("s3://"): + kwargs = None if download_config is None else download_config.storage_options return url, kwargs -def xopen(file: str, mode="r", *args, token: Optional[Union[str, bool]] = None, **kwargs): +def xopen(file: str, mode="r", *args, download_config: Optional[DownloadConfig] = None, **kwargs): """Extend `open` function to support remote files using `fsspec`. It also has a retry mechanism in case connection fails. @@ -479,8 +474,7 @@ def xopen(file: str, mode="r", *args, token: Optional[Union[str, bool]] = None, file (`str`): Path name of the file to be opened. mode (`str`, *optional*, default "r"): Mode in which the file is opened. *args: Arguments to be passed to `fsspec.open`. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. **kwargs: Keyword arguments to be passed to `fsspec.open`. Returns: @@ -492,14 +486,8 @@ def xopen(file: str, mode="r", *args, token: Optional[Union[str, bool]] = None, if is_local_path(main_hop): return open(main_hop, mode, *args, **kwargs) # add headers and cookies for authentication on the HF Hub and for Google Drive - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - file, new_kwargs = _prepare_http_url_kwargs(file_str, token=token) - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, http_kwargs = _prepare_http_url_kwargs(url, token=token) - new_kwargs = {"https": http_kwargs} - file = "::".join([main_hop, url, *rest_hops[1:]]) - else: + file, new_kwargs = _prepare_server_config(file_str, download_config=download_config) + if new_kwargs is None: new_kwargs = {} kwargs = {**kwargs, **new_kwargs} try: @@ -523,13 +511,12 @@ def xopen(file: str, mode="r", *args, token: Optional[Union[str, bool]] = None, return file_obj -def xlistdir(path: str, token: Optional[Union[str, bool]] = None) -> List[str]: +def xlistdir(path: str, download_config: Optional[DownloadConfig] = None) -> List[str]: """Extend `os.listdir` function to support remote files. Args: path (`str`): URL path. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `list` of `str` @@ -539,15 +526,7 @@ def xlistdir(path: str, token: Optional[Union[str, bool]] = None) -> List[str]: return os.listdir(path) else: # globbing inside a zip in a private repo requires authentication - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - raise NotImplementedError("os.listdir is not extended to support URLs in streaming mode") - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, http_kwargs = _prepare_http_url_kwargs(url, token=token) - storage_options = {"https": http_kwargs} - path = "::".join([main_hop, url, *rest_hops[1:]]) - else: - storage_options = None + path, storage_options = _prepare_server_config(path, download_config=download_config, implemented=False) fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) inner_path = main_hop.split("://")[1] if inner_path.strip("/") and not fs.isdir(inner_path): @@ -556,15 +535,14 @@ def xlistdir(path: str, token: Optional[Union[str, bool]] = None) -> List[str]: return [os.path.basename(obj["name"]) for obj in objects] -def xglob(urlpath, *, recursive=False, token: Optional[Union[str, bool]] = None): +def xglob(urlpath, *, recursive=False, download_config: Optional[DownloadConfig] = None): """Extend `glob.glob` function to support remote files. Args: urlpath (`str`): URL path with shell-style wildcard patterns. recursive (`bool`, default `False`): Whether to match the "**" pattern recursively to zero or more directories or subdirectories. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `list` of `str` @@ -574,15 +552,7 @@ def xglob(urlpath, *, recursive=False, token: Optional[Union[str, bool]] = None) return glob.glob(main_hop, recursive=recursive) else: # globbing inside a zip in a private repo requires authentication - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - raise NotImplementedError("glob.glob is not extended to support URLs in streaming mode") - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, kwargs = _prepare_http_url_kwargs(url, token=token) - storage_options = {"https": kwargs} - urlpath = "::".join([main_hop, url, *rest_hops[1:]]) - else: - storage_options = None + urlpath, storage_options = _prepare_server_config(urlpath, download_config=download_config, implemented=False) fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options) # - If there's no "*" in the pattern, get_fs_token_paths() doesn't do any pattern matching # so to be able to glob patterns like "[0-9]", we have to call `fs.glob`. @@ -594,13 +564,12 @@ def xglob(urlpath, *, recursive=False, token: Optional[Union[str, bool]] = None) return ["::".join([f"{protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths] -def xwalk(urlpath, token: Optional[Union[str, bool]] = None, **kwargs): +def xwalk(urlpath, download_config: Optional[DownloadConfig] = None, **kwargs): """Extend `os.walk` function to support remote files. Args: urlpath (`str`): URL root path. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. **kwargs: Additional keyword arguments forwarded to the underlying filesystem. @@ -612,15 +581,7 @@ def xwalk(urlpath, token: Optional[Union[str, bool]] = None, **kwargs): yield from os.walk(main_hop, **kwargs) else: # walking inside a zip in a private repo requires authentication - if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): - raise NotImplementedError("os.walk is not extended to support URLs in streaming mode") - elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): - url = rest_hops[0] - url, kwargs = _prepare_http_url_kwargs(url, token=token) - storage_options = {"https": kwargs} - urlpath = "::".join([main_hop, url, *rest_hops[1:]]) - else: - storage_options = None + urlpath, storage_options = _prepare_server_config(urlpath, download_config=download_config, implemented=False) fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options) inner_path = main_hop.split("://")[1] if inner_path.strip("/") and not fs.isdir(inner_path): @@ -643,25 +604,23 @@ def __str__(self): path_as_posix += "//" if path_as_posix.endswith(":") else "" # Add slashes to root of the protocol return path_as_posix - def exists(self, token: Optional[Union[str, bool]] = None): + def exists(self, download_config: Optional[DownloadConfig] = None): """Extend `pathlib.Path.exists` method to support both local and remote files. Args: - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `bool` """ - return xexists(str(self), token=token) + return xexists(str(self), download_config=download_config) - def glob(self, pattern, token: Optional[Union[str, bool]] = None): + def glob(self, pattern, download_config: Optional[DownloadConfig] = None): """Glob function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. Args: pattern (`str`): Pattern that resulting paths must match. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Yields: [`xPath`] @@ -672,9 +631,9 @@ def glob(self, pattern, token: Optional[Union[str, bool]] = None): yield from Path(main_hop).glob(pattern) else: # globbing inside a zip in a private repo requires authentication - if rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): + if rest_hops and _validate_servers(rest_hops[0]): url = rest_hops[0] - url, kwargs = _prepare_http_url_kwargs(url, token=token) + url, kwargs = _prepare_http_url_kwargs(url, download_config=download_config) storage_options = {"https": kwargs} posix_path = "::".join([main_hop, url, *rest_hops[1:]]) else: @@ -772,27 +731,27 @@ def _as_str(path: Union[str, Path, xPath]): return str(path) if isinstance(path, xPath) else str(xPath(str(path))) -def xgzip_open(filepath_or_buffer, *args, token: Optional[Union[str, bool]] = None, **kwargs): +def xgzip_open(filepath_or_buffer, *args, download_config: Optional[DownloadConfig] = None, **kwargs): import gzip if hasattr(filepath_or_buffer, "read"): return gzip.open(filepath_or_buffer, *args, **kwargs) else: filepath_or_buffer = str(filepath_or_buffer) - return gzip.open(xopen(filepath_or_buffer, "rb", token=token), *args, **kwargs) + return gzip.open(xopen(filepath_or_buffer, "rb", download_config=download_config), *args, **kwargs) -def xnumpy_load(filepath_or_buffer, *args, token: Optional[Union[str, bool]] = None, **kwargs): +def xnumpy_load(filepath_or_buffer, *args, download_config: Optional[DownloadConfig] = None, **kwargs): import numpy as np if hasattr(filepath_or_buffer, "read"): return np.load(filepath_or_buffer, *args, **kwargs) else: filepath_or_buffer = str(filepath_or_buffer) - return np.load(xopen(filepath_or_buffer, "rb", token=token), *args, **kwargs) + return np.load(xopen(filepath_or_buffer, "rb", download_config=download_config), *args, **kwargs) -def xpandas_read_csv(filepath_or_buffer, token: Optional[Union[str, bool]] = None, **kwargs): +def xpandas_read_csv(filepath_or_buffer, download_config: Optional[DownloadConfig] = None, **kwargs): import pandas as pd if hasattr(filepath_or_buffer, "read"): @@ -800,11 +759,11 @@ def xpandas_read_csv(filepath_or_buffer, token: Optional[Union[str, bool]] = Non else: filepath_or_buffer = str(filepath_or_buffer) if kwargs.get("compression", "infer") == "infer": - kwargs["compression"] = _get_extraction_protocol(filepath_or_buffer, token=token) - return pd.read_csv(xopen(filepath_or_buffer, "rb", token=token), **kwargs) + kwargs["compression"] = _get_extraction_protocol(filepath_or_buffer, download_config=download_config) + return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs) -def xpandas_read_excel(filepath_or_buffer, token: Optional[Union[str, bool]] = None, **kwargs): +def xpandas_read_excel(filepath_or_buffer, download_config: Optional[DownloadConfig] = None, **kwargs): import pandas as pd if hasattr(filepath_or_buffer, "read"): @@ -815,28 +774,29 @@ def xpandas_read_excel(filepath_or_buffer, token: Optional[Union[str, bool]] = N else: filepath_or_buffer = str(filepath_or_buffer) try: - return pd.read_excel(xopen(filepath_or_buffer, "rb", token=token), **kwargs) + return pd.read_excel(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs) except ValueError: # Cannot seek streaming HTTP file - return pd.read_excel(BytesIO(xopen(filepath_or_buffer, "rb", token=token).read()), **kwargs) + return pd.read_excel( + BytesIO(xopen(filepath_or_buffer, "rb", download_config=download_config).read()), **kwargs + ) -def xsio_loadmat(filepath_or_buffer, token: Optional[Union[str, bool]] = None, **kwargs): +def xsio_loadmat(filepath_or_buffer, download_config: Optional[DownloadConfig] = None, **kwargs): import scipy.io as sio if hasattr(filepath_or_buffer, "read"): return sio.loadmat(filepath_or_buffer, **kwargs) else: - return sio.loadmat(xopen(filepath_or_buffer, "rb", token=token), **kwargs) + return sio.loadmat(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs) -def xet_parse(source, parser=None, token: Optional[Union[str, bool]] = None): +def xet_parse(source, parser=None, download_config: Optional[DownloadConfig] = None): """Extend `xml.etree.ElementTree.parse` function to support remote files. Args: source: File path or file object. parser (`XMLParser`, *optional*, default `XMLParser`): Parser instance. - token (`bool` or `str`, optional): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. Returns: `xml.etree.ElementTree.Element`: Root element of the given source document. @@ -844,17 +804,16 @@ def xet_parse(source, parser=None, token: Optional[Union[str, bool]] = None): if hasattr(source, "read"): return ET.parse(source, parser=parser) else: - with xopen(source, "rb", token=token) as f: + with xopen(source, "rb", download_config=download_config) as f: return ET.parse(f, parser=parser) -def xxml_dom_minidom_parse(filename_or_file, token: Optional[Union[str, bool]] = None, **kwargs): +def xxml_dom_minidom_parse(filename_or_file, download_config: Optional[DownloadConfig] = None, **kwargs): """Extend `xml.dom.minidom.parse` function to support remote files. Args: filename_or_file (`str` or file): File path or file object. - token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the - Hugging Face Hub for private remote files. + download_config : mainly use token or storage_options to support different platforms and auth types. **kwargs (optional): Additional keyword arguments passed to `xml.dom.minidom.parse`. Returns: @@ -863,7 +822,7 @@ def xxml_dom_minidom_parse(filename_or_file, token: Optional[Union[str, bool]] = if hasattr(filename_or_file, "read"): return xml.dom.minidom.parse(filename_or_file, **kwargs) else: - with xopen(filename_or_file, "rb", token=token) as f: + with xopen(filename_or_file, "rb", download_config=download_config) as f: return xml.dom.minidom.parse(f, **kwargs) @@ -924,10 +883,10 @@ def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]: @classmethod def _iter_from_urlpath( - cls, urlpath: str, token: Optional[Union[str, bool]] = None + cls, urlpath: str, download_config: Optional[DownloadConfig] = None ) -> Generator[Tuple, None, None]: - compression = _get_extraction_protocol(urlpath, token=token) - with xopen(urlpath, "rb", token=token) as f: + compression = _get_extraction_protocol(urlpath, download_config=download_config) + with xopen(urlpath, "rb", download_config=download_config) as f: if compression == "zip": yield from cls._iter_zip(f) else: @@ -938,8 +897,8 @@ def from_buf(cls, fileobj) -> "ArchiveIterable": return cls(cls._iter_from_fileobj, fileobj) @classmethod - def from_urlpath(cls, urlpath_or_buf, token: Optional[Union[str, bool]] = None) -> "ArchiveIterable": - return cls(cls._iter_from_urlpath, urlpath_or_buf, token) + def from_urlpath(cls, urlpath_or_buf, download_config: Optional[DownloadConfig] = None) -> "ArchiveIterable": + return cls(cls._iter_from_urlpath, urlpath_or_buf, download_config) class FilesIterable(_IterableFromGenerator): @@ -947,18 +906,18 @@ class FilesIterable(_IterableFromGenerator): @classmethod def _iter_from_urlpaths( - cls, urlpaths: Union[str, List[str]], token: Optional[Union[str, bool]] = None + cls, urlpaths: Union[str, List[str]], download_config: Optional[DownloadConfig] = None ) -> Generator[str, None, None]: if not isinstance(urlpaths, list): urlpaths = [urlpaths] for urlpath in urlpaths: - if xisfile(urlpath, token=token): + if xisfile(urlpath, download_config=download_config): if xbasename(urlpath).startswith((".", "__")): # skipping hidden files return yield urlpath else: - for dirpath, dirnames, filenames in xwalk(urlpath, token=token): + for dirpath, dirnames, filenames in xwalk(urlpath, download_config=download_config): # skipping hidden directories; prune the search # [:] for the in-place list modification required by os.walk # (only works for local paths as fsspec's walk doesn't support the in-place modification) @@ -973,8 +932,8 @@ def _iter_from_urlpaths( yield xjoin(dirpath, filename) @classmethod - def from_urlpaths(cls, urlpaths, token: Optional[Union[str, bool]] = None) -> "FilesIterable": - return cls(cls._iter_from_urlpaths, urlpaths, token) + def from_urlpaths(cls, urlpaths, download_config: Optional[DownloadConfig] = None) -> "FilesIterable": + return cls(cls._iter_from_urlpaths, urlpaths, download_config) class StreamingDownloadManager: @@ -1054,7 +1013,7 @@ def extract(self, url_or_urls): def _extract(self, urlpath: str) -> str: urlpath = str(urlpath) - protocol = _get_extraction_protocol(urlpath, token=self.download_config.token) + protocol = _get_extraction_protocol(urlpath, download_config=self.download_config) # get inner file: zip://train-00000.json.gz::https://foo.bar/data.zip -> zip://train-00000.json.gz path = urlpath.split("::")[0] extension = _get_path_extension(path) @@ -1122,7 +1081,7 @@ def iter_archive(self, urlpath_or_buf: Union[str, io.BufferedReader]) -> Iterabl if hasattr(urlpath_or_buf, "read"): return ArchiveIterable.from_buf(urlpath_or_buf) else: - return ArchiveIterable.from_urlpath(urlpath_or_buf, token=self.download_config.token) + return ArchiveIterable.from_urlpath(urlpath_or_buf, download_config=self.download_config) def iter_files(self, urlpaths: Union[str, List[str]]) -> Iterable[str]: """Iterate over files. @@ -1141,4 +1100,4 @@ def iter_files(self, urlpaths: Union[str, List[str]]) -> Iterable[str]: >>> files = dl_manager.iter_files(files) ``` """ - return FilesIterable.from_urlpaths(urlpaths, token=self.download_config.token) + return FilesIterable.from_urlpaths(urlpaths, download_config=self.download_config) diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index 790736e10d2..fd5605f1b06 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -7,6 +7,7 @@ import pyarrow as pa from .. import config +from ..download.download_config import DownloadConfig from ..download.streaming_download_manager import xopen, xsplitext from ..table import array_cast from ..utils.py_utils import no_op_if_value_is_null, string_to_dict @@ -172,13 +173,15 @@ def decode_example( if file is None: token_per_repo_id = token_per_repo_id or {} source_url = path.split("::")[-1] + repo_id = None try: repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"] - token = token_per_repo_id[repo_id] + token_per_repo_id[repo_id] except (ValueError, KeyError): - token = None + pass - with xopen(path, "rb", token=token) as f: + download_config = DownloadConfig(token=None if repo_id is None else token_per_repo_id[repo_id]) + with xopen(path, "rb", download_config=download_config) as f: array, sampling_rate = sf.read(f) else: diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index aaa2b63cfe6..1f71b98a8b0 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -9,6 +9,7 @@ import pyarrow as pa from .. import config +from ..download.download_config import DownloadConfig from ..download.streaming_download_manager import xopen from ..table import array_cast from ..utils.file_utils import is_local_path @@ -167,10 +168,13 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "PIL.Image.Imag source_url = path.split("::")[-1] try: repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"] - token = token_per_repo_id.get(repo_id) + token_per_repo_id.get(repo_id) + download_config = DownloadConfig(token=token_per_repo_id.get(repo_id)) except ValueError: - token = None - with xopen(path, "rb", token=token) as f: + use_auth_token = None + download_config = DownloadConfig(token=use_auth_token) + + with xopen(path, "rb", download_config=download_config) as f: bytes_ = BytesIO(f.read()) image = PIL.Image.open(bytes_) else: diff --git a/src/datasets/load.py b/src/datasets/load.py index 05979a6438f..30944df7719 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -343,7 +343,7 @@ def _create_importable_file( def infer_module_for_data_files( - data_files_list: DataFilesList, token: Optional[Union[bool, str]] = None + data_files_list: DataFilesList, download_config: Optional[DownloadConfig] = None ) -> Optional[Tuple[str, str]]: """Infer module (and builder kwargs) from list of data files. @@ -352,8 +352,7 @@ def infer_module_for_data_files( Args: data_files_list (DataFilesList): List of data files. - token (bool or str, optional): Whether to use token or token to authenticate on the Hugging Face Hub - for private remote files. + download_config (bool or str, optional): mainly use token or storage_options to support different platforms and auth types. Returns: tuple[str, str]: Tuple with @@ -376,19 +375,18 @@ def sort_key(ext_count: Tuple[str, int]) -> Tuple[int, bool]: if ext in _EXTENSION_TO_MODULE: return _EXTENSION_TO_MODULE[ext] elif ext == ".zip": - return infer_module_for_data_files_in_archives(data_files_list, token=token) + return infer_module_for_data_files_in_archives(data_files_list, download_config=download_config) return None, {} def infer_module_for_data_files_in_archives( - data_files_list: DataFilesList, token: Optional[Union[bool, str]] + data_files_list: DataFilesList, download_config: Optional[DownloadConfig] ) -> Optional[Tuple[str, str]]: """Infer module (and builder kwargs) from list of archive data files. Args: data_files_list (DataFilesList): List of data files. - token (bool or str, optional): Whether to use token or token to authenticate on the Hugging Face Hub - for private remote files. + download_config (bool or str, optional): mainly use token or storage_options to support different platforms and auth types. Returns: tuple[str, str]: Tuple with @@ -405,7 +403,7 @@ def infer_module_for_data_files_in_archives( extracted = xjoin(StreamingDownloadManager().extract(filepath), "**") archived_files += [ f.split("::")[0] - for f in xglob(extracted, recursive=True, token=token)[ + for f in xglob(extracted, recursive=True, download_config=download_config)[ : config.ARCHIVED_DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE ] ] @@ -785,7 +783,7 @@ def get_module(self) -> DatasetModule: allowed_extensions=ALL_ALLOWED_EXTENSIONS, ) split_modules = { - split: infer_module_for_data_files(data_files_list, token=self.download_config.token) + split: infer_module_for_data_files(data_files_list, download_config=self.download_config) for split, data_files_list in data_files.items() } module_name, builder_kwargs = next(iter(split_modules.values())) diff --git a/src/datasets/streaming.py b/src/datasets/streaming.py index d53949fab4a..67109014aff 100644 --- a/src/datasets/streaming.py +++ b/src/datasets/streaming.py @@ -1,8 +1,9 @@ import importlib import inspect from functools import wraps -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional +from .download.download_config import DownloadConfig from .download.streaming_download_manager import ( xbasename, xdirname, @@ -39,7 +40,7 @@ from .builder import DatasetBuilder -def extend_module_for_streaming(module_path, token: Optional[Union[str, bool]] = None): +def extend_module_for_streaming(module_path, download_config: Optional[DownloadConfig] = None): """Extend the module to support streaming. We patch some functions in the module to use `fsspec` to support data streaming: @@ -55,8 +56,7 @@ def extend_module_for_streaming(module_path, token: Optional[Union[str, bool]] = Args: module_path: Path to the module to be extended. - token (``str`` or :obj:`bool`, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. - If True, or not specified, will get token from `"~/.huggingface"`. + download_config : mainly use token or storage_options to support different platforms and auth types. """ module = importlib.import_module(module_path) @@ -68,7 +68,7 @@ def extend_module_for_streaming(module_path, token: Optional[Union[str, bool]] = def wrap_auth(function): @wraps(function) def wrapper(*args, **kwargs): - return function(*args, token=token, **kwargs) + return function(*args, download_config=download_config, **kwargs) wrapper._decorator_name_ = "wrap_auth" return wrapper @@ -109,14 +109,15 @@ def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"): builder (:class:`DatasetBuilder`): Dataset builder instance. """ # this extends the open and os.path.join functions for data streaming - extend_module_for_streaming(builder.__module__, token=builder.token) + download_config = DownloadConfig(storage_options=builder.storage_options, token=builder.use_auth_token) + extend_module_for_streaming(builder.__module__, download_config=download_config) # if needed, we also have to extend additional internal imports (like wmt14 -> wmt_utils) if not builder.__module__.startswith("datasets."): # check that it's not a packaged builder like csv for imports in get_imports(inspect.getfile(builder.__class__)): if imports[0] == "internal": internal_import_name = imports[1] internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name]) - extend_module_for_streaming(internal_module_name, token=builder.token) + extend_module_for_streaming(internal_module_name, download_config=download_config) # builders can inherit from other builders that might use streaming functionality # (for example, ImageFolder and AudioFolder inherit from FolderBuilder which implements examples generation) @@ -129,4 +130,4 @@ def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"): if issubclass(cls, DatasetBuilder) and cls.__module__ != DatasetBuilder.__module__ ] # check it's not a standard builder from datasets.builder for module in parent_builder_modules: - extend_module_for_streaming(module, token=builder.token) + extend_module_for_streaming(module, download_config=download_config) diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py index 0934313393a..34be5ad0111 100644 --- a/tests/test_streaming_download_manager.py +++ b/tests/test_streaming_download_manager.py @@ -7,6 +7,7 @@ from fsspec.registry import _registry as _fsspec_registry from fsspec.spec import AbstractBufferedFile, AbstractFileSystem +from datasets.download.download_config import DownloadConfig from datasets.download.streaming_download_manager import ( StreamingDownloadManager, _get_extraction_protocol, @@ -236,8 +237,9 @@ def test_xexists(input_path, exists, tmp_path, mock_fsspec): @pytest.mark.integration def test_xexists_private(hf_private_dataset_repo_txt_data, hf_token): root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "") - assert xexists(root_url + "data/text_data.txt", token=hf_token) - assert not xexists(root_url + "file_that_doesnt_exist.txt", token=hf_token) + download_config = DownloadConfig(token=hf_token) + assert xexists(root_url + "data/text_data.txt", download_config=download_config) + assert not xexists(root_url + "file_that_doesnt_exist.txt", download_config=download_config) @pytest.mark.parametrize( @@ -320,12 +322,13 @@ def test_xlistdir(input_path, expected_paths, tmp_path, mock_fsspec): @pytest.mark.integration def test_xlistdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token): root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") - assert len(xlistdir("zip://::" + root_url, token=hf_token)) == 1 - assert len(xlistdir("zip://main_dir::" + root_url, token=hf_token)) == 2 + download_config = DownloadConfig(token=hf_token) + assert len(xlistdir("zip://::" + root_url, download_config=download_config)) == 1 + assert len(xlistdir("zip://main_dir::" + root_url, download_config=download_config)) == 2 with pytest.raises(FileNotFoundError): - xlistdir("zip://qwertyuiop::" + root_url, token=hf_token) + xlistdir("zip://qwertyuiop::" + root_url, download_config=download_config) with pytest.raises(NotImplementedError): - xlistdir(root_url, token=hf_token) + xlistdir(root_url, download_config=download_config) @pytest.mark.parametrize( @@ -348,11 +351,13 @@ def test_xisdir(input_path, isdir, tmp_path, mock_fsspec): @pytest.mark.integration def test_xisdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token): root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") - assert xisdir("zip://::" + root_url, token=hf_token) is True - assert xisdir("zip://main_dir::" + root_url, token=hf_token) is True - assert xisdir("zip://qwertyuiop::" + root_url, token=hf_token) is False + + download_config = DownloadConfig(token=hf_token) + assert xisdir("zip://::" + root_url, download_config=download_config) is True + assert xisdir("zip://main_dir::" + root_url, download_config=download_config) is True + assert xisdir("zip://qwertyuiop::" + root_url, download_config=download_config) is False with pytest.raises(NotImplementedError): - xisdir(root_url, token=hf_token) + xisdir(root_url, download_config=download_config) @pytest.mark.parametrize( @@ -374,8 +379,9 @@ def test_xisfile(input_path, isfile, tmp_path, mock_fsspec): @pytest.mark.integration def test_xisfile_private(hf_private_dataset_repo_txt_data, hf_token): root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "") - assert xisfile(root_url + "data/text_data.txt", token=hf_token) is True - assert xisfile(root_url + "qwertyuiop", token=hf_token) is False + download_config = DownloadConfig(token=hf_token) + assert xisfile(root_url + "data/text_data.txt", download_config=download_config) is True + assert xisfile(root_url + "qwertyuiop", download_config=download_config) is False @pytest.mark.parametrize( @@ -397,9 +403,10 @@ def test_xgetsize(input_path, size, tmp_path, mock_fsspec): @pytest.mark.integration def test_xgetsize_private(hf_private_dataset_repo_txt_data, hf_token): root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "") - assert xgetsize(root_url + "data/text_data.txt", token=hf_token) == 39 + download_config = DownloadConfig(token=hf_token) + assert xgetsize(root_url + "data/text_data.txt", download_config=download_config) == 39 with pytest.raises(FileNotFoundError): - xgetsize(root_url + "qwertyuiop", token=hf_token) + xgetsize(root_url + "qwertyuiop", download_config=download_config) @pytest.mark.parametrize( @@ -440,8 +447,9 @@ def test_xglob(input_path, expected_paths, tmp_path, mock_fsspec): @pytest.mark.integration def test_xglob_private(hf_private_dataset_repo_zipped_txt_data, hf_token): root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") - assert len(xglob("zip://**::" + root_url, token=hf_token)) == 3 - assert len(xglob("zip://qwertyuiop/*::" + root_url, token=hf_token)) == 0 + download_config = DownloadConfig(token=hf_token) + assert len(xglob("zip://**::" + root_url, download_config=download_config)) == 3 + assert len(xglob("zip://qwertyuiop/*::" + root_url, download_config=download_config)) == 0 @pytest.mark.parametrize( @@ -478,9 +486,10 @@ def test_xwalk(input_path, expected_outputs, tmp_path, mock_fsspec): @pytest.mark.integration def test_xwalk_private(hf_private_dataset_repo_zipped_txt_data, hf_token): root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") - assert len(list(xwalk("zip://::" + root_url, token=hf_token))) == 2 - assert len(list(xwalk("zip://main_dir::" + root_url, token=hf_token))) == 1 - assert len(list(xwalk("zip://qwertyuiop::" + root_url, token=hf_token))) == 0 + download_config = DownloadConfig(token=hf_token) + assert len(list(xwalk("zip://::" + root_url, download_config=download_config))) == 2 + assert len(list(xwalk("zip://main_dir::" + root_url, download_config=download_config))) == 1 + assert len(list(xwalk("zip://qwertyuiop::" + root_url, download_config=download_config))) == 0 @pytest.mark.parametrize(