From 1d4cbafcbfc85afb3f0a6b343a6c75c9a4ab66f9 Mon Sep 17 00:00:00 2001 From: abhash-er Date: Mon, 28 Apr 2025 18:46:20 +0200 Subject: [PATCH 1/3] feat(CompositeSampler): add composite sampler to confopt --- src/confopt/oneshot/archsampler/__init__.py | 2 + .../oneshot/archsampler/composite_sampler.py | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 src/confopt/oneshot/archsampler/composite_sampler.py diff --git a/src/confopt/oneshot/archsampler/__init__.py b/src/confopt/oneshot/archsampler/__init__.py index cee3873e..122bff66 100644 --- a/src/confopt/oneshot/archsampler/__init__.py +++ b/src/confopt/oneshot/archsampler/__init__.py @@ -1,4 +1,5 @@ from .base_sampler import BaseSampler +from .composite_sampler import CompositeSampler from .darts.sampler import DARTSSampler from .drnas.sampler import DRNASSampler from .gdas.sampler import GDASSampler @@ -12,4 +13,5 @@ "GDASSampler", "SNASSampler", "ReinMaxSampler", + "CompositeSampler", ] diff --git a/src/confopt/oneshot/archsampler/composite_sampler.py b/src/confopt/oneshot/archsampler/composite_sampler.py new file mode 100644 index 00000000..e27297da --- /dev/null +++ b/src/confopt/oneshot/archsampler/composite_sampler.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import torch + +from confopt.oneshot.archsampler import BaseSampler +from confopt.oneshot.base import OneShotComponent + + +class CompositeSampler(OneShotComponent): + def __init__( + self, + arch_samplers: list[BaseSampler], + arch_parameters: list[torch.Tensor], + ) -> None: + super().__init__() + self.arch_samplers = arch_samplers + self.arch_parameters = arch_parameters + + # get sample frequency from the samplers + self.sample_frequency = arch_samplers[0].sample_frequency + for sampler in arch_samplers: + assert ( + self.sample_frequency == sampler.sample_frequency + ), "All the sampler must have the same sample frequency" + + def sample(self, alpha: torch.Tensor) -> torch.Tensor: + sampled_alphas = alpha + for sampler in self.arch_samplers: + sampled_alphas = sampler.sample(sampled_alphas) + + return sampled_alphas + + def new_epoch(self) -> None: + super().new_epoch() + for sampler in self.arch_samplers: + sampler.new_epoch() + + def new_step(self) -> None: + super().new_step() + for sampler in self.arch_samplers: + sampler.new_step() From 8dfbc8a62337a5912f8afb82928197d471db8032 Mon Sep 17 00:00:00 2001 From: abhash-er Date: Mon, 28 Apr 2025 19:22:18 +0200 Subject: [PATCH 2/3] feat(CompositeProfile): add composite profile integrate composite sampler to confopt --- src/confopt/enums.py | 1 + src/confopt/profile/__init__.py | 2 + src/confopt/profile/profiles.py | 143 ++++++++++++++++++++++ src/confopt/train/experiment.py | 42 +++++-- src/confopt/train/search_space_handler.py | 9 +- 5 files changed, 184 insertions(+), 13 deletions(-) diff --git a/src/confopt/enums.py b/src/confopt/enums.py index 898699b8..eb5bc886 100644 --- a/src/confopt/enums.py +++ b/src/confopt/enums.py @@ -21,6 +21,7 @@ class SamplerType(Enum): GDAS = "gdas" SNAS = "snas" REINMAX = "reinmax" + COMPOSITE = "composite" def __str__(self) -> str: return self.value diff --git a/src/confopt/profile/__init__.py b/src/confopt/profile/__init__.py index ca3892ec..cb90fdcd 100644 --- a/src/confopt/profile/__init__.py +++ b/src/confopt/profile/__init__.py @@ -1,5 +1,6 @@ from .base import BaseProfile from .profiles import ( + CompositeProfile, DARTSProfile, DiscreteProfile, DRNASProfile, @@ -16,4 +17,5 @@ "SNASProfile", "DiscreteProfile", "ReinMaxProfile", + "CompositeProfile", ] diff --git a/src/confopt/profile/profiles.py b/src/confopt/profile/profiles.py index 72353017..36595463 100644 --- a/src/confopt/profile/profiles.py +++ b/src/confopt/profile/profiles.py @@ -3,6 +3,8 @@ from abc import ABC from typing import Any +from typing_extensions import override + from confopt.enums import SamplerType, SearchSpaceType from confopt.searchspace.darts.core.genotypes import DARTSGenotype from confopt.utils import get_num_classes @@ -402,6 +404,147 @@ def configure_sampler(self, **kwargs) -> None: # type: ignore super().configure_sampler(**kwargs) +class CompositeProfile(BaseProfile, ABC): + SAMPLER_TYPE = SamplerType.COMPOSITE + + def __init__( + self, + searchspace_type: str | SearchSpaceType, + samplers: list[str | SamplerType], + epochs: int, + # GDAS configs + tau_min: float = 0.1, + tau_max: float = 10, + # SNAS configs + temp_init: float = 1.0, + temp_min: float = 0.03, + temp_annealing: bool = True, + **kwargs: Any, + ) -> None: + self.samplers = [] + for sampler in samplers: + if isinstance(sampler, str): + sampler = SamplerType(sampler) + self.samplers.append(sampler) + + self.tau_min = tau_min + self.tau_max = tau_max + + self.temp_init = temp_init + self.temp_min = temp_min + self.temp_annealing = temp_annealing + + super().__init__( # type: ignore + self.SAMPLER_TYPE, + searchspace_type, + epochs, + **kwargs, + ) + + def _initialize_sampler_config(self) -> None: + """Initializes the sampler configuration for Composite samplers. + + The sampler configuration includes the sample frequency and the architecture + combine function. + + Args: + None + + Returns: + None + """ + self.sampler_config: dict[int, dict] = {} # type: ignore + for i, sampler in enumerate(self.samplers): + if sampler in [SamplerType.DARTS, SamplerType.DRNAS]: + config = { + "sampler_type": sampler, + "sample_frequency": self.sampler_sample_frequency, + "arch_combine_fn": self.sampler_arch_combine_fn, + } + elif sampler in [SamplerType.GDAS, SamplerType.REINMAX]: + config = { + "sampler_type": sampler, + "sample_frequency": self.sampler_sample_frequency, + "arch_combine_fn": self.sampler_arch_combine_fn, + "tau_min": self.tau_min, + "tau_max": self.tau_max, + } + elif sampler == SamplerType.SNAS: + config = { + "sampler_type": sampler, + "sample_frequency": self.sampler_sample_frequency, + "arch_combine_fn": self.sampler_arch_combine_fn, + "temp_init": self.temp_init, + "temp_min": self.temp_min, + "temp_annealing": self.temp_annealing, + "total_epochs": self.epochs, + } + else: + raise AttributeError(f"Illegal sampler type {sampler} provided!") + + self.sampler_config[i] = config + + @override + def configure_sampler( # type: ignore[override] + self, sampler_config_map: dict[int, dict] + ) -> None: + """Configures the sampler settings based on the provided configurations. + + Args: + sampler_config_map (dict[int, dict]): A dictionary where each key is an \ + integer representing the order of the sampler (zero-indexed), and \ + each value is a dictionary containing the configuration parameters \ + for that sampler. + + The inner configuration dictionary can contain different sets of keys + depending on the type of sampler being configured. The available keys + include: + + Generic Configurations: + - sample_frequency (str): The rate at which samples should be taken. + - arch_combine_fn (str): Function to combine architectures. + For FairDARTS, set this to 'sigmoid'. Default is 'default'. + + GDAS-specific Configurations: + - tau_min (float): Minimum temperature for sampling. + - tau_max (float): Maximum temperature for sampling. + + SNAS-specific Configurations: + - temp_init (float): Initial temperature for sampling. + - temp_min (float): Minimum temperature for sampling. + - temp_annealing (bool): Whether to apply temperature annealing. + - total_epochs (int): Total number of training epochs. + + The specific keys required in the dictionary depend on the type of + sampler being used. + Please make sure that the sample frequency of all the configurations are same. + Each configuration is validated, and an error is raised if unknown keys are + provided. + + Raises: + ValueError: If an unrecognized configuration key is detected. + + Returns: + None + """ + assert self.sampler_config is not None + + for idx in sampler_config_map: + for config_key in sampler_config_map[idx]: + assert idx in self.sampler_config + exists = False + sampler_type = self.sampler_config[idx]["sampler_type"] + if config_key in self.sampler_config[idx]: + exists = True + self.sampler_config[idx][config_key] = sampler_config_map[idx][ + config_key + ] + assert exists, ( + f"{config_key} is not a valid configuration for {sampler_type}", + "sampler inside composite sampler", + ) + + class DiscreteProfile: def __init__( self, diff --git a/src/confopt/train/experiment.py b/src/confopt/train/experiment.py index cfd2416a..c4609f80 100644 --- a/src/confopt/train/experiment.py +++ b/src/confopt/train/experiment.py @@ -40,6 +40,7 @@ ) from confopt.oneshot.archsampler import ( BaseSampler, + CompositeSampler, DARTSSampler, DRNASSampler, GDASSampler, @@ -374,17 +375,36 @@ def _set_sampler( config: dict, ) -> None: arch_params = self.search_space.arch_parameters - self.sampler: BaseSampler | None = None - if sampler == SamplerType.DARTS: - self.sampler = DARTSSampler(**config, arch_parameters=arch_params) - elif sampler == SamplerType.DRNAS: - self.sampler = DRNASSampler(**config, arch_parameters=arch_params) - elif sampler == SamplerType.GDAS: - self.sampler = GDASSampler(**config, arch_parameters=arch_params) - elif sampler == SamplerType.SNAS: - self.sampler = SNASSampler(**config, arch_parameters=arch_params) - elif sampler == SamplerType.REINMAX: - self.sampler = ReinMaxSampler(**config, arch_parameters=arch_params) + self.sampler: BaseSampler | CompositeSampler | None = None + + def _get_sampler_class(sampler: SamplerType) -> Callable: + if sampler == SamplerType.DARTS: + return DARTSSampler + if sampler == SamplerType.DRNAS: + return DRNASSampler + if sampler == SamplerType.GDAS: + return GDASSampler + if sampler == SamplerType.SNAS: + return SNASSampler + if sampler == SamplerType.REINMAX: + return ReinMaxSampler + + raise ValueError(f"Illegal sampler {sampler} provided") + + if sampler == SamplerType.COMPOSITE: + sub_samplers: list[BaseSampler] = [] + for _, sampler_config in config.items(): + sampler_type = sampler_config["sampler_type"] + del sampler_config["sampler_type"] + sampler_component = _get_sampler_class(sampler_type)( + **sampler_config, arch_parameters=arch_params + ) + sub_samplers.append(sampler_component) + self.sampler = CompositeSampler(sub_samplers, arch_parameters=arch_params) + else: + self.sampler = _get_sampler_class(sampler)( + **config, arch_parameters=arch_params + ) def _set_perturbator( self, diff --git a/src/confopt/train/search_space_handler.py b/src/confopt/train/search_space_handler.py index e720589e..25cd2ea8 100644 --- a/src/confopt/train/search_space_handler.py +++ b/src/confopt/train/search_space_handler.py @@ -2,7 +2,12 @@ import torch -from confopt.oneshot.archsampler import BaseSampler, DARTSSampler, GDASSampler +from confopt.oneshot.archsampler import ( + BaseSampler, + CompositeSampler, + DARTSSampler, + GDASSampler, +) from confopt.oneshot.dropout import Dropout from confopt.oneshot.lora_toggler import LoRAToggler from confopt.oneshot.partial_connector import PartialConnector @@ -23,7 +28,7 @@ class SearchSpaceHandler: def __init__( self, - sampler: BaseSampler, + sampler: BaseSampler | CompositeSampler, edge_normalization: bool = False, partial_connector: PartialConnector | None = None, perturbation: BasePerturbator | None = None, From 3535b0fef9f263e3c2deba4a5beb3ad90caeae60 Mon Sep 17 00:00:00 2001 From: abhash-er Date: Mon, 28 Apr 2025 19:22:55 +0200 Subject: [PATCH 3/3] test(test_sampler): add tests for composite sampler --- tests/test_sampler.py | 55 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 917593eb..55d14eb9 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -8,6 +8,7 @@ DRNASSampler, GDASSampler, ReinMaxSampler, + CompositeSampler, ) @@ -111,5 +112,59 @@ def test_post_sample_fn_sigmoid(self) -> None: _test_arch_combine_fn_default(sampler, alphas) +class TestCompositeSampler(unittest.TestCase): + def test_post_sample_fn_default(self) -> None: + alphas = [torch.randn(14, 8), torch.randn(14, 8)] + inner_samplers = [ + DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"), + GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"), + DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"), + ] + sampler = CompositeSampler(inner_samplers, alphas) + _test_arch_combine_fn_default(sampler, alphas) + + def test_post_sample_fn_sigmoid(self) -> None: + alphas = [torch.randn(14, 8), torch.randn(14, 8)] + inner_samplers = [ + DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + ] + sampler = CompositeSampler(inner_samplers, alphas) + _test_arch_combine_fn_sigmoid(sampler, alphas) + + def test_post_sample_fn_mixed(self) -> None: + alphas = [torch.randn(14, 8), torch.randn(14, 8)] + + # Case 1 + inner_samplers = [ + DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"), + GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"), + ] + sampler = CompositeSampler(inner_samplers, alphas) + # The last sampler's combine function matters, + # and if its GDAS/Reinmax, it would get ignored + _test_arch_combine_fn_default(sampler, alphas) + + # Case 2 + inner_samplers = [ + DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"), + DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + ] + sampler = CompositeSampler(inner_samplers, alphas) + _test_arch_combine_fn_default(sampler, alphas) + + # Case 3 + inner_samplers = [ + DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"), + GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"), + ] + sampler = CompositeSampler(inner_samplers, alphas) + _test_arch_combine_fn_sigmoid(sampler, alphas) + + if __name__ == "__main__": unittest.main()