Skip to content

Commit d313a48

Browse files
committed
[refactor] Update the split functions to be able to call function directly
1 parent 3191642 commit d313a48

File tree

1 file changed

+159
-171
lines changed

1 file changed

+159
-171
lines changed
Lines changed: 159 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from enum import IntEnum
2-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
from enum import Enum
2+
from functools import partial
3+
from typing import List, NamedTuple, Optional, Tuple, Union
34

45
import numpy as np
56

@@ -12,187 +13,69 @@
1213
train_test_split
1314
)
1415

15-
from typing_extensions import Protocol
16+
from torch.utils.data import Dataset
1617

1718

18-
# Use callback protocol as workaround, since callable with function fields count 'self' as argument
19-
class CrossValFunc(Protocol):
20-
def __call__(self,
21-
random_state: np.random.RandomState,
22-
num_splits: int,
23-
indices: np.ndarray,
24-
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
25-
...
19+
class _ResamplingStrategyArgs(NamedTuple):
20+
val_share: float = 0.33
21+
num_splits: int = 5
22+
shuffle: bool = False
23+
stratify: bool = False
2624

2725

28-
class HoldOutFunc(Protocol):
29-
def __call__(self, random_state: np.random.RandomState, val_share: float,
30-
indices: np.ndarray, stratify: Optional[Any]
31-
) -> Tuple[np.ndarray, np.ndarray]:
32-
...
33-
34-
35-
class CrossValTypes(IntEnum):
36-
"""The type of cross validation
37-
38-
This class is used to specify the cross validation function
39-
and is not supposed to be instantiated.
40-
41-
Examples: This class is supposed to be used as follows
42-
>>> cv_type = CrossValTypes.k_fold_cross_validation
43-
>>> print(cv_type.name)
44-
45-
k_fold_cross_validation
46-
47-
>>> for cross_val_type in CrossValTypes:
48-
print(cross_val_type.name, cross_val_type.value)
49-
50-
stratified_k_fold_cross_validation 1
51-
k_fold_cross_validation 2
52-
stratified_shuffle_split_cross_validation 3
53-
shuffle_split_cross_validation 4
54-
time_series_cross_validation 5
55-
"""
56-
stratified_k_fold_cross_validation = 1
57-
k_fold_cross_validation = 2
58-
stratified_shuffle_split_cross_validation = 3
59-
shuffle_split_cross_validation = 4
60-
time_series_cross_validation = 5
61-
62-
def is_stratified(self) -> bool:
63-
stratified = [self.stratified_k_fold_cross_validation,
64-
self.stratified_shuffle_split_cross_validation]
65-
return getattr(self, self.name) in stratified
66-
67-
68-
class HoldoutValTypes(IntEnum):
69-
"""TODO: change to enum using functools.partial"""
70-
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
71-
holdout_validation = 6
72-
stratified_holdout_validation = 7
73-
74-
def is_stratified(self) -> bool:
75-
stratified = [self.stratified_holdout_validation]
76-
return getattr(self, self.name) in stratified
77-
78-
79-
# TODO: replace it with another way
80-
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]
81-
82-
DEFAULT_RESAMPLING_PARAMETERS = {
83-
HoldoutValTypes.holdout_validation: {
84-
'val_share': 0.33,
85-
},
86-
HoldoutValTypes.stratified_holdout_validation: {
87-
'val_share': 0.33,
88-
},
89-
CrossValTypes.k_fold_cross_validation: {
90-
'num_splits': 5,
91-
},
92-
CrossValTypes.stratified_k_fold_cross_validation: {
93-
'num_splits': 5,
94-
},
95-
CrossValTypes.shuffle_split_cross_validation: {
96-
'num_splits': 5,
97-
},
98-
CrossValTypes.time_series_cross_validation: {
99-
'num_splits': 5,
100-
},
101-
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
102-
103-
104-
class HoldOutFuncs():
26+
class HoldoutFuncs():
10527
@staticmethod
106-
def holdout_validation(random_state: np.random.RandomState,
107-
val_share: float,
108-
indices: np.ndarray,
109-
**kwargs: Any
110-
) -> Tuple[np.ndarray, np.ndarray]:
111-
shuffle = kwargs.get('shuffle', True)
112-
train, val = train_test_split(indices, test_size=val_share,
113-
shuffle=shuffle,
114-
random_state=random_state if shuffle else None,
115-
)
28+
def holdout_validation(
29+
random_state: np.random.RandomState,
30+
val_share: float,
31+
indices: np.ndarray,
32+
shuffle: bool = False,
33+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
34+
):
35+
36+
train, val = train_test_split(
37+
indices, test_size=val_share, shuffle=shuffle,
38+
random_state=random_state if shuffle else None,
39+
stratify=labels_to_stratify
40+
)
11641
return train, val
11742

118-
@staticmethod
119-
def stratified_holdout_validation(random_state: np.random.RandomState,
120-
val_share: float,
121-
indices: np.ndarray,
122-
**kwargs: Any
123-
) -> Tuple[np.ndarray, np.ndarray]:
124-
train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"],
125-
random_state=random_state)
126-
return train, val
127-
128-
@classmethod
129-
def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str, HoldOutFunc]:
130-
131-
holdout_validators = {
132-
holdout_val_type.name: getattr(cls, holdout_val_type.name)
133-
for holdout_val_type in holdout_val_types
134-
}
135-
return holdout_validators
136-
13743

13844
class CrossValFuncs():
139-
@staticmethod
140-
def shuffle_split_cross_validation(random_state: np.random.RandomState,
141-
num_splits: int,
142-
indices: np.ndarray,
143-
**kwargs: Any
144-
) -> List[Tuple[np.ndarray, np.ndarray]]:
145-
cv = ShuffleSplit(n_splits=num_splits, random_state=random_state)
146-
splits = list(cv.split(indices))
147-
return splits
148-
149-
@staticmethod
150-
def stratified_shuffle_split_cross_validation(random_state: np.random.RandomState,
151-
num_splits: int,
152-
indices: np.ndarray,
153-
**kwargs: Any
154-
) -> List[Tuple[np.ndarray, np.ndarray]]:
155-
cv = StratifiedShuffleSplit(n_splits=num_splits, random_state=random_state)
156-
splits = list(cv.split(indices, kwargs["stratify"]))
157-
return splits
158-
159-
@staticmethod
160-
def stratified_k_fold_cross_validation(random_state: np.random.RandomState,
161-
num_splits: int,
162-
indices: np.ndarray,
163-
**kwargs: Any
164-
) -> List[Tuple[np.ndarray, np.ndarray]]:
165-
cv = StratifiedKFold(n_splits=num_splits, random_state=random_state)
166-
splits = list(cv.split(indices, kwargs["stratify"]))
167-
return splits
45+
# (shuffle, is_stratify) -> split_fn
46+
_args2split_fn = {
47+
(True, True): StratifiedShuffleSplit,
48+
(True, False): ShuffleSplit,
49+
(False, True): StratifiedKFold,
50+
(False, False): KFold,
51+
}
16852

16953
@staticmethod
170-
def k_fold_cross_validation(random_state: np.random.RandomState,
171-
num_splits: int,
172-
indices: np.ndarray,
173-
**kwargs: Any
174-
) -> List[Tuple[np.ndarray, np.ndarray]]:
54+
def k_fold_cross_validation(
55+
random_state: np.random.RandomState,
56+
num_splits: int,
57+
indices: np.ndarray,
58+
shuffle: bool = False,
59+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
60+
) -> List[Tuple[np.ndarray, np.ndarray]]:
17561
"""
176-
Standard k fold cross validation.
177-
178-
Args:
179-
indices (np.ndarray): array of indices to be split
180-
num_splits (int): number of cross validation splits
181-
18262
Returns:
18363
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
18464
"""
185-
shuffle = kwargs.get('shuffle', True)
186-
cv = KFold(n_splits=num_splits, random_state=random_state if shuffle else None, shuffle=shuffle)
65+
66+
split_fn = CrossValFuncs._args2split_fn[(shuffle, labels_to_stratify is not None)]
67+
cv = split_fn(n_splits=num_splits, random_state=random_state)
18768
splits = list(cv.split(indices))
18869
return splits
18970

19071
@staticmethod
191-
def time_series_cross_validation(random_state: np.random.RandomState,
192-
num_splits: int,
193-
indices: np.ndarray,
194-
**kwargs: Any
195-
) -> List[Tuple[np.ndarray, np.ndarray]]:
72+
def time_series(
73+
random_state: np.random.RandomState,
74+
num_splits: int,
75+
indices: np.ndarray,
76+
shuffle: bool = False,
77+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
78+
) -> List[Tuple[np.ndarray, np.ndarray]]:
19679
"""
19780
Returns train and validation indices respecting the temporal ordering of the data.
19881
@@ -215,10 +98,115 @@ def time_series_cross_validation(random_state: np.random.RandomState,
21598
splits = list(cv.split(indices))
21699
return splits
217100

218-
@classmethod
219-
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]:
220-
cross_validators = {
221-
cross_val_type.name: getattr(cls, cross_val_type.name)
222-
for cross_val_type in cross_val_types
223-
}
224-
return cross_validators
101+
102+
class CrossValTypes(Enum):
103+
"""The type of cross validation
104+
105+
This class is used to specify the cross validation function
106+
and is not supposed to be instantiated.
107+
108+
Examples: This class is supposed to be used as follows
109+
>>> cv_type = CrossValTypes.k_fold_cross_validation
110+
>>> print(cv_type.name)
111+
112+
k_fold_cross_validation
113+
114+
>>> for cross_val_type in CrossValTypes:
115+
print(cross_val_type.name, cross_val_type.value)
116+
117+
k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>)
118+
time_series <function CrossValFuncs.time_series>
119+
"""
120+
k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation)
121+
time_series = partial(CrossValFuncs.time_series)
122+
123+
def __call__(
124+
self,
125+
random_state: np.random.RandomState,
126+
indices: np.ndarray,
127+
num_splits: int = 5,
128+
shuffle: bool = False,
129+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
130+
) -> List[Tuple[np.ndarray, np.ndarray]]:
131+
"""
132+
This function allows to call and type-check the specified function.
133+
134+
Args:
135+
random_state (np.random.RandomState): random number genetor for the reproducibility
136+
num_splits (int): The number of splits in cross validation
137+
indices (np.ndarray): The indices of data points in a dataset
138+
shuffle (bool): If shuffle the indices or not
139+
labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
140+
The labels of the corresponding data points. It is used for the stratification.
141+
142+
Returns:
143+
splits (List[Tuple[np.ndarray, np.ndarray]]):
144+
splits[a split identifier][0: train, 1: val][a data point identifier]
145+
146+
"""
147+
return self.value(
148+
random_state=random_state,
149+
num_splits=num_splits,
150+
indices=indices,
151+
shuffle=shuffle,
152+
labels_to_stratify=labels_to_stratify
153+
)
154+
155+
156+
class HoldoutValTypes(Enum):
157+
"""The type of holdout validation
158+
159+
This class is used to specify the holdout validation function
160+
and is not supposed to be instantiated.
161+
162+
Examples: This class is supposed to be used as follows
163+
>>> holdout_type = HoldoutValTypes.holdout_validation
164+
>>> print(holdout_type.name)
165+
166+
holdout_validation
167+
168+
>>> print(holdout_type.value)
169+
170+
functools.partial(<function HoldoutValTypes.holdout_validation at ...>)
171+
172+
>>> for holdout_type in HoldoutValTypes:
173+
print(holdout_type.name)
174+
175+
holdout_validation
176+
177+
Additionally, HoldoutValTypes.<function> can be called directly.
178+
"""
179+
180+
holdout = partial(HoldoutFuncs.holdout_validation)
181+
182+
def __call__(
183+
self,
184+
random_state: np.random.RandomState,
185+
indices: np.ndarray,
186+
val_share: float = 0.33,
187+
shuffle: bool = False,
188+
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
189+
) -> List[Tuple[np.ndarray, np.ndarray]]:
190+
"""
191+
This function allows to call and type-check the specified function.
192+
193+
Args:
194+
random_state (np.random.RandomState): random number genetor for the reproducibility
195+
val_share (float): The ratio of validation dataset vs the given dataset
196+
indices (np.ndarray): The indices of data points in a dataset
197+
shuffle (bool): If shuffle the indices or not
198+
labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
199+
The labels of the corresponding data points. It is used for the stratification.
200+
201+
Returns:
202+
splits (List[Tuple[np.ndarray, np.ndarray]]):
203+
splits[a split identifier][0: train, 1: val][a data point identifier]
204+
205+
"""
206+
return self.value(
207+
random_state=random_state,
208+
val_share=val_share,
209+
indices=indices,
210+
shuffle=shuffle,
211+
labels_to_stratify=labels_to_stratify
212+
)

0 commit comments

Comments
 (0)