From ee3ca19868f4b05a6c2458deb6d2f51f15bbbe20 Mon Sep 17 00:00:00 2001 From: alexander-telepov Date: Mon, 18 Apr 2022 15:36:19 +0300 Subject: [PATCH 1/2] fix gumbel-softmax sampling --- core_motif.py | 32 ++++++++++++++------------------ core_motif_vbased.py | 39 ++++++++++++++++++--------------------- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/core_motif.py b/core_motif.py index 81d707f..84baab1 100644 --- a/core_motif.py +++ b/core_motif.py @@ -267,7 +267,7 @@ def create_candidate_motifs(self): def gumbel_softmax(self, logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, \ - g_ratio: float = 1e-3) -> torch.Tensor: + g_ratio: float = 1.) -> torch.Tensor: gumbels = ( -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() ) # ~Gumbel(0,1) @@ -328,18 +328,17 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): if g.batch_size != 1: ac_first_prob = [torch.softmax(logit, dim=0) for i, logit in enumerate(torch.split(logits_first, att_len, dim=0))] - ac_first_prob = [p+1e-8 for p in ac_first_prob] log_ac_first_prob = [x.log() for x in ac_first_prob] else: - ac_first_prob = torch.softmax(logits_first, dim=0) + 1e-8 + ac_first_prob = torch.softmax(logits_first, dim=0) log_ac_first_prob = ac_first_prob.log() if g.batch_size != 1: first_stack = [] first_ac_stack = [] for i, node_emb_i in enumerate(torch.split(att_emb, att_len, dim=0)): - ac_first_hot_i = self.gumbel_softmax(ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1) + ac_first_hot_i = self.gumbel_softmax(log_ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1) ac_first_i = torch.argmax(ac_first_hot_i, dim=-1) first_stack.append(torch.matmul(ac_first_hot_i, node_emb_i)) first_ac_stack.append(ac_first_i) @@ -361,7 +360,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, log_ac_first_prob_i in enumerate(log_ac_first_prob)], dim=0).contiguous() else: - ac_first_hot = self.gumbel_softmax(ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1) + ac_first_hot = self.gumbel_softmax(log_ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1) ac_first = torch.argmax(ac_first_hot, dim=-1) emb_first = torch.matmul(ac_first_hot, att_emb) ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros( @@ -381,18 +380,16 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): self.action2_layers[1](cand_expand) + self.action2_layers[2](emb_first_expand) logit_second = self.action2_layers[3](emb_cat).squeeze(-1) - ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8 + ac_second_prob = F.softmax(logit_second, dim=-1) log_ac_second_prob = ac_second_prob.log() - ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3) + ac_second_hot = self.gumbel_softmax(log_ac_second_prob, tau=self.tau, hard=True) emb_second = torch.matmul(ac_second_hot, cand_graph_emb) ac_second = torch.argmax(ac_second_hot, dim=-1) - - # Print gumbel otuput - ac_second_gumbel = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=False, g_ratio=1e-3) + # =============================== - # step 4 : where to add on motif + # step 3 : where to add on motif # =============================== # Select att points from candidate cand_att_emb = torch.masked_select(cand_node_emb, cand_att_mask.unsqueeze(-1)) @@ -418,12 +415,11 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): if g.batch_size != 1: ac_third_prob = [torch.softmax(logit,dim=-1) for i, logit in enumerate(torch.split(logits_third.squeeze(-1), ac3_att_len, dim=0))] - ac_third_prob = [p+1e-8 for p in ac_third_prob] log_ac_third_prob = [x.log() for x in ac_third_prob] else: logits_third = logits_third.transpose(1,0) - ac_third_prob = torch.softmax(logits_third, dim=-1) + 1e-8 + ac_third_prob = torch.softmax(logits_third, dim=-1) log_ac_third_prob = ac_third_prob.log() # gumbel softmax sampling and zero-padding @@ -431,7 +427,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): third_stack = [] third_ac_stack = [] for i, node_emb_i in enumerate(torch.split(emb_cat_ac3, ac3_att_len, dim=0)): - ac_third_hot_i = self.gumbel_softmax(ac_third_prob[i], tau=self.tau, hard=True, dim=-1) + ac_third_hot_i = self.gumbel_softmax(log_ac_third_prob[i], tau=self.tau, hard=True, dim=-1) ac_third_i = torch.argmax(ac_third_hot_i, dim=-1) third_stack.append(torch.matmul(ac_third_hot_i, node_emb_i)) third_ac_stack.append(ac_third_i) @@ -452,7 +448,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, log_ac_third_prob_i in enumerate(log_ac_third_prob)], dim=0).contiguous() else: - ac_third_hot = self.gumbel_softmax(ac_third_prob, tau=self.tau, hard=True, dim=-1) + ac_third_hot = self.gumbel_softmax(log_ac_third_prob, tau=self.tau, hard=True, dim=-1) ac_third = torch.argmax(ac_third_hot, dim=-1) emb_third = torch.matmul(ac_third_hot, emb_cat_ac3) @@ -470,7 +466,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): log_ac_second_prob, log_ac_third_prob], dim=1).contiguous() ac = torch.stack([ac_first, ac_second, ac_third], dim=1) - return ac, (ac_prob, log_ac_prob), (ac_first_prob, ac_second_hot, ac_third_prob) + return ac, (ac_prob, (log_ac_prob.exp() + 1e-8).log()), (ac_first_prob, ac_second_hot, ac_third_prob) def sample(self, ac, graph_emb, node_emb, g, cands): g.ndata['node_emb'] = node_emb @@ -522,7 +518,7 @@ def sample(self, ac, graph_emb, node_emb, g, cands): ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8 log_ac_second_prob = ac_second_prob.log() - ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3) + ac_second_hot = self.gumbel_softmax(logit_second, tau=self.tau, hard=True) emb_second = torch.matmul(ac_second_hot, cand_graph_emb) ac_second = torch.argmax(ac_second_hot, dim=-1) @@ -568,7 +564,7 @@ def sample(self, ac, graph_emb, node_emb, g, cands): log_ac_prob = torch.cat([log_ac_first_prob, log_ac_second_prob, log_ac_third_prob], dim=1).contiguous() - return (ac_prob, log_ac_prob), (ac_first_prob, ac_second_hot, ac_third_prob) + return (ac_prob, (log_ac_prob.exp() + 1e-8).log()), (ac_first_prob, ac_second_hot, ac_third_prob) class GCNEmbed(nn.Module): diff --git a/core_motif_vbased.py b/core_motif_vbased.py index 56b6223..3b84434 100644 --- a/core_motif_vbased.py +++ b/core_motif_vbased.py @@ -225,7 +225,7 @@ def create_candidate_motifs(self): def gumbel_softmax(self, logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, \ - g_ratio: float = 1e-3) -> torch.Tensor: + g_ratio: float = 1.) -> torch.Tensor: gumbels = ( -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() ) # ~Gumbel(0,1) @@ -286,18 +286,17 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): if g.batch_size != 1: ac_first_prob = [torch.softmax(logit, dim=0) for i, logit in enumerate(torch.split(logits_first, att_len, dim=0))] - ac_first_prob = [p+1e-8 for p in ac_first_prob] log_ac_first_prob = [x.log() for x in ac_first_prob] else: - ac_first_prob = torch.softmax(logits_first, dim=0) + 1e-8 + ac_first_prob = torch.softmax(logits_first, dim=0) log_ac_first_prob = ac_first_prob.log() if g.batch_size != 1: first_stack = [] first_ac_stack = [] for i, node_emb_i in enumerate(torch.split(att_emb, att_len, dim=0)): - ac_first_hot_i = self.gumbel_softmax(ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1) + ac_first_hot_i = self.gumbel_softmax(log_ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1) ac_first_i = torch.argmax(ac_first_hot_i, dim=-1) first_stack.append(torch.matmul(ac_first_hot_i, node_emb_i)) first_ac_stack.append(ac_first_i) @@ -318,7 +317,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, log_ac_first_prob_i in enumerate(log_ac_first_prob)], dim=0).contiguous() else: - ac_first_hot = self.gumbel_softmax(ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1) + ac_first_hot = self.gumbel_softmax(log_ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1) ac_first = torch.argmax(ac_first_hot, dim=-1) emb_first = torch.matmul(ac_first_hot, att_emb) ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros( @@ -328,6 +327,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): max(self.max_action - log_ac_first_prob.size(0),0),1)] , 0).contiguous().view(1,self.max_action) + # =============================== # step 2 : which motif to add - Using Descriptors # =============================== @@ -339,18 +339,16 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): self.action2_layers[1](cand_expand) + self.action2_layers[2](emb_first_expand) logit_second = self.action2_layers[3](emb_cat).squeeze(-1) - ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8 + ac_second_prob = F.softmax(logit_second, dim=-1) log_ac_second_prob = ac_second_prob.log() - ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3) + ac_second_hot = self.gumbel_softmax(log_ac_second_prob, tau=self.tau, hard=True) emb_second = torch.matmul(ac_second_hot, cand_graph_emb) ac_second = torch.argmax(ac_second_hot, dim=-1) - - # Print gumbel otuput - ac_second_gumbel = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=False, g_ratio=1e-3) + # =============================== - # step 4 : where to add on motif + # step 3 : where to add on motif # =============================== # Select att points from candidate cand_att_emb = torch.masked_select(cand_node_emb, cand_att_mask.unsqueeze(-1)) @@ -379,12 +377,11 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): if g.batch_size != 1: ac_third_prob = [torch.softmax(logit,dim=-1) for i, logit in enumerate(torch.split(logits_third.squeeze(-1), ac3_att_len, dim=0))] - ac_third_prob = [p+1e-8 for p in ac_third_prob] log_ac_third_prob = [x.log() for x in ac_third_prob] else: logits_third = logits_third.transpose(1,0) - ac_third_prob = torch.softmax(logits_third, dim=-1) + 1e-8 + ac_third_prob = torch.softmax(logits_third, dim=-1) log_ac_third_prob = ac_third_prob.log() # gumbel softmax sampling and zero-padding @@ -392,7 +389,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): third_stack = [] third_ac_stack = [] for i, node_emb_i in enumerate(torch.split(emb_cat_ac3, ac3_att_len, dim=0)): - ac_third_hot_i = self.gumbel_softmax(ac_third_prob[i], tau=self.tau, hard=True, dim=-1) + ac_third_hot_i = self.gumbel_softmax(log_ac_third_prob[i], tau=self.tau, hard=True, dim=-1) ac_third_i = torch.argmax(ac_third_hot_i, dim=-1) third_stack.append(torch.matmul(ac_third_hot_i, node_emb_i)) third_ac_stack.append(ac_third_i) @@ -413,7 +410,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, log_ac_third_prob_i in enumerate(log_ac_third_prob)], dim=0).contiguous() else: - ac_third_hot = self.gumbel_softmax(ac_third_prob, tau=self.tau, hard=True, dim=-1) + ac_third_hot = self.gumbel_softmax(log_ac_third_prob, tau=self.tau, hard=True, dim=-1) ac_third = torch.argmax(ac_third_hot, dim=-1) emb_third = torch.matmul(ac_third_hot, emb_cat_ac3) @@ -431,7 +428,7 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): log_ac_second_prob, log_ac_third_prob], dim=1).contiguous() ac = torch.stack([ac_first, ac_second, ac_third], dim=1) - return ac, ac_prob, log_ac_prob + return ac, ac_prob, (log_ac_prob.exp() + 1e-8).log() def _distribution(self, ac_prob): @@ -469,7 +466,7 @@ def sample(self, ac, graph_emb, node_emb, g, cands): + self.action1_layers[2](graph_expand) logits_first = self.action1_layers[3](att_emb).transpose(1,0) - ac_first_prob = torch.softmax(logits_first, dim=-1) + 1e-8 + ac_first_prob = torch.softmax(logits_first, dim=-1) log_ac_first_prob = ac_first_prob.log() ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros(1, @@ -491,10 +488,10 @@ def sample(self, ac, graph_emb, node_emb, g, cands): self.action2_layers[1](cand_expand) + self.action2_layers[2](emb_first_expand) logit_second = self.action2_layers[3](emb_cat).squeeze(-1) - ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8 + ac_second_prob = F.softmax(logit_second, dim=-1) log_ac_second_prob = ac_second_prob.log() - ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3) + ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True) emb_second = torch.matmul(ac_second_hot, cand_graph_emb) ac_second = torch.argmax(ac_second_hot, dim=-1) @@ -524,7 +521,7 @@ def sample(self, ac, graph_emb, node_emb, g, cands): logits_third = self.action3_layers[3](emb_cat_ac3) logits_third = logits_third.transpose(1,0) - ac_third_prob = torch.softmax(logits_third, dim=-1) + 1e-8 + ac_third_prob = torch.softmax(logits_third, dim=-1) log_ac_third_prob = ac_third_prob.log() # gumbel softmax sampling and zero-padding @@ -541,7 +538,7 @@ def sample(self, ac, graph_emb, node_emb, g, cands): log_ac_prob = torch.cat([log_ac_first_prob, log_ac_second_prob, log_ac_third_prob], dim=1).contiguous() - return ac_prob, log_ac_prob + return ac_prob, (log_ac_prob.exp() + 1e-8).log() class GCNEmbed(nn.Module): From 80f1edb5e793dcb6a9c640e2963a9cf0f5a7a01a Mon Sep 17 00:00:00 2001 From: alexander-telepov Date: Mon, 18 Apr 2022 15:40:57 +0300 Subject: [PATCH 2/2] fix log_prob padding --- core_motif.py | 24 ++++++++++++------------ core_motif_vbased.py | 25 ++++++++++++------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/core_motif.py b/core_motif.py index 84baab1..5f826d4 100644 --- a/core_motif.py +++ b/core_motif.py @@ -354,8 +354,8 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, ac_first_prob_i in enumerate(ac_first_prob)], dim=0).contiguous() log_ac_first_prob = torch.cat([ - torch.cat([log_ac_first_prob_i, log_ac_first_prob_i.new_zeros( - max(self.max_action - log_ac_first_prob_i.size(0),0),1)] + torch.cat([log_ac_first_prob_i, torch.full( + (max(self.max_action - log_ac_first_prob_i.size(0),0),1), float("-inf"), device=self.device)] , 0).contiguous().view(1,self.max_action) for i, log_ac_first_prob_i in enumerate(log_ac_first_prob)], dim=0).contiguous() @@ -366,8 +366,8 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros( max(self.max_action - ac_first_prob.size(0),0),1)] , 0).contiguous().view(1,self.max_action) - log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros( - max(self.max_action - log_ac_first_prob.size(0),0),1)] + log_ac_first_prob = torch.cat([log_ac_first_prob, torch.full( + (max(self.max_action - log_ac_first_prob.size(0),0),1), float("-inf"), device=self.device)] , 0).contiguous().view(1,self.max_action) # =============================== @@ -442,8 +442,8 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, ac_third_prob_i in enumerate(ac_third_prob)], dim=0).contiguous() log_ac_third_prob = torch.cat([ - torch.cat([log_ac_third_prob_i, log_ac_third_prob_i.new_zeros( - self.max_action - log_ac_third_prob_i.size(0))] + torch.cat([log_ac_third_prob_i, torch.full( + (self.max_action - log_ac_third_prob_i.size(0), ), float("-inf"), device=self.device)] , 0).contiguous().view(1,self.max_action) for i, log_ac_third_prob_i in enumerate(log_ac_third_prob)], dim=0).contiguous() @@ -455,8 +455,8 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): ac_third_prob = torch.cat([ac_third_prob, ac_third_prob.new_zeros( 1, self.max_action - ac_third_prob.size(1))] , -1).contiguous() - log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros( - 1, self.max_action - log_ac_third_prob.size(1))] + log_ac_third_prob = torch.cat([log_ac_third_prob, torch.full( + (1, self.max_action - log_ac_third_prob.size(1)), float("-inf"), device=self.device)] , -1).contiguous() # ==== concat everything ==== @@ -500,8 +500,8 @@ def sample(self, ac, graph_emb, node_emb, g, cands): max(self.max_action - ac_first_prob.size(1),0))] , 1).contiguous() - log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros(1, - max(self.max_action - log_ac_first_prob.size(1),0))] + log_ac_first_prob = torch.cat([log_ac_first_prob, torch.full((1, + max(self.max_action - log_ac_first_prob.size(1),0)), float("-inf"), device=self.device)] , 1).contiguous() emb_first = att_emb[ac[0]].unsqueeze(0) @@ -555,8 +555,8 @@ def sample(self, ac, graph_emb, node_emb, g, cands): ac_third_prob = torch.cat([ac_third_prob, ac_third_prob.new_zeros( 1, self.max_action - ac_third_prob.size(1))] , -1).contiguous() - log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros( - 1, self.max_action - log_ac_third_prob.size(1))] + log_ac_third_prob = torch.cat([log_ac_third_prob, torch.full( + (1, self.max_action - log_ac_third_prob.size(1)), float("-inf"), device=self.device)] , -1).contiguous() # ==== concat everything ==== diff --git a/core_motif_vbased.py b/core_motif_vbased.py index 3b84434..165066f 100644 --- a/core_motif_vbased.py +++ b/core_motif_vbased.py @@ -311,8 +311,8 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, ac_first_prob_i in enumerate(ac_first_prob)], dim=0).contiguous() log_ac_first_prob = torch.cat([ - torch.cat([log_ac_first_prob_i, log_ac_first_prob_i.new_zeros( - max(self.max_action - log_ac_first_prob_i.size(0),0),1)] + torch.cat([log_ac_first_prob_i, torch.full( + (max(self.max_action - log_ac_first_prob_i.size(0),0),1), float("-inf"), device=self.device)] , 0).contiguous().view(1,self.max_action) for i, log_ac_first_prob_i in enumerate(log_ac_first_prob)], dim=0).contiguous() @@ -323,11 +323,10 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros( max(self.max_action - ac_first_prob.size(0),0),1)] , 0).contiguous().view(1,self.max_action) - log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros( - max(self.max_action - log_ac_first_prob.size(0),0),1)] + log_ac_first_prob = torch.cat([log_ac_first_prob, torch.full( + (max(self.max_action - log_ac_first_prob.size(0),0),1), float("-inf"), device=self.device)] , 0).contiguous().view(1,self.max_action) - # =============================== # step 2 : which motif to add - Using Descriptors # =============================== @@ -404,8 +403,8 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): for i, ac_third_prob_i in enumerate(ac_third_prob)], dim=0).contiguous() log_ac_third_prob = torch.cat([ - torch.cat([log_ac_third_prob_i, log_ac_third_prob_i.new_zeros( - self.max_action - log_ac_third_prob_i.size(0))] + torch.cat([log_ac_third_prob_i, torch.full( + (self.max_action - log_ac_third_prob_i.size(0), ), float("-inf"), device=self.device)] , 0).contiguous().view(1,self.max_action) for i, log_ac_third_prob_i in enumerate(log_ac_third_prob)], dim=0).contiguous() @@ -417,8 +416,8 @@ def forward(self, graph_emb, node_emb, g, cands, deterministic=False): ac_third_prob = torch.cat([ac_third_prob, ac_third_prob.new_zeros( 1, self.max_action - ac_third_prob.size(1))] , -1).contiguous() - log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros( - 1, self.max_action - log_ac_third_prob.size(1))] + log_ac_third_prob = torch.cat([log_ac_third_prob, torch.full( + (1, self.max_action - log_ac_third_prob.size(1)), float("-inf"), device=self.device)] , -1).contiguous() # ==== concat everything ==== @@ -473,8 +472,8 @@ def sample(self, ac, graph_emb, node_emb, g, cands): max(self.max_action - ac_first_prob.size(1),0))] , 1).contiguous() - log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros(1, - max(self.max_action - log_ac_first_prob.size(1),0))] + log_ac_first_prob = torch.cat([log_ac_first_prob, torch.full((1, + max(self.max_action - log_ac_first_prob.size(1),0)), float("-inf"), device=self.device)] , 1).contiguous() emb_first = att_emb[ac[0]].unsqueeze(0) @@ -529,8 +528,8 @@ def sample(self, ac, graph_emb, node_emb, g, cands): ac_third_prob = torch.cat([ac_third_prob, ac_third_prob.new_zeros( 1, self.max_action - ac_third_prob.size(1))] , -1).contiguous() - log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros( - 1, self.max_action - log_ac_third_prob.size(1))] + log_ac_third_prob = torch.cat([log_ac_third_prob, torch.full( + (1, self.max_action - log_ac_third_prob.size(1)), float("-inf"), device=self.device)] , -1).contiguous() # ==== concat everything ====