Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
)
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.triggers.base import StartTriggerArgs

if TYPE_CHECKING:
from google.api_core import operation
Expand Down Expand Up @@ -1880,6 +1881,8 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
The value is considered only when running in deferrable mode. Must be greater than 0.
:param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
:param wait_timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False
:param start_from_trigger: If True and deferrable is True, the operator will start directly
from the triggerer without occupying a worker slot.
"""

template_fields: Sequence[str] = (
Expand All @@ -1894,6 +1897,15 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):

operator_extra_links = (DataprocJobLink(),)

start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = False

def __init__(
self,
*,
Expand All @@ -1911,6 +1923,7 @@ def __init__(
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
wait_timeout: int | None = None,
start_from_trigger: bool = False,
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
Expand Down Expand Up @@ -1938,9 +1951,28 @@ def __init__(
self.hook: DataprocHook | None = None
self.job_id: str | None = None
self.wait_timeout = wait_timeout
self.start_from_trigger = start_from_trigger
self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
self.openlineage_inject_transport_info = openlineage_inject_transport_info

if self.deferrable and self.start_from_trigger:
self.start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
trigger_kwargs={
"job": self.job,
"project_id": self.project_id,
"region": self.region,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
"cancel_on_kill": self.cancel_on_kill,
"request_id": self.request_id,
},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)

def execute(self, context: Context):
self.log.info("Submitting job")
self.hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,155 @@ async def run(self):
raise e


class DataprocSubmitJobDirectTrigger(DataprocBaseTrigger):
"""
Trigger that submits a Dataproc job and polls for its completion.

Used for direct-to-triggerer functionality where job submission and polling
are handled entirely by the triggerer without requiring a worker.

:param job: The job resource dict to submit.
:param project_id: Google Cloud Project where the job is running.
:param region: The Cloud Dataproc region in which to handle the request.
:param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
:param impersonation_chain: Optional service account to impersonate using short-term credentials.
:param polling_interval_seconds: Polling period in seconds to check for the status.
:param cancel_on_kill: Flag indicating whether to cancel the job when on_kill is called.
:param request_id: Optional unique id used to identify the request.
"""

def __init__(
self,
job: dict,
request_id: str | None = None,
**kwargs,
):
self.job = job
self.request_id = request_id
self.job_id: str | None = None
super().__init__(**kwargs)

def serialize(self):
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobDirectTrigger",
{
"job": self.job,
"request_id": self.request_id,
"project_id": self.project_id,
"region": self.region,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
"cancel_on_kill": self.cancel_on_kill,
},
)

if not AIRFLOW_V_3_0_PLUS:

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
"""
Get the task instance for the current task.

:param session: Sqlalchemy session
"""
task_instance = session.scalar(
select(TaskInstance).where(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
)
if task_instance is None:
raise RuntimeError(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_instance

async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
dag_id=self.task_instance.dag_id,
task_ids=[self.task_instance.task_id],
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
try:
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
except Exception:
raise RuntimeError(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_state

async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.

This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
if AIRFLOW_V_3_0_PLUS:
task_state = await self.get_task_state()
else:
task_instance = self.get_task_instance() # type: ignore[call-arg]
task_state = task_instance.state
return task_state != TaskInstanceState.DEFERRED

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
hook = self.get_async_hook()
self.log.info("Submitting Dataproc job.")
job_object = await hook.submit_job(
project_id=self.project_id,
region=self.region,
job=self.job,
request_id=self.request_id,
)
self.job_id = job_object.reference.job_id
self.log.info("Dataproc job %s submitted successfully.", self.job_id)

while True:
job = await hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id)
state = job.status.state
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
break
await asyncio.sleep(self.polling_interval_seconds)

yield TriggerEvent(
{"job_id": self.job_id, "job_state": JobStatus.State(state).name, "job": Job.to_dict(job)}
)
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
self.log.info("Cancelling the job: %s", self.job_id)
self.get_sync_hook().cancel_job(
job_id=self.job_id, project_id=self.project_id, region=self.region
)
self.log.info("Job: %s is cancelled", self.job_id)
yield TriggerEvent(
{
"job_id": self.job_id,
"job_state": ClusterStatus.State.DELETING.name, # type: ignore[attr-defined]
}
)
except Exception as e:
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
raise e


class DataprocClusterTrigger(DataprocBaseTrigger):
"""
DataprocClusterTrigger run on the trigger worker to perform create Build operation.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for DataprocSubmitJobOperator with start_from_trigger.
"""

from __future__ import annotations

import os
from datetime import datetime

from google.api_core.retry import Retry

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateClusterOperator,
DataprocDeleteClusterOperator,
DataprocSubmitJobOperator,
)

try:
from airflow.sdk import TriggerRule
except ImportError:
# Compatibility for Airflow < 3.1
from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined]

from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
DAG_ID = "dataproc_start_from_trigger"
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID

CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-")
CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-")
CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else CLUSTER_NAME_FULL

REGION = "europe-west1"

# Cluster definition
CLUSTER_CONFIG = {
"master_config": {
"num_instances": 1,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32},
},
"worker_config": {
"num_instances": 2,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32},
},
}

# Jobs definitions
SPARK_JOB = {
"reference": {"project_id": PROJECT_ID},
"placement": {"cluster_name": CLUSTER_NAME},
"spark_job": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
"main_class": "org.apache.spark.examples.SparkPi",
},
}


with DAG(
DAG_ID,
schedule="@once",
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example", "dataproc", "start_from_trigger"],
) as dag:
create_cluster = DataprocCreateClusterOperator(
task_id="create_cluster",
project_id=PROJECT_ID,
cluster_config=CLUSTER_CONFIG,
region=REGION,
cluster_name=CLUSTER_NAME,
retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
num_retries_if_resource_is_not_ready=3,
)

spark_task = DataprocSubmitJobOperator(
task_id="spark_task",
job=SPARK_JOB,
region=REGION,
project_id=PROJECT_ID,
deferrable=True,
start_from_trigger=True,
)

delete_cluster = DataprocDeleteClusterOperator(
task_id="delete_cluster",
project_id=PROJECT_ID,
cluster_name=CLUSTER_NAME,
region=REGION,
trigger_rule=TriggerRule.ALL_DONE,
)

(
# TEST SETUP
create_cluster
# TEST BODY
>> spark_task
# TEST TEARDOWN
>> delete_cluster
)

from tests_common.test_utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "teardown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests_common.test_utils.system_tests import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst)
test_run = get_test_run(dag)
Loading
Loading