Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 57 additions & 21 deletions bionemo-recipes/models/esm2/src/esm/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import logging
from dataclasses import dataclass
from typing import Any
from typing import Any, TypedDict

import datasets
import torch
Expand Down Expand Up @@ -334,7 +334,7 @@ class ContextParallelDataLoaderWrapper:

def __init__(
self,
dataloader: torch.utils.data.DataLoader,
dataloader: torch.utils.data.DataLoader | None,
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
):
"""A dataloader wrapper that distributes the data across the context parallelism group.
Expand All @@ -348,15 +348,28 @@ def __init__(
cp_mesh: The context parallel mesh.
cp_rank: The rank of the current context parallel process.
"""
self.dataloader = dataloader
if cp_mesh.get_local_rank() == 0:
assert dataloader is not None, "dataloader must be provided on rank 0"
self.dataloader = dataloader

else:
assert dataloader is None, "Dataloader on non-rank 0 will not be used"

self.cp_rank = cp_mesh.get_local_rank()
self.cp_group = cp_mesh.get_group()
self.num_cp_ranks = cp_mesh.size()
self._iterator = None

logger.debug(
"Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s",
torch.distributed.get_rank() if torch.distributed.is_initialized() else "<not initialized>",
self.cp_rank,
)

def __iter__(self):
"""Make the dataloader iterable."""
self._iterator = iter(self.dataloader) # < --- collator output.
if self.cp_rank == 0:
self._iterator = iter(self.dataloader) # < --- collator output.
return self

def __next__(self):
Expand Down Expand Up @@ -385,24 +398,19 @@ def _send_data_to_cp_ranks(self):
batch: The batch for the current CP rank.

"""
if self.cp_rank == 0:
# Get data once, then make copies for each rank.
if self._iterator is None:
self._iterator = iter(self.dataloader)
combined_batch = next(self._iterator)
try:
combined_batch = next(self._iterator) if self.cp_rank == 0 else None
except StopIteration as ex:
# If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so
# that the dataloader can be restarted.
combined_batch = [ex] * self.num_cp_ranks

else:
combined_batch = None

scatter_object_output_list = [None]
# Note: This does not provide an async_op handle. Thus its blocking.
torch.distributed.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=combined_batch,
group=self.cp_group,
group_src=0,
)
return scatter_object_output_list[0]
batch_on_this_rank = _scatter_batch_to_cp_ranks(combined_batch, self.cp_group)

if isinstance(batch_on_this_rank, StopIteration):
raise batch_on_this_rank

return batch_on_this_rank


def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down Expand Up @@ -670,3 +678,31 @@ def _get_group_local_rank(group: torch.distributed.ProcessGroup | None = None) -
return torch.distributed.get_rank()
global_rank = torch.distributed.get_rank()
return torch.distributed.get_group_rank(group, global_rank)


class BatchType(TypedDict):
"""The fields in the batch dictionary for context parallel."""

input_ids: torch.Tensor
labels: torch.Tensor
cu_seq_lens_q: torch.Tensor
cu_seq_lens_k: torch.Tensor
cu_seq_lens_q_padded: torch.Tensor
cu_seq_lens_k_padded: torch.Tensor
max_length_q: int
max_length_k: int


def _scatter_batch_to_cp_ranks(
batch: list[BatchType] | list[StopIteration], cp_group: torch.distributed.ProcessGroup | None = None
) -> BatchType | StopIteration:
"""Scatter a batch to all the CP ranks."""
scatter_object_output_list = [None]
# Note: This does not provide an async_op handle. Thus its blocking.
torch.distributed.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=batch,
group=cp_group,
group_src=0,
)
return scatter_object_output_list[0]
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import copy
import unittest
from itertools import pairwise
from typing import Dict, Iterator, List
from unittest import mock

Expand Down Expand Up @@ -229,57 +228,6 @@ def size(self) -> int:
return self._size


def _fake_get_batch(
cu_seqlens_padded,
input_ids_padded,
labels_padded,
cp_size,
qvk_format,
cp_rank,
):
total_slices = 2 * cp_size
seq_tokens = input_ids_padded.view(-1)
seq_labels = labels_padded.view(-1)
shard_tokens: List[torch.Tensor] = []
shard_labels: List[torch.Tensor] = []

for start, end in pairwise(cu_seqlens_padded):
start_idx = int(start)
end_idx = int(end)
slice_size = (end_idx - start_idx) // total_slices

first_start = start_idx + (cp_rank * slice_size)
first_end = first_start + slice_size
second_start = start_idx + ((total_slices - cp_rank - 1) * slice_size)
second_end = second_start + slice_size

shard_tokens.append(torch.cat([seq_tokens[first_start:first_end], seq_tokens[second_start:second_end]]))
shard_labels.append(torch.cat([seq_labels[first_start:first_end], seq_labels[second_start:second_end]]))

return (
torch.cat(shard_tokens).unsqueeze(0),
torch.cat(shard_labels).unsqueeze(0),
)


def _make_cp_shards(base_batch: Dict[str, torch.Tensor], cp_size: int):
combined_batch = []
for cp_rank in range(cp_size):
input_ids_sharded, labels_sharded = _fake_get_batch(
cu_seqlens_padded=base_batch["cu_seq_lens_q_padded"],
input_ids_padded=base_batch["input_ids"],
labels_padded=base_batch["labels"],
cp_size=cp_size,
qvk_format="thd",
cp_rank=cp_rank,
)
batch_shard = dict(base_batch)
batch_shard["input_ids"] = input_ids_sharded
batch_shard["labels"] = labels_sharded
combined_batch.append(batch_shard)
return combined_batch


def test_pad_thd_sequences_for_cp():
pid = 1 # The pad token id.
label_pad = -100 # The label pad id.
Expand Down Expand Up @@ -410,7 +358,7 @@ def run_roundtrip(base_batch):
cp_mesh_rank0 = _DummyDeviceMesh(size=cp_size, rank=0)
cp_mesh_rank1 = _DummyDeviceMesh(size=cp_size, rank=1)
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank0)
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank1)
loader_rank1 = ContextParallelDataLoaderWrapper(None, cp_mesh_rank1)

scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {}
current_rank = {"value": None}
Expand Down Expand Up @@ -499,7 +447,7 @@ def run_roundtrip(base_batch):
cp_mesh_rank0 = _DummyDeviceMesh(size=cp_size, rank=0)
cp_mesh_rank1 = _DummyDeviceMesh(size=cp_size, rank=1)
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank0)
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank1)
loader_rank1 = ContextParallelDataLoaderWrapper(None, cp_mesh_rank1)

scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {}
current_rank = {"value": None}
Expand Down
78 changes: 57 additions & 21 deletions bionemo-recipes/recipes/esm2_native_te/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import logging
from dataclasses import dataclass
from typing import Any
from typing import Any, TypedDict

import datasets
import torch
Expand Down Expand Up @@ -334,7 +334,7 @@ class ContextParallelDataLoaderWrapper:

def __init__(
self,
dataloader: torch.utils.data.DataLoader,
dataloader: torch.utils.data.DataLoader | None,
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
):
"""A dataloader wrapper that distributes the data across the context parallelism group.
Expand All @@ -348,15 +348,28 @@ def __init__(
cp_mesh: The context parallel mesh.
cp_rank: The rank of the current context parallel process.
"""
self.dataloader = dataloader
if cp_mesh.get_local_rank() == 0:
assert dataloader is not None, "dataloader must be provided on rank 0"
self.dataloader = dataloader

else:
assert dataloader is None, "Dataloader on non-rank 0 will not be used"

self.cp_rank = cp_mesh.get_local_rank()
self.cp_group = cp_mesh.get_group()
self.num_cp_ranks = cp_mesh.size()
self._iterator = None

logger.debug(
"Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s",
torch.distributed.get_rank() if torch.distributed.is_initialized() else "<not initialized>",
self.cp_rank,
)

def __iter__(self):
"""Make the dataloader iterable."""
self._iterator = iter(self.dataloader) # < --- collator output.
if self.cp_rank == 0:
self._iterator = iter(self.dataloader) # < --- collator output.
return self

def __next__(self):
Expand Down Expand Up @@ -385,24 +398,19 @@ def _send_data_to_cp_ranks(self):
batch: The batch for the current CP rank.

"""
if self.cp_rank == 0:
# Get data once, then make copies for each rank.
if self._iterator is None:
self._iterator = iter(self.dataloader)
combined_batch = next(self._iterator)
try:
combined_batch = next(self._iterator) if self.cp_rank == 0 else None
except StopIteration as ex:
# If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so
# that the dataloader can be restarted.
combined_batch = [ex] * self.num_cp_ranks

else:
combined_batch = None

scatter_object_output_list = [None]
# Note: This does not provide an async_op handle. Thus its blocking.
torch.distributed.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=combined_batch,
group=self.cp_group,
group_src=0,
)
return scatter_object_output_list[0]
batch_on_this_rank = _scatter_batch_to_cp_ranks(combined_batch, self.cp_group)

if isinstance(batch_on_this_rank, StopIteration):
raise batch_on_this_rank

return batch_on_this_rank


def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down Expand Up @@ -670,3 +678,31 @@ def _get_group_local_rank(group: torch.distributed.ProcessGroup | None = None) -
return torch.distributed.get_rank()
global_rank = torch.distributed.get_rank()
return torch.distributed.get_group_rank(group, global_rank)


class BatchType(TypedDict):
"""The fields in the batch dictionary for context parallel."""

input_ids: torch.Tensor
labels: torch.Tensor
cu_seq_lens_q: torch.Tensor
cu_seq_lens_k: torch.Tensor
cu_seq_lens_q_padded: torch.Tensor
cu_seq_lens_k_padded: torch.Tensor
max_length_q: int
max_length_k: int


def _scatter_batch_to_cp_ranks(
batch: list[BatchType] | list[StopIteration], cp_group: torch.distributed.ProcessGroup | None = None
) -> BatchType | StopIteration:
"""Scatter a batch to all the CP ranks."""
scatter_object_output_list = [None]
# Note: This does not provide an async_op handle. Thus its blocking.
torch.distributed.scatter_object_list(
scatter_object_output_list=scatter_object_output_list,
scatter_object_input_list=batch,
group=cp_group,
group_src=0,
)
return scatter_object_output_list[0]
15 changes: 10 additions & 5 deletions bionemo-recipes/recipes/esm2_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,16 @@ def create_cp_dataloader(
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
kwargs["pad_sequences_to_be_divisible_by"] = cp_mesh.size() * 2

train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)
if cp_mesh.get_local_rank() == 0:
train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)

train_dataloader.collate_fn = DataCollatorForContextParallel(
collator=train_dataloader.collate_fn,
cp_world_size=cp_mesh.size(),
)
train_dataloader.collate_fn = DataCollatorForContextParallel(
collator=train_dataloader.collate_fn,
cp_world_size=cp_mesh.size(),
)

else:
train_dataloader = None
tokenized_dataset = None

return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset
3 changes: 2 additions & 1 deletion bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def main(args: DictConfig) -> float | None:

# Dataloader exhausted, incrementing epoch
epoch += 1
dataset_or_sampler.set_epoch(epoch)
if dataset_or_sampler is not None: # The dataset only exists on rank 0
dataset_or_sampler.set_epoch(epoch)

# Save final model to a .safetensors file.
if args.checkpoint.save_final_model and ckpt_path:
Expand Down
3 changes: 2 additions & 1 deletion bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def main(args: DictConfig) -> float | None:

# Dataloader exhausted, incrementing epoch
epoch += 1
dataset_or_sampler.set_epoch(epoch)
if dataset_or_sampler is not None: # The dataset only exists on rank 0
dataset_or_sampler.set_epoch(epoch)

# Save final model to a .safetensors file.
if args.checkpoint.save_final_model and ckpt_path:
Expand Down
14 changes: 8 additions & 6 deletions bionemo-recipes/recipes/llama3_native_te/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,15 @@ def save_checkpoint_fsdp2(
)
logger.info(f"Saved FSDP2 dataloader to {ckpt_path}")

# If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time.
if async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None:
_ckpt_futures["fsdp2"].result()

state_dict = {"app": AppState(model=model, optimizer=optimizer, scheduler=scheduler, step=step, epoch=epoch)}
ckpt_save_func = dcp_async_save if async_save else dcp_save
_ckpt_futures["fsdp2"] = ckpt_save_func(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)
if async_save:
Comment thread
pstjohn marked this conversation as resolved.
# If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time.
if "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None:
_ckpt_futures["fsdp2"].result()

_ckpt_futures["fsdp2"] = dcp_async_save(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)
else:
dcp_save(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)

if max_checkpoints is not None and dist_config.is_main_process():
prune_checkpoints(ckpt_path, max_checkpoints)
Expand Down
Loading