Skip to content

Commit 9be8f1b

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
Add convolution for torch script
1 parent 62ef5a2 commit 9be8f1b

File tree

2 files changed

+67
-47
lines changed

2 files changed

+67
-47
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,10 @@ def get_bindings(alist) -> List[Any]:
217217

218218
for i in alist:
219219
if isinstance(i, str):
220-
results.append(context[i])
220+
try:
221+
results.append(context[i])
222+
except ValueError:
223+
results.append(None)
221224
elif isinstance(i, (list, tuple)) and all(isinstance(j, int) for j in i):
222225
results.append(mb.const(val=i))
223226
elif isinstance(i, (list, tuple)):
@@ -962,7 +965,7 @@ def linear(context, node):
962965
context.add(res, torch_name=node.name)
963966

964967

965-
@register_torch_op(torch_alias=["conv2d", "convolution"])
968+
@register_torch_op(torch_alias=["convolution", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d"])
966969
def _convolution(context, node):
967970
inputs = _get_inputs(context, node)
968971

@@ -980,11 +983,25 @@ def _convolution(context, node):
980983
# we require a (2 * n)-tuple, where n is the number of spatial dimensions, start and end for each spatial dimension
981984
pad = inputs[4].val
982985

983-
if len(weight.shape) in (3, 4):
984-
# 1D and 2D: Need to explicitly state L-R, T-B pad
986+
if type(pad) == str:
987+
if pad == "same":
988+
pad = 1
989+
elif pad == "valid":
990+
pad = 0
991+
else:
992+
raise ValueError(f"Unkown padding string value: '{pad}'")
993+
994+
if len(weight.shape) == 3:
995+
# 1D padding: needs explicitly state L-R for x dim
985996
pad = _np.repeat(pad, 2)
997+
elif len(weight.shape) == 4:
998+
# 2D padding: needs explicitly state L-R for x,y dims
999+
if type(pad) == int:
1000+
pad = _np.repeat(pad, 4)
1001+
elif len(pad) == 2:
1002+
pad = _np.repeat(pad, 2)
9861003
elif len(weight.shape) == 5:
987-
# 3D: Need to explicitly state F-Bk, L-R, T-B pad
1004+
# 3D padding: needs explicitly state L-R for x,y,z dims
9881005
if type(pad) == int:
9891006
pad = _np.repeat(pad, 6)
9901007
elif len(pad) == 3:
@@ -1000,6 +1017,11 @@ def _convolution(context, node):
10001017
transposed = inputs[6].val
10011018
out_pad = inputs[7].val
10021019
group = inputs[8]
1020+
elif len(inputs) == 8:
1021+
transposed = True
1022+
out_pad = inputs[5].val
1023+
dilations = inputs[7]
1024+
group = inputs[6]
10031025
elif len(inputs) == 7:
10041026
transposed = False
10051027
group = inputs[6]

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -226,29 +226,6 @@ def forward(self, x):
226226
use_scripting=True,
227227
)
228228

229-
@pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends))
230-
def test_conv(self, compute_unit, backend):
231-
pytest.xfail(
232-
"rdar://88194776 ([Converter] coremltools is not working with scripted torch convolution model)"
233-
)
234-
model = torch.nn.Conv2d(
235-
in_channels=2,
236-
out_channels=3,
237-
kernel_size=1,
238-
padding="same",
239-
stride=1,
240-
dilation=1,
241-
groups=1,
242-
bias=False,
243-
)
244-
self.run_compare_torch(
245-
(1, 2, 4, 5),
246-
model,
247-
backend=backend,
248-
compute_unit=compute_unit,
249-
use_scripting=True,
250-
)
251-
252229

253230
class TestMean(TorchBaseTest):
254231
@pytest.mark.parametrize(
@@ -1456,6 +1433,7 @@ class TestConv(TorchBaseTest):
14561433
[
14571434
"compute_unit",
14581435
"backend",
1436+
"scripting",
14591437
"padding",
14601438
"stride",
14611439
"length",
@@ -1467,10 +1445,11 @@ class TestConv(TorchBaseTest):
14671445
]
14681446
),
14691447
[
1470-
(compute_unit, backend, padding, stride, *param)
1471-
for compute_unit, backend, padding, stride, param in itertools.product(
1448+
(compute_unit, backend, scripting, padding, stride, *param)
1449+
for compute_unit, backend, scripting, padding, stride, param in itertools.product(
14721450
[ct.ComputeUnit.CPU_ONLY],
14731451
backends,
1452+
[True, False],
14741453
["same", "valid", 0, 1],
14751454
[1, 2, 3],
14761455
[
@@ -1490,6 +1469,7 @@ def test_convolution1d(
14901469
self,
14911470
compute_unit,
14921471
backend,
1472+
scripting,
14931473
padding,
14941474
stride,
14951475
length,
@@ -1503,6 +1483,7 @@ def test_convolution1d(
15031483
if padding == "same" and stride != 1:
15041484
# configuration not supported
15051485
return
1486+
15061487
model = nn.Conv1d(
15071488
in_channels=in_channels,
15081489
out_channels=out_channels,
@@ -1511,19 +1492,22 @@ def test_convolution1d(
15111492
padding=padding,
15121493
dilation=dilation,
15131494
bias=bias,
1495+
groups=groups,
15141496
)
15151497
self.run_compare_torch(
15161498
(1, in_channels, length),
15171499
model,
15181500
backend=backend,
15191501
compute_unit=compute_unit,
1502+
use_scripting=scripting,
15201503
)
15211504

15221505
@pytest.mark.parametrize(
15231506
",".join(
15241507
[
15251508
"compute_unit",
15261509
"backend",
1510+
"scripting",
15271511
"padding",
15281512
"stride",
15291513
"height",
@@ -1536,10 +1520,11 @@ def test_convolution1d(
15361520
]
15371521
),
15381522
[
1539-
(compute_unit, backend, padding, stride, *param)
1540-
for compute_unit, backend, padding, stride, param in itertools.product(
1523+
(compute_unit, backend, scripting, padding, stride, *param)
1524+
for compute_unit, backend, scripting, padding, stride, param in itertools.product(
15411525
[ct.ComputeUnit.CPU_ONLY],
15421526
backends,
1527+
[True, False],
15431528
["same", "valid", 1, 0],
15441529
[1, 2, 3],
15451530
[
@@ -1559,6 +1544,7 @@ def test_convolution2d(
15591544
self,
15601545
compute_unit,
15611546
backend,
1547+
scripting,
15621548
padding,
15631549
stride,
15641550
height,
@@ -1571,7 +1557,9 @@ def test_convolution2d(
15711557
groups=1,
15721558
):
15731559
if padding == "same" and stride != 1:
1560+
# configuration not supported
15741561
return
1562+
15751563
model = nn.Conv2d(
15761564
in_channels=in_channels,
15771565
out_channels=out_channels,
@@ -1580,19 +1568,22 @@ def test_convolution2d(
15801568
padding=padding,
15811569
dilation=dilation,
15821570
bias=bias,
1571+
groups=groups,
15831572
)
15841573
self.run_compare_torch(
15851574
(1, in_channels, height, width),
15861575
model,
15871576
backend=backend,
15881577
compute_unit=compute_unit,
1578+
use_scripting=scripting,
15891579
)
15901580

15911581
@pytest.mark.parametrize(
15921582
",".join(
15931583
[
15941584
"compute_unit",
15951585
"backend",
1586+
"scripting",
15961587
"padding",
15971588
"stride",
15981589
"depth",
@@ -1606,10 +1597,11 @@ def test_convolution2d(
16061597
]
16071598
),
16081599
[
1609-
(compute_unit, backend, padding, stride, *param)
1610-
for compute_unit, backend, padding, stride, param in itertools.product(
1600+
(compute_unit, backend, scripting, padding, stride, *param)
1601+
for compute_unit, backend, scripting, padding, stride, param in itertools.product(
16111602
[ct.ComputeUnit.CPU_ONLY],
16121603
backends,
1604+
[True, False],
16131605
["same", "valid", 1, 0],
16141606
[1, 2, 3],
16151607
[
@@ -1629,6 +1621,7 @@ def test_convolution3d(
16291621
self,
16301622
compute_unit,
16311623
backend,
1624+
scripting,
16321625
padding,
16331626
stride,
16341627
depth,
@@ -1642,52 +1635,57 @@ def test_convolution3d(
16421635
groups=1,
16431636
):
16441637
if padding == "same" and stride != 1:
1638+
# configuration not supported
16451639
return
1640+
16461641
model = nn.Conv3d(
16471642
in_channels=in_channels,
16481643
out_channels=out_channels,
16491644
kernel_size=kernel_size,
1645+
bias=bias,
16501646
stride=stride,
16511647
padding=padding,
16521648
dilation=dilation,
1653-
bias=bias,
1649+
groups=groups,
16541650
)
16551651
self.run_compare_torch(
16561652
(1, in_channels, depth, height, width),
16571653
model,
16581654
backend=backend,
16591655
compute_unit=compute_unit,
1656+
use_scripting=scripting,
16601657
)
16611658

16621659

1663-
class TestDynamicConv(TorchBaseTest):
1660+
class TestFunctionalConv(TorchBaseTest):
16641661
@pytest.mark.parametrize(
16651662
",".join(
16661663
[
16671664
"compute_unit",
16681665
"backend",
1666+
"padding",
16691667
"width",
16701668
"in_channels",
16711669
"out_channels",
16721670
"kernel_size",
16731671
"stride",
1674-
"padding",
16751672
]
16761673
),
16771674
[
1678-
(compute_unit, backend, *param)
1679-
for compute_unit, backend, param in itertools.product(
1675+
(compute_unit, backend, padding, *param)
1676+
for compute_unit, backend, padding, param in itertools.product(
16801677
compute_units,
16811678
backends,
1679+
["same", "valid", 1, 0],
16821680
[
1683-
(5, 1, 1, 1, 2, 1),
1684-
(3, 1, 1, 1, 2, 3),
1685-
(4, 3, 3, 1, 2, 1),
1686-
(7, 3, 3, 1, 3, 1),
1687-
(5, 3, 3, 2, 2, 1),
1688-
(3, 3, 3, 1, 3, 1),
1689-
(3, 3, 3, 1, 3, 3),
1690-
(7, 3, 3, 3, 1, 3),
1681+
(5, 1, 1, 1, 2),
1682+
(3, 1, 1, 1, 2),
1683+
(4, 3, 3, 1, 2),
1684+
(7, 3, 3, 1, 3),
1685+
(5, 3, 3, 2, 2),
1686+
(3, 3, 3, 1, 3),
1687+
(3, 3, 3, 1, 3),
1688+
(7, 3, 3, 3, 1),
16911689
],
16921690
)
16931691
],

0 commit comments

Comments
 (0)