7
7
# pyre-unsafe
8
8
9
9
import copy
10
+ import torch
10
11
import random
11
12
import statistics
12
13
import tempfile
13
14
import unittest
14
15
from contextlib import redirect_stdout
15
16
16
- from typing import Callable , List
17
+ from typing import Callable , List , Union
17
18
18
19
from unittest .mock import patch
19
20
56
57
57
58
OP_TYPE = "aten::add"
58
59
EVENT_BLOCK_NAME = "block_0"
59
- EVENTS_SIZE = 5
60
+ EVENTS_SIZE = 10
60
61
RAW_DATA_SIZE = 10
61
62
ETDUMP_PATH = "unittest_etdump_path"
62
63
ETRECORD_PATH = "unittest_etrecord_path"
@@ -72,7 +73,7 @@ def test_perf_data(self) -> None:
72
73
self .assertAlmostEqual (perfData .p50 , statistics .median (random_floats ))
73
74
74
75
def test_event_block_to_dataframe (self ) -> None :
75
- eventBlock = EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_random_events ())
76
+ eventBlock = EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_events ())
76
77
77
78
df = eventBlock .to_dataframe ()
78
79
# Check some fields of the returned dataframe
@@ -154,7 +155,7 @@ def test_inspector_print_data_tabular(self):
154
155
# The mock inspector instance starts with having an empty event blocks list.
155
156
# Add non-empty event blocks to test print_data_tabular().
156
157
inspector_instance .event_blocks = [
157
- EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_random_events ())
158
+ EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_events ())
158
159
]
159
160
# Call print_data_tabular(), make sure it doesn't crash
160
161
with redirect_stdout (None ):
@@ -535,17 +536,111 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
535
536
)
536
537
)
537
538
539
+ def test_get_runtime_intermediate_outputs (self ):
540
+ # Create a context manager to patch functions called by Inspector.__init__
541
+ with patch .object (
542
+ _inspector , "parse_etrecord" , return_value = None
543
+ ), patch .object (
544
+ _inspector , "gen_etdump_object" , return_value = None
545
+ ), patch .object (
546
+ EventBlock , "_gen_from_etdump"
547
+ ), patch .object (
548
+ _inspector , "gen_graphs_from_etrecord"
549
+ ):
550
+ # Call the constructor of Inspector
551
+ inspector_instance = Inspector (
552
+ etdump_path = ETDUMP_PATH ,
553
+ etrecord = ETRECORD_PATH ,
554
+ )
555
+
556
+ # The mock inspector instance starts with having an empty event blocks list.
557
+ # Add pre-defined event blocks to test _get_runtime_outputs().
558
+ inspector_instance .event_blocks = [
559
+ EventBlock (name = EVENT_BLOCK_NAME , events = self ._gen_events ())
560
+ ]
561
+
562
+ runtime_outputs = inspector_instance ._get_runtime_intermediate_outputs ()
563
+ # This output should be a dictionary with 5 keys
564
+ self .assertEqual (len (runtime_outputs ), 5 , )
565
+ # Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty)
566
+ self .assertNotIn ((0 ,), runtime_outputs )
567
+ self .assertNotIn ((1 ,), runtime_outputs )
568
+
569
+ # Same debug_handle but different instruction_id, should record the last one
570
+ self .assertIn ((4 ,), runtime_outputs )
571
+ self .assertTrue (torch .equal (runtime_outputs [(4 ,)][0 ], torch .tensor ([4.0 , 5.0 , 6.0 ])))
572
+ # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
573
+ for key in range (5 , 9 ):
574
+ self .assertIn ((key ,), runtime_outputs )
575
+ self .assertEqual (len (runtime_outputs [(key ,)]), RAW_DATA_SIZE )
576
+
538
577
def _gen_random_float_list (self ) -> List [float ]:
539
578
return [random .uniform (0 , 10 ) for _ in range (RAW_DATA_SIZE )]
540
579
541
- def _gen_random_events (self ) -> List [Event ]:
580
+ def _gen_random_runtime_output (self ) -> List [Union [None , List [torch .Tensor ], bool , float , int , str , torch .Tensor ]]:
581
+ return list (torch .randn (RAW_DATA_SIZE ))
582
+
583
+ def _gen_events (self ) -> List [Event ]:
542
584
events = []
543
- for i in range (EVENTS_SIZE ):
585
+ for i in range (2 ):
586
+ events .append (
587
+ # OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
588
+ Event (
589
+ name = "OPERATOR_CALL" ,
590
+ op_types = [OP_TYPE ],
591
+ perf_data = PerfData (self ._gen_random_float_list ()),
592
+ debug_handles = i * 2 ,
593
+ _instruction_id = i * 2 ,
594
+ debug_data = self ._gen_random_runtime_output ()
595
+ )
596
+ )
597
+ events .append (
598
+ # op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
599
+ Event (
600
+ name = f"op_{ i } " ,
601
+ op_types = [],
602
+ perf_data = PerfData (self ._gen_random_float_list ()),
603
+ debug_handles = i * 2 + 1 ,
604
+ _instruction_id = i * 2 + 1 ,
605
+ debug_data = self ._gen_random_runtime_output ()
606
+ )
607
+ )
608
+
609
+ # op_2 with debug_hanldes/instruction_id 4
610
+ events .append (
611
+ Event (
612
+ name = f"op_2" ,
613
+ op_types = [OP_TYPE ],
614
+ perf_data = PerfData (self ._gen_random_float_list ()),
615
+ debug_handles = 4 ,
616
+ debug_data = [torch .tensor ([1.0 , 2.0 , 3.0 ])],
617
+ _instruction_id = 4
618
+
619
+ )
620
+ )
621
+ # op_3 also with debug_hanldes 4 but with instruction_id 5
622
+ events .append (
623
+ Event (
624
+ name = f"op_3" ,
625
+ op_types = [OP_TYPE ],
626
+ perf_data = PerfData (self ._gen_random_float_list ()),
627
+ debug_handles = 4 ,
628
+ debug_data = [torch .tensor ([4.0 , 5.0 , 6.0 ])],
629
+ _instruction_id = 5
630
+
631
+ )
632
+ )
633
+
634
+ # op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
635
+ for i in range (4 , EVENTS_SIZE - 2 ):
544
636
events .append (
545
637
Event (
546
638
name = f"op_{ i } " ,
547
639
op_types = [OP_TYPE ],
548
640
perf_data = PerfData (self ._gen_random_float_list ()),
641
+ debug_handles = i + 1 ,
642
+ debug_data = self ._gen_random_runtime_output (),
643
+ _instruction_id = i + 2
549
644
)
550
645
)
551
646
return events
0 commit comments