diff --git a/examples/example_runner.ipynb b/examples/example_runner.ipynb index bdb2127bcf..52cc74c5dc 100644 --- a/examples/example_runner.ipynb +++ b/examples/example_runner.ipynb @@ -71,10 +71,10 @@ "evalue": "name 'utils' is not defined", "output_type": "error", "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_config_from_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'nas'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mlogger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msetup_logger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"/log.log\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetLevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mINFO\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'utils' is not defined" + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mconfig\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mutils\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mget_config_from_args\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mconfig_type\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m'nas'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0mlogger\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0msetup_logger\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mconfig\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msave\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0;34m\"/log.log\"\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mlogger\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msetLevel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mlogging\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mINFO\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;31mNameError\u001B[0m: name 'utils' is not defined" ] } ], diff --git a/examples/naslib_tutorial.ipynb b/examples/naslib_tutorial.ipynb index 7d7ac59ae0..74cd1af4b1 100644 --- a/examples/naslib_tutorial.ipynb +++ b/examples/naslib_tutorial.ipynb @@ -159,11 +159,11 @@ "evalue": "No module named 'naslib.search_spaces.simple_cell'", "output_type": "error", "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mnaslib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msearch_spaces\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNasBench201SearchSpace\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mNB201\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# instantiate the search space object\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0msearch_space\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNB201\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/naslib/lib/python3.7/site-packages/naslib/search_spaces/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0msimple_cell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSimpleCellSearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdarts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDartsSearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mnasbench101\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNasBench101SearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mnasbench201\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNasBench201SearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mhierarchical\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mHierarchicalSearchSpace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'naslib.search_spaces.simple_cell'" + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0;32mfrom\u001B[0m \u001B[0mnaslib\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msearch_spaces\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mNasBench201SearchSpace\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mNB201\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;31m# instantiate the search space object\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0msearch_space\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mNB201\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;32m~/anaconda3/envs/naslib/lib/python3.7/site-packages/naslib/search_spaces/__init__.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0msimple_cell\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mSimpleCellSearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mdarts\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mDartsSearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mnasbench101\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mNasBench101SearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mnasbench201\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mNasBench201SearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;34m.\u001B[0m\u001B[0mhierarchical\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgraph\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mHierarchicalSearchSpace\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'naslib.search_spaces.simple_cell'" ] } ], diff --git a/naslib/defaults/darts_defaults.yaml b/naslib/defaults/darts_defaults.yaml index cc1af7a71a..c5af7ae08d 100644 --- a/naslib/defaults/darts_defaults.yaml +++ b/naslib/defaults/darts_defaults.yaml @@ -26,7 +26,7 @@ search: arch_learning_rate: 0.0003 arch_weight_decay: 0.001 output_weights: True - + fidelity: 200 # GDAS @@ -36,7 +36,7 @@ search: # RE sample_size: 10 population_size: 100 - + #LS num_init: 10 @@ -56,7 +56,7 @@ search: #train_portion: 0.9 #data_size: 25000 - + # BANANAS k: 10 num_ensemble: 3 @@ -66,7 +66,7 @@ search: num_arches_to_mutate: 2 max_mutations: 1 num_candidates: 100 - + # BasePredictor predictor_type: var_sparse_gp debug_predictor: False @@ -88,4 +88,4 @@ evaluation: cutout_length: 16 cutout_prob: 1.0 drop_path_prob: 0.2 - auxiliary_weight: 0.4 + auxiliary_weight: 0.4 \ No newline at end of file diff --git a/naslib/defaults/drnas_defaults.yaml b/naslib/defaults/drnas_defaults.yaml new file mode 100644 index 0000000000..d7a045c1a6 --- /dev/null +++ b/naslib/defaults/drnas_defaults.yaml @@ -0,0 +1,125 @@ +# options cifar10, cifar100, ImageNet16-120 reports on their test acc is avaliable +dataset: ImageNet16-120 +# in the code base the deafult value for the seed is 2. +# using random seeds that are logged but log files are not provided +# not mentioned in the paper what are the random seeds are +seed: 99 +# darts (or nb301) +# nb201 +search_space: nasbench301 +out_dir: run +optimizer: drnas + +search: + checkpoint_freq: 5 + # default value batch size in code is 64 + batch_size: 64 + # lr_rate for progressive and original: 0.025 + learning_rate: 0.025 + # lr_rate for progressive and original: 0.025 + learning_rate_min: 0.001 + momentum: 0.9 + # weight_decay for progressive and original: 0.0003 + weight_decay: 0.0003 + # for cifar10 the learning process is 2 stages of 25 epochs each + # in code it states that the number of training epochs has the default value of 100 in nb201 + epochs: 100 + warm_start_epochs: 0 + grad_clip: 5 + # for cifar10 the train and optimization data (50k) is equally partitioned + train_portion: 0.5 + # for cifar10 the train and optimization data (50k) is equally partitioned + data_size: 25000 + + # for the four args the values are same for oridinary and progressive mode for nb201 + cutout: False + cutout_length: 16 + cutout_prob: 1.0 + drop_path_prob: 0.0 + + # for nb201 this value is false + unrolled: False + arch_learning_rate: 0.0003 + # not mentiond for progressive mode but for ordinary it is 1e-3 in nb201 + arch_weight_decay: 0.001 + output_weights: True + + fidelity: 200 + + # GDAS + tau_max: 10 + tau_min: 0.1 + + # RE + sample_size: 10 + population_size: 100 + + #LS + num_init: 10 + + #GSparsity-> Uncomment the lines below for GSparsity + #seed: 50 + #grad_clip: 0 + #threshold: 0.000001 + #weight_decay: 120 + #learning_rate: 0.01 + #momentum: 0.8 + #normalization: div + #normalization_exponent: 0.5 + #batch_size: 256 + #learning_rate_min: 0.0001 + #epochs: 100 + #warm_start_epochs: 0 + #train_portion: 0.9 + #data_size: 25000 + + + # BANANAS + k: 10 + num_ensemble: 3 + acq_fn_type: its + acq_fn_optimization: mutation + encoding_type: path + num_arches_to_mutate: 2 + max_mutations: 1 + num_candidates: 100 + + # BasePredictor + predictor_type: var_sparse_gp + debug_predictor: False + +evaluation: + checkpoint_freq: 30 + # Neither the paper nor the code base indicates the batch size but the default value is 64 + batch_size: 64 + + learning_rate: 0.025 + learning_rate_min: 0.00 + # momentum is 0.9 + momentum: 0.9 + # for cifar weight_decay is 3e-4 + weight_decay: 0.0003 + # cifar's eval is 600 epochs, for imagenet it is 250 + epochs: 250 + # for image net it has 5 epochs of warm starting + warm_start_epochs: 5 + grad_clip: 5 + # uses the whole training data of cifar10 (50K) to train from scratch for 600 epochs + train_portion: 1. + data_size: 50000 + + # cifar10 the cutout is done to have fair comparisons with previous work + cutout: True + # cifar10 cutout length is 16 + cutout_length: 16 + # cifar10 the cutout is done to have fair comparisons with previous work + cutout_prob: 1.0 + # cifar drop out is 0.3 + drop_path_prob: 0.2 + # cifar auxiliary is 0.4 + auxiliary_weight: 0.4 + + + +# has a partial channel variable that for oridinary is 1 and in progressive mode has 4 as the default value. +# mentions some things about regularization scale of l2 and kl (used for dirichlet) in code of nb201 diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index ccffda1c7f..65a9587141 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -88,7 +88,6 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int np.random.seed(self.config.search.seed) torch.manual_seed(self.config.search.seed) - self.optimizer.before_training() checkpoint_freq = self.config.search.checkpoint_freq if self.optimizer.using_step_function: self.scheduler = self.build_search_scheduler( @@ -101,6 +100,8 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int else: start_epoch = self._setup_checkpointers(resume_from, period=checkpoint_freq) + self.optimizer.before_training() + if self.optimizer.using_step_function: self.train_queue, self.valid_queue, _ = self.build_search_dataloaders( self.config @@ -146,7 +147,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int self.train_loss.update(float(train_loss.detach().cpu())) self.val_loss.update(float(val_loss.detach().cpu())) - + self.scheduler.step() end_time = time.time() @@ -179,7 +180,9 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int self.train_top1.avg = train_acc self.val_top1.avg = valid_acc - self.periodic_checkpointer.step(e) + add_checkpointables = self.optimizer.get_checkpointables() + del add_checkpointables["model"] + self.periodic_checkpointer.step(e, **add_checkpointables) anytime_results = self.optimizer.test_statistics() # if anytime_results: @@ -216,8 +219,8 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): evaluate with the current one-shot weights. """ logger.info("Start one-shot evaluation") - self.optimizer.before_training() self._setup_checkpointers(resume_from) + self.optimizer.before_training() loss = torch.nn.CrossEntropyLoss() @@ -286,7 +289,7 @@ def evaluate( best_arch = self.optimizer.get_final_architecture() logger.info(f"Final architecture hash: {best_arch.get_hash()}") - if best_arch.QUERYABLE: + if best_arch.QUERYABLE and (not retrain): if metric is None: metric = Metric.TEST_ACCURACY result = best_arch.query( @@ -408,8 +411,10 @@ def evaluate( logits_valid, target_valid, "val" ) + arch_weights = self.optimizer.get_checkpointables()["arch_weights"] + scheduler.step() - self.periodic_checkpointer.step(e) + self.periodic_checkpointer.step(iteration=e, arch_weights=arch_weights) self._log_and_reset_accuracies(e) # Disable drop path @@ -585,8 +590,11 @@ def _setup_checkpointers( if resume_from: logger.info("loading model from file {}".format(resume_from)) - checkpoint = checkpointer.resume_or_load(resume_from, resume=True) + # if resume=True starts from the last_checkpoint + # if resume=False starts from the path mentioned as resume_from + checkpoint = checkpointer.resume_or_load(resume_from, resume=False) if checkpointer.has_checkpoint(): + self.optimizer.set_checkpointables(checkpoint) return checkpoint.get("iteration", -1) + 1 return 0 diff --git a/naslib/evaluators/zc_evaluator.py b/naslib/evaluators/zc_evaluator.py index 677bba4551..ba5cb2674a 100644 --- a/naslib/evaluators/zc_evaluator.py +++ b/naslib/evaluators/zc_evaluator.py @@ -138,6 +138,7 @@ def single_evaluate(self, test_data, zc_api): logger.info("Querying the predictor") query_time_start = time.time() + # TODO: shouldn't mode="val" be passed? _, _, test_loader, _, _ = utils.get_train_val_loaders(self.config) # Iterate over the architectures, instantiate a graph with each architecture diff --git a/naslib/optimizers/core/metaclasses.py b/naslib/optimizers/core/metaclasses.py index b97ce83e82..d1abeb7456 100644 --- a/naslib/optimizers/core/metaclasses.py +++ b/naslib/optimizers/core/metaclasses.py @@ -121,3 +121,13 @@ def get_checkpointables(self): (dict): with name as key and object as value. e.g. graph, arch weights, optimizers, ... """ pass + + def set_checkpointables(self, checkpointables): + """ + would set values of the saved objects in the checkpoint during training/evaluation. + + Args: + (dict): with name as key and object as value. e.g. op and arch optimizers, arch weights, ... + """ + + pass diff --git a/naslib/optimizers/discrete/bananas/optimizer.py b/naslib/optimizers/discrete/bananas/optimizer.py index d9abebe58a..5fbaaa59d0 100644 --- a/naslib/optimizers/discrete/bananas/optimizer.py +++ b/naslib/optimizers/discrete/bananas/optimizer.py @@ -318,3 +318,4 @@ def get_arch_as_string(self, arch): else: str_arch = str(arch) return str_arch + diff --git a/naslib/optimizers/oneshot/configurable/optimizer.py b/naslib/optimizers/oneshot/configurable/optimizer.py index 23956502cf..025e333ec6 100644 --- a/naslib/optimizers/oneshot/configurable/optimizer.py +++ b/naslib/optimizers/oneshot/configurable/optimizer.py @@ -202,6 +202,11 @@ def get_checkpointables(self): "arch_weights": self.architectural_weights, } + def set_checkpointables(self, checkpointables): + self.op_optimizer = checkpointables.get("op_optimizer") + self.arch_optimizer = checkpointables.get("arch_optimizer") + self.architectural_weights = checkpointables.get("arch_weights") + def before_training(self): """ Move the graph into cuda memory if available. diff --git a/naslib/optimizers/oneshot/darts/optimizer.py b/naslib/optimizers/oneshot/darts/optimizer.py index ab4702a159..bc0b275bd9 100644 --- a/naslib/optimizers/oneshot/darts/optimizer.py +++ b/naslib/optimizers/oneshot/darts/optimizer.py @@ -200,6 +200,11 @@ def get_op_optimizer(self): def get_model_size(self): return count_parameters_in_MB(self.graph) + def set_checkpointables(self, checkpointables): + self.op_optimizer = checkpointables.get("op_optimizer") + self.arch_optimizer = checkpointables.get("arch_optimizer") + self.architectural_weights = checkpointables.get("arch_weights") + def test_statistics(self): # nb301 is not there but we use it anyways to generate the arch strings. # if self.graph.QUERYABLE: diff --git a/naslib/optimizers/oneshot/drnas/optimizer.py b/naslib/optimizers/oneshot/drnas/optimizer.py index 787df25c8b..82e49a2905 100644 --- a/naslib/optimizers/oneshot/drnas/optimizer.py +++ b/naslib/optimizers/oneshot/drnas/optimizer.py @@ -65,9 +65,11 @@ def __init__( self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def new_epoch(self, epoch): + #TODO: use this for the things that should be done for progressive learning + #at the beginning of each epoch super().new_epoch(epoch) - def adapt_search_space(self, search_space, scope=None): + def adapt_search_space(self, search_space, scope=None, **kwargs): """ Same as in darts with a different mixop. If you want to checkpoint the dirichlet 'concentration' parameter (beta) add it to the buffer here. @@ -94,7 +96,7 @@ def step(self, data_train, data_val): self.arch_optimizer.zero_grad() logits_val = self.graph(input_val) val_loss = self.loss(logits_val, target_val) - + # todo: this is the additional loss in eq2 in the paper if self.reg_type == "kl": val_loss += self._get_kl_reg() @@ -182,6 +184,7 @@ def process_weights(self, weights): return weights def apply_weights(self, x, weights): + # TODO: have this changed based on the progressive formulation weighted_sum = sum( w * op(x, None) for w, op in zip(weights, self.primitives) diff --git a/naslib/optimizers/oneshot/gsparsity/optimizer.py b/naslib/optimizers/oneshot/gsparsity/optimizer.py index 814621a9e9..61f202f260 100644 --- a/naslib/optimizers/oneshot/gsparsity/optimizer.py +++ b/naslib/optimizers/oneshot/gsparsity/optimizer.py @@ -419,6 +419,10 @@ def get_op_optimizer(self): def get_model_size(self): return count_parameters_in_MB(self.graph) + def set_checkpointables(self, checkpointables): + self.op_optimizer = checkpointables.get("op_optimizer") + self.op_optimizer_evaluate = checkpointables.get("op_optimizer_evaluate") + def get_checkpointables(self): """ Return all objects that should be saved in a checkpoint during training. diff --git a/naslib/runners/nas/runner.py b/naslib/runners/nas/runner.py index 2c0b96aa7b..75264c4f01 100644 --- a/naslib/runners/nas/runner.py +++ b/naslib/runners/nas/runner.py @@ -45,9 +45,9 @@ } supported_search_spaces = { - 'nasbench101': NasBench101SearchSpace(), - 'nasbench201': NasBench201SearchSpace(), - 'nasbench301': NasBench301SearchSpace(), + 'nasbench101': NasBench101SearchSpace(n_classes=config.n_classes), + 'nasbench201': NasBench201SearchSpace(n_classes=config.n_classes), + 'nasbench301': NasBench301SearchSpace(n_classes=config.n_classes, auxiliary=False), 'nlp': NasBenchNLPSearchSpace(), 'transbench101_micro': TransBench101SearchSpaceMicro(config.dataset), 'transbench101_macro': TransBench101SearchSpaceMacro(), @@ -61,16 +61,15 @@ optimizer = supported_optimizers[config.optimizer] optimizer.adapt_search_space(search_space, dataset_api=dataset_api) - + import torch if config.dataset in ['class_object', 'class_scene']: optimizer.loss = SoftmaxCrossEntropyWithLogits() elif config.dataset == 'autoencoder': optimizer.loss = torch.nn.L1Loss() - trainer = Trainer(optimizer, config, lightweight_output=True) trainer.search(resume_from="") -trainer.evaluate(resume_from="", dataset_api=dataset_api) +trainer.evaluate(resume_from="", dataset_api=dataset_api) \ No newline at end of file diff --git a/naslib/search_spaces/nasbench301/graph.py b/naslib/search_spaces/nasbench301/graph.py index 212fdaf91d..647966c45d 100644 --- a/naslib/search_spaces/nasbench301/graph.py +++ b/naslib/search_spaces/nasbench301/graph.py @@ -402,7 +402,7 @@ def query( genotype = convert_naslib_to_genotype(self) else: genotype = convert_compact_to_genotype(self.compact) - if metric == Metric.VAL_ACCURACY: + if metric == Metric.VAL_ACCURACY or metric == Metric.TEST_ACCURACY: val_acc = dataset_api["nb301_model"][0].predict( config=genotype, representation="genotype" ) diff --git a/naslib/utils/utils.py b/naslib/utils/utils.py index 0ff4d259bb..9c97122c8a 100644 --- a/naslib/utils/utils.py +++ b/naslib/utils/utils.py @@ -9,6 +9,7 @@ from scipy import stats import copy import json +import warnings from collections import OrderedDict @@ -173,7 +174,6 @@ def get_config_from_args(args=None, config_type="nas"): if args is None: args = parse_args() logger.info("Command line args: {}".format(args)) - if args.config_file is None: config = load_default_config(config_type=config_type) else: @@ -196,6 +196,16 @@ def get_config_from_args(args=None, config_type="nas"): config.set_new_allowed(True) config.merge_from_list(args.opts) + if config.dataset == 'cifar10': + config.n_classes = 10 + elif config.dataset == 'cifar100': + config.n_classes = 100 + elif config.dataset == 'ImageNet16-120': + config.n_classes = 120 + else: + warnings.warn("Number of classes was not set. Default 10 is set.") + config.n_classes = 10 + except AttributeError: for arg, value in pairwise(args): config[arg] = value