-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Description
In the rehearsal example there's the following template:
scenario = ClassIncremental(
CIFAR100(data_path="my/data/path", download=True, train=True),
increment=10,
initial_increment=50
)
memory = rehearsal.RehearsalMemory(
memory_size=2000,
herding_method="barycenter"
)
for task_id, taskset in enumerate(scenario):
if task_id > 0:
mem_x, mem_y, mem_t = memory.get()
taskset.add_samples(mem_x, mem_y, mem_t)
loader = DataLoader(taskset, shuffle=True)
for epoch in range(epochs):
for x, y, t in loader:
# Do your training here
# Herding based on the barycenter (as iCaRL did) needs features,
# so we need to extract those features, but beware to use a loader
# without shuffling.
loader = DataLoader(taskset, shuffle=False)
features = my_function_to_extract_features(my_model, loader)
# Important! Draw the raw samples from `scenario[task_id]` to
# re-generate the taskset, otherwise you'd risk sampling from both new
# data and memory data which is probably not what you want to do.
memory.add(*scenario[task_id].get_raw_samples(), features)
How does that change if we add train/valid splits?
for task_id, taskset in enumerate(scenario):
if task_id > 0:
mem_x, mem_y, mem_t = memory.get()
taskset.add_samples(mem_x, mem_y, mem_t)
dataset_train, dataset_val = tasks.split_train_val(taskset, val_split=0.1)
train_loader = tud.DataLoader(dataset_train, shuffle=True)
val_loader = tud.DataLoader(dataset_val, shuffle=True)
for epoch in range(epochs):
for x, y, t in train_loader:
# Do your training here
# Herding based on the barycenter (as iCaRL did) needs features,
# so we need to extract those features, but beware to use a loader
# without shuffling.
unshuffled_loader = DataLoader(taskset, shuffle=False) # --> here should it be taskset or dataset_train?
features = my_function_to_extract_features(my_model, unshuffled_loader)
# Important! Draw the raw samples from `scenario[task_id]` to
# re-generate the taskset, otherwise you'd risk sampling from both new
# data and memory data which is probably not what you want to do.
memory.add(*scenario[task_id].get_raw_samples(), features) # --> scenario[task_id].get_raw_samples() returns all samples in the current taskset?
If unshuffled_loader
uses dataset_train
then len(features)
len(scenario[task_id].get_raw_samples()[0])
.
Another question, is do we need to add samples into memory buffer from both train and valid samples or just train samples? Because, in my understanding the taskset contains all samples before the train/valid split, right?
Metadata
Metadata
Assignees
Labels
No labels