Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanups for the EEG-based TBE benchmark CLI, pt 2 #3815

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/scripts/fbgemm_gpu_benchmarks.bash
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ run_tbe_microbench () {
--tbe-pooling-size 55 \
--tbe-num-embeddings 10000000 \
--tbe-num-tables 1 \
--weights-precision fp16 \
--cache-precision "${cache_type}" \
--output-dtype bf16 \
--managed="${managed}" \
--emb-weights-dtype fp16 \
--emb-cache-dtype "${cache_type}" \
--emb-output-dtype bf16 \
--emb-location "${managed}" \
--row-wise
}

Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)
reporter = BenchmarkReporter(True)


def kineto_trace_profiler(p: profile, trace_info: tuple[str, str, str, str]) -> float:
Expand Down
86 changes: 32 additions & 54 deletions fbgemm_gpu/bench/tbe/tbe_training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
EmbeddingLocation,
str_to_embedding_location,
str_to_pooling_mode,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
Expand All @@ -32,6 +29,7 @@
)
from fbgemm_gpu.tbe.bench import (
benchmark_requests,
EmbeddingOpsCommonConfigLoader,
TBEBenchmarkingConfigLoader,
TBEDataConfigLoader,
)
Expand All @@ -50,50 +48,39 @@ def cli() -> None:


@cli.command()
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
@click.option("--cache-precision", type=SparseType, default=None)
@click.option("--stoc", is_flag=True, default=False)
@click.option(
"--managed",
default="device",
type=click.Choice(["device", "managed", "managed_caching"], case_sensitive=False),
)
@click.option(
"--emb-op-type",
default="split",
type=click.Choice(["split", "dense", "ssd"], case_sensitive=False),
help="The type of the embedding op to benchmark",
)
@click.option(
"--row-wise/--no-row-wise",
default=True,
help="Whether to use row-wise adagrad optimzier or not",
)
@click.option("--row-wise/--no-row-wise", default=True)
@click.option("--pooling", type=str, default="sum")
@click.option("--weighted-num-requires-grad", type=int, default=None)
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option(
"--uvm-host-mapped",
is_flag=True,
default=False,
help="Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister)",
"--weighted-num-requires-grad",
type=int,
default=None,
help="The number of weighted tables that require gradient",
)
@click.option(
"--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix"
"--ssd-prefix",
type=str,
default="/tmp/ssd_benchmark",
help="SSD directory prefix",
)
@click.option("--cache-load-factor", default=0.2)
@TBEBenchmarkingConfigLoader.options
@TBEDataConfigLoader.options
@EmbeddingOpsCommonConfigLoader.options
@click.pass_context
def device( # noqa C901
context: click.Context,
emb_op_type: click.Choice,
weights_precision: SparseType,
cache_precision: Optional[SparseType],
stoc: bool,
managed: click.Choice,
row_wise: bool,
pooling: str,
weighted_num_requires_grad: Optional[int],
bounds_check_mode: int,
output_dtype: SparseType,
uvm_host_mapped: bool,
cache_load_factor: float,
# SSD params
ssd_prefix: str,
Expand All @@ -110,6 +97,9 @@ def device( # noqa C901
# Load TBE data configuration from cli arguments
tbeconfig = TBEDataConfigLoader.load(context)

# Load common embedding op configuration from cli arguments
embconfig = EmbeddingOpsCommonConfigLoader.load(context)

# Generate feature_requires_grad
feature_requires_grad = (
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
Expand All @@ -123,22 +113,8 @@ def device( # noqa C901
# Determine the optimizer
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD

# Determine the embedding location
embedding_location = str_to_embedding_location(str(managed))
if embedding_location is EmbeddingLocation.DEVICE and not torch.cuda.is_available():
embedding_location = EmbeddingLocation.HOST

# Determine the pooling mode
pooling_mode = str_to_pooling_mode(pooling)

# Construct the common split arguments for the embedding op
common_split_args: Dict[str, Any] = {
"weights_precision": weights_precision,
"stochastic_rounding": stoc,
"output_dtype": output_dtype,
"pooling_mode": pooling_mode,
"bounds_check_mode": BoundsCheckMode(bounds_check_mode),
"uvm_host_mapped": uvm_host_mapped,
common_split_args: Dict[str, Any] = embconfig.split_args() | {
"optimizer": optimizer,
"learning_rate": 0.1,
"eps": 0.1,
Expand All @@ -154,7 +130,7 @@ def device( # noqa C901
)
for d in Ds
],
pooling_mode=pooling_mode,
pooling_mode=embconfig.pooling_mode,
use_cpu=not torch.cuda.is_available(),
)
elif emb_op_type == "ssd":
Expand All @@ -177,7 +153,7 @@ def device( # noqa C901
(
tbeconfig.E,
d,
embedding_location,
embconfig.embedding_location,
(
ComputeDevice.CUDA
if torch.cuda.is_available()
Expand All @@ -187,25 +163,27 @@ def device( # noqa C901
for d in Ds
],
cache_precision=(
weights_precision if cache_precision is None else cache_precision
embconfig.weights_dtype
if embconfig.cache_dtype is None
else embconfig.cache_dtype
),
cache_algorithm=CacheAlgorithm.LRU,
cache_load_factor=cache_load_factor,
**common_split_args,
)
embedding_op = embedding_op.to(get_device())

if weights_precision == SparseType.INT8:
if embconfig.weights_dtype == SparseType.INT8:
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
# min_val: float, max_val: float) -> None, (self:
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
# None, Tensor, Module]` is not a function.
embedding_op.init_embedding_weights_uniform(-0.0003, 0.0003)

nparams = sum(d * tbeconfig.E for d in Ds)
param_size_multiplier = weights_precision.bit_rate() / 8.0
output_size_multiplier = output_dtype.bit_rate() / 8.0
if pooling_mode.do_pooling():
param_size_multiplier = embconfig.weights_dtype.bit_rate() / 8.0
output_size_multiplier = embconfig.output_dtype.bit_rate() / 8.0
if embconfig.pooling_mode.do_pooling():
read_write_bytes = (
output_size_multiplier * tbeconfig.batch_params.B * sum(Ds)
+ param_size_multiplier
Expand All @@ -225,7 +203,7 @@ def device( # noqa C901
* tbeconfig.pooling_params.L
)

logging.info(f"Managed option: {managed}")
logging.info(f"Managed option: {embconfig.embedding_location}")
logging.info(
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
Expand Down Expand Up @@ -274,11 +252,11 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
f"T: {time_per_iter * 1.0e6:.0f}us"
)

if output_dtype == SparseType.INT8:
if embconfig.output_dtype == SparseType.INT8:
# backward bench not representative
return

if pooling_mode.do_pooling():
if embconfig.pooling_mode.do_pooling():
grad_output = torch.randn(tbeconfig.batch_params.B, sum(Ds)).to(get_device())
else:
grad_output = torch.randn(
Expand Down
50 changes: 26 additions & 24 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,20 @@ class EmbeddingLocation(enum.IntEnum):
HOST = 3
MTIA = 4


def str_to_embedding_location(key: str) -> EmbeddingLocation:
lookup = {
"device": EmbeddingLocation.DEVICE,
"managed": EmbeddingLocation.MANAGED,
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
"host": EmbeddingLocation.HOST,
"mtia": EmbeddingLocation.MTIA,
}
if key in lookup:
return lookup[key]
else:
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
@classmethod
# pyre-ignore[3]
def from_str(cls, key: str):
lookup = {
"device": EmbeddingLocation.DEVICE,
"managed": EmbeddingLocation.MANAGED,
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
"host": EmbeddingLocation.HOST,
"mtia": EmbeddingLocation.MTIA,
}
if key in lookup:
return lookup[key]
else:
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")


class CacheAlgorithm(enum.Enum):
Expand Down Expand Up @@ -74,17 +75,18 @@ class PoolingMode(enum.IntEnum):
def do_pooling(self) -> bool:
return self is not PoolingMode.NONE


def str_to_pooling_mode(key: str) -> PoolingMode:
lookup = {
"sum": PoolingMode.SUM,
"mean": PoolingMode.MEAN,
"none": PoolingMode.NONE,
}
if key in lookup:
return lookup[key]
else:
raise ValueError(f"Cannot parse value into PoolingMode: {key}")
@classmethod
# pyre-ignore[3]
def from_str(cls, key: str):
lookup = {
"sum": PoolingMode.SUM,
"mean": PoolingMode.MEAN,
"none": PoolingMode.NONE,
}
if key in lookup:
return lookup[key]
else:
raise ValueError(f"Cannot parse value into PoolingMode: {key}")


class BoundsCheckMode(enum.IntEnum):
Expand Down
11 changes: 8 additions & 3 deletions fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
benchmark_requests_refer,
benchmark_vbe,
)
from .config import TBEDataConfig # noqa F401
from .config_loader import TBEDataConfigLoader # noqa F401
from .config_param_models import BatchParams, IndicesParams, PoolingParams # noqa F401
from .embedding_ops_common_config import EmbeddingOpsCommonConfigLoader # noqa F401
from .eval_compression import ( # noqa F401
benchmark_eval_compression,
EvalCompressionBenchmarkOutput,
)
from .reporter import BenchmarkReporter # noqa F401
from .tbe_data_config import TBEDataConfig # noqa F401
from .tbe_data_config_loader import TBEDataConfigLoader # noqa F401
from .tbe_data_config_param_models import ( # noqa F401
BatchParams,
IndicesParams,
PoolingParams,
)
from .utils import fill_random_scale_bias # noqa F401
Loading