Skip to content

Commit 28f990a

Browse files
committed
implement MatrixApproximator
1 parent 5725edd commit 28f990a

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

network/oe.py

+53
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,59 @@ def soft_clip(self, x):
138138
return (direction * (norm + self.K)).view(original_shape)
139139

140140

141+
class MatrixApproximation(nn.Module):
142+
"""
143+
Fully connected NN to learn features on top of image features in the joint embedding space.
144+
"""
145+
146+
def __init__(self, normalize, input_dim=2048, output_dim=10, K=None):
147+
"""
148+
Constructor to prepare layers for the embedding.
149+
"""
150+
super(MatrixApproximation, self).__init__()
151+
self.u = nn.Parameter(torch.randn((input_dim)))
152+
self.v = nn.Parameter(torch.randn((output_dim)))
153+
self.d = nn.Parameter(torch.randn((output_dim)))
154+
self.pad = nn.ZeroPad2d((0, 0, 0, input_dim - output_dim))
155+
156+
self.normalize = normalize
157+
self.K = K
158+
159+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160+
161+
def forward(self, x):
162+
"""
163+
Forward pass through the model.
164+
"""
165+
W = self.pad(torch.diag(self.d)) + torch.ger(self.u, self.v)
166+
x = torch.matmul(x, W)
167+
168+
if self.normalize == 'unit_norm':
169+
original_shape = x.shape
170+
x = x.view(-1, original_shape[-1])
171+
x = F.normalize(x, p=2, dim=1)
172+
x = x.view(original_shape)
173+
elif self.normalize == 'max_norm':
174+
original_shape = x.shape
175+
x = x.view(-1, original_shape[-1])
176+
norm_x = torch.norm(x, p=2, dim=1)
177+
x[norm_x > 1.0] = F.normalize(x[norm_x > 1.0], p=2, dim=1)
178+
x = x.view(original_shape)
179+
else:
180+
if self.K:
181+
return self.soft_clip(x)
182+
else:
183+
return x
184+
return x
185+
186+
def soft_clip(self, x):
187+
original_shape = x.shape
188+
x = x.view(-1, original_shape[-1])
189+
direction = F.normalize(x, dim=1)
190+
norm = torch.norm(x, dim=1, keepdim=True)
191+
return (direction * (norm + self.K)).view(original_shape)
192+
193+
141194
class FeatCNN18(nn.Module):
142195
"""
143196
Fully connected NN to learn features on top of image features in the joint embedding space.

0 commit comments

Comments
 (0)