@@ -458,13 +458,15 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
458
458
< span class ="c1 "> # pyre-ignore-all-errors[56]</ span >
459
459
460
460
< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> logging</ span >
461
+ < span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> uuid</ span >
461
462
< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> itertools</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> accumulate</ span >
462
463
< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> typing</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> List</ span > < span class ="p "> ,</ span > < span class ="n "> Optional</ span > < span class ="p "> ,</ span > < span class ="n "> Tuple</ span > < span class ="p "> ,</ span > < span class ="n "> Union</ span >
463
464
464
465
< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> fbgemm_gpu</ span > < span class ="c1 "> # noqa: F401</ span >
465
466
< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> torch</ span > < span class ="c1 "> # usort:skip</ span >
466
467
< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torch</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> nn</ span > < span class ="p "> ,</ span > < span class ="n "> Tensor</ span > < span class ="c1 "> # usort:skip</ span >
467
468
469
+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> fbgemm_gpu.config</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> FeatureGateName</ span >
468
470
< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> fbgemm_gpu.split_embedding_configs</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> sparse_type_to_int</ span > < span class ="p "> ,</ span > < span class ="n "> SparseType</ span >
469
471
< span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> fbgemm_gpu.split_table_batched_embeddings_ops_common</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="p "> (</ span >
470
472
< span class ="n "> BoundsCheckMode</ span > < span class ="p "> ,</ span >
@@ -817,6 +819,10 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
817
819
< span class ="n "> indices_dtype</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> int32</ span > < span class ="p "> ,</ span > < span class ="c1 "> # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).</ span >
818
820
< span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="kc "> None</ span > < span class ="p "> :</ span > < span class ="c1 "> # noqa C901 # tuple of (rows, dims,)</ span >
819
821
< span class ="nb "> super</ span > < span class ="p "> (</ span > < span class ="n "> IntNBitTableBatchedEmbeddingBagsCodegen</ span > < span class ="p "> ,</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> ()</ span >
822
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> uuid</ span > < span class ="o "> =</ span > < span class ="nb "> str</ span > < span class ="p "> (</ span > < span class ="n "> uuid</ span > < span class ="o "> .</ span > < span class ="n "> uuid4</ span > < span class ="p "> ())</ span >
823
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> log</ span > < span class ="p "> (</ span >
824
+ < span class ="sa "> f</ span > < span class ="s2 "> "Feature Gates: </ span > < span class ="si "> {</ span > < span class ="p "> [(</ span > < span class ="n "> feature</ span > < span class ="o "> .</ span > < span class ="n "> name</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> feature</ span > < span class ="o "> .</ span > < span class ="n "> is_enabled</ span > < span class ="p "> ())</ span > < span class ="w "> </ span > < span class ="k "> for</ span > < span class ="w "> </ span > < span class ="n "> feature</ span > < span class ="w "> </ span > < span class ="ow "> in</ span > < span class ="w "> </ span > < span class ="n "> FeatureGateName</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span >
825
+ < span class ="p "> )</ span >
820
826
821
827
< span class ="c1 "> # 64 for AMD</ span >
822
828
< span class ="k "> if</ span > < span class ="n "> cache_assoc</ span > < span class ="o "> ==</ span > < span class ="mi "> 32</ span > < span class ="ow "> and</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> version</ span > < span class ="o "> .</ span > < span class ="n "> hip</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
@@ -1072,6 +1078,20 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
1072
1078
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> fp8_exponent_bits</ span > < span class ="o "> =</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span >
1073
1079
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> fp8_exponent_bias</ span > < span class ="o "> =</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span >
1074
1080
1081
+ < span class ="nd "> @torch</ span > < span class ="o "> .</ span > < span class ="n "> jit</ span > < span class ="o "> .</ span > < span class ="n "> ignore</ span >
1082
+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> log</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> msg</ span > < span class ="p "> :</ span > < span class ="nb "> str</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1083
+ < span class ="w "> </ span > < span class ="sd "> """</ span >
1084
+ < span class ="sd "> Log with TBE id prefix to distinguish between multiple TBE instances</ span >
1085
+ < span class ="sd "> per process</ span >
1086
+
1087
+ < span class ="sd "> Args:</ span >
1088
+ < span class ="sd "> msg (str): The message to print</ span >
1089
+
1090
+ < span class ="sd "> Returns:</ span >
1091
+ < span class ="sd "> None</ span >
1092
+ < span class ="sd "> """</ span >
1093
+ < span class ="n "> logging</ span > < span class ="o "> .</ span > < span class ="n "> info</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "[TBE=</ span > < span class ="si "> {</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> uuid</ span > < span class ="si "> }</ span > < span class ="s2 "> ] </ span > < span class ="si "> {</ span > < span class ="n "> msg</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
1094
+
1075
1095
< span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> get_cache_miss_counter</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span >
1076
1096
< span class ="c1 "> # cache_miss_counter[0]: cache_miss_forward_count which records the total number of forwards which has at least one cache miss</ span >
1077
1097
< span class ="c1 "> # cache_miss_counter[1]: unique_cache_miss_count which records to total number of unique (dedup) cache misses</ span >
@@ -1120,17 +1140,17 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
1120
1140
< span class ="k "> assert</ span > < span class ="p "> (</ span >
1121
1141
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> record_cache_metrics</ span > < span class ="o "> .</ span > < span class ="n "> record_cache_miss_counter</ span >
1122
1142
< span class ="p "> ),</ span > < span class ="s2 "> "record_cache_miss_counter should be true to access counter values"</ span >
1123
- < span class ="n " > logging </ span > < span class ="o "> .</ span > < span class ="n "> info </ span > < span class ="p "> (</ span >
1143
+ < span class ="bp " > self </ span > < span class ="o "> .</ span > < span class ="n "> log </ span > < span class ="p "> (</ span >
1124
1144
< span class ="sa "> f</ span > < span class ="s2 "> "</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1125
1145
< span class ="sa "> f</ span > < span class ="s2 "> "Miss counter value [0] - # of miss occured iters : </ span > < span class ="si "> {</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="s2 "> , </ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1126
1146
< span class ="sa "> f</ span > < span class ="s2 "> "Miss counter value [1] - # of unique misses : </ span > < span class ="si "> {</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="s2 "> , </ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1127
1147
< span class ="sa "> f</ span > < span class ="s2 "> "Miss counter value [2] - # of unique requested indices : </ span > < span class ="si "> {</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 2</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="s2 "> , </ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1128
1148
< span class ="sa "> f</ span > < span class ="s2 "> "Miss counter value [3] - # of total requested indices : </ span > < span class ="si "> {</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 3</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="s2 "> , "</ span >
1129
1149
< span class ="p "> )</ span >
1130
- < span class ="n " > logging </ span > < span class ="o "> .</ span > < span class ="n "> info </ span > < span class ="p "> (</ span >
1150
+ < span class ="bp " > self </ span > < span class ="o "> .</ span > < span class ="n "> log </ span > < span class ="p "> (</ span >
1131
1151
< span class ="sa "> f</ span > < span class ="s2 "> "unique_miss_rate using counter : </ span > < span class ="si "> {</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="w "> </ span > < span class ="o "> /</ span > < span class ="w "> </ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 2</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="s2 "> , </ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1132
1152
< span class ="p "> )</ span >
1133
- < span class ="n " > logging </ span > < span class ="o "> .</ span > < span class ="n "> info </ span > < span class ="p "> (</ span >
1153
+ < span class ="bp " > self </ span > < span class ="o "> .</ span > < span class ="n "> log </ span > < span class ="p "> (</ span >
1134
1154
< span class ="sa "> f</ span > < span class ="s2 "> "total_miss_rate using counter : </ span > < span class ="si "> {</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="w "> </ span > < span class ="o "> /</ span > < span class ="w "> </ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_miss_counter</ span > < span class ="p "> [</ span > < span class ="mi "> 3</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="s2 "> , </ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1135
1155
< span class ="p "> )</ span >
1136
1156
@@ -1145,7 +1165,7 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
1145
1165
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> gather_uvm_cache_stats</ span >
1146
1166
< span class ="p "> ),</ span > < span class ="s2 "> "gather_uvm_cache_stats should be set to true to access uvm cache stats."</ span >
1147
1167
< span class ="n "> uvm_cache_stats</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="o "> .</ span > < span class ="n "> tolist</ span > < span class ="p "> ()</ span >
1148
- < span class ="n " > logging </ span > < span class ="o "> .</ span > < span class ="n "> info </ span > < span class ="p "> (</ span >
1168
+ < span class ="bp " > self </ span > < span class ="o "> .</ span > < span class ="n "> log </ span > < span class ="p "> (</ span >
1149
1169
< span class ="sa "> f</ span > < span class ="s2 "> "N_called: </ span > < span class ="si "> {</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1150
1170
< span class ="sa "> f</ span > < span class ="s2 "> "N_requested_indices: </ span > < span class ="si "> {</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1151
1171
< span class ="sa "> f</ span > < span class ="s2 "> "N_unique_indices: </ span > < span class ="si "> {</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 2</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
@@ -1154,7 +1174,7 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
1154
1174
< span class ="sa "> f</ span > < span class ="s2 "> "N_conflict_misses: </ span > < span class ="si "> {</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 5</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1155
1175
< span class ="p "> )</ span >
1156
1176
< span class ="k "> if</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]:</ span >
1157
- < span class ="n " > logging </ span > < span class ="o "> .</ span > < span class ="n "> info </ span > < span class ="p "> (</ span >
1177
+ < span class ="bp " > self </ span > < span class ="o "> .</ span > < span class ="n "> log </ span > < span class ="p "> (</ span >
1158
1178
< span class ="sa "> f</ span > < span class ="s2 "> "unique indices / requested indices: </ span > < span class ="si "> {</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 2</ span > < span class ="p "> ]</ span > < span class ="w "> </ span > < span class ="o "> /</ span > < span class ="w "> </ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1159
1179
< span class ="sa "> f</ span > < span class ="s2 "> "unique misses / requested indices: </ span > < span class ="si "> {</ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 3</ span > < span class ="p "> ]</ span > < span class ="w "> </ span > < span class ="o "> /</ span > < span class ="w "> </ span > < span class ="n "> uvm_cache_stats</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="si "> }</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span >
1160
1180
< span class ="p "> )</ span >
@@ -1660,7 +1680,7 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
1660
1680
< span class ="k "> assert</ span > < span class ="ow "> not</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> use_cpu</ span >
1661
1681
< span class ="k "> if</ span > < span class ="n "> enforce_hbm</ span > < span class ="p "> :</ span >
1662
1682
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> jit</ span > < span class ="o "> .</ span > < span class ="n "> is_scripting</ span > < span class ="p "> ():</ span >
1663
- < span class ="n " > logging </ span > < span class ="o "> .</ span > < span class ="n "> info </ span > < span class ="p "> (</ span > < span class ="s2 "> "Enforce hbm for the cache location"</ span > < span class ="p "> )</ span >
1683
+ < span class ="bp " > self </ span > < span class ="o "> .</ span > < span class ="n "> log </ span > < span class ="p "> (</ span > < span class ="s2 "> "Enforce hbm for the cache location"</ span > < span class ="p "> )</ span >
1664
1684
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> weights_uvm</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> (</ span >
1665
1685
< span class ="n "> uvm_size</ span > < span class ="p "> ,</ span >
1666
1686
< span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> current_device</ span > < span class ="p "> ,</ span >
@@ -1800,7 +1820,7 @@ <h1>Source code for fbgemm_gpu.split_table_batched_embeddings_ops_inference</h1>
1800
1820
< span class ="k "> if</ span > < span class ="n "> cache_algorithm</ span > < span class ="o "> ==</ span > < span class ="n "> CacheAlgorithm</ span > < span class ="o "> .</ span > < span class ="n "> LFU</ span > < span class ="p "> :</ span >
1801
1821
< span class ="k "> assert</ span > < span class ="n "> cache_sets</ span > < span class ="o "> <</ span > < span class ="mi "> 2</ span > < span class ="o "> **</ span > < span class ="mi "> 24</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span >
1802
1822
< span class ="n "> cache_size</ span > < span class ="o "> =</ span > < span class ="n "> cache_sets</ span > < span class ="o "> *</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> cache_assoc</ span > < span class ="o "> *</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> max_D_cache</ span >
1803
- < span class ="n " > logging </ span > < span class ="o "> .</ span > < span class ="n "> info </ span > < span class ="p "> (</ span >
1823
+ < span class ="bp " > self </ span > < span class ="o "> .</ span > < span class ="n "> log </ span > < span class ="p "> (</ span >
1804
1824
< span class ="sa "> f</ span > < span class ="s2 "> "Using on-device cache with admission algorithm "</ span >
1805
1825
< span class ="sa "> f</ span > < span class ="s2 "> "</ span > < span class ="si "> {</ span > < span class ="n "> cache_algorithm</ span > < span class ="si "> }</ span > < span class ="s2 "> , </ span > < span class ="si "> {</ span > < span class ="n "> cache_sets</ span > < span class ="si "> }</ span > < span class ="s2 "> sets, "</ span >
1806
1826
< span class ="sa "> f</ span > < span class ="s2 "> "cache_load_factor: </ span > < span class ="si "> {</ span > < span class ="n "> cache_load_factor</ span > < span class ="w "> </ span > < span class ="si "> :</ span > < span class ="s2 "> .3f</ span > < span class ="si "> }</ span > < span class ="s2 "> , "</ span >
0 commit comments