Skip to content

Commit 9c57771

Browse files
committed
Report dataset download progress
1 parent 73c3b6c commit 9c57771

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

benchmark/dataset.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
import shutil
33
import tarfile
4+
import tqdm
45
import urllib.request
56
from dataclasses import dataclass, field
6-
from typing import Dict, Optional
7+
from typing import Dict, Optional, Callable
78
from urllib.request import build_opener, install_opener
89

910
from benchmark import DATASETS_DIR
@@ -54,7 +55,12 @@ def download(self):
5455

5556
if self.config.link:
5657
print(f"Downloading {self.config.link}...")
57-
tmp_path, _ = urllib.request.urlretrieve(self.config.link)
58+
with tqdm.tqdm(
59+
unit="B", unit_scale=True, miniters=1, dynamic_ncols=True, disable=None
60+
) as t:
61+
tmp_path, _ = urllib.request.urlretrieve(
62+
self.config.link, reporthook=_tqdm_reporthook(t)
63+
)
5864

5965
if self.config.link.endswith(".tgz") or self.config.link.endswith(
6066
".tar.gz"
@@ -76,6 +82,15 @@ def get_reader(self, normalize: bool) -> BaseReader:
7682
return reader_class(DATASETS_DIR / self.config.path, normalize=normalize)
7783

7884

85+
def _tqdm_reporthook(t: tqdm.tqdm) -> Callable[[int, int, int], None]:
86+
def reporthook(blocknum: int, block_size: int, total_size: int) -> None:
87+
if total_size > 0:
88+
t.total = total_size
89+
t.update(blocknum * block_size - t.n)
90+
91+
return reporthook
92+
93+
7994
if __name__ == "__main__":
8095
dataset = Dataset(
8196
{

0 commit comments

Comments
 (0)