Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from airflow.observability.metrics import stats_utils
from airflow.serialization.definitions.assets import SerializedAssetUniqueKey
from airflow.serialization.definitions.notset import NOTSET
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.ti_deps.dependencies_states import EXECUTION_STATES, TASK_CONCURRENCY_EXECUTION_STATES
from airflow.timetables.simple import AssetTriggeredTimetable
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -205,15 +205,22 @@ def load(self, session: Session) -> None:
self.dag_run_active_tasks_map.clear()
self.task_concurrency_map.clear()
self.task_dagrun_concurrency_map.clear()
# Use one grouped query on TASK_CONCURRENCY_EXECUTION_STATES to exclude DEFERRED
# from dag_run_active_tasks_map while still counting it for task-level limits.
query = session.execute(
select(TI.dag_id, TI.task_id, TI.run_id, func.count("*"))
.where(TI.state.in_(EXECUTION_STATES))
.group_by(TI.task_id, TI.run_id, TI.dag_id)
select(TI.dag_id, TI.task_id, TI.run_id, TI.state, func.count("*"))
.where(TI.state.in_(TASK_CONCURRENCY_EXECUTION_STATES))
.group_by(TI.dag_id, TI.task_id, TI.run_id, TI.state)
)
for dag_id, task_id, run_id, c in query:
self.dag_run_active_tasks_map[dag_id, run_id] += c
self.task_concurrency_map[(dag_id, task_id)] += c
self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += c
for dag_id, task_id, run_id, state, count in query:
# Always count towards task-level concurrency (max_active_tis_per_dag /
# max_active_tis_per_dagrun), including DEFERRED.
self.task_concurrency_map[(dag_id, task_id)] += count
self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += count
# Only count non-deferred states towards DAG-run active tasks
# (max_active_tasks / worker slot accounting).
if state != TaskInstanceState.DEFERRED:
self.dag_run_active_tasks_map[dag_id, run_id] += count


def _is_parent_process() -> bool:
Expand Down
12 changes: 10 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,15 +1916,23 @@ def xcom_pull(

@provide_session
def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
"""Return Number of running TIs from the DB."""
"""
Return number of active TIs for this task from the DB.

Counts task instances in running, queued, or deferred state.
Deferred TIs are included because they are still logically in-flight
and must count against max_active_tis_per_dag / max_active_tis_per_dagrun.
"""
from airflow.ti_deps.dependencies_states import TASK_CONCURRENCY_EXECUTION_STATES

# .count() is inefficient
num_running_task_instances_query = (
select(func.count())
.select_from(TaskInstance)
.where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state == TaskInstanceState.RUNNING,
TaskInstance.state.in_(TASK_CONCURRENCY_EXECUTION_STATES),
)
)
if same_dagrun:
Expand Down
9 changes: 9 additions & 0 deletions airflow-core/src/airflow/ti_deps/dependencies_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
TaskInstanceState.QUEUED,
}

# States counted for task-level concurrency limits (max_active_tis_per_dag /
# max_active_tis_per_dagrun). Includes DEFERRED because a deferred task
# instance is still logically in-flight and must block additional instances
# from being scheduled. This is intentionally separate from EXECUTION_STATES
# so that DAG-level max_active_tasks and pool slot calculations are unaffected.
TASK_CONCURRENCY_EXECUTION_STATES = EXECUTION_STATES | {
TaskInstanceState.DEFERRED,
}

# In order to be able to get queued a task must have one of these states
SCHEDULEABLE_STATES = {
None,
Expand Down
270 changes: 270 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,276 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker):
assert len(res) == 1
session.rollback()

def test_find_executable_task_instances_max_active_tis_per_dag_deferred_blocks(self, dag_maker, session):
"""
A DEFERRED TI should count against max_active_tis_per_dag.

When one TI is deferred, no additional TI of the same task should be
scheduled if the limit is already reached.
Regression test for https://github.com/apache/airflow/issues/61700
"""
dag_id = "SchedulerJobTest.test_max_active_tis_per_dag_deferred_blocks"
with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session):
task1 = EmptyOperator(task_id="deferrable_task", max_active_tis_per_dag=1)

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session)
dr2 = dag_maker.create_dagrun_after(
dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
)

# DR1's TI is deferred (waiting on a trigger)
ti1 = dr1.get_task_instance(task1.task_id, session)
ti1.state = TaskInstanceState.DEFERRED
session.merge(ti1)

# DR2's TI is scheduled and wants to run
ti2 = dr2.get_task_instance(task1.task_id, session)
ti2.state = State.SCHEDULED
session.merge(ti2)
session.flush()

res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
# ti2 should be blocked because ti1 is deferred and counts as active
assert len(res) == 0
session.rollback()

def test_find_executable_task_instances_max_active_tis_per_dag_deferred_plus_running(
self, dag_maker, session
):
"""
Deferred + running TIs together fill the max_active_tis_per_dag limit.

With max_active_tis_per_dag=2 and one RUNNING + one DEFERRED, a third
SCHEDULED TI should be blocked.
"""
dag_id = "SchedulerJobTest.test_max_active_tis_per_dag_deferred_plus_running"
with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session):
task1 = EmptyOperator(task_id="task", max_active_tis_per_dag=2)

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session)
dr2 = dag_maker.create_dagrun_after(
dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
)
dr3 = dag_maker.create_dagrun_after(
dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session
)

ti1 = dr1.get_task_instance(task1.task_id, session)
ti1.state = TaskInstanceState.RUNNING
session.merge(ti1)

ti2 = dr2.get_task_instance(task1.task_id, session)
ti2.state = TaskInstanceState.DEFERRED
session.merge(ti2)

ti3 = dr3.get_task_instance(task1.task_id, session)
ti3.state = State.SCHEDULED
session.merge(ti3)
session.flush()

# 1 running + 1 deferred = 2, which equals the limit
res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
assert len(res) == 0
session.rollback()

def test_find_executable_task_instances_max_active_tis_per_dag_deferred_with_room(
self, dag_maker, session
):
"""
With max_active_tis_per_dag=2 and only 1 deferred, one more TI
should be allowed to schedule.
"""
dag_id = "SchedulerJobTest.test_max_active_tis_per_dag_deferred_with_room"
with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session):
task1 = EmptyOperator(task_id="task", max_active_tis_per_dag=2)

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session)
dr2 = dag_maker.create_dagrun_after(
dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
)
dr3 = dag_maker.create_dagrun_after(
dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session
)

ti1 = dr1.get_task_instance(task1.task_id, session)
ti1.state = TaskInstanceState.DEFERRED
session.merge(ti1)

ti2 = dr2.get_task_instance(task1.task_id, session)
ti2.state = State.SCHEDULED
session.merge(ti2)

ti3 = dr3.get_task_instance(task1.task_id, session)
ti3.state = State.SCHEDULED
session.merge(ti3)
session.flush()

# 1 deferred -> room for 1 more (limit is 2)
res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
assert len(res) == 1
session.rollback()

def test_find_executable_task_instances_deferred_does_not_block_different_task(
self, dag_maker, session
):
"""
A DEFERRED TI of task A should NOT block task B from scheduling.

max_active_tis_per_dag is per-task, not per-DAG.
"""
dag_id = "SchedulerJobTest.test_deferred_does_not_block_different_task"
with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session):
task_a = EmptyOperator(task_id="task_a", max_active_tis_per_dag=1)
task_b = EmptyOperator(task_id="task_b")

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session)
dr2 = dag_maker.create_dagrun_after(
dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
)

# task_a in DR1 is deferred
ti_a1 = dr1.get_task_instance(task_a.task_id, session)
ti_a1.state = TaskInstanceState.DEFERRED
session.merge(ti_a1)

# task_a in DR2 is scheduled (should be blocked by deferred ti_a1)
ti_a2 = dr2.get_task_instance(task_a.task_id, session)
ti_a2.state = State.SCHEDULED
session.merge(ti_a2)

# task_b in DR1 is scheduled (should NOT be blocked)
ti_b1 = dr1.get_task_instance(task_b.task_id, session)
ti_b1.state = State.SCHEDULED
session.merge(ti_b1)
session.flush()

res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
queued_task_ids = [ti.task_id for ti in res]
# task_b should be queued, task_a should be blocked
assert "task_b" in queued_task_ids
assert "task_a" not in queued_task_ids
session.rollback()

def test_find_executable_task_instances_deferred_to_success_unblocks(self, dag_maker, session):
"""
When a deferred TI completes (SUCCESS), the next TI should be unblocked.
"""
dag_id = "SchedulerJobTest.test_deferred_to_success_unblocks"
with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session):
task1 = EmptyOperator(task_id="task", max_active_tis_per_dag=1)

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session)
dr2 = dag_maker.create_dagrun_after(
dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
)

ti1 = dr1.get_task_instance(task1.task_id, session)
ti2 = dr2.get_task_instance(task1.task_id, session)

# Step 1: ti1 is deferred, ti2 scheduled -> ti2 blocked
ti1.state = TaskInstanceState.DEFERRED
ti2.state = State.SCHEDULED
session.merge(ti1)
session.merge(ti2)
session.flush()

res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
assert len(res) == 0

# Step 2: ti1 completes -> ti2 should be unblocked
ti1.state = TaskInstanceState.SUCCESS
session.merge(ti1)
session.flush()

res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
assert len(res) == 1
assert res[0].key == ti2.key
session.rollback()

def test_find_executable_task_instances_max_active_tis_per_dagrun_deferred(self, dag_maker, session):
"""
DEFERRED TIs should also count against max_active_tis_per_dagrun.

With max_active_tis_per_dagrun=1 and 2 mapped instances in the same
dagrun, if one is deferred, the other should be blocked.
"""
dag_id = "SchedulerJobTest.test_max_active_tis_per_dagrun_deferred"
with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session):
task_a = EmptyOperator.partial(
task_id="task_a", max_active_tis_per_dagrun=1
).expand_kwargs([{"inputs": 1}, {"inputs": 2}])
EmptyOperator(task_id="task_b")

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, session=session)

ti_a0 = dr.get_task_instance(task_a.task_id, session, map_index=0)
ti_a1 = dr.get_task_instance(task_a.task_id, session, map_index=1)

ti_a0.state = TaskInstanceState.DEFERRED
ti_a1.state = State.SCHEDULED
session.merge(ti_a0)
session.merge(ti_a1)
session.flush()

res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
queued_task_ids = [(ti.task_id, ti.map_index) for ti in res]
# ti_a1 should be blocked, task_b may be queued
assert ("task_a", 1) not in queued_task_ids
session.rollback()

def test_find_executable_task_instances_deferred_does_not_affect_max_active_tasks(
self, dag_maker, session
):
"""
Deferred TIs should NOT count toward max_active_tasks.

max_active_tasks is about worker-level parallelism, while deferred tasks
don't consume worker slots. With max_active_tasks=2 and 1 deferred TI,
2 more SCHEDULED TIs should be allowed.
"""
dag_id = "SchedulerJobTest.test_deferred_does_not_affect_max_active_tasks"
with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session):
EmptyOperator(task_id="task_1")
EmptyOperator(task_id="task_2")
EmptyOperator(task_id="task_3")

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, session=session)
t1, t2, t3 = sorted(dr.get_task_instances(session=session), key=lambda t: t.task_id)

t1.state = TaskInstanceState.DEFERRED
t2.state = State.SCHEDULED
t3.state = State.SCHEDULED
session.merge(t1)
session.merge(t2)
session.merge(t3)
session.flush()

res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session)
# Deferred doesn't count toward max_active_tasks=2, so both scheduled can run
assert len(res) == 2
session.rollback()

def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_maker):
dag_id = "SchedulerJobTest.test_change_state_for__no_tis_with_state"
task_id_1 = "dummy"
Expand Down
Loading
Loading