Skip to content

Commit f808d8e

Browse files
committed
Debug BF16 and RMVPE
1 parent 3678712 commit f808d8e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
140140
if writers is not None:
141141
writer, writer_eval = writers
142142

143-
half_type = torch.float16 if hps.train.half_type=="fp16" else torch.bfloat16
143+
half_type = torch.bfloat16 if hps.train.half_type=="bf16" else torch.float16
144144

145145
# train_loader.batch_sampler.set_epoch(epoch)
146146
global global_step

utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
9999
f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
100100
elif f0_predictor == "rmvpe":
101101
from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
102-
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float16 ,device=kargs["device"],threshold=kargs["threshold"])
102+
f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
103103
else:
104104
raise Exception("Unknown f0 predictor")
105105
return f0_predictor_object

0 commit comments

Comments
 (0)