diff --git a/environment.yml b/environment.yml index 4b0e5076..5e450476 100644 --- a/environment.yml +++ b/environment.yml @@ -23,6 +23,7 @@ dependencies: - numpy=1.22.4 - pandas=1.5.1 - google-cloud-sdk=410.0.0 + - aria2=1.36.0 - pip=22.3 - pip: - earthengine-api==0.1.329 diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index 3ff07b1d..2d01bb93 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -20,14 +20,17 @@ import json import logging import os +import subprocess import typing as t import warnings +from urllib.parse import urljoin import cdsapi import urllib3 -from ecmwfapi import ECMWFService +from ecmwfapi import ECMWFService, api from .config import Config, optimize_selection_partition +from .util import retry_with_exponential_backoff warnings.simplefilter( "ignore", category=urllib3.connectionpool.InsecureRequestWarning) @@ -55,8 +58,9 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: """Download from data source.""" pass + @classmethod @abc.abstractmethod - def num_requests_per_key(self, dataset: str) -> int: + def num_requests_per_key(cls, dataset: str) -> int: """Specifies the number of workers to be used per api key for the dataset.""" pass @@ -108,7 +112,8 @@ def retrieve(self, dataset: str, selection: t.Dict, target: str) -> None: def license_url(self): return 'https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf' - def num_requests_per_key(self, dataset: str) -> int: + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key from the CDS API. CDS has dynamic, data-specific limits, defined here: @@ -123,7 +128,7 @@ def num_requests_per_key(self, dataset: str) -> int: https://cds.climate.copernicus.eu/cdsapp#!/yourrequests """ # TODO(#15): Parse live CDS limits API to set data-specific limits. - for internal_set in self.cds_hosted_datasets: + for internal_set in cls.cds_hosted_datasets: if dataset.startswith(internal_set): return 5 return 2 @@ -154,6 +159,80 @@ def __exit__(self, exc_type, exc_value, traceback): self._redirector.__exit__(exc_type, exc_value, traceback) +class SplitMARSRequest(api.APIRequest): + """Extended MARS APIRequest class that separates fetch and download stage.""" + @retry_with_exponential_backoff + def _download(self, url, path: str, size: int) -> None: + self.log( + "Transferring %s into %s" % (self._bytename(size), path) + ) + self.log("From %s" % (url,)) + + dir_path, file_name = os.path.split(path) + try: + subprocess.run( + ['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'], + check=True, + capture_output=True) + except subprocess.CalledProcessError as e: + self.log(f'Failed download from ECMWF server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}') + + def fetch(self, request: t.Dict) -> t.Dict: + status = None + + self.connection.submit("%s/%s/requests" % (self.url, self.service), request) + self.log("Request submitted") + self.log("Request id: " + self.connection.last.get("name")) + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + + while not self.connection.ready(): + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + self.connection.wait() + + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + + result = self.connection.result() + return result + + def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: + if target: + if os.path.exists(target): + # Empty the target file, if it already exists, otherwise the + # transfer below might be fooled into thinking we're resuming + # an interrupted download. + open(target, "w").close() + + self._download(urljoin(self.url, result["href"]), target, result["size"]) + self.connection.cleanup() + + +class MARSECMWFServiceExtended(ECMWFService): + """Extended MARS ECMFService class that separates fetch and download stage.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.c = SplitMARSRequest( + self.url, + "services/%s" % (self.service,), + email=self.email, + key=self.key, + log=self.log, + verbose=self.verbose, + quiet=self.quiet, + ) + + def fetch(self, req: t.Dict) -> t.Dict: + return self.c.fetch(req) + + def download(self, res: t.Dict, target: str) -> None: + self.c.download(res, target) + + class MarsClient(Client): """A client to access data from the Meteorological Archival and Retrieval System (MARS). @@ -176,25 +255,27 @@ class MarsClient(Client): def __init__(self, config: Config, level: int = logging.INFO) -> None: super().__init__(config, level) - self.c = ECMWFService( + + def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: + self.c = MARSECMWFServiceExtended( "mars", - key=config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), - url=config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), - email=config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), + key=self.config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), + url=self.config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), + email=self.config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), log=self.logger.debug, - verbose=True + verbose=True, ) - - def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): - self.c.execute(req=selection_, target=output) + result = self.c.fetch(req=selection_) + self.c.download(result, target=output) @property def license_url(self): return 'https://apps.ecmwf.int/datasets/licences/general/' - def num_requests_per_key(self, dataset: str) -> int: + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key (or user) for the Mars API. Mars allows 2 active requests per user and 20 queued requests per user, as of Sept 27, 2021. @@ -221,7 +302,8 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: def license_url(self): return 'lorem ipsum' - def num_requests_per_key(self, dataset: str) -> int: + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: return 1 diff --git a/weather_dl/download_pipeline/clients_test.py b/weather_dl/download_pipeline/clients_test.py index f11f8ea6..f211e324 100644 --- a/weather_dl/download_pipeline/clients_test.py +++ b/weather_dl/download_pipeline/clients_test.py @@ -15,29 +15,20 @@ import unittest from .clients import FakeClient, CdsClient, MarsClient -from .config import Config class MaxWorkersTest(unittest.TestCase): def test_cdsclient_internal(self): - client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-era5-some-data"), 5) + self.assertEqual(CdsClient.num_requests_per_key("reanalysis-era5-some-data"), 5) def test_cdsclient_mars_hosted(self): - client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-carra-height-levels"), 2) + self.assertEqual(CdsClient.num_requests_per_key("reanalysis-carra-height-levels"), 2) def test_marsclient(self): - client = MarsClient(Config.from_dict({'parameters': {}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-era5-some-data"), 2) + self.assertEqual(MarsClient.num_requests_per_key("reanalysis-era5-some-data"), 2) def test_fakeclient(self): - client = FakeClient(Config.from_dict({'parameters': {}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-era5-some-data"), 1) + self.assertEqual(FakeClient.num_requests_per_key("reanalysis-era5-some-data"), 1) if __name__ == '__main__':