Skip to content

Commit

Permalink
Faster data transfers from MARS. (#235)
Browse files Browse the repository at this point in the history
I'm taking a leaf from @mahrsee1997's PR #148 so that we can copy data from the MARS server faster (using a larger buffer size). Thanks for the primary contribution here, Rahul.

* restructured the fetch stage to be a Composite Beam transform and separated the request & download stage of fetching

* retry logic of downloads for MARS client & other cosmetic changes.

* Remove fetch / dl split

* retrieve in two steps.

* rm fetch + dl methods.

* Fix: `nim_requests_per_key` does not require class construction.

* fix lint: removed unused import.

* add support for aria2 for faster download

* code changes as per Alex feedback.

Co-authored-by: mahrsee1997 <[email protected]>
  • Loading branch information
alxmrs and mahrsee1997 authored Jan 4, 2023
1 parent 00e659c commit 7cae996
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 27 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- numpy=1.22.4
- pandas=1.5.1
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- pip=22.3
- pip:
- earthengine-api==0.1.329
Expand Down
110 changes: 96 additions & 14 deletions weather_dl/download_pipeline/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
import json
import logging
import os
import subprocess
import typing as t
import warnings
from urllib.parse import urljoin

import cdsapi
import urllib3
from ecmwfapi import ECMWFService
from ecmwfapi import ECMWFService, api

from .config import Config, optimize_selection_partition
from .util import retry_with_exponential_backoff

warnings.simplefilter(
"ignore", category=urllib3.connectionpool.InsecureRequestWarning)
Expand Down Expand Up @@ -55,8 +58,9 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None:
"""Download from data source."""
pass

@classmethod
@abc.abstractmethod
def num_requests_per_key(self, dataset: str) -> int:
def num_requests_per_key(cls, dataset: str) -> int:
"""Specifies the number of workers to be used per api key for the dataset."""
pass

Expand Down Expand Up @@ -108,7 +112,8 @@ def retrieve(self, dataset: str, selection: t.Dict, target: str) -> None:
def license_url(self):
return 'https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf'

def num_requests_per_key(self, dataset: str) -> int:
@classmethod
def num_requests_per_key(cls, dataset: str) -> int:
"""Number of requests per key from the CDS API.
CDS has dynamic, data-specific limits, defined here:
Expand All @@ -123,7 +128,7 @@ def num_requests_per_key(self, dataset: str) -> int:
https://cds.climate.copernicus.eu/cdsapp#!/yourrequests
"""
# TODO(#15): Parse live CDS limits API to set data-specific limits.
for internal_set in self.cds_hosted_datasets:
for internal_set in cls.cds_hosted_datasets:
if dataset.startswith(internal_set):
return 5
return 2
Expand Down Expand Up @@ -154,6 +159,80 @@ def __exit__(self, exc_type, exc_value, traceback):
self._redirector.__exit__(exc_type, exc_value, traceback)


class SplitMARSRequest(api.APIRequest):
"""Extended MARS APIRequest class that separates fetch and download stage."""
@retry_with_exponential_backoff
def _download(self, url, path: str, size: int) -> None:
self.log(
"Transferring %s into %s" % (self._bytename(size), path)
)
self.log("From %s" % (url,))

dir_path, file_name = os.path.split(path)
try:
subprocess.run(
['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'],
check=True,
capture_output=True)
except subprocess.CalledProcessError as e:
self.log(f'Failed download from ECMWF server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}')

def fetch(self, request: t.Dict) -> t.Dict:
status = None

self.connection.submit("%s/%s/requests" % (self.url, self.service), request)
self.log("Request submitted")
self.log("Request id: " + self.connection.last.get("name"))
if self.connection.status != status:
status = self.connection.status
self.log("Request is %s" % (status,))

while not self.connection.ready():
if self.connection.status != status:
status = self.connection.status
self.log("Request is %s" % (status,))
self.connection.wait()

if self.connection.status != status:
status = self.connection.status
self.log("Request is %s" % (status,))

result = self.connection.result()
return result

def download(self, result: t.Dict, target: t.Optional[str] = None) -> None:
if target:
if os.path.exists(target):
# Empty the target file, if it already exists, otherwise the
# transfer below might be fooled into thinking we're resuming
# an interrupted download.
open(target, "w").close()

self._download(urljoin(self.url, result["href"]), target, result["size"])
self.connection.cleanup()


class MARSECMWFServiceExtended(ECMWFService):
"""Extended MARS ECMFService class that separates fetch and download stage."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.c = SplitMARSRequest(
self.url,
"services/%s" % (self.service,),
email=self.email,
key=self.key,
log=self.log,
verbose=self.verbose,
quiet=self.quiet,
)

def fetch(self, req: t.Dict) -> t.Dict:
return self.c.fetch(req)

def download(self, res: t.Dict, target: str) -> None:
self.c.download(res, target)


class MarsClient(Client):
"""A client to access data from the Meteorological Archival and Retrieval System (MARS).
Expand All @@ -176,25 +255,27 @@ class MarsClient(Client):

def __init__(self, config: Config, level: int = logging.INFO) -> None:
super().__init__(config, level)
self.c = ECMWFService(

def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None:
self.c = MARSECMWFServiceExtended(
"mars",
key=config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")),
url=config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")),
email=config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")),
key=self.config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")),
url=self.config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")),
email=self.config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")),
log=self.logger.debug,
verbose=True
verbose=True,
)

def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None:
selection_ = optimize_selection_partition(selection)
with StdoutLogger(self.logger, level=logging.DEBUG):
self.c.execute(req=selection_, target=output)
result = self.c.fetch(req=selection_)
self.c.download(result, target=output)

@property
def license_url(self):
return 'https://apps.ecmwf.int/datasets/licences/general/'

def num_requests_per_key(self, dataset: str) -> int:
@classmethod
def num_requests_per_key(cls, dataset: str) -> int:
"""Number of requests per key (or user) for the Mars API.
Mars allows 2 active requests per user and 20 queued requests per user, as of Sept 27, 2021.
Expand All @@ -221,7 +302,8 @@ def retrieve(self, dataset: str, selection: t.Dict, output: str) -> None:
def license_url(self):
return 'lorem ipsum'

def num_requests_per_key(self, dataset: str) -> int:
@classmethod
def num_requests_per_key(cls, dataset: str) -> int:
return 1


Expand Down
17 changes: 4 additions & 13 deletions weather_dl/download_pipeline/clients_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,20 @@
import unittest

from .clients import FakeClient, CdsClient, MarsClient
from .config import Config


class MaxWorkersTest(unittest.TestCase):
def test_cdsclient_internal(self):
client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}}))
self.assertEqual(
client.num_requests_per_key("reanalysis-era5-some-data"), 5)
self.assertEqual(CdsClient.num_requests_per_key("reanalysis-era5-some-data"), 5)

def test_cdsclient_mars_hosted(self):
client = CdsClient(Config.from_dict({'parameters': {'api_url': 'url', 'api_key': 'key'}}))
self.assertEqual(
client.num_requests_per_key("reanalysis-carra-height-levels"), 2)
self.assertEqual(CdsClient.num_requests_per_key("reanalysis-carra-height-levels"), 2)

def test_marsclient(self):
client = MarsClient(Config.from_dict({'parameters': {}}))
self.assertEqual(
client.num_requests_per_key("reanalysis-era5-some-data"), 2)
self.assertEqual(MarsClient.num_requests_per_key("reanalysis-era5-some-data"), 2)

def test_fakeclient(self):
client = FakeClient(Config.from_dict({'parameters': {}}))
self.assertEqual(
client.num_requests_per_key("reanalysis-era5-some-data"), 1)
self.assertEqual(FakeClient.num_requests_per_key("reanalysis-era5-some-data"), 1)


if __name__ == '__main__':
Expand Down

0 comments on commit 7cae996

Please sign in to comment.