13
13
# limitations under the License.
14
14
15
15
# Imports
16
- import os
17
- from tqdm import tqdm
18
16
import collections
19
- import openwakeword
20
- import numpy as np
21
- import scipy
17
+ import os
22
18
import pickle
19
+ from typing import List , Union
23
20
21
+ import numpy as np
22
+ import scipy
24
23
from sklearn .linear_model import LogisticRegression
25
24
from sklearn .pipeline import make_pipeline
26
25
from sklearn .preprocessing import FunctionTransformer , StandardScaler
26
+ from tqdm import tqdm
27
+
28
+ import openwakeword
27
29
28
30
29
31
# Define functions to prepare data for speaker dependent verifier model
@@ -112,8 +114,8 @@ def train_verifier_model(features: np.ndarray, labels: np.ndarray):
112
114
113
115
114
116
def train_custom_verifier (
115
- positive_reference_clips : str ,
116
- negative_reference_clips : str ,
117
+ positive_reference_clips : List [ Union [ str , os . PathLike ]] ,
118
+ negative_reference_clips : List [ Union [ str , os . PathLike ]] ,
117
119
output_path : str ,
118
120
model_name : str ,
119
121
** kwargs
@@ -123,11 +125,11 @@ def train_custom_verifier(
123
125
from a single user.
124
126
125
127
Args:
126
- positive_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
128
+ positive_reference_clips (List[Union[ str, os.PathLike]] ): The path(s) to single-channel 16khz, 16-bit WAV files
127
129
of the target wake word/phrase.
128
- negative_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
130
+ negative_reference_clips (List[Union[ str, os.PathLike]] ): The path(s) to single-channel 16khz, 16-bit WAV files
129
131
of miscellaneous speech not containing the target wake word/phrase.
130
- output_path (str): The location to save the trained verifier model (as a scikit-learn .joblib file)
132
+ output_path (str): The location to save the trained verifier model (as a Python pickle file (.pkl) )
131
133
model_name (str): The name or path of the trained openWakeWord model that the verifier model will be
132
134
based on. If only a name, it must be one of the pre-trained models included in the
133
135
openWakeWord release.
@@ -171,4 +173,5 @@ def train_custom_verifier(
171
173
172
174
# Save logistic regression model to specified output location
173
175
print ("Done!" )
174
- pickle .dump (lr_model , open (output_path , "wb" ))
176
+ with open (output_path , "wb" ) as f :
177
+ pickle .dump (lr_model , f )
0 commit comments