Skip to content

Commit 840e71f

Browse files
Yixin Baofacebook-github-bot
Yixin Bao
authored andcommitted
Check CUDA kernel launches (/fbcode/caffe2/) (pytorch#49145)
Summary: Pull Request resolved: pytorch#49145 Pull Request resolved: pytorch#49105 (1) Add a safety check `C10_CUDA_KERNEL_LAUNCH_CHECK()` after each kernel launch. This diff only changes the files inside the directory /fbsource/fbcode/caffe2/modules/, /fbsource/fbcode/caffe2/fb/, /fbsource/fbcode/caffe2/test/. (2) Get rid of old check `AT_CUDA_CHECK(cudaGetLastError())` when necessary. Test Plan: Test build: ``` buck build mode/dev-nosan //caffe2/modules/detectron: buck test mode/dev-nosan //caffe2/modules/detectron: buck build mode/dev-nosan //caffe2/torch/fb/: buck test mode/dev-nosan //caffe2/torch/fb/: ``` To check for launches without checks: ``` python3 caffe2/torch/testing/check_kernel_launches.py ``` Make sure none of the updated files are in the returned list. Reviewed By: r-barnes Differential Revision: D25452852 fbshipit-source-id: d6657edab612c9e0fa99b29c68460be8b1a20064
1 parent 524adfb commit 840e71f

14 files changed

+36
-0
lines changed

modules/detectron/group_spatial_softmax_op.cu

+3
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ bool GroupSpatialSoftmaxOp<float, CUDAContext>::RunOnDevice() {
112112
GroupSpatialSoftmaxKernel<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS,
113113
0, context_.cuda_stream()>>>(
114114
N, A, W, H, Xdata, Pdata, num_classes_);
115+
C10_CUDA_KERNEL_LAUNCH_CHECK();
115116
return true;
116117
}
117118

@@ -158,11 +159,13 @@ bool GroupSpatialSoftmaxGradientOp<float, CUDAContext>::RunOnDevice() {
158159
SumProbsKernel<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS, 0,
159160
context_.cuda_stream()>>>(
160161
N, A, W, H, Ydata, dYdata, sum_probs_data, num_classes_);
162+
C10_CUDA_KERNEL_LAUNCH_CHECK();
161163

162164
// Step 2: dX[i] = dX[i] - s
163165
SubSumKernel<<<CAFFE_GET_BLOCKS(Y.size()), CAFFE_CUDA_NUM_THREADS, 0,
164166
context_.cuda_stream()>>>(
165167
N, A, W, H, sum_probs_.data<float>(), dXdata, num_classes_);
168+
C10_CUDA_KERNEL_LAUNCH_CHECK();
166169

167170
// Step 3: dX[i] = Y[i] * dX[i]
168171
math::Mul<float, CUDAContext>(Y.size(), dXdata, Ydata, dXdata, &context_);

modules/detectron/ps_roi_pool_op.cu

+2
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ bool PSRoIPoolOp<float, CUDAContext>::RunOnDevice() {
253253
output_size, X.data<float>(), spatial_scale_, X.dim32(1), X.dim32(2),
254254
X.dim32(3), pooled_height_, pooled_width_, R.data<float>(), output_dim_,
255255
group_size_, Y->mutable_data<float>(), A->mutable_data<int>());
256+
C10_CUDA_KERNEL_LAUNCH_CHECK();
256257
return true;
257258
}
258259

@@ -276,6 +277,7 @@ bool PSRoIPoolGradientOp<float, CUDAContext>::RunOnDevice() {
276277
dY.size(), dY.data<float>(), A.data<int>(), R.dim32(0), spatial_scale_,
277278
X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_,
278279
output_dim_, dX->mutable_data<float>(), R.data<float>());
280+
C10_CUDA_KERNEL_LAUNCH_CHECK();
279281
return true;
280282
}
281283

modules/detectron/roi_pool_f_op.cu

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ bool RoIPoolFOp<float, CUDAContext>::RunOnDevice() {
149149
output_size, X.data<float>(), spatial_scale_, X.dim32(1), X.dim32(2),
150150
X.dim32(3), pooled_height_, pooled_width_, R.data<float>(),
151151
Y->mutable_data<float>(), A->mutable_data<int>());
152+
C10_CUDA_KERNEL_LAUNCH_CHECK();
152153
return true;
153154
}
154155

@@ -173,6 +174,7 @@ bool RoIPoolFGradientOp<float, CUDAContext>::RunOnDevice() {
173174
dY.size(), dY.data<float>(), A.data<int>(), R.dim32(0), spatial_scale_,
174175
X.dim32(1), X.dim32(2), X.dim32(3), pooled_height_, pooled_width_,
175176
dX->mutable_data<float>(), R.data<float>());
177+
C10_CUDA_KERNEL_LAUNCH_CHECK();
176178
}
177179
return true;
178180
}

modules/detectron/select_smooth_l1_loss_op.cu

+2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ bool SelectSmoothL1LossOp<float, CUDAContext>::RunOnDevice() {
129129
M, Y_hat.data<float>(), Y.data<float>(),
130130
L.data<float>(), buff_.mutable_data<float>(),
131131
S.data<float>(), beta_);
132+
C10_CUDA_KERNEL_LAUNCH_CHECK();
132133

133134
// Sum of all losses
134135
// al := sum_i l_i
@@ -175,6 +176,7 @@ bool SelectSmoothL1LossGradientOp<float, CUDAContext>::RunOnDevice() {
175176
D, H, W, M, Y_hat.data<float>(), Y.data<float>(),
176177
L.data<float>(), d_Y_hat->mutable_data<float>(),
177178
d_avg_loss.data<float>(), scale_, S.data<float>(), beta_);
179+
C10_CUDA_KERNEL_LAUNCH_CHECK();
178180
179181
return true;
180182
}

modules/detectron/sigmoid_cross_entropy_loss_op.cu

+5
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ bool SigmoidCrossEntropyLossOp<float, CUDAContext>::RunOnDevice() {
9393
T.data<int>(),
9494
losses_.mutable_data<float>(),
9595
counts_.mutable_data<float>());
96+
C10_CUDA_KERNEL_LAUNCH_CHECK();
97+
9698
float* avg_loss_data = avg_loss->mutable_data<float>();
9799
math::Sum<float, CUDAContext>(
98100
losses_.size(), losses_.data<float>(), avg_loss_data, &context_);
@@ -106,6 +108,7 @@ bool SigmoidCrossEntropyLossOp<float, CUDAContext>::RunOnDevice() {
106108
CAFFE_CUDA_NUM_THREADS,
107109
0,
108110
context_.cuda_stream()>>>(normalizer_.size(), normalizer_data, 1e-5);
111+
C10_CUDA_KERNEL_LAUNCH_CHECK();
109112
math::Div<float, CUDAContext>(
110113
1, avg_loss_data, normalizer_data, avg_loss_data, &context_);
111114
}
@@ -135,6 +138,7 @@ bool SigmoidCrossEntropyLossGradientOp<float, CUDAContext>::RunOnDevice() {
135138
T.data<int>(),
136139
dX->mutable_data<float>(),
137140
counts_.mutable_data<float>());
141+
C10_CUDA_KERNEL_LAUNCH_CHECK();
138142
if (normalize_) {
139143
float* normalizer_data = normalizer_.mutable_data<float>();
140144
math::Sum<float, CUDAContext>(
@@ -145,6 +149,7 @@ bool SigmoidCrossEntropyLossGradientOp<float, CUDAContext>::RunOnDevice() {
145149
CAFFE_CUDA_NUM_THREADS,
146150
0,
147151
context_.cuda_stream()>>>(normalizer_.size(), normalizer_data, 1e-5);
152+
C10_CUDA_KERNEL_LAUNCH_CHECK();
148153
math::Div<float, CUDAContext>(
149154
1,
150155
d_avg_loss.data<float>(),

modules/detectron/sigmoid_focal_loss_op.cu

+2
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ bool SigmoidFocalLossOp<float, CUDAContext>::RunOnDevice() {
134134
N, D, H, W, X.data<float>(), T.data<int>(),
135135
wp.data<float>(), gamma_, alpha_, num_classes_,
136136
losses_.mutable_data<float>());
137+
C10_CUDA_KERNEL_LAUNCH_CHECK();
137138

138139
math::Sum<float, CUDAContext>(
139140
losses_.size(), losses_.data<float>(), avg_loss_data, &context_);
@@ -165,6 +166,7 @@ bool SigmoidFocalLossGradientOp<float, CUDAContext>::RunOnDevice() {
165166
N, D, H, W, X.data<float>(), T.data<int>(), dX->mutable_data<float>(),
166167
wp.data<float>(), gamma_, alpha_, num_classes_,
167168
d_avg_loss.data<float>());
169+
C10_CUDA_KERNEL_LAUNCH_CHECK();
168170
math::Scale<float, float, CUDAContext>(
169171
dX->size(),
170172
scale_,

modules/detectron/smooth_l1_loss_op.cu

+3
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ bool SmoothL1LossOp<float, CUDAContext>::RunOnDevice() {
102102
context_.cuda_stream()>>>(
103103
buff_.size(), buff_.data<float>(), buff_.mutable_data<float>(),
104104
beta_);
105+
C10_CUDA_KERNEL_LAUNCH_CHECK();
105106

106107
// Element-wise weighted smooth l1 loss (can be used to specify a per-element
107108
// loss weight)
@@ -164,6 +165,8 @@ bool SmoothL1LossGradientOp<float, CUDAContext>::RunOnDevice() {
164165
context_.cuda_stream()>>>(
165166
buff_.size(), buff_.data<float>(), d_Y_hat->mutable_data<float>(),
166167
d_avg_loss.data<float>(), scale_ / N, beta_);
168+
C10_CUDA_KERNEL_LAUNCH_CHECK();
169+
167170
// Element-wise scale by alpha_in and alpha_out
168171
math::Mul<float, CUDAContext>(
169172
d_Y_hat->size(), d_Y_hat->data<float>(), alpha_in.data<float>(),

modules/detectron/softmax_focal_loss_op.cu

+5
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ bool SoftmaxFocalLossOp<float, CUDAContext>::RunOnDevice() {
176176
<<<CAFFE_GET_BLOCKS(N * A * H * W), CAFFE_CUDA_NUM_THREADS,
177177
0, context_.cuda_stream()>>>(
178178
N, A, H, W, Xdata, P->mutable_data<float>(), num_classes_);
179+
C10_CUDA_KERNEL_LAUNCH_CHECK();
179180

180181
// Compute loss for each x,y location
181182
const int* Tdata = T.data<int>();
@@ -184,6 +185,7 @@ bool SoftmaxFocalLossOp<float, CUDAContext>::RunOnDevice() {
184185
0, context_.cuda_stream()>>>(
185186
N, A, H, W, P->data<float>(), Tdata, losses_.mutable_data<float>(),
186187
Wdata, gamma_, alpha_, num_classes_);
188+
C10_CUDA_KERNEL_LAUNCH_CHECK();
187189

188190
// sum the losses
189191
float* avg_loss_data = avg_loss->mutable_data<float>();
@@ -227,13 +229,16 @@ bool SoftmaxFocalLossGradientOp<float, CUDAContext>::RunOnDevice() {
227229
0, context_.cuda_stream()>>>(
228230
N, A, H, W, Pdata, Tdata, buff_.mutable_data<float>(),
229231
Wdata, gamma_, alpha_, num_classes_);
232+
C10_CUDA_KERNEL_LAUNCH_CHECK();
233+
230234
// Compute the gradient with the weights
231235
const float* Bdata = buff_.data<float>();
232236
SoftmaxFocalLossGradientKernel
233237
<<<CAFFE_GET_BLOCKS(N * D * H * W), CAFFE_CUDA_NUM_THREADS,
234238
0, context_.cuda_stream()>>>(
235239
N, D, H, W, Pdata, Tdata, Bdata, d_avg_loss.data<float>(),
236240
dX->mutable_data<float>(), num_classes_);
241+
C10_CUDA_KERNEL_LAUNCH_CHECK();
237242
math::Scale<float, float, CUDAContext>(
238243
dX->size(),
239244
scale_,

modules/detectron/spatial_narrow_as_op.cu

+2
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ bool SpatialNarrowAsOp<CUDAContext>::DoRunWithType() {
115115
out_width,
116116
A.template data<T>(),
117117
C->template mutable_data<T>());
118+
C10_CUDA_KERNEL_LAUNCH_CHECK();
118119
119120
return true;
120121
}
@@ -152,6 +153,7 @@ bool SpatialNarrowAsGradientOp<CUDAContext>::DoRunWithType() {
152153
out_width,
153154
dC.template data<T>(),
154155
dA->template mutable_data<T>());
156+
C10_CUDA_KERNEL_LAUNCH_CHECK();
155157
156158
return true;
157159
}

modules/detectron/upsample_nearest_op.cu

+3
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ bool UpsampleNearestOp<float, CUDAContext>::RunOnDevice() {
164164

165165
upscale<<<blocks, threads, 0, context_.cuda_stream()>>>(
166166
input_data, output_data, no_elements, scale_, d1, d2, d3);
167+
C10_CUDA_KERNEL_LAUNCH_CHECK();
168+
167169
return true;
168170
}
169171

@@ -209,6 +211,7 @@ bool UpsampleNearestGradientOp<float, CUDAContext>::RunOnDevice() {
209211
math::Set<float, CUDAContext>(no_elements, 0.f, gradInput_data, &context_);
210212
downscale<<<blocks, threads, 0, context_.cuda_stream()>>>(
211213
gradInput_data, gradOutput_data, no_elements, scale_, d1, d2, d3);
214+
C10_CUDA_KERNEL_LAUNCH_CHECK();
212215

213216
return true;
214217
}

test/cpp_extensions/cuda_extension.cu

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <cuda.h>
88
#include <cuda_runtime.h>
9+
#include <c10/cuda/CUDAException.h>
910

1011
#include <ATen/ATen.h>
1112

@@ -26,4 +27,5 @@ void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) {
2627
const int threads = 1024;
2728
const int blocks = (size + threads - 1) / threads;
2829
sigmoid_add_kernel<<<blocks, threads>>>(x, y, output, size);
30+
C10_CUDA_KERNEL_LAUNCH_CHECK();
2931
}

test/cpp_extensions/cuda_extension_kernel.cu

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cuda.h>
22
#include <cuda_runtime.h>
3+
#include <c10/cuda/CUDAException.h>
34

45
#include <ATen/ATen.h>
56

@@ -20,4 +21,5 @@ void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) {
2021
const int threads = 1024;
2122
const int blocks = (size + threads - 1) / threads;
2223
sigmoid_add_kernel<<<blocks, threads>>>(x, y, output, size);
24+
C10_CUDA_KERNEL_LAUNCH_CHECK();
2325
}

test/cpp_extensions/cuda_extension_kernel2.cu

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cuda.h>
22
#include <cuda_runtime.h>
3+
#include <c10/cuda/CUDAException.h>
34

45
#include <ATen/ATen.h>
56

@@ -20,4 +21,5 @@ void tanh_add_cuda(const float* x, const float* y, float* output, int size) {
2021
const int threads = 1024;
2122
const int blocks = (size + threads - 1) / threads;
2223
tanh_add_kernel<<<blocks, threads>>>(x, y, output, size);
24+
C10_CUDA_KERNEL_LAUNCH_CHECK();
2325
}

torch/lib/c10d/test/CUDATest.cu

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ __global__ void waitClocks(const uint64_t count) {
1717

1818
void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks) {
1919
waitClocks<<<1, 1, 0, stream.stream()>>>(clocks);
20+
C10_CUDA_KERNEL_LAUNCH_CHECK();
2021
}
2122

2223
int cudaNumDevices() {

0 commit comments

Comments
 (0)