Skip to content

Commit bf9779a

Browse files
committed
fix: update n_init of kmeans
1 parent 873859b commit bf9779a

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

.ci/install-dev.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
rm -rf .env
22

3-
# # install python 3.8.16 use pyenv
3+
# # install python 3.8.16 using pyenv:
44
# pyenv install 3.8.16
55
# pyenv local 3.8.16
66

graph_datasets/utils/common.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@ def format_result(
2020
timezone="Asia/Shanghai",
2121
**kwargs,
2222
):
23-
time = get_str_time(timezone)
2423
if sort_kw:
2524
kwargs = dict(sorted(kwargs.items()))
26-
return {
27-
**kwargs,
28-
"ds": dataset,
29-
"src": source,
30-
"model": model,
31-
"time": time,
32-
}
25+
26+
kwargs.update(
27+
{
28+
"ds": dataset,
29+
"src": source,
30+
"model": model,
31+
"time": get_str_time(timezone),
32+
}
33+
)
34+
return kwargs
3335

3436

3537
def get_str_time(timezone="Asia/Shanghai"):
@@ -93,7 +95,8 @@ def tab_printer(
9395
[
9496
k.replace("_", " "),
9597
f"{args[k]}" if isinstance(args[k], bool) else format_value(args[k]),
96-
] for k in keys
98+
]
99+
for k in keys
97100
]
98101
)
99102
if cols_align is not None:
@@ -116,7 +119,9 @@ def download_tip(info: Dict) -> None:
116119
data_file (str): filepath.
117120
url (str): url for downloading.
118121
"""
119-
info["Tip"] = "If the download fails, \
122+
info[
123+
"Tip"
124+
] = "If the download fails, \
120125
use the 'Download URL' to download manually and move the file to the 'Save Path'."
121126

122127
tab_printer(info)

graph_datasets/utils/evaluation/eval_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def kmeans_test(X, y, n_clusters, repeat=10):
151151
ari_list = []
152152
f1_list = []
153153
for _ in range(repeat):
154-
kmeans = KMeans(n_clusters=n_clusters)
154+
kmeans = KMeans(n_clusters=n_clusters, n_init='auto')
155155
y_pred = kmeans.fit_predict(X)
156156
(
157157
acc_score,

0 commit comments

Comments
 (0)