diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py index 9d4c59473b1f0..24e05e1f236b4 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py @@ -75,17 +75,21 @@ async def run(self) -> AsyncIterator[TriggerEvent]: hook = DbtCloudHook(self.conn_id, **self.hook_params) try: while await self.is_still_running(hook): - if self.end_time < time.time(): - yield TriggerEvent( - { - "status": "error", - "message": f"Job run {self.run_id} has not reached a terminal status after " - f"{self.end_time} seconds.", - "run_id": self.run_id, - } - ) - return await asyncio.sleep(self.poll_interval) + if self.end_time < time.time(): + # Final status check: the job may have completed during the sleep. + if await self.is_still_running(hook): + yield TriggerEvent( + { + "status": "error", + "message": f"Job run {self.run_id} has not reached a terminal " + f"status within the configured timeout.", + "run_id": self.run_id, + } + ) + return + # Job reached a terminal state — exit loop to handle below. + break job_run_status = await hook.get_job_status(self.run_id, self.account_id) if job_run_status == DbtCloudJobRunStatus.SUCCESS.value: yield TriggerEvent( diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py index 57ca1848d1542..539f5680c8329 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py @@ -202,13 +202,13 @@ async def test_dbt_job_run_exception(self, mock_get_job_status, mocked_is_still_ @mock.patch("airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger.is_still_running") @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_status") async def test_dbt_job_run_timeout(self, mock_get_job_status, mocked_is_still_running): - """Assert that run timeout after end_time elapsed""" + """Assert that run yields a timeout error after end_time has elapsed.""" mocked_is_still_running.return_value = True mock_get_job_status.side_effect = Exception("Test exception") end_time = time.time() trigger = DbtCloudRunJobTrigger( conn_id=self.CONN_ID, - poll_interval=self.POLL_INTERVAL, + poll_interval=0.1, end_time=end_time, run_id=self.RUN_ID, account_id=self.ACCOUNT_ID, @@ -219,7 +219,35 @@ async def test_dbt_job_run_timeout(self, mock_get_job_status, mocked_is_still_ru { "status": "error", "message": f"Job run {self.RUN_ID} has not reached a terminal status " - f"after {end_time} seconds.", + f"within the configured timeout.", + "run_id": self.RUN_ID, + } + ) + assert expected == actual + + @pytest.mark.asyncio + @mock.patch("airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger.is_still_running") + @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_status") + async def test_dbt_job_run_timeout_but_job_completes( + self, mock_get_job_status, mocked_is_still_running + ): + """Assert that a job completing at the timeout boundary is treated as success, not timeout.""" + mocked_is_still_running.side_effect = [True, False] + mock_get_job_status.return_value = DbtCloudJobRunStatus.SUCCESS.value + end_time = time.time() + trigger = DbtCloudRunJobTrigger( + conn_id=self.CONN_ID, + poll_interval=0.1, + end_time=end_time, + run_id=self.RUN_ID, + account_id=self.ACCOUNT_ID, + ) + generator = trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + { + "status": "success", + "message": f"Job run {self.RUN_ID} has completed successfully.", "run_id": self.RUN_ID, } )