From 097ae0d9fd1f893ac0f81e001c228ae2e7139ee1 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 19 Feb 2025 13:14:01 -0800 Subject: [PATCH 1/9] Fix FP8GlobalStateManager for FP8 since default recipe is MXFP8 Signed-off-by: Guyue Huang --- nemo/lightning/fabric/plugins.py | 2 ++ nemo/lightning/pytorch/plugins/mixed_precision.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index 58bf5f5ca9f9..8f72fa11cb42 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -68,7 +68,9 @@ def __init__( te_fp8, HAVE_TE = safe_import("transformer_engine.pytorch.fp8") assert HAVE_TE, "FP8 precision requires transformer engine." if fp8_params: + from transformer_engine.common.recipe import DelayedScaling te_fp8.FP8GlobalStateManager.FP8_PARAMETERS = True + te_fp8.FP8_RECIPE = DelayedScaling() fp8_param_gather = True dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 830978ba11e7..a6217dde6e56 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -101,7 +101,9 @@ def __init__( te_fp8, HAVE_TE = safe_import("transformer_engine.pytorch.fp8") assert HAVE_TE, "FP8 precision requires transformer engine." if fp8_params: + from transformer_engine.common.recipe import DelayedScaling te_fp8.FP8GlobalStateManager.FP8_PARAMETERS = True + te_fp8.FP8_RECIPE = DelayedScaling() fp8_param_gather = True dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 From a1dd339254dd049751b58f5de98496a532ae505b Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 19 Feb 2025 14:58:30 -0800 Subject: [PATCH 2/9] Fix Signed-off-by: Guyue Huang --- nemo/collections/llm/gpt/model/base.py | 34 ++++++++++++++------------ 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index b92ca669db49..ccb506d10485 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -218,22 +218,24 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC else: vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) - model = MCoreGPTModel( - self, - transformer_layer_spec=transformer_layer_spec, - vocab_size=vocab_size, - max_sequence_length=self.seq_length, - fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, - parallel_output=self.parallel_output, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - position_embedding_type=self.position_embedding_type, - rotary_percent=self.rotary_percent, - rotary_base=self.rotary_base, - seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=pre_process or parallel_state.is_pipeline_first_stage(), - post_process=post_process or parallel_state.is_pipeline_last_stage(), - scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, - ) + import transformer_engine + with transformer_engine.pytorch.fp8_model_init(recipe=transformer_engine.common.recipe.DelayedScaling()): + model = MCoreGPTModel( + self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=self.seq_length, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), + scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, + ) # If using full TE layer, need to set TP, CP group since the module call # is not routed through megatron core, which normally handles passing the From 97deae8dccdf0c6c6a4dada32632db829de851a7 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Thu, 20 Feb 2025 11:44:11 -0800 Subject: [PATCH 3/9] Revert changes to fabric, improve changes to GPT model Signed-off-by: Guyue Huang --- nemo/collections/llm/gpt/model/base.py | 43 +++++++++++-------- nemo/lightning/fabric/plugins.py | 4 +- .../pytorch/plugins/mixed_precision.py | 4 +- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index ccb506d10485..171a4a97f8cf 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -218,24 +218,31 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC else: vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) - import transformer_engine - with transformer_engine.pytorch.fp8_model_init(recipe=transformer_engine.common.recipe.DelayedScaling()): - model = MCoreGPTModel( - self, - transformer_layer_spec=transformer_layer_spec, - vocab_size=vocab_size, - max_sequence_length=self.seq_length, - fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, - parallel_output=self.parallel_output, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - position_embedding_type=self.position_embedding_type, - rotary_percent=self.rotary_percent, - rotary_base=self.rotary_base, - seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=pre_process or parallel_state.is_pipeline_first_stage(), - post_process=post_process or parallel_state.is_pipeline_last_stage(), - scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, - ) + # Set FP8 recipe to DelayedScaling to initialize model with float8 precision. + # If not set, the default recipe is MXFP8BlockScaling which will initialize + # model with mxfp8 precision. + if self.fp8 is not None: + assert HAVE_TE, "Transformer Engine is required for FP8 training." + te_fp8, _ = safe_import("transformer_engine.pytorch.fp8") + te_recipe, _ = safe_import("transformer_engine.common.recipe") + te_fp8.FP8GlobalStateManager.FP8_RECIPE = te_recipe.DelayedScaling() + + model = MCoreGPTModel( + self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=self.seq_length, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), + scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, + ) # If using full TE layer, need to set TP, CP group since the module call # is not routed through megatron core, which normally handles passing the diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index 8f72fa11cb42..d3f15ffc6a67 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generator, Literal, TypeVar +from typing import TYPE_CHECKING, Generator, Literal, TypeVar import torch from lightning.fabric.plugins.precision import MixedPrecision @@ -68,9 +68,7 @@ def __init__( te_fp8, HAVE_TE = safe_import("transformer_engine.pytorch.fp8") assert HAVE_TE, "FP8 precision requires transformer engine." if fp8_params: - from transformer_engine.common.recipe import DelayedScaling te_fp8.FP8GlobalStateManager.FP8_PARAMETERS = True - te_fp8.FP8_RECIPE = DelayedScaling() fp8_param_gather = True dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index a6217dde6e56..e93bc5f9c325 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -14,7 +14,7 @@ from contextlib import contextmanager from dataclasses import dataclass, fields -from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union +from typing import Generator, Literal, TypeVar, Union import torch from lightning.pytorch.plugins.precision import Precision @@ -101,9 +101,7 @@ def __init__( te_fp8, HAVE_TE = safe_import("transformer_engine.pytorch.fp8") assert HAVE_TE, "FP8 precision requires transformer engine." if fp8_params: - from transformer_engine.common.recipe import DelayedScaling te_fp8.FP8GlobalStateManager.FP8_PARAMETERS = True - te_fp8.FP8_RECIPE = DelayedScaling() fp8_param_gather = True dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 From aa01aca65ecde97916e82269300d5e92a4097500 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Thu, 20 Feb 2025 14:09:37 -0800 Subject: [PATCH 4/9] Use context instead of directly setting FP8GlobalStateManager Signed-off-by: Guyue Huang --- nemo/collections/llm/gpt/model/base.py | 51 +++++++++++++++----------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 171a4a97f8cf..4d41c112296d 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union +from functools import partial import lightning.pytorch as L import torch @@ -32,6 +33,7 @@ from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule from nemo.utils import logging from nemo.utils.import_utils import safe_import +from nemo.utils.te_utils import te_version _, HAVE_TE = safe_import("transformer_engine") @@ -219,30 +221,35 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) # Set FP8 recipe to DelayedScaling to initialize model with float8 precision. - # If not set, the default recipe is MXFP8BlockScaling which will initialize - # model with mxfp8 precision. if self.fp8 is not None: assert HAVE_TE, "Transformer Engine is required for FP8 training." - te_fp8, _ = safe_import("transformer_engine.pytorch.fp8") - te_recipe, _ = safe_import("transformer_engine.common.recipe") - te_fp8.FP8GlobalStateManager.FP8_RECIPE = te_recipe.DelayedScaling() - - model = MCoreGPTModel( - self, - transformer_layer_spec=transformer_layer_spec, - vocab_size=vocab_size, - max_sequence_length=self.seq_length, - fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, - parallel_output=self.parallel_output, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - position_embedding_type=self.position_embedding_type, - rotary_percent=self.rotary_percent, - rotary_base=self.rotary_base, - seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=pre_process or parallel_state.is_pipeline_first_stage(), - post_process=post_process or parallel_state.is_pipeline_last_stage(), - scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, - ) + te_pytorch, _ = safe_import("transformer_engine.pytorch") + fp8_model_init = te_pytorch.fp8_model_init + if te_version() >= (2, 0): + # In TE 2.0, the default recipe is MXFP8BlockScaling, need to change it to DelayedScaling + te_recipe, _ = safe_import("transformer_engine.common.recipe") + recipe = te_recipe.DelayedScaling() + build_model_context = partial(fp8_model_init, recipe=recipe) + else: + build_model_context = fp8_model_init + + with build_model_context(): + model = MCoreGPTModel( + self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=self.seq_length, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), + scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, + ) # If using full TE layer, need to set TP, CP group since the module call # is not routed through megatron core, which normally handles passing the From 430651208440e059f615fe0b8192b8bc6c91e49f Mon Sep 17 00:00:00 2001 From: guyueh1 Date: Thu, 20 Feb 2025 22:10:53 +0000 Subject: [PATCH 5/9] Apply isort and black reformatting Signed-off-by: guyueh1 --- nemo/collections/llm/gpt/model/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 4d41c112296d..b14b7f4ae43b 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -13,8 +13,8 @@ # limitations under the License. from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union from functools import partial +from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union import lightning.pytorch as L import torch From 499aa6895e2e243c07012649db208caed10dd9a0 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Thu, 20 Feb 2025 16:40:53 -0800 Subject: [PATCH 6/9] Fix for bf16 Signed-off-by: Guyue Huang --- nemo/collections/llm/gpt/model/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 4d41c112296d..c4b0be749912 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union from functools import partial +from contextlib import nullcontext import lightning.pytorch as L import torch @@ -221,6 +222,7 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) # Set FP8 recipe to DelayedScaling to initialize model with float8 precision. + build_model_context = nullcontext if self.fp8 is not None: assert HAVE_TE, "Transformer Engine is required for FP8 training." te_pytorch, _ = safe_import("transformer_engine.pytorch") From e5ce8fe395a14444721a908f834a14958b596603 Mon Sep 17 00:00:00 2001 From: guyueh1 Date: Fri, 21 Feb 2025 00:42:28 +0000 Subject: [PATCH 7/9] Apply isort and black reformatting Signed-off-by: guyueh1 --- nemo/collections/llm/gpt/model/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index cee0d4086760..b5c375830ee3 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext from dataclasses import dataclass from functools import partial -from contextlib import nullcontext from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union import lightning.pytorch as L From 80c43e39aea17d13234c0d1b913686b2ac3cd0ef Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Fri, 21 Feb 2025 11:13:33 -0800 Subject: [PATCH 8/9] Revert unintended changes Signed-off-by: Guyue Huang --- nemo/lightning/fabric/plugins.py | 2 +- nemo/lightning/pytorch/plugins/mixed_precision.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index d3f15ffc6a67..e787bd743441 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager -from typing import TYPE_CHECKING, Generator, Literal, TypeVar +from typing import Any, TYPE_CHECKING, Generator, Literal, TypeVar import torch from lightning.fabric.plugins.precision import MixedPrecision diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index e93bc5f9c325..830978ba11e7 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -14,7 +14,7 @@ from contextlib import contextmanager from dataclasses import dataclass, fields -from typing import Generator, Literal, TypeVar, Union +from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union import torch from lightning.pytorch.plugins.precision import Precision From 98d10a77e1dfde19ec8aaf3c9fd396f80f3b2dd5 Mon Sep 17 00:00:00 2001 From: guyueh1 Date: Fri, 21 Feb 2025 19:14:58 +0000 Subject: [PATCH 9/9] Apply isort and black reformatting Signed-off-by: guyueh1 --- nemo/lightning/fabric/plugins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index e787bd743441..58bf5f5ca9f9 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager -from typing import Any, TYPE_CHECKING, Generator, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Generator, Literal, TypeVar import torch from lightning.fabric.plugins.precision import MixedPrecision