Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for te v2.0 #12273

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
50 changes: 34 additions & 16 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union

import lightning.pytorch as L
Expand All @@ -32,6 +34,7 @@
from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule
from nemo.utils import logging
from nemo.utils.import_utils import safe_import
from nemo.utils.te_utils import te_version

_, HAVE_TE = safe_import("transformer_engine")

Expand Down Expand Up @@ -218,22 +221,37 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC
else:
vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by)

model = MCoreGPTModel(
self,
transformer_layer_spec=transformer_layer_spec,
vocab_size=vocab_size,
max_sequence_length=self.seq_length,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
rotary_base=self.rotary_base,
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
pre_process=pre_process or parallel_state.is_pipeline_first_stage(),
post_process=post_process or parallel_state.is_pipeline_last_stage(),
scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel,
)
# Set FP8 recipe to DelayedScaling to initialize model with float8 precision.
build_model_context = nullcontext
if self.fp8 is not None:
assert HAVE_TE, "Transformer Engine is required for FP8 training."
te_pytorch, _ = safe_import("transformer_engine.pytorch")
fp8_model_init = te_pytorch.fp8_model_init
if te_version() >= (2, 0):
# In TE 2.0, the default recipe is MXFP8BlockScaling, need to change it to DelayedScaling
te_recipe, _ = safe_import("transformer_engine.common.recipe")
recipe = te_recipe.DelayedScaling()
build_model_context = partial(fp8_model_init, recipe=recipe)
else:
build_model_context = fp8_model_init

with build_model_context():
model = MCoreGPTModel(
self,
transformer_layer_spec=transformer_layer_spec,
vocab_size=vocab_size,
max_sequence_length=self.seq_length,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
rotary_base=self.rotary_base,
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
pre_process=pre_process or parallel_state.is_pipeline_first_stage(),
post_process=post_process or parallel_state.is_pipeline_last_stage(),
scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel,
)

# If using full TE layer, need to set TP, CP group since the module call
# is not routed through megatron core, which normally handles passing the
Expand Down
Loading