23
23
)
24
24
from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
25
25
26
- BaseDatasetType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
26
+ BaseDatasetInputType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
27
27
28
28
29
29
def check_valid_data (data : Any ) -> None :
@@ -32,10 +32,9 @@ def check_valid_data(data: Any) -> None:
32
32
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.' )
33
33
34
34
35
- def type_check (train_tensors : BaseDatasetType , val_tensors : Optional [BaseDatasetType ] = None ) -> None :
36
- """To avoid unexpected behavior, we use loops over indices."""
37
- for i in range (len (train_tensors )):
38
- check_valid_data (train_tensors [i ])
35
+ def type_check (train_tensors : BaseDatasetInputType , val_tensors : Optional [BaseDatasetInputType ] = None ) -> None :
36
+ for train_tensor in train_tensors :
37
+ check_valid_data (train_tensor )
39
38
if val_tensors is not None :
40
39
for i in range (len (val_tensors )):
41
40
check_valid_data (val_tensors [i ])
@@ -63,10 +62,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
63
62
class BaseDataset (Dataset , metaclass = ABCMeta ):
64
63
def __init__ (
65
64
self ,
66
- train_tensors : BaseDatasetType ,
65
+ train_tensors : BaseDatasetInputType ,
67
66
dataset_name : Optional [str ] = None ,
68
- val_tensors : Optional [BaseDatasetType ] = None ,
69
- test_tensors : Optional [BaseDatasetType ] = None ,
67
+ val_tensors : Optional [BaseDatasetInputType ] = None ,
68
+ test_tensors : Optional [BaseDatasetInputType ] = None ,
70
69
resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
71
70
resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
72
71
shuffle : Optional [bool ] = True ,
@@ -313,7 +312,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
313
312
return (TransformSubset (self , self .splits [split_id ][0 ], train = True ),
314
313
TransformSubset (self , self .splits [split_id ][1 ], train = False ))
315
314
316
- def replace_data (self , X_train : BaseDatasetType , X_test : Optional [BaseDatasetType ]) -> 'BaseDataset' :
315
+ def replace_data (self , X_train : BaseDatasetInputType , X_test : Optional [BaseDatasetInputType ]) -> 'BaseDataset' :
317
316
"""
318
317
To speed up the training of small dataset, early pre-processing of the data
319
318
can be made on the fly by the pipeline.
0 commit comments