1
1
import torch
2
+
2
3
from ._ext import ctc_decode
3
4
4
5
@@ -17,13 +18,24 @@ class CTCBeamDecoder(object):
17
18
cutoff_prob (float): Cutoff probability in pruning. 1.0 means no pruning.
18
19
beam_width (int): This controls how broad the beam search is. Higher values are more likely to find top beams,
19
20
but they also will make your beam search exponentially slower.
20
- num_processes (int): Parallelize the batch using num_processes workers.
21
+ num_processes (int): Parallelize the batch using num_processes workers.
21
22
blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
22
23
log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
23
24
"""
24
25
25
- def __init__ (self , labels , model_path = None , alpha = 0 , beta = 0 , cutoff_top_n = 40 , cutoff_prob = 1.0 , beam_width = 100 ,
26
- num_processes = 4 , blank_id = 0 , log_probs_input = False ):
26
+ def __init__ (
27
+ self ,
28
+ labels ,
29
+ model_path = None ,
30
+ alpha = 0 ,
31
+ beta = 0 ,
32
+ cutoff_top_n = 40 ,
33
+ cutoff_prob = 1.0 ,
34
+ beam_width = 100 ,
35
+ num_processes = 4 ,
36
+ blank_id = 0 ,
37
+ log_probs_input = False ,
38
+ ):
27
39
self .cutoff_top_n = cutoff_top_n
28
40
self ._beam_width = beam_width
29
41
self ._scorer = None
@@ -33,8 +45,9 @@ def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cu
33
45
self ._blank_id = blank_id
34
46
self ._log_probs = 1 if log_probs_input else 0
35
47
if model_path :
36
- self ._scorer = ctc_decode .paddle_get_scorer (alpha , beta , model_path .encode (), self ._labels ,
37
- self ._num_labels )
48
+ self ._scorer = ctc_decode .paddle_get_scorer (
49
+ alpha , beta , model_path .encode (), self ._labels , self ._num_labels
50
+ )
38
51
self ._cutoff_prob = cutoff_prob
39
52
40
53
def decode (self , probs , seq_lens = None ):
@@ -72,14 +85,40 @@ def decode(self, probs, seq_lens=None):
72
85
scores = torch .FloatTensor (batch_size , self ._beam_width ).cpu ().float ()
73
86
out_seq_len = torch .zeros (batch_size , self ._beam_width ).cpu ().int ()
74
87
if self ._scorer :
75
- ctc_decode .paddle_beam_decode_lm (probs , seq_lens , self ._labels , self ._num_labels , self ._beam_width ,
76
- self ._num_processes , self ._cutoff_prob , self .cutoff_top_n , self ._blank_id ,
77
- self ._log_probs , self ._scorer , output , timesteps , scores , out_seq_len )
88
+ ctc_decode .paddle_beam_decode_lm (
89
+ probs ,
90
+ seq_lens ,
91
+ self ._labels ,
92
+ self ._num_labels ,
93
+ self ._beam_width ,
94
+ self ._num_processes ,
95
+ self ._cutoff_prob ,
96
+ self .cutoff_top_n ,
97
+ self ._blank_id ,
98
+ self ._log_probs ,
99
+ self ._scorer ,
100
+ output ,
101
+ timesteps ,
102
+ scores ,
103
+ out_seq_len ,
104
+ )
78
105
else :
79
- ctc_decode .paddle_beam_decode (probs , seq_lens , self ._labels , self ._num_labels , self ._beam_width ,
80
- self ._num_processes ,
81
- self ._cutoff_prob , self .cutoff_top_n , self ._blank_id , self ._log_probs ,
82
- output , timesteps , scores , out_seq_len )
106
+ ctc_decode .paddle_beam_decode (
107
+ probs ,
108
+ seq_lens ,
109
+ self ._labels ,
110
+ self ._num_labels ,
111
+ self ._beam_width ,
112
+ self ._num_processes ,
113
+ self ._cutoff_prob ,
114
+ self .cutoff_top_n ,
115
+ self ._blank_id ,
116
+ self ._log_probs ,
117
+ output ,
118
+ timesteps ,
119
+ scores ,
120
+ out_seq_len ,
121
+ )
83
122
84
123
return output , scores , timesteps , out_seq_len
85
124
@@ -99,3 +138,135 @@ def reset_params(self, alpha, beta):
99
138
def __del__ (self ):
100
139
if self ._scorer is not None :
101
140
ctc_decode .paddle_release_scorer (self ._scorer )
141
+
142
+
143
+ class OnlineCTCBeamDecoder (object ):
144
+ """
145
+ PyTorch wrapper for DeepSpeech PaddlePaddle Beam Search Decoder with interface for online decoding.
146
+ Args:
147
+ labels (list): The tokens/vocab used to train your model.
148
+ They should be in the same order as they are in your model's outputs.
149
+ model_path (basestring): The path to your external KenLM language model(LM)
150
+ alpha (float): Weighting associated with the LMs probabilities.
151
+ A weight of 0 means the LM has no effect.
152
+ beta (float): Weight associated with the number of words within our beam.
153
+ cutoff_top_n (int): Cutoff number in pruning. Only the top cutoff_top_n characters
154
+ with the highest probability in the vocab will be used in beam search.
155
+ cutoff_prob (float): Cutoff probability in pruning. 1.0 means no pruning.
156
+ beam_width (int): This controls how broad the beam search is. Higher values are more likely to find top beams,
157
+ but they also will make your beam search exponentially slower.
158
+ num_processes (int): Parallelize the batch using num_processes workers.
159
+ blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
160
+ log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
161
+ """
162
+ def __init__ (
163
+ self ,
164
+ labels ,
165
+ model_path = None ,
166
+ alpha = 0 ,
167
+ beta = 0 ,
168
+ cutoff_top_n = 40 ,
169
+ cutoff_prob = 1.0 ,
170
+ beam_width = 100 ,
171
+ num_processes = 4 ,
172
+ blank_id = 0 ,
173
+ log_probs_input = False ,
174
+ ):
175
+ self ._cutoff_top_n = cutoff_top_n
176
+ self ._beam_width = beam_width
177
+ self ._scorer = None
178
+ self ._num_processes = num_processes
179
+ self ._labels = list (labels ) # Ensure labels are a list
180
+ self ._num_labels = len (labels )
181
+ self ._blank_id = blank_id
182
+ self ._log_probs = 1 if log_probs_input else 0
183
+ if model_path :
184
+ self ._scorer = ctc_decode .paddle_get_scorer (
185
+ alpha , beta , model_path .encode (), self ._labels , self ._num_labels
186
+ )
187
+ self ._cutoff_prob = cutoff_prob
188
+
189
+ def decode (self , probs , states , is_eos_s , seq_lens = None ):
190
+ """
191
+ Conducts the beamsearch on model outputs and return results.
192
+ Args:
193
+ probs (Tensor) - A rank 3 tensor representing model outputs. Shape is batch x num_timesteps x num_labels.
194
+ states (Sequence[DecoderState]) - sequence of decoding states with lens equal to batch_size.
195
+ is_eos_s (Sequence[bool]) - sequence of bool with lens equal to batch size.
196
+ Should have False if havent pushed all chunks yet, and True if you pushed last cank and you want to get an answer
197
+ seq_lens (Tensor) - A rank 1 tensor representing the sequence length of the items in the batch. Optional,
198
+ if not provided the size of axis 1 (num_timesteps) of `probs` is used for all items
199
+
200
+ Returns:
201
+ tuple: (beam_results, beam_scores, timesteps, out_lens)
202
+
203
+ beam_results (Tensor): A 3-dim tensor representing the top n beams of a batch of items.
204
+ Shape: batchsize x num_beams x num_timesteps.
205
+ Results are still encoded as ints at this stage.
206
+ beam_scores (Tensor): A 3-dim tensor representing the likelihood of each beam in beam_results.
207
+ Shape: batchsize x num_beams x num_timesteps
208
+ timesteps (Tensor): A 2-dim tensor representing the timesteps at which the nth output character
209
+ has peak probability.
210
+ To be used as alignment between audio and transcript.
211
+ Shape: batchsize x num_beams
212
+ out_lens (Tensor): A 2-dim tensor representing the length of each beam in beam_results.
213
+ Shape: batchsize x n_beams.
214
+
215
+ """
216
+ probs = probs .cpu ().float ()
217
+ batch_size , max_seq_len = probs .size (0 ), probs .size (1 )
218
+ if seq_lens is None :
219
+ seq_lens = torch .IntTensor (batch_size ).fill_ (max_seq_len )
220
+ else :
221
+ seq_lens = seq_lens .cpu ().int ()
222
+ scores = torch .FloatTensor (batch_size , self ._beam_width ).cpu ().float ()
223
+ out_seq_len = torch .zeros (batch_size , self ._beam_width ).cpu ().int ()
224
+
225
+ decode_fn = ctc_decode .paddle_beam_decode_with_given_state
226
+ res_beam_results , res_timesteps = decode_fn (
227
+ probs ,
228
+ seq_lens ,
229
+ self ._num_processes ,
230
+ [state .state for state in states ],
231
+ is_eos_s ,
232
+ scores ,
233
+ out_seq_len
234
+ )
235
+ res_beam_results = res_beam_results .int ()
236
+ res_timesteps = res_timesteps .int ()
237
+
238
+ return res_beam_results , scores , res_timesteps , out_seq_len
239
+
240
+ def character_based (self ):
241
+ return ctc_decode .is_character_based (self ._scorer ) if self ._scorer else None
242
+
243
+ def max_order (self ):
244
+ return ctc_decode .get_max_order (self ._scorer ) if self ._scorer else None
245
+
246
+ def dict_size (self ):
247
+ return ctc_decode .get_dict_size (self ._scorer ) if self ._scorer else None
248
+
249
+ def reset_state (state ):
250
+ ctc_decode .paddle_release_state (state )
251
+
252
+
253
+ class DecoderState :
254
+ """
255
+ Class using for maintain different chunks of data in one beam algorithm corresponding to one unique source.
256
+ Note: after using State you should delete it, so dont reuse it
257
+ Args:
258
+ decoder (OnlineCTCBeamDecoder) - decoder you will use for decoding.
259
+ """
260
+ def __init__ (self , decoder ):
261
+ self .state = ctc_decode .paddle_get_decoder_state (
262
+ decoder ._labels ,
263
+ decoder ._beam_width ,
264
+ decoder ._cutoff_prob ,
265
+ decoder ._cutoff_top_n ,
266
+ decoder ._blank_id ,
267
+ decoder ._log_probs ,
268
+ decoder ._scorer ,
269
+ )
270
+
271
+ def __del__ (self ):
272
+ ctc_decode .paddle_release_state (self .state )
0 commit comments