Skip to content

Commit 923d2fd

Browse files
author
Ivan Zaitsev
committed
[autorevert] Surface test identity at signal level for advisor JSON
The autorevert AI advisor's signal_pattern JSON does not surface authoritative test identity (file/classname/name) for TEST signals. The advisor knows the test from `signal_key` (which is `file::name`), but `classname` is dropped by the test_id collapse — so when many tests share the same file or name, the advisor has to re-derive structure from logs. When the suspect commit's parent workflow_job is still in_progress at advisor dispatch time, that re-derivation fails and the advisor rationalizes pending-heavy partitions as flake → garbage. A 30d audit found 40 garbage verdicts on 23 distinct commits later auto-reverted; 33/37 had a workflow_job for the suspect commit still in_progress at verdict time. A TEST signal corresponds to a SINGLE specific test by construction (one signal per `(workflow, job_base, test_id)`). Every FAILURE event in that signal IS that specific test failing — no per-event test identity is needed (`signal_key` + `status` already convey it). This change surfaces the structured test identity ONCE at the top of the advisor signal_pattern payload: { "signal_key": "test/foo.py::test_bar", "signal_source": "test", "test_file": "test/foo.py", "test_classname": "TestFooBar", "test_name": "test_bar", ... } Implementation: - `Signal` gains optional `test_file`/`test_classname`/`test_name` attributes (default None). - `_build_test_signals` captures the (file, classname, name) triple from any TestRow contributing to the signal (collapses cleanly because each TEST signal is a single test by definition) and attaches them to the Signal at construction time. - `_build_signal_pattern_json` emits the three fields as top-level payload keys when populated, only for `signal_source == "test"`. - New regression tests: - `test_test_identity_surfaced_for_test_signals` — top-level keys present; no per-event leak. - `test_test_identity_omitted_for_job_signals` — JOB signals omit the keys entirely. Companion change: pytorch/pytorch prompt PR adding a "Partial Visibility Warning" section, tightening the `garbage` definition to require concrete completed-log infra evidence, and adding an explicit `signal_key`-is-authoritative assertion that aligns with the top-level test identity exposed here.
1 parent f6af685 commit 923d2fd

5 files changed

Lines changed: 312 additions & 12 deletions

File tree

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ def __init__(
359359
job_base_name: Optional[str] = None,
360360
test_module: Optional[str] = None,
361361
source: SignalSource = SignalSource.TEST,
362+
test_file: Optional[str] = None,
363+
test_classname: Optional[str] = None,
364+
test_name: Optional[str] = None,
362365
):
363366
self.key = key
364367
self.workflow_name = workflow_name
@@ -369,6 +372,15 @@ def __init__(
369372
self.test_module = test_module
370373
# Track the origin of the signal (test-track or job-track).
371374
self.source = source
375+
# For TEST signals: structured identity of the failing test, sourced
376+
# from tests.all_test_runs. Surfaced once at the top of the advisor
377+
# signal_pattern JSON so the AI advisor has authoritative ground
378+
# truth for which test failed (file/classname/name) without
379+
# re-deriving from logs. Every FAILURE event in this signal IS this
380+
# specific test failing.
381+
self.test_file = test_file
382+
self.test_classname = test_classname
383+
self.test_name = test_name
372384

373385
def detect_fixed(self) -> bool:
374386
"""

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from datetime import datetime, timedelta
88
from enum import Enum
9-
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple, Union
9+
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Tuple, Union
1010

1111
import github
1212

@@ -18,6 +18,7 @@
1818
Ineligible,
1919
RestartCommits,
2020
Signal,
21+
SignalSource,
2122
)
2223
from .signal_extraction_types import RunContext
2324
from .utils import (
@@ -807,17 +808,28 @@ def _partition_label(sha: str) -> str:
807808
}
808809
)
809810

810-
return json.dumps(
811-
{
812-
"signal_key": signal.key,
813-
"signal_source": signal.source.value if signal.source else "unknown",
814-
"workflow_name": signal.workflow_name,
815-
"job_base_name": signal.job_base_name,
816-
"commit_order": "newest_first",
817-
"suspect_commit": dispatch_advisor.suspect_commit,
818-
"commits": commits_json,
819-
}
820-
)
811+
payload: Dict[str, Any] = {
812+
"signal_key": signal.key,
813+
"signal_source": signal.source.value if signal.source else "unknown",
814+
"workflow_name": signal.workflow_name,
815+
"job_base_name": signal.job_base_name,
816+
"commit_order": "newest_first",
817+
"suspect_commit": dispatch_advisor.suspect_commit,
818+
}
819+
# For TEST signals, surface authoritative test identity once at the
820+
# top of the payload (file/classname/name from tests.all_test_runs).
821+
# Every FAILURE event in this signal IS this specific test failing —
822+
# the AI advisor does not need to re-derive from logs. Only emitted
823+
# when populated to keep the payload backward-compatible.
824+
if signal.source == SignalSource.TEST:
825+
if signal.test_file:
826+
payload["test_file"] = signal.test_file
827+
if signal.test_classname:
828+
payload["test_classname"] = signal.test_classname
829+
if signal.test_name:
830+
payload["test_name"] = signal.test_name
831+
payload["commits"] = commits_json
832+
return json.dumps(payload)
821833

822834
def _commit_message_check_pr_is_revert(
823835
self, commit_message: str, ctx: RunContext

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
151151
job_base_name=s.job_base_name,
152152
test_module=s.test_module,
153153
source=s.source,
154+
test_file=s.test_file,
155+
test_classname=s.test_classname,
156+
test_name=s.test_name,
154157
)
155158
)
156159
return deduped
@@ -237,6 +240,9 @@ def _inject_pending_workflow_events(
237240
job_base_name=s.job_base_name,
238241
test_module=s.test_module,
239242
source=s.source,
243+
test_file=s.test_file,
244+
test_classname=s.test_classname,
245+
test_name=s.test_name,
240246
)
241247
)
242248
return out
@@ -302,6 +308,9 @@ def _attach_advisor_verdicts(
302308
job_base_name=s.job_base_name,
303309
test_module=s.test_module,
304310
source=s.source,
311+
test_file=s.test_file,
312+
test_classname=s.test_classname,
313+
test_name=s.test_name,
305314
)
306315
)
307316
return out
@@ -393,6 +402,11 @@ def _build_test_signals(
393402
failing_tests_by_job_base_name: Set[
394403
Tuple[WorkflowName, JobBaseName, TestId]
395404
] = set()
405+
# Capture structured test identity (file/classname/name) per test_id so
406+
# we can attach it once at the Signal level. The TestRow `test_id`
407+
# property collapses to "file::name" and drops classname; we keep the
408+
# full triple here for the advisor signal_pattern JSON.
409+
test_identity_by_test_id: Dict[TestId, Tuple[str, str, str]] = {}
396410
for tr in test_rows:
397411
job = jobs_by_id.get(tr.job_id)
398412
job_base_name = job.base_name
@@ -419,6 +433,9 @@ def _build_test_signals(
419433
outcome = existing
420434

421435
tests_by_group_attempt[key] = outcome
436+
test_identity_by_test_id.setdefault(
437+
tr.test_id, (tr.file, tr.classname, tr.name)
438+
)
422439

423440
# Track keys that have at least one persistent failure (no retry success)
424441
if outcome.failure_runs > 0 and outcome.success_runs == 0:
@@ -523,6 +540,9 @@ def _build_test_signals(
523540
# Extract test module from test_id (format: "file.py::test_name")
524541
# Result: "file" or "path/to/file" without .py extension
525542
test_module = test_id.split("::")[0].replace(".py", "")
543+
test_file, test_classname, test_name = test_identity_by_test_id.get(
544+
test_id, ("", "", "")
545+
)
526546

527547
signals.append(
528548
Signal(
@@ -532,6 +552,9 @@ def _build_test_signals(
532552
job_base_name=str(job_base_name),
533553
test_module=test_module,
534554
source=SignalSource.TEST,
555+
test_file=test_file or None,
556+
test_classname=test_classname or None,
557+
test_name=test_name or None,
535558
)
536559
)
537560

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_actions.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,132 @@ def test_signal_pattern_sanity_check(self):
966966
reparsed = json.loads(json.dumps(result))
967967
self.assertEqual(reparsed, result)
968968

969+
def test_test_identity_surfaced_for_test_signals(self):
970+
"""TEST signals surface authoritative test identity (file/classname/name)
971+
once at the top of the payload. Every FAILURE event in the signal IS
972+
this specific test failing — no duplication per event."""
973+
import json
974+
975+
from pytorch_auto_revert.signal import (
976+
DispatchAdvisor,
977+
Signal,
978+
SignalCommit,
979+
SignalEvent,
980+
SignalSource,
981+
SignalStatus,
982+
)
983+
984+
t0 = datetime(2025, 8, 19, 12, 0, 0)
985+
t1 = datetime(2025, 8, 19, 11, 0, 0)
986+
987+
c_fail = SignalCommit(
988+
head_sha="sha_fail",
989+
timestamp=t0,
990+
events=[
991+
SignalEvent(
992+
"test", SignalStatus.FAILURE, t0, wf_run_id=100, job_id=200
993+
),
994+
],
995+
)
996+
c_base = SignalCommit(
997+
head_sha="sha_base",
998+
timestamp=t1,
999+
events=[
1000+
SignalEvent(
1001+
"test", SignalStatus.SUCCESS, t1, wf_run_id=99, job_id=199
1002+
),
1003+
],
1004+
)
1005+
signal = Signal(
1006+
key="test/inductor/test_comm_analysis.py::test_nccl_estimate_device_resolution_gpu",
1007+
workflow_name="trunk",
1008+
commits=[c_fail, c_base],
1009+
source=SignalSource.TEST,
1010+
job_base_name="linux-jammy / test",
1011+
test_file="test/inductor/test_comm_analysis.py",
1012+
test_classname="TestNcclEstimateDeviceResolution",
1013+
test_name="test_nccl_estimate_device_resolution_gpu",
1014+
)
1015+
advisor = DispatchAdvisor(
1016+
suspect_commit="sha_fail",
1017+
failed_commits=("sha_fail",),
1018+
successful_commits=("sha_base",),
1019+
)
1020+
1021+
result = json.loads(
1022+
SignalActionProcessor._build_signal_pattern_json(
1023+
signal=signal,
1024+
dispatch_advisor=advisor,
1025+
repo_full_name="pytorch/pytorch",
1026+
)
1027+
)
1028+
# Top-level test identity is present, populated, and authoritative
1029+
self.assertEqual(result["test_file"], "test/inductor/test_comm_analysis.py")
1030+
self.assertEqual(result["test_classname"], "TestNcclEstimateDeviceResolution")
1031+
self.assertEqual(
1032+
result["test_name"], "test_nccl_estimate_device_resolution_gpu"
1033+
)
1034+
1035+
# No per-event test_failures emission (would be redundant with signal_key)
1036+
for c in result["commits"]:
1037+
for ev in c["events"]:
1038+
self.assertNotIn("test_failures", ev)
1039+
self.assertNotIn("test_file", ev)
1040+
self.assertNotIn("test_classname", ev)
1041+
self.assertNotIn("test_name", ev)
1042+
1043+
def test_test_identity_omitted_for_job_signals(self):
1044+
"""JOB signals (and TEST signals without identity populated) must not
1045+
emit the test_* top-level keys."""
1046+
import json
1047+
1048+
from pytorch_auto_revert.signal import (
1049+
DispatchAdvisor,
1050+
Signal,
1051+
SignalCommit,
1052+
SignalEvent,
1053+
SignalSource,
1054+
SignalStatus,
1055+
)
1056+
1057+
t0 = datetime(2025, 8, 19, 12, 0, 0)
1058+
1059+
c_fail = SignalCommit(
1060+
head_sha="sha_fail",
1061+
timestamp=t0,
1062+
events=[
1063+
SignalEvent("j", SignalStatus.FAILURE, t0, wf_run_id=1, job_id=1),
1064+
],
1065+
)
1066+
c_base = SignalCommit(
1067+
head_sha="sha_base",
1068+
timestamp=t0,
1069+
events=[
1070+
SignalEvent("j", SignalStatus.SUCCESS, t0, wf_run_id=2, job_id=2),
1071+
],
1072+
)
1073+
signal = Signal(
1074+
key="lint",
1075+
workflow_name="trunk",
1076+
commits=[c_fail, c_base],
1077+
source=SignalSource.JOB,
1078+
)
1079+
advisor = DispatchAdvisor(
1080+
suspect_commit="sha_fail",
1081+
failed_commits=("sha_fail",),
1082+
successful_commits=("sha_base",),
1083+
)
1084+
1085+
result = json.loads(
1086+
SignalActionProcessor._build_signal_pattern_json(
1087+
signal=signal,
1088+
dispatch_advisor=advisor,
1089+
repo_full_name="pytorch/pytorch",
1090+
)
1091+
)
1092+
for k in ("test_file", "test_classname", "test_name"):
1093+
self.assertNotIn(k, result)
1094+
9691095

9701096
class TestDispatchAdvisorsMethod(unittest.TestCase):
9711097
"""Tests for SignalActionProcessor.dispatch_advisors."""
@@ -1270,6 +1396,108 @@ def test_invalid_verdict_string_defaults_to_unsure(self):
12701396
result[0].commits[0].advisor_result.verdict, AdvisorVerdict.UNSURE
12711397
)
12721398

1399+
def test_preserves_test_identity_on_attach(self):
1400+
# Signal-level test_file/test_classname/test_name must survive the
1401+
# Signal reconstruction inside _attach_advisor_verdicts.
1402+
from pytorch_auto_revert.signal import (
1403+
Signal,
1404+
SignalCommit,
1405+
SignalEvent,
1406+
SignalSource,
1407+
SignalStatus,
1408+
)
1409+
from pytorch_auto_revert.signal_extraction import SignalExtractor
1410+
from pytorch_auto_revert.signal_extraction_types import Sha
1411+
1412+
t0 = datetime(2025, 8, 19, 12, 0, 0)
1413+
c1 = SignalCommit(
1414+
"sha_aaa",
1415+
t0,
1416+
[SignalEvent("j", SignalStatus.FAILURE, t0, wf_run_id=1, job_id=10)],
1417+
)
1418+
signal = Signal(
1419+
key="test/foo.py::test_bar",
1420+
workflow_name="trunk",
1421+
commits=[c1],
1422+
source=SignalSource.TEST,
1423+
test_file="test/foo.py",
1424+
test_classname="TestFooBar",
1425+
test_name="test_bar",
1426+
)
1427+
extractor = SignalExtractor(workflows=["trunk"], lookback_hours=16)
1428+
extractor._datasource = Mock()
1429+
extractor._datasource.fetch_advisor_verdicts.return_value = {
1430+
("sha_aaa", "test/foo.py::test_bar"): ("revert", 0.95, t0),
1431+
}
1432+
out = extractor._attach_advisor_verdicts([signal], [(Sha("sha_aaa"), t0)])
1433+
self.assertEqual(out[0].test_file, "test/foo.py")
1434+
self.assertEqual(out[0].test_classname, "TestFooBar")
1435+
self.assertEqual(out[0].test_name, "test_bar")
1436+
1437+
1438+
class TestInjectPendingWorkflowEvents(unittest.TestCase):
1439+
"""Tests for SignalExtractor._inject_pending_workflow_events."""
1440+
1441+
def test_preserves_test_identity_on_inject(self):
1442+
# Signal-level test_file/test_classname/test_name must survive the
1443+
# Signal reconstruction inside _inject_pending_workflow_events.
1444+
from pytorch_auto_revert.signal import (
1445+
Signal,
1446+
SignalCommit,
1447+
SignalEvent,
1448+
SignalSource,
1449+
SignalStatus,
1450+
)
1451+
from pytorch_auto_revert.signal_extraction import SignalExtractor
1452+
from pytorch_auto_revert.signal_extraction_types import (
1453+
JobBaseName,
1454+
JobId,
1455+
JobName,
1456+
JobRow,
1457+
RunAttempt,
1458+
Sha,
1459+
WfRunId,
1460+
WorkflowName,
1461+
)
1462+
1463+
t0 = datetime(2025, 8, 19, 12, 0, 0)
1464+
c1 = SignalCommit(
1465+
"sha_aaa",
1466+
t0,
1467+
[SignalEvent("j", SignalStatus.FAILURE, t0, wf_run_id=1, job_id=10)],
1468+
)
1469+
signal = Signal(
1470+
key="test/foo.py::test_bar",
1471+
workflow_name="trunk",
1472+
commits=[c1],
1473+
source=SignalSource.TEST,
1474+
test_file="test/foo.py",
1475+
test_classname="TestFooBar",
1476+
test_name="test_bar",
1477+
)
1478+
# One pending JobRow on a different wf_run_id triggers synthesis on c1.
1479+
pending_job = JobRow(
1480+
head_sha=Sha("sha_aaa"),
1481+
workflow_name=WorkflowName("trunk"),
1482+
wf_run_id=WfRunId(2),
1483+
job_id=JobId(20),
1484+
run_attempt=RunAttempt(1),
1485+
name=JobName("j"),
1486+
status="in_progress",
1487+
conclusion="",
1488+
started_at=t0,
1489+
created_at=t0,
1490+
rule="",
1491+
)
1492+
extractor = SignalExtractor(workflows=["trunk"], lookback_hours=16)
1493+
out = extractor._inject_pending_workflow_events([signal], [pending_job])
1494+
self.assertEqual(out[0].test_file, "test/foo.py")
1495+
self.assertEqual(out[0].test_classname, "TestFooBar")
1496+
self.assertEqual(out[0].test_name, "test_bar")
1497+
# Synthesis actually fired (otherwise the test wouldn't exercise
1498+
# the reconstruction path)
1499+
self.assertGreater(len(out[0].commits[0].events), 1)
1500+
12731501

12741502
if __name__ == "__main__":
12751503
unittest.main()

0 commit comments

Comments
 (0)