Skip to content

Commit

Permalink
fsspec get uses tqdm, tries to handle additional protocols, and compu…
Browse files Browse the repository at this point in the history
…tes pseudo etag from head response
  • Loading branch information
dwyatte committed Feb 28, 2023
1 parent 2fceb77 commit 17d88d1
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,19 +330,24 @@ def _request_with_retry(

def fsspec_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
fsspec.filesystem(urlparse(url).scheme).info(url, timeout=timeout)
except Exception:
return False
return True
fs, _, paths = fsspec.get_fs_token_paths(url)
if len(paths) > 1:
raise ValueError("HEAD can be called with at most one path but was called with {paths}")
return fs.info(paths[0], timeout=timeout)


def fsspec_get(url, temp_file, timeout=10.0):
def fsspec_get(url, temp_file, timeout=10.0, desc=None):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
fsspec.filesystem(urlparse(url).scheme).get(url, temp_file, timeout=timeout)
except fsspec.FSTimeoutError as e:
raise ConnectionError(e) from None
fs, _, paths = fsspec.get_fs_token_paths(url)
if len(paths) > 1:
raise ValueError("GET can be called with at most one path but was called with {paths}")
callback = fsspec.callbacks.TqdmCallback(
tqdm_kwargs={
"desc": desc or "Downloading",
"disable": logging.is_progress_bar_enabled(),
}
)
fs.get(paths[0], temp_file, timeout=timeout, callback=callback)


def ftp_head(url, timeout=10.0):
Expand Down Expand Up @@ -493,8 +498,11 @@ def get_from_cache(
scheme = urlparse(url).scheme
if scheme == "ftp":
connected = ftp_head(url)
elif scheme in ("s3", "gs"):
connected = fsspec_head(url)
elif scheme not in ("http", "https"):
response = fsspec_head(url)
# use the hash of the response as a pseudo ETag to detect changes
etag = json.dumps(response, sort_keys=True) if use_etag else None
connected = True
try:
response = http_head(
url,
Expand Down Expand Up @@ -595,8 +603,8 @@ def _resumable_file_manager():
# GET file object
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme in ("gs", "s3"):
fsspec_get(url, temp_file)
elif scheme not in ("http", "https"):
fsspec_get(url, temp_file, desc=download_desc)
else:
http_get(
url,
Expand Down

0 comments on commit 17d88d1

Please sign in to comment.