Skip to content

improve reshard_after_forward logic #1094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions torchtitan/models/llama3/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,30 +353,41 @@ def apply_fsdp(
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

for layer_id, transformer_block in model.layers.items():
if reshard_after_forward_policy == "always":
match reshard_after_forward_policy:
case "always":
reshard_after_forward = True
elif reshard_after_forward_policy == "never":
case "never":
reshard_after_forward = False
elif reshard_after_forward_policy == "default":
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
else:
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = not pp_enabled
case _:
raise ValueError(
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
)

if model.tok_embeddings is not None:
fully_shard(
model.tok_embeddings,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
for layer_id, transformer_block in model.layers.items():
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
# As an optimization, do not reshard_after_forward the last layers by default
# since FSDP would prefetch them immediately after the forward pass
if model.norm is not None and model.output is not None:
fully_shard(
[model.norm, model.output],
**fsdp_config,
reshard_after_forward=reshard_after_forward_policy == "always",
)
fully_shard(model, **fsdp_config)


def apply_ddp(
Expand Down
10 changes: 5 additions & 5 deletions torchtitan/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ def __init__(self, gc_freq=1000):
assert gc_freq > 0, "gc_freq must be a positive integer"
self.gc_freq = gc_freq
gc.disable()
self.collect("Initial GC collection.")
self.collect("Initial GC collection")

def run(self, step_count):
if step_count > 1 and step_count % self.gc_freq == 0:
self.collect("Peforming periodical GC collection.")
self.collect("Peforming periodical GC collection")

@staticmethod
def collect(reason: str):
begin = time.monotonic()
gc.collect(1)
logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin)
logger.info("[GC] %s %.2f seconds", reason, time.monotonic() - begin)


# hardcoded BF16 type peak flops for NVIDIA A100, H100, H200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC
Expand Down Expand Up @@ -132,12 +132,12 @@ def check_if_feature_in_pytorch(
if "git" in torch.__version__: # pytorch is built from source
# notify users to check if the pull request is included in their pytorch
logger.warning(
"detected that the pytorch is built from source. Please make sure the PR "
"Detected that the pytorch is built from source. Please make sure the PR "
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
)
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
logger.warning(
f"detected that the pytorch version {torch.__version__} is older than "
f"Detected that the pytorch version {torch.__version__} is older than "
f"{min_nightly_version}. Please upgrade a newer version to include the "
f"change in ({pull_request_link}) for correct {feature_name}."
)
6 changes: 3 additions & 3 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(self, job_config: JobConfig):
f"global batch size {job_config.training.batch_size * dp_degree}, "
f"sequence length {job_config.training.seq_len}, "
f"total steps {job_config.training.steps} "
f"(warmup {job_config.lr_scheduler.warmup_steps})."
f"(warmup {job_config.lr_scheduler.warmup_steps})"
)

def next_batch(
Expand Down Expand Up @@ -400,7 +400,7 @@ def train(self):
job_config = self.job_config

self.checkpointer.load(step=job_config.checkpoint.load_step)
logger.info(f"Training starts at step {self.step + 1}.")
logger.info(f"Training starts at step {self.step + 1}")

with maybe_enable_profiling(
job_config, global_step=self.step
Expand Down Expand Up @@ -478,4 +478,4 @@ def close(self) -> None:

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
logger.info("Process group destroyed")
Loading