diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 770531d50..557fca64c 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,7 +12,6 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -import contextlib import functools from typing import Optional @@ -24,20 +23,6 @@ from torchtitan.logging_utils import logger -@contextlib.contextmanager -def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool): - import float8_experimental.config as config - - prev = config.enable_fsdp_fp8_all_gather - torch.distributed.barrier() - config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather - try: - yield - finally: - torch.distributed.barrier() - config.enable_fsdp_fp8_all_gather = prev - - @functools.lru_cache(None) def is_sm90_or_later(): # Float8 is only supported on H100+ GPUs @@ -63,21 +48,26 @@ def maybe_build_fp8_linear( ) return try: - from float8_experimental.float8_linear import TensorScalingType - from float8_experimental.float8_linear_utils import ( - swap_linear_with_float8_linear, + from float8_experimental import ( + CastConfig, + convert_to_float8_training, + Float8LinearConfig, + ScalingType, ) # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( job_config.training.enable_fsdp_float8_all_gather and dp_enabled ) - with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): - swap_linear_with_float8_linear( - model, - scaling_type_w=TensorScalingType.DYNAMIC, - skip_fqn_list=["output"], - ) + float8_config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC), + ) + convert_to_float8_training( + model, + config=float8_config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) logger.info( f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" ) @@ -102,6 +92,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp( "Skipped precomputing fp8 scales because SM90 or later is not available", ) return - from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp + from float8_experimental import precompute_float8_dynamic_scale_for_fsdp precompute_float8_dynamic_scale_for_fsdp(model)