@@ -128,14 +128,13 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
128128 std = self .approx_marginal_std (time )
129129
130130 # As input to the neural net we want to have something that changes proportianl
131- # to how the scores change (a la c_noise in edm)
131+ # to how the scores change
132132 time_enc = self .std_fn (time )
133133
134- # Time dependent z-scoring! Keeps input at similar scales (c_in in edm)
134+ # Time dependent z-scoring! Keeps input at similar scales
135135 input_enc = (input - mean ) / std
136136
137137 # Approximate score becoming exact for t -> t_max, "skip connection"
138- # (a la c_skip in edm)
139138 score_gaussian = (input - mean ) / std ** 2
140139
141140 # Score prediction by the network
@@ -145,7 +144,6 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
145144 # The learnable part will be largly scaled at the beginning of the diffusion
146145 # and the gaussian part (where it should end up) will dominate at the end of
147146 # the diffusion.
148- # (a la c_out in edm)
149147 scale = self .mean_t_fn (time ) / self .std_fn (time )
150148 output_score = - scale * score_pred - score_gaussian
151149
@@ -182,7 +180,7 @@ def loss(
182180 # update device if required
183181 self .device = input .device if self .device != input .device else self .device
184182
185- # Sample times from the Markov chain
183+ # Sample times from the Markov chain, use batch dimension
186184 if times is None :
187185 times = self .times_schedule (input .shape [0 ])
188186
@@ -355,7 +353,7 @@ def times_schedule(
355353 self , num_samples : int , t_min : float = None , t_max : float = None
356354 ) -> Tensor :
357355 """
358- Construction time samples for evaluating the diffusion model.
356+ Time samples for evaluating the diffusion model.
359357 Perform uniform sampling of time variables within the range [t_min, t_max].
360358 The `times` tensor will be put on the same device as the stored network.
361359
@@ -374,7 +372,12 @@ def times_schedule(
374372 t_min = self .t_min if isinstance (t_min , type (None )) else t_min
375373 t_max = self .t_max if isinstance (t_max , type (None )) else t_max
376374
377- return torch .rand (num_samples , device = self .device ) * (t_max - t_min ) + t_min
375+ times = torch .rand (num_samples , device = self .device ) * (t_max - t_min ) + t_min
376+
377+ # t_min and t_max need to be part of the sequence
378+ times [0 ,...] = t_min
379+ times [- 1 ,...] = t_max
380+ return torch .Tensor (sorted (times ))
378381
379382 def _set_weight_fn (self , weight_fn : Union [str , Callable ]):
380383 """Set the weight function.
@@ -495,162 +498,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
495498 return g
496499
497500
498- class ImprovedScoreEstimator (ConditionalScoreEstimator ):
499- """Implement EDM-like score matching estimator as in [1]
500-
501- [1] Karras et al "Elucidating the Design Space of Diffusion-Based
502- Generative Models", https://arxiv.org/abs/2206.00364
503- """
504-
505- def __init__ (
506- self ,
507- net : nn .Module ,
508- input_shape : torch .Size ,
509- condition_shape : torch .Size ,
510- weight_fn : Union [str , Callable ] = "max_likelihood" ,
511- beta_min : float = 0.002 , # sigma_min in the paper
512- beta_max : float = 80.0 , # sigma_max in the paper
513- beta_data : float = .5 , #sigma_data in the paper
514- mean_0 : Union [Tensor , float ] = 0.0 ,
515- std_0 : Union [Tensor , float ] = 1.0 ,
516- t_min : float = 1e-5 , # will be ignored due to EDM setup
517- t_max : float = 1.0 , #
518- pmean : float = - 1.2 , # mean of noise scheme for training
519- pstd : float = 1.2 , # std of noise scheme for training
520- sigma_data : float = 0.5 ,
521- ) -> None :
522-
523-
524- #TODO: store sigma values for training in extra field
525- self .pmean , self .pstd = pmean , pstd
526- noise_dist = stats .norm (pmean , pstd ** 2 )
527- self .sigma_min = exp (noise_dist .ppf (0.01 ))
528- self .sigma_max = exp (noise_dist .ppf (0.99 ))
529-
530- self .beta_data = beta_data #sigma data from edm paper
531- self .rho = 7
532-
533- super ().__init__ (
534- net ,
535- input_shape ,
536- condition_shape ,
537- mean_0 = mean_0 ,
538- std_0 = std_0 ,
539- weight_fn = weight_fn ,
540- beta_min = beta_min ,
541- beta_max = beta_max ,
542- t_min = t_min ,
543- t_max = t_max ,
544- )
545-
546- def mean_t_fn (self , times : Tensor ) -> Tensor :
547- """Conditional mean function for EDM-style DMs.
548- This is required to model c_in.
549-
550- Args:
551- times: time variable in [0,1].
552-
553- Returns:
554- Conditional mean at a given time.
555- """
556- noise = self .noise_schedule (times )
557- phi = 1. / torch .sqrt (noise ** 2 + self .beta_data ** 2 )
558- for _ in range (len (self .input_shape )):
559- phi = phi .unsqueeze (- 1 )
560- return phi
561-
562- def std_fn (self , times : Tensor ) -> Tensor :
563- """Standard deviation function for EDM style DMs.
564- This is akin to c_noise in the network/precond parametrisation.
565- Args:
566- times: time variable in [0,1].
567-
568- Returns:
569- Standard deviation at a given time.
570- """
571- std = .25 * torch .log (times )
572- for _ in range (len (self .input_shape )):
573- std = std .unsqueeze (- 1 )
574- return std
575-
576- def drift_fn (self , input : Tensor , times : Tensor ) -> Tensor :
577- """Drift function for variance preserving SDEs.
578-
579- Args:
580- input: Original data, x0.
581- times: SDE time variable in [0,1].
582-
583- Returns:
584- Drift function at a given time.
585- """
586- phi = - 0.5 * self .noise_schedule (times )
587- while len (phi .shape ) < len (input .shape ):
588- phi = phi .unsqueeze (- 1 )
589- return phi * input
590-
591- def diffusion_fn (self , input : Tensor , times : Tensor ) -> Tensor :
592- """Diffusion function for variance preserving SDEs.
593-
594- Args:
595- input: Original data, x0.
596- times: SDE time variable in [0,1].
597-
598- Returns:
599- Drift function at a given time.
600- """
601- g = torch .sqrt (self .noise_schedule (times ))
602- while len (g .shape ) < len (input .shape ):
603- g = g .unsqueeze (- 1 )
604- return g
605-
606- def noise_schedule (self , times : Tensor ) -> Tensor :
607- """
608- Generate a beta schedule similar to suggestions in the EDM [1] paper.
609-
610- This method acts as a fallback in case derivative classes do not
611- implement it on their own. It calculates a linear beta schedule defined
612- by the input `times`, which represent the normalized time steps t ∈ [0, 1].
613-
614- Args:
615- times (Tensor):
616- SDE times in [0, 1]. This tensor will be regenerated from
617- self.times_schedule
618-
619- Returns:
620- Tensor: Generated beta schedule at a given time.
621-
622- [1] Karras et al "Elucidating the Design Space of Diffusion-Based
623- Generative Models", https://arxiv.org/abs/2206.00364
624- """
625- return times
626-
627- def times_schedule (
628- self , num_samples : int , t_min : float = None , t_max : float = None
629- ) -> Tensor :
630- """
631- Construct time samples as suggested in EDM paper [1].
632-
633- Args:
634- num_samples (int): Number of samples to generate.
635- t_min (float, optional): The minimum time value. Defaults to self.t_min.
636- t_max (float, optional): The maximum time value. Defaults to self.t_max.
637-
638- Returns:
639- Tensor: A tensor of sampled time variables scaled and shifted to
640- the range [0,1].
641-
642- [1] Karras et al "Elucidating the Design Space of Diffusion-Based
643- Generative Models", https://arxiv.org/abs/2206.00364
644- """
645- times = torch .linspace (0.0 , 1.0 , steps = num_samples )
646- inv_rho = 1.0 / self .rho
647-
648- beta_scale = self .beta_max ** (inv_rho ) - self .beta_min ** (inv_rho )
649- offset = self .beta_min ** (inv_rho )
650-
651- return (offset + beta_scale * times ) ** (self .rho )
652-
653-
654501class SubVPScoreEstimator (ConditionalScoreEstimator ):
655502 """Class for score estimators with sub-variance preserving SDEs."""
656503
@@ -842,9 +689,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
842689 return g
843690
844691
845- # TODO: try to add a EDM-like estimator
846-
847-
848692class GaussianFourierTimeEmbedding (nn .Module ):
849693 """Gaussian random features for encoding time steps.
850694
0 commit comments