Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 91 additions & 82 deletions src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ inline int max(int a, int b) {
return a >= b ? a : b;
}

template <typename scalar_t, typename accscalar_t>
template <typename scalar_t, typename accscalar_t, typename index_t>
struct AvgPool2dKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
auto index = item.get_global_linear_id();
index_t index = item.get_global_linear_id();

if (index < total_elements_) {
const int pw = index % pooled_width_;
Expand Down Expand Up @@ -73,10 +73,10 @@ struct AvgPool2dKernelFunctor {
AvgPool2dKernelFunctor(
scalar_t* top_data,
const scalar_t* bottom_data,
int64_t total_elements,
int64_t channels,
int64_t height,
int64_t width,
index_t total_elements,
index_t channels,
index_t height,
index_t width,
int pooled_height,
int pooled_width,
int kernel_h,
Expand Down Expand Up @@ -109,10 +109,10 @@ struct AvgPool2dKernelFunctor {
private:
scalar_t* top_data_;
const scalar_t* bottom_data_;
int64_t total_elements_;
int64_t channels_;
int64_t height_;
int64_t width_;
index_t total_elements_;
index_t channels_;
index_t height_;
index_t width_;
int pooled_height_;
int pooled_width_;
int kernel_h_;
Expand All @@ -126,10 +126,10 @@ struct AvgPool2dKernelFunctor {
bool use_divisor_;
};

template <typename scalar_t, typename accscalar_t>
template <typename scalar_t, typename accscalar_t, typename index_t>
struct AvgPool2dChannelsLastKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
auto index = item.get_global_linear_id();
index_t index = item.get_global_linear_id();

if (index < total_elements_) {
const int c = index % channels_;
Expand Down Expand Up @@ -175,10 +175,10 @@ struct AvgPool2dChannelsLastKernelFunctor {
AvgPool2dChannelsLastKernelFunctor(
scalar_t* top_data,
const scalar_t* bottom_data,
int64_t total_elements,
int64_t channels,
int64_t height,
int64_t width,
index_t total_elements,
index_t channels,
index_t height,
index_t width,
int pooled_height,
int pooled_width,
int kernel_h,
Expand Down Expand Up @@ -211,10 +211,10 @@ struct AvgPool2dChannelsLastKernelFunctor {
private:
scalar_t* top_data_;
const scalar_t* bottom_data_;
int64_t total_elements_;
int64_t channels_;
int64_t height_;
int64_t width_;
index_t total_elements_;
index_t channels_;
index_t height_;
index_t width_;
int pooled_height_;
int pooled_width_;
int kernel_h_;
Expand All @@ -228,13 +228,13 @@ struct AvgPool2dChannelsLastKernelFunctor {
bool use_divisor_;
};

template <typename scalar_t, typename accscalar_t>
template <typename scalar_t, typename accscalar_t, typename index_t>
void launch_avg_pool2d_channels_last_kernel(
const int total_elements,
const Tensor& input,
const int64_t channels,
const int64_t height,
const int64_t width,
const index_t channels,
const index_t height,
const index_t width,
const int pooled_height,
const int pooled_width,
const int kernel_h,
Expand All @@ -255,7 +255,7 @@ void launch_avg_pool2d_channels_last_kernel(
const uint32_t global_range =
ceil_div<uint32_t>(total_elements, group_size) * group_size;

auto kfn = AvgPool2dChannelsLastKernelFunctor<scalar_t, accscalar_t>(
auto kfn = AvgPool2dChannelsLastKernelFunctor<scalar_t, accscalar_t, index_t>(
top_data,
bottom_data,
total_elements,
Expand All @@ -276,13 +276,13 @@ void launch_avg_pool2d_channels_last_kernel(
sycl_kernel_submit(global_range, group_size, queue, kfn);
}

template <typename scalar_t, typename accscalar_t>
template <typename scalar_t, typename accscalar_t, typename index_t>
void launch_avg_pool2d_kernel(
const int total_elements,
const Tensor& input,
const int64_t channels,
const int64_t height,
const int64_t width,
const index_t channels,
const index_t height,
const index_t width,
const int pooled_height,
const int pooled_width,
const int kernel_h,
Expand All @@ -303,7 +303,7 @@ void launch_avg_pool2d_kernel(
const uint32_t global_range =
ceil_div<uint32_t>(total_elements, group_size) * group_size;

auto kfn = AvgPool2dKernelFunctor<scalar_t, accscalar_t>(
auto kfn = AvgPool2dKernelFunctor<scalar_t, accscalar_t, index_t>(
top_data,
bottom_data,
total_elements,
Expand Down Expand Up @@ -664,58 +664,67 @@ void avg_pool2d_kernel(
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_xpu", [&] {
using accscalar_t = acc_type_device<scalar_t, kXPU>;

switch (memory_format) {
case MemoryFormat::ChannelsLast: {
output.unsafeGetTensorImpl()->empty_tensor_restride(
MemoryFormat::ChannelsLast);
launch_avg_pool2d_channels_last_kernel<scalar_t, accscalar_t>(
count,
input,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH_,
kW_,
dH_,
dW_,
padH_,
padW_,
output,
divisor_override_value,
count_include_pad,
use_divisor);
break;
}
case MemoryFormat::Contiguous: {
launch_avg_pool2d_kernel<scalar_t, accscalar_t>(
count,
input,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH_,
kW_,
dH_,
dW_,
padH_,
padW_,
output,
divisor_override_value,
count_include_pad,
use_divisor);
break;
}
default:
TORCH_CHECK(
false,
"Unsupported memory format. Supports only "
"ChannelsLast, Contiguous");
}
AT_DISPATCH_INDEX_TYPES(
at::native::canUse32BitIndexMath(output, INT_MAX)
? ScalarType::Int
: ScalarType::Long,
"avg_pool2d_xpu",
[&] {
switch (memory_format) {
case MemoryFormat::ChannelsLast: {
output.unsafeGetTensorImpl()->empty_tensor_restride(
MemoryFormat::ChannelsLast);
launch_avg_pool2d_channels_last_kernel<
scalar_t,
accscalar_t,
index_t>(
count,
input,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH_,
kW_,
dH_,
dW_,
padH_,
padW_,
output,
divisor_override_value,
count_include_pad,
use_divisor);
break;
}
case MemoryFormat::Contiguous: {
launch_avg_pool2d_kernel<scalar_t, accscalar_t, index_t>(
count,
input,
nInputPlane,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kH_,
kW_,
dH_,
dW_,
padH_,
padW_,
output,
divisor_override_value,
count_include_pad,
use_divisor);
break;
}
default:
TORCH_CHECK(
false,
"Unsupported memory format. Supports only "
"ChannelsLast, Contiguous");
}
});
});
}
}
Expand Down