Skip to content

Commit 540c9c5

Browse files
fix conv transpose dispatch into wrong kernel for forward path (#109)
1 parent dbe4de7 commit 540c9c5

File tree

9 files changed

+275
-38
lines changed

9 files changed

+275
-38
lines changed

tests/cpu/test_bf16_lazy_reorder.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,35 @@ def test_Conv2d_with_cpu(self):
6262
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_bf16))
6363
self.assertEqual(res_man_bf16.float(), res_auto_bf16.float(), 1e-2)
6464

65+
class TestDeconv(TestCase):
66+
def test_Deconv2d_with_cpu(self):
67+
rand_seed = int(get_rand_seed())
68+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
69+
torch.manual_seed(rand_seed)
70+
71+
_deconv = torch.nn.ConvTranspose2d(2, 3, (3, 3))
72+
73+
bn_man_bf16 =copy.deepcopy(_deconv).to(device=device).to(torch.bfloat16)
74+
bn_auto_mix =copy.deepcopy(_deconv).to(device=device)
75+
76+
_in_cpu = torch.rand((1, 2, 7, 7))
77+
in_auto_mix = _in_cpu.to(device=device)
78+
in_man_bf16 = in_auto_mix.to(torch.bfloat16)
79+
80+
res_cpu_fp32 = _deconv(_in_cpu)
81+
82+
with AutoDNNL(True), AutoMixPrecision(False):
83+
res_man_bf16 = bn_man_bf16(in_man_bf16)
84+
self.assertEqual(res_man_bf16.dtype, torch.bfloat16)
85+
self.assertEqual(res_cpu_fp32.bfloat16().float(), res_man_bf16, 1e-2)
86+
87+
with AutoMixPrecision(True):
88+
self.assertEqual(in_auto_mix.dtype, torch.float)
89+
self.assertFalse(ipex.core.is_bf16_dil_tensor(in_auto_mix))
90+
res_auto_bf16 = bn_auto_mix(in_auto_mix)
91+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_bf16))
92+
self.assertEqual(res_man_bf16.float(), res_auto_bf16.float(), 1e-2)
93+
6594
class TestBatchNorm(TestCase):
6695
def test_batch_norm2d(self):
6796
rand_seed = int(get_rand_seed())

tests/cpu/test_lazy_reorder.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,52 @@ def test_seq_conv(self):
105105
res_dpcpp = self._seq_conf(device, rand_seed)
106106
self.assertEqual(res_cpu, res_dpcpp.to('cpu'))
107107

108+
class TestDeonv(TestCase):
109+
def test_Deonv2d_with_cpu(self):
110+
rand_seed = int(get_rand_seed())
111+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
112+
torch.manual_seed(rand_seed)
113+
deconv_cpu = torch.nn.ConvTranspose2d(2, 3, (3, 3))
114+
deconv_dpcpp = torch.nn.ConvTranspose2d(2, 3, (3, 3)).to(device=device)
115+
116+
deconv_dpcpp.weight.data = deconv_cpu.weight.data.to(device=device)
117+
deconv_dpcpp.bias.data = deconv_cpu.bias.data.to(device=device)
118+
119+
input_cpu = torch.rand((1, 2, 7, 7))
120+
input_dpcpp = input_cpu.to(device=device)
121+
122+
ipex.core.enable_auto_dnnl()
123+
out_dpcpp = deconv_dpcpp(input_dpcpp)
124+
125+
ipex.core.disable_auto_dnnl()
126+
out_dpcpp_cpu = out_dpcpp.to('cpu')
127+
out_cpu = deconv_cpu(input_cpu)
128+
self.assertEqual(out_dpcpp.size(), out_cpu.size())
129+
self.assertEqual(out_cpu, out_dpcpp_cpu)
130+
131+
def _seq_conf(self, device, rand_seed):
132+
torch.manual_seed(rand_seed)
133+
deconv_dpcpp1 = torch.nn.ConvTranspose2d(2, 3, (7, 7)).to(device=device)
134+
deconv_dpcpp2 = torch.nn.ConvTranspose2d(3, 4, (5, 5)).to(device=device)
135+
deconv_dpcpp3 = torch.nn.ConvTranspose2d(4, 5, (3, 3)).to(device=device)
136+
input_cpu = torch.rand((1, 2, 10, 10))
137+
input_dpcpp = input_cpu.to(device=device)
138+
139+
out_dpcpp1 = deconv_dpcpp1(input_dpcpp)
140+
out_dpcpp2 = deconv_dpcpp2(out_dpcpp1)
141+
out_dpcpp3 = deconv_dpcpp3(out_dpcpp2)
142+
return out_dpcpp3
143+
144+
def test_seq_deconv(self):
145+
ipex.core.disable_auto_dnnl()
146+
rand_seed = int(get_rand_seed())
147+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
148+
res_cpu = self._seq_conf('cpu', rand_seed)
149+
150+
ipex.core.enable_auto_dnnl()
151+
res_dpcpp = self._seq_conf(device, rand_seed)
152+
self.assertEqual(res_cpu, res_dpcpp.to('cpu'))
153+
108154
class TestBinaryOp(TestCase):
109155
def test_add(self):
110156
ipex.core.enable_auto_dnnl()

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "torch_ipex/csrc/utils.h"
1414
#include "dbl/Common.h"
1515
#include "dbl/Conv.h"
16+
#include "dbl/Deconv.h"
1617
#include "dbl/Pool.h"
1718
#include "dbl/DNNLChecker.h"
1819
#include "ShadeDataContext.h"
@@ -60,11 +61,11 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
6061
}
6162

6263
dbl::comm::reorder_to_bf16_for_mix_prec(weight);
63-
dbl::conv::prepack_conv_weights(input, dil_input,
64+
dbl::conv::prepack_conv_weights(input, dil_input,
6465
weight, stride, padding, dilation, groups);
6566
dil_weight = dbl::comm::try_gen_dil_tensor(weight);
6667

67-
dil::tensor dil_output = dbl::conv::conv2d_impl(
68+
dil::tensor dil_output = dbl::conv::convolution_impl(
6869
dil_input,
6970
dil_weight,
7071
dil_bias,
@@ -172,6 +173,53 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_bac
172173
return std::make_tuple(grad_input, grad_weight, grad_bias);
173174
}
174175

176+
at::Tensor AtenIpexCPUDev::dil_deconvolution(
177+
const at::Tensor & input,
178+
const at::Tensor & weight,
179+
const at::Tensor & bias,
180+
at::IntArrayRef padding,
181+
at::IntArrayRef output_padding,
182+
at::IntArrayRef stride,
183+
at::IntArrayRef dilation,
184+
int64_t groups) {
185+
DEBUG("AtenIpexCPUDev::dil_deconvolution\n");
186+
dil::tensor dil_input;
187+
dil::tensor dil_weight;
188+
c10::optional<dil::tensor> dil_bias{c10::nullopt};
189+
190+
CHECK_DNNL_OP_PRE_COND(input);
191+
CHECK_DNNL_OP_PRE_COND(weight);
192+
193+
dbl::comm::reorder_to_bf16_for_mix_prec(input);
194+
dil_input = dbl::comm::try_gen_dil_tensor(input);
195+
196+
if (bias.defined()) {
197+
CHECK_DNNL_OP_PRE_COND(bias);
198+
dbl::comm::reorder_to_bf16_for_mix_prec(bias);
199+
dil_bias = dbl::comm::try_gen_dil_tensor(bias);
200+
}
201+
202+
dbl::comm::reorder_to_bf16_for_mix_prec(weight);
203+
204+
// TODO
205+
// dbl::deconv::prepack_deconv_weights(input, dil_input,
206+
// weight, stride, padding, dilation, groups);
207+
208+
dil_weight = dbl::comm::try_gen_dil_tensor(weight).transpose_(0, 1);
209+
210+
dil::tensor dil_output = dbl::deconv::deconvolution_impl(
211+
dil_input,
212+
dil_weight,
213+
dil_bias,
214+
padding,
215+
output_padding,
216+
stride,
217+
dilation,
218+
groups);
219+
220+
return dbl::comm::gen_aten_tensor_by(std::move(dil_output));
221+
}
222+
175223
at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
176224
DEBUG("AtenIpexCPUDev::convolution_overrideable\n");
177225

@@ -184,7 +232,11 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input
184232
dnnl_input_tensors.push_back(bias);
185233
}
186234
if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))
187-
return AtenIpexCPUDev::dil_convolution(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.defined() && !bias.is_contiguous() ? bias.contiguous() : bias, stride, padding, dilation, groups);
235+
if (transposed) {
236+
return AtenIpexCPUDev::dil_deconvolution(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.defined() && !bias.is_contiguous() ? bias.contiguous() : bias, padding, output_padding, stride, dilation, groups);
237+
} else {
238+
return AtenIpexCPUDev::dil_convolution(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.defined() && !bias.is_contiguous() ? bias.contiguous() : bias, stride, padding, dilation, groups);
239+
}
188240
}
189241
} catch (std::exception& e) {
190242
#if defined(_DEBUG)
@@ -198,43 +250,34 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input
198250
auto&& _ipex_input = bridge::shallowFallbackToCPUTensor(input);
199251
auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight);
200252
auto&& _ipex_bias = bridge::shallowFallbackToCPUTensor(bias);
201-
auto&& _ipex_result = at::mkldnn_convolution(_ipex_input, _ipex_weight, _ipex_bias, padding, stride, dilation, groups);
253+
auto&& _ipex_result = at::convolution(_ipex_input, _ipex_weight, _ipex_bias, stride, padding, dilation, transposed, output_padding, groups);
202254
static_cast<void>(_ipex_result); // Avoid warnings in case not used
203255
return bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
204256
}
205257

206-
at::Tensor AtenIpexCPUDev::mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
207-
DEBUG("AtenIpexCPUDev::mkldnn_convolution\n");
208-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.defined());
209-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.defined());
210-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.layout() == c10::kStrided);
211-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
212-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!(bias.defined()) || (bias.defined() && bias.layout() == c10::kStrided));
213-
auto&& _ipex_self = bridge::shallowFallbackToCPUTensor(self);
214-
auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight);
215-
auto&& _ipex_bias = bridge::shallowFallbackToCPUTensor(bias);
216-
auto&& _ipex_result = at::mkldnn_convolution(_ipex_self.contiguous(), _ipex_weight.contiguous(), _ipex_bias.contiguous(), padding, stride, dilation, groups);
217-
static_cast<void>(_ipex_result); // Avoid warnings in case not used
218-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_ipex_result.is_contiguous());
219-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_ipex_result.layout() == c10::kStrided);
220-
return bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
221-
}
222-
223258
std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_backward_overrideable(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, std::array<bool,3> output_mask) {
224259
DEBUG("AtenIpexCPUDev::convolution_backward_overrideable\n");
225260
// NOTE: DO NOT always call contiguous. It may break lazy-reorder. Because contiguous will call reorder instantly.
226261
if (check_auto_dnnl()) {
227-
return dil_convolution_backward(
228-
input.is_contiguous() ? input : input.contiguous(),
229-
grad_output.is_contiguous() ? grad_output : grad_output.contiguous(),
230-
weight.is_contiguous() ? weight : weight.contiguous(),
231-
padding,
232-
stride,
233-
dilation,
234-
groups,
235-
output_mask);
262+
if (transposed) {
263+
IPEX_CHECK(false, "deconvolution backward not support for dnnl path now");
264+
} else {
265+
return AtenIpexCPUDev::dil_convolution_backward(
266+
input.is_contiguous() ? input : input.contiguous(),
267+
grad_output.is_contiguous() ? grad_output : grad_output.contiguous(),
268+
weight.is_contiguous() ? weight : weight.contiguous(),
269+
padding,
270+
stride,
271+
dilation,
272+
groups,
273+
output_mask);
274+
}
236275
} else {
237-
return mkldnn_convolution_backward(input, grad_output, weight, padding, stride, dilation, groups, output_mask);
276+
if (transposed) {
277+
IPEX_CHECK(false, "deconvolution backward not support for native path now");
278+
} else {
279+
return AtenIpexCPUDev::mkldnn_convolution_backward(input, grad_output, weight, padding, stride, dilation, groups, output_mask);
280+
}
238281
}
239282
}
240283

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ class AtenIpexCPUDev {
1212
static at::Tensor dil_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
1313
static std::tuple<at::Tensor,at::Tensor,at::Tensor> dil_convolution_backward_overrideable(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, std::array<bool,3> output_mask);
1414
// aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
15-
static at::Tensor mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups);
1615
static std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask);
1716

1817
// For DNNL OPs
1918
static at::Tensor dil_convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
2019
static std::tuple<at::Tensor,at::Tensor,at::Tensor> dil_convolution_backward(const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask);
20+
static at::Tensor dil_deconvolution(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, at::IntArrayRef padding, at::IntArrayRef ouput_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups);
21+
// static std::tuple<at::Tensor,at::Tensor,at::Tensor> dil_deconvolution_backward(const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask);
2122
static at::Tensor& dil_add_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other, at::Scalar alpha);
2223
static at::Tensor dil_add(const at::Tensor& self, const at::Tensor& other, at::Scalar alpha);
2324
static at::Tensor & dil_add_(at::Tensor & self, const at::Tensor & other, at::Scalar alpha);

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ at::Tensor AtenIpexJITDev::dil_convolution_relu(
5252
weight_contiguous, stride, padding, dilation, groups);
5353
dil_weight = try_gen_dil_tensor(weight_contiguous);
5454

55-
dil::tensor dil_output = dbl::conv::conv2d_impl(
55+
dil::tensor dil_output = dbl::conv::convolution_impl(
5656
dil_input,
5757
dil_weight,
5858
dil_bias,
@@ -100,7 +100,7 @@ static at::Tensor& dil_convolution_inplace_fusion(
100100
weight_contiguous, stride, padding, dilation, groups);
101101
dil_weight = try_gen_dil_tensor(weight_contiguous);
102102

103-
dbl::conv::conv2d_inplace_impl(
103+
dbl::conv::convolution_inplace_impl(
104104
dil_input,
105105
dil_weight,
106106
dil_bias,

torch_ipex/csrc/cpu/dbl/Conv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ std::vector<int64_t> calc_conv_output_size(
2525
return output_size;
2626
}
2727

28-
dil::tensor conv2d_impl(
28+
dil::tensor convolution_impl(
2929
const dil::tensor& x,
3030
const dil::tensor& w,
3131
const c10::optional<dil::tensor>& b,
@@ -87,7 +87,7 @@ dil::tensor conv2d_impl(
8787
return y;
8888
}
8989

90-
void conv2d_inplace_impl(
90+
void convolution_inplace_impl(
9191
const dil::tensor& x,
9292
const dil::tensor& w,
9393
const c10::optional<dil::tensor>& b,

torch_ipex/csrc/cpu/dbl/Conv.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ std::vector<int64_t> calc_conv_output_size(
1818
at::IntArrayRef stride,
1919
at::IntArrayRef dilation);
2020

21-
dil::tensor conv2d_impl(
21+
dil::tensor convolution_impl(
2222
const dil::tensor& x,
2323
const dil::tensor& w,
2424
const c10::optional<dil::tensor>& b,
@@ -28,7 +28,7 @@ dil::tensor conv2d_impl(
2828
int64_t groups,
2929
const dil::attr_t& attr = dil::attr_t());
3030

31-
void conv2d_inplace_impl(
31+
void convolution_inplace_impl(
3232
const dil::tensor& x,
3333
const dil::tensor& w,
3434
const c10::optional<dil::tensor>& b,

torch_ipex/csrc/cpu/dbl/Deconv.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include "Deconv.h"
2+
3+
#include "Common.h"
4+
#include "cpu/ShadeDataContext.h"
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
namespace dbl {
9+
namespace deconv {
10+
11+
std::vector<int64_t> calc_deconv_input_size(
12+
at::IntArrayRef output_size,
13+
at::IntArrayRef kernel_size,
14+
at::IntArrayRef padding,
15+
at::IntArrayRef output_padding,
16+
at::IntArrayRef stride,
17+
at::IntArrayRef dilation,
18+
int64_t groups) {
19+
auto dim = output_size.size();
20+
std::vector<int64_t> input_size(dim);
21+
input_size[0] = output_size[0];
22+
input_size[1] = kernel_size[1] * groups;
23+
for (size_t d = 2; d < dim; ++d) {
24+
auto kernel = dilation[d - 2] * (kernel_size[d] - 1) + 1;
25+
input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) +
26+
kernel + output_padding[d - 2];
27+
}
28+
return input_size;
29+
}
30+
31+
dil::tensor deconvolution_impl(
32+
const dil::tensor& x,
33+
const dil::tensor& w,
34+
const c10::optional<dil::tensor>& b,
35+
at::IntArrayRef padding,
36+
at::IntArrayRef output_padding,
37+
at::IntArrayRef stride,
38+
at::IntArrayRef dilation,
39+
int64_t groups,
40+
const dil::attr_t& attr) {
41+
const dil::dims x_dims = x.get_dims();
42+
const dil::dims w_dims = w.get_dims();
43+
std::vector<int64_t> input_size{x_dims.cbegin(), x_dims.cend()};
44+
std::vector<int64_t> kernel_size{w_dims.cbegin(), w_dims.cend()};
45+
std::swap(kernel_size[0], kernel_size[1]);
46+
std::vector<int64_t> output_sizes = calc_deconv_input_size(input_size, kernel_size, padding, output_padding, stride, dilation, groups);
47+
48+
dil::tensor y;
49+
if (b.has_value()) {
50+
dil::convolution_transpose_forward::compute(
51+
x,
52+
w,
53+
b.value(),
54+
{output_sizes.cbegin(), output_sizes.cend()},
55+
y,
56+
{stride.begin(), stride.end()},
57+
{padding.begin(), padding.end()},
58+
{padding.begin(), padding.end()},
59+
{dilation.begin(), dilation.end()},
60+
groups,
61+
attr);
62+
} else {
63+
dil::convolution_transpose_forward::compute(
64+
x,
65+
w,
66+
{output_sizes.cbegin(), output_sizes.cend()},
67+
y,
68+
{stride.begin(), stride.end()},
69+
{padding.begin(), padding.end()},
70+
{padding.begin(), padding.end()},
71+
{dilation.begin(), dilation.end()},
72+
groups,
73+
attr);
74+
}
75+
return y;
76+
}
77+
78+
} // namespace deconv
79+
} // namespace dbl
80+
} // namespace cpu
81+
} // namespace torch_ipex

0 commit comments

Comments
 (0)