55from typing import Dict
66from typing import Tuple
77
8+ import numpy as np
89import torch
910
1011from .eval_tools import evaluate_results_nc
@@ -22,44 +23,31 @@ def load_from_file(file_name):
2223 return embeddings
2324
2425
25- def evaluate_from_embed_file (
26- embedding_file : str ,
27- data_file : str ,
28- save_path : str = "./tmp/" ,
26+ def evaluate_from_embeddings (
27+ labels : np .ndarray ,
28+ num_classes : int ,
29+ num_nodes : int ,
30+ data_name : str ,
31+ embeddings : torch .Tensor ,
2932 quiet : bool = True ,
33+ method : str = "both" ,
3034) -> Tuple [Dict , Dict ]:
31- """Evaluation of representation quality using clustering and classification tasks .
35+ """evaluate embeddings with LR and Clustering .
3236
3337 Args:
34- embedding_file (str): Embedded file name.
35- data_file (str): Data file name.
36- save_path (str, optional): Folder path to store. Defaults to './tmp/'.
37- quiet (bool, optional): Whether to print results. Defaults to True.
38+ labels (np.ndarray): labels.
39+ num_classes (int): number of classes.
40+ num_nodes (int): number of nodes.
41+ data_name (str): name of the datasets.
42+ embeddings (torch.Tensor): embeddings.
43+ quiet (bool, optional): whether to print info. Defaults to True.
44+ method (bool, optional): method for evaluation, \
45+ "sup" for linear regression, "unsup" for svm clustering, "both" for both.\
46+ Defaults to "both".
3847
3948 Returns:
40- Tuple[Dict, Dict]: Two dicts are included, \
41- which are the evaluation results of clustering and classification.
42-
43- Example:
44- .. code-block:: python
45-
46- from graph_datasets import evaluate_from_embed_file
47-
48- method_name='orderedgnn'
49- data_name='texas'
50-
51- clustering_res, classification_res = evaluate_from_embed_file(
52- f'{data_name}_{method_name}_embeds.pth',
53- f'{data_name}_data.pth',
54- save_path='./save/',
55- )
49+ Tuple[Dict, Dict]: (clustering_results, classification_results)
5650 """
57- embedding_file = os .path .join (save_path , embedding_file )
58- data_file = os .path .join (save_path , data_file )
59-
60- embeddings = load_from_file (embedding_file ).cpu ().detach ()
61- data = load_from_file (data_file )
62-
6351 # Call the evaluate_results_nc function with the loaded embeddings
6452 (
6553 svm_macro_f1_list ,
@@ -75,10 +63,13 @@ def evaluate_from_embed_file(
7563 f1_mean ,
7664 f1_std ,
7765 ) = evaluate_results_nc (
78- data ,
66+ labels ,
67+ num_classes ,
68+ num_nodes ,
69+ data_name ,
7970 embeddings ,
8071 quiet = quiet ,
81- method = "both" ,
72+ method = method ,
8273 )
8374
8475 # Format the output as desired
@@ -87,20 +78,73 @@ def evaluate_from_embed_file(
8778 "NMI" : f"{ nmi_mean * 100 :.2f} ±{ nmi_std * 100 :.2f} " ,
8879 "AMI" : f"{ ami_mean * 100 :.2f} ±{ ami_std * 100 :.2f} " ,
8980 "ARI" : f"{ ari_mean * 100 :.2f} ±{ ari_std * 100 :.2f} " ,
90- "Macro F1 " : f"{ f1_mean * 100 :.2f} ±{ f1_std * 100 :.2f} " ,
81+ "MaF1 " : f"{ f1_mean * 100 :.2f} ±{ f1_std * 100 :.2f} " ,
9182 }
9283
9384 svm_macro_f1_list = [f"{ res [0 ] * 100 :.2f} ±{ res [1 ] * 100 :.2f} " for res in svm_macro_f1_list ]
9485 svm_micro_f1_list = [f"{ res [0 ] * 100 :.2f} ±{ res [1 ] * 100 :.2f} " for res in svm_micro_f1_list ]
9586
9687 classification_results = {}
9788 for i , percent in enumerate (["10%" , "20%" , "30%" , "40%" ]):
98- classification_results [f"{ percent } _Macro-F1 " ] = svm_macro_f1_list [i ]
99- classification_results [f"{ percent } _Micro-F1 " ] = svm_micro_f1_list [i ]
89+ classification_results [f"{ percent } _MaF1 " ] = svm_macro_f1_list [i ]
90+ classification_results [f"{ percent } _MiF1 " ] = svm_micro_f1_list [i ]
10091
10192 return clustering_results , classification_results
10293
10394
95+ def evaluate_from_embed_file (
96+ embedding_file : str ,
97+ data_file : str ,
98+ save_path : str = "./tmp/" ,
99+ quiet : bool = True ,
100+ ) -> Tuple [Dict , Dict ]:
101+ """Evaluation of representation quality using clustering and classification tasks.
102+
103+ Args:
104+ embedding_file (str): Embedded file name.
105+ data_file (str): Data file name.
106+ save_path (str, optional): Folder path to store. Defaults to './tmp/'.
107+ quiet (bool, optional): Whether to print results. Defaults to True.
108+
109+ Returns:
110+ Tuple[Dict, Dict]: Two dicts are included, \
111+ which are the evaluation results of clustering and classification.
112+
113+ Example:
114+ .. code-block:: python
115+
116+ from graph_datasets import evaluate_from_embed_file
117+
118+ method_name='orderedgnn'
119+ data_name='texas'
120+
121+ clustering_res, classification_res = evaluate_from_embed_file(
122+ f'{data_name}_{method_name}_embeds.pth',
123+ f'{data_name}_data.pth',
124+ save_path='./save/',
125+ )
126+ """
127+ embedding_file = os .path .join (save_path , embedding_file )
128+ data_file = os .path .join (save_path , data_file )
129+
130+ embeddings = load_from_file (embedding_file ).cpu ().detach ()
131+ data = load_from_file (data_file )
132+
133+ labels = data .y .detach ().cpu ().numpy ()
134+ num_classes = data .num_classes
135+ num_nodes = data .num_nodes
136+ data_name = data .name
137+
138+ return evaluate_from_embeddings (
139+ labels = labels ,
140+ num_classes = num_classes ,
141+ num_nodes = num_nodes ,
142+ data_name = data_name ,
143+ embeddings = embeddings ,
144+ quiet = quiet ,
145+ )
146+
147+
104148# if __name__ == "__main__":
105149# method_name = 'orderedgnn' # 'selene' 'greet' 'hgrl' 'nwr-gae' 'orderedgnn'
106150# data_name = 'texas' # 'actor' 'chameleon' 'cornell' 'squirrel' 'texas' 'wisconsin'
0 commit comments