In the validation_forget() function, accuracy has been used without specifying the task argument which throws an AssertionError.
acc = accuracy(pred, label, ignore_index=-100)
I have replaced it with
acc = accuracy(pred, label, task="multiclass", num_classes=5063, ignore_index=-100)
I found 5063 to be the number of unique labels. Is this the right fix?
In the validation_forget() function, accuracy has been used without specifying the task argument which throws an AssertionError.
acc = accuracy(pred, label, ignore_index=-100)I have replaced it with
acc = accuracy(pred, label, task="multiclass", num_classes=5063, ignore_index=-100)I found 5063 to be the number of unique labels. Is this the right fix?