Skip to content

Commit 8b40864

Browse files
committed
test: regression tests for #5282 (after_run_callback on BaseNode roots)
Three tests for Runner._run_node_async plugin lifecycle on Workflow roots: - Baseline: pre-run + event hooks fire (PASSED on 2.0.0a3) - Regression anchor: after_run_callback strict xfail (remove when runners.py:427 TODO is wired — flips green as the signal) - Workaround proof: Runner subclass wrapping run_async dispatches after_run_callback post-drain (PASSED on 2.0.0a3) Closes #5282
1 parent abcf14c commit 8b40864

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Regression tests for google/adk-python#5282.
16+
17+
Runner._run_node_async (the dispatch path for Workflow / BaseNode roots)
18+
dispatches on_user_message_callback, before_run_callback, and
19+
on_event_callback, but does not dispatch run_after_run_callback
20+
(runners.py:427 TODO).
21+
22+
Three tests:
23+
(a) Baseline — pre-run and event hooks DO fire on a Workflow root.
24+
(b) Regression anchor — after_run_callback does NOT fire (strict xfail).
25+
Remove the xfail when the TODO at runners.py:427 is wired.
26+
(c) Workaround proof — Runner subclass wrapping run_async restores
27+
after_run_callback dispatch without touching ADK source.
28+
29+
Concurrency note for the WorkaroundRunner pattern: under concurrent
30+
run_async calls on a shared Runner, the _last_ic stash should live on a
31+
contextvars.ContextVar rather than self. The instance attribute is safe
32+
for single-invocation tests but will race under concurrent load.
33+
"""
34+
from __future__ import annotations
35+
36+
from dataclasses import dataclass
37+
from typing import AsyncGenerator
38+
from typing import Optional
39+
40+
from google.adk.agents.invocation_context import InvocationContext
41+
from google.adk.apps.app import App
42+
from google.adk.events.event import Event
43+
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
44+
from google.adk.plugins.base_plugin import BasePlugin
45+
from google.adk.runners import Runner
46+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
47+
from google.adk.workflow import Workflow
48+
from google.genai import types
49+
import pytest
50+
51+
APP_NAME = "issue_5282_repro"
52+
USER_ID = "u1"
53+
54+
55+
# ---------------------------------------------------------------------------
56+
# Fixtures
57+
# ---------------------------------------------------------------------------
58+
59+
60+
@dataclass
61+
class CallbackCounts:
62+
on_user_message_callback: int = 0
63+
before_run_callback: int = 0
64+
on_event_callback: int = 0
65+
after_run_callback: int = 0
66+
67+
68+
class TracerPlugin(BasePlugin):
69+
"""Counts every Plugin lifecycle callback the Runner dispatches."""
70+
71+
__test__ = False
72+
73+
def __init__(self) -> None:
74+
super().__init__(name="tracer")
75+
self.counts = CallbackCounts()
76+
77+
async def on_user_message_callback(
78+
self,
79+
*,
80+
invocation_context: InvocationContext,
81+
user_message: types.Content,
82+
) -> Optional[types.Content]:
83+
self.counts.on_user_message_callback += 1
84+
return None
85+
86+
async def before_run_callback(
87+
self, *, invocation_context: InvocationContext
88+
) -> Optional[types.Content]:
89+
self.counts.before_run_callback += 1
90+
return None
91+
92+
async def on_event_callback(
93+
self, *, invocation_context: InvocationContext, event: Event
94+
) -> Optional[Event]:
95+
self.counts.on_event_callback += 1
96+
return None
97+
98+
async def after_run_callback(
99+
self, *, invocation_context: InvocationContext
100+
) -> None:
101+
self.counts.after_run_callback += 1
102+
return None
103+
104+
105+
async def _terminal_node(ctx) -> Event:
106+
"""Minimal terminal node yielding a content-bearing Event.
107+
108+
Content (not just state) ensures _consume_event_queue runs the
109+
on_event_callback path — the canonical case the plugin hook targets.
110+
"""
111+
return Event(
112+
content=types.Content(
113+
parts=[types.Part(text="done")],
114+
role="model",
115+
)
116+
)
117+
118+
119+
def _build_runner(
120+
plugin: TracerPlugin, *, runner_cls: type[Runner] = Runner
121+
) -> Runner:
122+
workflow = Workflow(
123+
name="Issue5282Repro", edges=[("START", _terminal_node)]
124+
)
125+
app = App(name=APP_NAME, root_agent=workflow, plugins=[plugin])
126+
return runner_cls(
127+
app_name=APP_NAME,
128+
app=app,
129+
session_service=InMemorySessionService(),
130+
memory_service=InMemoryMemoryService(),
131+
)
132+
133+
134+
async def _drive_one_invocation(runner: Runner) -> None:
135+
session = await runner.session_service.create_session(
136+
app_name=APP_NAME, user_id=USER_ID
137+
)
138+
async for _ in runner.run_async(
139+
user_id=USER_ID,
140+
session_id=session.id,
141+
new_message=types.Content(
142+
parts=[types.Part(text="hi")], role="user"
143+
),
144+
):
145+
pass
146+
147+
148+
# ---------------------------------------------------------------------------
149+
# Workaround: Runner subclass dispatching run_after_run_callback post-drain
150+
# ---------------------------------------------------------------------------
151+
152+
153+
class WorkaroundRunner(Runner):
154+
"""Interim workaround for #5282.
155+
156+
Wraps run_async to dispatch plugin_manager.run_after_run_callback once
157+
the inner generator drains. Captures the active InvocationContext via
158+
_new_invocation_context (called once at runners.py:446).
159+
160+
Drop this class when the runners.py:427 TODO is resolved — the stock
161+
Runner will dispatch after_run_callback natively, and
162+
test_workflow_root_after_run_callback_not_dispatched will flip green
163+
as the signal.
164+
"""
165+
166+
def __init__(self, *args, **kwargs) -> None:
167+
super().__init__(*args, **kwargs)
168+
self._last_ic: Optional[InvocationContext] = None
169+
170+
def _new_invocation_context(self, session, **kwargs) -> InvocationContext:
171+
ic = super()._new_invocation_context(session, **kwargs)
172+
self._last_ic = ic
173+
return ic
174+
175+
async def run_async(self, **kwargs) -> AsyncGenerator[Event, None]:
176+
async for event in super().run_async(**kwargs):
177+
yield event
178+
ic = self._last_ic
179+
if ic is not None:
180+
await ic.plugin_manager.run_after_run_callback(
181+
invocation_context=ic
182+
)
183+
184+
185+
# ---------------------------------------------------------------------------
186+
# Tests
187+
# ---------------------------------------------------------------------------
188+
189+
190+
@pytest.mark.asyncio
191+
async def test_workflow_root_dispatches_pre_run_and_event_hooks():
192+
"""Baseline: pre-run and event hooks fire on a Workflow (BaseNode) root."""
193+
plugin = TracerPlugin()
194+
runner = _build_runner(plugin)
195+
196+
await _drive_one_invocation(runner)
197+
198+
assert plugin.counts.on_user_message_callback == 1
199+
assert plugin.counts.before_run_callback == 1
200+
assert plugin.counts.on_event_callback >= 1, (
201+
"on_event_callback should fire via _consume_event_queue "
202+
"(runners.py:619) for the content-bearing terminal event"
203+
)
204+
205+
206+
@pytest.mark.xfail(
207+
reason=(
208+
"#5282: runners.py:427 TODO — _run_node_async does not dispatch "
209+
"plugin_manager.run_after_run_callback on the BaseNode path. "
210+
"Remove this xfail when the TODO lands."
211+
),
212+
strict=True,
213+
)
214+
@pytest.mark.asyncio
215+
async def test_workflow_root_after_run_callback_not_dispatched():
216+
"""Regression anchor: stock Runner does NOT fire after_run_callback.
217+
218+
Strict xfail — passes (as xfail) while the bug exists, fails loudly if
219+
after_run_callback starts firing unexpectedly. When the fix lands, delete
220+
the @xfail decorator and the test becomes a green regression guard.
221+
"""
222+
plugin = TracerPlugin()
223+
runner = _build_runner(plugin)
224+
225+
await _drive_one_invocation(runner)
226+
227+
assert plugin.counts.after_run_callback == 1
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_workaround_runner_dispatches_after_run_callback():
232+
"""WorkaroundRunner restores after_run_callback without touching ADK source."""
233+
plugin = TracerPlugin()
234+
runner = _build_runner(plugin, runner_cls=WorkaroundRunner)
235+
236+
await _drive_one_invocation(runner)
237+
238+
assert plugin.counts.on_user_message_callback == 1
239+
assert plugin.counts.before_run_callback == 1
240+
assert plugin.counts.on_event_callback >= 1
241+
assert plugin.counts.after_run_callback == 1

0 commit comments

Comments
 (0)