Skip to content

Commit 9d6039b

Browse files
authored
fix group_conv3d caculate error (apache#12500)
1 parent 41be1b4 commit 9d6039b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

Diff for: python/tvm/topi/x86/conv3d.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, groups, out_dty
275275
# pack kernel
276276
shape = (
277277
num_filter // oc_bn,
278-
in_channel // groups // ic_bn,
278+
in_channel // groups // ic_bn if (in_channel // groups // ic_bn) else 1,
279279
kernel_depth,
280280
kernel_height,
281281
kernel_width,
@@ -392,7 +392,7 @@ def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, groups,
392392
# pack kernel
393393
shape = (
394394
num_filter // oc_bn,
395-
in_channel // groups // ic_bn,
395+
in_channel // groups // ic_bn if (in_channel // groups // ic_bn) else 1,
396396
kernel_depth,
397397
kernel_height,
398398
kernel_width,

Diff for: tests/python/frontend/onnx/test_forward.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2967,7 +2967,7 @@ def repeat(num, dims):
29672967
)
29682968

29692969
# TODO(jwfromm): Merge with other tests once group_conv3d is supported.
2970-
for dims in [1, 2]:
2970+
for dims in [1, 2, 3]:
29712971
# Group Convolution
29722972
verify_conv(
29732973
(1, 8) + repeat(5, dims),

0 commit comments

Comments
 (0)