Skip to content

Commit

Permalink
remove dependency of parfive, introduce brand new multi-threading & r…
Browse files Browse the repository at this point in the history
…esumable downloader
  • Loading branch information
weixingjian committed Mar 10, 2023
1 parent 22e9ef6 commit 3703cfc
Show file tree
Hide file tree
Showing 3 changed files with 397 additions and 98 deletions.
27 changes: 16 additions & 11 deletions opendatalab/cli/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,24 +136,29 @@ def info(obj: ContextInfo, name):
@command(synopsis=("$ odl get dataset_name # get dataset files into local",))
@click.argument("name", nargs=1)
@click.option(
"--conn",
"-c",
default=5,
help="The number of parallel download slots",
show_default=True,
"--dest",
"-d",
default='',
help="Desired dataset store path",
show_default=True
)
@click.option(
"--workers",
"-w",
default = 8,
help= "number of workers",
show_default = True
)
@click.pass_obj
def get(obj: ContextInfo, name, conn = 5):
def get(obj: ContextInfo, name, dest, workers):
"""Get(Download) dataset files into local path.\f
Args:
obj (ContextInfo): context info\f
name (str): dataset name\f
conn (int): The number of parallel download slots\f
destination(str): desired dataset store path\f
"""

from opendatalab.cli.get import implement_get
implement_get(obj, name, conn)


implement_get(obj, name, dest, workers)
if __name__ == "__main__":
cli()
118 changes: 31 additions & 87 deletions opendatalab/cli/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@
#
# Copyright 2022 Shanghai AI Lab. Licensed under MIT License.
#
import logging
import os
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import List

import click
import parfive
import requests
from tqdm import tqdm

from opendatalab.cli.policy import private_policy_url, service_agreement_url
from opendatalab.cli.utility import ContextInfo, exception_handler
from opendatalab.client import downloader
from opendatalab.exception import OdlDataNotExistsError


Expand All @@ -30,55 +25,10 @@ def handler(dwCtrlType):
if sys.platform == "win32":
import win32api
win32api.SetConsoleCtrlHandler(handler, True)

@exception_handler
def download_from_url(url:str, pth: str, file_name:str):
"""This function perform a resumable download for a single object
Args:
url (str): single download url
pth(str): local download path
file_name (str): file name(may contain relative path)
"""
response = requests.get(url, stream = True)

# get total file size
file_size = int(response.headers['content-length'])

target = os.path.join(pth, file_name)
# indicate a file-downloaing not complete
if os.path.exists(target):
first_byte = os.path.getsize(target)
else:
# indicate a new file
first_byte = 0

# check actual size and server size
if first_byte >= file_size:
click.secho('Download Complete')
sys.exit(1)

header = {"Range": f"bytes = {first_byte}-{file_size}"}

pbar = tqdm(total=file_size,
initial= first_byte,
unit = 'B',
unit_scale= True,
desc = 'Downloading Progress:')

req = requests.get(url, headers= header, stream=True)

with(open(target, 'ab')) as f:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size


@exception_handler
def implement_get(obj: ContextInfo, name: str, conn = 5):
def implement_get(obj: ContextInfo, name: str, destination:str, num_workers:int):
"""
implementation for getting dataset files
Args:
Expand All @@ -96,7 +46,7 @@ def implement_get(obj: ContextInfo, name: str, conn = 5):
dataset_name = name
sub_dir = ""

# print(name, ds_split ,dataset_name)
# print(name, ds_split ,dataset_name, sub_dir)

client = obj.get_client()
data_info = client.get_api().get_info(dataset_name)
Expand All @@ -113,16 +63,16 @@ def implement_get(obj: ContextInfo, name: str, conn = 5):
curr_dict = {}
if not info['isDir']:
curr_dict['size'] = info['size']
curr_dict['name'] = info['path']
curr_dict['name'] = os.path.join(sub_dir,info['path'])
obj_info_list.append(curr_dict)

# if not sub_dir:
print(obj_info_list, sub_dir)
# print(obj_info_list, sub_dir)
download_urls_list = client.get_api().get_dataset_download_urls(
dataset_id=info_dataset_id,
dataset_list=obj_info_list)
# print(obj_info_list)
print('____________________________________________________-')
print('___________________________________________________')


url_list = []
Expand All @@ -131,11 +81,9 @@ def implement_get(obj: ContextInfo, name: str, conn = 5):
url_list.append(item['url'])
item_list.append(item['name'])

print(url_list[0], item_list[0])
# print(url_list[0], item_list[0])



local_dir = Path.cwd().joinpath(info_dataset_name)
local_dir = destination

download_data = client.get_api().get_download_record(info_dataset_name)
has_download = download_data['hasDownload']
Expand All @@ -146,40 +94,36 @@ def implement_get(obj: ContextInfo, name: str, conn = 5):
f"\n[Warning]: Before downloading, please agree above content."):
client.get_api().submit_download_record(info_dataset_name, download_data)
else:
click.secho('bye~')
click.secho('See you next time~!')
sys.exit(1)

if click.confirm(f"Download files into local directory: {local_dir} ?", default=True):
if not Path(local_dir).exists():
Path(local_dir).mkdir(parents=True)
print(f"create local dir: {local_dir}")
else:
click.secho('bye~')
click.secho('See you next time~!')
sys.exit(1)

# print(url_list[0], item_list[0])
########################################################################
size = download_from_url(url_list[0], pth=local_dir, file_name = item_list[0])
########################################################################
print(size)


# downloader = parfive.Downloader(max_conn = conn,
# max_splits= 5,
# progress= True)

# for idx, url in enumerate(url_list):
# downloader.enqueue_file(url, path = local_dir, filename=item_list[idx])

# results = downloader.download()

# for i in results:
# click.echo(i)

# err_str = ''
# for err in results.errors:
# err_str += f"{err.url} \t {err.exception}\n"
# if not err_str:
# print(f"{info_dataset_name}, download completed!")
# else:
# sys.exit(err_str)

with tqdm(total = len(url_list)) as pbar:
for idx in range(len(url_list)):
if len(item_list[idx].split('/')) == 1:
filename = item_list[idx]
prefix = ''
else:
filename = item_list[idx].split('/')[-1]
prefix = os.path.dirname(item_list[idx])

click.echo(f"Downloading No.{idx+1} of total {len(url_list)} files\n")
if os.path.exists(os.path.join(destination,info_dataset_name, prefix,filename)):
click.echo('target already exists, jumping to next!')
continue
downloader.Downloader(url = url_list[idx],
filename=item_list[idx],
download_dir = os.path.join(destination, info_dataset_name),
blocks_num= num_workers).start()
pbar.update(1)
click.echo(f"\nDownload Complete!")
Loading

0 comments on commit 3703cfc

Please sign in to comment.