Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions airflow/emr_serverless/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

DEFAULT_COUNTDOWN = 25 * 60
DEFAULT_CHECK_INTERVAL_SECONDS = 60


class EmrServerlessHook(AwsBaseHook):
"""
Expand Down Expand Up @@ -54,8 +57,8 @@ def waiter(
failure_states: Set,
object_type: str,
action: str,
countdown: int = 25 * 60,
check_interval_seconds: int = 60,
countdown: int = DEFAULT_COUNTDOWN,
check_interval_seconds: int = DEFAULT_CHECK_INTERVAL_SECONDS,
) -> None:
"""
Will run the sensor until it turns True.
Expand Down
13 changes: 12 additions & 1 deletion airflow/emr_serverless/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from typing import TYPE_CHECKING, Dict, Optional, Sequence
from uuid import uuid4

from emr_serverless.hooks.emr import EmrServerlessHook
from emr_serverless.hooks.emr import (
EmrServerlessHook,
DEFAULT_COUNTDOWN,
DEFAULT_CHECK_INTERVAL_SECONDS,
)

from emr_serverless.sensors.emr import (
EmrServerlessApplicationSensor,
EmrServerlessJobSensor,
Expand Down Expand Up @@ -158,6 +163,8 @@ def __init__(
config: Optional[dict] = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
countdown: int = DEFAULT_COUNTDOWN,
check_interval_seconds: int = DEFAULT_CHECK_INTERVAL_SECONDS,
**kwargs,
):
self.aws_conn_id = aws_conn_id
Expand All @@ -167,6 +174,8 @@ def __init__(
self.configuration_overrides = configuration_overrides
self.wait_for_completion = wait_for_completion
self.config = config or {}
self.countdown = countdown
self.check_interval_seconds = check_interval_seconds
super().__init__(**kwargs)

self.client_request_token = client_request_token or str(uuid4())
Expand Down Expand Up @@ -221,6 +230,8 @@ def execute(self, context: "Context") -> Dict:
failure_states=EmrServerlessJobSensor.FAILURE_STATES,
object_type="job",
action="run",
countdown=self.countdown,
check_interval_seconds=self.check_interval_seconds
)
return response["jobRunId"]

Expand Down