From 69b341a21b0bb58c6cfaf4fdc10f6bfb69a2dcb2 Mon Sep 17 00:00:00 2001 From: mahrsee1997 Date: Wed, 20 Apr 2022 13:27:53 -0700 Subject: [PATCH 1/9] restructured the fetch stage to be a Composite Beam transform and separated the request & download stage of fetching --- weather_dl/download_pipeline/clients.py | 120 +++++++++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index 3ff07b1d..098c6109 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -20,12 +20,14 @@ import json import logging import os +import time import typing as t import warnings +from urllib.parse import urljoin import cdsapi import urllib3 -from ecmwfapi import ECMWFService +from ecmwfapi import ECMWFService, api from .config import Config, optimize_selection_partition @@ -60,6 +62,16 @@ def num_requests_per_key(self, dataset: str) -> int: """Specifies the number of workers to be used per api key for the dataset.""" pass + @abc.abstractmethod + def fetch(self, dataset: str, selection: t.Dict) -> t.Dict: + """Fetch data from data source.""" + pass + + @abc.abstractmethod + def download(self, dataset: str, result: t.Dict, output: str) -> None: + """Download from data source.""" + pass + @property @abc.abstractmethod def license_url(self): @@ -104,6 +116,12 @@ def retrieve(self, dataset: str, selection: t.Dict, target: str) -> None: selection_ = optimize_selection_partition(selection) self.c.retrieve(dataset, selection_, target) + def fetch(self, dataset: str, selection: t.Dict) -> None: + pass + + def download(self, dataset: str, result: t.Dict, output: str) -> None: + pass + @property def license_url(self): return 'https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf' @@ -154,6 +172,90 @@ def __exit__(self, exc_type, exc_value, traceback): self._redirector.__exit__(exc_type, exc_value, traceback) +class APIRequestExtended(api.APIRequest): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def fetch(self, request): + status = None + + self.connection.submit("%s/%s/requests" % (self.url, self.service), request) + self.log("Request submitted") + self.log("Request id: " + self.connection.last.get("name")) + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + + while not self.connection.ready(): + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + self.connection.wait() + + if self.connection.status != status: + status = self.connection.status + self.log("Request is %s" % (status,)) + + result = self.connection.result() + return result + + def download(self, result, target=None): + if target: + if os.path.exists(target): + # Empty the target file, if it already exists, otherwise the + # transfer below might be fooled into thinking we're resuming + # an interrupted download. + open(target, "w").close() + + size = -1 + tries = 0 + while size != result["size"] and tries < 10: + size = self._transfer( + urljoin(self.url, result["href"]), target, result["size"] + ) + if size != result["size"] and tries < 10: + tries += 1 + self.log("Transfer interrupted, resuming in 60s...") + time.sleep(60) + else: + break + + assert size == result["size"] + + self.connection.cleanup() + + return result + + +class ECMWFServiceExtended(ECMWFService): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def fetch(self, req): + c = APIRequestExtended( + self.url, + "services/%s" % (self.service,), + email=self.email, + key=self.key, + log=self.log, + verbose=self.verbose, + quiet=self.quiet, + ) + return c.fetch(req) + + def download(self, res, target): + c = APIRequestExtended( + self.url, + "services/%s" % (self.service,), + email=self.email, + key=self.key, + log=self.log, + verbose=self.verbose, + quiet=self.quiet, + ) + c.download(res, target) + + class MarsClient(Client): """A client to access data from the Meteorological Archival and Retrieval System (MARS). @@ -176,7 +278,7 @@ class MarsClient(Client): def __init__(self, config: Config, level: int = logging.INFO) -> None: super().__init__(config, level) - self.c = ECMWFService( + self.c = ECMWFServiceExtended( "mars", key=config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), url=config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), @@ -190,6 +292,14 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: with StdoutLogger(self.logger, level=logging.DEBUG): self.c.execute(req=selection_, target=output) + def fetch(self, dataset: str, selection: t.Dict) -> t.Dict: + with StdoutLogger(self.logger, level=logging.DEBUG): + return self.c.fetch(req=selection) + + def download(self, dataset: str, result: t.Dict, output: str) -> None: + with StdoutLogger(self.logger, level=logging.DEBUG): + self.c.download(res=result, target=output) + @property def license_url(self): return 'https://apps.ecmwf.int/datasets/licences/general/' @@ -217,6 +327,12 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: with open(output, 'w') as f: json.dump({dataset: selection}, f) + def fetch(self, dataset: str, selection: t.Dict) -> None: + pass + + def download(self, dataset: str, result: t.Dict, output: str) -> None: + pass + @property def license_url(self): return 'lorem ipsum' From 96a6c3338db392277b347613c2e4662607947eab Mon Sep 17 00:00:00 2001 From: mahrsee1997 Date: Fri, 22 Apr 2022 04:01:59 -0700 Subject: [PATCH 2/9] retry logic of downloads for MARS client & other cosmetic changes. --- weather_dl/download_pipeline/clients.py | 90 +++++++++++++------------ 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index 098c6109..3edaa6ef 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -20,16 +20,23 @@ import json import logging import os -import time +import shutil import typing as t import warnings +from contextlib import closing from urllib.parse import urljoin +from urllib.request import ( + Request, + urlopen, +) import cdsapi import urllib3 +from apache_beam.io.gcp.gcsio import DEFAULT_READ_BUFFER_SIZE from ecmwfapi import ECMWFService, api from .config import Config, optimize_selection_partition +from .util import retry_with_exponential_backoff warnings.simplefilter( "ignore", category=urllib3.connectionpool.InsecureRequestWarning) @@ -117,10 +124,10 @@ def retrieve(self, dataset: str, selection: t.Dict, target: str) -> None: self.c.retrieve(dataset, selection_, target) def fetch(self, dataset: str, selection: t.Dict) -> None: - pass + raise NotImplementedError() def download(self, dataset: str, result: t.Dict, output: str) -> None: - pass + raise NotImplementedError() @property def license_url(self): @@ -172,11 +179,30 @@ def __exit__(self, exc_type, exc_value, traceback): self._redirector.__exit__(exc_type, exc_value, traceback) -class APIRequestExtended(api.APIRequest): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class SplitMARSRequest(api.APIRequest): + """Extended MARS APIRequest class that separates fetch and download stage.""" + @retry_with_exponential_backoff + def _download(self, url, path: str, size: int) -> None: + existing_size = 0 + req = Request(url) + + if os.path.exists(path): + mode = "ab" + existing_size = os.path.getsize(path) + req.add_header("Range", "bytes=%s-" % existing_size) + else: + mode = "wb" - def fetch(self, request): + self.log( + "Transfering %s into %s" % (self._bytename(size), path) + ) + self.log("From %s" % (url,)) + + with open(path, mode) as f: + with closing(urlopen(req)) as http: + shutil.copyfileobj(http, f, DEFAULT_READ_BUFFER_SIZE) + + def fetch(self, request: t.Dict) -> t.Dict: status = None self.connection.submit("%s/%s/requests" % (self.url, self.service), request) @@ -199,7 +225,7 @@ def fetch(self, request): result = self.connection.result() return result - def download(self, result, target=None): + def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: if target: if os.path.exists(target): # Empty the target file, if it already exists, otherwise the @@ -207,32 +233,15 @@ def download(self, result, target=None): # an interrupted download. open(target, "w").close() - size = -1 - tries = 0 - while size != result["size"] and tries < 10: - size = self._transfer( - urljoin(self.url, result["href"]), target, result["size"] - ) - if size != result["size"] and tries < 10: - tries += 1 - self.log("Transfer interrupted, resuming in 60s...") - time.sleep(60) - else: - break - - assert size == result["size"] - + self._download(urljoin(self.url, result["href"]), target, result["size"]) self.connection.cleanup() - return result - -class ECMWFServiceExtended(ECMWFService): +class MARSECMWFServiceExtended(ECMWFService): + """Extended MARS ECMFService class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - def fetch(self, req): - c = APIRequestExtended( + self.c = SplitMARSRequest( self.url, "services/%s" % (self.service,), email=self.email, @@ -241,19 +250,12 @@ def fetch(self, req): verbose=self.verbose, quiet=self.quiet, ) - return c.fetch(req) - def download(self, res, target): - c = APIRequestExtended( - self.url, - "services/%s" % (self.service,), - email=self.email, - key=self.key, - log=self.log, - verbose=self.verbose, - quiet=self.quiet, - ) - c.download(res, target) + def fetch(self, req: t.Dict) -> t.Dict: + return self.c.fetch(req) + + def download(self, res: t.Dict, target: str) -> None: + self.c.download(res, target) class MarsClient(Client): @@ -278,7 +280,7 @@ class MarsClient(Client): def __init__(self, config: Config, level: int = logging.INFO) -> None: super().__init__(config, level) - self.c = ECMWFServiceExtended( + self.c = MARSECMWFServiceExtended( "mars", key=config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), url=config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), @@ -328,10 +330,10 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: json.dump({dataset: selection}, f) def fetch(self, dataset: str, selection: t.Dict) -> None: - pass + raise NotImplementedError() def download(self, dataset: str, result: t.Dict, output: str) -> None: - pass + raise NotImplementedError() @property def license_url(self): From 752d2c80cbf15291bd15cb8e4f5d302141b7066f Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Fri, 16 Sep 2022 12:44:11 -0700 Subject: [PATCH 3/9] Remove fetch / dl split --- weather_dl/download_pipeline/clients.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index 3edaa6ef..5dc4033f 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -69,16 +69,6 @@ def num_requests_per_key(self, dataset: str) -> int: """Specifies the number of workers to be used per api key for the dataset.""" pass - @abc.abstractmethod - def fetch(self, dataset: str, selection: t.Dict) -> t.Dict: - """Fetch data from data source.""" - pass - - @abc.abstractmethod - def download(self, dataset: str, result: t.Dict, output: str) -> None: - """Download from data source.""" - pass - @property @abc.abstractmethod def license_url(self): @@ -294,14 +284,6 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: with StdoutLogger(self.logger, level=logging.DEBUG): self.c.execute(req=selection_, target=output) - def fetch(self, dataset: str, selection: t.Dict) -> t.Dict: - with StdoutLogger(self.logger, level=logging.DEBUG): - return self.c.fetch(req=selection) - - def download(self, dataset: str, result: t.Dict, output: str) -> None: - with StdoutLogger(self.logger, level=logging.DEBUG): - self.c.download(res=result, target=output) - @property def license_url(self): return 'https://apps.ecmwf.int/datasets/licences/general/' From 5f7ce16b1b2c819d0efc7dbef4c86967bea01daf Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Mon, 19 Sep 2022 16:29:40 -0700 Subject: [PATCH 4/9] retrieve in two steps. --- weather_dl/download_pipeline/clients.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index 5dc4033f..f618b8ae 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -173,7 +173,6 @@ class SplitMARSRequest(api.APIRequest): """Extended MARS APIRequest class that separates fetch and download stage.""" @retry_with_exponential_backoff def _download(self, url, path: str, size: int) -> None: - existing_size = 0 req = Request(url) if os.path.exists(path): @@ -184,7 +183,7 @@ def _download(self, url, path: str, size: int) -> None: mode = "wb" self.log( - "Transfering %s into %s" % (self._bytename(size), path) + "Transferring %s into %s" % (self._bytename(size), path) ) self.log("From %s" % (url,)) @@ -282,7 +281,8 @@ def __init__(self, config: Config, level: int = logging.INFO) -> None: def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): - self.c.execute(req=selection_, target=output) + result = self.c.fetch(req=selection_) + self.c.download(result, target=output) @property def license_url(self): From 81558925ea81d2be6c86bb68aa1a68bddf5dc390 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Mon, 19 Sep 2022 16:32:26 -0700 Subject: [PATCH 5/9] rm fetch + dl methods. --- weather_dl/download_pipeline/clients.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index f618b8ae..9f7041ec 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -113,12 +113,6 @@ def retrieve(self, dataset: str, selection: t.Dict, target: str) -> None: selection_ = optimize_selection_partition(selection) self.c.retrieve(dataset, selection_, target) - def fetch(self, dataset: str, selection: t.Dict) -> None: - raise NotImplementedError() - - def download(self, dataset: str, result: t.Dict, output: str) -> None: - raise NotImplementedError() - @property def license_url(self): return 'https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf' @@ -311,12 +305,6 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: with open(output, 'w') as f: json.dump({dataset: selection}, f) - def fetch(self, dataset: str, selection: t.Dict) -> None: - raise NotImplementedError() - - def download(self, dataset: str, result: t.Dict, output: str) -> None: - raise NotImplementedError() - @property def license_url(self): return 'lorem ipsum' From 6eb1c79d00866163a80ca97312da1658b7f16a7d Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Thu, 1 Dec 2022 22:42:56 -0800 Subject: [PATCH 6/9] Fix: `nim_requests_per_key` does not require class construction. --- weather_dl/download_pipeline/clients.py | 14 +++++++++----- weather_dl/download_pipeline/clients_test.py | 16 ++++------------ 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index 9f7041ec..dcb56323 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -64,8 +64,9 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: """Download from data source.""" pass + @classmethod @abc.abstractmethod - def num_requests_per_key(self, dataset: str) -> int: + def num_requests_per_key(cls, dataset: str) -> int: """Specifies the number of workers to be used per api key for the dataset.""" pass @@ -117,7 +118,8 @@ def retrieve(self, dataset: str, selection: t.Dict, target: str) -> None: def license_url(self): return 'https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf' - def num_requests_per_key(self, dataset: str) -> int: + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key from the CDS API. CDS has dynamic, data-specific limits, defined here: @@ -132,7 +134,7 @@ def num_requests_per_key(self, dataset: str) -> int: https://cds.climate.copernicus.eu/cdsapp#!/yourrequests """ # TODO(#15): Parse live CDS limits API to set data-specific limits. - for internal_set in self.cds_hosted_datasets: + for internal_set in cls.cds_hosted_datasets: if dataset.startswith(internal_set): return 5 return 2 @@ -282,7 +284,8 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: def license_url(self): return 'https://apps.ecmwf.int/datasets/licences/general/' - def num_requests_per_key(self, dataset: str) -> int: + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key (or user) for the Mars API. Mars allows 2 active requests per user and 20 queued requests per user, as of Sept 27, 2021. @@ -309,7 +312,8 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: def license_url(self): return 'lorem ipsum' - def num_requests_per_key(self, dataset: str) -> int: + @classmethod + def num_requests_per_key(cls, dataset: str) -> int: return 1 diff --git a/weather_dl/download_pipeline/clients_test.py b/weather_dl/download_pipeline/clients_test.py index f11f8ea6..c9aa6562 100644 --- a/weather_dl/download_pipeline/clients_test.py +++ b/weather_dl/download_pipeline/clients_test.py @@ -20,24 +20,16 @@ class MaxWorkersTest(unittest.TestCase): def test_cdsclient_internal(self): - client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-era5-some-data"), 5) + self.assertEqual(CdsClient.num_requests_per_key("reanalysis-era5-some-data"), 5) def test_cdsclient_mars_hosted(self): - client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-carra-height-levels"), 2) + self.assertEqual(CdsClient.num_requests_per_key("reanalysis-carra-height-levels"), 2) def test_marsclient(self): - client = MarsClient(Config.from_dict({'parameters': {}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-era5-some-data"), 2) + self.assertEqual(MarsClient.num_requests_per_key("reanalysis-era5-some-data"), 2) def test_fakeclient(self): - client = FakeClient(Config.from_dict({'parameters': {}})) - self.assertEqual( - client.num_requests_per_key("reanalysis-era5-some-data"), 1) + self.assertEqual(FakeClient.num_requests_per_key("reanalysis-era5-some-data"), 1) if __name__ == '__main__': From 03592993ad7b6ceb4ca7ab263eb8242b2cb929db Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Thu, 1 Dec 2022 22:44:18 -0800 Subject: [PATCH 7/9] fix lint: removed unused import. --- weather_dl/download_pipeline/clients_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/weather_dl/download_pipeline/clients_test.py b/weather_dl/download_pipeline/clients_test.py index c9aa6562..f211e324 100644 --- a/weather_dl/download_pipeline/clients_test.py +++ b/weather_dl/download_pipeline/clients_test.py @@ -15,7 +15,6 @@ import unittest from .clients import FakeClient, CdsClient, MarsClient -from .config import Config class MaxWorkersTest(unittest.TestCase): From 8fb7af199ef10ef24ae1a58e24e5bbae28f06760 Mon Sep 17 00:00:00 2001 From: mahrsee1997 Date: Tue, 3 Jan 2023 17:53:22 +0000 Subject: [PATCH 8/9] add support for aria2 for faster download --- environment.yml | 1 + weather_dl/download_pipeline/clients.py | 60 +++++++++++------------- weather_dl/download_pipeline/pipeline.py | 2 +- 3 files changed, 29 insertions(+), 34 deletions(-) diff --git a/environment.yml b/environment.yml index 4b0e5076..5e450476 100644 --- a/environment.yml +++ b/environment.yml @@ -23,6 +23,7 @@ dependencies: - numpy=1.22.4 - pandas=1.5.1 - google-cloud-sdk=410.0.0 + - aria2=1.36.0 - pip=22.3 - pip: - earthengine-api==0.1.329 diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index dcb56323..6c4c2657 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -20,19 +20,13 @@ import json import logging import os -import shutil +import subprocess import typing as t import warnings -from contextlib import closing from urllib.parse import urljoin -from urllib.request import ( - Request, - urlopen, -) import cdsapi import urllib3 -from apache_beam.io.gcp.gcsio import DEFAULT_READ_BUFFER_SIZE from ecmwfapi import ECMWFService, api from .config import Config, optimize_selection_partition @@ -53,11 +47,12 @@ class Client(abc.ABC): level: Default log level for the client. """ - def __init__(self, config: Config, level: int = logging.INFO) -> None: + def __init__(self, config: Config, level: int = logging.INFO, initialize_connection: bool = True) -> None: """Clients are initialized with the general CLI configuration.""" self.config = config self.logger = logging.getLogger(f'{__name__}.{type(self).__name__}') self.logger.setLevel(level) + self.initialize_connection = initialize_connection @abc.abstractmethod def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: @@ -99,7 +94,7 @@ class CdsClient(Client): """Name patterns of datasets that are hosted internally on CDS servers.""" cds_hosted_datasets = {'reanalysis-era'} - def __init__(self, config: Config, level: int = logging.INFO) -> None: + def __init__(self, config: Config, level: int = logging.INFO, initialize_connection: bool = True) -> None: super().__init__(config, level) self.c = cdsapi.Client( url=config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')), @@ -169,23 +164,19 @@ class SplitMARSRequest(api.APIRequest): """Extended MARS APIRequest class that separates fetch and download stage.""" @retry_with_exponential_backoff def _download(self, url, path: str, size: int) -> None: - req = Request(url) - - if os.path.exists(path): - mode = "ab" - existing_size = os.path.getsize(path) - req.add_header("Range", "bytes=%s-" % existing_size) - else: - mode = "wb" - self.log( "Transferring %s into %s" % (self._bytename(size), path) ) self.log("From %s" % (url,)) - with open(path, mode) as f: - with closing(urlopen(req)) as http: - shutil.copyfileobj(http, f, DEFAULT_READ_BUFFER_SIZE) + dir_path, file_name = os.path.split(path) + try: + subprocess.run( + ['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'], + check=True, + capture_output=True) + except subprocess.CalledProcessError as e: + self.log(f'Failed download from ECMWF server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}') def fetch(self, request: t.Dict) -> t.Dict: status = None @@ -225,16 +216,18 @@ def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: class MARSECMWFServiceExtended(ECMWFService): """Extended MARS ECMFService class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): + initialize_connection = kwargs.pop('initialize_connection') super().__init__(*args, **kwargs) - self.c = SplitMARSRequest( - self.url, - "services/%s" % (self.service,), - email=self.email, - key=self.key, - log=self.log, - verbose=self.verbose, - quiet=self.quiet, - ) + if initialize_connection: + self.c = SplitMARSRequest( + self.url, + "services/%s" % (self.service,), + email=self.email, + key=self.key, + log=self.log, + verbose=self.verbose, + quiet=self.quiet, + ) def fetch(self, req: t.Dict) -> t.Dict: return self.c.fetch(req) @@ -263,15 +256,16 @@ class MarsClient(Client): level: Default log level for the client. """ - def __init__(self, config: Config, level: int = logging.INFO) -> None: - super().__init__(config, level) + def __init__(self, config: Config, level: int = logging.INFO, initialize_connection: bool = True) -> None: + super().__init__(config, level, initialize_connection) self.c = MARSECMWFServiceExtended( "mars", key=config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), url=config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), email=config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), log=self.logger.debug, - verbose=True + verbose=True, + initialize_connection=initialize_connection ) def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: diff --git a/weather_dl/download_pipeline/pipeline.py b/weather_dl/download_pipeline/pipeline.py index d101993d..c88db65d 100644 --- a/weather_dl/download_pipeline/pipeline.py +++ b/weather_dl/download_pipeline/pipeline.py @@ -197,7 +197,7 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: manifest = LocalManifest(Location(local_dir)) num_requesters_per_key = known_args.num_requests_per_key - client = CLIENTS[client_name](configs[0]) + client = CLIENTS[client_name](configs[0], initialize_connection=False) if num_requesters_per_key == -1: num_requesters_per_key = client.num_requests_per_key(config.dataset) From 7c7aa543a2bd80e5de032e2a3656bb207a08f895 Mon Sep 17 00:00:00 2001 From: mahrsee1997 Date: Wed, 4 Jan 2023 07:38:33 +0000 Subject: [PATCH 9/9] code changes as per Alex feedback. --- weather_dl/download_pipeline/clients.py | 40 +++++++++++------------- weather_dl/download_pipeline/pipeline.py | 2 +- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/weather_dl/download_pipeline/clients.py b/weather_dl/download_pipeline/clients.py index 6c4c2657..2d01bb93 100644 --- a/weather_dl/download_pipeline/clients.py +++ b/weather_dl/download_pipeline/clients.py @@ -47,12 +47,11 @@ class Client(abc.ABC): level: Default log level for the client. """ - def __init__(self, config: Config, level: int = logging.INFO, initialize_connection: bool = True) -> None: + def __init__(self, config: Config, level: int = logging.INFO) -> None: """Clients are initialized with the general CLI configuration.""" self.config = config self.logger = logging.getLogger(f'{__name__}.{type(self).__name__}') self.logger.setLevel(level) - self.initialize_connection = initialize_connection @abc.abstractmethod def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: @@ -94,7 +93,7 @@ class CdsClient(Client): """Name patterns of datasets that are hosted internally on CDS servers.""" cds_hosted_datasets = {'reanalysis-era'} - def __init__(self, config: Config, level: int = logging.INFO, initialize_connection: bool = True) -> None: + def __init__(self, config: Config, level: int = logging.INFO) -> None: super().__init__(config, level) self.c = cdsapi.Client( url=config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')), @@ -216,18 +215,16 @@ def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: class MARSECMWFServiceExtended(ECMWFService): """Extended MARS ECMFService class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): - initialize_connection = kwargs.pop('initialize_connection') super().__init__(*args, **kwargs) - if initialize_connection: - self.c = SplitMARSRequest( - self.url, - "services/%s" % (self.service,), - email=self.email, - key=self.key, - log=self.log, - verbose=self.verbose, - quiet=self.quiet, - ) + self.c = SplitMARSRequest( + self.url, + "services/%s" % (self.service,), + email=self.email, + key=self.key, + log=self.log, + verbose=self.verbose, + quiet=self.quiet, + ) def fetch(self, req: t.Dict) -> t.Dict: return self.c.fetch(req) @@ -256,19 +253,18 @@ class MarsClient(Client): level: Default log level for the client. """ - def __init__(self, config: Config, level: int = logging.INFO, initialize_connection: bool = True) -> None: - super().__init__(config, level, initialize_connection) + def __init__(self, config: Config, level: int = logging.INFO) -> None: + super().__init__(config, level) + + def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: self.c = MARSECMWFServiceExtended( "mars", - key=config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), - url=config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), - email=config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), + key=self.config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), + url=self.config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), + email=self.config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), log=self.logger.debug, verbose=True, - initialize_connection=initialize_connection ) - - def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None: selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): result = self.c.fetch(req=selection_) diff --git a/weather_dl/download_pipeline/pipeline.py b/weather_dl/download_pipeline/pipeline.py index c88db65d..d101993d 100644 --- a/weather_dl/download_pipeline/pipeline.py +++ b/weather_dl/download_pipeline/pipeline.py @@ -197,7 +197,7 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: manifest = LocalManifest(Location(local_dir)) num_requesters_per_key = known_args.num_requests_per_key - client = CLIENTS[client_name](configs[0], initialize_connection=False) + client = CLIENTS[client_name](configs[0]) if num_requesters_per_key == -1: num_requesters_per_key = client.num_requests_per_key(config.dataset)