Skip to content

Loading model #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: Develop_copy
Choose a base branch
from
Open
8 changes: 4 additions & 4 deletions examples/example_runner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@
"evalue": "name 'utils' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-4-11fe646b1b18>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_config_from_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'nas'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mlogger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msetup_logger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"/log.log\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetLevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mINFO\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'utils' is not defined"
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)",
"\u001B[0;32m<ipython-input-4-11fe646b1b18>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mconfig\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mutils\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mget_config_from_args\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mconfig_type\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m'nas'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0mlogger\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0msetup_logger\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mconfig\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msave\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0;34m\"/log.log\"\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mlogger\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msetLevel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mlogging\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mINFO\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;31mNameError\u001B[0m: name 'utils' is not defined"
]
}
],
Expand Down
10 changes: 5 additions & 5 deletions examples/naslib_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@
"evalue": "No module named 'naslib.search_spaces.simple_cell'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-76203c895428>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mnaslib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msearch_spaces\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNasBench201SearchSpace\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mNB201\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# instantiate the search space object\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0msearch_space\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNB201\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/naslib/lib/python3.7/site-packages/naslib/search_spaces/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0msimple_cell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSimpleCellSearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdarts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDartsSearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mnasbench101\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNasBench101SearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mnasbench201\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNasBench201SearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mhierarchical\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mHierarchicalSearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'naslib.search_spaces.simple_cell'"
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
"\u001B[0;32m<ipython-input-3-76203c895428>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0;32mfrom\u001B[0m \u001B[0mnaslib\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msearch_spaces\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mNasBench201SearchSpace\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mNB201\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;31m# instantiate the search space object\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0msearch_space\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mNB201\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;32m~/anaconda3/envs/naslib/lib/python3.7/site-packages/naslib/search_spaces/__init__.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0msimple_cell\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mSimpleCellSearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mdarts\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mDartsSearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mnasbench101\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mNasBench101SearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mnasbench201\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mNasBench201SearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mhierarchical\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mHierarchicalSearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
"\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'naslib.search_spaces.simple_cell'"
]
}
],
Expand Down
10 changes: 5 additions & 5 deletions naslib/defaults/darts_defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ search:
arch_learning_rate: 0.0003
arch_weight_decay: 0.001
output_weights: True

fidelity: 200

# GDAS
Expand All @@ -36,7 +36,7 @@ search:
# RE
sample_size: 10
population_size: 100

#LS
num_init: 10

Expand All @@ -56,7 +56,7 @@ search:
#train_portion: 0.9
#data_size: 25000


# BANANAS
k: 10
num_ensemble: 3
Expand All @@ -66,7 +66,7 @@ search:
num_arches_to_mutate: 2
max_mutations: 1
num_candidates: 100

# BasePredictor
predictor_type: var_sparse_gp
debug_predictor: False
Expand All @@ -88,4 +88,4 @@ evaluation:
cutout_length: 16
cutout_prob: 1.0
drop_path_prob: 0.2
auxiliary_weight: 0.4
auxiliary_weight: 0.4
125 changes: 125 additions & 0 deletions naslib/defaults/drnas_defaults.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# options cifar10, cifar100, ImageNet16-120 reports on their test acc is avaliable
dataset: ImageNet16-120
# in the code base the deafult value for the seed is 2.
# using random seeds that are logged but log files are not provided
# not mentioned in the paper what are the random seeds are
seed: 99
# darts (or nb301)
# nb201
search_space: nasbench301
out_dir: run
optimizer: drnas

search:
checkpoint_freq: 5
# default value batch size in code is 64
batch_size: 64
# lr_rate for progressive and original: 0.025
learning_rate: 0.025
# lr_rate for progressive and original: 0.025
learning_rate_min: 0.001
momentum: 0.9
# weight_decay for progressive and original: 0.0003
weight_decay: 0.0003
# for cifar10 the learning process is 2 stages of 25 epochs each
# in code it states that the number of training epochs has the default value of 100 in nb201
epochs: 100
warm_start_epochs: 0
grad_clip: 5
# for cifar10 the train and optimization data (50k) is equally partitioned
train_portion: 0.5
# for cifar10 the train and optimization data (50k) is equally partitioned
data_size: 25000

# for the four args the values are same for oridinary and progressive mode for nb201
cutout: False
cutout_length: 16
cutout_prob: 1.0
drop_path_prob: 0.0

# for nb201 this value is false
unrolled: False
arch_learning_rate: 0.0003
# not mentiond for progressive mode but for ordinary it is 1e-3 in nb201
arch_weight_decay: 0.001
output_weights: True

fidelity: 200

# GDAS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to make yaml files generally more readable, should focus only on specific optimizer settings @Neonkraft ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The darts_defualts.yaml was reverted to the format of the Develop_copy branch.

tau_max: 10
tau_min: 0.1

# RE
sample_size: 10
population_size: 100

#LS
num_init: 10

#GSparsity-> Uncomment the lines below for GSparsity
#seed: 50
#grad_clip: 0
#threshold: 0.000001
#weight_decay: 120
#learning_rate: 0.01
#momentum: 0.8
#normalization: div
#normalization_exponent: 0.5
#batch_size: 256
#learning_rate_min: 0.0001
#epochs: 100
#warm_start_epochs: 0
#train_portion: 0.9
#data_size: 25000


# BANANAS
k: 10
num_ensemble: 3
acq_fn_type: its
acq_fn_optimization: mutation
encoding_type: path
num_arches_to_mutate: 2
max_mutations: 1
num_candidates: 100

# BasePredictor
predictor_type: var_sparse_gp
debug_predictor: False

evaluation:
checkpoint_freq: 30
# Neither the paper nor the code base indicates the batch size but the default value is 64
batch_size: 64

learning_rate: 0.025
learning_rate_min: 0.00
# momentum is 0.9
momentum: 0.9
# for cifar weight_decay is 3e-4
weight_decay: 0.0003
# cifar's eval is 600 epochs, for imagenet it is 250
epochs: 250
# for image net it has 5 epochs of warm starting
warm_start_epochs: 5
grad_clip: 5
# uses the whole training data of cifar10 (50K) to train from scratch for 600 epochs
train_portion: 1.
data_size: 50000

# cifar10 the cutout is done to have fair comparisons with previous work
cutout: True
# cifar10 cutout length is 16
cutout_length: 16
# cifar10 the cutout is done to have fair comparisons with previous work
cutout_prob: 1.0
# cifar drop out is 0.3
drop_path_prob: 0.2
# cifar auxiliary is 0.4
auxiliary_weight: 0.4



# has a partial channel variable that for oridinary is 1 and in progressive mode has 4 as the default value.
# mentions some things about regularization scale of l2 and kl (used for dirichlet) in code of nb201
22 changes: 15 additions & 7 deletions naslib/defaults/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
np.random.seed(self.config.search.seed)
torch.manual_seed(self.config.search.seed)

self.optimizer.before_training()
checkpoint_freq = self.config.search.checkpoint_freq
if self.optimizer.using_step_function:
self.scheduler = self.build_search_scheduler(
Expand All @@ -101,6 +100,8 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
else:
start_epoch = self._setup_checkpointers(resume_from, period=checkpoint_freq)

self.optimizer.before_training()

if self.optimizer.using_step_function:
self.train_queue, self.valid_queue, _ = self.build_search_dataloaders(
self.config
Expand Down Expand Up @@ -146,7 +147,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int

self.train_loss.update(float(train_loss.detach().cpu()))
self.val_loss.update(float(val_loss.detach().cpu()))

self.scheduler.step()

end_time = time.time()
Expand Down Expand Up @@ -179,7 +180,9 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
self.train_top1.avg = train_acc
self.val_top1.avg = valid_acc

self.periodic_checkpointer.step(e)
add_checkpointables = self.optimizer.get_checkpointables()
del add_checkpointables["model"]
self.periodic_checkpointer.step(e, **add_checkpointables)

anytime_results = self.optimizer.test_statistics()
# if anytime_results:
Expand Down Expand Up @@ -216,8 +219,8 @@ def evaluate_oneshot(self, resume_from="", dataloader=None):
evaluate with the current one-shot weights.
"""
logger.info("Start one-shot evaluation")
self.optimizer.before_training()
self._setup_checkpointers(resume_from)
self.optimizer.before_training()

loss = torch.nn.CrossEntropyLoss()

Expand Down Expand Up @@ -286,7 +289,7 @@ def evaluate(
best_arch = self.optimizer.get_final_architecture()
logger.info(f"Final architecture hash: {best_arch.get_hash()}")

if best_arch.QUERYABLE:
if best_arch.QUERYABLE and (not retrain):
if metric is None:
metric = Metric.TEST_ACCURACY
result = best_arch.query(
Expand Down Expand Up @@ -408,8 +411,10 @@ def evaluate(
logits_valid, target_valid, "val"
)

arch_weights = self.optimizer.get_checkpointables()["arch_weights"]

scheduler.step()
self.periodic_checkpointer.step(e)
self.periodic_checkpointer.step(iteration=e, arch_weights=arch_weights)
self._log_and_reset_accuracies(e)

# Disable drop path
Expand Down Expand Up @@ -585,8 +590,11 @@ def _setup_checkpointers(

if resume_from:
logger.info("loading model from file {}".format(resume_from))
checkpoint = checkpointer.resume_or_load(resume_from, resume=True)
# if resume=True starts from the last_checkpoint
# if resume=False starts from the path mentioned as resume_from
checkpoint = checkpointer.resume_or_load(resume_from, resume=False)
if checkpointer.has_checkpoint():
self.optimizer.set_checkpointables(checkpoint)
return checkpoint.get("iteration", -1) + 1
return 0

Expand Down
1 change: 1 addition & 0 deletions naslib/evaluators/zc_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def single_evaluate(self, test_data, zc_api):
logger.info("Querying the predictor")
query_time_start = time.time()

# TODO: shouldn't mode="val" be passed?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense for me.

_, _, test_loader, _, _ = utils.get_train_val_loaders(self.config)

# Iterate over the architectures, instantiate a graph with each architecture
Expand Down
10 changes: 10 additions & 0 deletions naslib/optimizers/core/metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,13 @@ def get_checkpointables(self):
(dict): with name as key and object as value. e.g. graph, arch weights, optimizers, ...
"""
pass

def set_checkpointables(self, checkpointables):
"""
would set values of the saved objects in the checkpoint during training/evaluation.

Args:
(dict): with name as key and object as value. e.g. op and arch optimizers, arch weights, ...
"""

pass
1 change: 1 addition & 0 deletions naslib/optimizers/discrete/bananas/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,4 @@ def get_arch_as_string(self, arch):
else:
str_arch = str(arch)
return str_arch

5 changes: 5 additions & 0 deletions naslib/optimizers/oneshot/configurable/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ def get_checkpointables(self):
"arch_weights": self.architectural_weights,
}

def set_checkpointables(self, checkpointables):
self.op_optimizer = checkpointables.get("op_optimizer")
self.arch_optimizer = checkpointables.get("arch_optimizer")
self.architectural_weights = checkpointables.get("arch_weights")

def before_training(self):
"""
Move the graph into cuda memory if available.
Expand Down
5 changes: 5 additions & 0 deletions naslib/optimizers/oneshot/darts/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def get_op_optimizer(self):
def get_model_size(self):
return count_parameters_in_MB(self.graph)

def set_checkpointables(self, checkpointables):
self.op_optimizer = checkpointables.get("op_optimizer")
self.arch_optimizer = checkpointables.get("arch_optimizer")
self.architectural_weights = checkpointables.get("arch_weights")

def test_statistics(self):
# nb301 is not there but we use it anyways to generate the arch strings.
# if self.graph.QUERYABLE:
Expand Down
7 changes: 5 additions & 2 deletions naslib/optimizers/oneshot/drnas/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def __init__(
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def new_epoch(self, epoch):
#TODO: use this for the things that should be done for progressive learning
#at the beginning of each epoch
super().new_epoch(epoch)

def adapt_search_space(self, search_space, scope=None):
def adapt_search_space(self, search_space, scope=None, **kwargs):
"""
Same as in darts with a different mixop.
If you want to checkpoint the dirichlet 'concentration' parameter (beta) add it to the buffer here.
Expand All @@ -94,7 +96,7 @@ def step(self, data_train, data_val):
self.arch_optimizer.zero_grad()
logits_val = self.graph(input_val)
val_loss = self.loss(logits_val, target_val)

# todo: this is the additional loss in eq2 in the paper
if self.reg_type == "kl":
val_loss += self._get_kl_reg()

Expand Down Expand Up @@ -182,6 +184,7 @@ def process_weights(self, weights):
return weights

def apply_weights(self, x, weights):
# TODO: have this changed based on the progressive formulation
weighted_sum = sum(
w * op(x, None)
for w, op in zip(weights, self.primitives)
Expand Down
4 changes: 4 additions & 0 deletions naslib/optimizers/oneshot/gsparsity/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ def get_op_optimizer(self):
def get_model_size(self):
return count_parameters_in_MB(self.graph)

def set_checkpointables(self, checkpointables):
self.op_optimizer = checkpointables.get("op_optimizer")
self.op_optimizer_evaluate = checkpointables.get("op_optimizer_evaluate")

def get_checkpointables(self):
"""
Return all objects that should be saved in a checkpoint during training.
Expand Down
Loading