Skip to content

Commit fb997b8

Browse files
Fix bug about UCIR.
1 parent cb6efdd commit fb997b8

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

inclearn/models/ucir.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ def __init__(self, args):
4040
self._eval_every_x_epochs = args.get("eval_every_x_epochs")
4141

4242
self._use_mimic_score = args.get("mimic_score")
43-
self._use_less_forget = args.get("less_forget")
44-
self._lambda_schedule = args.get("lambda_schedule", True)
45-
self._use_ranking = args.get("ranking_loss")
43+
self._less_forget = args.get("less_forget")
44+
self._ranking_loss = args.get("ranking_loss")
4645

4746
self._network = network.BasicNet(
4847
args["convnet"],
@@ -62,10 +61,6 @@ def __init__(self, args):
6261

6362
self._finetuning_config = args.get("finetuning_config")
6463

65-
self._lambda = args.get("base_lambda", 5)
66-
self._nb_negatives = args.get("nb_negatives", 2)
67-
self._margin = args.get("ranking_margin", 0.2)
68-
6964
self._weight_generation = args.get("weight_generation")
7065

7166
self._herding_indexes = []
@@ -79,6 +74,10 @@ def __init__(self, args):
7974
self._args = args
8075
self._args["_logs"] = {}
8176

77+
self._during_finetune = False
78+
self._clip_classifier = None
79+
self._align_weights_after_epoch = False
80+
8281
def _after_task(self, inc_dataset):
8382
if "scale" not in self._args["_logs"]:
8483
self._args["_logs"]["scale"] = []
@@ -205,11 +204,11 @@ def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
205204
old_outputs = self._old_model(inputs)
206205
old_features = old_outputs["raw_features"]
207206

208-
if self._use_less_forget:
209-
if self._lambda_schedule:
210-
scheduled_lambda = self._lambda * math.sqrt(self._n_classes / self._task_size)
207+
if self._less_forget:
208+
if self._less_forget["scheduled_factor"]:
209+
scheduled_lambda = self._less_forget["lambda"] * math.sqrt(self._n_classes / self._task_size)
211210
else:
212-
scheduled_lambda = 1.
211+
scheduled_lambda = self._less_forget["lambda"]
213212

214213
lessforget_loss = scheduled_lambda * losses.embeddings_similarity(
215214
old_features, features
@@ -225,14 +224,14 @@ def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
225224
loss += mimic_loss
226225
self._metrics["mimic"] += mimic_loss.item()
227226

228-
if self._use_ranking:
229-
ranking_loss = losses.ucir_ranking(
227+
if self._ranking_loss:
228+
ranking_loss = self._ranking_loss["factor"] * losses.ucir_ranking(
230229
logits,
231230
targets,
232231
self._n_classes,
233232
self._task_size,
234-
nb_negatives=max(self._nb_negatives, self._task_size),
235-
margin=self._margin
233+
nb_negatives=min(self._ranking_loss["nb_negatives"], self._task_size),
234+
margin=self._ranking_loss["margin"]
236235
)
237236
loss += ranking_loss
238237
self._metrics["rank"] += ranking_loss.item()

options/ucir/ucir_cifar100.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ scheduling:
3939
gamma: 0.1
4040
lr_decay: 0.1
4141
optimizer: sgd
42-
epochs: 160
42+
epochs: 1 #60
4343

4444
weight_generation:
4545
type: imprinted

0 commit comments

Comments
 (0)