Skip to content

Commit 896441c

Browse files
committed
2026-03-19
1 parent fafaff3 commit 896441c

4 files changed

Lines changed: 75 additions & 63 deletions

File tree

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,34 @@
11
import torch
22
import torch.nn as nn
3-
import torch.nn.functional as F
43
from torch import Tensor
5-
from torch.distributions import Gamma
64

75

86
class GroupEncoder(nn.Module):
9-
"""Infer per-group rate posteriors q(τ_k) from local encoder features.
7+
"""Infer per-group rate posteriors q(log τ_k) from local encoder features.
8+
9+
Works in log-space: q(log τ_k) = Normal(μ_k, σ_k²), then τ_k = exp(log τ_k).
10+
This avoids Gamma rsample instabilities when τ → 0.
1011
1112
Architecture (DeepSets):
12-
x_i → φ(x_i) → mean-pool by group → ρ(·) → (α_k, β_k)
13+
x_i → φ(x_i) → mean-pool by group → ρ(·) → (μ_k, logvar_k)
1314
1415
Parameters
1516
----------
1617
encoder_out : int
1718
Dimension of per-reflection encoder features.
1819
hidden_dim : int
1920
Width of φ and ρ hidden layers.
21+
log_tau_init : float
22+
Initial bias for head_mu (should match prior mean, e.g. -6.9).
2023
"""
2124

22-
def __init__(self, encoder_out: int, hidden_dim: int = 64, alpha_min: float = 0.1):
25+
def __init__(
26+
self,
27+
encoder_out: int,
28+
hidden_dim: int = 64,
29+
log_tau_init: float = -6.9,
30+
):
2331
super().__init__()
24-
self.alpha_min = alpha_min
2532

2633
# φ: per-element transform (before pooling)
2734
self.phi = nn.Sequential(
@@ -37,19 +44,19 @@ def __init__(self, encoder_out: int, hidden_dim: int = 64, alpha_min: float = 0.
3744
nn.SiLU(),
3845
)
3946

40-
self.head_alpha = nn.Linear(hidden_dim, 1)
41-
self.head_beta = nn.Linear(hidden_dim, 1)
47+
self.head_mu = nn.Linear(hidden_dim, 1)
48+
self.head_logvar = nn.Linear(hidden_dim, 1)
4249

43-
nn.init.zeros_(self.head_alpha.weight)
44-
nn.init.zeros_(self.head_alpha.bias)
45-
nn.init.zeros_(self.head_beta.weight)
46-
nn.init.zeros_(self.head_beta.bias)
50+
nn.init.zeros_(self.head_mu.weight)
51+
nn.init.constant_(self.head_mu.bias, log_tau_init)
52+
nn.init.zeros_(self.head_logvar.weight)
53+
nn.init.constant_(self.head_logvar.bias, -2.0)
4754

4855
def forward(
4956
self,
5057
x: Tensor,
5158
group_labels: Tensor,
52-
) -> tuple[Gamma, Tensor]:
59+
) -> tuple[Tensor, Tensor, Tensor]:
5360
"""
5461
Parameters
5562
----------
@@ -60,15 +67,18 @@ def forward(
6067
6168
Returns
6269
-------
63-
q_tau : Gamma with batch shape (n_groups_in_batch,)
64-
tau_per_refl : (B, 1) sampled τ_k broadcast to each reflection.
70+
mu : (n_groups,)
71+
Posterior mean of log τ_k.
72+
logvar : (n_groups,)
73+
Posterior log-variance of log τ_k.
74+
tau_per_refl : (B, 1)
75+
Sampled τ_k = exp(log τ_k) broadcast to each reflection.
6576
"""
6677
# φ: transform each reflection
6778
z = self.phi(x) # (B, hidden_dim)
6879

6980
# Mean-pool by group (simple loop — K is small)
7081
unique_groups = torch.unique(group_labels)
71-
n_groups = unique_groups.shape[0]
7282

7383
group_means = []
7484
for k in unique_groups:
@@ -77,21 +87,20 @@ def forward(
7787

7888
group_features = torch.stack(group_means) # (n_groups, hidden_dim)
7989

80-
# ρ: per-group transform → Gamma params
90+
# ρ: per-group transform → Normal params in log-space
8191
h = self.rho(group_features) # (n_groups, hidden_dim)
8292

83-
alpha = (
84-
F.softplus(self.head_alpha(h)).squeeze(-1) + self.alpha_min
85-
) # (n_groups,)
86-
beta = F.softplus(self.head_beta(h)).squeeze(-1) + self.alpha_min # (n_groups,)
87-
88-
q_tau = Gamma(concentration=alpha, rate=beta)
93+
mu = self.head_mu(h).squeeze(-1) # (n_groups,)
94+
logvar = self.head_logvar(h).squeeze(-1).clamp(-10.0, 4.0) # (n_groups,)
8995

90-
# Sample one τ per group, broadcast to reflections
91-
tau_group = q_tau.rsample() # (n_groups,)
96+
# Reparameterized sample: log τ_k = μ_k + σ_k * ε, ε ~ N(0,1)
97+
std = torch.exp(0.5 * logvar)
98+
eps = torch.randn_like(std)
99+
log_tau = mu + std * eps # (n_groups,)
100+
tau_group = torch.exp(log_tau) # (n_groups,), always positive
92101

93102
# Map back: unique_groups is sorted (from torch.unique), so use searchsorted
94103
indices = torch.searchsorted(unique_groups, group_labels)
95104
tau_per_refl = tau_group[indices].unsqueeze(1) # (B, 1)
96105

97-
return q_tau, tau_per_refl
106+
return mu, logvar, tau_per_refl

src/integrator/model/integrators/hierarchical_integrator.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Hierarchical Integrator: groups reflections and learns per-group intensity priors.
22
33
Uses the standard encoder architecture but adds:
4-
1. A GroupEncoder that pools local features by radial bin → q(τ_k)
5-
2. Conditioning of the qi surrogate on the sampled τ_k
4+
1. A GroupEncoder that pools local features by radial bin → q(log τ_k)
5+
2. Conditioning of the qi surrogate on the sampled log τ_k
66
"""
77

88
from typing import Any, Literal
@@ -67,11 +67,12 @@ def _forward_impl(
6767
x_profile = self.encoders["profile"](shoebox_reshaped)
6868
x_intensity = self.encoders["intensity"](shoebox_reshaped)
6969

70-
# Group encoder: pool by radial bin -> sample τ_k per reflection
70+
# Group encoder: pool by radial bin sample τ_k in log-space
7171
group_labels = metadata["group_label"].long()
72-
q_tau, tau_per_refl = self.group_encoder(x_intensity, group_labels)
72+
mu, logvar, tau_per_refl = self.group_encoder(x_intensity, group_labels)
7373

7474
# Condition qi on τ_k: concatenate log(τ) to intensity features
75+
# tau_per_refl = exp(log_tau), so log(tau_per_refl) recovers log_tau
7576
log_tau = torch.log(tau_per_refl + 1e-6)
7677
x_intensity_cond = torch.cat([x_intensity, log_tau], dim=-1)
7778

@@ -103,20 +104,20 @@ def _forward_impl(
103104
)
104105
out = _assemble_outputs(out)
105106

106-
# Store q(τ_k) parameters and per-reflection τ for SBC / prediction
107-
# q_tau has batch shape [n_groups_in_batch]; scatter back to [B]
107+
# Store q(log τ_k) parameters and per-reflection τ for SBC / prediction
108108
_, inv = torch.unique(group_labels, return_inverse=True)
109-
out["tau_per_refl"] = tau_per_refl.squeeze(-1) # [B]
110-
out["q_tau_concentration"] = q_tau.concentration[inv] # [B]
111-
out["q_tau_rate"] = q_tau.rate[inv] # [B]
112-
out["group_label"] = group_labels # [B]
109+
out["tau_per_refl"] = tau_per_refl.squeeze(-1) # [B]
110+
out["q_log_tau_mu"] = mu[inv] # [B]
111+
out["q_log_tau_logvar"] = logvar[inv] # [B]
112+
out["group_label"] = group_labels # [B]
113113

114114
return {
115115
"forward_out": out,
116116
"qp": qp,
117117
"qi": qi,
118118
"qbg": qbg,
119-
"q_tau": q_tau,
119+
"mu": mu,
120+
"logvar": logvar,
120121
"tau_per_refl": tau_per_refl,
121122
"group_labels": group_labels,
122123
}
@@ -133,7 +134,8 @@ def _step(self, batch, step: Literal["train", "val"]):
133134
qi=outputs["qi"],
134135
qbg=outputs["qbg"],
135136
mask=forward_out["mask"],
136-
q_tau=outputs["q_tau"],
137+
mu=outputs["mu"],
138+
logvar=outputs["logvar"],
137139
tau_per_refl=outputs["tau_per_refl"],
138140
group_labels=outputs["group_labels"],
139141
)

src/integrator/model/loss/hierarchical_shoebox_loss.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
- KL( q(prf) || p(prf) ) — profile prior
66
- KL( q(I_i) || Exp(τ_{k(i)}) ) — adaptive intensity prior
77
- KL( q(bg_i) || Exp(λ_bg) ) — background prior
8-
- (1/N) Σ_k KL( q(τ_k) || Gamma(α, β) ) — global hyperprior
8+
- (1/N) Σ_k KL( q(log τ_k) || N(μ_0, σ_0²) ) — global hyperprior
99
"""
1010

1111
import torch
1212
import torch.nn as nn
1313
from torch import Tensor
14-
from torch.distributions import Distribution, Exponential, Gamma, Poisson
14+
from torch.distributions import Distribution, Gamma, Poisson
1515

1616
from integrator.configs.priors import PriorConfig
1717
from integrator.model.distributions.logistic_normal import ProfilePosterior
@@ -27,8 +27,8 @@ class HierarchicalShoeboxLoss(nn.Module):
2727
"""ELBO loss with per-group learned Exponential intensity priors.
2828
2929
The intensity prior for reflection i in group k is Exp(τ_k), where
30-
q(τ_k) = Gamma(α_q, β_q) is learned by the GroupEncoder. The global
31-
hyperprior is p(τ_k) = Gamma(hp_alpha, hp_beta).
30+
q(log τ_k) = Normal(μ_k, σ_k²) is learned by the GroupEncoder.
31+
The global hyperprior is p(log τ_k) = Normal(log_tau_mu, log_tau_sigma²).
3232
3333
Parameters
3434
----------
@@ -42,10 +42,10 @@ class HierarchicalShoeboxLoss(nn.Module):
4242
Monte Carlo samples for KL estimation.
4343
eps : float
4444
Numerical stability constant.
45-
hp_alpha : float
46-
Hyperprior Gamma concentration for τ_k.
47-
hp_beta : float
48-
Hyperprior Gamma rate for τ_k.
45+
log_tau_mu : float
46+
Prior mean for log τ_k (Normal hyperprior).
47+
log_tau_sigma : float
48+
Prior std for log τ_k (Normal hyperprior).
4949
dataset_size : int
5050
Total training set size N for global KL scaling.
5151
"""
@@ -58,15 +58,15 @@ def __init__(
5858
pi_cfg: PriorConfig | None = None,
5959
mc_samples: int = 4,
6060
eps: float = 1e-6,
61-
hp_alpha: float = 2.0,
62-
hp_beta: float = 1.0,
61+
log_tau_mu: float = -6.9,
62+
log_tau_sigma: float = 1.0,
6363
dataset_size: int = 1,
6464
):
6565
super().__init__()
6666
self.mc_samples = mc_samples
6767
self.eps = eps
68-
self.hp_alpha = hp_alpha
69-
self.hp_beta = hp_beta
68+
self.log_tau_mu = log_tau_mu
69+
self.log_tau_sigma = log_tau_sigma
7070
self.dataset_size = dataset_size
7171

7272
# Profile prior (Dirichlet path — ignored when qp is ProfilePosterior)
@@ -94,7 +94,8 @@ def forward(
9494
qi: Distribution,
9595
qbg: Distribution,
9696
mask: Tensor,
97-
q_tau: Gamma,
97+
mu: Tensor,
98+
logvar: Tensor,
9899
tau_per_refl: Tensor,
99100
group_labels: Tensor,
100101
) -> dict[str, Tensor]:
@@ -126,11 +127,6 @@ def forward(
126127
kl = kl + kl_prf
127128

128129
# ── Intensity KL: KL(q(I_i) || Exp(τ_{k(i)})) ──────────────
129-
# tau_per_refl is [B, 1]; flatten to [B] for the Exponential rate
130-
tau_flat = tau_per_refl.squeeze(-1).detach() # stop gradient to τ for this KL
131-
# Actually we want gradients through τ for the global KL, but the
132-
# per-reflection intensity KL should use the same τ sample.
133-
# Re-enable gradient: use tau_per_refl directly (no detach).
134130
tau_flat = tau_per_refl.squeeze(-1)
135131
p_i = Gamma(
136132
concentration=torch.ones_like(tau_flat),
@@ -152,13 +148,18 @@ def forward(
152148
)
153149
kl = kl + kl_bg
154150

155-
# ── Global KL: KL(q(τ_k) || Gamma(α, β)) / N ───────────────
156-
p_tau = Gamma(
157-
concentration=torch.tensor(self.hp_alpha, device=device),
158-
rate=torch.tensor(self.hp_beta, device=device),
159-
)
151+
# ── Global KL: KL(N(μ_k, σ_k²) || N(μ_0, σ_0²)) / N ─────
152+
sigma_q_sq = logvar.exp() # (n_groups,)
153+
sigma_p_sq = self.log_tau_sigma**2
154+
160155
kl_global = (
161-
torch.distributions.kl.kl_divergence(q_tau, p_tau).sum()
156+
0.5
157+
* (
158+
sigma_q_sq / sigma_p_sq
159+
+ (mu - self.log_tau_mu) ** 2 / sigma_p_sq
160+
- 1.0
161+
- torch.log(sigma_q_sq / sigma_p_sq)
162+
).sum()
162163
/ self.dataset_size
163164
)
164165

src/integrator/utils/factory_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _get_loss_module(
191191
kwargs = shallow_dict(loss_args)
192192

193193
# Forward extra keys from loss.args for custom loss classes
194-
# (e.g. hp_alpha, hp_beta for HierarchicalShoeboxLoss)
194+
# (e.g. log_tau_mu, log_tau_sigma for HierarchicalShoeboxLoss)
195195
standard_keys = {"mc_samples", "eps", "pprf_cfg", "pbg_cfg", "pi_cfg"}
196196
for k, v in cfg["loss"]["args"].items():
197197
if k not in standard_keys:

0 commit comments

Comments
 (0)