diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 8558a71236eb7..2d70e625c1544 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1414,9 +1414,21 @@ traces: description: | If True, then traces from Airflow internal methods are exported. Defaults to False. version_added: 3.1.0 + version_deprecated: 3.2.0 + deprecation_reason: | + This parameter is no longer used. type: string example: ~ default: "False" + task_runner_flush_timeout_milliseconds: + description: | + Timeout in milliseconds to wait for the OpenTelemetry span exporter to flush pending spans + when a task runner process exits. If the exporter does not finish within this time, any + buffered spans may be dropped. + version_added: 3.2.0 + type: integer + example: ~ + default: "30000" secrets: description: ~ options: diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 2997d55d8bb3b..d67c25c7bafaa 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -32,14 +32,11 @@ from airflow.configuration import conf from airflow.executors import workloads from airflow.executors.executor_loader import ExecutorLoader -from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models import Log from airflow.models.callback import CallbackKey from airflow.observability.metrics import stats_utils -from airflow.observability.trace import Trace from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState -from airflow.utils.thread_safe_dict import ThreadSafeDict PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -143,8 +140,6 @@ class BaseExecutor(LoggingMixin): :param parallelism: how many jobs should run at one time. """ - active_spans = ThreadSafeDict() - supports_ad_hoc_ti_run: bool = False supports_callbacks: bool = False supports_multi_team: bool = False @@ -217,10 +212,6 @@ def __repr__(self): _repr += ")" return _repr - @classmethod - def set_active_spans(cls, active_spans: ThreadSafeDict): - cls.active_spans = active_spans - def start(self): # pragma: no cover """Executors may need to get things started.""" @@ -340,17 +331,6 @@ def _emit_metrics(self, open_slots, num_running_tasks, num_queued_tasks): queued_tasks_metric_name = self._get_metric_name("executor.queued_tasks") running_tasks_metric_name = self._get_metric_name("executor.running_tasks") - span = Trace.get_current_span() - if span.is_recording(): - span.add_event( - name="executor", - attributes={ - open_slots_metric_name: open_slots, - queued_tasks_metric_name: num_queued_tasks, - running_tasks_metric_name: num_running_tasks, - }, - ) - self.log.debug("%s running task instances for executor %s", num_running_tasks, name) self.log.debug("%s in queue for executor %s", num_queued_tasks, name) if open_slots == 0: @@ -415,30 +395,6 @@ def trigger_tasks(self, open_slots: int) -> None: if key in self.attempts: del self.attempts[key] - if isinstance(workload, workloads.ExecuteTask) and hasattr(workload, "ti"): - ti = workload.ti - - # If it's None, then the span for the current id hasn't been started. - if self.active_spans is not None and self.active_spans.get("ti:" + str(ti.id)) is None: - if isinstance(ti, TaskInstanceDTO): - parent_context = Trace.extract(ti.parent_context_carrier) - else: - parent_context = Trace.extract(ti.dag_run.context_carrier) - # Start a new span using the context from the parent. - # Attributes will be set once the task has finished so that all - # values will be available (end_time, duration, etc.). - - span = Trace.start_child_span( - span_name=f"{ti.task_id}", - parent_context=parent_context, - component="task", - start_as_current=False, - ) - self.active_spans.set("ti:" + str(ti.id), span) - # Inject the current context into the carrier. - carrier = Trace.inject() - ti.context_carrier = carrier - workload_list.append(workload) if workload_list: diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index d691dcb6f0968..a5939cf424412 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -86,7 +86,7 @@ def make( from airflow.utils.helpers import log_filename_template_renderer ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True) - ser_ti.parent_context_carrier = ti.dag_run.context_carrier + ser_ti.context_carrier = ti.dag_run.context_carrier if not bundle_info: bundle_info = BundleInfo( name=ti.dag_model.bundle_name, diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 2d58f295c6ecc..c667631756edc 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -32,7 +32,6 @@ from functools import lru_cache, partial from itertools import groupby from typing import TYPE_CHECKING, Any -from uuid import UUID from sqlalchemy import ( and_, @@ -98,17 +97,14 @@ from airflow.models.team import Team from airflow.models.trigger import TRIGGER_FAIL_REPR, Trigger, TriggerFailureReason from airflow.observability.metrics import stats_utils -from airflow.observability.trace import Trace from airflow.serialization.definitions.assets import SerializedAssetUniqueKey from airflow.serialization.definitions.notset import NOTSET from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.timetables.simple import AssetTriggeredTimetable -from airflow.utils.dates import datetime_to_nano from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.span_status import SpanStatus from airflow.utils.sqlalchemy import ( get_dialect_name, is_lock_not_available_error, @@ -116,7 +112,6 @@ with_row_locks, ) from airflow.utils.state import CallbackState, DagRunState, State, TaskInstanceState -from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: @@ -273,14 +268,6 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): job_type = "SchedulerJob" - # For a dagrun span - # - key: dag_run.run_id | value: span - # - dagrun keys will be prefixed with 'dr:'. - # For a ti span - # - key: ti.id | value: span - # - taskinstance keys will be prefixed with 'ti:'. - active_spans = ThreadSafeDict() - def __init__( self, job: Job, @@ -434,9 +421,6 @@ def _get_workload_team_name(self, workload: SchedulerWorkload, session: Session) def _exit_gracefully(self, signum: int, frame: FrameType | None) -> None: """Clean up processor_agent to avoid leaving orphan processes.""" - if self._is_tracing_enabled(): - self._end_active_spans() - if not _is_parent_process(): # Only the parent process should perform the cleanup. return @@ -1311,18 +1295,6 @@ def process_executor_events( ti.pid, ) - if (active_ti_span := cls.active_spans.get("ti:" + str(ti.id))) is not None: - cls.set_ti_span_attrs(span=active_ti_span, state=state, ti=ti) - # End the span and remove it from the active_spans dict. - active_ti_span.end(end_time=datetime_to_nano(ti.end_date)) - cls.active_spans.delete("ti:" + str(ti.id)) - ti.span_status = SpanStatus.ENDED - else: - if ti.span_status == SpanStatus.ACTIVE: - # Another scheduler has started the span. - # Update the SpanStatus to let the process know that it must end it. - ti.span_status = SpanStatus.SHOULD_END - # There are two scenarios why the same TI with the same try_number is queued # after executor is finished with it: # 1) the TI was killed externally and it had no time to mark itself failed @@ -1459,39 +1431,6 @@ def process_executor_events( return len(event_buffer) - @classmethod - def set_ti_span_attrs(cls, span, state, ti): - span.set_attributes( - { - "airflow.category": "scheduler", - "airflow.task.id": ti.id, - "airflow.task.task_id": ti.task_id, - "airflow.task.dag_id": ti.dag_id, - "airflow.task.state": ti.state, - "airflow.task.error": state == TaskInstanceState.FAILED, - "airflow.task.start_date": str(ti.start_date), - "airflow.task.end_date": str(ti.end_date), - "airflow.task.duration": ti.duration, - "airflow.task.executor_config": str(ti.executor_config), - "airflow.task.logical_date": str(ti.logical_date), - "airflow.task.hostname": ti.hostname, - "airflow.task.log_url": ti.log_url, - "airflow.task.operator": str(ti.operator), - "airflow.task.try_number": ti.try_number, - "airflow.task.executor_state": state, - "airflow.task.pool": ti.pool, - "airflow.task.queue": ti.queue, - "airflow.task.priority_weight": ti.priority_weight, - "airflow.task.queued_dttm": str(ti.queued_dttm), - "airflow.task.queued_by_job_id": ti.queued_by_job_id, - "airflow.task.pid": ti.pid, - } - ) - if span.is_recording(): - span.add_event(name="airflow.task.queued", timestamp=datetime_to_nano(ti.queued_dttm)) - span.add_event(name="airflow.task.started", timestamp=datetime_to_nano(ti.start_date)) - span.add_event(name="airflow.task.ended", timestamp=datetime_to_nano(ti.end_date)) - def _execute(self) -> int | None: import os @@ -1515,12 +1454,6 @@ def _execute(self) -> int | None: executor.start() # local import due to type_checking. - from airflow.executors.base_executor import BaseExecutor - - # Pass a reference to the dictionary. - # Any changes made by a dag_run instance, will be reflected to the dictionary of this class. - DagRun.set_active_spans(active_spans=self.active_spans) - BaseExecutor.set_active_spans(active_spans=self.active_spans) stats_factory = stats_utils.get_stats_factory(Stats) Stats.initialize(factory=stats_factory) @@ -1571,162 +1504,6 @@ def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION) except Exception as e: # should not fail the scheduler self.log.exception("Failed to update dag run state for paused dags due to %s", e) - @provide_session - def _end_active_spans(self, session: Session = NEW_SESSION): - # No need to do a commit for every update. The annotation will commit all of them once at the end. - for prefixed_key, span in self.active_spans.get_all().items(): - # Use partition to split on the first occurrence of ':'. - prefix, sep, key = prefixed_key.partition(":") - - if prefix == "ti": - ti_result = session.get(TaskInstance, UUID(key)) - if ti_result is None: - continue - ti: TaskInstance = ti_result - - if ti.state in State.finished: - self.set_ti_span_attrs(span=span, state=ti.state, ti=ti) - span.end(end_time=datetime_to_nano(ti.end_date)) - ti.span_status = SpanStatus.ENDED - else: - span.end() - ti.span_status = SpanStatus.NEEDS_CONTINUANCE - elif prefix == "dr": - dag_run: DagRun | None = session.scalars( - select(DagRun).where(DagRun.id == int(key)) - ).one_or_none() - if dag_run is None: - continue - if dag_run.state in State.finished_dr_states: - dag_run.set_dagrun_span_attrs(span=span) - - span.end(end_time=datetime_to_nano(dag_run.end_date)) - dag_run.span_status = SpanStatus.ENDED - else: - span.end() - dag_run.span_status = SpanStatus.NEEDS_CONTINUANCE - initial_dag_run_context = Trace.extract(dag_run.context_carrier) - with Trace.start_child_span( - span_name="current_scheduler_exited", parent_context=initial_dag_run_context - ) as s: - s.set_attribute("trace_status", "needs continuance") - else: - self.log.error("Found key with unknown prefix: '%s'", prefixed_key) - - # Even if there is a key with an unknown prefix, clear the dict. - # If this method has been called, the scheduler is exiting. - self.active_spans.clear() - - def _end_spans_of_externally_ended_ops(self, session: Session): - # The scheduler that starts a dag_run or a task is also the one that starts the spans. - # Each scheduler should end the spans that it has started. - # - # Otel spans are implemented in a certain way so that the objects - # can't be shared between processes or get recreated. - # It is done so that the process that starts a span, is also the one that ends it. - # - # If another scheduler has finished processing a dag_run or a task and there is a reference - # on the active_spans dictionary, then the current scheduler started the span, - # and therefore must end it. - dag_runs_should_end: list[DagRun] = list( - session.scalars(select(DagRun).where(DagRun.span_status == SpanStatus.SHOULD_END)) - ) - tis_should_end: list[TaskInstance] = list( - session.scalars(select(TaskInstance).where(TaskInstance.span_status == SpanStatus.SHOULD_END)) - ) - - for dag_run in dag_runs_should_end: - active_dagrun_span = self.active_spans.get("dr:" + str(dag_run.id)) - if active_dagrun_span is not None: - if dag_run.state in State.finished_dr_states: - dag_run.set_dagrun_span_attrs(span=active_dagrun_span) - - active_dagrun_span.end(end_time=datetime_to_nano(dag_run.end_date)) - else: - active_dagrun_span.end() - self.active_spans.delete("dr:" + str(dag_run.id)) - dag_run.span_status = SpanStatus.ENDED - - for ti in tis_should_end: - active_ti_span = self.active_spans.get(f"ti:{ti.id}") - if active_ti_span is not None: - if ti.state in State.finished: - self.set_ti_span_attrs(span=active_ti_span, state=ti.state, ti=ti) - active_ti_span.end(end_time=datetime_to_nano(ti.end_date)) - else: - active_ti_span.end() - self.active_spans.delete(f"ti:{ti.id}") - ti.span_status = SpanStatus.ENDED - - def _recreate_unhealthy_scheduler_spans_if_needed(self, dag_run: DagRun, session: Session): - # There are two scenarios: - # 1. scheduler is unhealthy but managed to update span_status - # 2. scheduler is unhealthy and didn't manage to make any updates - # Check the span_status first, in case the 2nd db query can be avoided (scenario 1). - - # If the dag_run is scheduled by a different scheduler, and it's still running and the span is active, - # then check the Job table to determine if the initial scheduler is still healthy. - if ( - dag_run.scheduled_by_job_id != self.job.id - and dag_run.state in State.unfinished_dr_states - and dag_run.span_status == SpanStatus.ACTIVE - ): - initial_scheduler_id = dag_run.scheduled_by_job_id - job: Job | None = session.scalars( - select(Job).where( - Job.id == initial_scheduler_id, - Job.job_type == "SchedulerJob", - ) - ).one_or_none() - if job is None: - return - - if not job.is_alive(): - # Start a new span for the dag_run. - dr_span = Trace.start_root_span( - span_name=f"{dag_run.dag_id}_recreated", - component="dag", - start_time=dag_run.queued_at, - start_as_current=False, - ) - carrier = Trace.inject() - # Update the context_carrier and leave the SpanStatus as ACTIVE. - dag_run.context_carrier = carrier - self.active_spans.set("dr:" + str(dag_run.id), dr_span) - - tis = dag_run.get_task_instances(session=session) - - # At this point, any tis will have been adopted by the current scheduler, - # and ti.queued_by_job_id will point to the current id. - # Any tis that have been executed by the unhealthy scheduler, will need a new span - # so that it can be associated with the new dag_run span. - tis_needing_spans = [ - ti - for ti in tis - # If it has started and there is a reference on the active_spans dict, - # then it was started by the current scheduler. - if ti.start_date is not None and self.active_spans.get(f"ti:{ti.id}") is None - ] - - dr_context = Trace.extract(dag_run.context_carrier) - for ti in tis_needing_spans: - ti_span = Trace.start_child_span( - span_name=f"{ti.task_id}_recreated", - parent_context=dr_context, - start_time=ti.queued_dttm, - start_as_current=False, - ) - ti_carrier = Trace.inject() - ti.context_carrier = ti_carrier - - if ti.state in State.finished: - self.set_ti_span_attrs(span=ti_span, state=ti.state, ti=ti) - ti_span.end(end_time=datetime_to_nano(ti.end_date)) - ti.span_status = SpanStatus.ENDED - else: - ti.span_status = SpanStatus.ACTIVE - self.active_spans.set(f"ti:{ti.id}", ti_span) - def _run_scheduler_loop(self) -> None: """ Harvest DAG parsing results, queue tasks, and perform executor heartbeat; the actual scheduler loop. @@ -1819,9 +1596,6 @@ def _run_scheduler_loop(self) -> None: for loop_count in itertools.count(start=1): with Stats.timer("scheduler.scheduler_loop_duration") as timer: with create_session() as session: - if self._is_tracing_enabled(): - self._end_spans_of_externally_ended_ops(session) - # This will schedule for as many executors as possible. num_queued_tis = self._do_scheduling(session) # Don't keep any objects alive -- we've possibly just looked at 500+ ORM objects! @@ -2357,16 +2131,6 @@ def _start_queued_dagruns(self, session: Session) -> None: active_runs_of_dags = Counter({(dag_id, br_id): num for dag_id, br_id, num in session.execute(query)}) def _update_state(dag: SerializedDAG, dag_run: DagRun): - span = Trace.get_current_span() - span.set_attributes( - { - "state": str(DagRunState.RUNNING), - "run_id": dag_run.run_id, - "type": dag_run.run_type, - "dag_id": dag_run.dag_id, - } - ) - dag_run.state = DagRunState.RUNNING dag_run.start_date = timezone.utcnow() if ( @@ -2383,18 +2147,12 @@ def _update_state(dag: SerializedDAG, dag_run: DagRun): tags={}, extra_tags={"dag_id": dag.dag_id}, ) - if span.is_recording(): - span.add_event( - name="schedule_delay", - attributes={"dag_id": dag.dag_id, "schedule_delay": str(schedule_delay)}, - ) # cache saves time during scheduling of many dag_runs for same dag cached_get_dag: Callable[[DagRun], SerializedDAG | None] = lru_cache()( partial(self.scheduler_dag_bag.get_dag_for_run, session=session) ) - span = Trace.get_current_span() for dag_run in dag_runs: dag_id = dag_run.dag_id run_id = dag_run.run_id @@ -2434,15 +2192,6 @@ def _update_state(dag: SerializedDAG, dag_run: DagRun): dag_run.run_id, ) continue - if span.is_recording(): - span.add_event( - name="dag_run", - attributes={ - "run_id": dag_run.run_id, - "dag_id": dag_run.dag_id, - "conf": str(dag_run.conf), - }, - ) active_runs_of_dags[(dag_run.dag_id, backfill_id)] += 1 _update_state(dag, dag_run) dag_run.notify_dagrun_state_changed(msg="started") @@ -2554,17 +2303,6 @@ def _schedule_dag_run( self.log.warning("The DAG disappeared before verifying integrity: %s. Skipping.", dag_run.dag_id) return callback - if ( - self._is_tracing_enabled() - and dag_run.scheduled_by_job_id is not None - and dag_run.scheduled_by_job_id != self.job.id - and self.active_spans.get("dr:" + str(dag_run.id)) is None - ): - # If the dag_run has been previously scheduled by another job and there is no active span, - # then check if the job is still healthy. - # If it's not healthy, then recreate the spans. - self._recreate_unhealthy_scheduler_spans_if_needed(dag_run, session) - dag_run.scheduled_by_job_id = self.job.id # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index ca11647811099..1406283c05cb3 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -50,7 +50,6 @@ from airflow.jobs.job import perform_heartbeat from airflow.models.trigger import Trigger from airflow.observability.metrics import stats_utils -from airflow.observability.trace import Trace from airflow.sdk.api.datamodels._generated import HITLDetailResponse from airflow.sdk.execution_time.comms import ( CommsDecoder, @@ -627,15 +626,6 @@ def emit_metrics(self): extra_tags={"hostname": self.job.hostname}, ) - span = Trace.get_current_span() - span.set_attributes( - { - "trigger host": self.job.hostname, - "triggers running": len(self.running_triggers), - "capacity left": capacity_left, - } - ) - def update_triggers(self, requested_trigger_ids: set[int]): """ Request that we update what triggers we're running. diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 4bc47dddea7ac..61242e45390d6 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -28,6 +28,9 @@ from uuid import UUID import structlog +from opentelemetry import context, trace +from opentelemetry.trace import StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from sqlalchemy import ( JSON, Enum, @@ -72,12 +75,11 @@ from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH from airflow.models.tasklog import LogTemplate from airflow.models.taskmap import TaskMap -from airflow.observability.trace import Trace +from airflow.observability.traces import new_dagrun_trace_carrier, override_ids from airflow.serialization.definitions.deadline import SerializedReferenceModels from airflow.serialization.definitions.notset import NOTSET, ArgNotSet, is_arg_set from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES -from airflow.utils.dates import datetime_to_nano from airflow.utils.helpers import chunks, is_container, prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.retries import retry_db_transaction @@ -92,19 +94,16 @@ ) from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.strings import get_random_string -from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: from typing import Literal, TypeAlias - from opentelemetry.sdk.trace import Span from pydantic import NonNegativeInt from sqlalchemy.engine import ScalarResult from sqlalchemy.orm import Session from sqlalchemy.sql.elements import Case, ColumnElement - from airflow._shared.observability.traces.base_tracer import EmptySpan from airflow.models.dag_version import DagVersion from airflow.models.taskinstancekey import TaskInstanceKey from airflow.sdk import DAG as SDKDAG @@ -120,6 +119,8 @@ log = structlog.get_logger(__name__) +tracer = trace.get_tracer(__name__) + class TISchedulingDecision(NamedTuple): """Type of return for DagRun.task_instance_scheduling_decisions.""" @@ -153,8 +154,6 @@ class DagRun(Base, LoggingMixin): external trigger (i.e. manual runs). """ - active_spans = ThreadSafeDict() - __tablename__ = "dag_run" id: Mapped[int] = mapped_column(Integer, primary_key=True) @@ -368,7 +367,8 @@ def __init__( self.triggered_by = triggered_by self.triggering_user_name = triggering_user_name self.scheduled_by_job_id = None - self.context_carrier = {} + self.context_carrier: dict[str, str] = new_dagrun_trace_carrier() + if not isinstance(partition_key, str | None): raise ValueError( f"Expected partition_key to be a `str` or `None` but got `{partition_key.__class__.__name__}`" @@ -461,10 +461,6 @@ def check_version_id_exists_in_dr(self, dag_version_id: UUID, session: Session = def stats_tags(self) -> dict[str, str]: return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type}) - @classmethod - def set_active_spans(cls, active_spans: ThreadSafeDict): - cls.active_spans = active_spans - def get_state(self): return self._state @@ -1019,131 +1015,28 @@ def is_effective_leaf(task): leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED} return leaf_tis - def set_dagrun_span_attrs(self, span: Span | EmptySpan): - if self._state == DagRunState.FAILED: - span.set_attribute("airflow.dag_run.error", True) - - # Explicitly set the value type to Union[...] to avoid a mypy error. - attributes: dict[str, AttributeValueType] = { - "airflow.category": "DAG runs", - "airflow.dag_run.dag_id": str(self.dag_id), - "airflow.dag_run.logical_date": str(self.logical_date), - "airflow.dag_run.run_id": str(self.run_id), - "airflow.dag_run.queued_at": str(self.queued_at), - "airflow.dag_run.run_start_date": str(self.start_date), - "airflow.dag_run.run_end_date": str(self.end_date), - "airflow.dag_run.run_duration": str( - (self.end_date - self.start_date).total_seconds() if self.start_date and self.end_date else 0 - ), - "airflow.dag_run.state": str(self._state), - "airflow.dag_run.run_type": str(self.run_type), - "airflow.dag_run.data_interval_start": str(self.data_interval_start), - "airflow.dag_run.data_interval_end": str(self.data_interval_end), - "airflow.dag_run.conf": str(self.conf), - } - if span.is_recording(): - span.add_event(name="airflow.dag_run.queued", timestamp=datetime_to_nano(self.queued_at)) - span.add_event(name="airflow.dag_run.started", timestamp=datetime_to_nano(self.start_date)) - span.add_event(name="airflow.dag_run.ended", timestamp=datetime_to_nano(self.end_date)) - span.set_attributes(attributes) - - def start_dr_spans_if_needed(self, tis: list[TI]): - # If there is no value in active_spans, then the span hasn't already been started. - if self.active_spans is not None and self.active_spans.get("dr:" + str(self.id)) is None: - if self.span_status == SpanStatus.NOT_STARTED or self.span_status == SpanStatus.NEEDS_CONTINUANCE: - dr_span = None - continue_ti_spans = False - if self.span_status == SpanStatus.NOT_STARTED: - dr_span = Trace.start_root_span( - span_name=f"{self.dag_id}", - component="dag", - start_time=self.queued_at, # This is later converted to nano. - start_as_current=False, - ) - elif self.span_status == SpanStatus.NEEDS_CONTINUANCE: - # Use the existing context_carrier to set the initial dag_run span as the parent. - parent_context = Trace.extract(self.context_carrier) - with Trace.start_child_span( - span_name="new_scheduler", parent_context=parent_context - ) as s: - s.set_attribute("trace_status", "continued") - - dr_span = Trace.start_child_span( - span_name=f"{self.dag_id}_continued", - parent_context=parent_context, - component="dag", - # No start time - start_as_current=False, - ) - # After this span is started, the context_carrier will be replaced by the new one. - # New task span will use this span as the parent. - continue_ti_spans = True - carrier = Trace.inject() - self.context_carrier = carrier - self.span_status = SpanStatus.ACTIVE - # Set the span in a synchronized dictionary, so that the variable can be used to end the span. - self.active_spans.set("dr:" + str(self.id), dr_span) - self.log.debug( - "DagRun span has been started and the injected context_carrier is: %s", - self.context_carrier, - ) - # Start TI spans that also need continuance. - if continue_ti_spans: - new_dagrun_context = Trace.extract(self.context_carrier) - for ti in tis: - if ti.span_status == SpanStatus.NEEDS_CONTINUANCE: - ti_span = Trace.start_child_span( - span_name=f"{ti.task_id}_continued", - parent_context=new_dagrun_context, - start_as_current=False, - ) - ti_carrier = Trace.inject() - ti.context_carrier = ti_carrier - ti.span_status = SpanStatus.ACTIVE - self.active_spans.set(f"ti:{ti.id}", ti_span) - else: - self.log.debug( - "Found span_status '%s', while updating state for dag_run '%s'", - self.span_status, - self.run_id, - ) - - def end_dr_span_if_needed(self): - if self.active_spans is not None: - active_span = self.active_spans.get("dr:" + str(self.id)) - if active_span is not None: - self.log.debug( - "Found active span with span_id: %s, for dag_id: %s, run_id: %s, state: %s", - active_span.get_span_context().span_id, - self.dag_id, - self.run_id, - self.state, - ) - - self.set_dagrun_span_attrs(span=active_span) - active_span.end(end_time=datetime_to_nano(self.end_date)) - # Remove the span from the dict. - self.active_spans.delete("dr:" + str(self.id)) - self.span_status = SpanStatus.ENDED - else: - if self.span_status == SpanStatus.ACTIVE: - # Another scheduler has started the span. - # Update the DB SpanStatus to notify the owner to end it. - self.span_status = SpanStatus.SHOULD_END - elif self.span_status == SpanStatus.NEEDS_CONTINUANCE: - # This is a corner case where the scheduler exited gracefully - # while the dag_run was almost done. - # Since it reached this point, the dag has finished but there has been no time - # to create a new span for the current scheduler. - # There is no need for more spans, update the status on the db. - self.span_status = SpanStatus.ENDED - else: - self.log.debug( - "No active span has been found for dag_id: %s, run_id: %s, state: %s", - self.dag_id, - self.run_id, - self.state, - ) + def _emit_dagrun_span(self, state: DagRunState): + ctx = TraceContextTextMapPropagator().extract(self.context_carrier) + span = trace.get_current_span(context=ctx) + span_context = span.get_span_context() + with override_ids(span_context.trace_id, span_context.span_id): + attributes = { + "airflow.dag_id": str(self.dag_id), + "airflow.dag_run.run_id": self.run_id, + } + if self.logical_date: + attributes["airflow.dag_run.logical_date"] = str(self.logical_date) + if self.partition_key: + attributes["airflow.dag_run.partition_key"] = str(self.partition_key) + span = tracer.start_span( + name=f"dag_run.{self.dag_id}", + start_time=int((self.start_date or timezone.utcnow()).timestamp() * 1e9), + attributes=attributes, + context=context.Context(), + ) + status_code = StatusCode.OK if state == DagRunState.SUCCESS else StatusCode.ERROR + span.set_status(status_code) + span.end() @provide_session def update_state( @@ -1302,9 +1195,6 @@ def recalculate(self) -> _UnfinishedStates: # finally, if the leaves aren't done, the dag is still running else: - # It might need to start TI spans as well. - self.start_dr_spans_if_needed(tis=tis) - self.set_state(DagRunState.RUNNING) if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS: @@ -1331,10 +1221,8 @@ def recalculate(self) -> _UnfinishedStates: self.data_interval_start, self.data_interval_end, ) - - self.end_dr_span_if_needed() - session.flush() + self._emit_dagrun_span(state=self.state) self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis) self._emit_duration_stats_for_finished_state() diff --git a/airflow-core/src/airflow/observability/traces/__init__.py b/airflow-core/src/airflow/observability/traces/__init__.py index 217e5db960782..6bf0019f74708 100644 --- a/airflow-core/src/airflow/observability/traces/__init__.py +++ b/airflow-core/src/airflow/observability/traces/__init__.py @@ -15,3 +15,137 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +import logging +import os +from contextlib import contextmanager +from importlib.metadata import entry_points + +from opentelemetry import context, trace +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter +from opentelemetry.sdk.trace.id_generator import RandomIdGenerator +from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +from airflow.configuration import conf + +log = logging.getLogger(__name__) + +OVERRIDE_SPAN_ID_KEY = context.create_key("override_span_id") +OVERRIDE_TRACE_ID_KEY = context.create_key("override_trace_id") + + +class OverrideableRandomIdGenerator(RandomIdGenerator): + """Lets you override the span id.""" + + def generate_span_id(self): + override = context.get_value(OVERRIDE_SPAN_ID_KEY) + if override is not None: + return override + return super().generate_span_id() + + def generate_trace_id(self): + override = context.get_value(OVERRIDE_TRACE_ID_KEY) + if override is not None: + return override + return super().generate_trace_id() + + +def new_dagrun_trace_carrier() -> dict[str, str]: + """Generate a fresh W3C traceparent carrier without creating a recordable span.""" + gen = RandomIdGenerator() + span_ctx = SpanContext( + trace_id=gen.generate_trace_id(), + span_id=gen.generate_span_id(), + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + ctx = trace.set_span_in_context(NonRecordingSpan(span_ctx)) + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier, context=ctx) + return carrier + + +@contextmanager +def override_ids(trace_id, span_id, ctx=None): + ctx = context.set_value(OVERRIDE_TRACE_ID_KEY, trace_id, context=ctx) + ctx = context.set_value(OVERRIDE_SPAN_ID_KEY, span_id, context=ctx) + token = context.attach(ctx) + try: + yield + finally: + context.detach(token) + + +def _get_backcompat_config() -> tuple[str | None, Resource | None]: + """ + Possibly get deprecated Airflow configs for otel. + + Ideally we return (None, None) here. But if the old configuration is there, + then we will use it. + """ + resource = None + if not os.environ.get("OTEL_SERVICE_NAME") and not os.environ.get("OTEL_RESOURCE_ATTRIBUTES"): + service_name = conf.get("traces", "otel_service", fallback=None) + if service_name: + resource = Resource({"service.name": service_name}) + + endpoint = None + if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") and not os.environ.get( + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" + ): + # this is only for backcompat! + host = conf.get("traces", "otel_host", fallback=None) + port = conf.get("traces", "otel_port", fallback=None) + ssl_active = conf.getboolean("traces", "otel_ssl_active", fallback=False) + if host and port: + scheme = "https" if ssl_active else "http" + endpoint = f"{scheme}://{host}:{port}/v1/traces" + return endpoint, resource + + +def _load_exporter_from_env() -> SpanExporter: + """ + Load a span exporter using the OTEL_TRACES_EXPORTER env var. + + Mirrors the entry-point mechanism used by the OTEL SDK auto-instrumentation + configurator. Supported values (from installed packages): + - ``otlp`` (default) — OTLP/gRPC + - ``otlp_proto_http`` — OTLP/HTTP + - ``console`` — stdout (useful for debugging) + """ + exporter_name = os.environ.get("OTEL_TRACES_EXPORTER", "otlp") + eps = entry_points(group="opentelemetry_traces_exporter", name=exporter_name) + ep = next(iter(eps), None) + if ep is None: + raise RuntimeError( + f"No span exporter found for OTEL_TRACES_EXPORTER={exporter_name!r}. " + f"Available: {[e.name for e in entry_points(group='opentelemetry_traces_exporter')]}" + ) + return ep.load()() + + +def configure_otel(): + otel_on = conf.getboolean("traces", "otel_on", fallback=False) + if not otel_on: + return + + # ideally both endpoint and resource are None here + # they would only be something other than None if user is using deprecated + # Airflow-defined otel configs + backcompat_endpoint, resource = _get_backcompat_config() + + # backcompat: if old-style host/port config provided an endpoint, set the + # env var so the exporter (loaded below) picks it up automatically + + otlp_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + otlp_traces_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") + if backcompat_endpoint and not (otlp_endpoint or otlp_traces_endpoint): + os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = backcompat_endpoint + + provider = TracerProvider(id_generator=OverrideableRandomIdGenerator(), resource=resource) + provider.add_span_processor(BatchSpanProcessor(_load_exporter_from_env())) + trace.set_tracer_provider(provider) diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index b8bc480ef156f..49d46f652c6fb 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -38,6 +38,8 @@ ) from sqlalchemy.orm import scoped_session, sessionmaker +from airflow.observability.traces import configure_otel + try: from sqlalchemy.ext.asyncio import async_sessionmaker except ImportError: @@ -722,7 +724,7 @@ def initialize(): load_policy_plugins(policy_mgr) import_local_settings() configure_logging() - + configure_otel() configure_adapters() # The webservers import this file from models.py with the default settings. diff --git a/airflow-core/tests/integration/otel/dags/otel_test_dag.py b/airflow-core/tests/integration/otel/dags/otel_test_dag.py index 6c005a9927ee9..25861c8f622ae 100644 --- a/airflow-core/tests/integration/otel/dags/otel_test_dag.py +++ b/airflow-core/tests/integration/otel/dags/otel_test_dag.py @@ -22,12 +22,12 @@ from opentelemetry import trace from airflow import DAG -from airflow.sdk import chain, task -from airflow.sdk.observability.trace import Trace -from airflow.sdk.observability.traces import otel_tracer +from airflow.sdk import task logger = logging.getLogger("airflow.otel_test_dag") +tracer = trace.get_tracer(__name__) + args = { "owner": "airflow", "start_date": datetime(2024, 9, 1), @@ -36,52 +36,13 @@ @task -def task1(ti): - logger.info("Starting Task_1.") - - context_carrier = ti.context_carrier - - otel_task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) - tracer_provider = otel_task_tracer.get_otel_tracer_provider() - - if context_carrier is not None: - logger.info("Found ti.context_carrier: %s.", str(context_carrier)) - logger.info("Extracting the span context from the context_carrier.") - parent_context = otel_task_tracer.extract(context_carrier) - with otel_task_tracer.start_child_span( - span_name="task1_sub_span1", - parent_context=parent_context, - component="dag", - ) as s1: - s1.set_attribute("attr1", "val1") - logger.info("From task sub_span1.") - - with otel_task_tracer.start_child_span("task1_sub_span2") as s2: - s2.set_attribute("attr2", "val2") - logger.info("From task sub_span2.") +def task1(): + logger.info("starting task1") - tracer = trace.get_tracer("trace_test.tracer", tracer_provider=tracer_provider) - with tracer.start_as_current_span(name="task1_sub_span3") as s3: - s3.set_attribute("attr3", "val3") - logger.info("From task sub_span3.") + with tracer.start_as_current_span("sub_span1") as s1: + s1.set_attribute("attr1", "val1") - with otel_task_tracer.start_child_span( - span_name="task1_sub_span4", - parent_context=parent_context, - component="dag", - ) as s4: - s4.set_attribute("attr4", "val4") - logger.info("From task sub_span4.") - - logger.info("Task_1 finished.") - - -@task -def task2(): - logger.info("Starting Task_2.") - for i in range(3): - logger.info("Task_2, iteration '%d'.", i) - logger.info("Task_2 finished.") + logger.info("task1 finished.") with DAG( @@ -90,4 +51,4 @@ def task2(): schedule=None, catchup=False, ) as dag: - chain(task1(), task2()) # type: ignore + task1() diff --git a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_between_tasks.py b/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_between_tasks.py deleted file mode 100644 index 72fb9148a40e5..0000000000000 --- a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_between_tasks.py +++ /dev/null @@ -1,158 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import logging -import os -import time -from datetime import datetime - -from opentelemetry import trace -from sqlalchemy import select - -from airflow import DAG -from airflow.models import TaskInstance -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.sdk import chain, task -from airflow.sdk.observability.trace import Trace -from airflow.sdk.observability.traces import otel_tracer -from airflow.utils.session import create_session - -logger = logging.getLogger("airflow.otel_test_dag_with_pause") - -args = { - "owner": "airflow", - "start_date": datetime(2024, 9, 2), - "retries": 0, -} - - -@task -def task1(ti): - logger.info("Starting Task_1.") - - context_carrier = ti.context_carrier - - otel_task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) - tracer_provider = otel_task_tracer.get_otel_tracer_provider() - - if context_carrier is not None: - logger.info("Found ti.context_carrier: %s.", context_carrier) - logger.info("Extracting the span context from the context_carrier.") - - # If the task takes too long to execute, then the ti should be read from the db - # to make sure that the initial context_carrier is the same. - # Since Airflow 3, direct db access has been removed entirely. - if not AIRFLOW_V_3_0_PLUS: - with create_session() as session: - session_ti: TaskInstance = session.scalars( - select(TaskInstance).where( - TaskInstance.task_id == ti.task_id, - TaskInstance.run_id == ti.run_id, - ) - ).one() - context_carrier = session_ti.context_carrier - - parent_context = Trace.extract(context_carrier) - with otel_task_tracer.start_child_span( - span_name="task1_sub_span1", - parent_context=parent_context, - component="dag", - ) as s1: - s1.set_attribute("attr1", "val1") - logger.info("From task sub_span1.") - - with otel_task_tracer.start_child_span("task1_sub_span2") as s2: - s2.set_attribute("attr2", "val2") - logger.info("From task sub_span2.") - - tracer = trace.get_tracer("trace_test.tracer", tracer_provider=tracer_provider) - with tracer.start_as_current_span(name="task1_sub_span3") as s3: - s3.set_attribute("attr3", "val3") - logger.info("From task sub_span3.") - - if not AIRFLOW_V_3_0_PLUS: - with create_session() as session: - session_ti: TaskInstance = session.scalars( - select(TaskInstance).where( - TaskInstance.task_id == ti.task_id, - TaskInstance.run_id == ti.run_id, - ) - ).one() - context_carrier = session_ti.context_carrier - parent_context = Trace.extract(context_carrier) - - with otel_task_tracer.start_child_span( - span_name="task1_sub_span4", - parent_context=parent_context, - component="dag", - ) as s4: - s4.set_attribute("attr4", "val4") - logger.info("From task sub_span4.") - - logger.info("Task_1 finished.") - - -@task -def paused_task(): - logger.info("Starting Paused_task.") - - dag_folder = os.path.dirname(os.path.abspath(__file__)) - control_file = os.path.join(dag_folder, "dag_control.txt") - - # Create the file and write 'pause' to it. - with open(control_file, "w") as file: - file.write("pause") - - # Pause execution until the word 'pause' is replaced on the file. - while True: - # If there is an exception, then writing to the file failed. Let it exit. - file_contents = None - with open(control_file) as file: - file_contents = file.read() - - if "pause" in file_contents: - logger.info("Task has been paused.") - time.sleep(1) - continue - logger.info("Resuming task execution.") - # Break the loop and finish with the task execution. - break - - # Cleanup the control file. - if os.path.exists(control_file): - os.remove(control_file) - print("Control file has been cleaned up.") - - logger.info("Paused_task finished.") - - -@task -def task2(): - logger.info("Starting Task_2.") - for i in range(3): - logger.info("Task_2, iteration '%d'.", i) - logger.info("Task_2 finished.") - - -with DAG( - "otel_test_dag_with_pause_between_tasks", - default_args=args, - schedule=None, - catchup=False, -) as dag: - chain(task1(), paused_task(), task2()) # type: ignore diff --git a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py b/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py deleted file mode 100644 index dfc5c30243f08..0000000000000 --- a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import logging -import os -import time -from datetime import datetime - -from opentelemetry import trace -from sqlalchemy import select - -from airflow import DAG -from airflow.models import TaskInstance -from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS -from airflow.sdk import chain, task -from airflow.sdk.observability.trace import Trace -from airflow.sdk.observability.traces import otel_tracer -from airflow.utils.session import create_session - -logger = logging.getLogger("airflow.otel_test_dag_with_pause_in_task") - -args = { - "owner": "airflow", - "start_date": datetime(2024, 9, 2), - "retries": 0, -} - - -@task -def task1(ti): - logger.info("Starting Task_1.") - - context_carrier = ti.context_carrier - - dag_folder = os.path.dirname(os.path.abspath(__file__)) - control_file = os.path.join(dag_folder, "dag_control.txt") - - # Create the file and write 'pause' to it. - with open(control_file, "w") as file: - file.write("pause") - - # Pause execution until the word 'pause' is replaced on the file. - while True: - # If there is an exception, then writing to the file failed. Let it exit. - file_contents = None - with open(control_file) as file: - file_contents = file.read() - - if "pause" in file_contents: - logger.info("Task has been paused.") - time.sleep(1) - continue - logger.info("Resuming task execution.") - # Break the loop and finish with the task execution. - break - - otel_task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) - tracer_provider = otel_task_tracer.get_otel_tracer_provider() - - if context_carrier is not None: - logger.info("Found ti.context_carrier: %s.", context_carrier) - logger.info("Extracting the span context from the context_carrier.") - - # If the task takes too long to execute, then the ti should be read from the db - # to make sure that the initial context_carrier is the same. - # Since Airflow 3, direct db access has been removed entirely. - if not AIRFLOW_V_3_0_PLUS: - with create_session() as session: - session_ti: TaskInstance = session.scalars( - select(TaskInstance).where( - TaskInstance.task_id == ti.task_id, - TaskInstance.run_id == ti.run_id, - ) - ).one() - context_carrier = session_ti.context_carrier - - parent_context = Trace.extract(context_carrier) - with otel_task_tracer.start_child_span( - span_name="task1_sub_span1", - parent_context=parent_context, - component="dag", - ) as s1: - s1.set_attribute("attr1", "val1") - logger.info("From task sub_span1.") - - with otel_task_tracer.start_child_span("task1_sub_span2") as s2: - s2.set_attribute("attr2", "val2") - logger.info("From task sub_span2.") - - tracer = trace.get_tracer("trace_test.tracer", tracer_provider=tracer_provider) - with tracer.start_as_current_span(name="task1_sub_span3") as s3: - s3.set_attribute("attr3", "val3") - logger.info("From task sub_span3.") - - if not AIRFLOW_V_3_0_PLUS: - with create_session() as session: - session_ti: TaskInstance = session.scalars( - select(TaskInstance).where( - TaskInstance.task_id == ti.task_id, - TaskInstance.run_id == ti.run_id, - ) - ).one() - context_carrier = session_ti.context_carrier - parent_context = Trace.extract(context_carrier) - - with otel_task_tracer.start_child_span( - span_name="task1_sub_span4", - parent_context=parent_context, - component="dag", - ) as s4: - s4.set_attribute("attr4", "val4") - logger.info("From task sub_span4.") - - # Cleanup the control file. - if os.path.exists(control_file): - os.remove(control_file) - print("Control file has been cleaned up.") - - logger.info("Task_1 finished.") - - -@task -def task2(): - logger.info("Starting Task_2.") - for i in range(3): - logger.info("Task_2, iteration '%d'.", i) - logger.info("Task_2 finished.") - - -with DAG( - "otel_test_dag_with_pause_in_task", - default_args=args, - schedule=None, - catchup=False, -) as dag: - chain(task1(), task2()) # type: ignore diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 0e4546e301d77..60af1060ce12a 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -250,7 +250,7 @@ def serialize_and_get_dags(cls) -> dict[str, SerializedDAG]: dag_bag = DagBag(dag_folder=cls.dag_folder, include_examples=False) dag_ids = dag_bag.dag_ids - assert len(dag_ids) == 3 + assert len(dag_ids) == 1 dag_dict: dict[str, SerializedDAG] = {} with create_session() as session: @@ -317,7 +317,7 @@ def dag_execution_for_testing_metrics(self, capfd): try: # Start the processes here and not as fixtures or in a common setup, # so that the test can capture their output. - scheduler_process, apiserver_process = self.start_worker_and_scheduler() + scheduler_process, apiserver_process = self.start_scheduler() dag_id = "otel_test_dag" @@ -441,7 +441,7 @@ def test_dag_execution_succeeds(self, capfd): try: # Start the processes here and not as fixtures or in a common setup, # so that the test can capture their output. - scheduler_process, apiserver_process = self.start_worker_and_scheduler() + scheduler_process, apiserver_process = self.start_scheduler() dag_id = "otel_test_dag" @@ -486,10 +486,8 @@ def test_dag_execution_succeeds(self, capfd): log.info("out-start --\n%s\n-- out-end", out) log.info("err-start --\n%s\n-- err-end", err) - # host = "host.docker.internal" host = "jaeger" service_name = os.environ.get("OTEL_SERVICE_NAME", "test") - # service_name ``= "my-service-name" r = requests.get(f"http://{host}:16686/api/traces?service={service_name}") data = r.json() @@ -510,16 +508,12 @@ def get_parent_span_id(span): nested = get_span_hierarchy() assert nested == { - "otel_test_dag": None, - "task1": None, - "task1_sub_span1": None, - "task1_sub_span2": None, - "task1_sub_span3": "task1_sub_span2", - "task1_sub_span4": None, - "task2": None, + "sub_span1": "task_run.task1", + "task_run.task1": "dag_run.otel_test_dag", + "dag_run.otel_test_dag": None, } - def start_worker_and_scheduler(self): + def start_scheduler(self): scheduler_process = subprocess.Popen( self.scheduler_command_args, env=os.environ.copy(), diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 3aeebcedbdc2c..c23f180f3a479 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -81,7 +81,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.team import Team from airflow.models.trigger import Trigger -from airflow.observability.trace import Trace from airflow.partition_mappers.base import PartitionMapper as CorePartitionMapper from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator @@ -93,9 +92,7 @@ from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.timetables.base import DagRunInfo, DataInterval from airflow.utils.session import create_session, provide_session -from airflow.utils.span_status import SpanStatus from airflow.utils.state import CallbackState, DagRunState, State, TaskInstanceState -from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.pytest_plugin import AIRFLOW_ROOT_PATH @@ -3283,190 +3280,6 @@ def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses dag_runs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(dag_runs) == 2 - @pytest.mark.parametrize( - ("ti_state", "final_ti_span_status"), - [ - pytest.param(State.SUCCESS, SpanStatus.ENDED, id="dr_ended_successfully"), - pytest.param(State.RUNNING, SpanStatus.ACTIVE, id="dr_still_running"), - ], - ) - def test_recreate_unhealthy_scheduler_spans_if_needed(self, ti_state, final_ti_span_status, dag_maker): - with dag_maker( - dag_id="test_recreate_unhealthy_scheduler_spans_if_needed", - start_date=DEFAULT_DATE, - max_active_runs=1, - dagrun_timeout=datetime.timedelta(seconds=60), - ): - EmptyOperator(task_id="dummy") - - session = settings.Session() - - old_job = Job() - old_job.job_type = SchedulerJobRunner.job_type - - session.add(old_job) - session.commit() - - assert old_job.is_alive() is False - - new_job = Job() - new_job.job_type = SchedulerJobRunner.job_type - session.add(new_job) - session.flush() - - self.job_runner = SchedulerJobRunner(job=new_job) - self.job_runner.active_spans = ThreadSafeDict() - assert len(self.job_runner.active_spans.get_all()) == 0 - - dr = dag_maker.create_dagrun() - dr.state = State.RUNNING - dr.span_status = SpanStatus.ACTIVE - dr.scheduled_by_job_id = old_job.id - - ti = dr.get_task_instances(session=session)[0] - ti.state = ti_state - ti.start_date = timezone.utcnow() - ti.span_status = SpanStatus.ACTIVE - ti.queued_by_job_id = old_job.id - session.merge(ti) - session.merge(dr) - session.commit() - - assert dr.scheduled_by_job_id != self.job_runner.job.id - assert dr.scheduled_by_job_id == old_job.id - assert dr.run_id is not None - assert dr.state == State.RUNNING - assert dr.span_status == SpanStatus.ACTIVE - assert self.job_runner.active_spans.get("dr:" + str(dr.id)) is None - - assert self.job_runner.active_spans.get(f"ti:{ti.id}") is None - assert ti.state == ti_state - assert ti.span_status == SpanStatus.ACTIVE - - self.job_runner._recreate_unhealthy_scheduler_spans_if_needed(dr, session) - - assert self.job_runner.active_spans.get("dr:" + str(dr.id)) is not None - - if final_ti_span_status == SpanStatus.ACTIVE: - assert self.job_runner.active_spans.get(f"ti:{ti.id}") is not None - assert len(self.job_runner.active_spans.get_all()) == 2 - else: - assert self.job_runner.active_spans.get(f"ti:{ti.id}") is None - assert len(self.job_runner.active_spans.get_all()) == 1 - - assert dr.span_status == SpanStatus.ACTIVE - assert ti.span_status == final_ti_span_status - - def test_end_spans_of_externally_ended_ops(self, dag_maker): - with dag_maker( - dag_id="test_end_spans_of_externally_ended_ops", - start_date=DEFAULT_DATE, - max_active_runs=1, - dagrun_timeout=datetime.timedelta(seconds=60), - ): - EmptyOperator(task_id="dummy") - - session = settings.Session() - - job = Job() - job.job_type = SchedulerJobRunner.job_type - session.add(job) - - self.job_runner = SchedulerJobRunner(job=job) - self.job_runner.active_spans = ThreadSafeDict() - assert len(self.job_runner.active_spans.get_all()) == 0 - - dr = dag_maker.create_dagrun() - dr.state = State.SUCCESS - dr.span_status = SpanStatus.SHOULD_END - - ti = dr.get_task_instances(session=session)[0] - ti.state = State.SUCCESS - ti.span_status = SpanStatus.SHOULD_END - ti.context_carrier = {} - session.merge(ti) - session.merge(dr) - session.commit() - - dr_span = Trace.start_root_span(span_name="dag_run_span", start_as_current=False) - ti_span = Trace.start_child_span(span_name="ti_span", start_as_current=False) - - self.job_runner.active_spans.set("dr:" + str(dr.id), dr_span) - self.job_runner.active_spans.set(f"ti:{ti.id}", ti_span) - - assert dr.span_status == SpanStatus.SHOULD_END - assert ti.span_status == SpanStatus.SHOULD_END - - assert self.job_runner.active_spans.get("dr:" + str(dr.id)) is not None - assert self.job_runner.active_spans.get(f"ti:{ti.id}") is not None - - self.job_runner._end_spans_of_externally_ended_ops(session) - - assert dr.span_status == SpanStatus.ENDED - assert ti.span_status == SpanStatus.ENDED - - assert self.job_runner.active_spans.get("dr:" + str(dr.id)) is None - assert self.job_runner.active_spans.get(f"ti:{ti.id}") is None - - @pytest.mark.parametrize( - ("state", "final_span_status"), - [ - pytest.param(State.SUCCESS, SpanStatus.ENDED, id="dr_ended_successfully"), - pytest.param(State.RUNNING, SpanStatus.NEEDS_CONTINUANCE, id="dr_still_running"), - ], - ) - def test_end_active_spans(self, state, final_span_status, dag_maker): - with dag_maker( - dag_id="test_end_active_spans", - start_date=DEFAULT_DATE, - max_active_runs=1, - dagrun_timeout=datetime.timedelta(seconds=60), - ): - EmptyOperator(task_id="dummy") - - session = settings.Session() - - job = Job() - job.job_type = SchedulerJobRunner.job_type - - self.job_runner = SchedulerJobRunner(job=job) - self.job_runner.active_spans = ThreadSafeDict() - assert len(self.job_runner.active_spans.get_all()) == 0 - - dr = dag_maker.create_dagrun() - dr.state = state - dr.span_status = SpanStatus.ACTIVE - - ti = dr.get_task_instances(session=session)[0] - ti.state = state - ti.span_status = SpanStatus.ACTIVE - ti.context_carrier = {} - session.merge(ti) - session.merge(dr) - session.commit() - - dr_span = Trace.start_root_span(span_name="dag_run_span", start_as_current=False) - ti_span = Trace.start_child_span(span_name="ti_span", start_as_current=False) - - self.job_runner.active_spans.set("dr:" + str(dr.id), dr_span) - self.job_runner.active_spans.set(f"ti:{ti.id}", ti_span) - - assert dr.span_status == SpanStatus.ACTIVE - assert ti.span_status == SpanStatus.ACTIVE - - assert self.job_runner.active_spans.get("dr:" + str(dr.id)) is not None - assert self.job_runner.active_spans.get(f"ti:{ti.id}") is not None - assert len(self.job_runner.active_spans.get_all()) == 2 - - self.job_runner._end_active_spans(session) - - assert dr.span_status == final_span_status - assert ti.span_status == final_span_status - - assert self.job_runner.active_spans.get("dr:" + str(dr.id)) is None - assert self.job_runner.active_spans.get(f"ti:{ti.id}") is None - assert len(self.job_runner.active_spans.get_all()) == 0 - def test_dagrun_timeout_verify_max_active_runs(self, dag_maker, session): """ Test if a dagrun will not be scheduled if max_dag_runs diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index f3de13422fac3..14722f83b0cce 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -27,6 +27,7 @@ import pendulum import pytest +from opentelemetry.sdk.trace import TracerProvider from sqlalchemy import func, select from sqlalchemy.orm import joinedload @@ -54,9 +55,7 @@ from airflow.settings import get_policy_plugin_manager from airflow.task.trigger_rule import TriggerRule from airflow.triggers.base import StartTriggerArgs -from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.thread_safe_dict import ThreadSafeDict from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils import db @@ -560,142 +559,6 @@ def test_on_success_callback_when_task_skipped(self, session, testing_dag_bundle assert dag_run.state == DagRunState.SUCCESS mock_on_success.assert_called_once() - def test_start_dr_spans_if_needed_new_span(self, dag_maker, session): - with dag_maker( - dag_id="test_start_dr_spans_if_needed_new_span", - schedule=datetime.timedelta(days=1), - start_date=datetime.datetime(2017, 1, 1), - ) as dag: - dag_task1 = EmptyOperator(task_id="test_task1") - dag_task2 = EmptyOperator(task_id="test_task2") - dag_task1.set_downstream(dag_task2) - - initial_task_states = { - "test_task1": TaskInstanceState.QUEUED, - "test_task2": TaskInstanceState.QUEUED, - } - - dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) - - active_spans = ThreadSafeDict() - dag_run.set_active_spans(active_spans) - - tis = dag_run.get_task_instances() - - assert dag_run.active_spans is not None - assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is None - assert dag_run.span_status == SpanStatus.NOT_STARTED - - dag_run.start_dr_spans_if_needed(tis=tis) - - assert dag_run.span_status == SpanStatus.ACTIVE - assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is not None - - def test_start_dr_spans_if_needed_span_with_continuance(self, dag_maker, session): - with dag_maker( - dag_id="test_start_dr_spans_if_needed_span_with_continuance", - schedule=datetime.timedelta(days=1), - start_date=datetime.datetime(2017, 1, 1), - ) as dag: - dag_task1 = EmptyOperator(task_id="test_task1") - dag_task2 = EmptyOperator(task_id="test_task2") - dag_task1.set_downstream(dag_task2) - - initial_task_states = { - "test_task1": TaskInstanceState.RUNNING, - "test_task2": TaskInstanceState.QUEUED, - } - - dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) - - active_spans = ThreadSafeDict() - dag_run.set_active_spans(active_spans) - - dag_run.span_status = SpanStatus.NEEDS_CONTINUANCE - - tis = dag_run.get_task_instances() - - first_ti = tis[0] - first_ti.span_status = SpanStatus.NEEDS_CONTINUANCE - - assert dag_run.active_spans is not None - assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is None - assert dag_run.active_spans.get(f"ti:{first_ti.id}") is None - assert dag_run.span_status == SpanStatus.NEEDS_CONTINUANCE - assert first_ti.span_status == SpanStatus.NEEDS_CONTINUANCE - - dag_run.start_dr_spans_if_needed(tis=tis) - - assert dag_run.span_status == SpanStatus.ACTIVE - assert first_ti.span_status == SpanStatus.ACTIVE - assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is not None - assert dag_run.active_spans.get(f"ti:{first_ti.id}") is not None - - def test_end_dr_span_if_needed(self, testing_dag_bundle, dag_maker, session): - with dag_maker( - dag_id="test_end_dr_span_if_needed", - schedule=datetime.timedelta(days=1), - start_date=datetime.datetime(2017, 1, 1), - ) as dag: - dag_task1 = EmptyOperator(task_id="test_task1") - dag_task2 = EmptyOperator(task_id="test_task2") - dag_task1.set_downstream(dag_task2) - - initial_task_states = { - "test_task1": TaskInstanceState.SUCCESS, - "test_task2": TaskInstanceState.SUCCESS, - } - - dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) - - active_spans = ThreadSafeDict() - dag_run.set_active_spans(active_spans) - - from airflow.observability.trace import Trace - - dr_span = Trace.start_root_span(span_name="test_span", start_as_current=False) - - active_spans.set("dr:" + str(dag_run.id), dr_span) - - assert dag_run.active_spans is not None - assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is not None - - dag_run.end_dr_span_if_needed() - - assert dag_run.span_status == SpanStatus.ENDED - assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is None - - def test_end_dr_span_if_needed_with_span_from_another_scheduler( - self, testing_dag_bundle, dag_maker, session - ): - with dag_maker( - dag_id="test_end_dr_span_if_needed_with_span_from_another_scheduler", - schedule=datetime.timedelta(days=1), - start_date=datetime.datetime(2017, 1, 1), - ) as dag: - dag_task1 = EmptyOperator(task_id="test_task1") - dag_task2 = EmptyOperator(task_id="test_task2") - dag_task1.set_downstream(dag_task2) - - initial_task_states = { - "test_task1": TaskInstanceState.SUCCESS, - "test_task2": TaskInstanceState.SUCCESS, - } - - dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session) - - active_spans = ThreadSafeDict() - dag_run.set_active_spans(active_spans) - - dag_run.span_status = SpanStatus.ACTIVE - - assert dag_run.active_spans is not None - assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is None - - dag_run.end_dr_span_if_needed() - - assert dag_run.span_status == SpanStatus.SHOULD_END - def test_dagrun_update_state_with_handle_callback_success(self, testing_dag_bundle, dag_maker, session): def on_success_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_success" @@ -744,7 +607,6 @@ def on_success_callable(context): ) def test_dagrun_update_state_with_handle_callback_failure(self, testing_dag_bundle, dag_maker, session): - def on_failure_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_failure" @@ -3292,3 +3154,159 @@ def on_failure(context): assert context_received["ti"].task_id == "test_task" assert context_received["ti"].dag_id == "test_dag" assert context_received["ti"].run_id == dr.run_id + + +class TestDagRunTracing: + """Tests for DagRun OpenTelemetry span behavior.""" + + @pytest.fixture(autouse=True) + def sdk_tracer_provider(self): + """Patch the module-level tracer with one backed by a real SDK provider so spans have valid IDs.""" + provider = TracerProvider() + real_tracer = provider.get_tracer("airflow.models.dagrun") + with mock.patch("airflow.models.dagrun.tracer", real_tracer): + yield + + def test_context_carrier_set_on_init(self, dag_maker): + """DagRun.__init__ should populate context_carrier with a W3C traceparent.""" + with dag_maker("test_tracing_init"): + EmptyOperator(task_id="t1") + dr = dag_maker.create_dagrun() + + assert dr.context_carrier is not None + assert isinstance(dr.context_carrier, dict) + assert "traceparent" in dr.context_carrier + + def test_context_carrier_unique_per_dagrun(self, dag_maker): + """Each DagRun should get a distinct trace context.""" + with dag_maker("test_tracing_unique1"): + EmptyOperator(task_id="t1") + dr1 = dag_maker.create_dagrun() + + with dag_maker("test_tracing_unique2"): + EmptyOperator(task_id="t1") + dr2 = dag_maker.create_dagrun() + + assert dr1.context_carrier["traceparent"] != dr2.context_carrier["traceparent"] + + @pytest.mark.parametrize("final_state", [DagRunState.SUCCESS, DagRunState.FAILED]) + def test_emit_dagrun_span_called_on_completion(self, dag_maker, session, final_state): + """_emit_dagrun_span should be called exactly once when a dag run finishes.""" + with dag_maker("test_tracing_emit", session=session) as dag: + EmptyOperator(task_id="t1") + + dr = dag_maker.create_dagrun(state=DagRunState.RUNNING) + ti = dr.get_task_instance("t1", session=session) + ti.state = ( + TaskInstanceState.SUCCESS if final_state == DagRunState.SUCCESS else TaskInstanceState.FAILED + ) + session.flush() + + dr.dag = dag + + with mock.patch.object(dr, "_emit_dagrun_span") as mock_emit: + dr.update_state(session=session) + + mock_emit.assert_called_once_with(state=final_state) + + def test_emit_dagrun_span_not_called_while_running(self, dag_maker, session): + """_emit_dagrun_span should not be called while the dag run is still running.""" + with dag_maker("test_tracing_no_emit_running", session=session) as dag: + EmptyOperator(task_id="t1") + EmptyOperator(task_id="t2") + + dr = dag_maker.create_dagrun(state=DagRunState.RUNNING) + tis = dr.get_task_instances(session=session) + for ti in tis: + if ti.task_id == "t1": + ti.state = TaskInstanceState.SUCCESS + else: + ti.state = TaskInstanceState.RUNNING + session.flush() + + dr.dag = dag + + with mock.patch.object(dr, "_emit_dagrun_span") as mock_emit: + dr.update_state(session=session) + + mock_emit.assert_not_called() + + def test_emit_dagrun_span_uses_context_carrier_ids(self, dag_maker, session): + """The emitted span should inherit trace_id/span_id from the context_carrier.""" + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + from airflow.observability.traces import OverrideableRandomIdGenerator + + in_mem_exporter = InMemorySpanExporter() + provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) + provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) + test_tracer = provider.get_tracer("test") + + with dag_maker("test_tracing_ids", session=session) as dag: + EmptyOperator(task_id="t1") + + dr = dag_maker.create_dagrun(state=DagRunState.RUNNING) + ti = dr.get_task_instance("t1", session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() + dr.dag = dag + + with mock.patch("airflow.models.dagrun.tracer", test_tracer): + dr.update_state(session=session) + + spans = in_mem_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + + # Decode the expected trace_id/span_id from the stored context_carrier + ctx = TraceContextTextMapPropagator().extract(dr.context_carrier) + from opentelemetry import trace as otel_trace + + stored_span = otel_trace.get_current_span(context=ctx) + stored_ctx = stored_span.get_span_context() + + assert span.context.trace_id == stored_ctx.trace_id + assert span.context.span_id == stored_ctx.span_id + + @pytest.mark.parametrize("final_state", [DagRunState.SUCCESS, DagRunState.FAILED]) + def test_emit_dagrun_span_attributes_and_status(self, dag_maker, session, final_state): + """The emitted span should have the correct name, attributes, and status code.""" + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from opentelemetry.trace import StatusCode + + from airflow.observability.traces import OverrideableRandomIdGenerator + + in_mem_exporter = InMemorySpanExporter() + provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) + provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) + test_tracer = provider.get_tracer("test") + + with dag_maker("test_tracing_attrs", session=session) as dag: + EmptyOperator(task_id="t1") + + dr = dag_maker.create_dagrun(state=DagRunState.RUNNING) + ti = dr.get_task_instance("t1", session=session) + ti.state = ( + TaskInstanceState.SUCCESS if final_state == DagRunState.SUCCESS else TaskInstanceState.FAILED + ) + session.flush() + dr.dag = dag + + with mock.patch("airflow.models.dagrun.tracer", test_tracer): + dr.update_state(session=session) + + spans = in_mem_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + + assert span.name == f"dag_run.{dr.dag_id}" + assert span.attributes["airflow.dag_id"] == dr.dag_id + assert span.attributes["airflow.dag_run.run_id"] == dr.run_id + + expected_status = StatusCode.OK if final_state == DagRunState.SUCCESS else StatusCode.ERROR + assert span.status.status_code == expected_status diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index f6b79fba29aba..aa0bd36de11ad 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1002,6 +1002,7 @@ middleware middlewares midnights milli +millis milton minikube misconfigured diff --git a/scripts/ci/docker-compose/integration-otel.yml b/scripts/ci/docker-compose/integration-otel.yml index 9d5c6c8117ded..f0d32104a14ab 100644 --- a/scripts/ci/docker-compose/integration-otel.yml +++ b/scripts/ci/docker-compose/integration-otel.yml @@ -70,7 +70,7 @@ services: - INTEGRATION_OTEL=true - OTEL_SERVICE_NAME=test - OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf - - OTEL_TRACES_EXPORTER=otlp + - OTEL_TRACES_EXPORTER=otlp_proto_http - OTEL_METRICS_EXPORTER=otlp - OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://breeze-otel-collector:4318/v1/traces - OTEL_EXPORTER_OTLP_METRICS_ENDPOINT=http://breeze-otel-collector:4318/v1/metrics diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 94e447e37a56d..c3f786584746b 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1324,10 +1324,6 @@ def test( triggered_by=DagRunTriggeredByType.TEST, triggering_user_name="dag_test", ) - # Start a mock span so that one is present and not started downstream. We - # don't care about otel in dag.test and starting the span during dagrun update - # is not functioning properly in this context anyway. - dr.start_dr_spans_if_needed(tis=[]) log.debug("starting dagrun") # Instead of starting a scheduler, we run the minimal loop possible to check diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 0ef7a74a24b48..674935f5eaecb 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -19,14 +19,13 @@ from __future__ import annotations -import contextlib import contextvars import functools import os import sys import time from collections.abc import Callable, Iterable, Iterator, Mapping -from contextlib import suppress +from contextlib import ExitStack, contextmanager, suppress from datetime import datetime, timedelta, timezone from itertools import product from pathlib import Path @@ -36,6 +35,8 @@ import attrs import lazy_object_proxy import structlog +from opentelemetry import trace +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock @@ -133,6 +134,32 @@ from airflow.sdk.exceptions import DagRunTriggerException from airflow.sdk.types import OutletEventAccessorsProtocol +log = structlog.get_logger("task") + +tracer = trace.get_tracer(__name__) + + +@contextmanager +def _make_task_span(msg: StartupDetails): + parent_context = ( + TraceContextTextMapPropagator().extract(msg.ti.context_carrier) if msg.ti.context_carrier else None + ) + ti = msg.ti + span_name = f"task_run.{ti.task_id}" + if ti.map_index is not None and ti.map_index >= 0: + span_name += f"_{ti.map_index}" + with tracer.start_as_current_span(span_name, context=parent_context) as span: + span.set_attributes( + { + "airflow.dag_id": ti.dag_id, + "airflow.task_id": ti.task_id, + "airflow.dag_run.run_id": ti.run_id, + "airflow.task_instance.try_number": ti.try_number, + "airflow.task_instance.map_index": ti.map_index if ti.map_index is not None else -1, + } + ) + yield span + class TaskRunnerMarker: """Marker for listener hooks, to properly detect from which component they are called.""" @@ -476,8 +503,6 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: retries: int = self.task.retries or 0 first_try_number = max_tries - retries + 1 - log = structlog.get_logger(logger_name="task") - log.debug("Requesting first reschedule date from supervisor") response = SUPERVISOR_COMMS.send( @@ -494,8 +519,6 @@ def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: context = self.get_template_context() dag_run = context.get("dag_run") - log = structlog.get_logger(logger_name="task") - log.debug("Getting previous Dag run", dag_run=dag_run) if dag_run is None: @@ -530,7 +553,6 @@ def get_previous_ti( context = self.get_template_context() dag_run = context.get("dag_run") - log = structlog.get_logger(logger_name="task") log.debug("Getting previous task instance", task_id=self.task_id, state=state) # Use current dag run's logical_date if not provided @@ -846,7 +868,6 @@ def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None: def get_startup_details() -> StartupDetails: # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent # in response to us sending a request. - log = structlog.get_logger(logger_name="task") if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and ( msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG") @@ -871,7 +892,6 @@ def get_startup_details() -> StartupDetails: def startup(msg: StartupDetails) -> tuple[RuntimeTaskInstance, Context, Logger]: - log = structlog.get_logger("task") # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 os_type = sys.platform if os_type == "darwin": @@ -1238,7 +1258,7 @@ def _on_term(signum, frame): import jinja2 # If the task failed, swallow rendering error so it doesn't mask the main error. - with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): + with suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): previous_rendered_map_index = ti.rendered_map_index ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) # Send update only if value changed (e.g., user set context variables during execution) @@ -1796,6 +1816,20 @@ def finalize( log.exception("error calling listener") +@contextmanager +def flush_spans(): + try: + yield + finally: + provider = trace.get_tracer_provider() + if hasattr(provider, "force_flush"): + from airflow.sdk.configuration import conf + + timeout_millis = conf.getint("traces", "task_runner_flush_timeout_milliseconds", fallback=30000) + provider.force_flush(timeout_millis=timeout_millis) + + +@flush_spans() def main(): log = structlog.get_logger(logger_name="task") @@ -1805,38 +1839,42 @@ def main(): stats_factory = stats_utils.get_stats_factory(Stats) Stats.initialize(factory=stats_factory) - try: + stack = ExitStack() + with stack: try: - startup_details = get_startup_details() - ti, context, log = startup(msg=startup_details) - except AirflowRescheduleException as reschedule: - log.warning("Rescheduling task during startup, marking task as UP_FOR_RESCHEDULE") - SUPERVISOR_COMMS.send( - msg=RescheduleTask( - reschedule_date=reschedule.reschedule_date, - end_date=datetime.now(tz=timezone.utc), + try: + startup_details = get_startup_details() + span = _make_task_span(msg=startup_details) + stack.enter_context(span) + ti, context, log = startup(msg=startup_details) + except AirflowRescheduleException as reschedule: + log.warning("Rescheduling task during startup, marking task as UP_FOR_RESCHEDULE") + SUPERVISOR_COMMS.send( + msg=RescheduleTask( + reschedule_date=reschedule.reschedule_date, + end_date=datetime.now(tz=timezone.utc), + ) ) - ) - sys.exit(0) - with BundleVersionLock( - bundle_name=ti.bundle_instance.name, - bundle_version=ti.bundle_instance.version, - ): - state, _, error = run(ti, context, log) - context["exception"] = error - finalize(ti, state, context, log, error) - except KeyboardInterrupt: - log.exception("Ctrl-c hit") - sys.exit(2) - except Exception: - log.exception("Top level error") - sys.exit(1) - finally: - # Ensure the request socket is closed on the child side in all circumstances - # before the process fully terminates. - if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: - with suppress(Exception): - SUPERVISOR_COMMS.socket.close() + sys.exit(0) + with BundleVersionLock( + bundle_name=ti.bundle_instance.name, + bundle_version=ti.bundle_instance.version, + ): + state, _, error = run(ti, context, log) + context["exception"] = error + finalize(ti, state, context, log, error) + except KeyboardInterrupt: + log.exception("Ctrl-c hit") + sys.exit(2) + except Exception: + log.exception("Top level error") + sys.exit(1) + finally: + # Ensure the request socket is closed on the child side in all circumstances + # before the process fully terminates. + if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: + with suppress(Exception): + SUPERVISOR_COMMS.socket.close() def reinit_supervisor_comms() -> None: @@ -1851,7 +1889,6 @@ def reinit_supervisor_comms() -> None: if "SUPERVISOR_COMMS" not in globals(): global SUPERVISOR_COMMS - log = structlog.get_logger(logger_name="task") fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 2a495e557f70f..05191806be8a5 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -127,6 +127,7 @@ TaskRunnerMarker, _defer_task, _execute_task, + _make_task_span, _push_xcom_if_needed, _xcom_push, finalize, @@ -367,7 +368,7 @@ def test_parse_not_found_does_not_reschedule_when_max_attempts_reached(test_dags @mock.patch("airflow.sdk.execution_time.task_runner.get_startup_details") @mock.patch("airflow.sdk.execution_time.task_runner.CommsDecoder") def test_main_sends_reschedule_task_when_startup_reschedules( - mock_comms_decoder_cls, mock_get_startup_details, mock_startup, mock_exit, time_machine + mock_comms_decoder_cls, mock_get_startup_details, mock_startup, mock_exit, time_machine, make_ti_context ): """ If startup raises AirflowRescheduleException, the task runner should report a RescheduleTask @@ -379,7 +380,23 @@ def test_main_sends_reschedule_task_when_startup_reschedules( mock_comms_instance = mock.Mock() mock_comms_instance.socket = None mock_comms_decoder_cls.__getitem__.return_value.return_value = mock_comms_instance - mock_get_startup_details.return_value = mock.Mock() + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="my_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + context_carrier={}, + ), + dag_rel_path="", + bundle_info=BundleInfo(name="my-bundle", version=None), + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + sentry_integration="", + ) + mock_get_startup_details.return_value = what mock_startup.side_effect = AirflowRescheduleException(reschedule_date=reschedule_date) # Move time @@ -395,6 +412,102 @@ def test_main_sends_reschedule_task_when_startup_reschedules( ] +def test_task_span_is_child_of_dag_run_span(make_ti_context): + """Task span must be a child of the dag run span propagated via context_carrier.""" + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + # Build a real SDK provider and exporter so we can inspect finished spans. + in_mem_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) + + # Create a "dag run" span whose context we will propagate into the task. + dag_run_tracer = provider.get_tracer("dag_run") + with dag_run_tracer.start_as_current_span("dag_run.test_dag") as dag_run_span: + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + dag_run_span_ctx = dag_run_span.get_span_context() + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="my_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + context_carrier=carrier, + ), + dag_rel_path="", + bundle_info=BundleInfo(name="my-bundle", version=None), + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + sentry_integration="", + ) + + task_tracer = provider.get_tracer("airflow.sdk.execution_time.task_runner") + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", task_tracer): + with _make_task_span(what) as span: + task_span_ctx = span.get_span_context() + + # The task span must share the dag run's trace ID. + assert task_span_ctx.trace_id == dag_run_span_ctx.trace_id + + # The task span's parent must be the dag run span. + finished = in_mem_exporter.get_finished_spans() + task_spans = [s for s in finished if s.name == "task_run.my_task"] + assert len(task_spans) == 1 + assert task_spans[0].parent is not None + assert task_spans[0].parent.span_id == dag_run_span_ctx.span_id + + # Span attributes are set correctly. + attrs = task_spans[0].attributes + assert attrs["airflow.dag_id"] == "test_dag" + assert attrs["airflow.task_id"] == "my_task" + assert attrs["airflow.dag_run.run_id"] == "test_run" + assert attrs["airflow.task_instance.try_number"] == 1 + + +def test_task_span_no_parent_when_no_context_carrier(make_ti_context): + """When context_carrier is absent, the task span should be a root span (no parent).""" + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + in_mem_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="standalone_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + context_carrier=None, + ), + dag_rel_path="", + bundle_info=BundleInfo(name="my-bundle", version=None), + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + sentry_integration="", + ) + + task_tracer = provider.get_tracer("airflow.sdk.execution_time.task_runner") + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", task_tracer): + with _make_task_span(what): + pass + + finished = in_mem_exporter.get_finished_spans() + assert len(finished) == 1 + assert finished[0].parent is None + + def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): """Check that the bundle path is added to sys.path, so Dags can import shared modules.""" tmp_path.joinpath("util.py").write_text("NAME = 'dag_name'")