Skip to content

Commit 55dd869

Browse files
fix(rmvpe): pass device when loading torch model (svc-develop-team#301)
1 parent e7b4785 commit 55dd869

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modules/F0Predictor/rmvpe/inference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=16
1616
else:
1717
self.device = device
1818
model = E2E0(4, 1, (2, 2))
19-
ckpt = torch.load(model_path)
19+
ckpt = torch.load(model_path, map_location=torch.device(self.device))
2020
model.load_state_dict(ckpt['model'])
2121
model = model.to(dtype).to(self.device)
2222
model.eval()
@@ -54,4 +54,4 @@ def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=Fal
5454
mel = mel_extractor(audio_res, center=True).to(self.dtype)
5555
hidden = self.mel2hidden(mel)
5656
f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi)
57-
return f0
57+
return f0

0 commit comments

Comments
 (0)