1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .nn .functional as F
4
+
5
+ class Code2Vec (nn .Module ):
6
+ def __init__ (self , nodes_dim , paths_dim , embedding_dim , output_dim , dropout ):
7
+ super ().__init__ ()
8
+
9
+ self .node_embedding = nn .Embedding (nodes_dim , embedding_dim )
10
+ self .path_embedding = nn .Embedding (paths_dim , embedding_dim )
11
+ self .W = nn .Parameter (torch .randn (1 , embedding_dim , 3 * embedding_dim ))
12
+ self .a = nn .Parameter (torch .randn (1 , embedding_dim , 1 ))
13
+ self .out = nn .Linear (embedding_dim , output_dim )
14
+ self .do = nn .Dropout (dropout )
15
+
16
+ def forward (self , starts , paths , ends ):
17
+
18
+ #starts = paths = ends = [batch size, max length]
19
+
20
+ W = self .W .repeat (starts .shape [0 ], 1 , 1 )
21
+
22
+ #W = [batch size, embedding dim, embedding dim * 3]
23
+
24
+ embedded_starts = self .node_embedding (starts )
25
+ embedded_paths = self .path_embedding (paths )
26
+ embedded_ends = self .node_embedding (ends )
27
+
28
+ #embedded_* = [batch size, max length, embedding dim]
29
+
30
+ c = self .do (torch .cat ((embedded_starts , embedded_paths , embedded_ends ), dim = 2 ))
31
+
32
+ #c = [batch size, max length, embedding dim * 3]
33
+
34
+ c = c .permute (0 , 2 , 1 )
35
+
36
+ #c = [batch size, embedding dim * 3, max length]
37
+
38
+ x = torch .tanh (torch .bmm (W , c ))
39
+
40
+ #x = [batch size, embedding dim, max length]
41
+
42
+ x = x .permute (0 , 2 , 1 )
43
+
44
+ #x = [batch size, max length, embedding dim]
45
+
46
+ a = self .a .repeat (starts .shape [0 ], 1 , 1 )
47
+
48
+ #a = [batch size, embedding dim, 1]
49
+
50
+ z = torch .bmm (x , a ).squeeze (2 )
51
+
52
+ #z = [batch size, max length]
53
+
54
+ z = F .softmax (z , dim = 1 )
55
+
56
+ #z = [batch size, max length]
57
+
58
+ z = z .unsqueeze (2 )
59
+
60
+ #z = [batch size, max length, 1]
61
+
62
+ x = x .permute (0 , 2 , 1 )
63
+
64
+ #x = [batch size, embedding dim, max length]
65
+
66
+ v = torch .bmm (x , z ).squeeze (2 )
67
+
68
+ #v = [batch size, embedding dim]
69
+
70
+ out = self .out (v )
71
+
72
+ #out = [batch size, output dim]
73
+
74
+ return out
0 commit comments