Skip to content

Commit e02e466

Browse files
committed
[feature] Add debug_dataloader_samples utility to print decoded dataloader samples
1 parent fb1d6e9 commit e02e466

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

src/nanotron/debug_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import logging
2+
from typing import Optional
3+
4+
import torch
5+
from torch.utils.data import DataLoader
6+
from transformers import PreTrainedTokenizer
7+
8+
logger = logging.getLogger(__name__)
9+
10+
def debug_dataloader_samples(
11+
dataloader: DataLoader,
12+
tokenizer: PreTrainedTokenizer,
13+
num_samples: int = 2
14+
) -> None:
15+
"""
16+
Debug utility to inspect samples from a DataLoader.
17+
18+
This function pulls the first batch from the given DataLoader,
19+
detokenizes the 'input_ids' using the provided tokenizer,
20+
and prints the decoded texts for a few samples.
21+
22+
Args:
23+
dataloader (torch.utils.data.DataLoader): The DataLoader to inspect.
24+
tokenizer (PreTrainedTokenizer): Tokenizer used to decode input_ids.
25+
num_samples (int): Number of samples to print from the first batch.
26+
"""
27+
try:
28+
batch = next(iter(dataloader))
29+
except Exception as e:
30+
logger.error("[debug] Failed to retrieve batch from dataloader: %s", e)
31+
return
32+
33+
input_ids = batch.get("input_ids")
34+
if input_ids is None:
35+
logger.warning("[debug] 'input_ids' not found in batch. Available keys: %s", list(batch.keys()))
36+
return
37+
38+
if hasattr(input_ids, "cpu"):
39+
input_ids = input_ids.cpu()
40+
41+
logger.info("\n[Debug] Printing detokenized samples from the first batch:\n")
42+
for i in range(min(num_samples, len(input_ids))):
43+
try:
44+
decoded = tokenizer.decode(input_ids[i], skip_special_tokens=True)
45+
logger.info("[Sample %d]:\n%s\n%s", i+1, decoded, "=" * 40)
46+
except Exception as e:
47+
logger.error("[debug] Failed to decode sample %d: %s", i+1, e)

0 commit comments

Comments
 (0)