Skip to content

Commit 3635b23

Browse files
authored
Merge pull request #205 from carlthome/patch-2
Raise errors if sampler_type unhandled
2 parents c1e9c46 + 3dc8e6e commit 3635b23

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

stable_audio_tools/inference/sampling.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def sample_k(
384384
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=callback, extra_args=extra_args)
385385
elif sampler_type == "dpmpp-3m-sde":
386386
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=callback, extra_args=extra_args)
387+
else:
388+
raise ValueError(f"Unknown sampler_type: {sampler_type}")
387389
elif is_v_diff:
388390

389391
if sigma_max > 1: # sigma_max should be between 0 and 1
@@ -454,4 +456,6 @@ def sample_rf(
454456
elif sampler_type == "dpmpp":
455457
return sample_flow_dpmpp(model_fn, x, sigmas=t, sigma_max=sigma_max, callback=callback, **extra_args)
456458
elif sampler_type == "pingpong":
457-
return sample_flow_pingpong(model_fn, x, sigmas=t, sigma_max=sigma_max, callback=callback, **extra_args)
459+
return sample_flow_pingpong(model_fn, x, sigmas=t, sigma_max=sigma_max, callback=callback, **extra_args)
460+
else:
461+
raise ValueError(f"Unknown sampler_type: {sampler_type}")

0 commit comments

Comments
 (0)