diff --git a/src/confopt/blackbox/__init__.py b/src/confopt/blackbox/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/confopt/blackbox/optimizer_sampler.py b/src/confopt/blackbox/optimizer_sampler.py new file mode 100644 index 00000000..65c07d3f --- /dev/null +++ b/src/confopt/blackbox/optimizer_sampler.py @@ -0,0 +1,579 @@ +from __future__ import annotations + +from typing import Any + +from ConfigSpace import Categorical, ConfigurationSpace, Float, Integer + +from confopt.profiles.profile_config import BaseProfile +from confopt.train.experiment import SamplerType, SearchSpaceType + + +class OneShotOptimizerSampler: + def __init__(self, search_space: SearchSpaceType): + self.search_space = search_space + self.CONFIG_SUFFIX = "_config" + self.init_config_space() + + def init_config_space(self) -> ConfigurationSpace: + """Initializes the configuration space for the one-shot optimizer. + + Returns: + ConfigurationSpace: Configuration space for the one-shot optimizer + """ + ##### Sampler Space ##### + self.sampler_space = ConfigurationSpace( + name="sampler", + space={ + "sampler": Categorical( + "sampler_type", ["darts", "drnas", "gdas", "reinmax"] + ), + }, + ) + + gdas_cs = ConfigurationSpace( + name="gdas_space", + space={ + "tau_min": Float("tau_min", bounds=(0.1, 2)), + "tau_max": Float("tau_max", bounds=(5, 15)), + }, + ) + + self.sampler_space.add_configuration_space( + prefix=f"gdas{self.CONFIG_SUFFIX}", + configuration_space=gdas_cs, + parent_hyperparameter={ + "parent": self.sampler_space["sampler_type"], + "value": "gdas", + }, + ) + + self.sampler_space.add_configuration_space( + prefix="reinmax_config", + configuration_space=gdas_cs, + parent_hyperparameter={ + "parent": self.sampler_space["sampler_type"], + "value": "reinmax", + }, + ) + + ##### LoRA Space ##### + self.lora_space = ConfigurationSpace( + name="lora", + space={ + "use_lora": Categorical("use_lora", [True, False]), + }, + ) + + lora_cs = ConfigurationSpace( + name="lora_space", + space={ + "r": Categorical("r", [1, 2, 4, 8]), + "lora_alpha": Integer("lora_alpha", bounds=(1, 8)), + "lora_dropout": Float("lora_dropout", bounds=(0.0, 1.0)), + "lora_warm_epochs": Integer("lora_warm_epochs", bounds=(5, 25)), + }, + ) + + self.lora_space.add_configuration_space( + prefix=f"lora{self.CONFIG_SUFFIX}", + configuration_space=lora_cs, + parent_hyperparameter={ + "parent": self.lora_space["use_lora"], + "value": True, + }, + ) + + ##### Perturbation Space ##### + self.perturbation_space = ConfigurationSpace( + name="perturbation", + space={ + "use_perturbation": Categorical("use_perturbation", [True, False]), + }, + ) + + perturbation_cs = ConfigurationSpace( + name="perturbation_space", + space={ + "perturbation": Categorical("perturbation", ["random", "adversarial"]), + "epsilon": Float("epsilon", bounds=(0.05, 1.0)), + }, + ) + + adversarial_cs = ConfigurationSpace( + name="adversarial_space", + space={ + "steps": Integer("steps", bounds=(1, 50)), + "random_start": Categorical("random_start", [True, False]), + }, + ) + + perturbation_cs.add_configuration_space( + prefix=f"adversarial{self.CONFIG_SUFFIX}", + configuration_space=adversarial_cs, + parent_hyperparameter={ + "parent": perturbation_cs["perturbation"], + "value": "adversarial", + }, + ) + + self.perturbation_space.add_configuration_space( + prefix="perturbation_config", + configuration_space=perturbation_cs, + parent_hyperparameter={ + "parent": self.perturbation_space["use_perturbation"], + "value": True, + }, + ) + + ##### Partial Connection Space ##### + self.partial_connection_space = ConfigurationSpace( + name="partial_connection", + space={ + "use_partial_connection": Categorical( + "use_partial_connection", [True, False] + ), + }, + ) + + partial_connection_cs = ConfigurationSpace( + name="partial_connection_space", + space={ + "k": Categorical("k", [1, 2, 4, 8, 16]), + }, + ) + + self.partial_connection_space.add_configuration_space( + prefix=f"partial_connection{self.CONFIG_SUFFIX}", + configuration_space=partial_connection_cs, + parent_hyperparameter={ + "parent": self.partial_connection_space["use_partial_connection"], + "value": True, + }, + ) + + ##### Prune Space ##### + self.prune_space = ConfigurationSpace( + name="prune", + space={ + "use_prune": Categorical("use_prune", [True, False]), + }, + ) + + prune_cs = ConfigurationSpace( + name="prune_space", + space={ + "n_prune": Integer("n_prune", bounds=(1, 5)), + "prune_interval": Categorical("prune_interval", [5, 10]), + }, + ) + + self.prune_space.add_configuration_space( + prefix=f"prune{self.CONFIG_SUFFIX}", + configuration_space=prune_cs, + parent_hyperparameter={ + "parent": self.prune_space["use_prune"], + "value": True, + }, + ) + + ##### Arch Attention Space ##### + self.arch_attention_space = ConfigurationSpace( + name="arch_attention", + space={ + "use_arch_attention": Categorical("use_arch_attention", [True, False]), + }, + ) + + #### Arch Params Combine Function Space #### + self.sampler_arch_combine_fn = ConfigurationSpace( + name="sampler_arch_combine_fn", + space={ + "sampler_arch_combine_fn": Categorical( + "sampler_arch_combine_fn", ["softmax", "sigmoid"] + ), + }, + ) + + def sample_sampler( + self, sampler: SamplerType | None, sample_sampler_config: bool + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the sampler configuration. + + Args: + sampler: Sampler to be used. Sampled if None. + sample_sampler_config: Whether to sample the hyperparameters of the + optimizer. If False, uses the default values. + + Returns: + dict[str, Any]: Sampler configuration + """ + if sampler is not None and sample_sampler_config is False: # Nothing to sample + return {"sampler_type": sampler.value}, {} + + def split_config( + config: dict[str, Any] + ) -> tuple[dict[str, Any], dict[str, Any]]: + base_params = { + k: v for k, v in config.items() if self.CONFIG_SUFFIX + ":" not in k + } + config_params = { + k.split(":")[-1]: v + for k, v in config.items() + if self.CONFIG_SUFFIX + ":" in k + } + + return base_params, config_params + + # Sampler is given, but have to sample its configuration + if sampler is not None: + while True: + config = self.sampler_space.sample_configuration(1) + if config["sampler_type"] == sampler.value: + break + else: # Sample the sampler and its configuration + config = self.sampler_space.sample_configuration(1) + + base_config, extra_config = split_config(config) + extra_config = extra_config if sample_sampler_config else {} + + return base_config, extra_config + + def sample_lora( + self, lora: bool | None, sample_lora_config: bool + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the LoRA configuration. + + Args: + lora: Whether to use LoRA. Sampled if None. + sample_lora_config: Whether to sample the hyperparameters of the lora. + If False, uses the default values. + + Returns: + dict[str, Any]: LoRA configuration + """ + if lora is False: + return {"lora_rank": 0}, {} + + if lora is True: + while True: + config = self.lora_space.sample_configuration(1) + if config["use_lora"] is True: + break + elif lora is None: + config = self.lora_space.sample_configuration(1) + + base_config = { + "lora_rank": config["lora_config:r"] if "lora_config:r" in config else 0 + } + + extra_keys = [ + "lora_config:lora_warm_epochs", + "lora_config:lora_alpha", + "lora_config:lora_dropout", + # "lora_config:lora_toggle_probability", # TODO-AK: Add this to the config? + ] + extra_config = ( + {k.split(":")[-1]: config[k] for k in extra_keys if k in config} + if sample_lora_config + else {} + ) + + return base_config, extra_config + + def sample_perturbation( + self, perturbation: bool | None, sample_perturbation_config: bool + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the perturbation configuration. + + Args: + perturbation: Whether to use perturbation. Sampled if None. + sample_perturbation_config: Whether to sample the hyperparameters of the + perturbation. If False, uses the default values. + + Returns: + dict[str, Any]: perturbation configuration + """ + if perturbation is False: + return {"perturbation": None}, {} + + if perturbation is True: + while True: + config = self.perturbation_space.sample_configuration(1) + if config["use_perturbation"] is True: + break + elif perturbation is None: + config = self.perturbation_space.sample_configuration(1) + + if bool(config["use_perturbation"]) is True: + base_config = {"perturbation": config["perturbation_config:perturbation"]} + else: + base_config = {"perturbation": None} + + extra_config = { + k.split(":")[-1]: v + for k, v in config.items() + if self.CONFIG_SUFFIX + ":" in k + } + extra_config.pop("perturbation", None) + extra_config = extra_config if sample_perturbation_config else {} + + return base_config, extra_config + + def sample_partial_connection( + self, partial_connection: bool | None, sample_partial_connection_config: bool + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the partial connection configuration. + + Args: + partial_connection: Whether to use partial connection. Sampled if None. + sample_partial_connection_config: Whether to sample the hyperparameters of + the partial connection. If False, uses the default values. + + Returns: + dict[str, Any]: Partial connection configuration + """ + if partial_connection is False: + return {"is_partial_connection": False}, {} + + if partial_connection is True: + while True: + config = self.partial_connection_space.sample_configuration(1) + if bool(config["use_partial_connection"]) is True: + break + elif partial_connection is None: + config = self.partial_connection_space.sample_configuration(1) + + base_config = {"is_partial_connection": config["use_partial_connection"]} + extra_config = {"k": config["k"]} if "k" in config else {} + extra_config = {} if sample_partial_connection_config else {} + + return base_config, extra_config + + def sample_prune( + self, prune: bool | None, sample_prune_config: bool, epochs: int + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the prune configuration. + + Args: + prune: Whether to use prune. Sampled if None. + sample_prune_config: Whether to sample the hyperparameters of the + prune. If False, uses the default values. + epochs: Number of epochs to train the supernet. + + Returns: + dict[str, Any]: Prune configuration + """ + if prune is False: + return {}, {} + + if prune is True: + while True: + config = self.prune_space.sample_configuration(1) + if config["use_prune"] is True: + break + elif prune is None: + config = self.prune_space.sample_configuration(1) + + _ = epochs + _ = sample_prune_config + + return ( + {}, + {}, + ) # TODO-AK: Incomplete. Deal with pruning when the API is finalized. + + def sample_arch_attention( + self, arch_attention: bool | None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the architecture attention configuration. + + Args: + arch_attention: Whether to use attention between edges for arch parameters. + Sampled if None. + + Returns: + dict[str, Any]: Architecture attention configuration + """ + if arch_attention is None: + config = self.arch_attention_space.sample_configuration(1) + return {"is_arch_attention_enabled": config["use_arch_attention"]}, {} + + return {"is_arch_attention_enabled": arch_attention}, {} + + def sample_arch_params_combine_fn( + self, sampler_arch_combine_fn: str | None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the architecture parameters combine function configuration. + + Args: + sampler_arch_combine_fn: Post processing function for the arch parameters. + Sampled if None. + + Returns: + dict[str, Any]: Architecture parameters combine function configuration + """ + if sampler_arch_combine_fn is None: + config = self.sampler_arch_combine_fn.sample_configuration(1) + return {"sampler_arch_combine_fn": config["sampler_arch_combine_fn"]}, {} + + return {"sampler_arch_combine_fn": sampler_arch_combine_fn}, {} + + def sample_entangle_op_weights( + self, entangle_op_weights: bool | None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Samples the weight entanglement configuration. + + Args: + entangle_op_weights: Whether to use weight entanglement. Sampled if None. + + Returns: + dict[str, Any]: Weight entanglement configuration + """ + if entangle_op_weights is None: + return {"entangle_op_weights": True}, {} + + return {"entangle_op_weights": entangle_op_weights}, {} + + def sample( + self, + epochs: int, + *, + sampler: SamplerType | None = None, + lora: bool | None = None, + perturbation: bool | None = None, + partial_connection: bool | None = True, + prune: bool | None = None, + arch_attention: bool | None = None, + arch_params_combine_fn: str | None = None, + entangle_op_weights: bool | None = None, + sample_sampler_config: bool = True, + sample_lora_config: bool = True, + sample_perturbation_config: bool = True, + sample_partial_connection_config: bool = True, + sample_prune_config: bool = True, + ) -> BaseProfile: + """Samples a new optimizer profile. + + Key items to sample: + - sampler + - lora + - perturbation + - partial_connection + - prune + - attention between edges for arch parameters + - post processing function for the arch parameters (softmax or sigmoid) + - weight entanglement + + All parameters are sampled by default, indicated by None. Additionally, + this method gives the user to override the default sampling behavior by + specifying the values for the parameters. E.g., if user wants to use a + specific optimizer, they can pass the optimizer parameter and set the + remaining parameters to None. + + Args: + epochs: Number of epochs to train the supernet. + sampler: Sampler to be used. Sampled if None. + lora: Whether to use LoRA. Sampled if None. + perturbation: Whether to use perturbation. Sampled if None. + partial_connection: Whether to use partial connection. Sampled if None. + prune: Whether to use prune. Sampled if None. + arch_attention: Whether to use attention between edges for arch parameters. + Sampled if None. + arch_params_combine_fn: Post processing function for the arch parameters. + Sampled if None. + entangle_op_weights: Whether to use weight entanglement. Sampled if None. + sample_sampler_config: Whether to sample the hyperparameters of the + optimizer. If False, uses the default values. + sample_lora_config: Whether to sample the hyperparameters of the lora. + If False, uses the default values. + sample_perturbation_config: Whether to sample the hyperparameters of the + perturbation. If False, uses the default values. + sample_partial_connection_config: Whether to sample the hyperparameters of + the partial connection. If False, uses the default values. + sample_prune_config: Whether to sample the hyperparameters of the prune. + If False, uses the default values. + + Returns: + BaseProfile: Sampled optimizer profile + """ + full_config = {} + + base_config, extra_sampler_config = self.sample_sampler( + sampler, sample_sampler_config + ) + full_config.update(base_config) + + base_config, extra_lora_config = self.sample_lora(lora, sample_lora_config) + full_config.update(base_config) + + base_config, extra_perturbation_config = self.sample_perturbation( + perturbation, sample_perturbation_config + ) + full_config.update(base_config) + + base_config, extra_partial_connection_config = self.sample_partial_connection( + partial_connection, sample_partial_connection_config + ) + full_config.update(base_config) + + base_config, extra_prune_config = self.sample_prune( + prune, sample_prune_config, epochs + ) + full_config.update(base_config) + + base_config, extra_arch_attention_config = self.sample_arch_attention( + arch_attention + ) + full_config.update(base_config) + + ( + base_config, + extra_arch_params_combine_fn_config, + ) = self.sample_arch_params_combine_fn(arch_params_combine_fn) + full_config.update(base_config) + + base_config, extra_entangle_op_weights_config = self.sample_entangle_op_weights( + entangle_op_weights + ) + full_config.update(base_config) + + profile = BaseProfile(epochs=epochs, **full_config) + + ###### TODO-AK: Fix. Super ugly. ###### + sampler_config = { + "sample_frequency": None, + "arch_combine_fn": None, + } + if full_config["sampler_type"] in ["gdas", "reinmax"]: + sampler_config.update( + { + "tau_min": None, + "tau_max": None, + } + ) + profile.sampler_config = sampler_config + ###### END TODO ###### + + profile.configure_sampler(**extra_sampler_config) + profile.configure_lora(**extra_lora_config) + + if full_config["perturbation"] is True: + profile.configure_perturbator(**extra_perturbation_config) + + if full_config["is_partial_connection"] is True: + profile.configure_partial_connector(**extra_partial_connection_config) + + return profile + + +if __name__ == "__main__": + optimizer_sampler = OneShotOptimizerSampler(SearchSpaceType.DARTS) + + for _ in range(100): + try: + print("*" * 10) + profile = optimizer_sampler.sample(epochs=100) + print("SUCCEEDED") + except Exception as e: # noqa: BLE001 + print("FAILED") + print(e) diff --git a/src/confopt/oneshot/perturbator/sdarts/perturb.py b/src/confopt/oneshot/perturbator/sdarts/perturb.py index a10e44c3..26276887 100644 --- a/src/confopt/oneshot/perturbator/sdarts/perturb.py +++ b/src/confopt/oneshot/perturbator/sdarts/perturb.py @@ -17,7 +17,7 @@ def __init__( search_space: SearchSpace | None = None, data: tuple[torch.Tensor, torch.Tensor] | None = None, loss_criterion: torch.nn.modules.loss._Loss | None = None, - attack_type: Literal["random", "adverserial"] = "random", + attack_type: Literal["random", "adversarial"] = "random", steps: int = 7, random_start: bool = True, sample_frequency: Literal["epoch", "step"] = "step", @@ -29,12 +29,12 @@ def __init__( assert attack_type in [ "random", - "adverserial", - ], "attack_type must be either 'random' or 'adverserial'" + "adversarial", + ], "attack_type must be either 'random' or 'adversarial'" self.attack_type = attack_type - # Initialize variables for adverserial attack - if self.attack_type == "adverserial": + # Initialize variables for adversarial attack + if self.attack_type == "adversarial": assert search_space is not None, "search_space should not be None" assert data is not None, "data should not be None" diff --git a/src/confopt/profiles/profile_config.py b/src/confopt/profiles/profile_config.py index d2a837f3..59e82166 100644 --- a/src/confopt/profiles/profile_config.py +++ b/src/confopt/profiles/profile_config.py @@ -7,7 +7,7 @@ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # TODO Change this to real data -ADVERSERIAL_DATA = ( +ADVERSARIAL_DATA = ( torch.randn(2, 3, 32, 32).to(DEVICE), torch.randint(0, 9, (2,)).to(DEVICE), ) @@ -17,12 +17,16 @@ class BaseProfile: def __init__( self, - config_type: str, - epochs: int = 100, + sampler_type: str, + epochs: int, + *, is_partial_connection: bool = False, + partial_connector_config: dict | None = None, dropout: float | None = None, + sampler_sample_frequency: str = "step", perturbation: str | None = None, perturbator_sample_frequency: str = "epoch", + perturbator_config: dict | None = None, sampler_arch_combine_fn: str = "default", entangle_op_weights: bool = False, lora_rank: int = 0, @@ -37,7 +41,8 @@ def __init__( prune_num_keeps: list[int] | None = None, is_arch_attention_enabled: bool = False, ) -> None: - self.config_type = config_type + self.sampler_type = sampler_type + self.sampler_sample_frequency = sampler_sample_frequency self.epochs = epochs self.lora_warm_epochs = lora_warm_epochs self.seed = seed @@ -57,10 +62,14 @@ def __init__( self.entangle_op_weights = entangle_op_weights self._set_oles_configs(oles, calc_gm_score) self._set_pruner_configs(prune_epochs, prune_num_keeps) - PROFILE_TYPE = "BASE" - self.sampler_type = str.lower(PROFILE_TYPE) self.is_arch_attention_enabled = is_arch_attention_enabled + if partial_connector_config is not None: + self.configure_partial_connector(**partial_connector_config) + + if perturbator_config is not None: + self.configure_perturbator(**perturbator_config) + def _set_pruner_configs( self, prune_epochs: list[int] | None = None, @@ -90,6 +99,7 @@ def _set_lora_configs( ) -> None: self.lora_config = { "r": lora_rank, + "lora_warm_epochs": lora_warm_epochs, "lora_dropout": lora_dropout, "lora_alpha": lora_alpha, "merge_weights": merge_weights, @@ -115,8 +125,8 @@ def _set_perturb( perturb_type: str | None = None, perturbator_sample_frequency: str = "epoch", ) -> None: - assert perturbator_sample_frequency in ["epoch", "step"] - assert perturb_type in ["adverserial", "random", "none", None] + assert perturbator_sample_frequency in ["epoch", "step"], "Invalid frequency" + assert perturb_type in ["adversarial", "random", "none", None], "Invalid type" if perturb_type is None: self.perturb_type = "none" else: @@ -174,10 +184,10 @@ def _initialize_sampler_config(self) -> None: @abstractmethod def _initialize_perturbation_config(self) -> None: - if self.perturb_type == "adverserial": + if self.perturb_type == "adversarial": perturb_config = { "epsilon": 0.3, - "data": ADVERSERIAL_DATA, + "data": ADVERSARIAL_DATA, "loss_criterion": torch.nn.CrossEntropyLoss(), "steps": 20, "random_start": True, @@ -241,19 +251,19 @@ def _initialize_dropout_config(self) -> None: self.dropout_config = dropout_config def configure_sampler(self, **kwargs) -> None: # type: ignore - assert self.sampler_config is not None + assert self.sampler_config is not None, "sampler_config is None" for config_key in kwargs: assert ( config_key in self.sampler_config # type: ignore ), f"{config_key} not a valid configuration for the sampler of type \ - {self.config_type}" + {self.sampler_type}" self.sampler_config[config_key] = kwargs[config_key] # type: ignore def configure_perturbator(self, **kwargs) -> None: # type: ignore assert ( self.perturb_type != "none" ), "Perturbator is initialized with None, \ - re-initialize with random or adverserial" + re-initialize with random or adversarial" for config_key in kwargs: assert ( diff --git a/src/confopt/profiles/profiles.py b/src/confopt/profiles/profiles.py index a587a4ea..b765d3a4 100644 --- a/src/confopt/profiles/profiles.py +++ b/src/confopt/profiles/profiles.py @@ -11,129 +11,36 @@ class DARTSProfile(BaseProfile, ABC): - def __init__( + def __init__( # type: ignore self, epochs: int, - is_partial_connection: bool = False, - dropout: float | None = None, - perturbation: str | None = None, - sampler_sample_frequency: str = "step", - sampler_arch_combine_fn: str = "default", - perturbator_sample_frequency: str = "epoch", - partial_connector_config: dict | None = None, - perturbator_config: dict | None = None, - entangle_op_weights: bool = False, - lora_rank: int = 0, - lora_warm_epochs: int = 0, - lora_toggle_epochs: list[int] | None = None, - lora_toggle_probability: float | None = None, - seed: int = 100, - searchspace_str: str = "nb201", - oles: bool = False, - calc_gm_score: bool = False, - prune_epochs: list[int] | None = None, - prune_num_keeps: list[int] | None = None, - is_arch_attention_enabled: bool = False, + **kwargs, ) -> None: - PROFILE_TYPE = "DARTS" - self.sampler_sample_frequency = sampler_sample_frequency - super().__init__( - PROFILE_TYPE, - epochs, - is_partial_connection, - dropout, - perturbation, - perturbator_sample_frequency, - sampler_arch_combine_fn, - entangle_op_weights, - lora_rank, - lora_warm_epochs, - lora_toggle_epochs, - lora_toggle_probability, - seed, - searchspace_str, - oles, - calc_gm_score, - prune_epochs, - prune_num_keeps, - is_arch_attention_enabled, - ) - self.sampler_type = str.lower(PROFILE_TYPE) - - if partial_connector_config is not None: - self.configure_partial_connector(**partial_connector_config) - - if perturbator_config is not None: - self.configure_perturbator(**perturbator_config) + super().__init__("darts", epochs, **kwargs) def _initialize_sampler_config(self) -> None: darts_config = { "sample_frequency": self.sampler_sample_frequency, "arch_combine_fn": self.sampler_arch_combine_fn, } - self.sampler_config = darts_config # type: ignore + self.sampler_config = darts_config # type:ignore class GDASProfile(BaseProfile, ABC): - PROFILE_TYPE = "GDAS" - - def __init__( + def __init__( # type:ignore self, epochs: int, - is_partial_connection: bool = False, - dropout: float | None = None, - perturbation: str | None = None, - sampler_sample_frequency: str = "step", - sampler_arch_combine_fn: str = "default", - perturbator_sample_frequency: str = "epoch", tau_min: float = 0.1, tau_max: float = 10, - partial_connector_config: dict | None = None, - perturbator_config: dict | None = None, - entangle_op_weights: bool = False, - lora_rank: int = 0, - lora_warm_epochs: int = 0, - lora_toggle_epochs: list[int] | None = None, - lora_toggle_probability: float | None = None, - seed: int = 100, - searchspace_str: str = "nb201", - oles: bool = False, - calc_gm_score: bool = False, - prune_epochs: list[int] | None = None, - prune_num_keeps: list[int] | None = None, - is_arch_attention_enabled: bool = False, + **kwargs, ) -> None: - self.sampler_sample_frequency = sampler_sample_frequency self.tau_min = tau_min self.tau_max = tau_max super().__init__( - self.PROFILE_TYPE, + "gdas", epochs, - is_partial_connection, - dropout, - perturbation, - perturbator_sample_frequency, - sampler_arch_combine_fn, - entangle_op_weights, - lora_rank, - lora_warm_epochs, - lora_toggle_epochs, - lora_toggle_probability, - seed, - searchspace_str, - oles, - calc_gm_score, - prune_epochs, - prune_num_keeps, - is_arch_attention_enabled, + **kwargs, ) - self.sampler_type = str.lower(self.PROFILE_TYPE) - - if partial_connector_config is not None: - self.configure_partial_connector(**partial_connector_config) - - if perturbator_config is not None: - self.configure_perturbator(**perturbator_config) def _initialize_sampler_config(self) -> None: gdas_config = { @@ -146,72 +53,40 @@ def _initialize_sampler_config(self) -> None: class ReinMaxProfile(GDASProfile): - PROFILE_TYPE = "REINMAX" + def __init__( # type:ignore + self, + epochs: int, + tau_min: float = 0.1, + tau_max: float = 10, + **kwargs, + ) -> None: + super().__init__( + epochs, + tau_min, + tau_max, + **kwargs, + ) class SNASProfile(BaseProfile, ABC): - def __init__( + def __init__( # type: ignore self, epochs: int, - is_partial_connection: bool = False, - dropout: float | None = None, - perturbation: str | None = None, - sampler_sample_frequency: str = "step", - sampler_arch_combine_fn: str = "default", - perturbator_sample_frequency: str = "epoch", temp_init: float = 1.0, temp_min: float = 0.33, temp_annealing: bool = True, total_epochs: int = 250, - partial_connector_config: dict | None = None, - perturbator_config: dict | None = None, - entangle_op_weights: bool = False, - lora_rank: int = 0, - lora_warm_epochs: int = 0, - lora_toggle_epochs: list[int] | None = None, - lora_toggle_probability: float | None = None, - seed: int = 100, - searchspace_str: str = "nb201", - oles: bool = False, - calc_gm_score: bool = False, - prune_epochs: list[int] | None = None, - prune_num_keeps: list[int] | None = None, - is_arch_attention_enabled: bool = False, + **kwargs, ) -> None: - PROFILE_TYPE = "SNAS" - self.sampler_sample_frequency = sampler_sample_frequency self.temp_init = temp_init self.temp_min = temp_min self.temp_annealing = temp_annealing self.total_epochs = total_epochs - super().__init__( # type: ignore - PROFILE_TYPE, + super().__init__( # type:ignore + "snas", epochs, - is_partial_connection, - dropout, - perturbation, - perturbator_sample_frequency, - sampler_arch_combine_fn, - entangle_op_weights, - lora_rank, - lora_warm_epochs, - lora_toggle_epochs, - lora_toggle_probability, - seed, - searchspace_str, - oles, - calc_gm_score, - prune_epochs, - prune_num_keeps, - is_arch_attention_enabled, + **kwargs, ) - self.sampler_type = str.lower(PROFILE_TYPE) - - if partial_connector_config is not None: - self.configure_partial_connector(**partial_connector_config) - - if perturbator_config is not None: - self.configure_perturbator(**perturbator_config) def _initialize_sampler_config(self) -> None: snas_config = { @@ -226,67 +101,15 @@ def _initialize_sampler_config(self) -> None: class DRNASProfile(BaseProfile, ABC): - def __init__( - self, - epochs: int, - is_partial_connection: bool = False, - dropout: float | None = None, - perturbation: str | None = None, - sampler_sample_frequency: str = "step", - perturbator_sample_frequency: str = "epoch", - sampler_arch_combine_fn: str = "default", - partial_connector_config: dict | None = None, - perturbator_config: dict | None = None, - entangle_op_weights: bool = False, - lora_rank: int = 0, - lora_warm_epochs: int = 0, - lora_toggle_epochs: list[int] | None = None, - lora_toggle_probability: float | None = None, - seed: int = 100, - searchspace_str: str = "nb201", - oles: bool = False, - calc_gm_score: bool = False, - prune_epochs: list[int] | None = None, - prune_num_keeps: list[int] | None = None, - is_arch_attention_enabled: bool = False, - ) -> None: - PROFILE_TYPE = "DRNAS" - self.sampler_sample_frequency = sampler_sample_frequency - super().__init__( # type: ignore - PROFILE_TYPE, - epochs, - is_partial_connection, - dropout, - perturbation, - perturbator_sample_frequency, - sampler_arch_combine_fn, - entangle_op_weights, - lora_rank, - lora_warm_epochs, - lora_toggle_epochs, - lora_toggle_probability, - seed, - searchspace_str, - oles, - calc_gm_score, - prune_epochs, - prune_num_keeps, - is_arch_attention_enabled, - ) - self.sampler_type = str.lower(PROFILE_TYPE) - - if partial_connector_config is not None: - self.configure_partial_connector(**partial_connector_config) - - if perturbator_config is not None: - self.configure_perturbator(**perturbator_config) + def __init__(self, epochs: int, **kwargs) -> None: # type: ignore + super().__init__("drnas", epochs, **kwargs) # type: ignore def _initialize_sampler_config(self) -> None: drnas_config = { "sample_frequency": self.sampler_sample_frequency, "arch_combine_fn": self.sampler_arch_combine_fn, } - self.sampler_config = drnas_config # type: ignore + self.sampler_config = drnas_config # type:ignore class DiscreteProfile: diff --git a/src/confopt/train/experiment.py b/src/confopt/train/experiment.py index 7b1d1126..daf5f66d 100644 --- a/src/confopt/train/experiment.py +++ b/src/confopt/train/experiment.py @@ -62,7 +62,7 @@ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # TODO Change this to real data -ADVERSERIAL_DATA = ( +ADVERSARIAL_DATA = ( torch.randn(2, 3, 32, 32).to(DEVICE), torch.randint(0, 9, (2,)).to(DEVICE), ) @@ -92,7 +92,7 @@ class SamplerType(Enum): class PerturbatorType(Enum): RANDOM = "random" - ADVERSERIAL = "adverserial" + ADVERSARIAL = "adversarial" NONE = "none" @@ -173,7 +173,6 @@ def train_supernet( config = profile.get_config() run_name = profile.get_name_wandb_run() - assert hasattr(profile, "sampler_type") self.sampler_str = SamplerType(profile.sampler_type) self.perturbator_str = PerturbatorType(profile.perturb_type) self.is_partial_connection = profile.is_partial_connection @@ -847,7 +846,7 @@ def _train_discrete_model( parser.add_argument( "--perturbator", default="none", - help="Type of perturbation in (none, random, adverserial)", + help="Type of perturbation in (none, random, adversarial)", type=str, ) parser.add_argument( diff --git a/tests/test_oneshot.py b/tests/test_oneshot.py index 8d481b5a..95241b57 100644 --- a/tests/test_oneshot.py +++ b/tests/test_oneshot.py @@ -242,7 +242,7 @@ def test_sdarts_perturbator(self) -> None: # Changes the model's alpha as well, but if the loss does not decrease, it does # not change alpha # TODO Improve this test - perturbator.attack_type = "adverserial" + perturbator.attack_type = "adversarial" alphas_before = [ arch_param.clone() for arch_param in searchspace.arch_parameters ] diff --git a/tests/test_profiles.py b/tests/test_profiles.py index 3906a801..c0249121 100644 --- a/tests/test_profiles.py +++ b/tests/test_profiles.py @@ -7,6 +7,7 @@ BaseProfile, SNASProfile, ) +from confopt.profiles.profiles import ReinMaxProfile class TestBaseProfile(unittest.TestCase): @@ -241,6 +242,55 @@ def test_sampler_change(self) -> None: with self.assertRaises(AssertionError): profile.configure_sampler(invalid_config="step") +class TestReinMaxProfile(unittest.TestCase): + def test_initialization(self) -> None: + perturb_config = {"epsilon": 0.5} + partial_connector_config = { + "k": 2, + } + profile = ReinMaxProfile( + epochs=100, + is_partial_connection=True, + perturbation="random", + sampler_sample_frequency="step", + partial_connector_config=partial_connector_config, + perturbator_config=perturb_config, + ) + + assert profile.sampler_config is not None + assert profile.partial_connector_config["k"] == partial_connector_config["k"] + assert profile.perturb_config["epsilon"] == perturb_config["epsilon"] + + def test_invalid_initialization(self) -> None: + perturb_config = {"invalid_config": 0.5} + partial_connector_config = { + "invalid_config": 2, + } + + with self.assertRaises(AssertionError): + profile = ReinMaxProfile( # noqa: F841 + epochs=100, + is_partial_connection=True, + perturbation="random", + sampler_sample_frequency="step", + partial_connector_config=partial_connector_config, + perturbator_config=perturb_config, + ) + + def test_sampler_change(self) -> None: + profile = ReinMaxProfile( + epochs=100, + sampler_sample_frequency="step", + ) + sampler_config = {"tau_max": 12, "tau_min": 0.3} + profile.configure_sampler(**sampler_config) + + assert profile.sampler_config["tau_max"] == sampler_config["tau_max"] + assert profile.sampler_config["tau_min"] == sampler_config["tau_min"] + + with self.assertRaises(AssertionError): + profile.configure_sampler(invalid_config="step") + class TestSNASProfile(unittest.TestCase): def test_initialization(self) -> None: