diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index b92ca669db49..b5c375830ee3 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext from dataclasses import dataclass +from functools import partial from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union import lightning.pytorch as L @@ -32,6 +34,7 @@ from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule from nemo.utils import logging from nemo.utils.import_utils import safe_import +from nemo.utils.te_utils import te_version _, HAVE_TE = safe_import("transformer_engine") @@ -218,22 +221,37 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC else: vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) - model = MCoreGPTModel( - self, - transformer_layer_spec=transformer_layer_spec, - vocab_size=vocab_size, - max_sequence_length=self.seq_length, - fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, - parallel_output=self.parallel_output, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - position_embedding_type=self.position_embedding_type, - rotary_percent=self.rotary_percent, - rotary_base=self.rotary_base, - seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=pre_process or parallel_state.is_pipeline_first_stage(), - post_process=post_process or parallel_state.is_pipeline_last_stage(), - scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, - ) + # Set FP8 recipe to DelayedScaling to initialize model with float8 precision. + build_model_context = nullcontext + if self.fp8 is not None: + assert HAVE_TE, "Transformer Engine is required for FP8 training." + te_pytorch, _ = safe_import("transformer_engine.pytorch") + fp8_model_init = te_pytorch.fp8_model_init + if te_version() >= (2, 0): + # In TE 2.0, the default recipe is MXFP8BlockScaling, need to change it to DelayedScaling + te_recipe, _ = safe_import("transformer_engine.common.recipe") + recipe = te_recipe.DelayedScaling() + build_model_context = partial(fp8_model_init, recipe=recipe) + else: + build_model_context = fp8_model_init + + with build_model_context(): + model = MCoreGPTModel( + self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=self.seq_length, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), + scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, + ) # If using full TE layer, need to set TP, CP group since the module call # is not routed through megatron core, which normally handles passing the