Skip to content

Commit e5113c2

Browse files
committed
fix: custom verifier type hints
fix: be sure to close pickle file after opening
1 parent c40fe92 commit e5113c2

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

openwakeword/custom_verifier_model.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@
1313
# limitations under the License.
1414

1515
# Imports
16-
import os
17-
from tqdm import tqdm
1816
import collections
19-
import openwakeword
20-
import numpy as np
21-
import scipy
17+
import os
2218
import pickle
19+
from typing import List, Union
2320

21+
import numpy as np
22+
import scipy
2423
from sklearn.linear_model import LogisticRegression
2524
from sklearn.pipeline import make_pipeline
2625
from sklearn.preprocessing import FunctionTransformer, StandardScaler
26+
from tqdm import tqdm
27+
28+
import openwakeword
2729

2830

2931
# Define functions to prepare data for speaker dependent verifier model
@@ -112,8 +114,8 @@ def train_verifier_model(features: np.ndarray, labels: np.ndarray):
112114

113115

114116
def train_custom_verifier(
115-
positive_reference_clips: str,
116-
negative_reference_clips: str,
117+
positive_reference_clips: List[Union[str, os.PathLike]],
118+
negative_reference_clips: List[Union[str, os.PathLike]],
117119
output_path: str,
118120
model_name: str,
119121
**kwargs
@@ -123,11 +125,11 @@ def train_custom_verifier(
123125
from a single user.
124126
125127
Args:
126-
positive_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
128+
positive_reference_clips (List[Union[str, os.PathLike]]): The path(s) to single-channel 16khz, 16-bit WAV files
127129
of the target wake word/phrase.
128-
negative_reference_clips (str): The path to a directory containing single-channel 16khz, 16-bit WAV files
130+
negative_reference_clips (List[Union[str, os.PathLike]]): The path(s) to single-channel 16khz, 16-bit WAV files
129131
of miscellaneous speech not containing the target wake word/phrase.
130-
output_path (str): The location to save the trained verifier model (as a scikit-learn .joblib file)
132+
output_path (str): The location to save the trained verifier model (as a Python pickle file (.pkl))
131133
model_name (str): The name or path of the trained openWakeWord model that the verifier model will be
132134
based on. If only a name, it must be one of the pre-trained models included in the
133135
openWakeWord release.
@@ -171,4 +173,5 @@ def train_custom_verifier(
171173

172174
# Save logistic regression model to specified output location
173175
print("Done!")
174-
pickle.dump(lr_model, open(output_path, "wb"))
176+
with open(output_path, "wb") as f:
177+
pickle.dump(lr_model, f)

0 commit comments

Comments
 (0)