Skip to content

Commit c40fe92

Browse files
committed
fixed bug in train.py when using longer training examples than normal [skip ci]
1 parent fe57deb commit c40fe92

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

openwakeword/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -813,13 +813,13 @@ def convert_onnx_to_tflite(onnx_model_path, output_path):
813813
# Create openwakeword model
814814
if args.train_model is True:
815815
F = openwakeword.utils.AudioFeatures(device='cpu')
816-
input_shape = F.get_embedding_shape(config["total_length"]//16000) # training data is always 16 khz
816+
input_shape = np.load(os.path.join(feature_save_dir, "positive_features_test.npy")).shape[1:]
817817

818818
oww = Model(n_classes=1, input_shape=input_shape, model_type=config["model_type"],
819819
layer_dim=config["layer_size"], seconds_per_example=1280*input_shape[0]/16000)
820820

821821
# Create data transform function for batch generation to handle differ clip lengths (todo: write tests for this)
822-
def f(x, n=16):
822+
def f(x, n=input_shape[0]):
823823
"""Simple transformation function to ensure negative data is the appropriate shape for the model size"""
824824
if n > x.shape[1] or n < x.shape[1]:
825825
x = np.vstack(x)

0 commit comments

Comments
 (0)