From d6fe57aa99897f4e6592c980c7293dbe1758831e Mon Sep 17 00:00:00 2001 From: John Robertson Date: Wed, 14 Dec 2022 11:43:18 +0100 Subject: [PATCH 1/6] modifed Lukas' code for alpha convergence plotting --- examples/plot_darts.py | 26 +++++++ naslib/defaults/trainer.py | 144 ++++++++++++++++++++++++------------- 2 files changed, 122 insertions(+), 48 deletions(-) create mode 100644 examples/plot_darts.py diff --git a/examples/plot_darts.py b/examples/plot_darts.py new file mode 100644 index 0000000000..f449993a76 --- /dev/null +++ b/examples/plot_darts.py @@ -0,0 +1,26 @@ +import os +import logging +from naslib.defaults.trainer import Trainer +from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch +from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace + +from naslib.utils import set_seed, setup_logger, get_config_from_args + +config = get_config_from_args() # use --help so see the options +config.search.batch_size = 128 +config.search.epochs = 1 +config.save_arch_weights = True +config.plot_arch_weights = True +config.save_arch_weights_path = f"{config.save}/save_arch" +set_seed(config.seed) + +logger = setup_logger(config.save + "/log.log") +logger.setLevel(logging.INFO) # default DEBUG is very verbose + +search_space = SimpleCellSearchSpace() # DartsSearchSpace() # use SimpleCellSearchSpace() for less heavy search + +optimizer = DARTSOptimizer(config) +optimizer.adapt_search_space(search_space) + +trainer = Trainer(optimizer, config) +trainer.search() \ No newline at end of file diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index ccffda1c7f..5db72f8ce7 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -8,6 +8,10 @@ import torch import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from pathlib import Path from fvcore.common.checkpoint import PeriodicCheckpointer from naslib.search_spaces.core.query_metrics import Metric @@ -57,8 +61,8 @@ def __init__(self, optimizer, config, lightweight_output=False): self.val_loss = utils.AverageMeter() n_parameters = optimizer.get_model_size() - # logger.info("param size = %fMB", n_parameters) - self.search_trajectory = utils.AttrDict( + logger.info("param size = %fMB", n_parameters) + self.errors_dict = utils.AttrDict( { "train_acc": [], "train_loss": [], @@ -73,7 +77,8 @@ def __init__(self, optimizer, config, lightweight_output=False): } ) - def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None]=None, report_incumbent=True): + def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None] = None, + report_incumbent=True): """ Start the architecture search. @@ -83,7 +88,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int resume_from (str): Checkpoint file to resume from. If not given then train from scratch. """ - logger.info("Beginning search") + logger.info("Start training") np.random.seed(self.config.search.seed) torch.manual_seed(self.config.search.seed) @@ -108,11 +113,26 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int for e in range(start_epoch, self.epochs): + # create the arch directory (without overwriting) + if self.config.save_arch_weights: + Path(f"{self.config.save_arch_weights_path}/epoch_{e}").mkdir(parents=True, exist_ok=False) + start_time = time.time() self.optimizer.new_epoch(e) + arch_weights_lst = [] if self.optimizer.using_step_function: for step, data_train in enumerate(self.train_queue): + + # save arch weights to array of tensors + if self.config.save_arch_weights: + if len(arch_weights_lst) == 0: + for alpha_i in self.optimizer.architectural_weights: + arch_weights_lst.append(torch.unsqueeze(alpha_i.detach(), dim=0)) + else: + for idx, alpha_i in enumerate(self.optimizer.architectural_weights): + arch_weights_lst[idx] = torch.cat((arch_weights_lst[idx], torch.unsqueeze(alpha_i.detach(), dim=0)), dim=0) + data_train = ( data_train[0].to(self.device), data_train[1].to(self.device, non_blocking=True), @@ -151,11 +171,11 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int end_time = time.time() - self.search_trajectory.train_acc.append(self.train_top1.avg) - self.search_trajectory.train_loss.append(self.train_loss.avg) - self.search_trajectory.valid_acc.append(self.val_top1.avg) - self.search_trajectory.valid_loss.append(self.val_loss.avg) - self.search_trajectory.runtime.append(end_time - start_time) + self.errors_dict.train_acc.append(self.train_top1.avg) + self.errors_dict.train_loss.append(self.train_loss.avg) + self.errors_dict.valid_acc.append(self.val_top1.avg) + self.errors_dict.valid_loss.append(self.val_loss.avg) + self.errors_dict.runtime.append(end_time - start_time) else: end_time = time.time() # TODO: nasbench101 does not have train_loss, valid_loss, test_loss implemented, so this is a quick fix for now @@ -168,28 +188,28 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int ) = self.optimizer.train_statistics(report_incumbent) train_loss, valid_loss, test_loss = -1, -1, -1 - self.search_trajectory.train_acc.append(train_acc) - self.search_trajectory.train_loss.append(train_loss) - self.search_trajectory.valid_acc.append(valid_acc) - self.search_trajectory.valid_loss.append(valid_loss) - self.search_trajectory.test_acc.append(test_acc) - self.search_trajectory.test_loss.append(test_loss) - self.search_trajectory.runtime.append(end_time - start_time) - self.search_trajectory.train_time.append(train_time) + self.errors_dict.train_acc.append(train_acc) + self.errors_dict.train_loss.append(train_loss) + self.errors_dict.valid_acc.append(valid_acc) + self.errors_dict.valid_loss.append(valid_loss) + self.errors_dict.test_acc.append(test_acc) + self.errors_dict.test_loss.append(test_loss) + self.errors_dict.runtime.append(end_time - start_time) + self.errors_dict.train_time.append(train_time) self.train_top1.avg = train_acc self.val_top1.avg = valid_acc self.periodic_checkpointer.step(e) anytime_results = self.optimizer.test_statistics() - # if anytime_results: + if anytime_results: # record anytime performance - # self.search_trajectory.arch_eval.append(anytime_results) - # log_every_n_seconds( - # logging.INFO, - # "Epoch {}, Anytime results: {}".format(e, anytime_results), - # n=5, - # ) + self.errors_dict.arch_eval.append(anytime_results) + log_every_n_seconds( + logging.INFO, + "Epoch {}, Anytime results: {}".format(e, anytime_results), + n=5, + ) self._log_to_json() @@ -198,6 +218,23 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int if after_epoch is not None: after_epoch(e) + logger.info(f"Saving architectural weight tensors: {self.config.save_arch_weights_path}/epoch_{e}") + + # writing arch weights to file and plotting + if self.config.save_arch_weights: + if not Path(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_0.pt').exists(): + for idx in range(len(arch_weights_lst)): + if self.config.plot_arch_weights: + self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) + torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') + else: + for idx in range(len(self.optimizer.architectural_weights)): + old_arch_weights = torch.load(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') + arch_weights_lst[idx] = torch.cat((old_arch_weights, arch_weights_lst[idx]), dim=0) + if self.config.plot_arch_weights: + self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) + torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') + self.optimizer.after_training() if summary_writer is not None: @@ -240,9 +277,9 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): end_time = time.time() - self.search_trajectory.valid_acc.append(self.val_top1.avg) - self.search_trajectory.valid_loss.append(self.val_loss.avg) - self.search_trajectory.runtime.append(end_time - start_time) + self.errors_dict.valid_acc.append(self.val_top1.avg) + self.errors_dict.valid_loss.append(self.val_loss.avg) + self.errors_dict.runtime.append(end_time - start_time) self._log_to_json() @@ -250,13 +287,13 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): return self.val_top1.avg def evaluate( - self, - retrain:bool=True, - search_model:str="", - resume_from:str="", - best_arch:Graph=None, - dataset_api:object=None, - metric:Metric=None, + self, + retrain: bool = True, + search_model: str = "", + resume_from: str = "", + best_arch: Graph = None, + dataset_api: object = None, + metric: Metric = None, ): """ Evaluate the final architecture as given from the optimizer. @@ -284,7 +321,7 @@ def evaluate( self._setup_checkpointers(search_model) # required to load the architecture best_arch = self.optimizer.get_final_architecture() - logger.info(f"Final architecture hash: {best_arch.get_hash()}") + logger.info("Final architecture:\n" + best_arch.modules_str()) if best_arch.QUERYABLE: if metric is None: @@ -293,7 +330,6 @@ def evaluate( metric=metric, dataset=self.config.dataset, dataset_api=dataset_api ) logger.info("Queried results ({}): {}".format(metric, result)) - return result else: best_arch.to(self.device) if retrain: @@ -366,14 +402,14 @@ def evaluate( logits_train = best_arch(input_train) train_loss = loss(logits_train, target_train) if hasattr( - best_arch, "auxilary_logits" + best_arch, "auxilary_logits" ): # darts specific stuff log_first_n(logging.INFO, "Auxiliary is used", n=10) auxiliary_loss = loss( best_arch.auxilary_logits(), target_train ) train_loss += ( - self.config.evaluation.auxiliary_weight * auxiliary_loss + self.config.evaluation.auxiliary_weight * auxiliary_loss ) train_loss.backward() if grad_clip: @@ -395,9 +431,8 @@ def evaluate( if self.valid_queue: best_arch.eval() for i, (input_valid, target_valid) in enumerate( - self.valid_queue + self.valid_queue ): - input_valid = input_valid.to(self.device).float() target_valid = target_valid.to(self.device).float() @@ -453,8 +488,6 @@ def evaluate( ) ) - return top1.avg - @staticmethod def build_search_dataloaders(config): train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders( @@ -496,10 +529,12 @@ def build_eval_scheduler(optimizer, config): def _log_and_reset_accuracies(self, epoch, writer=None): logger.info( - "Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( + "Epoch {} done. Train accuracy (top1, top5): {:.5f}, {:.5f}, Validation accuracy: {:.5f}, {:.5f}".format( epoch, self.train_top1.avg, + self.train_top5.avg, self.val_top1.avg, + self.val_top5.avg, ) ) @@ -550,7 +585,7 @@ def _prepare_dataloaders(self, config, mode="train"): self.test_queue = test_queue def _setup_checkpointers( - self, resume_from="", search=True, period=1, **add_checkpointables + self, resume_from="", search=True, period=1, **add_checkpointables ): """ Sets up a periodic chechkpointer which can be used to save checkpoints @@ -596,14 +631,27 @@ def _log_to_json(self): os.makedirs(self.config.save) if not self.lightweight_output: with codecs.open( - os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" + os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" ) as file: - json.dump(self.search_trajectory, file, separators=(",", ":")) + json.dump(self.errors_dict, file, separators=(",", ":")) else: with codecs.open( - os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" + os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" ) as file: - lightweight_dict = copy.deepcopy(self.search_trajectory) + lightweight_dict = copy.deepcopy(self.errors_dict) for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: lightweight_dict.pop(key) json.dump([self.config, lightweight_dict], file, separators=(",", ":")) + + def _plot_architectural_weights(self, idx, alpha_i, epoch_num): + # Todo check if softmax is suitable here. In which range are the weights for e.g. GDAS + alpha_i = torch.softmax(alpha_i.detach(), dim=1).cpu().numpy() + g = sns.heatmap(alpha_i.T, cmap=sns.diverging_palette(230, 0, 90, 60, as_cmap=True)) + g.set_xticklabels(g.get_xticklabels(), rotation=60) + + plt.title(f"arch weights for operation {idx}") + plt.xlabel("steps") + plt.ylabel("alpha values") + plt.tight_layout() + plt.savefig(f"{self.config.save_arch_weights_path}/epoch_{epoch_num}/heatmap_{idx}.png") + plt.close() \ No newline at end of file From a7fba2290aee68ec6f9642eb266665ead665c26b Mon Sep 17 00:00:00 2001 From: John Robertson Date: Tue, 20 Dec 2022 16:06:41 +0100 Subject: [PATCH 2/6] initial changes for alpha plotting procedure --- examples/plot_weights.py | 36 +++++++++++++++++++++++++++ naslib/defaults/trainer.py | 20 +++++++++------ naslib/utils/vis/__init__.py | 3 +++ naslib/utils/vis/utils.py | 48 ++++++++++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 examples/plot_weights.py create mode 100644 naslib/utils/vis/__init__.py create mode 100644 naslib/utils/vis/utils.py diff --git a/examples/plot_weights.py b/examples/plot_weights.py new file mode 100644 index 0000000000..b00b17c3cc --- /dev/null +++ b/examples/plot_weights.py @@ -0,0 +1,36 @@ +import logging +from naslib.defaults.trainer import Trainer +from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch +from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace + +from naslib.utils import set_seed, setup_logger, get_config_from_args +from naslib.utils.vis import plot_architectural_weights + +config = get_config_from_args() # use --help so see the options +config.save_arch_weights = True +# config.search.batch_size = 32 +config.search.epochs = 1 +set_seed(config.seed) + +logger = setup_logger(config.save + "/log.log") +logger.setLevel(logging.INFO) # default DEBUG is very verbose + +search_space = SimpleCellSearchSpace() + +optimizer = DARTSOptimizer(config) +optimizer.adapt_search_space(search_space) + +trainer = Trainer(optimizer, config) +trainer.search() + +for u,v in trainer.optimizer.graph.edges: + print(trainer.optimizer.graph.edges[u,v].op) + +# for u,v in optimizer.graph.edges: +# print(optimizer.graph.edges[u,v].op) + +# trainer.evaluate() + +# plot_architectural_weights(config) + + diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index 5db72f8ce7..5701b5e3d4 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -111,6 +111,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int self.config ) + arch_weights = [] for e in range(start_epoch, self.epochs): # create the arch directory (without overwriting) @@ -123,15 +124,14 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int arch_weights_lst = [] if self.optimizer.using_step_function: for step, data_train in enumerate(self.train_queue): - - # save arch weights to array of tensors - if self.config.save_arch_weights: - if len(arch_weights_lst) == 0: - for alpha_i in self.optimizer.architectural_weights: - arch_weights_lst.append(torch.unsqueeze(alpha_i.detach(), dim=0)) + + if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: + if len(arch_weights) == 0: + for edge_weights in self.optimizer.architectural_weights: + arch_weights.append(torch.unsqueeze(edge_weights.detach(), dim=0)) else: - for idx, alpha_i in enumerate(self.optimizer.architectural_weights): - arch_weights_lst[idx] = torch.cat((arch_weights_lst[idx], torch.unsqueeze(alpha_i.detach(), dim=0)), dim=0) + for i, edge_weights in enumerate(self.optimizer.architectural_weights): + arch_weights[i] = torch.cat((arch_weights[i], torch.unsqueeze(edge_weights.detach(), dim=0)), dim=0) data_train = ( data_train[0].to(self.device), @@ -240,6 +240,10 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int if summary_writer is not None: summary_writer.close() + logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") + if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: + torch.save(arch_weights, f'{self.config.save}/arch_weights.pt') + logger.info("Training finished") def evaluate_oneshot(self, resume_from="", dataloader=None): diff --git a/naslib/utils/vis/__init__.py b/naslib/utils/vis/__init__.py new file mode 100644 index 0000000000..868c86a414 --- /dev/null +++ b/naslib/utils/vis/__init__.py @@ -0,0 +1,3 @@ +from .utils import ( + plot_architectural_weights, +) \ No newline at end of file diff --git a/naslib/utils/vis/utils.py b/naslib/utils/vis/utils.py new file mode 100644 index 0000000000..934440739b --- /dev/null +++ b/naslib/utils/vis/utils.py @@ -0,0 +1,48 @@ +import logging +import torch +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib.cm import ScalarMappable +import seaborn as sns + +logger = logging.getLogger(__name__) + +def plot_architectural_weights(config): + arch_weights = torch.load(f'{config.save}/arch_weights.pt') + + for i, edge_weights in enumerate(arch_weights): + arch_weights[i] = torch.softmax(edge_weights.detach(), dim=1).cpu().numpy() + + num_epochs = config.search.epochs + cmap = sns.diverging_palette(230, 0, 90, 60, as_cmap=True) + + fig, axes = plt.subplots(nrows=len(arch_weights)) + cax = fig.add_axes([.9, 0.05, .0125, 0.925]) + + for i, edge_weights in enumerate(arch_weights): + num_steps, num_alphas = edge_weights.shape + sns.heatmap( + edge_weights.T, + cmap=cmap, + vmin=np.min(arch_weights), + vmax=np.max(arch_weights), + ax=axes[i], + cbar=True, + cbar_ax=cax + ) + + if i == len(arch_weights) - 1: + axes[i].set_xticks(np.arange(stop=num_steps+num_steps/num_epochs, step=num_steps/num_epochs)) + axes[i].set_xticklabels(np.arange(num_epochs+1), rotation=360, fontdict=dict(fontsize=6)) + else: + axes[i].set_xticks([]) + + axes[i].set_ylabel('edge', fontdict=dict(fontsize=6)) + axes[i].set_yticks(np.arange(num_alphas)) + axes[i].set_yticklabels(['op'] * num_alphas, rotation=360, fontdict=dict(fontsize=6)) + + fig.tight_layout(rect=[0, 0, 0.9, 1], pad=0.5) + fig.savefig(f"{config.save}/arch_weights.png", dpi=300) + plt.close() + From 6ca93d9ccaa764354ed54e89dead4be2f3a476d8 Mon Sep 17 00:00:00 2001 From: John Robertson Date: Mon, 9 Jan 2023 15:14:15 +0100 Subject: [PATCH 3/6] one shot model alpha weights visualization --- ...lot_darts.py => plot_save_arch_weights.py} | 17 +++--- examples/plot_weights.py | 36 ------------ naslib/defaults/trainer.py | 54 ++++------------- naslib/utils/utils.py | 4 +- naslib/utils/vis/utils.py | 58 ++++++++++++++----- 5 files changed, 67 insertions(+), 102 deletions(-) rename examples/{plot_darts.py => plot_save_arch_weights.py} (57%) delete mode 100644 examples/plot_weights.py diff --git a/examples/plot_darts.py b/examples/plot_save_arch_weights.py similarity index 57% rename from examples/plot_darts.py rename to examples/plot_save_arch_weights.py index f449993a76..894f8fd035 100644 --- a/examples/plot_darts.py +++ b/examples/plot_save_arch_weights.py @@ -1,26 +1,29 @@ import os import logging from naslib.defaults.trainer import Trainer -from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch -from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace +from naslib.optimizers import DARTSOptimizer +from naslib.search_spaces import NasBench201SearchSpace from naslib.utils import set_seed, setup_logger, get_config_from_args +from naslib.utils.vis import plot_architectural_weights config = get_config_from_args() # use --help so see the options -config.search.batch_size = 128 -config.search.epochs = 1 +config.search.epochs = 50 config.save_arch_weights = True config.plot_arch_weights = True -config.save_arch_weights_path = f"{config.save}/save_arch" + set_seed(config.seed) logger = setup_logger(config.save + "/log.log") logger.setLevel(logging.INFO) # default DEBUG is very verbose -search_space = SimpleCellSearchSpace() # DartsSearchSpace() # use SimpleCellSearchSpace() for less heavy search +search_space = NasBench201SearchSpace() optimizer = DARTSOptimizer(config) optimizer.adapt_search_space(search_space) trainer = Trainer(optimizer, config) -trainer.search() \ No newline at end of file +# trainer.search() + +plot_architectural_weights(config, optimizer) + diff --git a/examples/plot_weights.py b/examples/plot_weights.py deleted file mode 100644 index b00b17c3cc..0000000000 --- a/examples/plot_weights.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging -from naslib.defaults.trainer import Trainer -from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch -from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace - -from naslib.utils import set_seed, setup_logger, get_config_from_args -from naslib.utils.vis import plot_architectural_weights - -config = get_config_from_args() # use --help so see the options -config.save_arch_weights = True -# config.search.batch_size = 32 -config.search.epochs = 1 -set_seed(config.seed) - -logger = setup_logger(config.save + "/log.log") -logger.setLevel(logging.INFO) # default DEBUG is very verbose - -search_space = SimpleCellSearchSpace() - -optimizer = DARTSOptimizer(config) -optimizer.adapt_search_space(search_space) - -trainer = Trainer(optimizer, config) -trainer.search() - -for u,v in trainer.optimizer.graph.edges: - print(trainer.optimizer.graph.edges[u,v].op) - -# for u,v in optimizer.graph.edges: -# print(optimizer.graph.edges[u,v].op) - -# trainer.evaluate() - -# plot_architectural_weights(config) - - diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index 5701b5e3d4..9017ce0e77 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -18,6 +18,7 @@ from naslib.utils import utils from naslib.utils.logging import log_every_n_seconds, log_first_n +from naslib.utils.vis import plot_architectural_weights from typing import Callable from .additional_primitives import DropPathWrapper @@ -114,18 +115,14 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int arch_weights = [] for e in range(start_epoch, self.epochs): - # create the arch directory (without overwriting) - if self.config.save_arch_weights: - Path(f"{self.config.save_arch_weights_path}/epoch_{e}").mkdir(parents=True, exist_ok=False) - start_time = time.time() self.optimizer.new_epoch(e) - arch_weights_lst = [] if self.optimizer.using_step_function: for step, data_train in enumerate(self.train_queue): - - if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: + + # save arch weights to array of tensors + if self.config.save_arch_weights: if len(arch_weights) == 0: for edge_weights in self.optimizer.architectural_weights: arch_weights.append(torch.unsqueeze(edge_weights.detach(), dim=0)) @@ -218,32 +215,18 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int if after_epoch is not None: after_epoch(e) - logger.info(f"Saving architectural weight tensors: {self.config.save_arch_weights_path}/epoch_{e}") - - # writing arch weights to file and plotting - if self.config.save_arch_weights: - if not Path(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_0.pt').exists(): - for idx in range(len(arch_weights_lst)): - if self.config.plot_arch_weights: - self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) - torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') - else: - for idx in range(len(self.optimizer.architectural_weights)): - old_arch_weights = torch.load(f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') - arch_weights_lst[idx] = torch.cat((old_arch_weights, arch_weights_lst[idx]), dim=0) - if self.config.plot_arch_weights: - self._plot_architectural_weights(idx, alpha_i=arch_weights_lst[idx], epoch_num=e) - torch.save(arch_weights_lst[idx], f'{self.config.save_arch_weights_path}/epoch_{e}/tensor_{idx}.pt') - + # save and possibly plot architectural weights + logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") + if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: + torch.save(arch_weights, f'{self.config.save}/arch_weights.pt') + if hasattr(self.config, "plot_arch_weights") and self.config.plot_arch_weights: + plot_architectural_weights(self.config, self.optimizer) + self.optimizer.after_training() if summary_writer is not None: summary_writer.close() - logger.info(f"Saving architectural weight tensors: {self.config.save}/arch_weights.pt") - if hasattr(self.config, "save_arch_weights") and self.config.save_arch_weights: - torch.save(arch_weights, f'{self.config.save}/arch_weights.pt') - logger.info("Training finished") def evaluate_oneshot(self, resume_from="", dataloader=None): @@ -645,17 +628,4 @@ def _log_to_json(self): lightweight_dict = copy.deepcopy(self.errors_dict) for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: lightweight_dict.pop(key) - json.dump([self.config, lightweight_dict], file, separators=(",", ":")) - - def _plot_architectural_weights(self, idx, alpha_i, epoch_num): - # Todo check if softmax is suitable here. In which range are the weights for e.g. GDAS - alpha_i = torch.softmax(alpha_i.detach(), dim=1).cpu().numpy() - g = sns.heatmap(alpha_i.T, cmap=sns.diverging_palette(230, 0, 90, 60, as_cmap=True)) - g.set_xticklabels(g.get_xticklabels(), rotation=60) - - plt.title(f"arch weights for operation {idx}") - plt.xlabel("steps") - plt.ylabel("alpha values") - plt.tight_layout() - plt.savefig(f"{self.config.save_arch_weights_path}/epoch_{epoch_num}/heatmap_{idx}.png") - plt.close() \ No newline at end of file + json.dump([self.config, lightweight_dict], file, separators=(",", ":")) \ No newline at end of file diff --git a/naslib/utils/utils.py b/naslib/utils/utils.py index 999540c822..efb48ba48f 100644 --- a/naslib/utils/utils.py +++ b/naslib/utils/utils.py @@ -322,8 +322,8 @@ def get_train_val_loaders(config, mode="train"): data = config.data dataset = config.dataset seed = config.search.seed - batch_size = config.batch_size - train_portion = config.train_portion + batch_size = config.batch_size if hasattr(config, "batch_size") else config.search.batch_size + train_portion = config.train_portion if hasattr(config, "train_portion") else config.search.train_portion config = config.search if mode == "train" else config.evaluation if dataset == "cifar10": train_transform, valid_transform = _data_transforms_cifar10(config) diff --git a/naslib/utils/vis/utils.py b/naslib/utils/vis/utils.py index 934440739b..de38fd9abb 100644 --- a/naslib/utils/vis/utils.py +++ b/naslib/utils/vis/utils.py @@ -8,20 +8,43 @@ logger = logging.getLogger(__name__) -def plot_architectural_weights(config): +def plot_architectural_weights(config, optimizer): + # load alphas arch_weights = torch.load(f'{config.save}/arch_weights.pt') - + + # discretize and softmax alphas for i, edge_weights in enumerate(arch_weights): - arch_weights[i] = torch.softmax(edge_weights.detach(), dim=1).cpu().numpy() + total_steps, num_alphas = edge_weights.shape + num_epochs = config.search.epochs + steps_per_epoch = total_steps // num_epochs + + disc_weights = torch.mean(edge_weights.detach().reshape(-1, steps_per_epoch, num_alphas), axis=1).cpu() + arch_weights[i] = torch.softmax(disc_weights, dim=-1).numpy() - num_epochs = config.search.epochs + # define diverging colormap with NASLib colors cmap = sns.diverging_palette(230, 0, 90, 60, as_cmap=True) - fig, axes = plt.subplots(nrows=len(arch_weights)) - cax = fig.add_axes([.9, 0.05, .0125, 0.925]) + # unpack search space information + edge_names, op_names = [], [] + for graph in optimizer.graph._get_child_graphs(single_instances=True): + for u, v, edge_data in graph.edges.data(): + if edge_data.has("alpha"): + edge_names.append((u, v)) + op_names.append([op.get_op_name for op in edge_data.op.get_embedded_ops()]) - for i, edge_weights in enumerate(arch_weights): + # define figure and axes + fig, axes = plt.subplots(nrows=len(arch_weights), figsize=(10, np.array(op_names).size/10)) + cax = fig.add_axes([.95, 0.12, 0.0075, 0.795]) + cax.tick_params(labelsize=6) + cax.set_title('alphas', fontdict=dict(fontsize=6)) + + # unpack number of epochs + num_epochs = config.search.epochs + + # iterate over arch weights and create heatmaps + for (i, edge_weights) in enumerate(arch_weights): num_steps, num_alphas = edge_weights.shape + sns.heatmap( edge_weights.T, cmap=cmap, @@ -33,16 +56,21 @@ def plot_architectural_weights(config): ) if i == len(arch_weights) - 1: - axes[i].set_xticks(np.arange(stop=num_steps+num_steps/num_epochs, step=num_steps/num_epochs)) - axes[i].set_xticklabels(np.arange(num_epochs+1), rotation=360, fontdict=dict(fontsize=6)) + # axes[i].set_xticks(np.arange(stop=num_steps+num_steps/num_epochs, step=num_steps/num_epochs)) + # axes[i].set_xticklabels(np.arange(num_epochs+1), rotation=360, fontdict=dict(fontsize=6)) + axes[i].xaxis.set_tick_params(labelsize=6) + axes[i].set_xlabel("Epoch", fontdict=dict(fontsize=6)) else: axes[i].set_xticks([]) + + axes[i].set_ylabel(edge_names[i], fontdict=dict(fontsize=6)) + axes[i].set_yticks(np.arange(num_alphas) + 0.5) + axes[i].set_yticklabels(op_names[i], rotation=360, fontdict=dict(fontsize=5)) - axes[i].set_ylabel('edge', fontdict=dict(fontsize=6)) - axes[i].set_yticks(np.arange(num_alphas)) - axes[i].set_yticklabels(['op'] * num_alphas, rotation=360, fontdict=dict(fontsize=6)) + fig.tight_layout(rect=[0, 0, 0.95, 0.925], pad=0.25) + + _, search_space, dataset, optimizer, seed = config.save.split('/') + fig.suptitle(f"optimizer: {optimizer}, search space: {search_space}, dataset: {dataset}, seed: {seed}") - fig.tight_layout(rect=[0, 0, 0.9, 1], pad=0.5) - fig.savefig(f"{config.save}/arch_weights.png", dpi=300) - plt.close() + fig.savefig(f"{config.save}/arch_weights.pdf", dpi=300) From d3d83001cac720e7372947adab79331871e7cc95 Mon Sep 17 00:00:00 2001 From: John Robertson Date: Mon, 30 Jan 2023 09:31:33 +0100 Subject: [PATCH 4/6] divided plots into sets of 4 heatmaps --- examples/plot_save_arch_weights.py | 33 +++++++-- naslib/utils/vis/utils.py | 108 +++++++++++++++-------------- 2 files changed, 81 insertions(+), 60 deletions(-) diff --git a/examples/plot_save_arch_weights.py b/examples/plot_save_arch_weights.py index 894f8fd035..15dd8dfac5 100644 --- a/examples/plot_save_arch_weights.py +++ b/examples/plot_save_arch_weights.py @@ -1,29 +1,48 @@ import os import logging from naslib.defaults.trainer import Trainer -from naslib.optimizers import DARTSOptimizer -from naslib.search_spaces import NasBench201SearchSpace +from naslib.optimizers import DARTSOptimizer, GDASOptimizer, DrNASOptimizer +from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, NasBench301SearchSpace -from naslib.utils import set_seed, setup_logger, get_config_from_args +from naslib.utils import set_seed, setup_logger, get_config_from_args, create_exp_dir from naslib.utils.vis import plot_architectural_weights -config = get_config_from_args() # use --help so see the options +config = get_config_from_args() # use --help so see the options config.search.epochs = 50 config.save_arch_weights = True config.plot_arch_weights = True +config.optimizer = 'gdas' +config.search_space = 'nasbench301' +config.save = "{}/{}/{}/{}/{}".format( + config.out_dir, config.search_space, config.dataset, config.optimizer, config.seed +) +create_exp_dir(config.save) +create_exp_dir(config.save + "/search") # required for the checkpoints +create_exp_dir(config.save + "/eval") + +optimizers = { + 'gdas': GDASOptimizer(config), + 'darts': DARTSOptimizer(config), + 'drnas': DrNASOptimizer(config), +} + +search_spaces = { + 'nasbench101': NasBench101SearchSpace(), + 'nasbench201': NasBench201SearchSpace(), + 'nasbench301': NasBench301SearchSpace(), +} set_seed(config.seed) logger = setup_logger(config.save + "/log.log") logger.setLevel(logging.INFO) # default DEBUG is very verbose -search_space = NasBench201SearchSpace() +search_space = search_spaces[config.search_space] -optimizer = DARTSOptimizer(config) +optimizer = optimizers[config.optimizer] optimizer.adapt_search_space(search_space) trainer = Trainer(optimizer, config) # trainer.search() plot_architectural_weights(config, optimizer) - diff --git a/naslib/utils/vis/utils.py b/naslib/utils/vis/utils.py index de38fd9abb..1daaf6da1e 100644 --- a/naslib/utils/vis/utils.py +++ b/naslib/utils/vis/utils.py @@ -9,68 +9,70 @@ logger = logging.getLogger(__name__) def plot_architectural_weights(config, optimizer): - # load alphas - arch_weights = torch.load(f'{config.save}/arch_weights.pt') - - # discretize and softmax alphas - for i, edge_weights in enumerate(arch_weights): - total_steps, num_alphas = edge_weights.shape - num_epochs = config.search.epochs - steps_per_epoch = total_steps // num_epochs - - disc_weights = torch.mean(edge_weights.detach().reshape(-1, steps_per_epoch, num_alphas), axis=1).cpu() - arch_weights[i] = torch.softmax(disc_weights, dim=-1).numpy() - - # define diverging colormap with NASLib colors - cmap = sns.diverging_palette(230, 0, 90, 60, as_cmap=True) + all_weights = torch.load(f'{config.save}/arch_weights.pt') # load alphas # unpack search space information - edge_names, op_names = [], [] + alpha_dict = {} + min_soft, max_soft = np.inf, -np.inf for graph in optimizer.graph._get_child_graphs(single_instances=True): - for u, v, edge_data in graph.edges.data(): + for edge_weights, (u, v, edge_data) in zip(all_weights, graph.edges.data()): + if edge_data.has("alpha"): - edge_names.append((u, v)) - op_names.append([op.get_op_name for op in edge_data.op.get_embedded_ops()]) + total_steps, num_alphas = edge_weights.shape + steps_per_epoch = total_steps // config.search.epochs + disc_weights = torch.mean(edge_weights.detach().reshape(-1, steps_per_epoch, num_alphas), axis=1).cpu() + soft_weights = torch.softmax(disc_weights, dim=-1).numpy() - # define figure and axes - fig, axes = plt.subplots(nrows=len(arch_weights), figsize=(10, np.array(op_names).size/10)) - cax = fig.add_axes([.95, 0.12, 0.0075, 0.795]) - cax.tick_params(labelsize=6) - cax.set_title('alphas', fontdict=dict(fontsize=6)) + cell_name = edge_data['cell_name'] if hasattr(edge_data, 'cell_name') else "" + alpha_dict[(u, v, cell_name)] = {} + alpha_dict[(u, v, cell_name)]['op_names'] = [op.get_op_name for op in edge_data.op.get_embedded_ops()] + alpha_dict[(u, v, cell_name)]['alphas'] = soft_weights - # unpack number of epochs - num_epochs = config.search.epochs + min_soft = min(min_soft, np.min(soft_weights)) + max_soft = max(max_soft, np.max(soft_weights)) - # iterate over arch weights and create heatmaps - for (i, edge_weights) in enumerate(arch_weights): - num_steps, num_alphas = edge_weights.shape + max_rows = 4 # plot heatmaps in increments of n_rows edges + for start_id in range(0, len(alpha_dict.keys()), max_rows): - sns.heatmap( - edge_weights.T, - cmap=cmap, - vmin=np.min(arch_weights), - vmax=np.max(arch_weights), - ax=axes[i], - cbar=True, - cbar_ax=cax - ) + # calculate number of rows in plot + n_rows = min(max_rows, len(alpha_dict.keys())-start_id) + logger.info(f"Creating plot {config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png") - if i == len(arch_weights) - 1: - # axes[i].set_xticks(np.arange(stop=num_steps+num_steps/num_epochs, step=num_steps/num_epochs)) - # axes[i].set_xticklabels(np.arange(num_epochs+1), rotation=360, fontdict=dict(fontsize=6)) - axes[i].xaxis.set_tick_params(labelsize=6) - axes[i].set_xlabel("Epoch", fontdict=dict(fontsize=6)) - else: - axes[i].set_xticks([]) - - axes[i].set_ylabel(edge_names[i], fontdict=dict(fontsize=6)) - axes[i].set_yticks(np.arange(num_alphas) + 0.5) - axes[i].set_yticklabels(op_names[i], rotation=360, fontdict=dict(fontsize=5)) + # define figure and axes and NASLib colormap + fig, axes = plt.subplots(nrows=n_rows, figsize=(10, max_rows)) + cmap = sns.diverging_palette(230, 0, 90, 60, as_cmap=True) + + # iterate over arch weights and create heatmaps + for ax_id, (u, v, cell_name) in enumerate(list(alpha_dict.keys())[start_id:start_id+n_rows]): + map = sns.heatmap( + alpha_dict[u, v, cell_name]['alphas'].T, + cmap=cmap, + vmin=min_soft, + vmax=max_soft, + ax=axes[ax_id], + cbar=True + ) + + op_names = alpha_dict[(u, v, cell_name)]['op_names'] + + if ax_id < n_rows-1: + axes[ax_id].set_xticks([]) + axes[ax_id].set_ylabel(f"{u, v}", fontdict=dict(fontsize=6)) + axes[ax_id].set_yticks(np.arange(len(op_names)) + 0.5) + fontsize = max(6, 40/len(op_names)) + axes[ax_id].set_yticklabels(op_names, rotation=360, fontdict=dict(fontsize=fontsize)) + if cell_name != "": + axes[ax_id].set_title(cell_name, fontdict=dict(fontsize=6)) + cbar = map.collections[0].colorbar + cbar.ax.tick_params(labelsize=6) + cbar.ax.set_title('softmax', fontdict=dict(fontsize=6)) - fig.tight_layout(rect=[0, 0, 0.95, 0.925], pad=0.25) - - _, search_space, dataset, optimizer, seed = config.save.split('/') - fig.suptitle(f"optimizer: {optimizer}, search space: {search_space}, dataset: {dataset}, seed: {seed}") + # axes[ax_id].set_xticks(np.arange(config.search.epochs+1)) + # axes[ax_id].set_xticklabels(np.arange(config.search.epochs+1)) + axes[ax_id].xaxis.set_tick_params(labelsize=6) + axes[ax_id].set_xlabel("Epoch", fontdict=dict(fontsize=6)) - fig.savefig(f"{config.save}/arch_weights.pdf", dpi=300) + fig.suptitle(f"optimizer: {config.optimizer}, search space: {config.search_space}, dataset: {config.dataset}, seed: {config.seed}") + fig.tight_layout() + fig.savefig(f"{config.save}/arch_weights_{start_id+1}to{start_id+n_rows}.png", dpi=300) From 00823491af9d087a2937c55556908797f02601d8 Mon Sep 17 00:00:00 2001 From: John Robertson Date: Thu, 2 Feb 2023 11:14:02 +0100 Subject: [PATCH 5/6] resolved discrepencies unrelated to pull request and added trainer.py --- naslib/defaults/trainer.py | 101 ++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index 9017ce0e77..35e60fb637 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -8,9 +8,6 @@ import torch import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns - from pathlib import Path from fvcore.common.checkpoint import PeriodicCheckpointer @@ -63,7 +60,7 @@ def __init__(self, optimizer, config, lightweight_output=False): n_parameters = optimizer.get_model_size() logger.info("param size = %fMB", n_parameters) - self.errors_dict = utils.AttrDict( + self.search_trajectory = utils.AttrDict( { "train_acc": [], "train_loss": [], @@ -78,8 +75,7 @@ def __init__(self, optimizer, config, lightweight_output=False): } ) - def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None] = None, - report_incumbent=True): + def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int], None]=None, report_incumbent=True): """ Start the architecture search. @@ -89,7 +85,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int resume_from (str): Checkpoint file to resume from. If not given then train from scratch. """ - logger.info("Start training") + logger.info("Beginning search") np.random.seed(self.config.search.seed) torch.manual_seed(self.config.search.seed) @@ -122,7 +118,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int for step, data_train in enumerate(self.train_queue): # save arch weights to array of tensors - if self.config.save_arch_weights: + if self.config.save_arch_weights is True: if len(arch_weights) == 0: for edge_weights in self.optimizer.architectural_weights: arch_weights.append(torch.unsqueeze(edge_weights.detach(), dim=0)) @@ -168,11 +164,11 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int end_time = time.time() - self.errors_dict.train_acc.append(self.train_top1.avg) - self.errors_dict.train_loss.append(self.train_loss.avg) - self.errors_dict.valid_acc.append(self.val_top1.avg) - self.errors_dict.valid_loss.append(self.val_loss.avg) - self.errors_dict.runtime.append(end_time - start_time) + self.search_trajectory.train_acc.append(self.train_top1.avg) + self.search_trajectory.train_loss.append(self.train_loss.avg) + self.search_trajectory.valid_acc.append(self.val_top1.avg) + self.search_trajectory.valid_loss.append(self.val_loss.avg) + self.search_trajectory.runtime.append(end_time - start_time) else: end_time = time.time() # TODO: nasbench101 does not have train_loss, valid_loss, test_loss implemented, so this is a quick fix for now @@ -185,28 +181,28 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int ) = self.optimizer.train_statistics(report_incumbent) train_loss, valid_loss, test_loss = -1, -1, -1 - self.errors_dict.train_acc.append(train_acc) - self.errors_dict.train_loss.append(train_loss) - self.errors_dict.valid_acc.append(valid_acc) - self.errors_dict.valid_loss.append(valid_loss) - self.errors_dict.test_acc.append(test_acc) - self.errors_dict.test_loss.append(test_loss) - self.errors_dict.runtime.append(end_time - start_time) - self.errors_dict.train_time.append(train_time) + self.search_trajectory.train_acc.append(train_acc) + self.search_trajectory.train_loss.append(train_loss) + self.search_trajectory.valid_acc.append(valid_acc) + self.search_trajectory.valid_loss.append(valid_loss) + self.search_trajectory.test_acc.append(test_acc) + self.search_trajectory.test_loss.append(test_loss) + self.search_trajectory.runtime.append(end_time - start_time) + self.search_trajectory.train_time.append(train_time) self.train_top1.avg = train_acc self.val_top1.avg = valid_acc self.periodic_checkpointer.step(e) anytime_results = self.optimizer.test_statistics() - if anytime_results: - # record anytime performance - self.errors_dict.arch_eval.append(anytime_results) - log_every_n_seconds( - logging.INFO, - "Epoch {}, Anytime results: {}".format(e, anytime_results), - n=5, - ) + # if anytime_results: + # # record anytime performance + # self.search_trajectory.arch_eval.append(anytime_results) + # log_every_n_seconds( + # logging.INFO, + # "Epoch {}, Anytime results: {}".format(e, anytime_results), + # n=5, + # ) self._log_to_json() @@ -264,9 +260,9 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): end_time = time.time() - self.errors_dict.valid_acc.append(self.val_top1.avg) - self.errors_dict.valid_loss.append(self.val_loss.avg) - self.errors_dict.runtime.append(end_time - start_time) + self.search_trajectory.valid_acc.append(self.val_top1.avg) + self.search_trajectory.valid_loss.append(self.val_loss.avg) + self.search_trajectory.runtime.append(end_time - start_time) self._log_to_json() @@ -274,13 +270,13 @@ def evaluate_oneshot(self, resume_from="", dataloader=None): return self.val_top1.avg def evaluate( - self, - retrain: bool = True, - search_model: str = "", - resume_from: str = "", - best_arch: Graph = None, - dataset_api: object = None, - metric: Metric = None, + self, + retrain:bool=True, + search_model:str="", + resume_from:str="", + best_arch:Graph=None, + dataset_api:object=None, + metric:Metric=None, ): """ Evaluate the final architecture as given from the optimizer. @@ -308,7 +304,7 @@ def evaluate( self._setup_checkpointers(search_model) # required to load the architecture best_arch = self.optimizer.get_final_architecture() - logger.info("Final architecture:\n" + best_arch.modules_str()) + logger.info(f"Final architecture hash: {best_arch.get_hash()}") if best_arch.QUERYABLE: if metric is None: @@ -317,6 +313,7 @@ def evaluate( metric=metric, dataset=self.config.dataset, dataset_api=dataset_api ) logger.info("Queried results ({}): {}".format(metric, result)) + return result else: best_arch.to(self.device) if retrain: @@ -396,7 +393,7 @@ def evaluate( best_arch.auxilary_logits(), target_train ) train_loss += ( - self.config.evaluation.auxiliary_weight * auxiliary_loss + self.config.evaluation.auxiliary_weight * auxiliary_loss ) train_loss.backward() if grad_clip: @@ -418,8 +415,9 @@ def evaluate( if self.valid_queue: best_arch.eval() for i, (input_valid, target_valid) in enumerate( - self.valid_queue + self.valid_queue ): + input_valid = input_valid.to(self.device).float() target_valid = target_valid.to(self.device).float() @@ -475,6 +473,9 @@ def evaluate( ) ) + return top1.avg + + @staticmethod def build_search_dataloaders(config): train_queue, valid_queue, test_queue, _, _ = utils.get_train_val_loaders( @@ -516,12 +517,10 @@ def build_eval_scheduler(optimizer, config): def _log_and_reset_accuracies(self, epoch, writer=None): logger.info( - "Epoch {} done. Train accuracy (top1, top5): {:.5f}, {:.5f}, Validation accuracy: {:.5f}, {:.5f}".format( + "Epoch {} done. Train accuracy: {:.5f}, Validation accuracy: {:.5f}".format( epoch, self.train_top1.avg, - self.train_top5.avg, - self.val_top1.avg, - self.val_top5.avg, + self.val_top1.avg ) ) @@ -572,7 +571,7 @@ def _prepare_dataloaders(self, config, mode="train"): self.test_queue = test_queue def _setup_checkpointers( - self, resume_from="", search=True, period=1, **add_checkpointables + self, resume_from="", search=True, period=1, **add_checkpointables ): """ Sets up a periodic chechkpointer which can be used to save checkpoints @@ -618,14 +617,14 @@ def _log_to_json(self): os.makedirs(self.config.save) if not self.lightweight_output: with codecs.open( - os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" + os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" ) as file: - json.dump(self.errors_dict, file, separators=(",", ":")) + json.dump(self.search_trajectory, file, separators=(",", ":")) else: with codecs.open( - os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" + os.path.join(self.config.save, "errors.json"), "w", encoding="utf-8" ) as file: - lightweight_dict = copy.deepcopy(self.errors_dict) + lightweight_dict = copy.deepcopy(self.search_trajectory) for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: lightweight_dict.pop(key) json.dump([self.config, lightweight_dict], file, separators=(",", ":")) \ No newline at end of file From 3d689cf438bb00084a76793a9258fb9ba74b1e33 Mon Sep 17 00:00:00 2001 From: John Robertson Date: Thu, 2 Feb 2023 11:22:11 +0100 Subject: [PATCH 6/6] added utils change and vis file to commit --- naslib/utils/__init__.py | 3 ++- naslib/utils/vis/utils.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/naslib/utils/__init__.py b/naslib/utils/__init__.py index db300252a6..37f300f838 100644 --- a/naslib/utils/__init__.py +++ b/naslib/utils/__init__.py @@ -9,7 +9,8 @@ parse_args, get_train_val_loaders, get_project_root, - compute_scores + compute_scores, + create_exp_dir ) from .logging import setup_logger from .get_dataset_api import get_dataset_api, get_zc_benchmark_api, load_sampled_architectures diff --git a/naslib/utils/vis/utils.py b/naslib/utils/vis/utils.py index 1daaf6da1e..a9504ec101 100644 --- a/naslib/utils/vis/utils.py +++ b/naslib/utils/vis/utils.py @@ -67,8 +67,6 @@ def plot_architectural_weights(config, optimizer): cbar.ax.tick_params(labelsize=6) cbar.ax.set_title('softmax', fontdict=dict(fontsize=6)) - # axes[ax_id].set_xticks(np.arange(config.search.epochs+1)) - # axes[ax_id].set_xticklabels(np.arange(config.search.epochs+1)) axes[ax_id].xaxis.set_tick_params(labelsize=6) axes[ax_id].set_xlabel("Epoch", fontdict=dict(fontsize=6))