-
Notifications
You must be signed in to change notification settings - Fork 295
[train] Make TrainingInputBatch to PAD only to left, hence response tensors be right-aligned #1285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
CharlieFRuan marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -1,9 +1,12 @@ | ||||
| import logging | ||||
| from typing import List, Optional, Tuple | ||||
|
|
||||
| import torch | ||||
| from jaxtyping import Float, Integer | ||||
| from transformers import AutoTokenizer | ||||
|
|
||||
| logger = logging.getLogger(__name__) | ||||
|
|
||||
|
|
||||
| def _verify_inputs( | ||||
| prompts: List[List[int]], | ||||
|
|
@@ -34,6 +37,7 @@ def convert_prompts_responses_to_batch_tensors( | |||
| loss_masks: List[List[int]], | ||||
| logprobs: Optional[List[List[float]]] = None, | ||||
| rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, | ||||
| max_seq_len: Optional[int] = None, | ||||
| ) -> Tuple[ | ||||
| Float[torch.Tensor, "batch seq_len"], | ||||
| Float[torch.Tensor, "batch seq_len"], | ||||
|
|
@@ -46,12 +50,33 @@ def convert_prompts_responses_to_batch_tensors( | |||
| """ | ||||
| Convert prompts and responses to batch tensors for training. | ||||
|
|
||||
| This function concatenates all prompts and responses to the following format: | ||||
| Each sequence is laid out as a single left-padded block: | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this in general makes sense to me, but can you test the forward pass for megatron in
just running a basic forward pass for any model should suffice
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. made a megatron gsm8k run after rebasing, seems normal. Added to the PR description. I'll address some R3 things before merging |
||||
|
|
||||
| | [PAD] [PAD] prompt prompt prompt respon respon | | ||||
| | [PAD] prompt prompt prompt respon respon respon | | ||||
| | prompt prompt prompt respon respon respon respon | | ||||
| |<---- max_response_len ---->| | ||||
|
|
||||
| The padded sequence length is ``max(prompt_len_i + response_len_i)``. | ||||
| This way, the max padded sequence length is ``max_seq_len``. | ||||
|
|
||||
| This makes the response-level tensors (action_mask, rewards, loss_masks, logprobs): | ||||
| | prompt prompt respon respon | | ||||
| | prompt respon respon respon | | ||||
| | respon respon respon respon | | ||||
|
|
||||
| So the action_mask is: | ||||
| | 0 0 1 1 | | ||||
| | 0 1 1 1 | | ||||
| | 1 1 1 1 | | ||||
|
|
||||
| Attention mask is 1 for all real tokens, 0 for padding. | ||||
| Action mask is 1 for the last ``response_len_i`` positions, 0 for padding. | ||||
|
|
||||
| | [PAD] [PAD] token token token | token token [PAD] [PAD] | | ||||
| | token token token token token | token token [PAD] [PAD] | | ||||
| | [PAD] [PAD] [PAD] token token | token token token [PAD] | | ||||
| |<---------- prompt ----------->|<-------- answer ------->| | ||||
| Response-level tensors are **right-aligned** within ``(batch, max_response_len)``: non-padded | ||||
| values occupy the last ``response_len_i`` positions, with leading zeros. This matches the model | ||||
| forward pass which extracts ``log_probs[:, -num_actions-1:-1]`` —- response tokens are always at | ||||
| the end of the sequence, so their logprobs are right-aligned in the slice. | ||||
|
|
||||
| Assumes that the responses already contain an eos token at index -1. | ||||
|
|
||||
|
|
@@ -62,88 +87,89 @@ def convert_prompts_responses_to_batch_tensors( | |||
| rewards: List of rewards for each response | ||||
| loss_masks: List of loss masks for each response | ||||
| logprobs: List of rollout log probs for each response | ||||
| max_seq_len: Optional. If provided and ``max(prompt_i + response_i)`` | ||||
| exceeds it, a warning is logged (no truncation is performed). | ||||
|
|
||||
| Returns: | ||||
| sequences: Full trajectories (padded and concatenated prompts and responses). Size: (batch, seq_len). | ||||
| attention_mask: Attention mask for the model. Size: (batch, seq_len) | ||||
| action_mask: Response mask for the model. Size: (batch, response_len) | ||||
| rewards: Rewards for each output. Size: (batch, response_len) | ||||
| loss_masks: Loss masks for each output. Size: (batch, response_len) | ||||
| sequences: ``(batch, max_total)`` where ``max_total = max(prompt_i + response_i)``. | ||||
| attention_mask: ``(batch, max_total)`` | ||||
| action_mask: ``(batch, max_response)`` — right-aligned response indicator. | ||||
| rewards: ``(batch, max_response)`` — right-aligned. | ||||
| loss_masks: ``(batch, max_response)`` — right-aligned. | ||||
| logprobs: ``(batch, max_response)`` — right-aligned, or ``None``. | ||||
| """ | ||||
| _verify_inputs(prompts, responses, rewards, loss_masks) | ||||
|
|
||||
| max_input_len, max_output_len = 0, 0 | ||||
| prompt_token_lens, response_token_lens = [], [] | ||||
| inputs_token_ids, outputs_token_ids = [], [] | ||||
| for prompt, response in zip(prompts, responses): | ||||
| prompt_token_lens = [len(p) for p in prompts] | ||||
| response_token_lens = [len(r) for r in responses] | ||||
|
|
||||
| inputs_token_ids.append(prompt) | ||||
| outputs_token_ids.append(response) | ||||
| max_response = max(response_token_lens) | ||||
| # Pad to the tightest bound: max per-sample total. | ||||
| max_total = max(p + r for p, r in zip(prompt_token_lens, response_token_lens)) | ||||
|
|
||||
| prompt_token_len = len(prompt) | ||||
| response_token_len = len(response) | ||||
| prompt_token_lens.append(prompt_token_len) | ||||
| response_token_lens.append(response_token_len) | ||||
|
|
||||
| max_input_len = max(max_input_len, prompt_token_len) | ||||
| max_output_len = max(max_output_len, response_token_len) | ||||
| if max_seq_len is not None and max_total > max_seq_len: | ||||
| logger.warning( | ||||
| f"Max sequence length in batch ({max_total}) exceeds max_seq_len ({max_seq_len}). " | ||||
| f"No truncation is performed; consider checking generator settings." | ||||
| ) | ||||
|
|
||||
| pad_token_id = tokenizer.pad_token_id | ||||
| sequences = [] | ||||
| attention_masks = [] | ||||
| action_masks = [] | ||||
| for i, prompt in enumerate(prompts): | ||||
| # left padding input | ||||
| input_len = prompt_token_lens[i] | ||||
| input_ids = [pad_token_id] * (max_input_len - input_len) + list(inputs_token_ids[i]) | ||||
| input_attention_mask = [0] * (max_input_len - input_len) + [1] * input_len | ||||
|
|
||||
| # right padding output | ||||
| output_len = response_token_lens[i] | ||||
| output_ids = list(outputs_token_ids[i]) + [pad_token_id] * (max_output_len - output_len) | ||||
| output_attention_mask = [1] * output_len + [0] * (max_output_len - output_len) | ||||
|
|
||||
| # concat input and output | ||||
| sequences.append(input_ids + output_ids) | ||||
| attention_masks.append(input_attention_mask + output_attention_mask) | ||||
| action_masks.append(output_attention_mask) | ||||
| for i in range(len(prompts)): | ||||
| total_real = prompt_token_lens[i] + response_token_lens[i] | ||||
| pad_len = max_total - total_real | ||||
|
|
||||
| # Unified left-pad: [PAD ... PAD PROMPT RESPONSE] | ||||
| seq = [pad_token_id] * pad_len + prompts[i] + responses[i] | ||||
| attention_mask_i = [0] * pad_len + [1] * total_real | ||||
|
|
||||
| # Response indicator within the last max_response positions (right-aligned). | ||||
| resp_pad = max_response - response_token_lens[i] | ||||
| action_mask_i = [0] * resp_pad + [1] * response_token_lens[i] | ||||
|
|
||||
| sequences.append(seq) | ||||
| attention_masks.append(attention_mask_i) | ||||
| action_masks.append(action_mask_i) | ||||
|
|
||||
| sequences = torch.tensor(sequences) | ||||
| attention_mask = torch.tensor(attention_masks, dtype=torch.int64) | ||||
| action_mask = torch.tensor(action_masks, dtype=torch.int64) | ||||
|
|
||||
| # initialize ret loss masks to be the same as action mask | ||||
| ret_loss_masks = torch.zeros_like(action_mask, dtype=torch.float) | ||||
| for i, loss_mask in enumerate(loss_masks): | ||||
| ret_loss_masks[i, : len(loss_mask)] = torch.tensor(loss_mask) | ||||
| # Response-level tensors are RIGHT-ALIGNED to match the model output. | ||||
| # The model's log_probs[:, -num_actions-1:-1] returns logprobs where | ||||
| # response tokens occupy the last response_len_i positions. | ||||
| ret_loss_masks = torch.zeros(len(prompts), max_response, dtype=torch.float) | ||||
| for i, lm in enumerate(loss_masks): | ||||
| ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float) | ||||
|
|
||||
| # do the same for custom rewards | ||||
| ret_rewards = torch.zeros_like(action_mask, dtype=torch.float) | ||||
| # Same thing for rewards. | ||||
| ret_rewards = torch.zeros(len(prompts), max_response, dtype=torch.float) | ||||
| for i, custom_reward in enumerate(rewards): | ||||
| if isinstance(custom_reward, list): | ||||
| custom_reward = torch.tensor(custom_reward) | ||||
| ret_rewards[i, : len(custom_reward)] = custom_reward | ||||
| ret_rewards[i, max_response - len(custom_reward) :] = custom_reward | ||||
|
|
||||
| # Same thing for logprobs. | ||||
| logprobs_tensor = None | ||||
| if logprobs: | ||||
| max_output_len = action_mask.size(1) | ||||
| padded_logprobs = [ | ||||
| sample_logprobs + [0.0] * (max_output_len - len(sample_logprobs)) for sample_logprobs in logprobs | ||||
| ] | ||||
| logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) | ||||
| logprobs_tensor = torch.zeros(len(prompts), max_response, dtype=torch.float) | ||||
| for i, sample_logprobs in enumerate(logprobs): | ||||
| lp = torch.tensor(sample_logprobs, dtype=torch.float) | ||||
| logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp | ||||
|
|
||||
| rollout_expert_indices_tensor = None | ||||
| if rollout_expert_indices: | ||||
| first_non_empty = next((x for x in rollout_expert_indices if x), None) | ||||
| if first_non_empty: | ||||
| total_seq_len = max_input_len + max_output_len | ||||
| num_layers = len(first_non_empty[0]) | ||||
| topk = len(first_non_empty[0][0]) if num_layers > 0 else 0 | ||||
| padded = torch.zeros(len(rollout_expert_indices), total_seq_len, num_layers, topk, dtype=torch.int32) | ||||
| padded = torch.zeros(len(rollout_expert_indices), max_total, num_layers, topk, dtype=torch.int32) | ||||
| for i, sample_indices in enumerate(rollout_expert_indices): | ||||
| if sample_indices: | ||||
| left_pad = max_input_len - prompt_token_lens[i] | ||||
| n = min(len(sample_indices), total_seq_len - left_pad) | ||||
| left_pad = max_total - (prompt_token_lens[i] + response_token_lens[i]) | ||||
| n = min(len(sample_indices), max_total - left_pad) | ||||
| padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32) | ||||
| rollout_expert_indices_tensor = padded | ||||
|
|
||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.