Skip to content

Commit 6635b3b

Browse files
committed
reflected the comments from ravin
1 parent 2cbc1ce commit 6635b3b

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
2525

26-
BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
26+
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
2727

2828

2929
def check_valid_data(data: Any) -> None:
@@ -32,10 +32,9 @@ def check_valid_data(data: Any) -> None:
3232
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.')
3333

3434

35-
def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None:
36-
"""To avoid unexpected behavior, we use loops over indices."""
37-
for i in range(len(train_tensors)):
38-
check_valid_data(train_tensors[i])
35+
def type_check(train_tensors: BaseDatasetInputType, val_tensors: Optional[BaseDatasetInputType] = None) -> None:
36+
for train_tensor in train_tensors:
37+
check_valid_data(train_tensor)
3938
if val_tensors is not None:
4039
for i in range(len(val_tensors)):
4140
check_valid_data(val_tensors[i])
@@ -63,10 +62,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
6362
class BaseDataset(Dataset, metaclass=ABCMeta):
6463
def __init__(
6564
self,
66-
train_tensors: BaseDatasetType,
65+
train_tensors: BaseDatasetInputType,
6766
dataset_name: Optional[str] = None,
68-
val_tensors: Optional[BaseDatasetType] = None,
69-
test_tensors: Optional[BaseDatasetType] = None,
67+
val_tensors: Optional[BaseDatasetInputType] = None,
68+
test_tensors: Optional[BaseDatasetInputType] = None,
7069
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
7170
resampling_strategy_args: Optional[Dict[str, Any]] = None,
7271
shuffle: Optional[bool] = True,
@@ -313,7 +312,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
313312
return (TransformSubset(self, self.splits[split_id][0], train=True),
314313
TransformSubset(self, self.splits[split_id][1], train=False))
315314

316-
def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset':
315+
def replace_data(self, X_train: BaseDatasetInputType, X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':
317316
"""
318317
To speed up the training of small dataset, early pre-processing of the data
319318
can be made on the fly by the pipeline.

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,12 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any)
150150
"""
151151
Standard k fold cross validation.
152152
153-
:param indices: array of indices to be split
154-
:param num_splits: number of cross validation splits
155-
:return: list of tuples of training and validation indices
153+
Args:
154+
indices (np.ndarray): array of indices to be split
155+
num_splits (int): number of cross validation splits
156+
157+
Returns:
158+
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
156159
"""
157160
cv = KFold(n_splits=num_splits)
158161
splits = list(cv.split(indices))
@@ -163,14 +166,21 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
163166
-> List[Tuple[np.ndarray, np.ndarray]]:
164167
"""
165168
Returns train and validation indices respecting the temporal ordering of the data.
166-
Dummy example: [0, 1, 2, 3] with 3 folds yields
167-
[0] [1]
168-
[0, 1] [2]
169-
[0, 1, 2] [3]
170-
171-
:param indices: array of indices to be split
172-
:param num_splits: number of cross validation splits
173-
:return: list of tuples of training and validation indices
169+
170+
Args:
171+
indices (np.ndarray): array of indices to be split
172+
num_splits (int): number of cross validation splits
173+
174+
Returns:
175+
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
176+
177+
Examples:
178+
>>> indices = np.array([0, 1, 2, 3])
179+
>>> CrossValFuncs.time_series_cross_validation(3, indices)
180+
[([0], [1]),
181+
([0, 1], [2]),
182+
([0, 1, 2], [3])]
183+
174184
"""
175185
cv = TimeSeriesSplit(n_splits=num_splits)
176186
splits = list(cv.split(indices))

0 commit comments

Comments
 (0)