diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py index 6dc6e0ab77529..38ac0ff8fc4da 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py @@ -63,6 +63,7 @@ ) from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID +from airflow.triggers.base import StartTriggerArgs if TYPE_CHECKING: from google.api_core import operation @@ -1880,6 +1881,8 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator): The value is considered only when running in deferrable mode. Must be greater than 0. :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called :param wait_timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False + :param start_from_trigger: If True and deferrable is True, the operator will start directly + from the triggerer without occupying a worker slot. """ template_fields: Sequence[str] = ( @@ -1894,6 +1897,15 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator): operator_extra_links = (DataprocJobLink(),) + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger", + trigger_kwargs={}, + next_method="execute_complete", + next_kwargs=None, + timeout=None, + ) + start_from_trigger = False + def __init__( self, *, @@ -1911,6 +1923,7 @@ def __init__( polling_interval_seconds: int = 10, cancel_on_kill: bool = True, wait_timeout: int | None = None, + start_from_trigger: bool = False, openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), @@ -1938,9 +1951,28 @@ def __init__( self.hook: DataprocHook | None = None self.job_id: str | None = None self.wait_timeout = wait_timeout + self.start_from_trigger = start_from_trigger self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self.openlineage_inject_transport_info = openlineage_inject_transport_info + if self.deferrable and self.start_from_trigger: + self.start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger", + trigger_kwargs={ + "job": self.job, + "project_id": self.project_id, + "region": self.region, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "polling_interval_seconds": self.polling_interval_seconds, + "cancel_on_kill": self.cancel_on_kill, + "request_id": self.request_id, + }, + next_method="execute_complete", + next_kwargs=None, + timeout=None, + ) + def execute(self, context: Context): self.log.info("Submitting job") self.hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 73dd18c4c294a..76aa2f5021fe5 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -230,6 +230,155 @@ async def run(self): raise e +class DataprocSubmitJobDirectTrigger(DataprocBaseTrigger): + """ + Trigger that submits a Dataproc job and polls for its completion. + + Used for direct-to-triggerer functionality where job submission and polling + are handled entirely by the triggerer without requiring a worker. + + :param job: The job resource dict to submit. + :param project_id: Google Cloud Project where the job is running. + :param region: The Cloud Dataproc region in which to handle the request. + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + :param polling_interval_seconds: Polling period in seconds to check for the status. + :param cancel_on_kill: Flag indicating whether to cancel the job when on_kill is called. + :param request_id: Optional unique id used to identify the request. + """ + + def __init__( + self, + job: dict, + request_id: str | None = None, + **kwargs, + ): + self.job = job + self.request_id = request_id + self.job_id: str | None = None + super().__init__(**kwargs) + + def serialize(self): + return ( + "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger", + { + "job": self.job, + "request_id": self.request_id, + "project_id": self.project_id, + "region": self.region, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "polling_interval_seconds": self.polling_interval_seconds, + "cancel_on_kill": self.cancel_on_kill, + }, + ) + + if not AIRFLOW_V_3_0_PLUS: + + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + + :param session: Sqlalchemy session + """ + task_instance = session.scalar( + select(TaskInstance).where( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + ) + if task_instance is None: + raise RuntimeError( + "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + + async def get_task_state(self): + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( + dag_id=self.task_instance.dag_id, + task_ids=[self.task_instance.task_id], + run_ids=[self.task_instance.run_id], + map_index=self.task_instance.map_index, + ) + try: + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + except Exception: + raise RuntimeError( + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_state + + async def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + if AIRFLOW_V_3_0_PLUS: + task_state = await self.get_task_state() + else: + task_instance = self.get_task_instance() # type: ignore[call-arg] + task_state = task_instance.state + return task_state != TaskInstanceState.DEFERRED + + async def run(self) -> AsyncIterator[TriggerEvent]: + try: + hook = self.get_async_hook() + self.log.info("Submitting Dataproc job.") + job_object = await hook.submit_job( + project_id=self.project_id, + region=self.region, + job=self.job, + request_id=self.request_id, + ) + self.job_id = job_object.reference.job_id + self.log.info("Dataproc job %s submitted successfully.", self.job_id) + + while True: + job = await hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id) + state = job.status.state + self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) + if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): + break + await asyncio.sleep(self.polling_interval_seconds) + + yield TriggerEvent( + {"job_id": self.job_id, "job_state": JobStatus.State(state).name, "job": Job.to_dict(job)} + ) + except asyncio.CancelledError: + self.log.info("Task got cancelled.") + try: + if self.job_id and self.cancel_on_kill and await self.safe_to_cancel(): + self.log.info("Cancelling the job: %s", self.job_id) + self.get_sync_hook().cancel_job( + job_id=self.job_id, project_id=self.project_id, region=self.region + ) + self.log.info("Job: %s is cancelled", self.job_id) + yield TriggerEvent( + { + "job_id": self.job_id, + "job_state": ClusterStatus.State.DELETING.name, # type: ignore[attr-defined] + } + ) + except Exception as e: + self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) + raise e + + class DataprocClusterTrigger(DataprocBaseTrigger): """ DataprocClusterTrigger run on the trigger worker to perform create Build operation. diff --git a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py new file mode 100644 index 0000000000000..23f1d93cd9e58 --- /dev/null +++ b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for DataprocSubmitJobOperator with start_from_trigger. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from google.api_core.retry import Retry + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.dataproc import ( + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocSubmitJobOperator, +) + +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "dataproc_start_from_trigger" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-") +CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-") +CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else CLUSTER_NAME_FULL + +REGION = "europe-west1" + +# Cluster definition +CLUSTER_CONFIG = { + "master_config": { + "num_instances": 1, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, + "worker_config": { + "num_instances": 2, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, +} + +# Jobs definitions +SPARK_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "spark_job": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, +} + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "dataproc", "start_from_trigger"], +) as dag: + create_cluster = DataprocCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + cluster_config=CLUSTER_CONFIG, + region=REGION, + cluster_name=CLUSTER_NAME, + retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0), + num_retries_if_resource_is_not_ready=3, + ) + + spark_task = DataprocSubmitJobOperator( + task_id="spark_task", + job=SPARK_JOB, + region=REGION, + project_id=PROJECT_ID, + deferrable=True, + start_from_trigger=True, + ) + + delete_cluster = DataprocDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + cluster_name=CLUSTER_NAME, + region=REGION, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + create_cluster + # TEST BODY + >> spark_task + # TEST TEARDOWN + >> delete_cluster + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst) +test_run = get_test_run(dag) diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index d284c8755488a..edb2bbe6c1c84 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -2248,6 +2248,61 @@ def test_missing_region_parameter(self): impersonation_chain=IMPERSONATION_CHAIN, ) + def test_start_from_trigger_default_false(self): + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + job={}, + gcp_conn_id=GCP_CONN_ID, + ) + assert op.start_from_trigger is False + + def test_start_from_trigger_sets_start_trigger_args(self): + job = {"placement": {"cluster_name": "test-cluster"}} + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + job=job, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + start_from_trigger=True, + polling_interval_seconds=15, + cancel_on_kill=False, + request_id=REQUEST_ID, + ) + assert op.start_from_trigger is True + assert ( + op.start_trigger_args.trigger_cls + == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger" + ) + assert op.start_trigger_args.trigger_kwargs == { + "job": job, + "project_id": GCP_PROJECT, + "region": GCP_REGION, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + "polling_interval_seconds": 15, + "cancel_on_kill": False, + "request_id": REQUEST_ID, + } + assert op.start_trigger_args.next_method == "execute_complete" + + def test_start_from_trigger_without_deferrable_does_not_set_args(self): + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + job={}, + gcp_conn_id=GCP_CONN_ID, + deferrable=False, + start_from_trigger=True, + ) + assert op.start_from_trigger is True + assert op.start_trigger_args.trigger_kwargs == {} + @pytest.mark.db_test @pytest.mark.need_serialized_dag diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index aa66d2237ed4b..ed3e5c08f3610 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -31,6 +31,7 @@ DataprocBatchTrigger, DataprocClusterTrigger, DataprocOperationTrigger, + DataprocSubmitJobDirectTrigger, DataprocSubmitTrigger, ) from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType @@ -138,6 +139,26 @@ def submit_trigger(): ) +TEST_JOB = { + "placement": {"cluster_name": "test-cluster"}, + "pyspark_job": {"main_python_file_uri": "gs://test"}, +} +TEST_REQUEST_ID = "test-request-id" + + +@pytest.fixture +def submit_job_direct_trigger(): + return DataprocSubmitJobDirectTrigger( + job=TEST_JOB, + request_id=TEST_REQUEST_ID, + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + polling_interval_seconds=TEST_POLL_INTERVAL, + cancel_on_kill=True, + ) + + @pytest.fixture def async_get_batch(): def func(**kwargs): @@ -661,3 +682,136 @@ async def test_submit_trigger_run_cancelled( # Clean up the generator await async_gen.aclose() + + +class TestDataprocSubmitJobDirectTrigger: + def test_serialization(self, submit_job_direct_trigger): + classpath, kwargs = submit_job_direct_trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger" + assert kwargs == { + "job": TEST_JOB, + "request_id": TEST_REQUEST_ID, + "project_id": TEST_PROJECT_ID, + "region": TEST_REGION, + "gcp_conn_id": TEST_GCP_CONN_ID, + "polling_interval_seconds": TEST_POLL_INTERVAL, + "cancel_on_kill": True, + "impersonation_chain": None, + } + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook" + ) + async def test_run_submits_and_polls_success(self, mock_get_async_hook, submit_job_direct_trigger): + mock_hook = mock_get_async_hook.return_value + + mock_submitted_job = mock.MagicMock() + mock_submitted_job.reference.job_id = TEST_JOB_ID + submit_future = asyncio.Future() + submit_future.set_result(mock_submitted_job) + mock_hook.submit_job.return_value = submit_future + + mock_done_job = Job(status=JobStatus(state=JobStatus.State.DONE)) + get_future = asyncio.Future() + get_future.set_result(mock_done_job) + mock_hook.get_job.return_value = get_future + + async_gen = submit_job_direct_trigger.run() + event = await async_gen.asend(None) + + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE.name, "job": Job.to_dict(mock_done_job)} + ) + assert event.payload == expected_event.payload + + mock_hook.submit_job.assert_called_once_with( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + job=TEST_JOB, + request_id=TEST_REQUEST_ID, + ) + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook" + ) + async def test_run_submits_and_polls_error(self, mock_get_async_hook, submit_job_direct_trigger): + mock_hook = mock_get_async_hook.return_value + + mock_submitted_job = mock.MagicMock() + mock_submitted_job.reference.job_id = TEST_JOB_ID + submit_future = asyncio.Future() + submit_future.set_result(mock_submitted_job) + mock_hook.submit_job.return_value = submit_future + + mock_error_job = Job(status=JobStatus(state=JobStatus.State.ERROR)) + get_future = asyncio.Future() + get_future.set_result(mock_error_job) + mock_hook.get_job.return_value = get_future + + async_gen = submit_job_direct_trigger.run() + event = await async_gen.asend(None) + + expected_event = TriggerEvent( + { + "job_id": TEST_JOB_ID, + "job_state": JobStatus.State.ERROR.name, + "job": Job.to_dict(mock_error_job), + } + ) + assert event.payload == expected_event.payload + + @pytest.mark.asyncio + @pytest.mark.parametrize("is_safe_to_cancel", [True, False]) + @mock.patch( + "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_async_hook" + ) + @mock.patch( + "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.get_sync_hook" + ) + @mock.patch( + "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger.safe_to_cancel" + ) + async def test_run_cancelled_after_submit( + self, + mock_safe_to_cancel, + mock_get_sync_hook, + mock_get_async_hook, + submit_job_direct_trigger, + is_safe_to_cancel, + ): + mock_safe_to_cancel.return_value = is_safe_to_cancel + mock_hook = mock_get_async_hook.return_value + + mock_submitted_job = mock.MagicMock() + mock_submitted_job.reference.job_id = TEST_JOB_ID + submit_future = asyncio.Future() + submit_future.set_result(mock_submitted_job) + mock_hook.submit_job.return_value = submit_future + + mock_hook.get_job.side_effect = asyncio.CancelledError + + mock_sync_hook = mock_get_sync_hook.return_value + mock_sync_hook.cancel_job = mock.MagicMock() + + async_gen = submit_job_direct_trigger.run() + + try: + await async_gen.asend(None) + await async_gen.asend(None) + except (asyncio.CancelledError, StopAsyncIteration): + pass + except Exception as e: + pytest.fail(f"Unexpected exception raised: {e}") + + if submit_job_direct_trigger.cancel_on_kill and is_safe_to_cancel: + mock_sync_hook.cancel_job.assert_called_once_with( + job_id=TEST_JOB_ID, + project_id=submit_job_direct_trigger.project_id, + region=submit_job_direct_trigger.region, + ) + else: + mock_sync_hook.cancel_job.assert_not_called() + + await async_gen.aclose()