diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 8fbcedca..a8f553c6 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -387,13 +387,9 @@ def tokenize_sft_batch( tokenizer=tokenizer, return_function=True, ) - # Handle missing pad_token_id (common for LLaMA and similar models) - pad_token_id = tokenizer.pad_token_id - if pad_token_id is None: - pad_token_id = tokenizer.eos_token_id - - # First pass: tokenize all trajectories - tokenized_trajectories = [] + # Tokenize all trajectories (no padding — each keeps its natural length) + trajectory_tensors = [] + num_trainable_tokens = 0 for trajectory in trajectory_batch: messages = trajectory.messages_and_choices tools = trajectory.tools @@ -409,49 +405,18 @@ def tokenize_sft_batch( ), ) - # Create attention mask (all 1s - no padding yet) attention_mask = [1] * len(input_ids) labels = train_on_responses_only_fn({"input_ids": [input_ids]})["labels"][0] - tokenized_trajectories.append( + trajectory_tensors.append( { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, + "input_ids": torch.tensor([input_ids], dtype=torch.long), + "attention_mask": torch.tensor([attention_mask], dtype=torch.long), + "labels": torch.tensor([labels], dtype=torch.long), } ) - - # Find max length in this batch for padding - max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories) - - # Second pass: pad all trajectories to max_seq_length - trajectory_tensors = [] - for tokenized in tokenized_trajectories: - input_ids = tokenized["input_ids"] - attention_mask = tokenized["attention_mask"] - labels = tokenized["labels"] - - # Pad to max_seq_length - padding_length = max_seq_length - len(input_ids) - if padding_length > 0: - input_ids = input_ids + [pad_token_id] * padding_length - attention_mask = attention_mask + [0] * padding_length - labels = labels + [-100] * padding_length - - trajectory_tensor = { - "input_ids": torch.tensor([input_ids], dtype=torch.long), - "attention_mask": torch.tensor([attention_mask], dtype=torch.long), - "labels": torch.tensor([labels], dtype=torch.long), - } - - trajectory_tensors.append(trajectory_tensor) - - # Calculate total trainable tokens (labels != -100) - num_trainable_tokens = sum( - (tensor_dict["labels"] != -100).sum().item() - for tensor_dict in trajectory_tensors - ) + num_trainable_tokens += sum(1 for l in labels if l != -100) return SFTBatch( trajectory_tensors=trajectory_tensors,