1
- from enum import IntEnum
2
- from typing import Any , Dict , List , Optional , Tuple , Union
1
+ from enum import Enum
2
+ from functools import partial
3
+ from typing import List , NamedTuple , Optional , Tuple , Union
3
4
4
5
import numpy as np
5
6
12
13
train_test_split
13
14
)
14
15
15
- from typing_extensions import Protocol
16
+ from torch . utils . data import Dataset
16
17
17
18
18
- # Use callback protocol as workaround, since callable with function fields count 'self' as argument
19
- class CrossValFunc (Protocol ):
20
- def __call__ (self ,
21
- random_state : np .random .RandomState ,
22
- num_splits : int ,
23
- indices : np .ndarray ,
24
- stratify : Optional [Any ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
25
- ...
19
+ class _ResamplingStrategyArgs (NamedTuple ):
20
+ val_share : float = 0.33
21
+ num_splits : int = 5
22
+ shuffle : bool = False
23
+ stratify : bool = False
26
24
27
25
28
- class HoldOutFunc (Protocol ):
29
- def __call__ (self , random_state : np .random .RandomState , val_share : float ,
30
- indices : np .ndarray , stratify : Optional [Any ]
31
- ) -> Tuple [np .ndarray , np .ndarray ]:
32
- ...
33
-
34
-
35
- class CrossValTypes (IntEnum ):
36
- """The type of cross validation
37
-
38
- This class is used to specify the cross validation function
39
- and is not supposed to be instantiated.
40
-
41
- Examples: This class is supposed to be used as follows
42
- >>> cv_type = CrossValTypes.k_fold_cross_validation
43
- >>> print(cv_type.name)
44
-
45
- k_fold_cross_validation
46
-
47
- >>> for cross_val_type in CrossValTypes:
48
- print(cross_val_type.name, cross_val_type.value)
49
-
50
- stratified_k_fold_cross_validation 1
51
- k_fold_cross_validation 2
52
- stratified_shuffle_split_cross_validation 3
53
- shuffle_split_cross_validation 4
54
- time_series_cross_validation 5
55
- """
56
- stratified_k_fold_cross_validation = 1
57
- k_fold_cross_validation = 2
58
- stratified_shuffle_split_cross_validation = 3
59
- shuffle_split_cross_validation = 4
60
- time_series_cross_validation = 5
61
-
62
- def is_stratified (self ) -> bool :
63
- stratified = [self .stratified_k_fold_cross_validation ,
64
- self .stratified_shuffle_split_cross_validation ]
65
- return getattr (self , self .name ) in stratified
66
-
67
-
68
- class HoldoutValTypes (IntEnum ):
69
- """TODO: change to enum using functools.partial"""
70
- """The type of hold out validation (refer to CrossValTypes' doc-string)"""
71
- holdout_validation = 6
72
- stratified_holdout_validation = 7
73
-
74
- def is_stratified (self ) -> bool :
75
- stratified = [self .stratified_holdout_validation ]
76
- return getattr (self , self .name ) in stratified
77
-
78
-
79
- # TODO: replace it with another way
80
- RESAMPLING_STRATEGIES = [CrossValTypes , HoldoutValTypes ]
81
-
82
- DEFAULT_RESAMPLING_PARAMETERS = {
83
- HoldoutValTypes .holdout_validation : {
84
- 'val_share' : 0.33 ,
85
- },
86
- HoldoutValTypes .stratified_holdout_validation : {
87
- 'val_share' : 0.33 ,
88
- },
89
- CrossValTypes .k_fold_cross_validation : {
90
- 'num_splits' : 5 ,
91
- },
92
- CrossValTypes .stratified_k_fold_cross_validation : {
93
- 'num_splits' : 5 ,
94
- },
95
- CrossValTypes .shuffle_split_cross_validation : {
96
- 'num_splits' : 5 ,
97
- },
98
- CrossValTypes .time_series_cross_validation : {
99
- 'num_splits' : 5 ,
100
- },
101
- } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
102
-
103
-
104
- class HoldOutFuncs ():
26
+ class HoldoutFuncs ():
105
27
@staticmethod
106
- def holdout_validation (random_state : np .random .RandomState ,
107
- val_share : float ,
108
- indices : np .ndarray ,
109
- ** kwargs : Any
110
- ) -> Tuple [np .ndarray , np .ndarray ]:
111
- shuffle = kwargs .get ('shuffle' , True )
112
- train , val = train_test_split (indices , test_size = val_share ,
113
- shuffle = shuffle ,
114
- random_state = random_state if shuffle else None ,
115
- )
28
+ def holdout_validation (
29
+ random_state : np .random .RandomState ,
30
+ val_share : float ,
31
+ indices : np .ndarray ,
32
+ shuffle : bool = False ,
33
+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
34
+ ):
35
+
36
+ train , val = train_test_split (
37
+ indices , test_size = val_share , shuffle = shuffle ,
38
+ random_state = random_state if shuffle else None ,
39
+ stratify = labels_to_stratify
40
+ )
116
41
return train , val
117
42
118
- @staticmethod
119
- def stratified_holdout_validation (random_state : np .random .RandomState ,
120
- val_share : float ,
121
- indices : np .ndarray ,
122
- ** kwargs : Any
123
- ) -> Tuple [np .ndarray , np .ndarray ]:
124
- train , val = train_test_split (indices , test_size = val_share , shuffle = True , stratify = kwargs ["stratify" ],
125
- random_state = random_state )
126
- return train , val
127
-
128
- @classmethod
129
- def get_holdout_validators (cls , * holdout_val_types : HoldoutValTypes ) -> Dict [str , HoldOutFunc ]:
130
-
131
- holdout_validators = {
132
- holdout_val_type .name : getattr (cls , holdout_val_type .name )
133
- for holdout_val_type in holdout_val_types
134
- }
135
- return holdout_validators
136
-
137
43
138
44
class CrossValFuncs ():
139
- @staticmethod
140
- def shuffle_split_cross_validation (random_state : np .random .RandomState ,
141
- num_splits : int ,
142
- indices : np .ndarray ,
143
- ** kwargs : Any
144
- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
145
- cv = ShuffleSplit (n_splits = num_splits , random_state = random_state )
146
- splits = list (cv .split (indices ))
147
- return splits
148
-
149
- @staticmethod
150
- def stratified_shuffle_split_cross_validation (random_state : np .random .RandomState ,
151
- num_splits : int ,
152
- indices : np .ndarray ,
153
- ** kwargs : Any
154
- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
155
- cv = StratifiedShuffleSplit (n_splits = num_splits , random_state = random_state )
156
- splits = list (cv .split (indices , kwargs ["stratify" ]))
157
- return splits
158
-
159
- @staticmethod
160
- def stratified_k_fold_cross_validation (random_state : np .random .RandomState ,
161
- num_splits : int ,
162
- indices : np .ndarray ,
163
- ** kwargs : Any
164
- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
165
- cv = StratifiedKFold (n_splits = num_splits , random_state = random_state )
166
- splits = list (cv .split (indices , kwargs ["stratify" ]))
167
- return splits
45
+ # (shuffle, is_stratify) -> split_fn
46
+ _args2split_fn = {
47
+ (True , True ): StratifiedShuffleSplit ,
48
+ (True , False ): ShuffleSplit ,
49
+ (False , True ): StratifiedKFold ,
50
+ (False , False ): KFold ,
51
+ }
168
52
169
53
@staticmethod
170
- def k_fold_cross_validation (random_state : np .random .RandomState ,
171
- num_splits : int ,
172
- indices : np .ndarray ,
173
- ** kwargs : Any
174
- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
54
+ def k_fold_cross_validation (
55
+ random_state : np .random .RandomState ,
56
+ num_splits : int ,
57
+ indices : np .ndarray ,
58
+ shuffle : bool = False ,
59
+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
60
+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
175
61
"""
176
- Standard k fold cross validation.
177
-
178
- Args:
179
- indices (np.ndarray): array of indices to be split
180
- num_splits (int): number of cross validation splits
181
-
182
62
Returns:
183
63
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
184
64
"""
185
- shuffle = kwargs .get ('shuffle' , True )
186
- cv = KFold (n_splits = num_splits , random_state = random_state if shuffle else None , shuffle = shuffle )
65
+
66
+ split_fn = CrossValFuncs ._args2split_fn [(shuffle , labels_to_stratify is not None )]
67
+ cv = split_fn (n_splits = num_splits , random_state = random_state )
187
68
splits = list (cv .split (indices ))
188
69
return splits
189
70
190
71
@staticmethod
191
- def time_series_cross_validation (random_state : np .random .RandomState ,
192
- num_splits : int ,
193
- indices : np .ndarray ,
194
- ** kwargs : Any
195
- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
72
+ def time_series (
73
+ random_state : np .random .RandomState ,
74
+ num_splits : int ,
75
+ indices : np .ndarray ,
76
+ shuffle : bool = False ,
77
+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
78
+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
196
79
"""
197
80
Returns train and validation indices respecting the temporal ordering of the data.
198
81
@@ -215,10 +98,115 @@ def time_series_cross_validation(random_state: np.random.RandomState,
215
98
splits = list (cv .split (indices ))
216
99
return splits
217
100
218
- @classmethod
219
- def get_cross_validators (cls , * cross_val_types : CrossValTypes ) -> Dict [str , CrossValFunc ]:
220
- cross_validators = {
221
- cross_val_type .name : getattr (cls , cross_val_type .name )
222
- for cross_val_type in cross_val_types
223
- }
224
- return cross_validators
101
+
102
+ class CrossValTypes (Enum ):
103
+ """The type of cross validation
104
+
105
+ This class is used to specify the cross validation function
106
+ and is not supposed to be instantiated.
107
+
108
+ Examples: This class is supposed to be used as follows
109
+ >>> cv_type = CrossValTypes.k_fold_cross_validation
110
+ >>> print(cv_type.name)
111
+
112
+ k_fold_cross_validation
113
+
114
+ >>> for cross_val_type in CrossValTypes:
115
+ print(cross_val_type.name, cross_val_type.value)
116
+
117
+ k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>)
118
+ time_series <function CrossValFuncs.time_series>
119
+ """
120
+ k_fold_cross_validation = partial (CrossValFuncs .k_fold_cross_validation )
121
+ time_series = partial (CrossValFuncs .time_series )
122
+
123
+ def __call__ (
124
+ self ,
125
+ random_state : np .random .RandomState ,
126
+ indices : np .ndarray ,
127
+ num_splits : int = 5 ,
128
+ shuffle : bool = False ,
129
+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
130
+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
131
+ """
132
+ This function allows to call and type-check the specified function.
133
+
134
+ Args:
135
+ random_state (np.random.RandomState): random number genetor for the reproducibility
136
+ num_splits (int): The number of splits in cross validation
137
+ indices (np.ndarray): The indices of data points in a dataset
138
+ shuffle (bool): If shuffle the indices or not
139
+ labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
140
+ The labels of the corresponding data points. It is used for the stratification.
141
+
142
+ Returns:
143
+ splits (List[Tuple[np.ndarray, np.ndarray]]):
144
+ splits[a split identifier][0: train, 1: val][a data point identifier]
145
+
146
+ """
147
+ return self .value (
148
+ random_state = random_state ,
149
+ num_splits = num_splits ,
150
+ indices = indices ,
151
+ shuffle = shuffle ,
152
+ labels_to_stratify = labels_to_stratify
153
+ )
154
+
155
+
156
+ class HoldoutValTypes (Enum ):
157
+ """The type of holdout validation
158
+
159
+ This class is used to specify the holdout validation function
160
+ and is not supposed to be instantiated.
161
+
162
+ Examples: This class is supposed to be used as follows
163
+ >>> holdout_type = HoldoutValTypes.holdout_validation
164
+ >>> print(holdout_type.name)
165
+
166
+ holdout_validation
167
+
168
+ >>> print(holdout_type.value)
169
+
170
+ functools.partial(<function HoldoutValTypes.holdout_validation at ...>)
171
+
172
+ >>> for holdout_type in HoldoutValTypes:
173
+ print(holdout_type.name)
174
+
175
+ holdout_validation
176
+
177
+ Additionally, HoldoutValTypes.<function> can be called directly.
178
+ """
179
+
180
+ holdout = partial (HoldoutFuncs .holdout_validation )
181
+
182
+ def __call__ (
183
+ self ,
184
+ random_state : np .random .RandomState ,
185
+ indices : np .ndarray ,
186
+ val_share : float = 0.33 ,
187
+ shuffle : bool = False ,
188
+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
189
+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
190
+ """
191
+ This function allows to call and type-check the specified function.
192
+
193
+ Args:
194
+ random_state (np.random.RandomState): random number genetor for the reproducibility
195
+ val_share (float): The ratio of validation dataset vs the given dataset
196
+ indices (np.ndarray): The indices of data points in a dataset
197
+ shuffle (bool): If shuffle the indices or not
198
+ labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
199
+ The labels of the corresponding data points. It is used for the stratification.
200
+
201
+ Returns:
202
+ splits (List[Tuple[np.ndarray, np.ndarray]]):
203
+ splits[a split identifier][0: train, 1: val][a data point identifier]
204
+
205
+ """
206
+ return self .value (
207
+ random_state = random_state ,
208
+ val_share = val_share ,
209
+ indices = indices ,
210
+ shuffle = shuffle ,
211
+ labels_to_stratify = labels_to_stratify
212
+ )
0 commit comments