diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 7f81e4916b..4caa4a7ead 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -22,10 +22,10 @@ inline int max(int a, int b) { return a >= b ? a : b; } -template +template struct AvgPool2dKernelFunctor { void operator()(sycl::nd_item<1> item) const { - auto index = item.get_global_linear_id(); + index_t index = item.get_global_linear_id(); if (index < total_elements_) { const int pw = index % pooled_width_; @@ -73,10 +73,10 @@ struct AvgPool2dKernelFunctor { AvgPool2dKernelFunctor( scalar_t* top_data, const scalar_t* bottom_data, - int64_t total_elements, - int64_t channels, - int64_t height, - int64_t width, + index_t total_elements, + index_t channels, + index_t height, + index_t width, int pooled_height, int pooled_width, int kernel_h, @@ -109,10 +109,10 @@ struct AvgPool2dKernelFunctor { private: scalar_t* top_data_; const scalar_t* bottom_data_; - int64_t total_elements_; - int64_t channels_; - int64_t height_; - int64_t width_; + index_t total_elements_; + index_t channels_; + index_t height_; + index_t width_; int pooled_height_; int pooled_width_; int kernel_h_; @@ -126,10 +126,10 @@ struct AvgPool2dKernelFunctor { bool use_divisor_; }; -template +template struct AvgPool2dChannelsLastKernelFunctor { void operator()(sycl::nd_item<1> item) const { - auto index = item.get_global_linear_id(); + index_t index = item.get_global_linear_id(); if (index < total_elements_) { const int c = index % channels_; @@ -175,10 +175,10 @@ struct AvgPool2dChannelsLastKernelFunctor { AvgPool2dChannelsLastKernelFunctor( scalar_t* top_data, const scalar_t* bottom_data, - int64_t total_elements, - int64_t channels, - int64_t height, - int64_t width, + index_t total_elements, + index_t channels, + index_t height, + index_t width, int pooled_height, int pooled_width, int kernel_h, @@ -211,10 +211,10 @@ struct AvgPool2dChannelsLastKernelFunctor { private: scalar_t* top_data_; const scalar_t* bottom_data_; - int64_t total_elements_; - int64_t channels_; - int64_t height_; - int64_t width_; + index_t total_elements_; + index_t channels_; + index_t height_; + index_t width_; int pooled_height_; int pooled_width_; int kernel_h_; @@ -228,13 +228,13 @@ struct AvgPool2dChannelsLastKernelFunctor { bool use_divisor_; }; -template +template void launch_avg_pool2d_channels_last_kernel( const int total_elements, const Tensor& input, - const int64_t channels, - const int64_t height, - const int64_t width, + const index_t channels, + const index_t height, + const index_t width, const int pooled_height, const int pooled_width, const int kernel_h, @@ -255,7 +255,7 @@ void launch_avg_pool2d_channels_last_kernel( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto kfn = AvgPool2dChannelsLastKernelFunctor( + auto kfn = AvgPool2dChannelsLastKernelFunctor( top_data, bottom_data, total_elements, @@ -276,13 +276,13 @@ void launch_avg_pool2d_channels_last_kernel( sycl_kernel_submit(global_range, group_size, queue, kfn); } -template +template void launch_avg_pool2d_kernel( const int total_elements, const Tensor& input, - const int64_t channels, - const int64_t height, - const int64_t width, + const index_t channels, + const index_t height, + const index_t width, const int pooled_height, const int pooled_width, const int kernel_h, @@ -303,7 +303,7 @@ void launch_avg_pool2d_kernel( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto kfn = AvgPool2dKernelFunctor( + auto kfn = AvgPool2dKernelFunctor( top_data, bottom_data, total_elements, @@ -664,58 +664,67 @@ void avg_pool2d_kernel( AT_DISPATCH_FLOATING_TYPES_AND2( kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_xpu", [&] { using accscalar_t = acc_type_device; - - switch (memory_format) { - case MemoryFormat::ChannelsLast: { - output.unsafeGetTensorImpl()->empty_tensor_restride( - MemoryFormat::ChannelsLast); - launch_avg_pool2d_channels_last_kernel( - count, - input, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH_, - kW_, - dH_, - dW_, - padH_, - padW_, - output, - divisor_override_value, - count_include_pad, - use_divisor); - break; - } - case MemoryFormat::Contiguous: { - launch_avg_pool2d_kernel( - count, - input, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH_, - kW_, - dH_, - dW_, - padH_, - padW_, - output, - divisor_override_value, - count_include_pad, - use_divisor); - break; - } - default: - TORCH_CHECK( - false, - "Unsupported memory format. Supports only " - "ChannelsLast, Contiguous"); - } + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(output, INT_MAX) + ? ScalarType::Int + : ScalarType::Long, + "avg_pool2d_xpu", + [&] { + switch (memory_format) { + case MemoryFormat::ChannelsLast: { + output.unsafeGetTensorImpl()->empty_tensor_restride( + MemoryFormat::ChannelsLast); + launch_avg_pool2d_channels_last_kernel< + scalar_t, + accscalar_t, + index_t>( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH_, + kW_, + dH_, + dW_, + padH_, + padW_, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + case MemoryFormat::Contiguous: { + launch_avg_pool2d_kernel( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH_, + kW_, + dH_, + dW_, + padH_, + padW_, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only " + "ChannelsLast, Contiguous"); + } + }); }); } }