Skip to content

Commit c90ad94

Browse files
authored
Merge pull request #188 from stas6626/master
Online Decoding
2 parents 9a20e00 + 930672d commit c90ad94

9 files changed

+642
-120
lines changed

README.md

+30
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,36 @@ beam_results, beam_scores, timesteps, out_lens = decoder.decode(output)
6464
1. `timesteps` - Shape: BATCHSIZE x N_BEAMS The timestep at which the nth output character has peak probability. Can be used as alignment between the audio and the transcript.
6565
1. `out_lens` - Shape: BATCHSIZE x N_BEAMS. `out_lens[i][j]` is the length of the jth beam_result, of item i of your batch.
6666

67+
### Online decoding
68+
69+
```python
70+
from ctcdecode import OnlineCTCBeamDecoder
71+
72+
decoder = OnlineCTCBeamDecoder(
73+
labels,
74+
model_path=None,
75+
alpha=0,
76+
beta=0,
77+
cutoff_top_n=40,
78+
cutoff_prob=1.0,
79+
beam_width=100,
80+
num_processes=4,
81+
blank_id=0,
82+
log_probs_input=False
83+
)
84+
85+
state1 = ctcdecode.DecoderState(decoder)
86+
87+
probs_seq = torch.FloatTensor([probs_seq])
88+
beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq[:, :2], [state1], [False])
89+
beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq[:, 2:], [state1], [True])
90+
91+
```
92+
93+
The Online decoder is copying CTCBeamDecoder interface, but it requires states and is_eos_s sequences.
94+
95+
States are used to accumulate sequences of chunks, each corresponding to one data source. Is_eos_s tells the decoder whether the chunks have stopped being pushed to the corresponding state.
96+
6797
### More examples
6898

6999
Get the top beam for the first item in your batch

ctcdecode/__init__.py

+183-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
23
from ._ext import ctc_decode
34

45

@@ -17,13 +18,24 @@ class CTCBeamDecoder(object):
1718
cutoff_prob (float): Cutoff probability in pruning. 1.0 means no pruning.
1819
beam_width (int): This controls how broad the beam search is. Higher values are more likely to find top beams,
1920
but they also will make your beam search exponentially slower.
20-
num_processes (int): Parallelize the batch using num_processes workers.
21+
num_processes (int): Parallelize the batch using num_processes workers.
2122
blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
2223
log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
2324
"""
2425

25-
def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100,
26-
num_processes=4, blank_id=0, log_probs_input=False):
26+
def __init__(
27+
self,
28+
labels,
29+
model_path=None,
30+
alpha=0,
31+
beta=0,
32+
cutoff_top_n=40,
33+
cutoff_prob=1.0,
34+
beam_width=100,
35+
num_processes=4,
36+
blank_id=0,
37+
log_probs_input=False,
38+
):
2739
self.cutoff_top_n = cutoff_top_n
2840
self._beam_width = beam_width
2941
self._scorer = None
@@ -33,8 +45,9 @@ def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cu
3345
self._blank_id = blank_id
3446
self._log_probs = 1 if log_probs_input else 0
3547
if model_path:
36-
self._scorer = ctc_decode.paddle_get_scorer(alpha, beta, model_path.encode(), self._labels,
37-
self._num_labels)
48+
self._scorer = ctc_decode.paddle_get_scorer(
49+
alpha, beta, model_path.encode(), self._labels, self._num_labels
50+
)
3851
self._cutoff_prob = cutoff_prob
3952

4053
def decode(self, probs, seq_lens=None):
@@ -72,14 +85,40 @@ def decode(self, probs, seq_lens=None):
7285
scores = torch.FloatTensor(batch_size, self._beam_width).cpu().float()
7386
out_seq_len = torch.zeros(batch_size, self._beam_width).cpu().int()
7487
if self._scorer:
75-
ctc_decode.paddle_beam_decode_lm(probs, seq_lens, self._labels, self._num_labels, self._beam_width,
76-
self._num_processes, self._cutoff_prob, self.cutoff_top_n, self._blank_id,
77-
self._log_probs, self._scorer, output, timesteps, scores, out_seq_len)
88+
ctc_decode.paddle_beam_decode_lm(
89+
probs,
90+
seq_lens,
91+
self._labels,
92+
self._num_labels,
93+
self._beam_width,
94+
self._num_processes,
95+
self._cutoff_prob,
96+
self.cutoff_top_n,
97+
self._blank_id,
98+
self._log_probs,
99+
self._scorer,
100+
output,
101+
timesteps,
102+
scores,
103+
out_seq_len,
104+
)
78105
else:
79-
ctc_decode.paddle_beam_decode(probs, seq_lens, self._labels, self._num_labels, self._beam_width,
80-
self._num_processes,
81-
self._cutoff_prob, self.cutoff_top_n, self._blank_id, self._log_probs,
82-
output, timesteps, scores, out_seq_len)
106+
ctc_decode.paddle_beam_decode(
107+
probs,
108+
seq_lens,
109+
self._labels,
110+
self._num_labels,
111+
self._beam_width,
112+
self._num_processes,
113+
self._cutoff_prob,
114+
self.cutoff_top_n,
115+
self._blank_id,
116+
self._log_probs,
117+
output,
118+
timesteps,
119+
scores,
120+
out_seq_len,
121+
)
83122

84123
return output, scores, timesteps, out_seq_len
85124

@@ -99,3 +138,135 @@ def reset_params(self, alpha, beta):
99138
def __del__(self):
100139
if self._scorer is not None:
101140
ctc_decode.paddle_release_scorer(self._scorer)
141+
142+
143+
class OnlineCTCBeamDecoder(object):
144+
"""
145+
PyTorch wrapper for DeepSpeech PaddlePaddle Beam Search Decoder with interface for online decoding.
146+
Args:
147+
labels (list): The tokens/vocab used to train your model.
148+
They should be in the same order as they are in your model's outputs.
149+
model_path (basestring): The path to your external KenLM language model(LM)
150+
alpha (float): Weighting associated with the LMs probabilities.
151+
A weight of 0 means the LM has no effect.
152+
beta (float): Weight associated with the number of words within our beam.
153+
cutoff_top_n (int): Cutoff number in pruning. Only the top cutoff_top_n characters
154+
with the highest probability in the vocab will be used in beam search.
155+
cutoff_prob (float): Cutoff probability in pruning. 1.0 means no pruning.
156+
beam_width (int): This controls how broad the beam search is. Higher values are more likely to find top beams,
157+
but they also will make your beam search exponentially slower.
158+
num_processes (int): Parallelize the batch using num_processes workers.
159+
blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
160+
log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
161+
"""
162+
def __init__(
163+
self,
164+
labels,
165+
model_path=None,
166+
alpha=0,
167+
beta=0,
168+
cutoff_top_n=40,
169+
cutoff_prob=1.0,
170+
beam_width=100,
171+
num_processes=4,
172+
blank_id=0,
173+
log_probs_input=False,
174+
):
175+
self._cutoff_top_n = cutoff_top_n
176+
self._beam_width = beam_width
177+
self._scorer = None
178+
self._num_processes = num_processes
179+
self._labels = list(labels) # Ensure labels are a list
180+
self._num_labels = len(labels)
181+
self._blank_id = blank_id
182+
self._log_probs = 1 if log_probs_input else 0
183+
if model_path:
184+
self._scorer = ctc_decode.paddle_get_scorer(
185+
alpha, beta, model_path.encode(), self._labels, self._num_labels
186+
)
187+
self._cutoff_prob = cutoff_prob
188+
189+
def decode(self, probs, states, is_eos_s, seq_lens=None):
190+
"""
191+
Conducts the beamsearch on model outputs and return results.
192+
Args:
193+
probs (Tensor) - A rank 3 tensor representing model outputs. Shape is batch x num_timesteps x num_labels.
194+
states (Sequence[DecoderState]) - sequence of decoding states with lens equal to batch_size.
195+
is_eos_s (Sequence[bool]) - sequence of bool with lens equal to batch size.
196+
Should have False if havent pushed all chunks yet, and True if you pushed last cank and you want to get an answer
197+
seq_lens (Tensor) - A rank 1 tensor representing the sequence length of the items in the batch. Optional,
198+
if not provided the size of axis 1 (num_timesteps) of `probs` is used for all items
199+
200+
Returns:
201+
tuple: (beam_results, beam_scores, timesteps, out_lens)
202+
203+
beam_results (Tensor): A 3-dim tensor representing the top n beams of a batch of items.
204+
Shape: batchsize x num_beams x num_timesteps.
205+
Results are still encoded as ints at this stage.
206+
beam_scores (Tensor): A 3-dim tensor representing the likelihood of each beam in beam_results.
207+
Shape: batchsize x num_beams x num_timesteps
208+
timesteps (Tensor): A 2-dim tensor representing the timesteps at which the nth output character
209+
has peak probability.
210+
To be used as alignment between audio and transcript.
211+
Shape: batchsize x num_beams
212+
out_lens (Tensor): A 2-dim tensor representing the length of each beam in beam_results.
213+
Shape: batchsize x n_beams.
214+
215+
"""
216+
probs = probs.cpu().float()
217+
batch_size, max_seq_len = probs.size(0), probs.size(1)
218+
if seq_lens is None:
219+
seq_lens = torch.IntTensor(batch_size).fill_(max_seq_len)
220+
else:
221+
seq_lens = seq_lens.cpu().int()
222+
scores = torch.FloatTensor(batch_size, self._beam_width).cpu().float()
223+
out_seq_len = torch.zeros(batch_size, self._beam_width).cpu().int()
224+
225+
decode_fn = ctc_decode.paddle_beam_decode_with_given_state
226+
res_beam_results, res_timesteps = decode_fn(
227+
probs,
228+
seq_lens,
229+
self._num_processes,
230+
[state.state for state in states],
231+
is_eos_s,
232+
scores,
233+
out_seq_len
234+
)
235+
res_beam_results = res_beam_results.int()
236+
res_timesteps = res_timesteps.int()
237+
238+
return res_beam_results, scores, res_timesteps, out_seq_len
239+
240+
def character_based(self):
241+
return ctc_decode.is_character_based(self._scorer) if self._scorer else None
242+
243+
def max_order(self):
244+
return ctc_decode.get_max_order(self._scorer) if self._scorer else None
245+
246+
def dict_size(self):
247+
return ctc_decode.get_dict_size(self._scorer) if self._scorer else None
248+
249+
def reset_state(state):
250+
ctc_decode.paddle_release_state(state)
251+
252+
253+
class DecoderState:
254+
"""
255+
Class using for maintain different chunks of data in one beam algorithm corresponding to one unique source.
256+
Note: after using State you should delete it, so dont reuse it
257+
Args:
258+
decoder (OnlineCTCBeamDecoder) - decoder you will use for decoding.
259+
"""
260+
def __init__(self, decoder):
261+
self.state = ctc_decode.paddle_get_decoder_state(
262+
decoder._labels,
263+
decoder._beam_width,
264+
decoder._cutoff_prob,
265+
decoder._cutoff_top_n,
266+
decoder._blank_id,
267+
decoder._log_probs,
268+
decoder._scorer,
269+
)
270+
271+
def __del__(self):
272+
ctc_decode.paddle_release_state(self.state)

0 commit comments

Comments
 (0)