|
1 | 1 | import torch
|
2 |
| -import torch.nn.functional as F |
3 |
| -import math |
4 | 2 |
|
5 | 3 |
|
6 | 4 | class NoiseScheduleVP:
|
@@ -559,7 +557,7 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal
|
559 | 557 | x_t: A pytorch tensor. The approximated solution at time `t`.
|
560 | 558 | """
|
561 | 559 | ns = self.noise_schedule
|
562 |
| - dims = x.dim() |
| 560 | + x.dim() |
563 | 561 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
564 | 562 | h = lambda_t - lambda_s
|
565 | 563 | log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
@@ -984,20 +982,25 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol
|
984 | 982 | nfe = 0
|
985 | 983 | if order == 2:
|
986 | 984 | r1 = 0.5
|
987 |
| - lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) |
988 |
| - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) |
| 985 | + def lower_update(x, s, t): |
| 986 | + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) |
| 987 | + def higher_update(x, s, t, **kwargs): |
| 988 | + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) |
989 | 989 | elif order == 3:
|
990 | 990 | r1, r2 = 1. / 3., 2. / 3.
|
991 |
| - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) |
992 |
| - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) |
| 991 | + def lower_update(x, s, t): |
| 992 | + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) |
| 993 | + def higher_update(x, s, t, **kwargs): |
| 994 | + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) |
993 | 995 | else:
|
994 | 996 | raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
995 | 997 | while torch.abs((s - t_0)).mean() > t_err:
|
996 | 998 | t = ns.inverse_lambda(lambda_s + h)
|
997 | 999 | x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
998 | 1000 | x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
999 | 1001 | delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
1000 |
| - norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) |
| 1002 | + def norm_fn(v): |
| 1003 | + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) |
1001 | 1004 | E = norm_fn((x_higher - x_lower) / delta).max()
|
1002 | 1005 | if torch.all(E <= 1.):
|
1003 | 1006 | x = x_higher
|
|
0 commit comments