19
19
import torch
20
20
from fbgemm_gpu .split_embedding_configs import EmbOptimType as OptimType , SparseType
21
21
from fbgemm_gpu .split_table_batched_embeddings_ops_common import (
22
- BoundsCheckMode ,
23
22
CacheAlgorithm ,
24
23
EmbeddingLocation ,
25
- str_to_embedding_location ,
26
24
str_to_pooling_mode ,
27
25
)
28
26
from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
32
30
)
33
31
from fbgemm_gpu .tbe .bench import (
34
32
benchmark_requests ,
33
+ EmbeddingOpsCommonConfigLoader ,
35
34
TBEBenchmarkingConfigLoader ,
36
35
TBEDataConfigLoader ,
37
36
)
@@ -50,50 +49,39 @@ def cli() -> None:
50
49
51
50
52
51
@cli .command ()
53
- @click .option ("--weights-precision" , type = SparseType , default = SparseType .FP32 )
54
- @click .option ("--cache-precision" , type = SparseType , default = None )
55
- @click .option ("--stoc" , is_flag = True , default = False )
56
- @click .option (
57
- "--managed" ,
58
- default = "device" ,
59
- type = click .Choice (["device" , "managed" , "managed_caching" ], case_sensitive = False ),
60
- )
61
52
@click .option (
62
53
"--emb-op-type" ,
63
54
default = "split" ,
64
55
type = click .Choice (["split" , "dense" , "ssd" ], case_sensitive = False ),
56
+ help = "The type of the embedding op to benchmark" ,
57
+ )
58
+ @click .option (
59
+ "--row-wise/--no-row-wise" ,
60
+ default = True ,
61
+ help = "Whether to use row-wise adagrad optimzier or not" ,
65
62
)
66
- @click .option ("--row-wise/--no-row-wise" , default = True )
67
- @click .option ("--pooling" , type = str , default = "sum" )
68
- @click .option ("--weighted-num-requires-grad" , type = int , default = None )
69
- @click .option ("--bounds-check-mode" , type = int , default = BoundsCheckMode .NONE .value )
70
- @click .option ("--output-dtype" , type = SparseType , default = SparseType .FP32 )
71
63
@click .option (
72
- "--uvm-host-mapped " ,
73
- is_flag = True ,
74
- default = False ,
75
- help = "Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister) " ,
64
+ "--weighted-num-requires-grad " ,
65
+ type = int ,
66
+ default = None ,
67
+ help = "The number of weighted tables that require gradient " ,
76
68
)
77
69
@click .option (
78
- "--ssd-prefix" , type = str , default = "/tmp/ssd_benchmark" , help = "SSD directory prefix"
70
+ "--ssd-prefix" ,
71
+ type = str ,
72
+ default = "/tmp/ssd_benchmark" ,
73
+ help = "SSD directory prefix" ,
79
74
)
80
75
@click .option ("--cache-load-factor" , default = 0.2 )
81
76
@TBEBenchmarkingConfigLoader .options
82
77
@TBEDataConfigLoader .options
78
+ @EmbeddingOpsCommonConfigLoader .options
83
79
@click .pass_context
84
80
def device ( # noqa C901
85
81
context : click .Context ,
86
82
emb_op_type : click .Choice ,
87
- weights_precision : SparseType ,
88
- cache_precision : Optional [SparseType ],
89
- stoc : bool ,
90
- managed : click .Choice ,
91
83
row_wise : bool ,
92
- pooling : str ,
93
84
weighted_num_requires_grad : Optional [int ],
94
- bounds_check_mode : int ,
95
- output_dtype : SparseType ,
96
- uvm_host_mapped : bool ,
97
85
cache_load_factor : float ,
98
86
# SSD params
99
87
ssd_prefix : str ,
@@ -110,6 +98,9 @@ def device( # noqa C901
110
98
# Load TBE data configuration from cli arguments
111
99
tbeconfig = TBEDataConfigLoader .load (context )
112
100
101
+ # Load common embedding op configuration from cli arguments
102
+ embconfig = EmbeddingOpsCommonConfigLoader .load (context )
103
+
113
104
# Generate feature_requires_grad
114
105
feature_requires_grad = (
115
106
tbeconfig .generate_feature_requires_grad (weighted_num_requires_grad )
@@ -123,22 +114,8 @@ def device( # noqa C901
123
114
# Determine the optimizer
124
115
optimizer = OptimType .EXACT_ROWWISE_ADAGRAD if row_wise else OptimType .EXACT_ADAGRAD
125
116
126
- # Determine the embedding location
127
- embedding_location = str_to_embedding_location (str (managed ))
128
- if embedding_location is EmbeddingLocation .DEVICE and not torch .cuda .is_available ():
129
- embedding_location = EmbeddingLocation .HOST
130
-
131
- # Determine the pooling mode
132
- pooling_mode = str_to_pooling_mode (pooling )
133
-
134
117
# Construct the common split arguments for the embedding op
135
- common_split_args : Dict [str , Any ] = {
136
- "weights_precision" : weights_precision ,
137
- "stochastic_rounding" : stoc ,
138
- "output_dtype" : output_dtype ,
139
- "pooling_mode" : pooling_mode ,
140
- "bounds_check_mode" : BoundsCheckMode (bounds_check_mode ),
141
- "uvm_host_mapped" : uvm_host_mapped ,
118
+ common_split_args : Dict [str , Any ] = embconfig .split_args () | {
142
119
"optimizer" : optimizer ,
143
120
"learning_rate" : 0.1 ,
144
121
"eps" : 0.1 ,
@@ -154,7 +131,7 @@ def device( # noqa C901
154
131
)
155
132
for d in Ds
156
133
],
157
- pooling_mode = pooling_mode ,
134
+ pooling_mode = embconfig . pooling_mode ,
158
135
use_cpu = not torch .cuda .is_available (),
159
136
)
160
137
elif emb_op_type == "ssd" :
@@ -177,7 +154,7 @@ def device( # noqa C901
177
154
(
178
155
tbeconfig .E ,
179
156
d ,
180
- embedding_location ,
157
+ embconfig . embedding_location ,
181
158
(
182
159
ComputeDevice .CUDA
183
160
if torch .cuda .is_available ()
@@ -187,25 +164,27 @@ def device( # noqa C901
187
164
for d in Ds
188
165
],
189
166
cache_precision = (
190
- weights_precision if cache_precision is None else cache_precision
167
+ embconfig .weights_dtype
168
+ if embconfig .cache_dtype is None
169
+ else embconfig .cache_dtype
191
170
),
192
171
cache_algorithm = CacheAlgorithm .LRU ,
193
172
cache_load_factor = cache_load_factor ,
194
173
** common_split_args ,
195
174
)
196
175
embedding_op = embedding_op .to (get_device ())
197
176
198
- if weights_precision == SparseType .INT8 :
177
+ if embconfig . weights_dtype == SparseType .INT8 :
199
178
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
200
179
# min_val: float, max_val: float) -> None, (self:
201
180
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
202
181
# None, Tensor, Module]` is not a function.
203
182
embedding_op .init_embedding_weights_uniform (- 0.0003 , 0.0003 )
204
183
205
184
nparams = sum (d * tbeconfig .E for d in Ds )
206
- param_size_multiplier = weights_precision .bit_rate () / 8.0
207
- output_size_multiplier = output_dtype .bit_rate () / 8.0
208
- if pooling_mode .do_pooling ():
185
+ param_size_multiplier = embconfig . weights_dtype .bit_rate () / 8.0
186
+ output_size_multiplier = embconfig . output_dtype .bit_rate () / 8.0
187
+ if embconfig . pooling_mode .do_pooling ():
209
188
read_write_bytes = (
210
189
output_size_multiplier * tbeconfig .batch_params .B * sum (Ds )
211
190
+ param_size_multiplier
@@ -225,7 +204,7 @@ def device( # noqa C901
225
204
* tbeconfig .pooling_params .L
226
205
)
227
206
228
- logging .info (f"Managed option: { managed } " )
207
+ logging .info (f"Managed option: { embconfig . embedding_location } " )
229
208
logging .info (
230
209
f"Embedding parameters: { nparams / 1.0e9 : .2f} GParam, "
231
210
f"{ nparams * param_size_multiplier / 1.0e9 : .2f} GB"
@@ -274,11 +253,11 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
274
253
f"T: { time_per_iter * 1.0e6 :.0f} us"
275
254
)
276
255
277
- if output_dtype == SparseType .INT8 :
256
+ if embconfig . output_dtype == SparseType .INT8 :
278
257
# backward bench not representative
279
258
return
280
259
281
- if pooling_mode .do_pooling ():
260
+ if embconfig . pooling_mode .do_pooling ():
282
261
grad_output = torch .randn (tbeconfig .batch_params .B , sum (Ds )).to (get_device ())
283
262
else :
284
263
grad_output = torch .randn (
0 commit comments