-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneural_mean_discrepancy.py
91 lines (67 loc) · 3.26 KB
/
neural_mean_discrepancy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
from tqdm import tqdm
class Neural_Mean_Discrepancy(nn.Module):
def __init__(self, model, layer_names, device):
super().__init__()
self.model = model
self.model.eval()
self.device = device
self.layer_names = layer_names
self.activation = {}
# register the hooks using the requested layer names
self.register_activations()
def get_activations(self, name):
# https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6
def hook(model, input, output):
self.activation[name] = output.detach()
return hook
def register_activations(self):
# register a forward hook for every requested layer name
for name, layer in self.model.named_modules():
if name in self.layer_names:
layer.register_forward_hook(self.get_activations(name))
def fit_in_distribution_dataset(self, id_dataset):
"""
:param id_dataset: a torch.Dataset() where getitem outputs a single image torch.Tensor([C, H, W]) and label ()
"""
print('Fitting in-distribution dataset..')
_, self.nmf = self.compute_activations(id_dataset)
def predict_nmd_unk_distribtion(self, ud_dataset):
"""
:param ud_dataset: a list of images as [torch.Tensor(),...], or a torch.Dataset()
:return:
"""
print('Predicting nmd of unknown distribution dataset..')
nmf_per_sample, _ = self.compute_activations(ud_dataset)
nmd_score, nmd_per_sample = self.nmd_score(nmf_per_sample)
return nmd_score, nmd_per_sample
def compute_activations(self, dataset):
# create empty dictionary to store activations for every example
layer_activations = {key: [] for key in self.layer_names}
# iterate through dataset
for (x, y) in tqdm(dataset):
# pass single image (1, C, H, W) through model
_ = self.model(x.to(self.device).unsqueeze(0))
# iterate through all the layers we need activations for
for layer_name in self.layer_names:
# get activation map
activation_map = self.activation[layer_name]
# take mean over the spatial dims
channel_activations = activation_map.mean(dim=[0, 2, 3])
# append to layer activation dictionary
layer_activations[layer_name].append(channel_activations)
# stack activations per sample
activations_per_sample = torch.cat(
[torch.stack(activations, dim=0) for layer_name, activations in layer_activations.items()],
dim=1)
# take mean over all examples, and concat to a single vector (neural mean feature)
nmf = torch.cat(
[torch.stack(activations, dim=0).mean(dim=0) for layer_name, activations in layer_activations.items()],
dim=0)
return activations_per_sample, nmf
def nmd_score(self, activations):
#print(activations.shape, self.nmf.shape)
nmd = activations - self.nmf
nmd_reduced = nmd.mean(dim=0)
return nmd_reduced, nmd