Skip to content

Commit 89cdc43

Browse files
authored
[llama4] fall back to for-loop based MoE if not on SM90 or later (#1096)
as titled also fixing a breakage caused by #1086
1 parent 4f532e0 commit 89cdc43

File tree

4 files changed

+24
-10
lines changed

4 files changed

+24
-10
lines changed

torchtitan/components/float8.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
1515

16-
import torch
1716
import torch.nn as nn
1817

1918
from torchtitan.config_manager import JobConfig
@@ -23,19 +22,15 @@
2322
register_model_converter,
2423
)
2524
from torchtitan.tools.logging import logger
26-
27-
28-
def _is_sm89_or_later():
29-
# Float8 is only supported on SM89 or later (H100+ GPUs)
30-
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
25+
from torchtitan.tools.utils import has_cuda_capability
3126

3227

3328
class Float8Converter(ModelConverter):
3429
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3530
self.enabled = False
3631

3732
float8_config = job_config.float8
38-
if not _is_sm89_or_later():
33+
if not has_cuda_capability(8, 9):
3934
logger.warning(
4035
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
4136
)
@@ -73,7 +68,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
7368
)
7469

7570
else:
76-
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
71+
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
7772
enable_fsdp_float8_all_gather = (
7873
parallel_dims.dp_shard_enabled
7974
and float8_config.enable_fsdp_float8_all_gather

torchtitan/experiments/llama4/model/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from torchtitan.protocols.train_spec import BaseModelArgs
1616
from torchtitan.tools.logging import logger
17+
from torchtitan.tools.utils import has_cuda_capability
1718

1819

1920
@dataclass
@@ -54,6 +55,11 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
5455
self.vocab_size = tokenizer.n_words
5556
self.max_seq_len = job_config.training.seq_len
5657
self.use_flex_attn = job_config.model.use_flex_attn
58+
if self.use_grouped_mm and not has_cuda_capability(9, 0):
59+
logger.warning(
60+
"Failed to use grouped mm, which is only supported on SM90 or later",
61+
)
62+
self.use_grouped_mm = False
5763

5864
def get_nparams_and_flops(
5965
self, model: nn.Module, seq_len: int

torchtitan/experiments/llama4/model/moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def forward(
7979
# fall back to regular bmm between 3D tensors
8080
assert x.dim() == 3
8181

82+
assert (
83+
x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16
84+
), "torch._grouped_mm only supports bf16 dtypes"
8285
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
8386
h = h * torch._grouped_mm(x, self.w3, offs=offsets)
8487
out = torch._grouped_mm(h, self.w2, offs=offsets)
@@ -246,14 +249,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
246249
ALIGN_SIZE_M = 16
247250

248251
with torch.no_grad():
249-
permuted_indices, m_sizes = generate_permute_indices(
252+
(
253+
permuted_indices,
254+
num_local_tokens_per_expert,
255+
_,
256+
) = generate_permute_indices(
250257
num_local_tokens_per_expert,
251258
self.experts.num_experts,
252259
1,
253260
token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M,
254261
ALIGN_SIZE_M,
255262
)
256-
num_local_tokens_per_expert = m_sizes
257263
token_indices = torch.vstack(
258264
(token_indices, token_indices.new_zeros((dim)))
259265
)

torchtitan/tools/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
from torchtitan.tools.logging import logger
1717

1818

19+
def has_cuda_capability(major: int, minor: int) -> bool:
20+
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
21+
major,
22+
minor,
23+
)
24+
25+
1926
def get_device_info():
2027
device_type = _get_available_device_type()
2128
if device_type is None:

0 commit comments

Comments
 (0)