-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
Hi,
just found a subtle bug / inconsistency in the regression layers that I recognized when creating posterior sample from the predictive.
For the standard MVN case, W is defined as a method:
def W(self):
cov_diag = torch.exp(self.W_logdiag)
if self.W_dist == Normal:
cov = self.W_dist(self.W_mean, cov_diag)
elif self.W_dist == DenseNormal:
tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
cov = self.W_dist(self.W_mean, tril)
elif self.W_dist == LowRankNormal:
cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag)
return covwhereas for the t-VBLL regression layer, it is defined as a property:
@property
def W(self):
cov_diag = torch.exp(self.W_logdiag)
if self.W_dist == Normal:
cov = self.W_dist(self.W_mean, cov_diag)
elif self.W_dist == DenseNormal:
tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
cov = self.W_dist(self.W_mean, tril)
elif self.W_dist == LowRankNormal:
cov = self.W_dist(self.W_mean, self.W_offdiag, cov_diag)
return covThis than alters the way to sample from W:
- for VBLL:
layer.W().rsample() - for tVBLL:
layer.W.rsample()
I personally prefer W as a property. Happy to create a PR for this but wanted to double check with you guys.
EDIT: Just checked, same holds for the classification case.
Metadata
Metadata
Assignees
Labels
No labels