Skip to content

Commit 3952a3c

Browse files
committed
Remove a cuda stream synchronize.
1 parent 3a2ca8f commit 3952a3c

File tree

2 files changed

+22
-59
lines changed

2 files changed

+22
-59
lines changed

include/detail/gpu_ctc.h

+12-59
Original file line numberDiff line numberDiff line change
@@ -204,30 +204,14 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
204204
Lmax = std::max(Lmax, L);
205205
}
206206

207-
#ifdef __HIPCC__
208-
cuda_status = hipMemcpyAsync(&(repeats_[start_idx]), repeats,
209-
(end_idx - start_idx) * sizeof(int),
210-
hipMemcpyHostToDevice, stream_);
211-
#else
212-
cuda_status = cudaMemcpyAsync(&(repeats_[start_idx]), repeats,
213-
(end_idx - start_idx) * sizeof(int),
214-
cudaMemcpyHostToDevice, stream_);
215-
#endif
216-
207+
cuda_status = warpctc::memcpy_h2d_async(
208+
&(repeats_[start_idx]), repeats, (end_idx - start_idx) * sizeof(int), stream_);
217209
if (cuda_status != gpuSuccess)
218210
return CTC_STATUS_MEMOPS_FAILED;
219211

220212

221-
#ifdef __HIPCC__
222-
cuda_status = hipMemcpyAsync(&(label_offsets_[start_idx]), label_offsets,
223-
(end_idx - start_idx) * sizeof(int),
224-
hipMemcpyHostToDevice, stream_);
225-
#else
226-
cuda_status = cudaMemcpyAsync(&(label_offsets_[start_idx]), label_offsets,
227-
(end_idx - start_idx) * sizeof(int),
228-
cudaMemcpyHostToDevice, stream_);
229-
#endif
230-
213+
cuda_status = warpctc::memcpy_h2d_async(
214+
&(label_offsets_[start_idx]), label_offsets, (end_idx - start_idx) * sizeof(int), stream_);
231215
if (cuda_status != gpuSuccess)
232216
return CTC_STATUS_MEMOPS_FAILED;
233217
}
@@ -243,16 +227,8 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
243227
gpu_bytes_used);
244228
gpu_bytes_used += minibatch_ * sizeof(int);
245229

246-
#ifdef __HIPCC__
247-
cuda_status = hipMemcpyAsync(utt_length_, input_lengths,
248-
minibatch_ * sizeof(int),
249-
hipMemcpyHostToDevice, stream_);
250-
#else
251-
cuda_status = cudaMemcpyAsync(utt_length_, input_lengths,
252-
minibatch_ * sizeof(int),
253-
cudaMemcpyHostToDevice, stream_);
254-
#endif
255-
230+
cuda_status = warpctc::memcpy_h2d_async(
231+
utt_length_, input_lengths, minibatch_ * sizeof(int), stream_);
256232
if (cuda_status != gpuSuccess)
257233
return CTC_STATUS_MEMOPS_FAILED;
258234

@@ -261,16 +237,8 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
261237
gpu_bytes_used);
262238
gpu_bytes_used += minibatch_ * sizeof(int);
263239

264-
#ifdef __HIPCC__
265-
cuda_status = hipMemcpyAsync(label_sizes_, label_lengths,
266-
minibatch_ * sizeof(int),
267-
hipMemcpyHostToDevice, stream_);
268-
#else
269-
cuda_status = cudaMemcpyAsync(label_sizes_, label_lengths,
270-
minibatch_ * sizeof(int),
271-
cudaMemcpyHostToDevice, stream_);
272-
#endif
273-
240+
cuda_status = warpctc::memcpy_h2d_async(
241+
label_sizes_, label_lengths, minibatch_ * sizeof(int), stream_);
274242
if (cuda_status != gpuSuccess)
275243
return CTC_STATUS_MEMOPS_FAILED;
276244

@@ -279,16 +247,8 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
279247
gpu_bytes_used);
280248
gpu_bytes_used += Lmax * minibatch_ * sizeof(int);
281249

282-
#ifdef __HIPCC__
283-
cuda_status = hipMemcpyAsync(labels_without_blanks_, flat_labels,
284-
total_label_length * sizeof(int),
285-
hipMemcpyHostToDevice, stream_);
286-
#else
287-
cuda_status = cudaMemcpyAsync(labels_without_blanks_, flat_labels,
288-
total_label_length * sizeof(int),
289-
cudaMemcpyHostToDevice, stream_);
290-
#endif
291-
250+
cuda_status = warpctc::memcpy_h2d_async(
251+
labels_without_blanks_, flat_labels, total_label_length * sizeof(int), stream_);
292252
if (cuda_status != gpuSuccess)
293253
return CTC_STATUS_MEMOPS_FAILED;
294254

@@ -302,7 +262,6 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
302262
gpu_bytes_used);
303263
gpu_bytes_used += (S_ * T_) * minibatch_ * sizeof(ProbT);
304264

305-
306265
denoms_ =
307266
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
308267
gpu_bytes_used);
@@ -330,25 +289,19 @@ ctcStatus_t GpuCTC<ProbT>::launch_alpha_beta_kernels(const ProbT* const probs,
330289
// away
331290
const int stride = minibatch_;
332291

333-
if (compute_alpha)
292+
if (compute_alpha) {
334293
compute_alpha_kernel<ProbT, NT, VT><<<grid_size, NT, 0, stream_>>>
335294
(probs, label_sizes_, utt_length_,
336295
repeats_, labels_without_blanks_, label_offsets_,
337296
labels_with_blanks_, alphas_, nll_forward_,
338297
stride, out_dim_, S_, T_, blank_label_);
339-
298+
}
340299

341300
if (compute_beta) {
342301
compute_betas_and_grad_kernel<ProbT, NT, VT><<<grid_size, NT, 0, stream_>>>
343302
(probs, label_sizes_, utt_length_, repeats_,
344303
labels_with_blanks_, alphas_, nll_forward_, nll_backward_,
345304
grads, stride, out_dim_, S_, T_, blank_label_);
346-
347-
#ifdef __HIPCC__
348-
hipStreamSynchronize(stream_);
349-
#else
350-
cudaStreamSynchronize(stream_);
351-
#endif
352305
}
353306

354307
#ifdef __HIPCC__

include/detail/gpu_helper.h

+10
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ static gpuError_t memcpy_d2h_async(void *dst, const void *src, size_t bytes, GPU
1414
return status;
1515
}
1616

17+
static gpuError_t memcpy_h2d_async(void *dst, const void *src, size_t bytes, GPUstream stream) {
18+
gpuError_t status;
19+
#ifdef __HIPCC__
20+
status = hipMemcpyAsync(dst, src, bytes, hipMemcpyHostToDevice, stream);
21+
#else
22+
status = cudaMemcpyAsync(dst, src, bytes, cudaMemcpyHostToDevice, stream);
23+
#endif
24+
return status;
25+
}
26+
1727
static gpuError_t synchronize(GPUstream stream) {
1828
gpuError_t status;
1929
#ifdef __HIPCC__

0 commit comments

Comments
 (0)