Skip to content

Commit d619858

Browse files
q10facebook-github-bot
authored andcommitted
Cleanups for the EEG-based TBE benchmark CLI, pt 2 (pytorch#3815)
Summary: Pull Request resolved: pytorch#3815 X-link: facebookresearch/FBGEMM#890 - Cleanups for the EEG-based TBE benchmark CLI, pt 2 Reviewed By: jiawenliu64 Differential Revision: D70426271
1 parent 05d089a commit d619858

8 files changed

+213
-80
lines changed

fbgemm_gpu/bench/tbe/tbe_training_benchmark.py

+32-53
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
import torch
2020
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
2121
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
22-
BoundsCheckMode,
2322
CacheAlgorithm,
2423
EmbeddingLocation,
25-
str_to_embedding_location,
2624
str_to_pooling_mode,
2725
)
2826
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
@@ -32,6 +30,7 @@
3230
)
3331
from fbgemm_gpu.tbe.bench import (
3432
benchmark_requests,
33+
EmbeddingOpsCommonConfigLoader,
3534
TBEBenchmarkingConfigLoader,
3635
TBEDataConfigLoader,
3736
)
@@ -50,50 +49,39 @@ def cli() -> None:
5049

5150

5251
@cli.command()
53-
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
54-
@click.option("--cache-precision", type=SparseType, default=None)
55-
@click.option("--stoc", is_flag=True, default=False)
56-
@click.option(
57-
"--managed",
58-
default="device",
59-
type=click.Choice(["device", "managed", "managed_caching"], case_sensitive=False),
60-
)
6152
@click.option(
6253
"--emb-op-type",
6354
default="split",
6455
type=click.Choice(["split", "dense", "ssd"], case_sensitive=False),
56+
help="The type of the embedding op to benchmark",
57+
)
58+
@click.option(
59+
"--row-wise/--no-row-wise",
60+
default=True,
61+
help="Whether to use row-wise adagrad optimzier or not",
6562
)
66-
@click.option("--row-wise/--no-row-wise", default=True)
67-
@click.option("--pooling", type=str, default="sum")
68-
@click.option("--weighted-num-requires-grad", type=int, default=None)
69-
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value)
70-
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
7163
@click.option(
72-
"--uvm-host-mapped",
73-
is_flag=True,
74-
default=False,
75-
help="Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister)",
64+
"--weighted-num-requires-grad",
65+
type=int,
66+
default=None,
67+
help="The number of weighted tables that require gradient",
7668
)
7769
@click.option(
78-
"--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix"
70+
"--ssd-prefix",
71+
type=str,
72+
default="/tmp/ssd_benchmark",
73+
help="SSD directory prefix",
7974
)
8075
@click.option("--cache-load-factor", default=0.2)
8176
@TBEBenchmarkingConfigLoader.options
8277
@TBEDataConfigLoader.options
78+
@EmbeddingOpsCommonConfigLoader.options
8379
@click.pass_context
8480
def device( # noqa C901
8581
context: click.Context,
8682
emb_op_type: click.Choice,
87-
weights_precision: SparseType,
88-
cache_precision: Optional[SparseType],
89-
stoc: bool,
90-
managed: click.Choice,
9183
row_wise: bool,
92-
pooling: str,
9384
weighted_num_requires_grad: Optional[int],
94-
bounds_check_mode: int,
95-
output_dtype: SparseType,
96-
uvm_host_mapped: bool,
9785
cache_load_factor: float,
9886
# SSD params
9987
ssd_prefix: str,
@@ -110,6 +98,9 @@ def device( # noqa C901
11098
# Load TBE data configuration from cli arguments
11199
tbeconfig = TBEDataConfigLoader.load(context)
112100

101+
# Load common embedding op configuration from cli arguments
102+
embconfig = EmbeddingOpsCommonConfigLoader.load(context)
103+
113104
# Generate feature_requires_grad
114105
feature_requires_grad = (
115106
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
@@ -123,22 +114,8 @@ def device( # noqa C901
123114
# Determine the optimizer
124115
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD
125116

126-
# Determine the embedding location
127-
embedding_location = str_to_embedding_location(str(managed))
128-
if embedding_location is EmbeddingLocation.DEVICE and not torch.cuda.is_available():
129-
embedding_location = EmbeddingLocation.HOST
130-
131-
# Determine the pooling mode
132-
pooling_mode = str_to_pooling_mode(pooling)
133-
134117
# Construct the common split arguments for the embedding op
135-
common_split_args: Dict[str, Any] = {
136-
"weights_precision": weights_precision,
137-
"stochastic_rounding": stoc,
138-
"output_dtype": output_dtype,
139-
"pooling_mode": pooling_mode,
140-
"bounds_check_mode": BoundsCheckMode(bounds_check_mode),
141-
"uvm_host_mapped": uvm_host_mapped,
118+
common_split_args: Dict[str, Any] = embconfig.split_args() | {
142119
"optimizer": optimizer,
143120
"learning_rate": 0.1,
144121
"eps": 0.1,
@@ -154,7 +131,7 @@ def device( # noqa C901
154131
)
155132
for d in Ds
156133
],
157-
pooling_mode=pooling_mode,
134+
pooling_mode=embconfig.pooling_mode,
158135
use_cpu=not torch.cuda.is_available(),
159136
)
160137
elif emb_op_type == "ssd":
@@ -177,7 +154,7 @@ def device( # noqa C901
177154
(
178155
tbeconfig.E,
179156
d,
180-
embedding_location,
157+
embconfig.embedding_location,
181158
(
182159
ComputeDevice.CUDA
183160
if torch.cuda.is_available()
@@ -187,25 +164,27 @@ def device( # noqa C901
187164
for d in Ds
188165
],
189166
cache_precision=(
190-
weights_precision if cache_precision is None else cache_precision
167+
embconfig.weights_dtype
168+
if embconfig.cache_dtype is None
169+
else embconfig.cache_dtype
191170
),
192171
cache_algorithm=CacheAlgorithm.LRU,
193172
cache_load_factor=cache_load_factor,
194173
**common_split_args,
195174
)
196175
embedding_op = embedding_op.to(get_device())
197176

198-
if weights_precision == SparseType.INT8:
177+
if embconfig.weights_dtype == SparseType.INT8:
199178
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
200179
# min_val: float, max_val: float) -> None, (self:
201180
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
202181
# None, Tensor, Module]` is not a function.
203182
embedding_op.init_embedding_weights_uniform(-0.0003, 0.0003)
204183

205184
nparams = sum(d * tbeconfig.E for d in Ds)
206-
param_size_multiplier = weights_precision.bit_rate() / 8.0
207-
output_size_multiplier = output_dtype.bit_rate() / 8.0
208-
if pooling_mode.do_pooling():
185+
param_size_multiplier = embconfig.weights_dtype.bit_rate() / 8.0
186+
output_size_multiplier = embconfig.output_dtype.bit_rate() / 8.0
187+
if embconfig.pooling_mode.do_pooling():
209188
read_write_bytes = (
210189
output_size_multiplier * tbeconfig.batch_params.B * sum(Ds)
211190
+ param_size_multiplier
@@ -225,7 +204,7 @@ def device( # noqa C901
225204
* tbeconfig.pooling_params.L
226205
)
227206

228-
logging.info(f"Managed option: {managed}")
207+
logging.info(f"Managed option: {embconfig.embedding_location}")
229208
logging.info(
230209
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
231210
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
@@ -274,11 +253,11 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
274253
f"T: {time_per_iter * 1.0e6:.0f}us"
275254
)
276255

277-
if output_dtype == SparseType.INT8:
256+
if embconfig.output_dtype == SparseType.INT8:
278257
# backward bench not representative
279258
return
280259

281-
if pooling_mode.do_pooling():
260+
if embconfig.pooling_mode.do_pooling():
282261
grad_output = torch.randn(tbeconfig.batch_params.B, sum(Ds)).to(get_device())
283262
else:
284263
grad_output = torch.randn(

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,20 @@ class EmbeddingLocation(enum.IntEnum):
3333
HOST = 3
3434
MTIA = 4
3535

36-
37-
def str_to_embedding_location(key: str) -> EmbeddingLocation:
38-
lookup = {
39-
"device": EmbeddingLocation.DEVICE,
40-
"managed": EmbeddingLocation.MANAGED,
41-
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
42-
"host": EmbeddingLocation.HOST,
43-
"mtia": EmbeddingLocation.MTIA,
44-
}
45-
if key in lookup:
46-
return lookup[key]
47-
else:
48-
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
36+
@classmethod
37+
# pyre-ignore[3]
38+
def from_str(cls, key: str):
39+
lookup = {
40+
"device": EmbeddingLocation.DEVICE,
41+
"managed": EmbeddingLocation.MANAGED,
42+
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
43+
"host": EmbeddingLocation.HOST,
44+
"mtia": EmbeddingLocation.MTIA,
45+
}
46+
if key in lookup:
47+
return lookup[key]
48+
else:
49+
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
4950

5051

5152
class CacheAlgorithm(enum.Enum):

fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
benchmark_requests_refer,
2020
benchmark_vbe,
2121
)
22-
from .config import TBEDataConfig # noqa F401
23-
from .config_loader import TBEDataConfigLoader # noqa F401
24-
from .config_param_models import BatchParams, IndicesParams, PoolingParams # noqa F401
22+
from .embedding_ops_common_config import EmbeddingOpsCommonConfigLoader # noqa F401
2523
from .eval_compression import ( # noqa F401
2624
benchmark_eval_compression,
2725
EvalCompressionBenchmarkOutput,
2826
)
2927
from .reporter import BenchmarkReporter # noqa F401
28+
from .tbe_data_config import TBEDataConfig # noqa F401
29+
from .tbe_data_config_loader import TBEDataConfigLoader # noqa F401
30+
from .tbe_data_config_param_models import ( # noqa F401
31+
BatchParams,
32+
IndicesParams,
33+
PoolingParams,
34+
)
3035
from .utils import fill_random_scale_bias # noqa F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import dataclasses
11+
from typing import Any, Dict, Optional
12+
13+
import click
14+
import torch
15+
from fbgemm_gpu.split_embedding_configs import SparseType
16+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
17+
BoundsCheckMode,
18+
EmbeddingLocation,
19+
PoolingMode,
20+
str_to_pooling_mode,
21+
)
22+
23+
24+
@dataclasses.dataclass(frozen=True)
25+
class EmbeddingOpsCommonConfig:
26+
# Precision of the embedding weights
27+
weights_dtype: SparseType
28+
# Precision of the embedding cache
29+
cache_dtype: Optional[SparseType]
30+
# Precision of the embedding output
31+
output_dtype: SparseType
32+
# Enable stochastic rounding when performing quantization
33+
stochastic_rounding: bool
34+
# Pooling operation to perform
35+
pooling_mode: PoolingMode
36+
# Use host-mapped UVM buffers
37+
uvm_host_mapped: bool
38+
# Memory location of the embeddings
39+
embedding_location: EmbeddingLocation
40+
# Bounds check mode
41+
bounds_check_mode: BoundsCheckMode
42+
43+
# pyre-ignore [3]
44+
def validate(self):
45+
return self
46+
47+
def split_args(self) -> Dict[str, Any]:
48+
return {
49+
"weights_precision": self.weights_dtype,
50+
"stochastic_rounding": self.stochastic_rounding,
51+
"output_dtype": self.output_dtype,
52+
"pooling_mode": self.pooling_mode,
53+
"bounds_check_mode": self.bounds_check_mode,
54+
"uvm_host_mapped": self.uvm_host_mapped,
55+
}
56+
57+
58+
class EmbeddingOpsCommonConfigLoader:
59+
@classmethod
60+
# pyre-ignore [2]
61+
def options(cls, func) -> click.Command:
62+
options = [
63+
click.option(
64+
"--emb-weights-dtype",
65+
type=SparseType,
66+
default=SparseType.FP32,
67+
help="Precision of the embedding weights",
68+
),
69+
click.option(
70+
"--emb-cache-dtype",
71+
type=SparseType,
72+
default=None,
73+
help="Precision of the embedding cache",
74+
),
75+
click.option(
76+
"--emb-output-dtype",
77+
type=SparseType,
78+
default=SparseType.FP32,
79+
help="Precision of the embedding output",
80+
),
81+
click.option(
82+
"--emb-stochastic-rounding",
83+
is_flag=True,
84+
default=False,
85+
help="Enable stochastic rounding when performing quantization",
86+
),
87+
click.option(
88+
"--emb-pooling-mode",
89+
type=click.Choice(["sum", "mean", "none"], case_sensitive=False),
90+
default="sum",
91+
help="Pooling operation to perform",
92+
),
93+
click.option(
94+
"--emb-uvm-host-mapped",
95+
is_flag=True,
96+
default=False,
97+
help="Use host-mapped UVM buffers",
98+
),
99+
click.option(
100+
"--emb-location",
101+
default="device",
102+
type=click.Choice(
103+
["device", "managed", "managed_caching"], case_sensitive=False
104+
),
105+
help="Memory location of the embeddings",
106+
),
107+
click.option(
108+
"--emb-bounds-check",
109+
type=int,
110+
default=BoundsCheckMode.WARNING.value,
111+
help="Bounds check mode"
112+
f"Available modes: FATAL={BoundsCheckMode.FATAL.value}, "
113+
f"WARNING={BoundsCheckMode.WARNING.value}, "
114+
f"IGNORE={BoundsCheckMode.IGNORE.value}, "
115+
f"NONE={BoundsCheckMode.NONE.value}",
116+
),
117+
]
118+
119+
for option in reversed(options):
120+
func = option(func)
121+
return func
122+
123+
@classmethod
124+
def load(cls, context: click.Context) -> EmbeddingOpsCommonConfig:
125+
params = context.params
126+
127+
weights_dtype = params["emb_weights_dtype"]
128+
cache_dtype = params["emb_cache_dtype"]
129+
output_dtype = params["emb_output_dtype"]
130+
stochastic_rounding = params["emb_stochastic_rounding"]
131+
pooling_mode = str_to_pooling_mode(str(params["emb_pooling_mode"]))
132+
uvm_host_mapped = params["emb_uvm_host_mapped"]
133+
bounds_check_mode = BoundsCheckMode(params["emb_bounds_check"])
134+
135+
embedding_location = EmbeddingLocation.from_str(str(params["emb_location"]))
136+
if (
137+
embedding_location is EmbeddingLocation.DEVICE
138+
and not torch.cuda.is_available()
139+
):
140+
embedding_location = EmbeddingLocation.HOST
141+
142+
return EmbeddingOpsCommonConfig(
143+
weights_dtype,
144+
cache_dtype,
145+
output_dtype,
146+
stochastic_rounding,
147+
pooling_mode,
148+
uvm_host_mapped,
149+
embedding_location,
150+
bounds_check_mode,
151+
).validate()

0 commit comments

Comments
 (0)