Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from metaseq.data import iterators, data_utils
from metaseq.data.plasma_utils import PlasmaStore
from metaseq.dataclass.utils import convert_namespace_to_omegaconf
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel, utils as distributed_utils
from metaseq.file_io import PathManager
from metaseq.logging import meters, metrics, progress_bar
from metaseq.trainer import Trainer
Expand Down Expand Up @@ -144,10 +144,12 @@ def main(cfg: DictConfig) -> None:
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
model = fsdp_wrap(
task.build_model(cfg.model),
process_group=distributed_utils.get_data_parallel_group(),
)
model = task.build_model(cfg.model)
if not isinstance(model, FullyShardedDataParallel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, this is for loading up consolidated model for training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
I added support to change the MP size during job lunch, and for that I need to wrap it in FullyShardedDataParallel inside the build_model.
As I don't want to double wrap it, I needed to add this if..

model = fsdp_wrap(
model,
process_group=distributed_utils.get_data_parallel_group(),
)
else:
model = task.build_model(cfg.model)

Expand Down
54 changes: 35 additions & 19 deletions metaseq/data/cm3_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are changes to the cm3 objectives that i landed in scaling_racm3 correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.


import numpy as np
import random
import torch

from typing import List, Optional, Tuple
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
to_skip=0,
permute_documents=True,
source_target=False,
percent_full_document_rotation: float = 0.0
):
super().__init__(
dataset,
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
self.sentinel_fixed = self.sentinel_method == "fixed"
self.allow_rotation_across_eod = allow_rotation_across_eod
self.eod = eod
self.percent_full_document_rotation = percent_full_document_rotation

def get_sentinel(self, i):
return self.sentinel_tokens[i]
Expand Down Expand Up @@ -139,7 +142,8 @@ def sentinel_targets(self, document: torch.Tensor, spans: List[Tuple[int, int]])
index = index + size + 1
return target

def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]:
def get_spans_to_mask(self, document_length: int, document_boundaries: List[Tuple[int, int]]) -> List[
Tuple[int, int]]:
# Ok, we do not use a budget here but instead
# our goal is to sample from ~ U[0,1] in the case of len(sentinel_tokens) = 1
# If len(sentinel_tokens) > 1 we try to find len(sentinel_tokens) non intersecting spans
Expand All @@ -156,18 +160,23 @@ def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]:
if len_sentinel_tokens == 0:
return None
if len_sentinel_tokens == 1:
if np.random.random() < self.percent_full_document_rotation:
return [random.choice(document_boundaries)]

start, end = np.random.uniform(size=2)
if end < start:
start, end = end, start
# round down
start = int(start * document_length)
start = max(1, int(start * document_length))
# round up
end = int(end * document_length + 0.5)
if start == end:
return None
else:
assert start < end
return [(start, end)]
if len_sentinel_tokens < len(document_boundaries) and np.random.random() < self.percent_full_document_rotation:
return random.sample(document_boundaries, len_sentinel_tokens)

# Let's implement the general case. We will create len(self.sentinel_tokens) ** 2 possible candidates
# And we will filter one by one to insure no intersections. If we can't find anything then so be it.
Expand Down Expand Up @@ -200,24 +209,31 @@ def get_document_boundaries(self, item: torch.Tensor):
boundaries = boundaries + [item.size(0)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is get_document_boundaries() robust to the case that there is no break tokens

spans = []
for i in range(1, len(boundaries)):
spans.append((boundaries[i - 1], boundaries[i]))
spans.append((boundaries[i - 1] + 1, boundaries[i]))
return spans

def cm3_shuffle(self, item):
assert len(item) > 0
document_boundaries = self.get_document_boundaries(item)
spans = self.get_spans_to_mask(len(item), document_boundaries)
if not self.allow_rotation_across_eod and spans is not None:
spans = adjust_spans(spans, document_boundaries)
if spans is None:
return item
else:
spans = self.get_ordered_spans(spans)
causal_source = self.sentinel_masking(item, spans)
causal_masked = self.sentinel_targets(item, spans)

total_count = len(causal_source) + len(causal_masked)
total_diff = total_count - self.tokens_per_sample
total_causal_length = len(causal_source) - total_diff
return torch.cat([
causal_source[:total_causal_length],
causal_masked
])[: self.tokens_per_sample] # EOSS tokens can add just enough tokens to get off by 1-2.

def __iter__(self):
for packed_item in super().__iter__():
item = packed_item["block"]
assert len(item) > 0
spans = self.get_spans_to_mask(len(item))
if not self.allow_rotation_across_eod:
document_boundaries = self.get_document_boundaries(item)
spans = adjust_spans(spans, document_boundaries)
if spans is None:
yield packed_item
else:
spans = self.get_ordered_spans(spans)
causal_source = self.sentinel_masking(item, spans)
causal_masked = self.sentinel_targets(item, spans)
packed_item["block"] = torch.cat([causal_source, causal_masked])[
: self.tokens_per_sample
] # EOSS tokens can add just enough tokens to get off by 1-2.
yield packed_item
packed_item["block"] = self.cm3_shuffle(packed_item["block"])
yield packed_item
16 changes: 13 additions & 3 deletions metaseq/modules/megatron/fused_kernels/scaled_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ void dispatch_scaled_softmax_forward(
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 13: // 8192
scaled_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
default:
break;
}
Expand All @@ -541,7 +544,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 );
if (key_seq_len == 0) {
return;
} else {
Expand Down Expand Up @@ -617,6 +620,10 @@ void dispatch_scaled_masked_softmax_forward(
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 13: // 8192
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
Expand All @@ -634,7 +641,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 );
if (key_seq_len == 0) {
return;
} else {
Expand Down Expand Up @@ -709,7 +716,10 @@ void dispatch_scaled_masked_softmax_backward(
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;

case 13: // 8192
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(key_seq_len <= 8192);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -415,6 +415,14 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand All @@ -431,7 +439,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -506,6 +514,14 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ torch::Tensor fwd_cuda(
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
TORCH_INTERNAL_ASSERT(seq_len <= 8192);

// Output
auto act_options = input.options().requires_grad(false);
Expand Down
Loading