Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 26 additions & 30 deletions core_motif.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def create_candidate_motifs(self):


def gumbel_softmax(self, logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, \
g_ratio: float = 1e-3) -> torch.Tensor:
g_ratio: float = 1.) -> torch.Tensor:
gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0,1)
Expand Down Expand Up @@ -328,18 +328,17 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False):
if g.batch_size != 1:
ac_first_prob = [torch.softmax(logit, dim=0)
for i, logit in enumerate(torch.split(logits_first, att_len, dim=0))]
ac_first_prob = [p+1e-8 for p in ac_first_prob]
log_ac_first_prob = [x.log() for x in ac_first_prob]

else:
ac_first_prob = torch.softmax(logits_first, dim=0) + 1e-8
ac_first_prob = torch.softmax(logits_first, dim=0)
log_ac_first_prob = ac_first_prob.log()

if g.batch_size != 1:
first_stack = []
first_ac_stack = []
for i, node_emb_i in enumerate(torch.split(att_emb, att_len, dim=0)):
ac_first_hot_i = self.gumbel_softmax(ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1)
ac_first_hot_i = self.gumbel_softmax(log_ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1)
ac_first_i = torch.argmax(ac_first_hot_i, dim=-1)
first_stack.append(torch.matmul(ac_first_hot_i, node_emb_i))
first_ac_stack.append(ac_first_i)
Expand All @@ -355,20 +354,20 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False):
for i, ac_first_prob_i in enumerate(ac_first_prob)], dim=0).contiguous()

log_ac_first_prob = torch.cat([
torch.cat([log_ac_first_prob_i, log_ac_first_prob_i.new_zeros(
max(self.max_action - log_ac_first_prob_i.size(0),0),1)]
torch.cat([log_ac_first_prob_i, torch.full(
(max(self.max_action - log_ac_first_prob_i.size(0),0),1), float("-inf"), device=self.device)]
, 0).contiguous().view(1,self.max_action)
for i, log_ac_first_prob_i in enumerate(log_ac_first_prob)], dim=0).contiguous()

else:
ac_first_hot = self.gumbel_softmax(ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1)
ac_first_hot = self.gumbel_softmax(log_ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1)
ac_first = torch.argmax(ac_first_hot, dim=-1)
emb_first = torch.matmul(ac_first_hot, att_emb)
ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros(
max(self.max_action - ac_first_prob.size(0),0),1)]
, 0).contiguous().view(1,self.max_action)
log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros(
max(self.max_action - log_ac_first_prob.size(0),0),1)]
log_ac_first_prob = torch.cat([log_ac_first_prob, torch.full(
(max(self.max_action - log_ac_first_prob.size(0),0),1), float("-inf"), device=self.device)]
, 0).contiguous().view(1,self.max_action)

# ===============================
Expand All @@ -381,18 +380,16 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False):
self.action2_layers[1](cand_expand) + self.action2_layers[2](emb_first_expand)

logit_second = self.action2_layers[3](emb_cat).squeeze(-1)
ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8
ac_second_prob = F.softmax(logit_second, dim=-1)
log_ac_second_prob = ac_second_prob.log()

ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3)
ac_second_hot = self.gumbel_softmax(log_ac_second_prob, tau=self.tau, hard=True)
emb_second = torch.matmul(ac_second_hot, cand_graph_emb)
ac_second = torch.argmax(ac_second_hot, dim=-1)

# Print gumbel otuput
ac_second_gumbel = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=False, g_ratio=1e-3)


# ===============================
# step 4 : where to add on motif
# step 3 : where to add on motif
# ===============================
# Select att points from candidate
cand_att_emb = torch.masked_select(cand_node_emb, cand_att_mask.unsqueeze(-1))
Expand All @@ -418,20 +415,19 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False):
if g.batch_size != 1:
ac_third_prob = [torch.softmax(logit,dim=-1)
for i, logit in enumerate(torch.split(logits_third.squeeze(-1), ac3_att_len, dim=0))]
ac_third_prob = [p+1e-8 for p in ac_third_prob]
log_ac_third_prob = [x.log() for x in ac_third_prob]

else:
logits_third = logits_third.transpose(1,0)
ac_third_prob = torch.softmax(logits_third, dim=-1) + 1e-8
ac_third_prob = torch.softmax(logits_third, dim=-1)
log_ac_third_prob = ac_third_prob.log()

# gumbel softmax sampling and zero-padding
if g.batch_size != 1:
third_stack = []
third_ac_stack = []
for i, node_emb_i in enumerate(torch.split(emb_cat_ac3, ac3_att_len, dim=0)):
ac_third_hot_i = self.gumbel_softmax(ac_third_prob[i], tau=self.tau, hard=True, dim=-1)
ac_third_hot_i = self.gumbel_softmax(log_ac_third_prob[i], tau=self.tau, hard=True, dim=-1)
ac_third_i = torch.argmax(ac_third_hot_i, dim=-1)
third_stack.append(torch.matmul(ac_third_hot_i, node_emb_i))
third_ac_stack.append(ac_third_i)
Expand All @@ -446,21 +442,21 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False):
for i, ac_third_prob_i in enumerate(ac_third_prob)], dim=0).contiguous()

log_ac_third_prob = torch.cat([
torch.cat([log_ac_third_prob_i, log_ac_third_prob_i.new_zeros(
self.max_action - log_ac_third_prob_i.size(0))]
torch.cat([log_ac_third_prob_i, torch.full(
(self.max_action - log_ac_third_prob_i.size(0), ), float("-inf"), device=self.device)]
, 0).contiguous().view(1,self.max_action)
for i, log_ac_third_prob_i in enumerate(log_ac_third_prob)], dim=0).contiguous()

else:
ac_third_hot = self.gumbel_softmax(ac_third_prob, tau=self.tau, hard=True, dim=-1)
ac_third_hot = self.gumbel_softmax(log_ac_third_prob, tau=self.tau, hard=True, dim=-1)
ac_third = torch.argmax(ac_third_hot, dim=-1)
emb_third = torch.matmul(ac_third_hot, emb_cat_ac3)

ac_third_prob = torch.cat([ac_third_prob, ac_third_prob.new_zeros(
1, self.max_action - ac_third_prob.size(1))]
, -1).contiguous()
log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros(
1, self.max_action - log_ac_third_prob.size(1))]
log_ac_third_prob = torch.cat([log_ac_third_prob, torch.full(
(1, self.max_action - log_ac_third_prob.size(1)), float("-inf"), device=self.device)]
, -1).contiguous()

# ==== concat everything ====
Expand All @@ -470,7 +466,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False):
log_ac_second_prob, log_ac_third_prob], dim=1).contiguous()
ac = torch.stack([ac_first, ac_second, ac_third], dim=1)

return ac, (ac_prob, log_ac_prob), (ac_first_prob, ac_second_hot, ac_third_prob)
return ac, (ac_prob, (log_ac_prob.exp() + 1e-8).log()), (ac_first_prob, ac_second_hot, ac_third_prob)

def sample(self, ac, graph_emb, node_emb, g, cands):
g.ndata['node_emb'] = node_emb
Expand Down Expand Up @@ -504,8 +500,8 @@ def sample(self, ac, graph_emb, node_emb, g, cands):
max(self.max_action - ac_first_prob.size(1),0))]
, 1).contiguous()

log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros(1,
max(self.max_action - log_ac_first_prob.size(1),0))]
log_ac_first_prob = torch.cat([log_ac_first_prob, torch.full((1,
max(self.max_action - log_ac_first_prob.size(1),0)), float("-inf"), device=self.device)]
, 1).contiguous()
emb_first = att_emb[ac[0]].unsqueeze(0)

Expand All @@ -522,7 +518,7 @@ def sample(self, ac, graph_emb, node_emb, g, cands):
ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8
log_ac_second_prob = ac_second_prob.log()

ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3)
ac_second_hot = self.gumbel_softmax(logit_second, tau=self.tau, hard=True)
emb_second = torch.matmul(ac_second_hot, cand_graph_emb)
ac_second = torch.argmax(ac_second_hot, dim=-1)

Expand Down Expand Up @@ -559,16 +555,16 @@ def sample(self, ac, graph_emb, node_emb, g, cands):
ac_third_prob = torch.cat([ac_third_prob, ac_third_prob.new_zeros(
1, self.max_action - ac_third_prob.size(1))]
, -1).contiguous()
log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros(
1, self.max_action - log_ac_third_prob.size(1))]
log_ac_third_prob = torch.cat([log_ac_third_prob, torch.full(
(1, self.max_action - log_ac_third_prob.size(1)), float("-inf"), device=self.device)]
, -1).contiguous()

# ==== concat everything ====
ac_prob = torch.cat([ac_first_prob, ac_second_prob, ac_third_prob], dim=1).contiguous()
log_ac_prob = torch.cat([log_ac_first_prob,
log_ac_second_prob, log_ac_third_prob], dim=1).contiguous()

return (ac_prob, log_ac_prob), (ac_first_prob, ac_second_hot, ac_third_prob)
return (ac_prob, (log_ac_prob.exp() + 1e-8).log()), (ac_first_prob, ac_second_hot, ac_third_prob)


class GCNEmbed(nn.Module):
Expand Down
Loading