Skip to content

Commit bf9e642

Browse files
authored
Merge pull request #54 from ChEB-AI/refactor_chebiOverXPartial
Refactor ChEBIOverXPartial
2 parents a95415b + 19b194a commit bf9e642

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from lightning.pytorch.core.datamodule import LightningDataModule
1616
from lightning_utilities.core.rank_zero import rank_zero_info
17+
from sklearn.model_selection import StratifiedShuffleSplit
1718
from torch.utils.data import DataLoader
1819

1920
from chebai.preprocessing import reader as dr
@@ -929,11 +930,17 @@ def get_test_split(
929930
labels_list = df["labels"].tolist()
930931

931932
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
932-
msss = MultilabelStratifiedShuffleSplit(
933-
n_splits=1, test_size=test_size, random_state=seed
934-
)
935933

936-
train_indices, test_indices = next(msss.split(labels_list, labels_list))
934+
if len(labels_list[0]) > 1:
935+
splitter = MultilabelStratifiedShuffleSplit(
936+
n_splits=1, test_size=test_size, random_state=seed
937+
)
938+
else:
939+
splitter = StratifiedShuffleSplit(
940+
n_splits=1, test_size=test_size, random_state=seed
941+
)
942+
943+
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
937944

938945
df_train = df.iloc[train_indices]
939946
df_test = df.iloc[test_indices]
@@ -985,12 +992,18 @@ def get_train_val_splits_given_test(
985992

986993
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
987994
test_size = ((1 - self.train_split) ** 2) / self.train_split
988-
msss = MultilabelStratifiedShuffleSplit(
989-
n_splits=1, test_size=test_size, random_state=seed
990-
)
995+
996+
if len(labels_list_trainval[0]) > 1:
997+
splitter = MultilabelStratifiedShuffleSplit(
998+
n_splits=1, test_size=test_size, random_state=seed
999+
)
1000+
else:
1001+
splitter = StratifiedShuffleSplit(
1002+
n_splits=1, test_size=test_size, random_state=seed
1003+
)
9911004

9921005
train_indices, validation_indices = next(
993-
msss.split(labels_list_trainval, labels_list_trainval)
1006+
splitter.split(labels_list_trainval, labels_list_trainval)
9941007
)
9951008

9961009
df_validation = df_trainval.iloc[validation_indices]

chebai/preprocessing/datasets/chebi.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,9 @@ def __init__(self, top_class_id: int, **kwargs):
736736
top_class_id (int): The ID of the top class from which to extract subclasses.
737737
**kwargs: Additional keyword arguments passed to the superclass initializer.
738738
"""
739+
if "top_class_id" not in kwargs:
740+
kwargs["top_class_id"] = top_class_id
741+
739742
self.top_class_id: int = top_class_id
740743
super().__init__(**kwargs)
741744

@@ -758,27 +761,18 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
758761
"""
759762
Extracts a subset of ChEBI based on subclasses of the top class ID.
760763
764+
This method calls the superclass method to extract the full class hierarchy,
765+
then extracts the subgraph containing only the descendants of the top class ID, including itself.
766+
761767
Args:
762768
chebi_path (str): The file path to the ChEBI ontology file.
763769
764770
Returns:
765-
nx.DiGraph: The extracted class hierarchy as a directed graph.
771+
nx.DiGraph: The extracted class hierarchy as a directed graph, limited to the
772+
descendants of the top class ID.
766773
"""
767-
with open(chebi_path, encoding="utf-8") as chebi:
768-
chebi = "\n".join(l for l in chebi if not l.startswith("xref:"))
769-
elements = [
770-
term_callback(clause)
771-
for clause in fastobo.loads(chebi)
772-
if clause and ":" in str(clause.id)
773-
]
774-
g = nx.DiGraph()
775-
for n in elements:
776-
g.add_node(n["id"], **n)
777-
g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]])
778-
779-
g = nx.transitive_closure_dag(g)
780-
g = g.subgraph(list(nx.descendants(g, self.top_class_id)) + [self.top_class_id])
781-
print("Compute transitive closure")
774+
g = super()._extract_class_hierarchy(chebi_path)
775+
g = g.subgraph(list(g.successors(self.top_class_id)) + [self.top_class_id])
782776
return g
783777

784778

0 commit comments

Comments
 (0)