Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

from airflow.providers.common.compat.sdk import BaseOperator, BaseOperatorLink, XCom, conf
from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink, XCom, conf
from airflow.providers.dbt.cloud.hooks.dbt import (
DbtCloudHook,
DbtCloudJobRunException,
Expand Down Expand Up @@ -70,7 +70,9 @@ class DbtCloudRunJobOperator(BaseOperator):
enabled but could be disabled to perform an asynchronous wait for a long-running job run execution
using the ``DbtCloudJobRunSensor``.
:param timeout: Time in seconds to wait for a job run to reach a terminal status for non-asynchronous
waits. Used only if ``wait_for_termination`` is True. Defaults to 7 days.
waits. Used only if ``wait_for_termination`` is True. This limits how long the operator waits for the
job to complete and does not imply job cancellation. Task-level timeouts should be
enforced via ``execution_timeout``. Defaults to 7 days.
:param check_interval: Time in seconds to check on a job run's status for non-asynchronous waits.
Used only if ``wait_for_termination`` is True. Defaults to 60 seconds.
:param additional_run_config: Optional. Any additional parameters that should be included in the API
Expand All @@ -83,6 +85,9 @@ class DbtCloudRunJobOperator(BaseOperator):
https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job
:param deferrable: Run operator in the deferrable mode
:param hook_params: Extra arguments passed to the DbtCloudHook constructor.
:param execution_timeout: Maximum time allowed for the task to run. If exceeded, the dbt Cloud
job will be cancelled and the task will fail. When both ``execution_timeout`` and
``timeout`` are set, the earlier deadline takes precedence.
:return: The ID of the triggered dbt Cloud job run.
"""

Expand Down Expand Up @@ -212,16 +217,26 @@ def execute(self, context: Context):
raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.")

return self.run_id

# Derive absolute deadlines for deferrable execution.
# execution_timeout is a hard task-level limit (cancels the job),
# while timeout only limits how long we wait for the job to finish.
# If both are set, the earliest deadline wins.
end_time = time.time() + self.timeout
execution_deadline = None
if self.execution_timeout:
execution_deadline = time.time() + self.execution_timeout.total_seconds()

job_run_info = JobRunInfo(account_id=self.account_id, run_id=self.run_id)
job_run_status = self.hook.get_job_run_status(**job_run_info)
if not DbtCloudJobRunStatus.is_terminal(job_run_status):
self.defer(
timeout=self.execution_timeout,
timeout=None,
trigger=DbtCloudRunJobTrigger(
conn_id=self.dbt_cloud_conn_id,
run_id=self.run_id,
end_time=end_time,
execution_deadline=execution_deadline,
account_id=self.account_id,
poll_interval=self.check_interval,
),
Expand Down Expand Up @@ -252,6 +267,12 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> int:
raise DbtCloudJobRunException(f"Job run {self.run_id} has been cancelled.")
if event["status"] == "error":
raise DbtCloudJobRunException(f"Job run {self.run_id} has failed.")

# Enforce execution_timeout semantics in deferrable mode by cancelling the job.
if event["status"] == "timeout":
self.hook.cancel_job_run(account_id=self.account_id, run_id=self.run_id)
raise AirflowException(f"Job run {self.run_id} has timed out.")

self.log.info(event["message"])
return int(event["run_id"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class DbtCloudRunJobTrigger(BaseTrigger):
:param conn_id: The connection identifier for connecting to Dbt.
:param run_id: The ID of a dbt Cloud job.
:param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param execution_deadline: Optional absolute timestamp (in seconds since the epoch) after which
the task is considered timed out.
:param account_id: The ID of a dbt Cloud account.
:param poll_interval: polling period in seconds to check for the status.
:param hook_params: Extra arguments passed to the DbtCloudHook constructor.
Expand All @@ -47,12 +49,14 @@ def __init__(
poll_interval: float,
account_id: int | None,
hook_params: dict[str, Any] | None = None,
execution_deadline: float | None = None,
):
super().__init__()
self.run_id = run_id
self.account_id = account_id
self.conn_id = conn_id
self.end_time = end_time
self.execution_deadline = execution_deadline
self.poll_interval = poll_interval
self.hook_params = hook_params or {}

Expand All @@ -65,6 +69,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"account_id": self.account_id,
"conn_id": self.conn_id,
"end_time": self.end_time,
"execution_deadline": self.execution_deadline,
"poll_interval": self.poll_interval,
"hook_params": self.hook_params,
},
Expand All @@ -75,6 +80,17 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
hook = DbtCloudHook(self.conn_id, **self.hook_params)
try:
while await self.is_still_running(hook):
if self.execution_deadline is not None:
if self.execution_deadline < time.time():
yield TriggerEvent(
{
"status": "timeout",
"message": f"Job run {self.run_id} has timed out.",
"run_id": self.run_id,
}
)
return

if self.end_time < time.time():
# Perform a final status check before declaring timeout, in case the
# job completed between the last poll and the timeout expiry.
Expand Down
38 changes: 37 additions & 1 deletion providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest

from airflow.models import DAG, Connection
from airflow.providers.common.compat.sdk import TaskDeferred, timezone
from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred, timezone
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus
from airflow.providers.dbt.cloud.operators.dbt import (
DbtCloudGetJobRunArtifactOperator,
Expand Down Expand Up @@ -214,6 +214,42 @@ def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_jo
dbt_op.execute(MagicMock())
assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger"

def test_execute_complete_timeout_cancels_job(self):
"""
Verify that when a deferrable dbt job emits a timeout event,
the operator cancels the job and fails.
"""
operator = DbtCloudRunJobOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=ACCOUNT_ID_CONN,
job_id=JOB_ID,
dag=self.dag,
deferrable=True,
)

# Pretend the job was already triggered.
operator.run_id = RUN_ID

# Mock the hook so we can assert cancellation.
operator.hook = MagicMock()

timeout_event = {
"status": "timeout",
"run_id": RUN_ID,
"message": "Job run timed out.",
}

with pytest.raises(AirflowException, match="has timed out"):
operator.execute_complete(
context=self.mock_context,
event=timeout_event,
)

operator.hook.cancel_job_run.assert_called_once_with(
account_id=operator.account_id,
run_id=RUN_ID,
)

@patch(
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_by_name",
return_value=mock_response_json(DEFAULT_ACCOUNT_JOB_RESPONSE),
Expand Down
33 changes: 33 additions & 0 deletions providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TestDbtCloudRunJobTrigger:
CONN_ID = "dbt_cloud_default"
ACCOUNT_ID = 12340
END_TIME = time.time() + 60 * 60 * 24 * 7
EXECUTION_DEADLINE = time.time() + 60 * 60 * 24 * 7
POLL_INTERVAL = 3.0

def test_serialization(self):
Expand All @@ -43,6 +44,7 @@ def test_serialization(self):
conn_id=self.CONN_ID,
poll_interval=self.POLL_INTERVAL,
end_time=self.END_TIME,
execution_deadline=self.EXECUTION_DEADLINE,
run_id=self.RUN_ID,
account_id=self.ACCOUNT_ID,
hook_params={"retry_delay": 10},
Expand All @@ -54,6 +56,7 @@ def test_serialization(self):
"account_id": self.ACCOUNT_ID,
"conn_id": self.CONN_ID,
"end_time": self.END_TIME,
"execution_deadline": self.EXECUTION_DEADLINE,
"poll_interval": self.POLL_INTERVAL,
"hook_params": {"retry_delay": 10},
}
Expand Down Expand Up @@ -255,6 +258,36 @@ async def test_dbt_job_run_timeout_with_final_status_check(self, mock_get_job_st
)
assert expected == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger.is_still_running")
async def test_dbt_job_run_execution_timeout(self, mocked_is_still_running):
"""Assert that run emits timeout event after execution_deadline elapsed"""
mocked_is_still_running.return_value = True

execution_deadline = time.time()

trigger = DbtCloudRunJobTrigger(
conn_id=self.CONN_ID,
poll_interval=self.POLL_INTERVAL,
end_time=time.time() + 60,
execution_deadline=execution_deadline,
run_id=self.RUN_ID,
account_id=self.ACCOUNT_ID,
)

generator = trigger.run()
actual = await generator.asend(None)

expected = TriggerEvent(
{
"status": "timeout",
"message": f"Job run {self.RUN_ID} has timed out.",
"run_id": self.RUN_ID,
}
)

assert expected == actual

@pytest.mark.asyncio
@pytest.mark.parametrize(
("mock_response", "expected_status"),
Expand Down
Loading