11import torch
22import torch .nn as nn
3- import torch .nn .functional as F
43from torch import Tensor
5- from torch .distributions import Gamma
64
75
86class GroupEncoder (nn .Module ):
9- """Infer per-group rate posteriors q(τ_k) from local encoder features.
7+ """Infer per-group rate posteriors q(log τ_k) from local encoder features.
8+
9+ Works in log-space: q(log τ_k) = Normal(μ_k, σ_k²), then τ_k = exp(log τ_k).
10+ This avoids Gamma rsample instabilities when τ → 0.
1011
1112 Architecture (DeepSets):
12- x_i → φ(x_i) → mean-pool by group → ρ(·) → (α_k, β_k )
13+ x_i → φ(x_i) → mean-pool by group → ρ(·) → (μ_k, logvar_k )
1314
1415 Parameters
1516 ----------
1617 encoder_out : int
1718 Dimension of per-reflection encoder features.
1819 hidden_dim : int
1920 Width of φ and ρ hidden layers.
21+ log_tau_init : float
22+ Initial bias for head_mu (should match prior mean, e.g. -6.9).
2023 """
2124
22- def __init__ (self , encoder_out : int , hidden_dim : int = 64 , alpha_min : float = 0.1 ):
25+ def __init__ (
26+ self ,
27+ encoder_out : int ,
28+ hidden_dim : int = 64 ,
29+ log_tau_init : float = - 6.9 ,
30+ ):
2331 super ().__init__ ()
24- self .alpha_min = alpha_min
2532
2633 # φ: per-element transform (before pooling)
2734 self .phi = nn .Sequential (
@@ -37,19 +44,19 @@ def __init__(self, encoder_out: int, hidden_dim: int = 64, alpha_min: float = 0.
3744 nn .SiLU (),
3845 )
3946
40- self .head_alpha = nn .Linear (hidden_dim , 1 )
41- self .head_beta = nn .Linear (hidden_dim , 1 )
47+ self .head_mu = nn .Linear (hidden_dim , 1 )
48+ self .head_logvar = nn .Linear (hidden_dim , 1 )
4249
43- nn .init .zeros_ (self .head_alpha .weight )
44- nn .init .zeros_ (self .head_alpha .bias )
45- nn .init .zeros_ (self .head_beta .weight )
46- nn .init .zeros_ (self .head_beta .bias )
50+ nn .init .zeros_ (self .head_mu .weight )
51+ nn .init .constant_ (self .head_mu .bias , log_tau_init )
52+ nn .init .zeros_ (self .head_logvar .weight )
53+ nn .init .constant_ (self .head_logvar .bias , - 2.0 )
4754
4855 def forward (
4956 self ,
5057 x : Tensor ,
5158 group_labels : Tensor ,
52- ) -> tuple [Gamma , Tensor ]:
59+ ) -> tuple [Tensor , Tensor , Tensor ]:
5360 """
5461 Parameters
5562 ----------
@@ -60,15 +67,18 @@ def forward(
6067
6168 Returns
6269 -------
63- q_tau : Gamma with batch shape (n_groups_in_batch,)
64- tau_per_refl : (B, 1) sampled τ_k broadcast to each reflection.
70+ mu : (n_groups,)
71+ Posterior mean of log τ_k.
72+ logvar : (n_groups,)
73+ Posterior log-variance of log τ_k.
74+ tau_per_refl : (B, 1)
75+ Sampled τ_k = exp(log τ_k) broadcast to each reflection.
6576 """
6677 # φ: transform each reflection
6778 z = self .phi (x ) # (B, hidden_dim)
6879
6980 # Mean-pool by group (simple loop — K is small)
7081 unique_groups = torch .unique (group_labels )
71- n_groups = unique_groups .shape [0 ]
7282
7383 group_means = []
7484 for k in unique_groups :
@@ -77,21 +87,20 @@ def forward(
7787
7888 group_features = torch .stack (group_means ) # (n_groups, hidden_dim)
7989
80- # ρ: per-group transform → Gamma params
90+ # ρ: per-group transform → Normal params in log-space
8191 h = self .rho (group_features ) # (n_groups, hidden_dim)
8292
83- alpha = (
84- F .softplus (self .head_alpha (h )).squeeze (- 1 ) + self .alpha_min
85- ) # (n_groups,)
86- beta = F .softplus (self .head_beta (h )).squeeze (- 1 ) + self .alpha_min # (n_groups,)
87-
88- q_tau = Gamma (concentration = alpha , rate = beta )
93+ mu = self .head_mu (h ).squeeze (- 1 ) # (n_groups,)
94+ logvar = self .head_logvar (h ).squeeze (- 1 ).clamp (- 10.0 , 4.0 ) # (n_groups,)
8995
90- # Sample one τ per group, broadcast to reflections
91- tau_group = q_tau .rsample () # (n_groups,)
96+ # Reparameterized sample: log τ_k = μ_k + σ_k * ε, ε ~ N(0,1)
97+ std = torch .exp (0.5 * logvar )
98+ eps = torch .randn_like (std )
99+ log_tau = mu + std * eps # (n_groups,)
100+ tau_group = torch .exp (log_tau ) # (n_groups,), always positive
92101
93102 # Map back: unique_groups is sorted (from torch.unique), so use searchsorted
94103 indices = torch .searchsorted (unique_groups , group_labels )
95104 tau_per_refl = tau_group [indices ].unsqueeze (1 ) # (B, 1)
96105
97- return q_tau , tau_per_refl
106+ return mu , logvar , tau_per_refl
0 commit comments