diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 67006a7d8ee27..ce6e0ae67900f 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -88,7 +88,7 @@ from airflow.observability.metrics import stats_utils from airflow.serialization.definitions.assets import SerializedAssetUniqueKey from airflow.serialization.definitions.notset import NOTSET -from airflow.ti_deps.dependencies_states import EXECUTION_STATES +from airflow.ti_deps.dependencies_states import EXECUTION_STATES, TASK_CONCURRENCY_EXECUTION_STATES from airflow.timetables.simple import AssetTriggeredTimetable from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import LoggingMixin @@ -205,15 +205,22 @@ def load(self, session: Session) -> None: self.dag_run_active_tasks_map.clear() self.task_concurrency_map.clear() self.task_dagrun_concurrency_map.clear() + # Use one grouped query on TASK_CONCURRENCY_EXECUTION_STATES to exclude DEFERRED + # from dag_run_active_tasks_map while still counting it for task-level limits. query = session.execute( - select(TI.dag_id, TI.task_id, TI.run_id, func.count("*")) - .where(TI.state.in_(EXECUTION_STATES)) - .group_by(TI.task_id, TI.run_id, TI.dag_id) + select(TI.dag_id, TI.task_id, TI.run_id, TI.state, func.count("*")) + .where(TI.state.in_(TASK_CONCURRENCY_EXECUTION_STATES)) + .group_by(TI.dag_id, TI.task_id, TI.run_id, TI.state) ) - for dag_id, task_id, run_id, c in query: - self.dag_run_active_tasks_map[dag_id, run_id] += c - self.task_concurrency_map[(dag_id, task_id)] += c - self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += c + for dag_id, task_id, run_id, state, count in query: + # Always count towards task-level concurrency (max_active_tis_per_dag / + # max_active_tis_per_dagrun), including DEFERRED. + self.task_concurrency_map[(dag_id, task_id)] += count + self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += count + # Only count non-deferred states towards DAG-run active tasks + # (max_active_tasks / worker slot accounting). + if state != TaskInstanceState.DEFERRED: + self.dag_run_active_tasks_map[dag_id, run_id] += count def _is_parent_process() -> bool: diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 964c345a3bf6c..fd7d1c8ac5eaf 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1916,7 +1916,15 @@ def xcom_pull( @provide_session def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int: - """Return Number of running TIs from the DB.""" + """ + Return number of active TIs for this task from the DB. + + Counts task instances in running, queued, or deferred state. + Deferred TIs are included because they are still logically in-flight + and must count against max_active_tis_per_dag / max_active_tis_per_dagrun. + """ + from airflow.ti_deps.dependencies_states import TASK_CONCURRENCY_EXECUTION_STATES + # .count() is inefficient num_running_task_instances_query = ( select(func.count()) @@ -1924,7 +1932,7 @@ def get_num_running_task_instances(self, session: Session, same_dagrun: bool = F .where( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, - TaskInstance.state == TaskInstanceState.RUNNING, + TaskInstance.state.in_(TASK_CONCURRENCY_EXECUTION_STATES), ) ) if same_dagrun: diff --git a/airflow-core/src/airflow/ti_deps/dependencies_states.py b/airflow-core/src/airflow/ti_deps/dependencies_states.py index ebf581ab48e18..a0840c0202939 100644 --- a/airflow-core/src/airflow/ti_deps/dependencies_states.py +++ b/airflow-core/src/airflow/ti_deps/dependencies_states.py @@ -23,6 +23,15 @@ TaskInstanceState.QUEUED, } +# States counted for task-level concurrency limits (max_active_tis_per_dag / +# max_active_tis_per_dagrun). Includes DEFERRED because a deferred task +# instance is still logically in-flight and must block additional instances +# from being scheduled. This is intentionally separate from EXECUTION_STATES +# so that DAG-level max_active_tasks and pool slot calculations are unaffected. +TASK_CONCURRENCY_EXECUTION_STATES = EXECUTION_STATES | { + TaskInstanceState.DEFERRED, +} + # In order to be able to get queued a task must have one of these states SCHEDULEABLE_STATES = { None, diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 5ccbb7e40568d..823636673cfd2 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -2124,6 +2124,276 @@ def test_find_executable_task_instances_max_active_tis_per_dag(self, dag_maker): assert len(res) == 1 session.rollback() + def test_find_executable_task_instances_max_active_tis_per_dag_deferred_blocks(self, dag_maker, session): + """ + A DEFERRED TI should count against max_active_tis_per_dag. + + When one TI is deferred, no additional TI of the same task should be + scheduled if the limit is already reached. + Regression test for https://github.com/apache/airflow/issues/61700 + """ + dag_id = "SchedulerJobTest.test_max_active_tis_per_dag_deferred_blocks" + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): + task1 = EmptyOperator(task_id="deferrable_task", max_active_tis_per_dag=1) + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session) + dr2 = dag_maker.create_dagrun_after( + dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session + ) + + # DR1's TI is deferred (waiting on a trigger) + ti1 = dr1.get_task_instance(task1.task_id, session) + ti1.state = TaskInstanceState.DEFERRED + session.merge(ti1) + + # DR2's TI is scheduled and wants to run + ti2 = dr2.get_task_instance(task1.task_id, session) + ti2.state = State.SCHEDULED + session.merge(ti2) + session.flush() + + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + # ti2 should be blocked because ti1 is deferred and counts as active + assert len(res) == 0 + session.rollback() + + def test_find_executable_task_instances_max_active_tis_per_dag_deferred_plus_running( + self, dag_maker, session + ): + """ + Deferred + running TIs together fill the max_active_tis_per_dag limit. + + With max_active_tis_per_dag=2 and one RUNNING + one DEFERRED, a third + SCHEDULED TI should be blocked. + """ + dag_id = "SchedulerJobTest.test_max_active_tis_per_dag_deferred_plus_running" + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): + task1 = EmptyOperator(task_id="task", max_active_tis_per_dag=2) + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session) + dr2 = dag_maker.create_dagrun_after( + dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session + ) + dr3 = dag_maker.create_dagrun_after( + dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session + ) + + ti1 = dr1.get_task_instance(task1.task_id, session) + ti1.state = TaskInstanceState.RUNNING + session.merge(ti1) + + ti2 = dr2.get_task_instance(task1.task_id, session) + ti2.state = TaskInstanceState.DEFERRED + session.merge(ti2) + + ti3 = dr3.get_task_instance(task1.task_id, session) + ti3.state = State.SCHEDULED + session.merge(ti3) + session.flush() + + # 1 running + 1 deferred = 2, which equals the limit + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + assert len(res) == 0 + session.rollback() + + def test_find_executable_task_instances_max_active_tis_per_dag_deferred_with_room( + self, dag_maker, session + ): + """ + With max_active_tis_per_dag=2 and only 1 deferred, one more TI + should be allowed to schedule. + """ + dag_id = "SchedulerJobTest.test_max_active_tis_per_dag_deferred_with_room" + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): + task1 = EmptyOperator(task_id="task", max_active_tis_per_dag=2) + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session) + dr2 = dag_maker.create_dagrun_after( + dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session + ) + dr3 = dag_maker.create_dagrun_after( + dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session + ) + + ti1 = dr1.get_task_instance(task1.task_id, session) + ti1.state = TaskInstanceState.DEFERRED + session.merge(ti1) + + ti2 = dr2.get_task_instance(task1.task_id, session) + ti2.state = State.SCHEDULED + session.merge(ti2) + + ti3 = dr3.get_task_instance(task1.task_id, session) + ti3.state = State.SCHEDULED + session.merge(ti3) + session.flush() + + # 1 deferred -> room for 1 more (limit is 2) + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + assert len(res) == 1 + session.rollback() + + def test_find_executable_task_instances_deferred_does_not_block_different_task( + self, dag_maker, session + ): + """ + A DEFERRED TI of task A should NOT block task B from scheduling. + + max_active_tis_per_dag is per-task, not per-DAG. + """ + dag_id = "SchedulerJobTest.test_deferred_does_not_block_different_task" + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): + task_a = EmptyOperator(task_id="task_a", max_active_tis_per_dag=1) + task_b = EmptyOperator(task_id="task_b") + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session) + dr2 = dag_maker.create_dagrun_after( + dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session + ) + + # task_a in DR1 is deferred + ti_a1 = dr1.get_task_instance(task_a.task_id, session) + ti_a1.state = TaskInstanceState.DEFERRED + session.merge(ti_a1) + + # task_a in DR2 is scheduled (should be blocked by deferred ti_a1) + ti_a2 = dr2.get_task_instance(task_a.task_id, session) + ti_a2.state = State.SCHEDULED + session.merge(ti_a2) + + # task_b in DR1 is scheduled (should NOT be blocked) + ti_b1 = dr1.get_task_instance(task_b.task_id, session) + ti_b1.state = State.SCHEDULED + session.merge(ti_b1) + session.flush() + + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + queued_task_ids = [ti.task_id for ti in res] + # task_b should be queued, task_a should be blocked + assert "task_b" in queued_task_ids + assert "task_a" not in queued_task_ids + session.rollback() + + def test_find_executable_task_instances_deferred_to_success_unblocks(self, dag_maker, session): + """ + When a deferred TI completes (SUCCESS), the next TI should be unblocked. + """ + dag_id = "SchedulerJobTest.test_deferred_to_success_unblocks" + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): + task1 = EmptyOperator(task_id="task", max_active_tis_per_dag=1) + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, run_id="run_1", session=session) + dr2 = dag_maker.create_dagrun_after( + dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session + ) + + ti1 = dr1.get_task_instance(task1.task_id, session) + ti2 = dr2.get_task_instance(task1.task_id, session) + + # Step 1: ti1 is deferred, ti2 scheduled -> ti2 blocked + ti1.state = TaskInstanceState.DEFERRED + ti2.state = State.SCHEDULED + session.merge(ti1) + session.merge(ti2) + session.flush() + + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + assert len(res) == 0 + + # Step 2: ti1 completes -> ti2 should be unblocked + ti1.state = TaskInstanceState.SUCCESS + session.merge(ti1) + session.flush() + + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + assert len(res) == 1 + assert res[0].key == ti2.key + session.rollback() + + def test_find_executable_task_instances_max_active_tis_per_dagrun_deferred(self, dag_maker, session): + """ + DEFERRED TIs should also count against max_active_tis_per_dagrun. + + With max_active_tis_per_dagrun=1 and 2 mapped instances in the same + dagrun, if one is deferred, the other should be blocked. + """ + dag_id = "SchedulerJobTest.test_max_active_tis_per_dagrun_deferred" + with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session): + task_a = EmptyOperator.partial( + task_id="task_a", max_active_tis_per_dagrun=1 + ).expand_kwargs([{"inputs": 1}, {"inputs": 2}]) + EmptyOperator(task_id="task_b") + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, session=session) + + ti_a0 = dr.get_task_instance(task_a.task_id, session, map_index=0) + ti_a1 = dr.get_task_instance(task_a.task_id, session, map_index=1) + + ti_a0.state = TaskInstanceState.DEFERRED + ti_a1.state = State.SCHEDULED + session.merge(ti_a0) + session.merge(ti_a1) + session.flush() + + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + queued_task_ids = [(ti.task_id, ti.map_index) for ti in res] + # ti_a1 should be blocked, task_b may be queued + assert ("task_a", 1) not in queued_task_ids + session.rollback() + + def test_find_executable_task_instances_deferred_does_not_affect_max_active_tasks( + self, dag_maker, session + ): + """ + Deferred TIs should NOT count toward max_active_tasks. + + max_active_tasks is about worker-level parallelism, while deferred tasks + don't consume worker slots. With max_active_tasks=2 and 1 deferred TI, + 2 more SCHEDULED TIs should be allowed. + """ + dag_id = "SchedulerJobTest.test_deferred_does_not_affect_max_active_tasks" + with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session): + EmptyOperator(task_id="task_1") + EmptyOperator(task_id="task_2") + EmptyOperator(task_id="task_3") + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, session=session) + t1, t2, t3 = sorted(dr.get_task_instances(session=session), key=lambda t: t.task_id) + + t1.state = TaskInstanceState.DEFERRED + t2.state = State.SCHEDULED + t3.state = State.SCHEDULED + session.merge(t1) + session.merge(t2) + session.merge(t3) + session.flush() + + res = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) + # Deferred doesn't count toward max_active_tasks=2, so both scheduled can run + assert len(res) == 2 + session.rollback() + def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_maker): dag_id = "SchedulerJobTest.test_change_state_for__no_tis_with_state" task_id_1 = "dummy" diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 9eba07daaa130..cc38d6ab74907 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -1580,8 +1580,11 @@ def test_get_num_running_task_instances(self, dag_maker, create_task_instance): assert ti3 in session session.commit() - assert ti1.get_num_running_task_instances(session=session) == 1 - assert ti2.get_num_running_task_instances(session=session) == 1 + # get_num_running_task_instances now counts RUNNING + QUEUED + DEFERRED. + # ti1 (RUNNING) and ti2 (QUEUED) share the same dag_id/task_id, so both + # see a count of 2. ti3 is in a different dag, so it sees 1. + assert ti1.get_num_running_task_instances(session=session) == 2 + assert ti2.get_num_running_task_instances(session=session) == 2 assert ti3.get_num_running_task_instances(session=session) == 1 def test_get_num_running_task_instances_per_dagrun(self, create_task_instance, dag_maker): @@ -1623,16 +1626,51 @@ def test_get_num_running_task_instances_per_dagrun(self, create_task_instance, d session.commit() - assert tis1[("task_1", 0)].get_num_running_task_instances(session=session, same_dagrun=True) == 1 - assert tis1[("task_1", 1)].get_num_running_task_instances(session=session, same_dagrun=True) == 1 + # With QUEUED now counted, task_1 in each dagrun has 2 (1 RUNNING + 1 QUEUED) + assert tis1[("task_1", 0)].get_num_running_task_instances(session=session, same_dagrun=True) == 2 + assert tis1[("task_1", 1)].get_num_running_task_instances(session=session, same_dagrun=True) == 2 assert tis1[("task_2", 0)].get_num_running_task_instances(session=session) == 2 assert tis1[("task_3", 0)].get_num_running_task_instances(session=session, same_dagrun=True) == 1 - assert tis2[("task_1", 0)].get_num_running_task_instances(session=session, same_dagrun=True) == 1 - assert tis2[("task_1", 1)].get_num_running_task_instances(session=session, same_dagrun=True) == 1 + assert tis2[("task_1", 0)].get_num_running_task_instances(session=session, same_dagrun=True) == 2 + assert tis2[("task_1", 1)].get_num_running_task_instances(session=session, same_dagrun=True) == 2 assert tis2[("task_2", 0)].get_num_running_task_instances(session=session) == 2 assert tis2[("task_3", 0)].get_num_running_task_instances(session=session, same_dagrun=True) == 1 + def test_get_num_running_task_instances_includes_deferred(self, dag_maker, create_task_instance): + """ + get_num_running_task_instances should count DEFERRED TIs. + + Regression test for https://github.com/apache/airflow/issues/61700 + """ + session = settings.Session() + + ti1 = create_task_instance( + dag_id="test_get_num_running_task_instances_deferred", task_id="task1", session=session + ) + + logical_date = DEFAULT_DATE + datetime.timedelta(days=1) + dr = dag_maker.create_dagrun( + logical_date=logical_date, + run_type=DagRunType.MANUAL, + state=None, + run_id="2", + session=session, + data_interval=(logical_date, logical_date), + run_after=logical_date, + triggered_by=DagRunTriggeredByType.TEST, + ) + ti2 = dr.task_instances[0] + ti2.task = ti1.task + + ti1.state = TaskInstanceState.RUNNING + ti2.state = TaskInstanceState.DEFERRED + session.commit() + + # Both RUNNING and DEFERRED should be counted + assert ti1.get_num_running_task_instances(session=session) == 2 + assert ti2.get_num_running_task_instances(session=session) == 2 + def test_log_url(self, create_task_instance): ti = create_task_instance(dag_id="my_dag", task_id="op", logical_date=timezone.datetime(2018, 1, 1)) diff --git a/airflow-core/tests/unit/ti_deps/deps/test_task_concurrency.py b/airflow-core/tests/unit/ti_deps/deps/test_task_concurrency.py index 4b7150f2cf7cc..c33ac2e7ee7cd 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_task_concurrency.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_task_concurrency.py @@ -47,6 +47,12 @@ def _get_task(self, **kwargs): ({"max_active_tis_per_dag": 1, "max_active_tis_per_dagrun": 2}, 1, False), ({"max_active_tis_per_dag": 2, "max_active_tis_per_dagrun": 1}, 1, False), ({"max_active_tis_per_dag": 1, "max_active_tis_per_dagrun": 1}, 1, False), + # Deferred-specific scenarios: the count returned by + # get_num_running_task_instances now includes DEFERRED TIs. + # 1 deferred TI fills a limit of 1 -> blocked + ({"max_active_tis_per_dag": 1}, 1, False), + # 1 deferred + 1 running = 2, limit 3 -> allowed + ({"max_active_tis_per_dag": 3}, 2, True), ], ) def test_concurrency(self, kwargs, num_running_tis, is_task_concurrency_dep_met):