This repository was archived by the owner on May 29, 2023. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathdeformable_conv.py
51 lines (43 loc) · 1.5 KB
/
deformable_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torchvision.ops as ops
class DeformableConvFunc(torch.autograd.Function):
@staticmethod
def symbolic(g, cls, x, offset):
weight = cls.state_dict()["weight"]
weight = g.op("Constant", value_t=weight)
return g.op(
"DeformableConv2D",
x,
offset,
weight,
strides_i=(cls.stride, cls.stride),
pads_i=(cls.padding, cls.padding, cls.padding, cls.padding),
dilations_i=(cls.dilation, cls.dilation),
deformable_group_i=cls.groups,
)
@staticmethod
def forward(self, cls, x, offset):
y = cls.origin_forward(x, offset)
return y
class DeformableConvolution(ops.DeformConv2d):
"""
This is a support class which helps export network with SparseConv in ONNX format.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.origin_forward = super().forward
self.stride = kwargs.get("stride", 1)
self.padding = kwargs.get("padding", 0)
self.dilation = kwargs.get("dilation", 1)
self.groups = kwargs.get("groups", 1)
self.pad_l = nn.ConstantPad2d((1, 1, 1, 1), 0)
def forward(self, x, offset):
"""
Using paddings is a workaround for 2021.4 release.
"""
x = self.pad_l(x)
offset = self.pad_l(offset)
y = DeformableConvFunc.apply(self, x, offset)
y = y[:, :, 1:-1, 1:-1]
return y