Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import abc
import ctypes
import inspect
import itertools
import json
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -910,8 +911,9 @@ class Sequence(abc.ABC):

# Get total row number.
>>> len(seq)
# Random access by row index. Used for data sampling.
# Random access by row index, or by list of indexes. Used for data sampling.
>>> seq[10]
>>> seq[[0, 2, 5, 7, 12]]
# Range data access. Used to read data in batch when constructing Dataset.
>>> seq[0:100]
# Optionally specify batch_size to control range data read size.
Expand Down Expand Up @@ -2204,19 +2206,37 @@ def _lazy_init(
return self.set_feature_name(feature_name)

@staticmethod
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]:
offset = 0
seq_id = 0
seq = seqs[seq_id]
for row_id in indices:
assert row_id >= offset, "sample indices are expected to be monotonic"
while row_id >= offset + len(seq):
offset += len(seq)
seq_id += 1
seq = seqs[seq_id]
id_in_seq = row_id - offset
row = seq[id_in_seq]
yield row if row.flags["OWNDATA"] else row.copy()
def _yield_all_rows_from_seqlist(seqs: List[Sequence]) -> Iterator[np.ndarray]:
for seq in seqs:
nrow = len(seq)
batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size
for start in range(0, nrow, batch_size):
end = min(start + batch_size, nrow)
yield seq[start:end]

@staticmethod
def _yield_rows_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]:
indices = np.asarray(indices)
if len(indices) == sum(len(seq) for seq in seqs):
# Fast path to sample all rows.
for rows in Dataset._yield_all_rows_from_seqlist(seqs):
yield rows if rows.flags["OWNDATA"] else rows.copy()
return

# Identify the Sequence object corresponding to each element in `indices`.
seq_starts = np.array([0, *(len(seq) for seq in seqs)], dtype=np.int64).cumsum()
seq_ends = seq_starts[1:]
seq_ids = np.searchsorted(seq_ends, np.asarray(indices), side="right")

# Sample from each identified sequence, in batch-wise fashion.
for id_of_seq, id_indices in itertools.groupby(zip(seq_ids, indices), key=lambda t: t[0]):
seq = seqs[id_of_seq]
batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size
id_indices = [int(index - seq_starts[id_of_seq]) for _, index in id_indices]
for begin in range(0, len(id_indices), batch_size):
batch_indices = id_indices[begin : min(begin + batch_size, len(id_indices))]
rows = seq[batch_indices]
yield rows if rows.flags["OWNDATA"] else rows.copy()

def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Sample data from seqs.
Expand All @@ -2230,7 +2250,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr
indices = self._create_sample_indices(total_nrow)

# Select sampled rows, transpose to column order.
sampled = np.array(list(self._yield_row_from_seqlist(seqs, indices)))
sampled = np.vstack(list(self._yield_rows_from_seqlist(seqs, indices)))
sampled = sampled.T

filtered = []
Expand Down Expand Up @@ -2271,12 +2291,9 @@ def __init_from_seqs(
sample_data, col_indices = self.__sample(seqs, total_nrow)
self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow)

for seq in seqs:
nrow = len(seq)
batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size
for start in range(0, nrow, batch_size):
end = min(start + batch_size, nrow)
self._push_rows(seq[start:end])
for rows in self._yield_all_rows_from_seqlist(seqs):
self._push_rows(rows)

return self

def __init_from_np2d(
Expand Down Expand Up @@ -3268,7 +3285,7 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]:
elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices]
elif _is_list_of_sequences(self.data) and len(self.data) > 0:
self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
self.data = np.vstack(list(self._yield_rows_from_seqlist(self.data, self.used_indices)))
else:
_log_warning(
f"Cannot subset {type(self.data).__name__} type of raw data.\nReturning original raw data"
Expand Down
Loading