|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Regression tests for google/adk-python#5282. |
| 16 | +
|
| 17 | +Runner._run_node_async (the dispatch path for Workflow / BaseNode roots) |
| 18 | +dispatches on_user_message_callback, before_run_callback, and |
| 19 | +on_event_callback, but does not dispatch run_after_run_callback |
| 20 | +(runners.py:427 TODO). |
| 21 | +
|
| 22 | +Three tests: |
| 23 | + (a) Baseline — pre-run and event hooks DO fire on a Workflow root. |
| 24 | + (b) Regression anchor — after_run_callback does NOT fire (strict xfail). |
| 25 | + Remove the xfail when the TODO at runners.py:427 is wired. |
| 26 | + (c) Workaround proof — Runner subclass wrapping run_async restores |
| 27 | + after_run_callback dispatch without touching ADK source. |
| 28 | +
|
| 29 | +Concurrency note for the WorkaroundRunner pattern: under concurrent |
| 30 | +run_async calls on a shared Runner, the _last_ic stash should live on a |
| 31 | +contextvars.ContextVar rather than self. The instance attribute is safe |
| 32 | +for single-invocation tests but will race under concurrent load. |
| 33 | +""" |
| 34 | +from __future__ import annotations |
| 35 | + |
| 36 | +from dataclasses import dataclass |
| 37 | +from typing import AsyncGenerator |
| 38 | +from typing import Optional |
| 39 | + |
| 40 | +from google.adk.agents.invocation_context import InvocationContext |
| 41 | +from google.adk.apps.app import App |
| 42 | +from google.adk.events.event import Event |
| 43 | +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService |
| 44 | +from google.adk.plugins.base_plugin import BasePlugin |
| 45 | +from google.adk.runners import Runner |
| 46 | +from google.adk.sessions.in_memory_session_service import InMemorySessionService |
| 47 | +from google.adk.workflow import Workflow |
| 48 | +from google.genai import types |
| 49 | +import pytest |
| 50 | + |
| 51 | +APP_NAME = "issue_5282_repro" |
| 52 | +USER_ID = "u1" |
| 53 | + |
| 54 | + |
| 55 | +# --------------------------------------------------------------------------- |
| 56 | +# Fixtures |
| 57 | +# --------------------------------------------------------------------------- |
| 58 | + |
| 59 | + |
| 60 | +@dataclass |
| 61 | +class CallbackCounts: |
| 62 | + on_user_message_callback: int = 0 |
| 63 | + before_run_callback: int = 0 |
| 64 | + on_event_callback: int = 0 |
| 65 | + after_run_callback: int = 0 |
| 66 | + |
| 67 | + |
| 68 | +class TracerPlugin(BasePlugin): |
| 69 | + """Counts every Plugin lifecycle callback the Runner dispatches.""" |
| 70 | + |
| 71 | + __test__ = False |
| 72 | + |
| 73 | + def __init__(self) -> None: |
| 74 | + super().__init__(name="tracer") |
| 75 | + self.counts = CallbackCounts() |
| 76 | + |
| 77 | + async def on_user_message_callback( |
| 78 | + self, |
| 79 | + *, |
| 80 | + invocation_context: InvocationContext, |
| 81 | + user_message: types.Content, |
| 82 | + ) -> Optional[types.Content]: |
| 83 | + self.counts.on_user_message_callback += 1 |
| 84 | + return None |
| 85 | + |
| 86 | + async def before_run_callback( |
| 87 | + self, *, invocation_context: InvocationContext |
| 88 | + ) -> Optional[types.Content]: |
| 89 | + self.counts.before_run_callback += 1 |
| 90 | + return None |
| 91 | + |
| 92 | + async def on_event_callback( |
| 93 | + self, *, invocation_context: InvocationContext, event: Event |
| 94 | + ) -> Optional[Event]: |
| 95 | + self.counts.on_event_callback += 1 |
| 96 | + return None |
| 97 | + |
| 98 | + async def after_run_callback( |
| 99 | + self, *, invocation_context: InvocationContext |
| 100 | + ) -> None: |
| 101 | + self.counts.after_run_callback += 1 |
| 102 | + return None |
| 103 | + |
| 104 | + |
| 105 | +async def _terminal_node(ctx) -> Event: |
| 106 | + """Minimal terminal node yielding a content-bearing Event. |
| 107 | +
|
| 108 | + Content (not just state) ensures _consume_event_queue runs the |
| 109 | + on_event_callback path — the canonical case the plugin hook targets. |
| 110 | + """ |
| 111 | + return Event( |
| 112 | + content=types.Content( |
| 113 | + parts=[types.Part(text="done")], |
| 114 | + role="model", |
| 115 | + ) |
| 116 | + ) |
| 117 | + |
| 118 | + |
| 119 | +def _build_runner( |
| 120 | + plugin: TracerPlugin, *, runner_cls: type[Runner] = Runner |
| 121 | +) -> Runner: |
| 122 | + workflow = Workflow( |
| 123 | + name="Issue5282Repro", edges=[("START", _terminal_node)] |
| 124 | + ) |
| 125 | + app = App(name=APP_NAME, root_agent=workflow, plugins=[plugin]) |
| 126 | + return runner_cls( |
| 127 | + app_name=APP_NAME, |
| 128 | + app=app, |
| 129 | + session_service=InMemorySessionService(), |
| 130 | + memory_service=InMemoryMemoryService(), |
| 131 | + ) |
| 132 | + |
| 133 | + |
| 134 | +async def _drive_one_invocation(runner: Runner) -> None: |
| 135 | + session = await runner.session_service.create_session( |
| 136 | + app_name=APP_NAME, user_id=USER_ID |
| 137 | + ) |
| 138 | + async for _ in runner.run_async( |
| 139 | + user_id=USER_ID, |
| 140 | + session_id=session.id, |
| 141 | + new_message=types.Content( |
| 142 | + parts=[types.Part(text="hi")], role="user" |
| 143 | + ), |
| 144 | + ): |
| 145 | + pass |
| 146 | + |
| 147 | + |
| 148 | +# --------------------------------------------------------------------------- |
| 149 | +# Workaround: Runner subclass dispatching run_after_run_callback post-drain |
| 150 | +# --------------------------------------------------------------------------- |
| 151 | + |
| 152 | + |
| 153 | +class WorkaroundRunner(Runner): |
| 154 | + """Interim workaround for #5282. |
| 155 | +
|
| 156 | + Wraps run_async to dispatch plugin_manager.run_after_run_callback once |
| 157 | + the inner generator drains. Captures the active InvocationContext via |
| 158 | + _new_invocation_context (called once at runners.py:446). |
| 159 | +
|
| 160 | + Drop this class when the runners.py:427 TODO is resolved — the stock |
| 161 | + Runner will dispatch after_run_callback natively, and |
| 162 | + test_workflow_root_after_run_callback_not_dispatched will flip green |
| 163 | + as the signal. |
| 164 | + """ |
| 165 | + |
| 166 | + def __init__(self, *args, **kwargs) -> None: |
| 167 | + super().__init__(*args, **kwargs) |
| 168 | + self._last_ic: Optional[InvocationContext] = None |
| 169 | + |
| 170 | + def _new_invocation_context(self, session, **kwargs) -> InvocationContext: |
| 171 | + ic = super()._new_invocation_context(session, **kwargs) |
| 172 | + self._last_ic = ic |
| 173 | + return ic |
| 174 | + |
| 175 | + async def run_async(self, **kwargs) -> AsyncGenerator[Event, None]: |
| 176 | + async for event in super().run_async(**kwargs): |
| 177 | + yield event |
| 178 | + ic = self._last_ic |
| 179 | + if ic is not None: |
| 180 | + await ic.plugin_manager.run_after_run_callback( |
| 181 | + invocation_context=ic |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +# --------------------------------------------------------------------------- |
| 186 | +# Tests |
| 187 | +# --------------------------------------------------------------------------- |
| 188 | + |
| 189 | + |
| 190 | +@pytest.mark.asyncio |
| 191 | +async def test_workflow_root_dispatches_pre_run_and_event_hooks(): |
| 192 | + """Baseline: pre-run and event hooks fire on a Workflow (BaseNode) root.""" |
| 193 | + plugin = TracerPlugin() |
| 194 | + runner = _build_runner(plugin) |
| 195 | + |
| 196 | + await _drive_one_invocation(runner) |
| 197 | + |
| 198 | + assert plugin.counts.on_user_message_callback == 1 |
| 199 | + assert plugin.counts.before_run_callback == 1 |
| 200 | + assert plugin.counts.on_event_callback >= 1, ( |
| 201 | + "on_event_callback should fire via _consume_event_queue " |
| 202 | + "(runners.py:619) for the content-bearing terminal event" |
| 203 | + ) |
| 204 | + |
| 205 | + |
| 206 | +@pytest.mark.xfail( |
| 207 | + reason=( |
| 208 | + "#5282: runners.py:427 TODO — _run_node_async does not dispatch " |
| 209 | + "plugin_manager.run_after_run_callback on the BaseNode path. " |
| 210 | + "Remove this xfail when the TODO lands." |
| 211 | + ), |
| 212 | + strict=True, |
| 213 | +) |
| 214 | +@pytest.mark.asyncio |
| 215 | +async def test_workflow_root_after_run_callback_not_dispatched(): |
| 216 | + """Regression anchor: stock Runner does NOT fire after_run_callback. |
| 217 | +
|
| 218 | + Strict xfail — passes (as xfail) while the bug exists, fails loudly if |
| 219 | + after_run_callback starts firing unexpectedly. When the fix lands, delete |
| 220 | + the @xfail decorator and the test becomes a green regression guard. |
| 221 | + """ |
| 222 | + plugin = TracerPlugin() |
| 223 | + runner = _build_runner(plugin) |
| 224 | + |
| 225 | + await _drive_one_invocation(runner) |
| 226 | + |
| 227 | + assert plugin.counts.after_run_callback == 1 |
| 228 | + |
| 229 | + |
| 230 | +@pytest.mark.asyncio |
| 231 | +async def test_workaround_runner_dispatches_after_run_callback(): |
| 232 | + """WorkaroundRunner restores after_run_callback without touching ADK source.""" |
| 233 | + plugin = TracerPlugin() |
| 234 | + runner = _build_runner(plugin, runner_cls=WorkaroundRunner) |
| 235 | + |
| 236 | + await _drive_one_invocation(runner) |
| 237 | + |
| 238 | + assert plugin.counts.on_user_message_callback == 1 |
| 239 | + assert plugin.counts.before_run_callback == 1 |
| 240 | + assert plugin.counts.on_event_callback >= 1 |
| 241 | + assert plugin.counts.after_run_callback == 1 |
0 commit comments