Skip to content

Commit 6fabafc

Browse files
committed
init
1 parent 159c7a3 commit 6fabafc

File tree

4 files changed

+573
-0
lines changed

4 files changed

+573
-0
lines changed

cosine_similarity.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from sklearn.metrics.pairwise import cosine_similarity
2+
import os
3+
import json
4+
import numpy as np
5+
6+
def cos_sim(vector_a, vector_b):
7+
vector_a = np.mat(vector_a)
8+
vector_b = np.mat(vector_b)
9+
num = float(vector_a * vector_b.T)
10+
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
11+
sim = num / denom
12+
return sim
13+
14+
# with open('path/imagenet_glove.json', 'r') as f:
15+
# data = json.load(f)
16+
#
17+
# source_vectors = data['0']
18+
# target_vectors = data['1']
19+
# simple_sim = cos_sim(source_vectors, target_vectors)

embedding.py

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
from sklearn.metrics.pairwise import cosine_similarity
2+
import os
3+
import json
4+
import numpy as np
5+
from cosine_similarity import *
6+
from embedding import *
7+
from get_avelabel import *
8+
import scipy.io as scio
9+
10+
def read_mat():
11+
dataFile = 'path/imagelabels.mat'
12+
data = scio.loadmat(dataFile)
13+
14+
def get_imagenet_labels():
15+
"""Return list of imagnet labels
16+
17+
Returns:
18+
[list(str)] -- list of imagnet labels
19+
"""
20+
with open('imagenet_class_index.json', 'r') as f:
21+
class_idx = json.load(f)
22+
imagenet_labels = [class_idx[str(k)][1] for k in range(len(class_idx))]
23+
return imagenet_labels
24+
25+
def get_flower_labels():
26+
fname = 'path/flo_labels.txt'
27+
with open(fname, 'r+', encoding='utf-8') as f:
28+
s = [i[:-1].split(',') for i in f.readlines()]
29+
flo_labels = [s[k][1] for k in range(len(s))]
30+
return flo_labels
31+
32+
def get_inat_labels():
33+
# fname = 'path/inat_label.txt'
34+
# with open(fname, 'r+', encoding='utf-8') as f:
35+
# s = [i.lower().replace('\n','') for i in f.readlines()]
36+
# return s
37+
38+
"8000"
39+
with open('path/categories.json', 'r') as f:
40+
data = json.load(f)
41+
name_list = [i['name'] for i in data]
42+
43+
filename = open('path/inat_8000categories.txt', 'w')
44+
for i in name_list:
45+
filename.write(i)
46+
filename.write('\n')
47+
filename.close()
48+
return name_list
49+
50+
def get_cal_labels():
51+
fname = 'path/cal_label.txt'
52+
with open(fname, 'r+', encoding='utf-8') as f:
53+
s = [i.replace('-',' ').replace('\n','') for i in f.readlines()]
54+
return s
55+
56+
def get_sun_labels():
57+
fname = 'path/ClassName.txt'
58+
with open(fname, 'r+', encoding='utf-8') as f:
59+
s = [i.replace('_',' ').replace('\n','').split('/')[-1] for i in f.readlines()]
60+
return s
61+
62+
def get_nih_labels():
63+
fname = 'path/nih.txt'
64+
with open(fname, 'r+', encoding='utf-8') as f:
65+
n = [i.replace('\n','').split(',')[-1] for i in f.readlines()]
66+
with open(fname, 'r+', encoding='utf-8') as f:
67+
s = [i.split(',')[0] for i in f.readlines()]
68+
return s, n
69+
70+
71+
def label_to_embedding(label, word2emb):
72+
"""label to glove """
73+
# for idx, word in enumerate(label):
74+
# if word not in word2emb:
75+
# return None
76+
# glove_v = word2emb[word]
77+
78+
# try:
79+
# if label not in word2emb:
80+
# return None
81+
# else:
82+
# glove_v = word2emb[label]
83+
# return glove_v
84+
# except:
85+
# print('label corrupt', label)
86+
if isinstance(label, list):
87+
label_key = label[0]
88+
else:
89+
label_key = label
90+
if label_key not in word2emb:
91+
return None
92+
else:
93+
glove_v = word2emb[label_key]
94+
return glove_v
95+
96+
def imagenet_embedding(word2emb):
97+
source_vectors = {}
98+
source = get_imagenet_labels()
99+
target = get_imagenet_labels()
100+
for i, label in enumerate(source):
101+
imagenet_label = label.replace('_', ' ').split(' ')
102+
if len(imagenet_label) > 1:
103+
vector_average = 0
104+
for word in imagenet_label:
105+
vector_add = label_to_embedding(word, word2emb)
106+
if vector_add is not None:
107+
vector_average = vector_average + vector_add
108+
if not isinstance(vector_average, int):
109+
vector_average = vector_average / len(imagenet_label)
110+
source_vectors[i] = np.array(vector_average).tolist()
111+
else:
112+
source_v = label_to_embedding(imagenet_label, word2emb)
113+
if source_v is not None:
114+
source_vectors[i] = np.array(source_v).tolist()
115+
print(i)
116+
117+
with open("path/imagenet_glove.json", "w") as f:
118+
json.dump(source_vectors, f)
119+
print("loading finished")
120+
121+
def COVID_embedding(word2emb):
122+
p_emb = label_to_embedding('pneumonia', word2emb)
123+
### 349
124+
n_emb_add = label_to_embedding('not', word2emb)
125+
n_emb = (p_emb+n_emb_add)/2
126+
### 398
127+
##total 747
128+
129+
return p_emb, n_emb
130+
131+
def phe_embedding(word2emb):
132+
p_emb = label_to_embedding('pneumonia', word2emb)
133+
### 3875 +8 +390 = 4273
134+
n_emb_add = label_to_embedding('not', word2emb)
135+
n_emb = (p_emb+n_emb_add)/2
136+
### 1341 +8 +234 = 1583
137+
##total 5856
138+
return p_emb, n_emb
139+
140+
def luna_embedding(word2emb):
141+
p_emb = (label_to_embedding('lung', word2emb) + label_to_embedding('cancer', word2emb))/2
142+
### 785
143+
n_emb = (label_to_embedding('not', word2emb) + label_to_embedding('lung', word2emb) + label_to_embedding('cancer', word2emb))/3
144+
### 70720
145+
##total 71505
146+
147+
return p_emb, n_emb
148+
149+
def embedding(word2emb):
150+
source_vectors = {}
151+
source = get_cal_labels()
152+
for i, label in enumerate(source):
153+
imagenet_label = label.replace('_', ' ').split(' ')
154+
if len(imagenet_label) > 1:
155+
vector_average = 0
156+
for word in imagenet_label:
157+
vector_add = label_to_embedding(word, word2emb)
158+
if vector_add is not None:
159+
vector_average = vector_average + vector_add
160+
if not isinstance(vector_average, int):
161+
vector_average = vector_average / len(imagenet_label)
162+
source_vectors[i] = np.array(vector_average).tolist()
163+
else:
164+
source_v = label_to_embedding(imagenet_label, word2emb)
165+
if source_v is not None:
166+
source_vectors[i] = np.array(source_v).tolist()
167+
168+
with open("flo_glove.json", "w") as f:
169+
json.dump(source_vectors, f)
170+
print("loading finished")
171+
172+
173+
def inat_embedding(word2emb):
174+
source_vectors = {}
175+
source = get_inat_labels()
176+
for i, label in enumerate(source):
177+
imagenet_label = label.lower().split(' ')
178+
if len(imagenet_label) > 1:
179+
vector_average = 0
180+
for word in imagenet_label:
181+
vector_add = label_to_embedding(word, word2emb)
182+
if vector_add is not None:
183+
vector_average = vector_average + vector_add
184+
if not isinstance(vector_average, int):
185+
vector_average = vector_average / len(imagenet_label)
186+
source_vectors[i] = np.array(vector_average).tolist()
187+
else:
188+
source_v = label_to_embedding(imagenet_label, word2emb)
189+
if source_v is not None:
190+
source_vectors[i] = np.array(source_v).tolist()
191+
print('a')
192+
with open("inat_glove8000.json", "w") as f:
193+
json.dump(source_vectors, f)
194+
print("loading finished")
195+
196+
197+
def cal_embedding(word2emb):
198+
source_vectors = {}
199+
source = get_cal_labels()
200+
original_cal_label = {}
201+
for i, label in enumerate(source):
202+
imagenet_label = label.lower().split(' ')
203+
if len(imagenet_label) > 1:
204+
vector_average = 0
205+
for word in imagenet_label:
206+
vector_add = label_to_embedding(word, word2emb)
207+
if vector_add is not None:
208+
vector_average = vector_average + vector_add
209+
if not isinstance(vector_average, int):
210+
vector_average = vector_average / len(imagenet_label)
211+
source_vectors[i] = np.array(vector_average).tolist()
212+
original_cal_label[i] = imagenet_label
213+
else:
214+
source_v = label_to_embedding(imagenet_label, word2emb)
215+
if source_v is not None:
216+
source_vectors[i] = np.array(source_v).tolist()
217+
original_cal_label[i] = imagenet_label
218+
print('a')
219+
with open("cal_glove.json", "w") as f:
220+
json.dump(source_vectors, f)
221+
print("loading finished")
222+
with open("cal_label_vertorized.json", "w") as f:
223+
json.dump(original_cal_label, f)
224+
print("loading finished")
225+
226+
227+
def nih_embedding(word2emb):
228+
source_vectors = {}
229+
number_vectors = {}
230+
source, n = get_nih_labels()
231+
for i, label in enumerate(source):
232+
imagenet_label = label.lower().replace('-', ' ').split(' ')
233+
if len(imagenet_label) > 1:
234+
vector_average = 0
235+
for word in imagenet_label:
236+
vector_add = label_to_embedding(word, word2emb)
237+
if vector_add is not None:
238+
vector_average = vector_average + vector_add
239+
if not isinstance(vector_average, int):
240+
vector_average = vector_average / len(imagenet_label)
241+
source_vectors[i] = np.array(vector_average).tolist()
242+
else:
243+
source_v = label_to_embedding(imagenet_label, word2emb)
244+
if source_v is not None:
245+
source_vectors[i] = np.array(source_v).tolist()
246+
number_vectors[i] = n[i]
247+
print('a')
248+
with open("nih_glove.json", "w") as f:
249+
json.dump(source_vectors, f)
250+
print("loading finished")
251+
return number_vectors

get_avelabel.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from sklearn.metrics.pairwise import cosine_similarity
2+
import os
3+
import json
4+
import numpy as np
5+
6+
def read_mat():
7+
dataFile = 'path/imagelabels.mat'
8+
data = scio.loadmat(dataFile)
9+
10+
def get_imagenet_labels():
11+
"""Return list of imagnet labels
12+
13+
Returns:
14+
[list(str)] -- list of imagnet labels
15+
"""
16+
with open('../imagenet_class_index.json', 'r') as f:
17+
class_idx = json.load(f)
18+
imagenet_labels = [class_idx[str(k)][1] for k in range(len(class_idx))]
19+
return imagenet_labels
20+
21+
def get_flower_labels():
22+
fname = 'path/flo_labels.txt'
23+
with open(fname, 'r+', encoding='utf-8') as f:
24+
s = [i[:-1].split(',') for i in f.readlines()]
25+
flo_labels = [s[k][1] for k in range(len(s))]
26+
return flo_labels
27+
28+
def get_inat_labels():
29+
# fname = 'path/inat_label.txt'
30+
# with open(fname, 'r+', encoding='utf-8') as f:
31+
# s = [i.lower().replace('\n','') for i in f.readlines()]
32+
# return s
33+
34+
"8000"
35+
with open('path/categories.json', 'r') as f:
36+
data = json.load(f)
37+
name_list = [i['name'] for i in data]
38+
39+
filename = open('path/inat_8000categories.txt', 'w')
40+
for i in name_list:
41+
filename.write(i)
42+
filename.write('\n')
43+
filename.close()
44+
return name_list
45+
46+
def get_cal_labels():
47+
fname = 'path/cal_label.txt'
48+
with open(fname, 'r+', encoding='utf-8') as f:
49+
s = [i.replace('-',' ').replace('\n','') for i in f.readlines()]
50+
return s
51+
52+
def get_sun_labels():
53+
fname = 'path/ClassName.txt'
54+
with open(fname, 'r+', encoding='utf-8') as f:
55+
s = [i.replace('_',' ').replace('\n','').split('/')[-1] for i in f.readlines()]
56+
return s
57+
58+
def get_nih_labels():
59+
fname = 'path/nih.txt'
60+
with open(fname, 'r+', encoding='utf-8') as f:
61+
n = [i.replace('\n','').split(',')[-1] for i in f.readlines()]
62+
with open(fname, 'r+', encoding='utf-8') as f:
63+
s = [i.split(',')[0] for i in f.readlines()]
64+
return s, n

0 commit comments

Comments
 (0)