1+ import logging
2+ from typing import Optional
3+
4+ import torch
5+ from torch .utils .data import DataLoader
6+ from transformers import PreTrainedTokenizer
7+
8+ logger = logging .getLogger (__name__ )
9+
10+ def debug_dataloader_samples (
11+ dataloader : DataLoader ,
12+ tokenizer : PreTrainedTokenizer ,
13+ num_samples : int = 2
14+ ) -> None :
15+ """
16+ Debug utility to inspect samples from a DataLoader.
17+
18+ This function pulls the first batch from the given DataLoader,
19+ detokenizes the 'input_ids' using the provided tokenizer,
20+ and prints the decoded texts for a few samples.
21+
22+ Args:
23+ dataloader (torch.utils.data.DataLoader): The DataLoader to inspect.
24+ tokenizer (PreTrainedTokenizer): Tokenizer used to decode input_ids.
25+ num_samples (int): Number of samples to print from the first batch.
26+ """
27+ try :
28+ batch = next (iter (dataloader ))
29+ except Exception as e :
30+ logger .error ("[debug] Failed to retrieve batch from dataloader: %s" , e )
31+ return
32+
33+ input_ids = batch .get ("input_ids" )
34+ if input_ids is None :
35+ logger .warning ("[debug] 'input_ids' not found in batch. Available keys: %s" , list (batch .keys ()))
36+ return
37+
38+ if hasattr (input_ids , "cpu" ):
39+ input_ids = input_ids .cpu ()
40+
41+ logger .info ("\n [Debug] Printing detokenized samples from the first batch:\n " )
42+ for i in range (min (num_samples , len (input_ids ))):
43+ try :
44+ decoded = tokenizer .decode (input_ids [i ], skip_special_tokens = True )
45+ logger .info ("[Sample %d]:\n %s\n %s" , i + 1 , decoded , "=" * 40 )
46+ except Exception as e :
47+ logger .error ("[debug] Failed to decode sample %d: %s" , i + 1 , e )
0 commit comments