Skip to content
This repository was archived by the owner on Jan 29, 2020. It is now read-only.

Commit 1266180

Browse files
committed
cleaned up code
1 parent 5aff8e6 commit 1266180

File tree

7 files changed

+317
-520
lines changed

7 files changed

+317
-520
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,4 @@ venv.bak/
108108
data/
109109
logs/
110110
models/
111+
checkpoints/

balance.py

-47
This file was deleted.

code2vec

-1
This file was deleted.

download.sh

-45
This file was deleted.

flatten.py

-17
This file was deleted.

models.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class Code2Vec(nn.Module):
6+
def __init__(self, nodes_dim, paths_dim, embedding_dim, output_dim, dropout):
7+
super().__init__()
8+
9+
self.node_embedding = nn.Embedding(nodes_dim, embedding_dim)
10+
self.path_embedding = nn.Embedding(paths_dim, embedding_dim)
11+
self.W = nn.Parameter(torch.randn(1, embedding_dim, 3*embedding_dim))
12+
self.a = nn.Parameter(torch.randn(1, embedding_dim, 1))
13+
self.out = nn.Linear(embedding_dim, output_dim)
14+
self.do = nn.Dropout(dropout)
15+
16+
def forward(self, starts, paths, ends):
17+
18+
#starts = paths = ends = [batch size, max length]
19+
20+
W = self.W.repeat(starts.shape[0], 1, 1)
21+
22+
#W = [batch size, embedding dim, embedding dim * 3]
23+
24+
embedded_starts = self.node_embedding(starts)
25+
embedded_paths = self.path_embedding(paths)
26+
embedded_ends = self.node_embedding(ends)
27+
28+
#embedded_* = [batch size, max length, embedding dim]
29+
30+
c = self.do(torch.cat((embedded_starts, embedded_paths, embedded_ends), dim=2))
31+
32+
#c = [batch size, max length, embedding dim * 3]
33+
34+
c = c.permute(0, 2, 1)
35+
36+
#c = [batch size, embedding dim * 3, max length]
37+
38+
x = torch.tanh(torch.bmm(W, c))
39+
40+
#x = [batch size, embedding dim, max length]
41+
42+
x = x.permute(0, 2, 1)
43+
44+
#x = [batch size, max length, embedding dim]
45+
46+
a = self.a.repeat(starts.shape[0], 1, 1)
47+
48+
#a = [batch size, embedding dim, 1]
49+
50+
z = torch.bmm(x, a).squeeze(2)
51+
52+
#z = [batch size, max length]
53+
54+
z = F.softmax(z, dim=1)
55+
56+
#z = [batch size, max length]
57+
58+
z = z.unsqueeze(2)
59+
60+
#z = [batch size, max length, 1]
61+
62+
x = x.permute(0, 2, 1)
63+
64+
#x = [batch size, embedding dim, max length]
65+
66+
v = torch.bmm(x, z).squeeze(2)
67+
68+
#v = [batch size, embedding dim]
69+
70+
out = self.out(v)
71+
72+
#out = [batch size, output dim]
73+
74+
return out

0 commit comments

Comments
 (0)