diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 7c0bf3241af..199a740737a 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -46,6 +46,7 @@ display_or_print_df, EDGE_DIALECT_GRAPH_KEY, EXCLUDED_COLUMNS_WHEN_PRINTING, + EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT, EXCLUDED_EVENTS_WHEN_PRINTING, find_populated_event, FORWARD, @@ -1149,6 +1150,36 @@ def _consume_etrecord(self) -> None: self._etrecord._representative_inputs ) + # TODO: Make it more extensible to further merge overlapping debug handles + def _get_runtime_intermediate_outputs(self) -> Dict[Tuple[int, ...], Any]: + """ + Retrieve the raw runtime intermediate outputs(debug handles and value mappings) + from the event blocks. These outputs will be processed later to merge overlapping debug handles. + """ + debug_handle_to_output = {} + for event_block in self.event_blocks: + for event in event_block.events: + # Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax + if ( + event.name in EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT + or not event.op_types + ): + continue + # Normalize debug_handles to a tuple + debug_handles = event.debug_handles + if isinstance(debug_handles, int): + debug_handles = (debug_handles,) + else: + debug_handles = tuple(debug_handles) + current_entry = debug_handle_to_output.get(debug_handles, (-1, None)) + # When event has same debug handles, only keep the one with the largest instruction id + if event._instruction_id > current_entry[0]: + debug_handle_to_output[debug_handles] = ( + event._instruction_id, + event.debug_data, + ) + return {k: v[1] for k, v in debug_handle_to_output.items()} + def to_dataframe( self, include_units: bool = True, diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 3faf62e008a..f3cdbc1238a 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -52,6 +52,8 @@ ] EXCLUDED_EVENTS_WHEN_PRINTING = {"OPERATOR_CALL"} +EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT = {"OPERATOR_CALL"} + class TimeScale(Enum): NS = "ns" diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index bcf2fb2230d..b96a694b581 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -13,7 +13,7 @@ import unittest from contextlib import redirect_stdout -from typing import Callable, List +from typing import Callable, List, Union from unittest.mock import patch @@ -56,7 +56,7 @@ OP_TYPE = "aten::add" EVENT_BLOCK_NAME = "block_0" -EVENTS_SIZE = 5 +EVENTS_SIZE = 10 RAW_DATA_SIZE = 10 ETDUMP_PATH = "unittest_etdump_path" ETRECORD_PATH = "unittest_etrecord_path" @@ -535,17 +535,116 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self): ) ) + def test_get_runtime_intermediate_outputs(self): + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # The mock inspector instance starts with having an empty event blocks list. + # Add pre-defined event blocks to test _get_runtime_outputs(). + inspector_instance.event_blocks = [ + EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events()) + ] + + runtime_outputs = inspector_instance._get_runtime_intermediate_outputs() + # This output should be a dictionary with 5 keys + self.assertEqual( + len(runtime_outputs), + 5, + ) + # Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty) + self.assertNotIn((0,), runtime_outputs) + self.assertNotIn((1,), runtime_outputs) + + # Same debug_handle but different instruction_id, should record the last one + self.assertIn((4,), runtime_outputs) + self.assertTrue( + torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0])) + ) + # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size + for key in range(5, 9): + self.assertIn((key,), runtime_outputs) + self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) + def _gen_random_float_list(self) -> List[float]: return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)] + def _gen_random_runtime_output( + self, + ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: + return list(torch.randn(RAW_DATA_SIZE)) + def _gen_random_events(self) -> List[Event]: events = [] - for i in range(EVENTS_SIZE): + for i in range(2): + events.append( + # OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2 + Event( + name="OPERATOR_CALL", + op_types=[OP_TYPE], + perf_data=PerfData(self._gen_random_float_list()), + debug_handles=i * 2, + _instruction_id=i * 2, + debug_data=self._gen_random_runtime_output(), + ) + ) + events.append( + # op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3 + Event( + name=f"op_{i}", + op_types=[], + perf_data=PerfData(self._gen_random_float_list()), + debug_handles=i * 2 + 1, + _instruction_id=i * 2 + 1, + debug_data=self._gen_random_runtime_output(), + ) + ) + + # op_2 with debug_hanldes/instruction_id 4 + events.append( + Event( + name="op_2", + op_types=[OP_TYPE], + perf_data=PerfData(self._gen_random_float_list()), + debug_handles=4, + debug_data=[torch.tensor([1.0, 2.0, 3.0])], + _instruction_id=4, + ) + ) + # op_3 also with debug_hanldes 4 but with instruction_id 5 + events.append( + Event( + name="op_3", + op_types=[OP_TYPE], + perf_data=PerfData(self._gen_random_float_list()), + debug_handles=4, + debug_data=[torch.tensor([4.0, 5.0, 6.0])], + _instruction_id=5, + ) + ) + + # op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9 + for i in range(4, EVENTS_SIZE - 2): events.append( Event( name=f"op_{i}", op_types=[OP_TYPE], perf_data=PerfData(self._gen_random_float_list()), + debug_handles=i + 1, + debug_data=self._gen_random_runtime_output(), + _instruction_id=i + 2, ) ) return events