From b9c8bd968bd8dd37b3ef3001a296e45f1938aaea Mon Sep 17 00:00:00 2001 From: "Meng, Chunhuan" Date: Fri, 25 Apr 2025 07:51:03 +0000 Subject: [PATCH 1/4] Enhance Performance by Adding `index_t` Template Parameter --- .../native/xpu/sycl/AveragePool2dKernels.cpp | 171 +++++++++--------- 1 file changed, 88 insertions(+), 83 deletions(-) diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 9002ce3ff1..564c1f1df3 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -22,10 +22,10 @@ inline int max(int a, int b) { return a >= b ? a : b; } -template +template struct AvgPool2dKernelFunctor { void operator()(sycl::nd_item<1> item) const { - auto index = item.get_global_linear_id(); + index_t index = item.get_global_linear_id(); if (index < total_elements_) { const int pw = index % pooled_width_; @@ -73,10 +73,10 @@ struct AvgPool2dKernelFunctor { AvgPool2dKernelFunctor( scalar_t* top_data, const scalar_t* bottom_data, - int64_t total_elements, - int64_t channels, - int64_t height, - int64_t width, + index_t total_elements, + index_t channels, + index_t height, + index_t width, int pooled_height, int pooled_width, int kernel_h, @@ -109,10 +109,10 @@ struct AvgPool2dKernelFunctor { private: scalar_t* top_data_; const scalar_t* bottom_data_; - int64_t total_elements_; - int64_t channels_; - int64_t height_; - int64_t width_; + index_t total_elements_; + index_t channels_; + index_t height_; + index_t width_; int pooled_height_; int pooled_width_; int kernel_h_; @@ -126,10 +126,10 @@ struct AvgPool2dKernelFunctor { bool use_divisor_; }; -template +template struct AvgPool2dChannelsLastKernelFunctor { void operator()(sycl::nd_item<1> item) const { - auto index = item.get_global_linear_id(); + index_t index = item.get_global_linear_id(); if (index < total_elements_) { const int c = index % channels_; @@ -175,10 +175,10 @@ struct AvgPool2dChannelsLastKernelFunctor { AvgPool2dChannelsLastKernelFunctor( scalar_t* top_data, const scalar_t* bottom_data, - int64_t total_elements, - int64_t channels, - int64_t height, - int64_t width, + index_t total_elements, + index_t channels, + index_t height, + index_t width, int pooled_height, int pooled_width, int kernel_h, @@ -211,10 +211,10 @@ struct AvgPool2dChannelsLastKernelFunctor { private: scalar_t* top_data_; const scalar_t* bottom_data_; - int64_t total_elements_; - int64_t channels_; - int64_t height_; - int64_t width_; + index_t total_elements_; + index_t channels_; + index_t height_; + index_t width_; int pooled_height_; int pooled_width_; int kernel_h_; @@ -228,13 +228,13 @@ struct AvgPool2dChannelsLastKernelFunctor { bool use_divisor_; }; -template +template void launch_avg_pool2d_channels_last_kernel( const int total_elements, const Tensor& input, - const int64_t channels, - const int64_t height, - const int64_t width, + const index_t channels, + const index_t height, + const index_t width, const int pooled_height, const int pooled_width, const int kernel_h, @@ -255,7 +255,7 @@ void launch_avg_pool2d_channels_last_kernel( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto kfn = AvgPool2dChannelsLastKernelFunctor( + auto kfn = AvgPool2dChannelsLastKernelFunctor( top_data, bottom_data, total_elements, @@ -276,13 +276,13 @@ void launch_avg_pool2d_channels_last_kernel( sycl_kernel_submit(global_range, group_size, queue, kfn); } -template +template void launch_avg_pool2d_kernel( const int total_elements, const Tensor& input, - const int64_t channels, - const int64_t height, - const int64_t width, + const index_t channels, + const index_t height, + const index_t width, const int pooled_height, const int pooled_width, const int kernel_h, @@ -303,7 +303,7 @@ void launch_avg_pool2d_kernel( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto kfn = AvgPool2dKernelFunctor( + auto kfn = AvgPool2dKernelFunctor( top_data, bottom_data, total_elements, @@ -664,58 +664,63 @@ void avg_pool2d_kernel( AT_DISPATCH_FLOATING_TYPES_AND2( kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_xpu", [&] { using accscalar_t = acc_type_device; - - switch (memory_format) { - case MemoryFormat::ChannelsLast: { - output.unsafeGetTensorImpl()->empty_tensor_restride( - MemoryFormat::ChannelsLast); - launch_avg_pool2d_channels_last_kernel( - count, - input, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH_, - kW_, - dH_, - dW_, - padH_, - padW_, - output, - divisor_override_value, - count_include_pad, - use_divisor); - break; - } - case MemoryFormat::Contiguous: { - launch_avg_pool2d_kernel( - count, - input, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH_, - kW_, - dH_, - dW_, - padH_, - padW_, - output, - divisor_override_value, - count_include_pad, - use_divisor); - break; - } - default: - TORCH_CHECK( - false, - "Unsupported memory format. Supports only " - "ChannelsLast, Contiguous"); - } + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(output, INT_MAX) ? ScalarType::Int + : ScalarType::Long, + "avg_pool2d_backward_xpu", + [&] { + switch (memory_format) { + case MemoryFormat::ChannelsLast: { + output.unsafeGetTensorImpl()->empty_tensor_restride( + MemoryFormat::ChannelsLast); + launch_avg_pool2d_channels_last_kernel( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH_, + kW_, + dH_, + dW_, + padH_, + padW_, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + case MemoryFormat::Contiguous: { + launch_avg_pool2d_kernel( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH_, + kW_, + dH_, + dW_, + padH_, + padW_, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only " + "ChannelsLast, Contiguous"); + } + }); }); } } @@ -835,4 +840,4 @@ void avg_pool2d_backward_kernel( } } // namespace xpu -} // namespace at::native +} // namespace at::native \ No newline at end of file From 9edfc2ea70cd134e555828ff763c815de808fd86 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:34:38 +0800 Subject: [PATCH 2/4] correct macro identifier --- src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 564c1f1df3..27ecbbc989 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -667,7 +667,7 @@ void avg_pool2d_kernel( AT_DISPATCH_INDEX_TYPES( at::native::canUse32BitIndexMath(output, INT_MAX) ? ScalarType::Int : ScalarType::Long, - "avg_pool2d_backward_xpu", + "avg_pool2d_xpu", [&] { switch (memory_format) { case MemoryFormat::ChannelsLast: { @@ -840,4 +840,4 @@ void avg_pool2d_backward_kernel( } } // namespace xpu -} // namespace at::native \ No newline at end of file +} // namespace at::native From 204f3dfd17e9090d9859beab069f700fc0cccd5d Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:52:12 +0800 Subject: [PATCH 3/4] format --- src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 7e68f3844e..bedb37d4bd 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -22,7 +22,7 @@ inline int max(int a, int b) { return a >= b ? a : b; } -template +template struct AvgPool2dKernelFunctor { void operator()(sycl::nd_item<1> item) const { index_t index = item.get_global_linear_id(); @@ -126,7 +126,7 @@ struct AvgPool2dKernelFunctor { bool use_divisor_; }; -template +template struct AvgPool2dChannelsLastKernelFunctor { void operator()(sycl::nd_item<1> item) const { index_t index = item.get_global_linear_id(); @@ -228,7 +228,7 @@ struct AvgPool2dChannelsLastKernelFunctor { bool use_divisor_; }; -template +template void launch_avg_pool2d_channels_last_kernel( const int total_elements, const Tensor& input, @@ -255,7 +255,7 @@ void launch_avg_pool2d_channels_last_kernel( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto kfn = AvgPool2dChannelsLastKernelFunctor( + auto kfn = AvgPool2dChannelsLastKernelFunctor( top_data, bottom_data, total_elements, @@ -276,7 +276,7 @@ void launch_avg_pool2d_channels_last_kernel( sycl_kernel_submit(global_range, group_size, queue, kfn); } -template +template void launch_avg_pool2d_kernel( const int total_elements, const Tensor& input, @@ -303,7 +303,7 @@ void launch_avg_pool2d_kernel( const uint32_t global_range = ceil_div(total_elements, group_size) * group_size; - auto kfn = AvgPool2dKernelFunctor( + auto kfn = AvgPool2dKernelFunctor( top_data, bottom_data, total_elements, From 78e0a1455cb05de1a5322ed7c1e21eeb7f3f19ff Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:32:17 +0800 Subject: [PATCH 4/4] format --- .../native/xpu/sycl/AveragePool2dKernels.cpp | 116 +++++++++--------- 1 file changed, 60 insertions(+), 56 deletions(-) diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index bedb37d4bd..4caa4a7ead 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -664,63 +664,67 @@ void avg_pool2d_kernel( AT_DISPATCH_FLOATING_TYPES_AND2( kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_xpu", [&] { using accscalar_t = acc_type_device; - AT_DISPATCH_INDEX_TYPES( - at::native::canUse32BitIndexMath(output, INT_MAX) ? ScalarType::Int - : ScalarType::Long, - "avg_pool2d_xpu", - [&] { - switch (memory_format) { - case MemoryFormat::ChannelsLast: { - output.unsafeGetTensorImpl()->empty_tensor_restride( - MemoryFormat::ChannelsLast); - launch_avg_pool2d_channels_last_kernel( - count, - input, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH_, - kW_, - dH_, - dW_, - padH_, - padW_, - output, - divisor_override_value, - count_include_pad, - use_divisor); - break; - } - case MemoryFormat::Contiguous: { - launch_avg_pool2d_kernel( - count, - input, - nInputPlane, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - kH_, - kW_, - dH_, - dW_, - padH_, - padW_, - output, - divisor_override_value, - count_include_pad, - use_divisor); - break; + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(output, INT_MAX) + ? ScalarType::Int + : ScalarType::Long, + "avg_pool2d_xpu", + [&] { + switch (memory_format) { + case MemoryFormat::ChannelsLast: { + output.unsafeGetTensorImpl()->empty_tensor_restride( + MemoryFormat::ChannelsLast); + launch_avg_pool2d_channels_last_kernel< + scalar_t, + accscalar_t, + index_t>( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH_, + kW_, + dH_, + dW_, + padH_, + padW_, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + case MemoryFormat::Contiguous: { + launch_avg_pool2d_kernel( + count, + input, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH_, + kW_, + dH_, + dW_, + padH_, + padW_, + output, + divisor_override_value, + count_include_pad, + use_divisor); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only " + "ChannelsLast, Contiguous"); } - default: - TORCH_CHECK( - false, - "Unsupported memory format. Supports only " - "ChannelsLast, Contiguous"); - } - }); + }); }); } }