|
14 | 14 | ) |
15 | 15 | from lightning.pytorch.core.datamodule import LightningDataModule |
16 | 16 | from lightning_utilities.core.rank_zero import rank_zero_info |
| 17 | +from sklearn.model_selection import StratifiedShuffleSplit |
17 | 18 | from torch.utils.data import DataLoader |
18 | 19 |
|
19 | 20 | from chebai.preprocessing import reader as dr |
@@ -929,11 +930,17 @@ def get_test_split( |
929 | 930 | labels_list = df["labels"].tolist() |
930 | 931 |
|
931 | 932 | test_size = 1 - self.train_split - (1 - self.train_split) ** 2 |
932 | | - msss = MultilabelStratifiedShuffleSplit( |
933 | | - n_splits=1, test_size=test_size, random_state=seed |
934 | | - ) |
935 | 933 |
|
936 | | - train_indices, test_indices = next(msss.split(labels_list, labels_list)) |
| 934 | + if len(labels_list[0]) > 1: |
| 935 | + splitter = MultilabelStratifiedShuffleSplit( |
| 936 | + n_splits=1, test_size=test_size, random_state=seed |
| 937 | + ) |
| 938 | + else: |
| 939 | + splitter = StratifiedShuffleSplit( |
| 940 | + n_splits=1, test_size=test_size, random_state=seed |
| 941 | + ) |
| 942 | + |
| 943 | + train_indices, test_indices = next(splitter.split(labels_list, labels_list)) |
937 | 944 |
|
938 | 945 | df_train = df.iloc[train_indices] |
939 | 946 | df_test = df.iloc[test_indices] |
@@ -985,12 +992,18 @@ def get_train_val_splits_given_test( |
985 | 992 |
|
986 | 993 | # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) |
987 | 994 | test_size = ((1 - self.train_split) ** 2) / self.train_split |
988 | | - msss = MultilabelStratifiedShuffleSplit( |
989 | | - n_splits=1, test_size=test_size, random_state=seed |
990 | | - ) |
| 995 | + |
| 996 | + if len(labels_list_trainval[0]) > 1: |
| 997 | + splitter = MultilabelStratifiedShuffleSplit( |
| 998 | + n_splits=1, test_size=test_size, random_state=seed |
| 999 | + ) |
| 1000 | + else: |
| 1001 | + splitter = StratifiedShuffleSplit( |
| 1002 | + n_splits=1, test_size=test_size, random_state=seed |
| 1003 | + ) |
991 | 1004 |
|
992 | 1005 | train_indices, validation_indices = next( |
993 | | - msss.split(labels_list_trainval, labels_list_trainval) |
| 1006 | + splitter.split(labels_list_trainval, labels_list_trainval) |
994 | 1007 | ) |
995 | 1008 |
|
996 | 1009 | df_validation = df_trainval.iloc[validation_indices] |
|
0 commit comments