Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 8 additions & 43 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,9 @@ def tokenize_sft_batch(
tokenizer=tokenizer,
return_function=True,
)
# Handle missing pad_token_id (common for LLaMA and similar models)
pad_token_id = tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = tokenizer.eos_token_id

# First pass: tokenize all trajectories
tokenized_trajectories = []
# Tokenize all trajectories (no padding — each keeps its natural length)
trajectory_tensors = []
num_trainable_tokens = 0
for trajectory in trajectory_batch:
messages = trajectory.messages_and_choices
tools = trajectory.tools
Expand All @@ -409,49 +405,18 @@ def tokenize_sft_batch(
),
)

# Create attention mask (all 1s - no padding yet)
attention_mask = [1] * len(input_ids)

labels = train_on_responses_only_fn({"input_ids": [input_ids]})["labels"][0]

tokenized_trajectories.append(
trajectory_tensors.append(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"input_ids": torch.tensor([input_ids], dtype=torch.long),
"attention_mask": torch.tensor([attention_mask], dtype=torch.long),
"labels": torch.tensor([labels], dtype=torch.long),
}
)

# Find max length in this batch for padding
max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories)

# Second pass: pad all trajectories to max_seq_length
trajectory_tensors = []
for tokenized in tokenized_trajectories:
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
labels = tokenized["labels"]

# Pad to max_seq_length
padding_length = max_seq_length - len(input_ids)
if padding_length > 0:
input_ids = input_ids + [pad_token_id] * padding_length
attention_mask = attention_mask + [0] * padding_length
labels = labels + [-100] * padding_length

trajectory_tensor = {
"input_ids": torch.tensor([input_ids], dtype=torch.long),
"attention_mask": torch.tensor([attention_mask], dtype=torch.long),
"labels": torch.tensor([labels], dtype=torch.long),
}

trajectory_tensors.append(trajectory_tensor)

# Calculate total trainable tokens (labels != -100)
num_trainable_tokens = sum(
(tensor_dict["labels"] != -100).sum().item()
for tensor_dict in trajectory_tensors
)
num_trainable_tokens += sum(1 for l in labels if l != -100)

return SFTBatch(
trajectory_tensors=trajectory_tensors,
Expand Down