Skip to content

Commit 3b16811

Browse files
author
Sameer Mesiah
committed
Fix upstream map index resolution after placeholder expansion with unit test.
1 parent 84110f4 commit 3b16811

2 files changed

Lines changed: 145 additions & 0 deletions

File tree

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2187,6 +2187,14 @@ def tg2(inp):
21872187
# and "ti_count == ancestor_ti_count" does not work, since the further
21882188
# expansion may be of length 1.
21892189
if not _is_further_mapped_inside(relative, common_ancestor):
2190+
placeholder_index = resolve_placeholder_map_index(
2191+
task=task, relative=relative, map_index=ancestor_map_index, run_id=run_id, session=session
2192+
)
2193+
# Handle cases where an upstream mapped placeholder (map_index = -1) has already
2194+
# been expanded and replaced by its successor (map_index = 0) at evaluation time.
2195+
if placeholder_index is not None:
2196+
return placeholder_index
2197+
21902198
return ancestor_map_index
21912199

21922200
# Otherwise we need a partial aggregation for values from selected task
@@ -2261,6 +2269,54 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]
22612269
return visited
22622270

22632271

2272+
def resolve_placeholder_map_index(
2273+
*,
2274+
task: Operator,
2275+
relative: Operator,
2276+
map_index: int,
2277+
run_id: str,
2278+
session: Session,
2279+
) -> int | None:
2280+
"""
2281+
Resolve the correct map_index for upstream dependency evaluation.
2282+
2283+
This handles the transition from map_index = -1 (pre-expansion placeholder)
2284+
to map_index = 0 (post-expansion placeholder successor).
2285+
2286+
Returns:
2287+
- 0 if the placeholder has transitioned from -1 to 0
2288+
- None if no override should be applied
2289+
"""
2290+
if map_index != -1:
2291+
return None
2292+
2293+
rows = session.execute(
2294+
select(TaskInstance.task_id, TaskInstance.map_index).where(
2295+
TaskInstance.dag_id == relative.dag_id,
2296+
TaskInstance.run_id == run_id,
2297+
TaskInstance.task_id.in_([task.task_id, relative.task_id]),
2298+
TaskInstance.map_index.in_([-1, 0]),
2299+
)
2300+
).all()
2301+
2302+
task_to_map_indexes: dict[str, list[int]] = defaultdict(list)
2303+
for task_id, mi in rows:
2304+
task_to_map_indexes[task_id].append(mi)
2305+
2306+
# We only rewrite when:
2307+
# 1) the current task is still using the placeholder (-1)
2308+
# 2) the upstream placeholder (-1) no longer exists
2309+
# 3) the post-expansion placeholder (0) does exist
2310+
if (
2311+
-1 in task_to_map_indexes.get(task.task_id, [])
2312+
and -1 not in task_to_map_indexes.get(relative.task_id, [])
2313+
and 0 in task_to_map_indexes.get(relative.task_id, [])
2314+
):
2315+
return 0
2316+
2317+
return None
2318+
2319+
22642320
class TaskInstanceNote(Base):
22652321
"""For storage of arbitrary notes concerning the task instance."""
22662322

airflow-core/tests/unit/models/test_taskinstance.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3021,6 +3021,95 @@ def g(v):
30213021
assert result == expected
30223022

30233023

3024+
def test_downstream_placeholder_handles_upstream_post_expansion(dag_maker, session):
3025+
"""
3026+
Test dynamic task mapping behavior when an upstream placeholder task
3027+
(map_index = -1) has been replaced by the first expanded task
3028+
(map_index = 0).
3029+
3030+
This verifies that trigger rule evaluation correctly resolves relevant
3031+
upstream map indexes both when referencing the original placeholder
3032+
and when referencing the first expanded task instance.
3033+
"""
3034+
3035+
with dag_maker(session=session) as dag:
3036+
3037+
@task
3038+
def get_mapping_source():
3039+
return ["one", "two", "three"]
3040+
3041+
@task
3042+
def mapped_task(x):
3043+
output = f"{x}"
3044+
return output
3045+
3046+
@task_group(prefix_group_id=False)
3047+
def the_task_group(x):
3048+
start = MockOperator(task_id="start")
3049+
upstream = mapped_task(x)
3050+
3051+
# Plain downstream inside task group (no mapping source).
3052+
downstream = MockOperator(task_id="downstream")
3053+
3054+
start >> upstream >> downstream
3055+
3056+
mapping_source = get_mapping_source()
3057+
mapped_tg = the_task_group.expand(x=mapping_source)
3058+
3059+
mapping_source >> mapped_tg
3060+
3061+
# Create DAG run and execute prerequisites.
3062+
dr = dag_maker.create_dagrun()
3063+
3064+
dag_maker.run_ti("get_mapping_source", map_index=-1, dag_run=dr, session=session)
3065+
3066+
# Force expansion of the upstream mapped task.
3067+
upstream_task = dag.get_task("mapped_task")
3068+
_, max_index = TaskMap.expand_mapped_task(
3069+
upstream_task,
3070+
dr.run_id,
3071+
session=session,
3072+
)
3073+
expanded_ti_count = max_index + 1
3074+
3075+
downstream_task = dag.get_task("downstream")
3076+
3077+
# Grab the downstream placeholder TI.
3078+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=-1, session=session)
3079+
downstream_ti.refresh_from_task(downstream_task)
3080+
3081+
result = downstream_ti.get_relevant_upstream_map_indexes(
3082+
upstream=upstream_task,
3083+
ti_count=expanded_ti_count,
3084+
session=session,
3085+
)
3086+
3087+
assert result == 0
3088+
3089+
# Now do the same for downstream expanded (map_index = 0) to ensure existing behavior is not broken.
3090+
# Force expansion of the downstream mapped task.
3091+
_, max_index = TaskMap.expand_mapped_task(
3092+
downstream_task,
3093+
dr.run_id,
3094+
session=session,
3095+
)
3096+
expanded_ti_count = max_index + 1
3097+
3098+
# Grab the first expanded downstream task. Behavior is the same for all cases where map_index >= 0.
3099+
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=0, session=session)
3100+
downstream_ti.refresh_from_task(downstream_task)
3101+
3102+
result = downstream_ti.get_relevant_upstream_map_indexes(
3103+
upstream=upstream_task,
3104+
ti_count=expanded_ti_count,
3105+
session=session,
3106+
)
3107+
3108+
# Verify behavior remains unchanged once the downstream task itself
3109+
# has expanded (map_index >= 0).
3110+
assert result == 0
3111+
3112+
30243113
def test_find_relevant_relatives_with_non_mapped_task_as_tuple(dag_maker, session):
30253114
"""Test that specifying a non-mapped task as a tuple doesn't raise NotMapped exception."""
30263115
# t1 -> t2 (non-mapped) -> t3

0 commit comments

Comments
 (0)