Skip to content

Commit 833c17b

Browse files
committed
removed all mentions of edm
1 parent 8f0c65a commit 833c17b

File tree

1 file changed

+10
-166
lines changed

1 file changed

+10
-166
lines changed

sbi/neural_nets/estimators/score_estimator.py

Lines changed: 10 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,13 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
128128
std = self.approx_marginal_std(time)
129129

130130
# As input to the neural net we want to have something that changes proportianl
131-
# to how the scores change (a la c_noise in edm)
131+
# to how the scores change
132132
time_enc = self.std_fn(time)
133133

134-
# Time dependent z-scoring! Keeps input at similar scales (c_in in edm)
134+
# Time dependent z-scoring! Keeps input at similar scales
135135
input_enc = (input - mean) / std
136136

137137
# Approximate score becoming exact for t -> t_max, "skip connection"
138-
# (a la c_skip in edm)
139138
score_gaussian = (input - mean) / std**2
140139

141140
# Score prediction by the network
@@ -145,7 +144,6 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
145144
# The learnable part will be largly scaled at the beginning of the diffusion
146145
# and the gaussian part (where it should end up) will dominate at the end of
147146
# the diffusion.
148-
# (a la c_out in edm)
149147
scale = self.mean_t_fn(time) / self.std_fn(time)
150148
output_score = -scale * score_pred - score_gaussian
151149

@@ -182,7 +180,7 @@ def loss(
182180
# update device if required
183181
self.device = input.device if self.device != input.device else self.device
184182

185-
# Sample times from the Markov chain
183+
# Sample times from the Markov chain, use batch dimension
186184
if times is None:
187185
times = self.times_schedule(input.shape[0])
188186

@@ -355,7 +353,7 @@ def times_schedule(
355353
self, num_samples: int, t_min: float = None, t_max: float = None
356354
) -> Tensor:
357355
"""
358-
Construction time samples for evaluating the diffusion model.
356+
Time samples for evaluating the diffusion model.
359357
Perform uniform sampling of time variables within the range [t_min, t_max].
360358
The `times` tensor will be put on the same device as the stored network.
361359
@@ -374,7 +372,12 @@ def times_schedule(
374372
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
375373
t_max = self.t_max if isinstance(t_max, type(None)) else t_max
376374

377-
return torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min
375+
times = torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min
376+
377+
# t_min and t_max need to be part of the sequence
378+
times[0,...] = t_min
379+
times[-1,...] = t_max
380+
return torch.Tensor(sorted(times))
378381

379382
def _set_weight_fn(self, weight_fn: Union[str, Callable]):
380383
"""Set the weight function.
@@ -495,162 +498,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
495498
return g
496499

497500

498-
class ImprovedScoreEstimator(ConditionalScoreEstimator):
499-
"""Implement EDM-like score matching estimator as in [1]
500-
501-
[1] Karras et al "Elucidating the Design Space of Diffusion-Based
502-
Generative Models", https://arxiv.org/abs/2206.00364
503-
"""
504-
505-
def __init__(
506-
self,
507-
net: nn.Module,
508-
input_shape: torch.Size,
509-
condition_shape: torch.Size,
510-
weight_fn: Union[str, Callable] = "max_likelihood",
511-
beta_min: float = 0.002, # sigma_min in the paper
512-
beta_max: float = 80.0, # sigma_max in the paper
513-
beta_data: float = .5, #sigma_data in the paper
514-
mean_0: Union[Tensor, float] = 0.0,
515-
std_0: Union[Tensor, float] = 1.0,
516-
t_min: float = 1e-5, # will be ignored due to EDM setup
517-
t_max: float = 1.0, #
518-
pmean: float = -1.2, # mean of noise scheme for training
519-
pstd: float = 1.2, # std of noise scheme for training
520-
sigma_data: float = 0.5,
521-
) -> None:
522-
523-
524-
#TODO: store sigma values for training in extra field
525-
self.pmean, self.pstd = pmean, pstd
526-
noise_dist = stats.norm(pmean, pstd**2)
527-
self.sigma_min = exp(noise_dist.ppf(0.01))
528-
self.sigma_max = exp(noise_dist.ppf(0.99))
529-
530-
self.beta_data = beta_data #sigma data from edm paper
531-
self.rho = 7
532-
533-
super().__init__(
534-
net,
535-
input_shape,
536-
condition_shape,
537-
mean_0=mean_0,
538-
std_0=std_0,
539-
weight_fn=weight_fn,
540-
beta_min=beta_min,
541-
beta_max=beta_max,
542-
t_min=t_min,
543-
t_max=t_max,
544-
)
545-
546-
def mean_t_fn(self, times: Tensor) -> Tensor:
547-
"""Conditional mean function for EDM-style DMs.
548-
This is required to model c_in.
549-
550-
Args:
551-
times: time variable in [0,1].
552-
553-
Returns:
554-
Conditional mean at a given time.
555-
"""
556-
noise = self.noise_schedule(times)
557-
phi = 1./torch.sqrt(noise**2 + self.beta_data**2)
558-
for _ in range(len(self.input_shape)):
559-
phi = phi.unsqueeze(-1)
560-
return phi
561-
562-
def std_fn(self, times: Tensor) -> Tensor:
563-
"""Standard deviation function for EDM style DMs.
564-
This is akin to c_noise in the network/precond parametrisation.
565-
Args:
566-
times: time variable in [0,1].
567-
568-
Returns:
569-
Standard deviation at a given time.
570-
"""
571-
std = .25*torch.log(times)
572-
for _ in range(len(self.input_shape)):
573-
std = std.unsqueeze(-1)
574-
return std
575-
576-
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
577-
"""Drift function for variance preserving SDEs.
578-
579-
Args:
580-
input: Original data, x0.
581-
times: SDE time variable in [0,1].
582-
583-
Returns:
584-
Drift function at a given time.
585-
"""
586-
phi = -0.5 * self.noise_schedule(times)
587-
while len(phi.shape) < len(input.shape):
588-
phi = phi.unsqueeze(-1)
589-
return phi * input
590-
591-
def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
592-
"""Diffusion function for variance preserving SDEs.
593-
594-
Args:
595-
input: Original data, x0.
596-
times: SDE time variable in [0,1].
597-
598-
Returns:
599-
Drift function at a given time.
600-
"""
601-
g = torch.sqrt(self.noise_schedule(times))
602-
while len(g.shape) < len(input.shape):
603-
g = g.unsqueeze(-1)
604-
return g
605-
606-
def noise_schedule(self, times: Tensor) -> Tensor:
607-
"""
608-
Generate a beta schedule similar to suggestions in the EDM [1] paper.
609-
610-
This method acts as a fallback in case derivative classes do not
611-
implement it on their own. It calculates a linear beta schedule defined
612-
by the input `times`, which represent the normalized time steps t ∈ [0, 1].
613-
614-
Args:
615-
times (Tensor):
616-
SDE times in [0, 1]. This tensor will be regenerated from
617-
self.times_schedule
618-
619-
Returns:
620-
Tensor: Generated beta schedule at a given time.
621-
622-
[1] Karras et al "Elucidating the Design Space of Diffusion-Based
623-
Generative Models", https://arxiv.org/abs/2206.00364
624-
"""
625-
return times
626-
627-
def times_schedule(
628-
self, num_samples: int, t_min: float = None, t_max: float = None
629-
) -> Tensor:
630-
"""
631-
Construct time samples as suggested in EDM paper [1].
632-
633-
Args:
634-
num_samples (int): Number of samples to generate.
635-
t_min (float, optional): The minimum time value. Defaults to self.t_min.
636-
t_max (float, optional): The maximum time value. Defaults to self.t_max.
637-
638-
Returns:
639-
Tensor: A tensor of sampled time variables scaled and shifted to
640-
the range [0,1].
641-
642-
[1] Karras et al "Elucidating the Design Space of Diffusion-Based
643-
Generative Models", https://arxiv.org/abs/2206.00364
644-
"""
645-
times = torch.linspace(0.0, 1.0, steps=num_samples)
646-
inv_rho = 1.0 / self.rho
647-
648-
beta_scale = self.beta_max ** (inv_rho) - self.beta_min ** (inv_rho)
649-
offset = self.beta_min ** (inv_rho)
650-
651-
return (offset + beta_scale * times) ** (self.rho)
652-
653-
654501
class SubVPScoreEstimator(ConditionalScoreEstimator):
655502
"""Class for score estimators with sub-variance preserving SDEs."""
656503

@@ -842,9 +689,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
842689
return g
843690

844691

845-
# TODO: try to add a EDM-like estimator
846-
847-
848692
class GaussianFourierTimeEmbedding(nn.Module):
849693
"""Gaussian random features for encoding time steps.
850694

0 commit comments

Comments
 (0)