Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pdebench/data_download/download_direct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
import argparse
from pathlib import Path

Expand Down Expand Up @@ -58,7 +59,7 @@ def parse_metadata(pde_names):
return meta_df[meta_df["PDE"].isin(pde_names)]


def download_data(root_folder, pde_name):
def download_data(root_folder, pde_name, no_fast_download=False):
""" "
Download data splits specific to a given PDE.

Expand All @@ -75,6 +76,14 @@ def download_data(root_folder, pde_name):
# Iterate filtered dataframe and download the files
for _, row in tqdm(pde_df.iterrows(), total=pde_df.shape[0]):
file_path = Path(root_folder) / row["Path"]
if not no_fast_download:
try:
file_size = os.path.getsize(file_path / row["Filename"])
except (FileNotFoundError, OSError):
file_size = 0
if file_size != 0:
# print(file_path / row["Filename"], file_size)
continue
download_url(row["URL"], file_path, row["Filename"], md5=row["MD5"])


Expand All @@ -96,7 +105,12 @@ def download_data(root_folder, pde_name):
action="append",
help="Name of the PDE dataset to download. You can use this flag multiple times to download multiple datasets",
)
arg_parser.add_argument(
"--no_fast_download",
action="store_true",
help="Disable fast download mode, which skips files that are already downloaded",
)

args = arg_parser.parse_args()

download_data(args.root_folder, args.pde_name)
download_data(args.root_folder, args.pde_name, args.no_fast_download)