-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathfix_utils.py
62 lines (51 loc) · 2.19 KB
/
fix_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import Tuple, Optional, List, Dict
import torch.nn as nn
import torch
from dalib.modules.grl import WarmStartGradientReverseLayer
def shift_log(x: torch.Tensor, offset: Optional[float] = 1e-6) -> torch.Tensor:
return torch.log(torch.clamp(x + offset, max=1.))
class ImageClassifier(nn.Module):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,
bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True):
super(ImageClassifier, self).__init__()
self.backbone = nn.Sequential(backbone,nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten())
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
self.num_classes = num_classes
if bottleneck is None:
self.bottleneck = nn.Sequential(
)
self._features_dim = backbone.out_features
else:
self.bottleneck = bottleneck
assert bottleneck_dim > 0
self._features_dim = bottleneck_dim
if head is None:
self.head = nn.Linear(self._features_dim, num_classes)
else:
self.head = head
self.finetune = finetune
@property
def features_dim(self) -> int:
"""The dimension of features before the final `head` layer"""
return self._features_dim
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""""""
f = self.backbone(x)
f1 = self.bottleneck(f)
predictions = self.head(f1)
return predictions, f
def get_parameters(self, base_lr=1.0) -> List[Dict]:
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
]
return params