diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 475cbd7ae68fb..16bc80846dd10 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -2094,6 +2094,14 @@ def tg2(inp): # and "ti_count == ancestor_ti_count" does not work, since the further # expansion may be of length 1. if not _is_further_mapped_inside(relative, common_ancestor): + placeholder_index = resolve_placeholder_map_index( + task=task, relative=relative, map_index=ancestor_map_index, run_id=run_id, session=session + ) + # Handle cases where an upstream mapped placeholder (map_index = -1) has already + # been expanded and replaced by its successor (map_index = 0) at evaluation time. + if placeholder_index is not None: + return placeholder_index + return ancestor_map_index # Otherwise we need a partial aggregation for values from selected task @@ -2168,6 +2176,54 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]] return visited +def resolve_placeholder_map_index( + *, + task: Operator, + relative: Operator, + map_index: int, + run_id: str, + session: Session, +) -> int | None: + """ + Resolve the correct map_index for upstream dependency evaluation. + + This handles the transition from map_index = -1 (pre-expansion placeholder) + to map_index = 0 (post-expansion placeholder successor). + + Returns: + - 0 if the placeholder has transitioned from -1 to 0 + - None if no override should be applied + """ + if map_index != -1: + return None + + rows = session.execute( + select(TaskInstance.task_id, TaskInstance.map_index).where( + TaskInstance.dag_id == relative.dag_id, + TaskInstance.run_id == run_id, + TaskInstance.task_id.in_([task.task_id, relative.task_id]), + TaskInstance.map_index.in_([-1, 0]), + ) + ).all() + + task_to_map_indexes: dict[str, list[int]] = defaultdict(list) + for task_id, mi in rows: + task_to_map_indexes[task_id].append(mi) + + # We only rewrite when: + # 1) the current task is still using the placeholder (-1) + # 2) the upstream placeholder (-1) no longer exists + # 3) the post-expansion placeholder (0) does exist + if ( + -1 in task_to_map_indexes.get(task.task_id, []) + and -1 not in task_to_map_indexes.get(relative.task_id, []) + and 0 in task_to_map_indexes.get(relative.task_id, []) + ): + return 0 + + return None + + class TaskInstanceNote(Base): """For storage of arbitrary notes concerning the task instance.""" diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index a013b09cdb0c8..1898fe8b4bbdc 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -3037,6 +3037,95 @@ def g(v): assert result == expected +def test_downstream_placeholder_handles_upstream_post_expansion(dag_maker, session): + """ + Test dynamic task mapping behavior when an upstream placeholder task + (map_index = -1) has been replaced by the first expanded task + (map_index = 0). + + This verifies that trigger rule evaluation correctly resolves relevant + upstream map indexes both when referencing the original placeholder + and when referencing the first expanded task instance. + """ + + with dag_maker(session=session) as dag: + + @task + def get_mapping_source(): + return ["one", "two", "three"] + + @task + def mapped_task(x): + output = f"{x}" + return output + + @task_group(prefix_group_id=False) + def the_task_group(x): + start = MockOperator(task_id="start") + upstream = mapped_task(x) + + # Plain downstream inside task group (no mapping source). + downstream = MockOperator(task_id="downstream") + + start >> upstream >> downstream + + mapping_source = get_mapping_source() + mapped_tg = the_task_group.expand(x=mapping_source) + + mapping_source >> mapped_tg + + # Create DAG run and execute prerequisites. + dr = dag_maker.create_dagrun() + + dag_maker.run_ti("get_mapping_source", map_index=-1, dag_run=dr, session=session) + + # Force expansion of the upstream mapped task. + upstream_task = dag.get_task("mapped_task") + _, max_index = TaskMap.expand_mapped_task( + upstream_task, + dr.run_id, + session=session, + ) + expanded_ti_count = max_index + 1 + + downstream_task = dag.get_task("downstream") + + # Grab the downstream placeholder TI. + downstream_ti = dr.get_task_instance(task_id="downstream", map_index=-1, session=session) + downstream_ti.refresh_from_task(downstream_task) + + result = downstream_ti.get_relevant_upstream_map_indexes( + upstream=upstream_task, + ti_count=expanded_ti_count, + session=session, + ) + + assert result == 0 + + # Now do the same for downstream expanded (map_index = 0) to ensure existing behavior is not broken. + # Force expansion of the downstream mapped task. + _, max_index = TaskMap.expand_mapped_task( + downstream_task, + dr.run_id, + session=session, + ) + expanded_ti_count = max_index + 1 + + # Grab the first expanded downstream task. Behavior is the same for all cases where map_index >= 0. + downstream_ti = dr.get_task_instance(task_id="downstream", map_index=0, session=session) + downstream_ti.refresh_from_task(downstream_task) + + result = downstream_ti.get_relevant_upstream_map_indexes( + upstream=upstream_task, + ti_count=expanded_ti_count, + session=session, + ) + + # Verify behavior remains unchanged once the downstream task itself + # has expanded (map_index >= 0). + assert result == 0 + + def test_find_relevant_relatives_with_non_mapped_task_as_tuple(dag_maker, session): """Test that specifying a non-mapped task as a tuple doesn't raise NotMapped exception.""" # t1 -> t2 (non-mapped) -> t3