diff --git a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp index d45d06545..99ac4cb19 100644 --- a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp +++ b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp @@ -471,11 +471,11 @@ struct NllLossBackwardReduce2DKernelFunctor { ? static_cast(gradOutput_ptr[0]) / static_cast(*total_weight_ptr) : static_cast(gradOutput_ptr[0])); - for (i = local_item_id; i < nframe; i += local_size) { - t = (int)target_ptr[i]; + i = local_item_id; + if (i < ndim * nframe) { + t = (int)target_ptr[i % ndim]; if (t != (int)ignore_index) { - gradInput_ptr[i * ndim + t] = - has_weights ? weights_ptr[t] * grad : grad; + gradInput_ptr[i] = has_weights ? weights_ptr[t] * grad : grad; } } }