Skip to content

Commit 812d1e1

Browse files
committed
feat: add evaluate_from_embeddings
1 parent 81f7be7 commit 812d1e1

File tree

2 files changed

+109
-57
lines changed

2 files changed

+109
-57
lines changed

graph_datasets/utils/evaluation/eval_tools.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.metrics import normalized_mutual_info_score as NMI
1818
from sklearn.svm import LinearSVC
1919

20+
from ..common import get_str_time
2021
from ..common import tab_printer
2122

2223

@@ -38,7 +39,14 @@ def save_dict(di_, filename_):
3839
pickle.dump(di_, f)
3940

4041

41-
def split_train_test_nodes(data, train_ratio, valid_ratio, data_name, split_id=0, fixed_split=True):
42+
def split_train_test_nodes(
43+
num_nodes,
44+
train_ratio,
45+
valid_ratio,
46+
data_name,
47+
split_id=0,
48+
fixed_split=True,
49+
):
4250
if fixed_split:
4351
file_path = f"../input/fixed_splits/{data_name}-{train_ratio}-{valid_ratio}-splits.npy"
4452
if not os.path.exists(file_path):
@@ -47,13 +55,12 @@ def split_train_test_nodes(data, train_ratio, valid_ratio, data_name, split_id=0
4755
splits = {}
4856
for idx in range(10):
4957
# set up train val and test
50-
shuffle = list(range(data.num_nodes))
58+
shuffle = list(range(num_nodes))
5159
random.shuffle(shuffle)
52-
train_nodes = shuffle[:int(data.num_nodes * train_ratio / 100)]
53-
val_nodes = shuffle[
54-
int(data.num_nodes * train_ratio /
55-
100):int(data.num_nodes * (train_ratio + valid_ratio) / 100)]
56-
test_nodes = shuffle[int(data.num_nodes * (train_ratio + valid_ratio) / 100):]
60+
train_nodes = shuffle[:int(num_nodes * train_ratio / 100)]
61+
val_nodes = shuffle[int(num_nodes * train_ratio /
62+
100):int(num_nodes * (train_ratio + valid_ratio) / 100)]
63+
test_nodes = shuffle[int(num_nodes * (train_ratio + valid_ratio) / 100):]
5764
splits[idx] = {"train": train_nodes, "valid": val_nodes, "test": test_nodes}
5865
save_dict(di_=splits, filename_=file_path)
5966
else:
@@ -62,12 +69,12 @@ def split_train_test_nodes(data, train_ratio, valid_ratio, data_name, split_id=0
6269
train_nodes, val_nodes, test_nodes = split["train"], split["valid"], split["test"]
6370
else:
6471
# set up train val and test
65-
shuffle = list(range(data.num_nodes))
72+
shuffle = list(range(num_nodes))
6673
random.shuffle(shuffle)
67-
train_nodes = shuffle[:int(data.num_nodes * train_ratio / 100)]
68-
val_nodes = shuffle[int(data.num_nodes * train_ratio /
69-
100):int(data.num_nodes * (train_ratio + valid_ratio) / 100)]
70-
test_nodes = shuffle[int(data.num_nodes * (train_ratio + valid_ratio) / 100):]
74+
train_nodes = shuffle[:int(num_nodes * train_ratio / 100)]
75+
val_nodes = shuffle[int(num_nodes * train_ratio /
76+
100):int(num_nodes * (train_ratio + valid_ratio) / 100)]
77+
test_nodes = shuffle[int(num_nodes * (train_ratio + valid_ratio) / 100):]
7178

7279
return np.array(train_nodes), np.array(val_nodes), np.array(test_nodes)
7380

@@ -175,18 +182,18 @@ def kmeans_test(X, y, n_clusters, repeat=10):
175182
)
176183

177184

178-
def svm_test(data, embeddings, labels, train_ratios=(10, 20, 30, 40), repeat=10):
185+
def svm_test(num_nodes, data_name, embeddings, labels, train_ratios=(10, 20, 30, 40), repeat=10):
179186
result_macro_f1_list = []
180187
result_micro_f1_list = []
181188
for train_ratio in train_ratios:
182189
macro_f1_list = []
183190
micro_f1_list = []
184191
for i in range(repeat):
185192
train_idx, val_idx, test_idx = split_train_test_nodes(
186-
data=data,
193+
num_nodes=num_nodes,
187194
train_ratio=train_ratio,
188195
valid_ratio=train_ratio,
189-
data_name=data.name,
196+
data_name=data_name,
190197
split_id=i,
191198
)
192199
X_train, X_test = embeddings[np.concatenate([train_idx, val_idx])], embeddings[test_idx]
@@ -204,16 +211,16 @@ def svm_test(data, embeddings, labels, train_ratios=(10, 20, 30, 40), repeat=10)
204211

205212

206213
def evaluate_results_nc(
207-
data,
214+
labels,
215+
num_classes,
216+
num_nodes,
217+
data_name,
208218
embeddings,
209219
quiet=False,
210220
method="unsup",
211221
alpha: float = 2.0,
212222
beta: float = 2.0,
213223
):
214-
labels = data.y.detach().cpu().numpy()
215-
num_classes = data.num_classes
216-
num_nodes = data.num_nodes
217224
if embeddings.shape[0] > num_nodes:
218225
z_1 = embeddings[:num_nodes]
219226
z_2 = embeddings[num_nodes:]
@@ -227,7 +234,8 @@ def evaluate_results_nc(
227234
svm_macro_f1_list,
228235
svm_micro_f1_list,
229236
) = svm_test(
230-
data=data,
237+
num_nodes=num_nodes,
238+
data_name=data_name,
231239
embeddings=embeddings,
232240
labels=labels,
233241
)
@@ -315,7 +323,7 @@ def save_embedding(
315323
verbose: bool or int = True,
316324
):
317325
dataset_name = dataset_name.replace("_", "-")
318-
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
326+
timestamp = get_str_time()
319327
file_name = f"{dataset_name.lower()}_{model_name.lower()}_embeds_{timestamp}.pth"
320328
file_path = os.path.join(save_dir, file_name)
321329

graph_datasets/utils/evaluation/evaluation.py

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict
66
from typing import Tuple
77

8+
import numpy as np
89
import torch
910

1011
from .eval_tools import evaluate_results_nc
@@ -22,44 +23,31 @@ def load_from_file(file_name):
2223
return embeddings
2324

2425

25-
def evaluate_from_embed_file(
26-
embedding_file: str,
27-
data_file: str,
28-
save_path: str = "./tmp/",
26+
def evaluate_from_embeddings(
27+
labels: np.ndarray,
28+
num_classes: int,
29+
num_nodes: int,
30+
data_name: str,
31+
embeddings: torch.Tensor,
2932
quiet: bool = True,
33+
method: str = "both",
3034
) -> Tuple[Dict, Dict]:
31-
"""Evaluation of representation quality using clustering and classification tasks.
35+
"""evaluate embeddings with LR and Clustering.
3236
3337
Args:
34-
embedding_file (str): Embedded file name.
35-
data_file (str): Data file name.
36-
save_path (str, optional): Folder path to store. Defaults to './tmp/'.
37-
quiet (bool, optional): Whether to print results. Defaults to True.
38+
labels (np.ndarray): labels.
39+
num_classes (int): number of classes.
40+
num_nodes (int): number of nodes.
41+
data_name (str): name of the datasets.
42+
embeddings (torch.Tensor): embeddings.
43+
quiet (bool, optional): whether to print info. Defaults to True.
44+
method (bool, optional): method for evaluation, \
45+
"sup" for linear regression, "unsup" for svm clustering, "both" for both.\
46+
Defaults to "both".
3847
3948
Returns:
40-
Tuple[Dict, Dict]: Two dicts are included, \
41-
which are the evaluation results of clustering and classification.
42-
43-
Example:
44-
.. code-block:: python
45-
46-
from graph_datasets import evaluate_from_embed_file
47-
48-
method_name='orderedgnn'
49-
data_name='texas'
50-
51-
clustering_res, classification_res = evaluate_from_embed_file(
52-
f'{data_name}_{method_name}_embeds.pth',
53-
f'{data_name}_data.pth',
54-
save_path='./save/',
55-
)
49+
Tuple[Dict, Dict]: (clustering_results, classification_results)
5650
"""
57-
embedding_file = os.path.join(save_path, embedding_file)
58-
data_file = os.path.join(save_path, data_file)
59-
60-
embeddings = load_from_file(embedding_file).cpu().detach()
61-
data = load_from_file(data_file)
62-
6351
# Call the evaluate_results_nc function with the loaded embeddings
6452
(
6553
svm_macro_f1_list,
@@ -75,10 +63,13 @@ def evaluate_from_embed_file(
7563
f1_mean,
7664
f1_std,
7765
) = evaluate_results_nc(
78-
data,
66+
labels,
67+
num_classes,
68+
num_nodes,
69+
data_name,
7970
embeddings,
8071
quiet=quiet,
81-
method="both",
72+
method=method,
8273
)
8374

8475
# Format the output as desired
@@ -87,20 +78,73 @@ def evaluate_from_embed_file(
8778
"NMI": f"{nmi_mean * 100:.2f}±{nmi_std * 100:.2f}",
8879
"AMI": f"{ami_mean * 100:.2f}±{ami_std * 100:.2f}",
8980
"ARI": f"{ari_mean * 100:.2f}±{ari_std * 100:.2f}",
90-
"Macro F1": f"{f1_mean * 100:.2f}±{f1_std * 100:.2f}",
81+
"MaF1": f"{f1_mean * 100:.2f}±{f1_std * 100:.2f}",
9182
}
9283

9384
svm_macro_f1_list = [f"{res[0] * 100:.2f}±{res[1] * 100:.2f}" for res in svm_macro_f1_list]
9485
svm_micro_f1_list = [f"{res[0] * 100:.2f}±{res[1] * 100:.2f}" for res in svm_micro_f1_list]
9586

9687
classification_results = {}
9788
for i, percent in enumerate(["10%", "20%", "30%", "40%"]):
98-
classification_results[f"{percent}_Macro-F1"] = svm_macro_f1_list[i]
99-
classification_results[f"{percent}_Micro-F1"] = svm_micro_f1_list[i]
89+
classification_results[f"{percent}_MaF1"] = svm_macro_f1_list[i]
90+
classification_results[f"{percent}_MiF1"] = svm_micro_f1_list[i]
10091

10192
return clustering_results, classification_results
10293

10394

95+
def evaluate_from_embed_file(
96+
embedding_file: str,
97+
data_file: str,
98+
save_path: str = "./tmp/",
99+
quiet: bool = True,
100+
) -> Tuple[Dict, Dict]:
101+
"""Evaluation of representation quality using clustering and classification tasks.
102+
103+
Args:
104+
embedding_file (str): Embedded file name.
105+
data_file (str): Data file name.
106+
save_path (str, optional): Folder path to store. Defaults to './tmp/'.
107+
quiet (bool, optional): Whether to print results. Defaults to True.
108+
109+
Returns:
110+
Tuple[Dict, Dict]: Two dicts are included, \
111+
which are the evaluation results of clustering and classification.
112+
113+
Example:
114+
.. code-block:: python
115+
116+
from graph_datasets import evaluate_from_embed_file
117+
118+
method_name='orderedgnn'
119+
data_name='texas'
120+
121+
clustering_res, classification_res = evaluate_from_embed_file(
122+
f'{data_name}_{method_name}_embeds.pth',
123+
f'{data_name}_data.pth',
124+
save_path='./save/',
125+
)
126+
"""
127+
embedding_file = os.path.join(save_path, embedding_file)
128+
data_file = os.path.join(save_path, data_file)
129+
130+
embeddings = load_from_file(embedding_file).cpu().detach()
131+
data = load_from_file(data_file)
132+
133+
labels = data.y.detach().cpu().numpy()
134+
num_classes = data.num_classes
135+
num_nodes = data.num_nodes
136+
data_name = data.name
137+
138+
return evaluate_from_embeddings(
139+
labels=labels,
140+
num_classes=num_classes,
141+
num_nodes=num_nodes,
142+
data_name=data_name,
143+
embeddings=embeddings,
144+
quiet=quiet,
145+
)
146+
147+
104148
# if __name__ == "__main__":
105149
# method_name = 'orderedgnn' # 'selene' 'greet' 'hgrl' 'nwr-gae' 'orderedgnn'
106150
# data_name = 'texas' # 'actor' 'chameleon' 'cornell' 'squirrel' 'texas' 'wisconsin'

0 commit comments

Comments
 (0)