This repository was archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 919
/
Copy pathurl_utils.py
133 lines (112 loc) · 4.29 KB
/
url_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Common utilities for downloading and extracting datasets."""
import logging
import math
import os
import tarfile
import zipfile
from contextlib import contextmanager
from tempfile import TemporaryDirectory
import requests
from tqdm import tqdm
from google_drive_downloader import GoogleDriveDownloader as gdd
logger = logging.getLogger(__name__)
def maybe_download(url, filename=None, work_directory=".", expected_bytes=None):
"""Download a file if it is not already downloaded.
Args:
filename (str): File name.
work_directory (str): Working directory.
url (str): URL of the file to download.
expected_bytes (int): Expected file size in bytes.
Returns:
str: File path of the file downloaded.
"""
if filename is None:
filename = url.split("/")[-1]
os.makedirs(work_directory, exist_ok=True)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
if not os.path.isdir(work_directory):
os.makedirs(work_directory)
r = requests.get(url, stream=True)
total_size = int(r.headers.get("content-length", 0))
block_size = 1024
num_iterables = math.ceil(total_size / block_size)
with open(filepath, "wb") as file:
for data in tqdm(
r.iter_content(block_size),
total=num_iterables,
unit="KB",
unit_scale=True,
):
file.write(data)
else:
logger.info("File {} already downloaded".format(filepath))
if expected_bytes is not None:
statinfo = os.stat(filepath)
if statinfo.st_size != expected_bytes:
os.remove(filepath)
raise IOError("Failed to verify {}".format(filepath))
return filepath
def maybe_download_googledrive(
google_file_id, file_name, work_directory=".", expected_bytes=None
):
"""Download a file from google drive if it is not already downloaded.
Args:
google_file_id (str): The ID of the google file which can be found in
the file link, e.g. https://drive.google.com/file/d/{google_file_id}/view
file_name (str): Name of the downloaded file.
work_directory (str, optional): Directory to download the file to.
Defaults to ".".
expected_bytes (int, optional): Expected file size in bytes.
Returns:
str: File path of the file downloaded.
"""
os.makedirs(work_directory, exist_ok=True)
filepath = os.path.join(work_directory, file_name)
if not os.path.exists(filepath):
gdd.download_file_from_google_drive(file_id=google_file_id, dest_path=filepath)
else:
logger.info("File {} already downloaded".format(filepath))
if expected_bytes is not None:
statinfo = os.stat(filepath)
if statinfo.st_size != expected_bytes:
os.remove(filepath)
raise IOError("Failed to verify {}".format(filepath))
return filepath
def extract_tar(file_path, dest_path="."):
"""Extracts all contents of a tar archive file.
Args:
file_path (str): Path of file to extract.
dest_path (str, optional): Destination directory. Defaults to ".".
"""
if not os.path.exists(file_path):
raise IOError("File doesn't exist")
if not os.path.exists(dest_path):
raise IOError("Destination directory doesn't exist")
with tarfile.open(file_path) as t:
t.extractall(path=dest_path)
def extract_zip(file_path, dest_path="."):
"""Extracts all contents of a zip archive file.
Args:
file_path (str): Path of file to extract.
dest_path (str, optional): Destination directory. Defaults to ".".
"""
if not os.path.exists(file_path):
raise IOError("File doesn't exist")
if not os.path.exists(dest_path):
raise IOError("Destination directory doesn't exist")
with zipfile.ZipFile(file_path) as z:
z.extractall(dest_path, filter(lambda f: not f.endswith("\r"), z.namelist()))
@contextmanager
def download_path(path):
tmp_dir = TemporaryDirectory()
if path is None:
path = tmp_dir.name
else:
path = os.path.realpath(path)
try:
yield path
finally:
tmp_dir.cleanup()