@@ -138,6 +138,59 @@ def soft_clip(self, x):
138
138
return (direction * (norm + self .K )).view (original_shape )
139
139
140
140
141
+ class MatrixApproximation (nn .Module ):
142
+ """
143
+ Fully connected NN to learn features on top of image features in the joint embedding space.
144
+ """
145
+
146
+ def __init__ (self , normalize , input_dim = 2048 , output_dim = 10 , K = None ):
147
+ """
148
+ Constructor to prepare layers for the embedding.
149
+ """
150
+ super (MatrixApproximation , self ).__init__ ()
151
+ self .u = nn .Parameter (torch .randn ((input_dim )))
152
+ self .v = nn .Parameter (torch .randn ((output_dim )))
153
+ self .d = nn .Parameter (torch .randn ((output_dim )))
154
+ self .pad = nn .ZeroPad2d ((0 , 0 , 0 , input_dim - output_dim ))
155
+
156
+ self .normalize = normalize
157
+ self .K = K
158
+
159
+ self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
160
+
161
+ def forward (self , x ):
162
+ """
163
+ Forward pass through the model.
164
+ """
165
+ W = self .pad (torch .diag (self .d )) + torch .ger (self .u , self .v )
166
+ x = torch .matmul (x , W )
167
+
168
+ if self .normalize == 'unit_norm' :
169
+ original_shape = x .shape
170
+ x = x .view (- 1 , original_shape [- 1 ])
171
+ x = F .normalize (x , p = 2 , dim = 1 )
172
+ x = x .view (original_shape )
173
+ elif self .normalize == 'max_norm' :
174
+ original_shape = x .shape
175
+ x = x .view (- 1 , original_shape [- 1 ])
176
+ norm_x = torch .norm (x , p = 2 , dim = 1 )
177
+ x [norm_x > 1.0 ] = F .normalize (x [norm_x > 1.0 ], p = 2 , dim = 1 )
178
+ x = x .view (original_shape )
179
+ else :
180
+ if self .K :
181
+ return self .soft_clip (x )
182
+ else :
183
+ return x
184
+ return x
185
+
186
+ def soft_clip (self , x ):
187
+ original_shape = x .shape
188
+ x = x .view (- 1 , original_shape [- 1 ])
189
+ direction = F .normalize (x , dim = 1 )
190
+ norm = torch .norm (x , dim = 1 , keepdim = True )
191
+ return (direction * (norm + self .K )).view (original_shape )
192
+
193
+
141
194
class FeatCNN18 (nn .Module ):
142
195
"""
143
196
Fully connected NN to learn features on top of image features in the joint embedding space.
0 commit comments