@@ -40,9 +40,8 @@ def __init__(self, args):
40
40
self ._eval_every_x_epochs = args .get ("eval_every_x_epochs" )
41
41
42
42
self ._use_mimic_score = args .get ("mimic_score" )
43
- self ._use_less_forget = args .get ("less_forget" )
44
- self ._lambda_schedule = args .get ("lambda_schedule" , True )
45
- self ._use_ranking = args .get ("ranking_loss" )
43
+ self ._less_forget = args .get ("less_forget" )
44
+ self ._ranking_loss = args .get ("ranking_loss" )
46
45
47
46
self ._network = network .BasicNet (
48
47
args ["convnet" ],
@@ -62,10 +61,6 @@ def __init__(self, args):
62
61
63
62
self ._finetuning_config = args .get ("finetuning_config" )
64
63
65
- self ._lambda = args .get ("base_lambda" , 5 )
66
- self ._nb_negatives = args .get ("nb_negatives" , 2 )
67
- self ._margin = args .get ("ranking_margin" , 0.2 )
68
-
69
64
self ._weight_generation = args .get ("weight_generation" )
70
65
71
66
self ._herding_indexes = []
@@ -79,6 +74,10 @@ def __init__(self, args):
79
74
self ._args = args
80
75
self ._args ["_logs" ] = {}
81
76
77
+ self ._during_finetune = False
78
+ self ._clip_classifier = None
79
+ self ._align_weights_after_epoch = False
80
+
82
81
def _after_task (self , inc_dataset ):
83
82
if "scale" not in self ._args ["_logs" ]:
84
83
self ._args ["_logs" ]["scale" ] = []
@@ -205,11 +204,11 @@ def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
205
204
old_outputs = self ._old_model (inputs )
206
205
old_features = old_outputs ["raw_features" ]
207
206
208
- if self ._use_less_forget :
209
- if self ._lambda_schedule :
210
- scheduled_lambda = self ._lambda * math .sqrt (self ._n_classes / self ._task_size )
207
+ if self ._less_forget :
208
+ if self ._less_forget [ "scheduled_factor" ] :
209
+ scheduled_lambda = self ._less_forget [ "lambda" ] * math .sqrt (self ._n_classes / self ._task_size )
211
210
else :
212
- scheduled_lambda = 1.
211
+ scheduled_lambda = self . _less_forget [ "lambda" ]
213
212
214
213
lessforget_loss = scheduled_lambda * losses .embeddings_similarity (
215
214
old_features , features
@@ -225,14 +224,14 @@ def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
225
224
loss += mimic_loss
226
225
self ._metrics ["mimic" ] += mimic_loss .item ()
227
226
228
- if self ._use_ranking :
229
- ranking_loss = losses .ucir_ranking (
227
+ if self ._ranking_loss :
228
+ ranking_loss = self . _ranking_loss [ "factor" ] * losses .ucir_ranking (
230
229
logits ,
231
230
targets ,
232
231
self ._n_classes ,
233
232
self ._task_size ,
234
- nb_negatives = max (self ._nb_negatives , self ._task_size ),
235
- margin = self ._margin
233
+ nb_negatives = min (self ._ranking_loss [ "nb_negatives" ] , self ._task_size ),
234
+ margin = self ._ranking_loss [ "margin" ]
236
235
)
237
236
loss += ranking_loss
238
237
self ._metrics ["rank" ] += ranking_loss .item ()
0 commit comments