diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index ced6c58c55f2..384323aecf95 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -11,6 +11,7 @@ import abc import ctypes import inspect +import itertools import json import warnings from collections import OrderedDict @@ -910,8 +911,9 @@ class Sequence(abc.ABC): # Get total row number. >>> len(seq) - # Random access by row index. Used for data sampling. + # Random access by row index, or by list of indexes. Used for data sampling. >>> seq[10] + >>> seq[[0, 2, 5, 7, 12]] # Range data access. Used to read data in batch when constructing Dataset. >>> seq[0:100] # Optionally specify batch_size to control range data read size. @@ -2204,19 +2206,37 @@ def _lazy_init( return self.set_feature_name(feature_name) @staticmethod - def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]: - offset = 0 - seq_id = 0 - seq = seqs[seq_id] - for row_id in indices: - assert row_id >= offset, "sample indices are expected to be monotonic" - while row_id >= offset + len(seq): - offset += len(seq) - seq_id += 1 - seq = seqs[seq_id] - id_in_seq = row_id - offset - row = seq[id_in_seq] - yield row if row.flags["OWNDATA"] else row.copy() + def _yield_all_rows_from_seqlist(seqs: List[Sequence]) -> Iterator[np.ndarray]: + for seq in seqs: + nrow = len(seq) + batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size + for start in range(0, nrow, batch_size): + end = min(start + batch_size, nrow) + yield seq[start:end] + + @staticmethod + def _yield_rows_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]: + indices = np.asarray(indices) + if len(indices) == sum(len(seq) for seq in seqs): + # Fast path to sample all rows. + for rows in Dataset._yield_all_rows_from_seqlist(seqs): + yield rows if rows.flags["OWNDATA"] else rows.copy() + return + + # Identify the Sequence object corresponding to each element in `indices`. + seq_starts = np.array([0, *(len(seq) for seq in seqs)], dtype=np.int64).cumsum() + seq_ends = seq_starts[1:] + seq_ids = np.searchsorted(seq_ends, np.asarray(indices), side="right") + + # Sample from each identified sequence, in batch-wise fashion. + for id_of_seq, id_indices in itertools.groupby(zip(seq_ids, indices), key=lambda t: t[0]): + seq = seqs[id_of_seq] + batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size + id_indices = [int(index - seq_starts[id_of_seq]) for _, index in id_indices] + for begin in range(0, len(id_indices), batch_size): + batch_indices = id_indices[begin : min(begin + batch_size, len(id_indices))] + rows = seq[batch_indices] + yield rows if rows.flags["OWNDATA"] else rows.copy() def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]: """Sample data from seqs. @@ -2230,7 +2250,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr indices = self._create_sample_indices(total_nrow) # Select sampled rows, transpose to column order. - sampled = np.array(list(self._yield_row_from_seqlist(seqs, indices))) + sampled = np.vstack(list(self._yield_rows_from_seqlist(seqs, indices))) sampled = sampled.T filtered = [] @@ -2271,12 +2291,9 @@ def __init_from_seqs( sample_data, col_indices = self.__sample(seqs, total_nrow) self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) - for seq in seqs: - nrow = len(seq) - batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size - for start in range(0, nrow, batch_size): - end = min(start + batch_size, nrow) - self._push_rows(seq[start:end]) + for rows in self._yield_all_rows_from_seqlist(seqs): + self._push_rows(rows) + return self def __init_from_np2d( @@ -3268,7 +3285,7 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]: elif isinstance(self.data, Sequence): self.data = self.data[self.used_indices] elif _is_list_of_sequences(self.data) and len(self.data) > 0: - self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices))) + self.data = np.vstack(list(self._yield_rows_from_seqlist(self.data, self.used_indices))) else: _log_warning( f"Cannot subset {type(self.data).__name__} type of raw data.\nReturning original raw data"