@@ -312,9 +312,7 @@ class GensimWordEmbedding(AbstractWordEmbedding):
312312 def __init__ (self , keyed_vectors ):
313313 gensim = utils .LazyLoader ("gensim" , globals (), "gensim" )
314314
315- if isinstance (
316- keyed_vectors , gensim .models .keyedvectors .WordEmbeddingsKeyedVectors
317- ):
315+ if isinstance (keyed_vectors , gensim .models .KeyedVectors ):
318316 self .keyed_vectors = keyed_vectors
319317 else :
320318 raise ValueError (
@@ -335,11 +333,11 @@ def __getitem__(self, index):
335333 """
336334 if isinstance (index , str ):
337335 try :
338- index = self .keyed_vectors .vocab .get (index ). index
336+ index = self .keyed_vectors .key_to_index .get (index )
339337 except KeyError :
340338 return None
341339 try :
342- return self .keyed_vectors .vectors_norm [index ]
340+ return self .keyed_vectors .get_normed_vectors () [index ]
343341 except IndexError :
344342 # word embedding ID out of bounds
345343 return None
@@ -352,10 +350,10 @@ def word2index(self, word):
352350 Returns:
353351 index (int)
354352 """
355- vocab = self .keyed_vectors .vocab .get (word )
353+ vocab = self .keyed_vectors .key_to_index .get (word )
356354 if vocab is None :
357355 raise KeyError (word )
358- return vocab . index
356+ return vocab
359357
360358 def index2word (self , index ):
361359 """
@@ -368,7 +366,7 @@ def index2word(self, index):
368366 """
369367 try :
370368 # this is a list, so the error would be IndexError
371- return self .keyed_vectors .index2word [index ]
369+ return self .keyed_vectors .index_to_key [index ]
372370 except IndexError :
373371 raise KeyError (index )
374372
@@ -386,8 +384,8 @@ def get_mse_dist(self, a, b):
386384 try :
387385 mse_dist = self ._mse_dist_mat [a ][b ]
388386 except KeyError :
389- e1 = self .keyed_vectors .vectors_norm [a ]
390- e2 = self .keyed_vectors .vectors_norm [b ]
387+ e1 = self .keyed_vectors .get_normed_vectors () [a ]
388+ e2 = self .keyed_vectors .get_normed_vectors () [b ]
391389 e1 = torch .tensor (e1 ).to (utils .device )
392390 e2 = torch .tensor (e2 ).to (utils .device )
393391 mse_dist = torch .sum ((e1 - e2 ) ** 2 ).item ()
@@ -406,9 +404,9 @@ def get_cos_sim(self, a, b):
406404 distance (float): cosine similarity
407405 """
408406 if not isinstance (a , str ):
409- a = self .keyed_vectors .index2word [a ]
407+ a = self .keyed_vectors .index_to_key [a ]
410408 if not isinstance (b , str ):
411- b = self .keyed_vectors .index2word [b ]
409+ b = self .keyed_vectors .index_to_key [b ]
412410 cos_sim = self .keyed_vectors .similarity (a , b )
413411 return cos_sim
414412
@@ -421,7 +419,7 @@ def nearest_neighbours(self, index, topn, return_words=True):
421419 Returns:
422420 neighbours (list[int]): List of indices of the nearest neighbours
423421 """
424- word = self .keyed_vectors .index2word [index ]
422+ word = self .keyed_vectors .index_to_key [index ]
425423 return [
426424 self .word2index (i [0 ])
427425 for i in self .keyed_vectors .similar_by_word (word , topn )
0 commit comments