Skip to content

Commit

Permalink
delete delayed scaling from torchao.float8 (#1753)
Browse files Browse the repository at this point in the history
Update

[ghstack-poisoned]
  • Loading branch information
vkuzo authored Feb 22, 2025
1 parent 25ddb77 commit d370196
Show file tree
Hide file tree
Showing 25 changed files with 93 additions and 2,296 deletions.
53 changes: 13 additions & 40 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_tensor import ScaledMMConfig

# estimating TOPs for matmuls in fp32, fp16, fp8
Expand Down Expand Up @@ -122,39 +118,18 @@ def main(
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down Expand Up @@ -185,7 +160,7 @@ def main(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
scaling_repr = linear_float8.extra_repr()

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand All @@ -196,8 +171,6 @@ def main(
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
Expand Down
180 changes: 0 additions & 180 deletions benchmarks/float8/bench_multi_gpu.py

This file was deleted.

47 changes: 0 additions & 47 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@
)

from torchao.float8 import (
CastConfig,
Float8LinearConfig,
ScalingType,
convert_to_float8_training,
)
from torchao.float8.roofline_utils import (
Expand Down Expand Up @@ -219,24 +217,6 @@ def run(
scaling_type_weight="dynamic",
scaling_type_grad_output="dynamic",
)
fp8_mem_time_sympy_del_limit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=True,
scaling_type_input="delayed",
scaling_type_weight="delayed",
scaling_type_grad_output="delayed",
)
fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=False,
scaling_type_input="delayed",
scaling_type_weight="delayed",
scaling_type_grad_output="delayed",
)

if gemm_time_strategy == "roofline":
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
Expand All @@ -258,16 +238,12 @@ def run(
# roofline memory overhead estimates
"fp8_oh_dyn_limit",
"fp8_oh_dyn_nolimit",
"fp8_oh_del_limit",
"fp8_oh_del_nolimit",
# actual e2e measurements
"bf16_s",
"fp8_dyn_s",
"fp8_del_s",
"fp8_dyn_axs_s",
# 'fp8_lw_s',
"fp8_dyn_sp",
"fp8_del_sp",
"fp8_dyn_axs_sp",
# 'fp8_lw_sp',
]
Expand Down Expand Up @@ -309,12 +285,6 @@ def run(
fp8_mem_time_dyn_nolimit_s = (
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_mem_time_del_limit_s = (
fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_mem_time_del_nolimit_s = (
fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)

# create the model
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
Expand All @@ -333,19 +303,6 @@ def run(
m_fp8_dyn = torch.compile(m_fp8_dyn)
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)

# get the float8 delayed scaling gpu kernel time
torch._dynamo.reset()
config = Float8LinearConfig(
enable_amax_init=False,
enable_pre_and_post_forward=False,
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_del = torch.compile(m_fp8_del)
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = Float8LinearConfig.from_recipe_name("rowwise")
Expand Down Expand Up @@ -374,16 +331,12 @@ def run(
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
fp8_mem_time_del_limit_s,
fp8_mem_time_del_nolimit_s,
# e2e numbers
bf16_time_actual_s,
fp8_dyn_time_actual_s,
fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
# bf16_time_actual_s / fp8_lw_time_actual_s,
]
Expand Down
Loading

0 comments on commit d370196

Please sign in to comment.