Skip to content

Commit

Permalink
some layers can now export to normal nn.Module for serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
XuankangLin committed Aug 4, 2020
1 parent 5ad1c54 commit 9a4d3db
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
19 changes: 19 additions & 0 deletions diffabs/deeppoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,19 @@ class Linear(nn.Linear):
def __str__(self):
return f'{Dom.name}.' + super().__str__()

@classmethod
def from_module(cls, src: nn.Linear) -> Linear:
with_bias = src.bias is not None
new_lin = Linear(src.in_features, src.out_features, with_bias)
new_lin.load_state_dict(src.state_dict())
return new_lin

def export(self) -> nn.Linear:
with_bias = self.bias is not None
lin = nn.Linear(self.in_features, self.out_features, with_bias)
lin.load_state_dict(self.state_dict())
return lin

def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
"""
:param ts: either Tensor, Ele, or Ele tensors
Expand Down Expand Up @@ -814,6 +827,9 @@ class ReLU(nn.ReLU):
def __str__(self):
return f'{Dom.name}.' + super().__str__()

def export(self) -> nn.ReLU:
return nn.ReLU()

def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
""" According to paper, it approximates E by either of the two cases, whichever has smaller areas.
Mathematically, it can be proved that the (linear) approximation is optimal in terms of approximated areas.
Expand Down Expand Up @@ -933,6 +949,9 @@ class Tanh(nn.Tanh):
def __str__(self):
return f'{Dom.name}.' + super().__str__()

def export(self) -> nn.Tanh:
return nn.Tanh()

def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
""" For both LB' and UB', it chooses the smaller slope between LB-UB and LB'/UB'. Specifically,
when L > 0, LB' chooses LB-UB, otherwise LB';
Expand Down
19 changes: 19 additions & 0 deletions diffabs/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,19 @@ class Linear(nn.Linear):
def __str__(self):
return f'{Dom.name}.' + super().__str__()

@classmethod
def from_module(cls, src: nn.Linear) -> Linear:
with_bias = src.bias is not None
new_lin = Linear(src.in_features, src.out_features, with_bias)
new_lin.load_state_dict(src.state_dict())
return new_lin

def export(self) -> nn.Linear:
with_bias = self.bias is not None
lin = nn.Linear(self.in_features, self.out_features, with_bias)
lin.load_state_dict(self.state_dict())
return lin

def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
""" Re-implement the forward computation by myself, because F.linear() may apply optimization using
torch.addmm() which requires inputs to be tensor.
Expand Down Expand Up @@ -416,6 +429,9 @@ class ReLU(nn.ReLU):
def __str__(self):
return f'{Dom.name}.' + super().__str__()

def export(self) -> nn.ReLU:
return nn.ReLU()

def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
return _distribute_to_super(super().forward, *ts)
pass
Expand All @@ -425,6 +441,9 @@ class Tanh(nn.Tanh):
def __str__(self):
return f'{Dom.name}.' + super().__str__()

def export(self) -> nn.Tanh:
return nn.Tanh()

def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
return _distribute_to_super(super().forward, *ts)
pass
Expand Down

0 comments on commit 9a4d3db

Please sign in to comment.