This repository was archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 721
Cm3 integration #727
Open
urielsinger
wants to merge
25
commits into
main
Choose a base branch
from
cm3_seq_len
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Cm3 integration #727
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
dc55c3f
- support seq_len > 2048 (4096 and 8192)
1a63cb5
old/new tokens conversion
6fba607
fsdp double wrap disable
3c40198
- local symlink
d464d96
- local symlink
f37dcc2
Merge remote-tracking branch 'origin/main'
56f8b54
pr fix
c5a58b6
revert "force_distributed=True"
f0b5275
fixed
adampolyak d54eca8
Merge remote-tracking branch 'origin/main'
d30921d
improve free port finding for single node dist init
adampolyak 865c4b3
Merge remote-tracking branch 'origin/cm3_seq_len'
75b74e9
- pytorch FSDP support
61f8792
fix bug
4a677a4
fix bug
af57884
back to fairscale
59556ee
back to fairscale
82d4c77
fix delete_old_checkpoint_files
a72de97
stop training when loss_scale reached minimum
00df75a
stop training when loss_scale reached minimum
bc68a84
add validate_on_first_step support
d5f50e8
fix for single files
adampolyak a42d648
add no_c10d support
2c484fa
Merge remote-tracking branch 'origin/cm3_seq_len' into cm3_seq_len
7e7b5e3
criterion fsdp
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are changes to the cm3 objectives that i landed in scaling_racm3 correct?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exactly. |
||
|
|
||
| import numpy as np | ||
| import random | ||
| import torch | ||
|
|
||
| from typing import List, Optional, Tuple | ||
|
|
@@ -76,6 +77,7 @@ def __init__( | |
| to_skip=0, | ||
| permute_documents=True, | ||
| source_target=False, | ||
| percent_full_document_rotation: float = 0.0 | ||
| ): | ||
| super().__init__( | ||
| dataset, | ||
|
|
@@ -106,6 +108,7 @@ def __init__( | |
| self.sentinel_fixed = self.sentinel_method == "fixed" | ||
| self.allow_rotation_across_eod = allow_rotation_across_eod | ||
| self.eod = eod | ||
| self.percent_full_document_rotation = percent_full_document_rotation | ||
|
|
||
| def get_sentinel(self, i): | ||
| return self.sentinel_tokens[i] | ||
|
|
@@ -139,7 +142,8 @@ def sentinel_targets(self, document: torch.Tensor, spans: List[Tuple[int, int]]) | |
| index = index + size + 1 | ||
| return target | ||
|
|
||
| def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]: | ||
| def get_spans_to_mask(self, document_length: int, document_boundaries: List[Tuple[int, int]]) -> List[ | ||
| Tuple[int, int]]: | ||
| # Ok, we do not use a budget here but instead | ||
| # our goal is to sample from ~ U[0,1] in the case of len(sentinel_tokens) = 1 | ||
| # If len(sentinel_tokens) > 1 we try to find len(sentinel_tokens) non intersecting spans | ||
|
|
@@ -156,18 +160,23 @@ def get_spans_to_mask(self, document_length: int) -> List[Tuple[int, int]]: | |
| if len_sentinel_tokens == 0: | ||
| return None | ||
| if len_sentinel_tokens == 1: | ||
| if np.random.random() < self.percent_full_document_rotation: | ||
| return [random.choice(document_boundaries)] | ||
|
|
||
| start, end = np.random.uniform(size=2) | ||
| if end < start: | ||
| start, end = end, start | ||
| # round down | ||
| start = int(start * document_length) | ||
| start = max(1, int(start * document_length)) | ||
| # round up | ||
| end = int(end * document_length + 0.5) | ||
| if start == end: | ||
| return None | ||
| else: | ||
| assert start < end | ||
| return [(start, end)] | ||
| if len_sentinel_tokens < len(document_boundaries) and np.random.random() < self.percent_full_document_rotation: | ||
| return random.sample(document_boundaries, len_sentinel_tokens) | ||
|
|
||
| # Let's implement the general case. We will create len(self.sentinel_tokens) ** 2 possible candidates | ||
| # And we will filter one by one to insure no intersections. If we can't find anything then so be it. | ||
|
|
@@ -200,24 +209,31 @@ def get_document_boundaries(self, item: torch.Tensor): | |
| boundaries = boundaries + [item.size(0)] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is get_document_boundaries() robust to the case that there is no break tokens |
||
| spans = [] | ||
| for i in range(1, len(boundaries)): | ||
| spans.append((boundaries[i - 1], boundaries[i])) | ||
| spans.append((boundaries[i - 1] + 1, boundaries[i])) | ||
| return spans | ||
|
|
||
| def cm3_shuffle(self, item): | ||
| assert len(item) > 0 | ||
| document_boundaries = self.get_document_boundaries(item) | ||
| spans = self.get_spans_to_mask(len(item), document_boundaries) | ||
| if not self.allow_rotation_across_eod and spans is not None: | ||
| spans = adjust_spans(spans, document_boundaries) | ||
| if spans is None: | ||
| return item | ||
| else: | ||
| spans = self.get_ordered_spans(spans) | ||
| causal_source = self.sentinel_masking(item, spans) | ||
| causal_masked = self.sentinel_targets(item, spans) | ||
|
|
||
| total_count = len(causal_source) + len(causal_masked) | ||
| total_diff = total_count - self.tokens_per_sample | ||
| total_causal_length = len(causal_source) - total_diff | ||
| return torch.cat([ | ||
| causal_source[:total_causal_length], | ||
| causal_masked | ||
| ])[: self.tokens_per_sample] # EOSS tokens can add just enough tokens to get off by 1-2. | ||
|
|
||
| def __iter__(self): | ||
| for packed_item in super().__iter__(): | ||
| item = packed_item["block"] | ||
| assert len(item) > 0 | ||
| spans = self.get_spans_to_mask(len(item)) | ||
| if not self.allow_rotation_across_eod: | ||
| document_boundaries = self.get_document_boundaries(item) | ||
| spans = adjust_spans(spans, document_boundaries) | ||
| if spans is None: | ||
| yield packed_item | ||
| else: | ||
| spans = self.get_ordered_spans(spans) | ||
| causal_source = self.sentinel_masking(item, spans) | ||
| causal_masked = self.sentinel_targets(item, spans) | ||
| packed_item["block"] = torch.cat([causal_source, causal_masked])[ | ||
| : self.tokens_per_sample | ||
| ] # EOSS tokens can add just enough tokens to get off by 1-2. | ||
| yield packed_item | ||
| packed_item["block"] = self.cm3_shuffle(packed_item["block"]) | ||
| yield packed_item | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, this is for loading up consolidated model for training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
I added support to change the MP size during job lunch, and for that I need to wrap it in
FullyShardedDataParallelinside thebuild_model.As I don't want to double wrap it, I needed to add this if..