Skip to content

Commit 95a461e

Browse files
author
Li Fuchen
authored
Add interface supporting double type (#163)
* add float64 input for warpctc * add comment to function of double
1 parent fc7f226 commit 95a461e

File tree

5 files changed

+234
-17
lines changed

5 files changed

+234
-17
lines changed

include/ctc.h

+75-5
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ struct ctcOptions {
6565
int blank_label;
6666
};
6767

68-
/** Compute the connectionist temporal classification loss between a sequence
69-
* of probabilities and a ground truth labeling. Optionally compute the
70-
* gradient with respect to the inputs.
68+
/** Compute the connectionist temporal classification loss between
69+
* a probability sequence with dtype float and a ground truth labeling.
70+
* Optionally compute the gradient with respect to the inputs.
7171
* \param [in] activations pointer to the activations in either CPU or GPU
7272
* addressable memory, depending on info. We assume a fixed
7373
* memory layout for this 3 dimensional tensor, which has dimension
@@ -112,10 +112,57 @@ API_REFERENCE ctcStatus_t compute_ctc_loss(const float* const activations,
112112
void *workspace,
113113
ctcOptions options);
114114

115+
/** Compute the connectionist temporal classification loss between
116+
* a probability sequence of dtype double and a ground truth labeling.
117+
* Optionally compute the gradient with respect to the inputs.
118+
* \param [in] activations pointer to the activations in either CPU or GPU
119+
* addressable memory, depending on info. We assume a fixed
120+
* memory layout for this 3 dimensional tensor, which has dimension
121+
* (t, n, p), where t is the time index, n is the minibatch index,
122+
* and p indexes over probabilities of each symbol in the alphabet.
123+
* The memory layout is (t, n, p) in C order (slowest to fastest changing
124+
* index, aka row-major), or (p, n, t) in Fortran order (fastest to slowest
125+
* changing index, aka column-major). We also assume strides are equal to
126+
* dimensions - there is no padding between dimensions.
127+
* More precisely, element (t, n, p), for a problem with mini_batch examples
128+
* in the mini batch, and alphabet_size symbols in the alphabet, is located at:
129+
* activations[(t * mini_batch + n) * alphabet_size + p]
130+
* \param [out] gradients if not NULL, then gradients are computed. Should be
131+
* allocated in the same memory space as probs and memory
132+
* ordering is identical.
133+
* \param [in] flat_labels Always in CPU memory. A concatenation
134+
* of all the labels for the minibatch.
135+
* \param [in] label_lengths Always in CPU memory. The length of each label
136+
* for each example in the minibatch.
137+
* \param [in] input_lengths Always in CPU memory. The number of time steps
138+
* for each sequence in the minibatch.
139+
* \param [in] alphabet_size The number of possible output symbols. There
140+
* should be this many probabilities for each time step.
141+
* \param [in] mini_batch How many examples in a minibatch.
142+
* \param [out] costs Always in CPU memory. The cost of each example in the
143+
* minibatch.
144+
* \param [in,out] workspace In same memory space as probs. Should be of
145+
* size requested by get_workspace_size.
146+
* \param [in] options see struct ctcOptions
147+
*
148+
* \return Status information
149+
*
150+
* */
151+
API_REFERENCE ctcStatus_t compute_ctc_loss_double(const double* const activations,
152+
double* gradients,
153+
const int* const flat_labels,
154+
const int* const label_lengths,
155+
const int* const input_lengths,
156+
int alphabet_size,
157+
int minibatch,
158+
double *costs,
159+
void *workspace,
160+
ctcOptions options);
161+
115162

116163
/** For a given set of labels and minibatch size return the required workspace
117-
* size. This will need to be allocated in the same memory space as your
118-
* probabilities.
164+
* size when the dtype of your probabilities is float. This will need to be allocated
165+
* in the same memory space as your probabilities.
119166
* \param [in] label_lengths Always in CPU memory. The length of each label
120167
* for each example in the minibatch.
121168
* \param [in] input_lengths Always in CPU memory. The number of time steps
@@ -136,6 +183,29 @@ API_REFERENCE ctcStatus_t get_workspace_size(const int* const label_lengths,
136183
ctcOptions info,
137184
size_t* size_bytes);
138185

186+
/** For a given set of labels and minibatch size return the required workspace
187+
* size when the dtype of your probabilities is double. This will need to be allocated
188+
* in the same memory space as your probabilities.
189+
* \param [in] label_lengths Always in CPU memory. The length of each label
190+
* for each example in the minibatch.
191+
* \param [in] input_lengths Always in CPU memory. The number of time steps
192+
* for each sequence in the minibatch.
193+
* \param [in] alphabet_size How many symbols in the alphabet or, equivalently,
194+
* the number of probabilities at each time step
195+
* \param [in] mini_batch How many examples in a minibatch.
196+
* \param [in] info see struct ctcOptions
197+
* \param [out] size_bytes is pointer to a scalar where the memory
198+
* requirement in bytes will be placed. This memory should be allocated
199+
* at the same place, CPU or GPU, that the probs are in
200+
*
201+
* \return Status information
202+
**/
203+
API_REFERENCE ctcStatus_t get_workspace_size_double(const int* const label_lengths,
204+
const int* const input_lengths,
205+
int alphabet_size, int minibatch,
206+
ctcOptions info,
207+
size_t* size_bytes);
208+
139209
#ifdef __cplusplus
140210
}
141211
#endif

include/detail/gpu_ctc.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {
367367

368368
// Numerically stable SM
369369
ctcStatus_t ctc_status =
370-
reduce_max(probs_, denoms_, out_dim_,
370+
reduce_max<ProbT>(probs_, denoms_, out_dim_,
371371
activation_cols_, 1, stream_);
372372
if (ctc_status != CTC_STATUS_SUCCESS)
373373
return ctc_status;
@@ -385,7 +385,7 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {
385385

386386
// Reduce along columns to calculate denominator
387387
ctc_status =
388-
reduce_exp(probs_, denoms_, out_dim_,
388+
reduce_exp<ProbT>(probs_, denoms_, out_dim_,
389389
activation_cols_, 1, stream_);
390390
if (ctc_status != CTC_STATUS_SUCCESS)
391391
return ctc_status;

include/detail/reduce.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3-
ctcStatus_t reduce_negate(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream);
4-
ctcStatus_t reduce_exp(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream);
5-
ctcStatus_t reduce_max(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream);
3+
template <typename T>
4+
ctcStatus_t reduce_negate(const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream);
5+
template <typename T>
6+
ctcStatus_t reduce_exp(const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream);
7+
template <typename T>
8+
ctcStatus_t reduce_max(const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream);

src/ctc_entrypoint.cpp

+136
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <cstddef>
22
#include <iostream>
33
#include <algorithm>
4+
#include <cstdio>
45

56
#include <ctc.h>
67

@@ -89,6 +90,59 @@ ctcStatus_t compute_ctc_loss(const float* const activations,
8990
}
9091
}
9192

93+
ctcStatus_t compute_ctc_loss_double(const double* const activations,
94+
double* gradients,
95+
const int* const flat_labels,
96+
const int* const label_lengths,
97+
const int* const input_lengths,
98+
int alphabet_size,
99+
int minibatch,
100+
double *costs,
101+
void *workspace,
102+
ctcOptions options) {
103+
if (activations == nullptr ||
104+
flat_labels == nullptr ||
105+
label_lengths == nullptr ||
106+
input_lengths == nullptr ||
107+
costs == nullptr ||
108+
workspace == nullptr ||
109+
alphabet_size <= 0 ||
110+
minibatch <= 0)
111+
return CTC_STATUS_INVALID_VALUE;
112+
113+
if (options.loc == CTC_CPU) {
114+
CpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.num_threads,
115+
options.blank_label);
116+
117+
if (gradients != NULL)
118+
return ctc.cost_and_grad(activations, gradients,
119+
costs,
120+
flat_labels, label_lengths,
121+
input_lengths);
122+
else
123+
return ctc.score_forward(activations, costs, flat_labels,
124+
label_lengths, input_lengths);
125+
} else if (options.loc == CTC_GPU) {
126+
#ifdef __CUDACC__
127+
GpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.stream,
128+
options.blank_label);
129+
130+
if (gradients != NULL)
131+
return ctc.cost_and_grad(activations, gradients, costs,
132+
flat_labels, label_lengths,
133+
input_lengths);
134+
else
135+
return ctc.score_forward(activations, costs, flat_labels,
136+
label_lengths, input_lengths);
137+
#else
138+
std::cerr << "GPU execution requested, but not compiled with GPU support" << std::endl;
139+
return CTC_STATUS_EXECUTION_FAILED;
140+
#endif
141+
} else {
142+
return CTC_STATUS_INVALID_VALUE;
143+
}
144+
}
145+
92146

93147
ctcStatus_t get_workspace_size(const int* const label_lengths,
94148
const int* const input_lengths,
@@ -172,4 +226,86 @@ ctcStatus_t get_workspace_size(const int* const label_lengths,
172226
return CTC_STATUS_SUCCESS;
173227
}
174228

229+
ctcStatus_t get_workspace_size_double(const int* const label_lengths,
230+
const int* const input_lengths,
231+
int alphabet_size, int minibatch,
232+
ctcOptions options,
233+
size_t* size_bytes)
234+
{
235+
if (label_lengths == nullptr ||
236+
input_lengths == nullptr ||
237+
size_bytes == nullptr ||
238+
alphabet_size <= 0 ||
239+
minibatch <= 0)
240+
return CTC_STATUS_INVALID_VALUE;
241+
242+
// This is the max of all S and T for all examples in the minibatch.
243+
int maxL = *std::max_element(label_lengths, label_lengths + minibatch);
244+
int maxT = *std::max_element(input_lengths, input_lengths + minibatch);
245+
246+
const int S = 2 * maxL + 1;
247+
248+
*size_bytes = 0;
249+
250+
if (options.loc == CTC_GPU) {
251+
// GPU storage
252+
//nll_forward, nll_backward
253+
*size_bytes += 2 * sizeof(double) * minibatch;
254+
255+
//repeats
256+
*size_bytes += sizeof(int) * minibatch;
257+
258+
//label offsets
259+
*size_bytes += sizeof(int) * minibatch;
260+
261+
//utt_length
262+
*size_bytes += sizeof(int) * minibatch;
263+
264+
//label lengths
265+
*size_bytes += sizeof(int) * minibatch;
266+
267+
//labels without blanks - overallocate for now
268+
*size_bytes += sizeof(int) * maxL * minibatch;
269+
270+
//labels with blanks
271+
*size_bytes += sizeof(int) * S * minibatch;
272+
273+
//alphas
274+
*size_bytes += sizeof(double) * S * maxT * minibatch;
275+
276+
//denoms
277+
*size_bytes += sizeof(double) * maxT * minibatch;
278+
279+
//probs (since we will pass in activations)
280+
*size_bytes += sizeof(double) * alphabet_size * maxT * minibatch;
281+
282+
} else {
283+
//cpu can eventually replace all minibatch with
284+
//max number of concurrent threads if memory is
285+
//really tight
286+
287+
//per minibatch memory
288+
size_t per_minibatch_bytes = 0;
289+
290+
//output
291+
per_minibatch_bytes += sizeof(double) * alphabet_size ;
292+
293+
//alphas
294+
per_minibatch_bytes += sizeof(double) * S * maxT;
295+
296+
//betas
297+
per_minibatch_bytes += sizeof(double) * S;
298+
299+
//labels w/blanks, e_inc, s_inc
300+
per_minibatch_bytes += 3 * sizeof(int) * S;
301+
302+
*size_bytes = per_minibatch_bytes * minibatch;
303+
304+
//probs
305+
*size_bytes += sizeof(double) * alphabet_size * maxT * minibatch;
306+
}
307+
308+
return CTC_STATUS_SUCCESS;
309+
}
310+
175311
}

src/reduce.cu

+15-7
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,23 @@ ctcStatus_t reduce(Iof f, Rof g, const T* input, T* output, int rows, int cols,
148148

149149
return CTC_STATUS_SUCCESS;
150150
}
151-
152-
ctcStatus_t reduce_negate(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
153-
return reduce(ctc_helper::negate<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
151+
template<typename T>
152+
ctcStatus_t reduce_negate(const T *input, T *output, int rows, int cols, bool axis, cudaStream_t stream) {
153+
return reduce(ctc_helper::negate<T>(), ctc_helper::add<T>(), input, output, rows, cols, axis, stream);
154154
}
155+
template ctcStatus_t reduce_negate<float>(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream);
156+
template ctcStatus_t reduce_negate<double>(const double *input, double *output, int rows, int cols, bool axis, cudaStream_t stream);
155157

156-
ctcStatus_t reduce_exp(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
157-
return reduce(ctc_helper::exponential<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
158+
template<typename T>
159+
ctcStatus_t reduce_exp(const T *input, T *output, int rows, int cols, bool axis, cudaStream_t stream) {
160+
return reduce(ctc_helper::exponential<T>(), ctc_helper::add<T>(), input, output, rows, cols, axis, stream);
158161
}
162+
template ctcStatus_t reduce_exp<float>(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream);
163+
template ctcStatus_t reduce_exp<double>(const double *input, double *output, int rows, int cols, bool axis, cudaStream_t stream);
159164

160-
ctcStatus_t reduce_max(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
161-
return reduce(ctc_helper::identity<float>(), ctc_helper::maximum<float>(),input, output, rows, cols, axis, stream);
165+
template<typename T>
166+
ctcStatus_t reduce_max(const T *input, T *output, int rows, int cols, bool axis, cudaStream_t stream) {
167+
return reduce(ctc_helper::identity<T>(), ctc_helper::maximum<T>(),input, output, rows, cols, axis, stream);
162168
}
169+
template ctcStatus_t reduce_max<float>(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream);
170+
template ctcStatus_t reduce_max<double>(const double *input, double *output, int rows, int cols, bool axis, cudaStream_t stream);

0 commit comments

Comments
 (0)