diff --git a/aeon/datasets/_data_loaders.py b/aeon/datasets/_data_loaders.py index d4c20b5cf8..53cda54ffc 100644 --- a/aeon/datasets/_data_loaders.py +++ b/aeon/datasets/_data_loaders.py @@ -23,6 +23,7 @@ import zipfile from datetime import datetime from http.client import IncompleteRead, RemoteDisconnected +from pathlib import Path from urllib.error import HTTPError, URLError from urllib.parse import urlparse from urllib.request import Request, urlopen, urlretrieve @@ -465,7 +466,9 @@ def _download_and_extract(url, extract_path=None): with open(zip_file_name, "wb") as out_file: out_file.write(response.read()) if extract_path is None: - extract_path = os.path.join(MODULE, "local_data/%s/" % file_name.split(".")[0]) + extract_path = os.path.join( + str(Path.home() / ".aeon"), "local_data/%s/" % file_name.split(".")[0] + ) else: extract_path = os.path.join(extract_path, "%s/" % file_name.split(".")[0]) @@ -524,8 +527,14 @@ def _load_tsc_dataset( local_module = extract_path local_dirname = "" else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + aeon_home = Path.home() / ".aeon" + local_module = str(aeon_home) + local_dirname = "data" if not os.path.exists(os.path.join(local_module, local_dirname)): os.makedirs(os.path.join(local_module, local_dirname)) @@ -545,7 +554,11 @@ def _load_tsc_dataset( try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile as e: raise ValueError( @@ -987,8 +1000,13 @@ def load_forecasting(name, extract_path=None, return_metadata=False): local_module = extract_path local_dirname = "" else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + local_module = str(Path.home() / ".aeon") + local_dirname = "data" if not os.path.exists(os.path.join(local_module, local_dirname)): os.makedirs(os.path.join(local_module, local_dirname)) @@ -1028,7 +1046,11 @@ def load_forecasting(name, extract_path=None, return_metadata=False): try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile: raise ValueError( @@ -1141,8 +1163,13 @@ def load_regression( local_module = extract_path local_dirname = "" else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + local_module = str(Path.home() / ".aeon") + local_dirname = "data" error_str = ( f"File name {name} is not in the list of valid files to download," f"see aeon.datasets.tser_datasetss.tser_soton for the list. " @@ -1182,7 +1209,11 @@ def load_regression( try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile: try_monash = True @@ -1322,8 +1353,13 @@ def load_classification( local_module = extract_path local_dirname = None else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + local_module = str(Path.home() / ".aeon") + local_dirname = "data" if local_dirname is None: path = local_module else: @@ -1362,7 +1398,11 @@ def load_classification( try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile: try_zenodo = True @@ -1443,7 +1483,7 @@ def download_all_regression(extract_path=None): local_module = extract_path local_dirname = "" else: - local_module = MODULE + local_module = str(Path.home() / ".aeon") local_dirname = "data" if not os.path.exists(os.path.join(local_module, local_dirname)): diff --git a/aeon/datasets/tests/test_data_loaders.py b/aeon/datasets/tests/test_data_loaders.py index 29d7049b9e..59da1980d6 100644 --- a/aeon/datasets/tests/test_data_loaders.py +++ b/aeon/datasets/tests/test_data_loaders.py @@ -57,7 +57,7 @@ def test_load_forecasting_from_repo(): assert not meta["contain_missing_values"] assert not meta["contain_equal_length"] - shutil.rmtree(os.path.dirname(__file__) + "/../local_data") + shutil.rmtree(os.path.dirname(__file__) + "/../local_data", ignore_errors=True) @pytest.mark.skipif( @@ -84,7 +84,7 @@ def test_load_classification_from_repo(): assert meta["classlabel"] assert not meta["targetlabel"] assert meta["class_values"] == ["1", "2"] - shutil.rmtree(os.path.dirname(__file__) + "/../local_data") + shutil.rmtree(os.path.dirname(__file__) + "/../local_data", ignore_errors=True) @pytest.mark.skipif(